Rustでメモ化を行うためのシンプルなライブラリを作った

TL;DR

一行追加するだけで関数をメモ化するマクロを作った。

成果物はこちら https://docs.rs/memoise/

背景

同じ引数に対して同じ値を返す関数(いわゆる参照透明だったり数学的だったりな関数)では、 関数の計算結果を保存しておくことによって計算を高速化したりすることができます。 このようなテクニックを関数のメモ化(memoise, memoize, memoization)などと呼びます。 特に再帰的に定義される関数についてメモ化を行うことによって、 動的計画法の実装をシンプルで直感的なものにできたりします。

しかし、関数のメモ化はやりたいことが自明なのにもかかわらず、 毎回手で書いていると微妙に面倒だったり、うっかりメモ化忘れで計算量が爆発してしまったり、 ちょっと辛いところがありました。

特にRustを使っていると、グローバル変数を雑に使うことを許して貰えないので、 毎回メモ化のためのテーブルを関数の引数として引き回さなければならなかったり、 メモ化テーブルのmutableリファレンスのスコープを短く抑える必要があったりするので、 C++などでやるよりも若干コードにノイズが多くなります。 メモ化のためのコードがコードの半分ぐらいを占めたりして、 関数の見通しが悪くなることもあります。

例えば、n個のものからk個を選ぶ組み合わせの数を求める動的計画法 (別に動的計画法で解く必要がある問題ではないけど、解説のためのシンプルな例題として)を考えてみます。

これを求める関数を comb(n, k) とすると

  • k = 0の時は1通り
  • k > 0 かつ n = 0 の時は0通り
  • それ以外の時は、1つめのものを選ぶ時と選ばないときを考えると、 comb(n-1, k-1) + comb(n-1, k)

と、再帰的な定義が考えられます。

これのnkに対して2次元配列を作って、 値の依存する方向を考えて組織的に表を埋めていくのが動的計画法ですが、 この再帰的定義をそのまま再帰関数として記述しても、計算する関数自体はできます。

fn comb(n: usize, k: usize) -> usize {
    if k == 0 {
        return 1;
    }
    if n == 0 {
        return 0;
    }
    comb(n - 1, k - 1) + comb(n - 1, k)
}

ただ、このままだと同じ引数に対して何回も計算をすることになるので、 引数のサイズに応じて指数的な計算時間が掛かってしまいます。 高速に実行するにはメモ化が必要になってくると言うわけです。

手動メモ化

この関数をメモ化することを考えてみます。 まず、計算結果を保存するテーブルが必要です。 引数nkに対して結果を保存したいので、2次元のVecを使うことにします。 要素の値として、計算済みと未計算を区別しないといけないので、Option型を使います。 なので、テーブルの型全体としてはVec<Vec<Option<size>>> になります。

fn comb(n: usize, k: usize, tbl: &mut Vec<Vec<Option<usize>>>) -> usize

すでにタイピングが死ぬほど面倒です。

Option<usize> ではメモリ上のオーバーヘッドが生じるので、 この関数の場合は結果が必ず正の値になることを利用して、 テーブルを-1などで初期化して、負の値なら未計算扱いにするなどのハックがありますが、 汎用性を考えると難しくなるので、今回はナイーブにOptionを使うことにします。

(少し話は逸れますが、Rustにはこういう用途でオーバーヘッドを回避するためか(?)NonZero型なるものがあったりしますが、 個人的には0を有効な値として使いつつOptionのオーバーヘッドを回避したいので、 NonMaxみたいなのがあると嬉しいなあとか思ってたりしました)

次に、計算済みかチェックするコードを追加します。

fn comb(n: usize, k: usize, tbl: &mut Vec<Vec<Option<usize>>>) -> usize {
    if let Some(ret) = tbl[n][k] {
        return ret;
    }
    ....
}

続いて、再帰の引数でテーブルを引き回す部分追加します。

    ...
    comb(n - 1, k - 1, tbl) + comb(n - 1, k, tbl)
    ...

最後に、計算結果をテーブルに保存する部分を書きます。

    ...
    let ret = comb(n - 1, k - 1, tbl) + comb(n - 1, k, tbl);

    tbl[n][k] = Some(ret);
    ret
}

全体としては次のようになります。

fn comb(n: usize, k: usize, tbl: &mut Vec<Vec<Option<usize>>>) -> usize {
    if let Some(ret) = tbl[n][k] {
        return ret;
    }

    if k == 0 {
        return 1;
    }
    if n == 0 {
        return 0;
    }

    let ret = comb(n - 1, k - 1, tbl) + comb(n - 1, k, tbl);

    tbl[n][k] = Some(ret);
    ret
}

さらに、呼び出し側をテーブルを作って渡してやるように変更します。

comb(n, k, &mut vec![vec![None; k + 1]; n + 1]);

これでようやく完成です。 個々のコードが難しいと言うわけでは決してないのですが、 元のコードと比較して、コードを変更しなければならない箇所がかなり散らばっている上に、 テーブルの参照と更新の部分は、忘れると計算量が爆発するにもかかわらず、 忘れてもコンパイルは通ってしまうのが質が悪いところです。

また、メモ化のキーにしたい引数が増えるに従ってリニアにテーブルのネストが増えるし、 引数の添字も間違えやすくなります(特にRustでは配列作るのと参照するので、見かけの順序が逆になっているので)。

やりたいのは「nとkでメモ化したい」というだけなのに、 こんな面倒をなことをするのはやはりなんだかおかしい気がします。

自動メモ化のアプローチ

というわけで、これらの操作を自動化したいのです。 メモ化を(半)自動的に行うにはいろいろアプローチが考えられて、

などがありますが(名称は勝手に僕が考えたものなので、この辺の分類学は詳しい方がいらっしゃったら教えて下さい)、 今回は一番ナイーブなメタプログラミングでやっていきたいと思います。

(遅延無限木をHaskellネイティブの遅延評価の上で実装して不動点演算子化するアプローチなどは僕の昔の記事 https://tanakh.hatenablog.com/entry/20100411/p1 があります。正直今回のアプローチよりもはるかに凝ったことしてるので話としては面白いと思います。)

さて、メタプログラミングによる関数の自動メモ化ですが、実は先行研究として cached というクレートがあります。こちらの方が機能は豊富で細かいところもカスタマイズできるので、普段使いにはこちらで良いような気もしますが、

  • 使い方がやや複雑
  • なんだか冗長
  • テーブルがHashMapなので遅い

なのがちょっと引っかかったので、そのあたりを解消するべくPoC的なライブラリを作って見ましたという話になります。

作ったもの

そういうわけで、設計思想としては

  • 死ぬほどシンプルに
  • 手でメモ化したときと同じぐらいの速度になるように

という2つを重点に考えました。

シンプルにするために、通常のマクロじゃなくて、attribute macroで実装することにしました。 基本的に関数の頭にattributeをポン置きするだけの使い方です。 性能については、キーをusizeに限定して、あらかじめ取りうる範囲を指定することにしました。

できたものがこちらになります。 https://docs.rs/memoise/

使い方としては、関数の頭に1行#[memoise(keys(...))]というのを追加するだけです。 呼び出し方も通常の関数通り呼び出すだけです。

use memoise::memoise;

#[memoise(keys(n = 100, k = 100))]
fn comb(n: usize, k: usize) -> usize {
    if m == 0 {
        return 1;
    }
    if n == 0 {
        return 0;
    }
    comb(n - 1, k - 1) + comb(n - 1, k)
}

keys の所にメモ化したい引数と取り得る最大の値を書きます。 この場合だと、nkがそれぞれ[0..100]の値をとることができます。

宣言的にメモ化を行うことができるようになったので、 手動での実装に比べると手間も間違える余地も大幅に減っていると思います。

実際にマクロが生成しているコードは手で書いた場合に追加するものに近いものになっています。 相違点として、引数でのテーブルの引き回しを避けるために、グローバルにテーブルを定義します。 スレッドセーフにテーブルへアクセスするために、thread_local!を使用しています。 なので、スレッド間ではテーブルは共有しないので、マルチスレッドでの高速化はできません。

Future work

使い方はこれだけなんですが、現状これではシンプル過ぎて汎用性に欠ける気もするので、

  • テーブル自動リサイズであらかじめサイズ指定しないようにできるはず
  • キーとして変数リテラルだけじゃなくて任意の式を取れるようにしたい
  • ハッシュテーブルやバランスツリーを使った実装もおいおいは追加したい

あたりの拡張を考えています。