スキップしてメイン コンテンツに移動

「フカシギの数え方」の問題を解いてみた

先日、「『フカシギの数え方』 おねえさんといっしょ! みんなで数えてみよう!」という動画を見た。格子状のマスの左上から右下までの経路が何通りあるのかを調べて、格子が多くなればなるほど組み合わせの数が爆発的に増えることを教えてくれる動画だ。これは自己回避歩行(Self-avoiding walk)と呼ばれている問題らしい。

これだけ聞いてもそれほどインパクトはないのだが、動画に出てくるおねえさんの経路を調べあげる執念がもの凄く、ネット上でも結構な話題になっている。執念と言うよりも狂気に近い。しかし、話題になった割には動画内で言及されている高速なアルゴリズムを実装したという話を聞かなかったので、自分で確かめることにした。


動画のおねえさんは深さ優先探索によるプログラムを使っていると思われるが、それだとスパコンを使っても10×10マスの格子を解くのに25万年も掛かってしまう。そこで、高速化のためにゼロサプレス型二分決定グラフ(ZDD; Zero-Suppressed Binary Decision Diagram)と呼ばれるアルゴリズムを利用することにした。このアルゴリズムを開発したのは北大の湊先生で、ZDDによりすべての経路を見つけ出すアルゴリズムとしてクヌース先生のSIMPATHを使った。ZDDについてはクヌース先生も強い関心を持っていて、 The Art of Computer Programming Volume 4, Fascicle 1(TAOCP 4-1)ではBDD/ZDDの詳細な解説を読むことができる。演習問題の解説だけで書籍の半分を使っていることからしても気合の入れようがわかるだろう。

実際に自分のノートPCでZDDアルゴリズムを使ったコードを走らせたら、ほんの10秒程度で10×10マスの問題を解いてしまった。おねえさんがスパコンで25万年かかった問題をノートPCでたった10秒である。約8千億倍の高速化だ。これだけ劇的に変わるとやっぱり楽しい。そして、アルゴリズムの重要性を再認識させられた。

さて、以下におねえさんが利用したであろう自作の深さ優先探索(DFS)プログラムと高速な解法であるZDDアルゴリズムの両方を載せておく。ZDDについては4つのプログラムを1つのPythonスクリプトで統合している。これは、クヌース先生のSIMPATHおよびSIMPATH-REDUCEを利用しており、また、SGB (Stanford GraphBase)ライブラリでグラフを作成し、経路の数え上げにGMPを使ったC++で自作コードを組んだためである。最初はSIMPATHを参考に実装して一つのソースコードに統一しようかと思ったが、解説もちゃんと書かれているクヌース先生のコードをそのまま利用することにした。これらのプログラムの簡単な解説と使い方を載せておく。

まずはDFSによる実装を以下に示す。Node構造体から経路を辿って数え上げるだけの単純なプログラムである。

// Count paths from (0, 0) to (N, N) in the NxN lattice graph using DFS.
#include <iostream>
#include <sstream>
#include <vector>
using namespace std;
static unsigned long long cnt = 0;
struct Node {
int idx;
vector<Node*> neighbors;
bool is_visit;
bool is_goal;
Node(int idx) : idx(idx), is_visit(false), is_goal(false) { }
void Count() {
if (is_goal) { cnt++; return; }
is_visit = true;
for (vector<Node*>::iterator p = neighbors.begin(); p != neighbors.end(); ++p)
if (!(*p)->is_visit) (*p)->Count();
is_visit = false;
}
};
int main(int argc, char* argv[])
{
if (argc < 2) {
cerr << "Usage: " << argv[0] << " N" << endl;
return 1;
}
int N;
vector<Node*> nodes;
stringstream ss;
ss << argv[1]; ss >> N;
for (int i = 0; i < N * N; i++) nodes.push_back(new Node(i));
nodes[N*N-1]->is_goal = true;
for (int i = 0; i < N; i++) {
for (int j = 0; j < N; j++) {
int x = i * N + j;
if (x - N >= 0) nodes[x]->neighbors.push_back(nodes[x-N]);
if (x + N < N * N) nodes[x]->neighbors.push_back(nodes[x+N]);
if (x % N != 0) nodes[x]->neighbors.push_back(nodes[x-1]);
if (x % N != N - 1) nodes[x]->neighbors.push_back(nodes[x+1]);
}
}
nodes[0]->Count();
cout << "Count all paths from (0, 0) to ("
<< N - 1 << ", " << N - 1 << ") in the "
<< N << "x" << N << " lattice graph:" << endl;
cout << "Count: " << cnt << endl;
return 0;
}

実行方法は以下の通り。引数の7は7×7のノードを表しており、マスで表現すれば6×6マスとなる。そして計算結果として、575,780,564通りの経路が探索されたことを示している。因みに7×7ノード(6×6マス)で約3分ほどの時間が掛かっており、それ以上の大きさだと現実的な時間で解くことが難しくなってくる。

% ./count_lattice_path_dfs 7 Count all paths from (0, 0) to (6, 6) in the 7x7 lattice graph: Count: 575780564

次に、ZDDを利用したプログラムのビルド方法を示す。

ビルドをする前にSGBおよびGMPを準備しておく。また、SIMPATHおよびSIMPATH-REDUCEについてはCWEBを利用しているので、それも入れておく。

最初に、SGBライブラリを用いて格子状のグラフを作成する。

% gcc gen_lattice.c -O3 -lgb -o gen_lattice

ソースコード内でグラフを下記のように作成している。board関数はチェスのような格子状のグラフを作成することができる。他にもグラフを作成する便利な関数が多数あるので、興味のある方はSGBライブラリのgb_basic.wあたりを参照して欲しい。

Graph* g = board(N, M, 0L, 0L, 1L, 0L, 0L);

/* Generate lattice graphs. */
#include "gb_graph.h"
#include "gb_basic.h"
#include "gb_save.h"
int main(int argc, char* argv[])
{
if (argc < 4) {
fprintf(stderr, "Usage %s N M outfile\n", argv[0]);
exit(1);
}
int N = atoi(argv[1]);
int M = atoi(argv[2]);
Graph* g;
g = board(N, M, 0L, 0L, 1L, 0L, 0L);
save_graph(g, argv[3]);
if (g) {
register Vertex* v;
printf("Graph->ID: %s\n", g->id);
printf("N of vertices: %ld, N of arcs: %ld\n\n", g->n, g->m);
for (v = g->vertices; v < g->vertices + g->n; v++) {
printf("%s\n", v->name);
register Arc* a;
for(a = v->arcs; a; a= a->next)
printf(" -> %s, length %ld\n", a->tip->name, a->len);
}
} else printf("Something went wrong (panic code %ld)!\n", panic_code);
return 0;
}
view raw gen_lattice.c hosted with ❤ by GitHub

次に、SIMPATHおよびSIMPATH-REDUCEを以下のようにビルドする。

% wget http://www-cs-faculty.stanford.edu/~uno/programs/simpath.w % ctangle simpath.w % gcc simpath.c -O3 -lgb -o simpath

% wget http://www-cs-faculty.stanford.edu/~uno/programs/simpath-reduce.w % ctangle simpath-reduce.w % gcc simpath-reduce.c -O3 -lgb -o simpath-reduce

因みに、ctangleでC言語への変換、cweaveでTeX形式への変換を行うことができるので、以下のようにTeX形式のドキュメントも一緒に作っておくことをお勧めする。また、これらのアルゴリズムを理解するために前述のTAOCP Vol. 4を読んでおくことが望ましい。

% cweave simpath.w % tex simpath.tex % dvipdf simpath.dvi # PDFファイルが必要なら. % cweave simpath-reduce.w % tex simpath-reduce.tex % dvipdf simpath-reduce.dvi # PDFファイルが必要なら.

最後に、ZDDから経路を数え上げるプログラムをコンパイルする。

% g++ count_path_mp.cpp -O3 -lgmp -o count_path_mp

経路を一つ一つ数えていては高速に数え上げるという点で意味がなくなってしまうので、ちゃんとメモ化再帰で数える。ZDDは高度に圧縮されているのでメモ化再帰でほとんど瞬時に数え上げることができる。

// Count paths on a lattice graph.
#include <iostream>
#include <fstream>
#include <sstream>
#include <vector>
#include <algorithm>
#include <gmp.h>
using namespace std;
struct Node {
int idx;
Node* child[3];
mpz_t cnt;
Node(int idx) : idx(idx) {
mpz_init_set_ui(cnt, 0);
fill(child, child + 3, static_cast<Node*>(NULL));
}
void Count(mpz_t& n) {
if (mpz_cmp_ui(cnt, 0) > 0) {
mpz_set(n, cnt);
return;
}
mpz_t w;
mpz_init(w);
for (int i = 0; child[i]; i++) {
child[i]->Count(w);
mpz_add(cnt, cnt, w);
}
mpz_clear(w);
mpz_set(n, cnt);
}
};
void read_data(const char* filename, vector<Node*>& nodes)
{
fstream fs(filename, ios_base::in);
string line;
stringstream ss;
int x[4];
int n;
getline(fs, line);
ss.str(line);
ss >> n;
for (int i = 0; i < n; i++) nodes.push_back(new Node(i));
mpz_set_ui(nodes[1]->cnt, 1);
while (getline(fs, line)) {
ss.clear(); ss.str(line);
for (int i = 0; ss >> x[i]; i++);
n = 0;
for (int i = 1; i < 3; i++) {
if (x[i] == 0) continue;
nodes[x[0]]->child[n++] = nodes[x[i]];
}
}
}
int main(int argc, char* argv[])
{
if (argc < 2) {
cerr << "Usage: " << argv[0] << " datafile" <<endl;
exit(1);
}
vector<Node*> nodes;
read_data(argv[1], nodes);
mpz_t cnt;
mpz_init(cnt);
nodes[2]->Count(cnt);
cout << "Count: " << mpz_get_str(NULL, 10, cnt) << endl;
mpz_clear(cnt);
return 0;
}

そして、上記のプログラムを統括するPythonスクリプトが以下となる。グルー言語としてもPythonは優秀だ。

#!/usr/bin/env python
"""
Count all paths from upper left to lower right on NxN lattice graph.
Build (required SGB and GMP libraries):
% gcc gen_lattice.c -O3 -lgb -o gen_lattice
% wget http://www-cs-faculty.stanford.edu/~uno/programs/simpath.w
% ctangle simpath.w
% gcc simpath.c -O3 -lgb -o simpath
% wget http://www-cs-faculty.stanford.edu/~uno/programs/simpath-reduce.w
% ctangle simpath-reduce.w
% gcc simpath-reduce.c -O3 -lgb -o simpath-reduce
% g++ count_path_mp.cpp -O3 -lgmp -o count_path_mp
"""
import sys, os, subprocess
COMMAND_GEN_LATTICE_GB = "./gen_lattice %d %d %s"
COMMAND_SIMPATH = "./simpath %s %s %s > %s"
COMMAND_SIMPATH_REDUCE = "./simpath-reduce < %s > %s"
COMMAND_COUNT_PATH = "./count_path_mp %s"
def renumber_zdd(infile, outfile):
data = []
for l in file(infile):
l = l.replace(":", " ").replace("(~", "").replace("?", " ").replace(")", "").strip()
x = map(lambda x: int(x, 16), l.split())
data.append([x[0], x[2], x[3]])
data.sort()
m = {0:0, 1:1}
cnt = 2
for d in data:
m[d[0]] = cnt
cnt += 1
for i in range(len(data)):
for j in range(3):
data[i][j] = m[data[i][j]]
outf = file(outfile, "w")
print >>outf, len(data) + 2
for d in data:
print >>outf, "%d %d %d" % tuple(d)
def main(args):
if len(args) < 2:
print >>sys.stderr, "Usage: %s N" % os.path.basename(args[0])
sys.exit()
N = int(args[1])
fname = "lattice_%02d" % N
log = subprocess.Popen(COMMAND_GEN_LATTICE_GB % (N, N, fname + ".gb"), shell=True, stdin=None, stdout=subprocess.PIPE, stderr=subprocess.STDOUT).stdout.read()
log += subprocess.Popen(COMMAND_SIMPATH % (fname + ".gb", "0.0", "%d.%d" % (N - 1, N - 1), fname + ".out"), shell=True, stdin=None, stdout=subprocess.PIPE, stderr=subprocess.STDOUT).stdout.read()
log += subprocess.Popen(COMMAND_SIMPATH_REDUCE % (fname + ".out", fname + ".out.r"), shell=True, stdin=None, stdout=subprocess.PIPE, stderr=subprocess.STDOUT).stdout.read()
renumber_zdd(fname + ".out.r", fname + ".out.rc")
file(fname + ".log", "w").write(log)
print "Count all paths from (0, 0) to (%d, %d) in the %dx%d lattice graph:" % (N - 1, N - 1, N, N)
sys.stdout.flush()
os.system(COMMAND_COUNT_PATH % fname + ".out.rc")
if __name__ == "__main__": main(sys.argv)

実行方法は以下の通り。

% ./count_lattice_path.py 8 Count all paths from (0, 0) to (7, 7) in the 8x8 lattice graph: Count: 789360053252

DFSのプログラム同様、引数の8は8×8ノード(7×7マス)を表している。そして計算結果として、789,360,053,252通りの経路が探索されたことを示している。また、中間ファイルとして以下が出力される。

lattice_08.gb # SGBによるグラフデータ. lattice_08.out # SIMPATHの出力: not-necessarily-reduced BDD lattice_08.out.r # SIMPATH-REDUCEの出力: ZDD lattice_08.out.rc # SIMPATH-REDUCEの出力を再番付. lattice_08.log # 一連の実行ファイルのログ.

因みに、DFSで7×7ノード(6×6マス)を解くと3分ほどかかったが、ZDDでは0.1秒未満で完了した。また、ZDDで14×14ノード(13×13マス)は数分で終わることを確認したが、それ以上は途中で搭載メモリ以上に必要なメモリが大きくなるので実行していない。

コメント