多値の結合

Lua は、関数から複数の値(多値)を返すことの出来る言語です。

function f()
  return 1, 2, 3
end
print( f() )  -- 1 2 3

そんな Lua を使っていると、割としばしば、「多値を結合したい」と思う場面に出くわします。
例えば…例えば、えっと、とっさに例が思いつきませんでしたが、とにかく出くわすのです。


しかしながら、少し Lua を触ったことのある人なら分かるように、これは容易ではありません。
単純に、こんな感じ(↓)に書けばいいと思うかもしれませんが、

function f()
  return 1, 2, 3
end
print( f(), 4, 5 )

これだと、「カンマの前の値は強制的に一つにする」という仕様があるため、結果は

1	4	5

このようになってしまいます。


もちろん、これは非常に妥当な仕様であり、例えば

local x, y, z = f(), g(), h()

という文があった場合にどうなるか、というのを考えれば分かると思いますが、
しかし、一方で、関数の戻り値を捨てずに結合したいと思う時もあります。
具体例は相変わらず思いつかないですが、とにかくあるのです。


で、この辺りに関しては、 Lua の開発側も気にしているようで、
メーリスのログをざっと眺めた感じでは、割と積極的に議論されているようです。
なので、構文さえ安定すれば、そのうち言語仕様に取り入れられるはずです。
が、そのうち、というのは嫌だなぁとも思うのです。とはいえ、私家パッチを作るのも面倒。


というわけで、なんとか既存の枠組みで 多値を結合できないか考えてみたところ、
こんな感じの関数 combine が出来上がりました:

function combine(...)
  return function(...) 〜 end
end

print( combine( f() )( 4, 5 ) )  -- 1 2 3 4 5

見たとおり、 combine( explist1 )( explist2 ) と書けば、
explist1 と explist2 を結合して、返してくれます。
また、即結合するのではなく、一段目の関数呼び出しを何処かに格納することで、

local f = combine( 1, 2 )
print( f( 3, 4, 5 ) )  -- 1 2 3 4 5
print( f( 4, 8 ) )  -- 1 2 4 8

このように使うことも出来ます。


実装は、こんな感じです:
http://ideone.com/2kPFk

-- 補助関数
 
-- 引数をそのまま帰す
function id(...)
  return ...
end
 
-- 引数の数を数える
do
  local select = select
  function count(...)
    return select( "#", ... )
  end
end
 
-- 整数 n に対し、 n/2 と n - n/2 のペアを返す
do
  local floor = math.floor
  function halfDiv(n)
    local nh = floor( n / 2 )
    return nh, n - nh
  end
end
 
 
-- 実装本体
 
-- まず引数一定の場合から
 
-- 0 の場合も一応定義。引数をそのまま返せばいい
function combine0()
  return id
end
 
-- 1 以降は手で書く。
function combine1( x )
  return function(...)
    return x, ...
  end
end
function combine2( x, y )
  return function(...)
    return x, y, ...
  end
end
function combine3( x, y, z )
  return function(...)
    return x, y, z, ...
  end
end
function combine4( x, y, z, w )
  return function(...)
    return x, y, z, w, ...
  end
end
 
-- 一般引数の場合
do
  local combine0, combine1, combine2, combine3, combine4 =
        combine0, combine1, combine2, combine3, combine4
  
  local halfDiv = halfDiv
  local select = select
  
  -- ... の先頭 n 個の引数に対し(仮に a1, ... an と置く)、
  -- f( args ) が a1, ... an, args を返すような関数 f を作って返す。
  local function combine_( n, ... )
    -- 引数の数が少ない場合は固定定数版を返す。
    if n == 0 then return combine0(...) end
    if n == 1 then return combine1(...) end
    if n == 2 then return combine2(...) end
    if n == 3 then return combine3(...) end
    if n == 4 then return combine4(...) end
    
    -- 引数の数が多い場合は半分に分割する
    local comb1, comb2
    do
      -- n の半分の値を得る
      local nh, nh_ = halfDiv(n)
      -- 半分に分割し、再帰的に適用
      comb1, comb2 = combine_( nh, ... ), combine_( nh_, select( nh+1, ... ) )
    end
    
    -- combine( a, b, c, d )(...) == combine( a, b )( combine( c, d )( ... ) )
    return function (...)
      return comb1( comb2(...) )
    end
  end
  
  -- 本体は n を調べて combine_ を呼ぶだけ
  local count = count
  function combine( ... )
    return combine_( count(...), ... )
  end
end


ポイントとしては、テーブルを一切使わず、クロージャのみで完結させている点です。
もちろんテーブルを使って実装してもよく、その場合には、おおよそ

function combine(...)
  local args = {...}
  return function(...)
    local results = { unpack(args) }
    local n = #results
    for i = 1, select( '#', ... ) do
      n = n + 1
      results[n] = select( i, ... )
    end
    return unpack( results )
  end
end

このようなコードになるはずです。
実際にはもっと効率をよく出来る筈なので、気が向いたらテーブル版も作ってみることにします。