Haskellでメモ化を行うもう一つの方法

はじめに

Haskell で動的計画法を書くための3つの方針 - tosの日記

これを読んで、私もちょっと前にHaskellでメモ化をやる方法を考えていたことを思い出したので、書いてみることにします。

Haskellでのメモ化は、私のかなり昔の駄文(リアルにびっくりするほど駄文なのでご注意。メモ化の綴りも間違ってます)や、このあたりに日本語の文章があります。

これらのページでのメモ化実現方針は、1. 計算済み値を保持するテーブルをMapなどを用いて用意する 2. そのテーブルを副作用を用いて更新する、というものになっています。なるほど、これは手続き型言語との対比でとても直接的な実装です。しかし、テーブルを更新するために関数全体がモナドになってしまったりして、あまり使い勝手が良くなさそうです。モナドであることを悟らせないために、演算子モナド化したり、あるいはモナドじゃなくするためにunsafePerformIOを使うなどのトリックも考えられますが、どちらもあまり綺麗には見えません。

メモ化においてやりたい事は、テーブルの作成でもテーブルの更新でも無くて、一回計算した値がもう一度計算されないようになっているということです。一度計算したものがもう一度計算されない―――おあつらえ向きの物がHaskellにはあります。そう、遅延評価です。

基本的な考え方

簡単のために、メモ化関数の引数は1つの自然数とします(この制約は後である程度ゆるくします)。無限のサイズのテーブルがあると考えます。そこに、その関数(fとしましょう)のすべての引数に対する値が格納されています。0番目の所にはf 0、1番目のところにはf 1、…というふうにすべて入っています。遅延評価を用いて、各々の値は参照されるまで計算されないようにします。f n を計算したければ、そのテーブルをルックアップするだけです。そのノードが参照されるのが初めてならその時点でf nが計算され、二回目以降ならすでに計算された値が返されます。問題は、そのテーブルをどうやって表現するかです。効率的にルックアップ出来なければいけませんし、計算されていないノードは存在しないようにしなりません。

無限リストからの類推で、無限ツリーというものを考えます。無限ツリーはノードが無限にあるツリーです。ここではリーフはないとします。つまり、ノードはすべて内部ノードで、子を持っています。その特殊な場合として、無限完全二分木を考えます。すべてのノードが2つの子をもち、無限に広がっているような木です。これにヒープよろしく、ルートには0を、その子には1と2、さらにその子には3と4と5と6…というように番号をふります。そして、その対応する番号を引数とする関数の値をノードに持たせます。これは完全にバランスした二分木なので、O(log n)でルックアップできます。

         0
    1         2
  3   4     5     6
 7 8 9 10 11 12 13 14
...

実装

メモ化のインターフェースは、fixと同じものを採用します(fixというのは、ここではControl.Monad.fixを指しています。これは一般には"不動点演算子"と呼ばれるものです。不動点演算子に関しては、私のブログの駄文や、Googleに聞いてみてください)。fixをこれから実装するmemofixに挿げ替えるだけで関数がメモ化されるという目論見です。利用例は次の様になります。

-- normal version
fib = fix $ \f n -> if n<2 then n else f (n-1) + f (n-2)
-- memoised version
memofib = memofix $ \f n -> if n<2 then n else f (n-1) + f (n-2)

まずは無限完全二分木を定義します。これはリーフを考えなくていいので、とても簡単です。

data Tree a = Tree a (Tree a) (Tree a)

次に、このツリーのルックアップを定義します。添字の上位ビットからみて、ツリーを辿っていくだけです。

findTree :: Tree b -> Integer -> b
findTree tree ix = f (bits $ ix + 1) tree
  where
    f []     (Tree v _ _) = v
    f (0:bs) (Tree _ l _) = f bs l
    f (_:bs) (Tree _ _ r) = f bs r

    bits = tail . reverse . map (`mod`2). takeWhile (>0) . iterate (`div`2)

次に、関数の値がすべて詰まった無限ツリーの生成です。ノードの値を計算するための関数を渡してもらって、適当に全部生成するだけです。遅延評価がすべて何とかしてくれます。

genTree :: (Integer -> b) -> Tree b
genTree f = gen 0 where
  gen ix = Tree (f ix) (gen $ ix*2+1) (gen $ ix*2+2)

さて、あとはメモ化コードを書くだけです。これもとてもシンプルです。

memofix :: ((Integer -> b) -> (Integer -> b)) -> (Integer -> b)
memofix f = memof where
  memof = f $ findTree tbl
  tbl = genTree memof

memofがメモ化versionのfで、tblがすべての答えの詰まった無限ツリーです。tblは先程のgenTreeにmemofを渡してやれば出来上がりで、memofはfにtblをルックアップする関数を渡してやればいいのです。簡単でしょ?

memofib = memofix $ \f n -> if n<2 then n else f (n-1) + f (n-2)
main = print $ map memofib [(0 :: Integer)..100]

実行してみます。

$ ghc --make Main.hs && time ./Main 
[1 of 1] Compiling Main             ( Main.hs, Main.o )
Linking Main ...
[0,1,1,2,3,5,8,13,21,34,55,89,144,233,377,610,987,1597,2584,4181,6765,10946,17711,28657,46368,75025,121393,196418,317811,514229,832040,1346269,2178309,3524578,5702887,9227465,14930352,24157817,39088169,63245986,102334155,165580141,267914296,433494437,701408733,1134903170,1836311903,2971215073,4807526976,7778742049,12586269025,20365011074,32951280099,53316291173,86267571272,139583862445,225851433717,365435296162,591286729879,956722026041,1548008755920,2504730781961,4052739537881,6557470319842,10610209857723,17167680177565,27777890035288,44945570212853,72723460248141,117669030460994,190392490709135,308061521170129,498454011879264,806515533049393,1304969544928657,2111485077978050,3416454622906707,5527939700884757,8944394323791464,14472334024676221,23416728348467685,37889062373143906,61305790721611591,99194853094755497,160500643816367088,259695496911122585,420196140727489673,679891637638612258,1100087778366101931,1779979416004714189,2880067194370816120,4660046610375530309,7540113804746346429,12200160415121876738,19740274219868223167,31940434634990099905,51680708854858323072,83621143489848422977,135301852344706746049,218922995834555169026,354224848179261915075]

real	0m0.012s
user	0m0.010s
sys	0m0.000s

動きました。メモ化しないと、

fib = fix $ \f n -> if n<2 then n else f (n-1) + f (n-2)
main = print $ map fib [(0 :: Integer)..100]
$ ghc --make Main.hs && time ./Main 
[1 of 1] Compiling Main             ( Main.hs, Main.o )
Linking Main ...
^C

帰って来ず。

添字の一般化

添字が自然数1つというのは嬉しくありません。自然数以外や二つ以上の引数を扱うことを考えます。無限完全二分木以外のデータ構造を考えるのが面倒なので、任意の引数を自然数に変換することを考えます。つまり、当該データ型のすべてのデータに番号を振ってやって、その番号によってツリーをルックアップするのです。もちろんこれはその型が可算無限集合でなければいけないのですが、可算無限でない引数をメモ化の引数にするのはここでは考えないことにします。複数の引数はそのまま扱うのはとても難しいので、タプルにするということで妥協します。

自然数に変換できるというクラスMemoIxを定義します。逆変換も後で必要になるので用意しておきます。

class MemoIx a where
  index   :: a -> Integer
  unindex :: Integer -> a

ところで、Haskellの標準ライブラリにIxというクラスがあります。MemoIxと同様に値に対して整数を振るためのクラスですが、これを利用するには値の上限・下限が必要になります。遅延評価でせっかく無限のものを扱えるようになったのに、値の上限・下限を要求されるのは大変イケてないので、これをそのまま使うことは出来ませんでした。

具体的な型に対するインスタンスを書きます。まずはIntegerです。Integerは整数であって自然数ではないので、自然数に変換してやる必要があります。正の数を偶数に、負の数を奇数にエンコードすることにします。

instance MemoIx Integer where
  index n | n>=0 = n*2
          | otherwise = -n*2-1

  unindex n | n`mod`2==0 = n`div`2
            | otherwise = -((n+1)`div`2)

次にタプルを考えます。タプルはペアで表現できるので、ペアのみを考えます。ペアにうまく番号を振るには、次のように斜めに振ってやれば良いです(方法は他にも色々考えられます)。

 | 0 1 2 3 ...
-+-----------
0| 0 2 5 9 ...
1| 1 4 8 ...
2| 3 7 ...
3| 6 ...
.| ...

この番号を頑張って計算します。unindexで二次方程式浮動小数を用いて解いているのですが、これのせいでIntegerのメリットが失われているのがあまり良くないです(何とかならないでしょうか)。

instance (MemoIx a, MemoIx b) => MemoIx (a, b) where
  index (a, b) = l*(l+1)`div`2 + ib
    where
      ia = index a 
      ib = index b
      l  = ia+ib

  unindex ix = (unindex ia, unindex ib)
    where
      l  = floor ((-1 + sqrt (1 + 8 * fromIntegral ix))/2)
      ib = ix - l*(l+1)`div`2
      ia = l-ib

ところでこのコードは自然数への変換の際にかなり数が大きくなります。これは大丈夫なのでしょうか?入力a, bに対して、出力の大きさはO(ab)程度です。ということは、これをルックアップするのはO(log(max(a, b)))程度ということになります。なので、大丈夫そうです。

文字列型など他の型のMemoIxのインスタンスも考えられますが、それぞれ皆様各自お考え下さいませ。

さて、後はツリーとメモ化コードをMemoIxに対応させれば完成です。

findTree :: MemoIx a => Tree b -> a -> b
findTree tree ix = f (bits $ index ix + 1) tree
  where
    f []     (Tree v _ _) = v
    f (0:bs) (Tree _ l _) = f bs l
    f (_:bs) (Tree _ _ r) = f bs r

    bits = tail . reverse . map (`mod`2). takeWhile (>0) . iterate (`div`2)

genTree :: MemoIx a => (a -> b) -> Tree b
genTree f = gen 0 where
  gen ix = Tree (f $ unindex ix) (gen $ ix*2+1) (gen $ ix*2+2)

memofix :: MemoIx a => ((a -> b) -> (a -> b)) -> (a -> b)
memofix f = memof where
  memof = f $ findTree tbl
  tbl = genTree memof

適当な関数を書いてみます。

comb :: (Integer, Integer) -> Integer
comb = memofix $ \f (i, j) -> if i==1||j==1 then 1 else f (i-1, j) + f (i, j-1)

main = print $ comb (100,50)

実行してみます。

$ ghc --make Main.hs -O2 && time ./Main 
[1 of 1] Compiling Main             ( Main.hs, Main.o )
Linking Main ...
4503056131931081050165600532646379362000

real	0m0.082s
user	0m0.070s
sys	0m0.010s

正しく動作しました。

インターフェースのバリエーション

fixのインターフェースではなくて、次のようなものも考えられます。テーブルを作ってルックアップする関数を返すだけのシンプルなものです。

memo :: MemoIx a => (a -> b) -> (a -> b)
memo f = findTree (genTree f)

これを用いてfibを定義するとこうなります。

fib :: Integer -> Integer
fib = memo $ \n -> if n<2 then n else fib (n-1) + fib (n-2)

fixのインターフェースよりもこちらの方が書きやすいかもしれません。

ソースコード

全体のソースコードを再掲しておきます。

class MemoIx a where
  index :: a -> Integer
  unindex :: Integer -> a

instance MemoIx Integer where
  index n | n>=0 = n*2
          | otherwise = -n*2-1

  unindex n | n`mod`2==0 = n`div`2
            | otherwise = -((n+1)`div`2)

instance (MemoIx a, MemoIx b) => MemoIx (a, b) where
  index (a, b) = l*(l+1)`div`2 + ib
    where
      ia = index a 
      ib = index b
      l  = ia+ib

  unindex ix = (unindex ia, unindex ib)
    where
      l  = floor ((-1 + sqrt (1 + 8 * fromIntegral ix))/2)
      ib = ix - l*(l+1)`div`2
      ia = l-ib

data Tree a = Tree a (Tree a) (Tree a)

findTree :: MemoIx a => Tree b -> a -> b
findTree tree ix = f (bits $ index ix + 1) tree
  where
    f []     (Tree v _ _) = v
    f (0:bs) (Tree _ l _) = f bs l
    f (_:bs) (Tree _ _ r) = f bs r

    bits = tail . reverse . map (`mod`2). takeWhile (>0) . iterate (`div`2)

genTree :: MemoIx a => (a -> b) -> Tree b
genTree f = gen 0 where
  gen ix = Tree (f $ unindex ix) (gen $ ix*2+1) (gen $ ix*2+2)

memofix :: MemoIx a => ((a -> b) -> (a -> b)) -> (a -> b)
memofix f = memof where
  memof = f $ findTree tbl
  tbl = genTree memof

memo :: MemoIx a => (a -> b) -> (a -> b)
memo f = findTree (genTree f)

とかいうことを

嬉々として書いていたら、
http://www.haskell.org/haskellwiki/Memoization
こんなのがあったのですよ…。ぶわっ。
こっちのほうがツリーのエンコード方法が上手いので、後でコード書きなおします。