素数表を作ってみる

Google Code Jamのこの問題Haskellで解こうと思い、アルゴリズムが思いついたのでコードを書こうとして気付きました。Haskell素数表ってどうやって作るんだ?

ちなみにJavaやCではbooleanの配列を用意し、2から初めて倍数をfalseにしていき、次に3を見て倍数をfalseにしていき...っていう方法をよく使います。有名なアリストテレスの篩です。
ただ、Haskellでは変数の値が基本的に変えられないので、この方法は駄目そうです。

調べてみたところ、Haskellはリストの構造を上手く活かしたコードでアリストテレスの篩を非常に短く書くことが出来ます。以下はTech Tipsに載っていた有名らしいコード。

eratos []     = [] 
eratos (x:xs) = x : eratos [y | y <- xs, y `mod` x /= 0]  

main = do
	print "start."
	print $ length $ eratos [3,5..999999]
	print "end."

eratos (x : xs)の(x : xs)はパターンマッチングと呼ばれる表現法で、eratos [1,2,3]とすると、x = 1, xs = [2,3]となります。xがSchemeでいうcarで、xsがcdrって感じなのかな。
/=はJavaやCの!=と同じでNot Equalを表します。

このコードは簡潔でいいのですが、いかんせん遅い。1000000までの素数を求めるのに、15分やっても終わりませんでした。100000まででも5秒くらいかかります。

上記のTech Tipsで言及されていますが、リストの基本操作それぞれにO(n)の時間がかかってしまうために遅くなっているようです。「100万くらいなら操作がO(logn)で行えるSetで十分」ということだったので、Setを使ってみることに。

import Data.Setと書いてコンパイルするもなぜかエラー。こんなエラーです。

Undefined symbols:
"___stginit_containerszm0zi3zi0zi0_DataziSet_", referenced from:
___stginit_Main_ in test.o
ld: symbol(s) not found
collect2: ld returned 1 exit status

しかも対話環境のghciでは上手く動いています。
よく分からなかったので、検索してみると同じような症状に悩んでいる人を発見。
解答を見てみると、今までは「ghc test.hs」でコンパイルしていましたが、「ghc --make test.hs -o a.out」でコンパイルすればいいそうです。--makeをつけるとGHCでもGHCiと同じように必要なモジュールを探してくれるそうです。

確かにコンパイルが通るようになったのでソースの実装開始。
今回はIntの範囲なので、Data.IntSetを使いました。

とりあえず2から順に倍数をIntSetに入れていき、素数でない数字の集合を作ろうとやってみたらスタックオーバーフロー。

Stack space overflow: current size 8388608 bytes.
Use `+RTS -Ksize -RTS' to increase it.

逆に素数の方をIntSetに入れていったらどうなのかなとやってみてもスタックオーバーフロー。

関数の呼び出し過ぎが原因のようなので、偶数を始めから省いてみたところ、成功。やっぱ50万回ってデカイですよね。

import Data.IntSet
import qualified Data.IntSet as IntSet

prime = insert 2 $ check (fromList [3,5..999999]) 3 1000000


check s n lim
	| n > lim = s
	| notMember n s = check s (n+2) lim
	| otherwise = check (del (n + 2*n) n lim s) (n+2) lim

del n m lim s
	| n <= lim = del (n + 2*m) m lim $ delete n s
	| otherwise = s


main = do
	print "start."
	print $ size prime
	print "end."

始めに奇数をIntSetに入れておき、IntSetの下の数字から見て倍数を消していっています。以下がtimeコマンドで測った実行結果です。

"start."
78498
"end."

real 0m2.547s
user 0m2.433s
sys 0m0.092s

こういう問題を解くにはまだ全然速度が遅いけど、とりあえず今やろうと思っている問題にはこれで十分なので満足。10万くらいなら0.1秒で出してくれます。




追記 (2012/09/20)

よく考えたらアリストテレスの篩ってnじゃなくてsqrt(n)までやればいいんですよね。
ちょっと修正してみました。

import Data.IntSet
import qualified Data.IntSet as IntSet

prime n  = insert 2 $ check (fromList [3,5..n]) 3 n $ (ceiling . sqrt . fromIntegral) n

check s n lim sqlim
	| n > sqlim = s
	| notMember n s = check s (n+2) lim sqlim
	| otherwise = check (del (n + 2*n) n lim s) (n+2) lim sqlim

del n m lim s
	| n <= lim = del (n + 2*m) m lim $ delete n s
	| otherwise = s

main = do
	print "start."
	print $ size $ prime 1000000
	print "end."

checkは1000000までじゃなくて1000まで見れば十分でした。
0.5秒くらい速くなりました。