間違った分割統治法

多値の結合 - 野良C++erの雑記帳 の実装コードが非効率的だったので、訂正します。

-- 補助関数

-- 引数をそのまま帰す
function id(...)
  return ...
end

-- 引数の数を数える
do
  local select = select
  function count(...)
    return select( "#", ... )
  end
end


-- 実装本体

-- まず引数一定の場合から

-- 0 の場合も一応定義。引数をそのまま返せばいい
function combine0()
  return id
end

-- 1 以降は手で書く。
function combine1( x )
  return function(...)
    return x, ...
  end
end
function combine2( x, y )
  return function(...)
    return x, y, ...
  end
end
function combine3( x, y, z )
  return function(...)
    return x, y, z, ...
  end
end
function combine4( x, y, z, w )
  return function(...)
    return x, y, z, w, ...
  end
end

-- 一般引数の場合
do
  local combine0, combine1, combine2, combine3, combine4 =
        combine0, combine1, combine2, combine3, combine4
  
  local select = select
  
  -- ... の先頭 n 個の引数に対し(仮に a1, ... an と置く)、
  -- f( args ) が a1, ... an, args を返すような関数 f を作って返す。
  local function combine_( n, ... )
    -- 引数の数が少ない場合は固定定数版を返す。
    if n == 0 then return combine0(...) end
    if n == 1 then return combine1(...) end
    if n == 2 then return combine2(...) end
    if n == 3 then return combine3(...) end
    if n == 4 then return combine4(...) end
    
    -- 引数の数が多い場合は、後ろ4つを分割
    local comb1, comb2 = combine_( n-4, ... ), combine4( select( n-3, ... ) )
    
    -- combine( a, b, c, d, e, f )(...) == combine( a, b )( combine( c, d, e, f )( ... ) )
    return function (...)
      return comb1( comb2(...) )
    end
  end
  
  -- 本体は n を調べて combine_ を呼ぶだけ
  local count = count
  function combine( ... )
    return combine_( count(...), ... )
  end
end

以上。


前回の記事のコードは、見た感じ分割統治法で効率的に見えますが、
実は素直に定数で分割していったほうが効率的だったというオチ。
実際にクロージャが作られる回数を計算してみると分かると思います。


んで、それだけじゃ寂しいので、定数で処理できる combineN の数を増やし、
分岐のコストを抑えるためテーブルに入れることで、更に効率化してみます:

-- 補助関数

-- 引数をそのまま帰す
function id(...)
  return ...
end

-- 引数の数を数える
do
  local select = select
  function count(...)
    return select( "#", ... )
  end
end


-- 実装本体

-- まず引数一定の場合から

-- 0 の場合も一応定義。引数をそのまま返せばいい
function combine0()
  return id
end

-- 1 以降は手で書く。
function combine1( a1 )
  return function(...)
    return a1, ...
  end
end
function combine2( a1, a2 )
  return function(...)
    return a1, a2, ...
  end
end
function combine3( a1, a2, a3 )
  return function(...)
    return a1, a2, a3, ...
  end
end
function combine4( a1, a2, a3, a4 )
  return function(...)
    return a1, a2, a3, a4, ...
  end
end
function combine5( a1, a2, a3, a4, a5 )
  return function(...)
    return a1, a2, a3, a4, a5, ...
  end
end
function combine6( a1, a2, a3, a4, a5, a6 )
  return function(...)
    return a1, a2, a3, a4, a5, a6, ...
  end
end
function combine7( a1, a2, a3, a4, a5, a6, a7 )
  return function(...)
    return a1, a2, a3, a4, a5, a6, a7, ...
  end
end
function combine8( a1, a2, a3, a4, a5, a6, a7, a8 )
  return function(...)
    return a1, a2, a3, a4, a5, a6, a7, a8, ...
  end
end

-- 一般引数の場合
do
  local combine_fs = {
    combine1, combine2, combine3, combine4,
    combine5, combine6, combine7, combine8,
    [0] = combine0  -- Lua のテーブルは 1-base だけど 0 にも代入して構わない
  }
  local m = #combine_fs  -- 定数で処理できる最大の n (=8)
  local combineM = combine_fs[m]
  local select = select
  
  -- ... の先頭 n 個の引数に対し(仮に a1, ... an と置く)、
  -- f( args ) が a1, ... an, args を返すような関数 f を作って返す。
  local function combine_( n, ... )
    -- 引数の数が少ない場合は固定定数版を返す。
    if n <= m then
      local combineN = combine_fs[n]
      return combineN(...)
    end
    
    -- 引数の数が多い場合は、後ろ m 個で分割
    local comb1 = combine_( n-m, ... )
    local comb2 = combineM( select( n-m+1, ... ) )
    
    -- combine( a, b, c, d )(...) == combine( a, b )( combine( c, d )( ... ) )
    return function (...)
      return comb1( comb2(...) )
    end
  end
  
  -- 本体は n を調べて combine_ を呼ぶだけ
  local count = count
  function combine( ... )
    return combine_( count(...), ... )
  end
end

この例では 8 引数で打ち止めにしましたが、
力技で更に並べる(あるいは、もっと賢く、スクリプトに自動生成させる)ことで、
殆どのケースにおいて、クロージャの製作個数を劇的に減らせます。
反面、ファイルの読み込みには時間がかかるようになりますが、
これは予め luac でコンパイルしておく等の工夫によって解決できます。