pure Knuth/Fisher-Yates shuffle in haskell - algorithm

In go I could write a function like this:
func pureFisherYates(s []int, swaps []int) []int {
newS := copy(s)
for i, _ := range newS {
for _, j := range swaps {
newS[i], newS[j] = newS[j], newS[i]
}
}
}
To me this seems like a pure function. It always returns the same output given the same input, and it doesn't mutate the state of the world (except in some strict sense the same way that any other function does, taking up cpu resources, creating heat energy, etc). Yet whenever I look for how to do pure shuffling I find stuff like this, and whenever I look for specifically Haskell implementation Fisher-Yates I either get an 0^2 Fisher-Yates implemented with a list or a [a] -> IO [a] implementation. Does there exist a [a] -> [a] O(n) shuffle and if not why is my above go implementation impure.

The ST monad allows exactly such encapsulated mutability, and Data.Array.ST contains arrays which can be mutated in ST and then an immutable version returned outside.
https://wiki.haskell.org/Random_shuffle gives two implementations of Fisher-Yates shuffle using ST. They aren't literally [a] -> [a], but that's because random number generation needs to be handled as well:
import System.Random
import Data.Array.ST
import Control.Monad
import Control.Monad.ST
import Data.STRef
-- | Randomly shuffle a list without the IO Monad
-- /O(N)/
shuffle' :: [a] -> StdGen -> ([a],StdGen)
shuffle' xs gen = runST (do
g <- newSTRef gen
let randomRST lohi = do
(a,s') <- liftM (randomR lohi) (readSTRef g)
writeSTRef g s'
return a
ar <- newArray n xs
xs' <- forM [1..n] $ \i -> do
j <- randomRST (i,n)
vi <- readArray ar i
vj <- readArray ar j
writeArray ar j vi
return vj
gen' <- readSTRef g
return (xs',gen'))
where
n = length xs
newArray :: Int -> [a] -> ST s (STArray s Int a)
newArray n xs = newListArray (1,n) xs
and
import Control.Monad
import Control.Monad.ST
import Control.Monad.Random
import System.Random
import Data.Array.ST
import GHC.Arr
shuffle :: RandomGen g => [a] -> Rand g [a]
shuffle xs = do
let l = length xs
rands <- forM [0..(l-2)] $ \i -> getRandomR (i, l-1)
let ar = runSTArray $ do
ar <- thawSTArray $ listArray (0, l-1) xs
forM_ (zip [0..] rands) $ \(i, j) -> do
vi <- readSTArray ar i
vj <- readSTArray ar j
writeSTArray ar j vi
writeSTArray ar i vj
return ar
return (elems ar)
*Main> evalRandIO (shuffle [1..10])
[6,5,1,7,10,4,9,2,8,3]
EDIT: with a fixed swaps argument as in your Go code, the code is quite simple
{-# LANGUAGE ScopedTypeVariables #-}
import Data.Array.ST
import Data.Foldable
import Control.Monad.ST
shuffle :: forall a. [a] -> [Int] -> [a]
shuffle xs swaps = runST $ do
let n = length xs
ar <- newListArray (1,n) xs :: ST s (STArray s Int a)
for_ [1..n] $ \i ->
for_ swaps $ \j -> do
vi <- readArray ar i
vj <- readArray ar j
writeArray ar j vi
writeArray ar i vj
getElems ar
but I am not sure you can reasonably call it Fisher-Yates shuffle.

Related

Is runInBoundThread the best tool for parallelism?

Say, I want to fold monoids in parallel. My computer has 8 cores. I have this function to split a list into equal-sized smaller lists (with bounded modulo-bias):
import Data.List
parallelize :: Int -> [a] -> [[a]]
parallelize 0 _ = []
parallelize n [] = replicate n []
parallelize n xs = let
(us,vs) = splitAt (quot (length xs) n) xs
in us : parallelize (n-1) vs
The first version of parallel fold I made was:
import Control.Concurrent
import Control.Concurrent.QSemN
import Data.Foldable
import Data.IORef
foldP :: Monoid m => [m] -> IO m
foldP xs = do
result <- newIORef mempty
sem <- newQSemN 0
n <- getNumCapabilities
let yss = parallelize n xs
for_ yss (\ys -> forkIO (modifyIORef result (fold ys <>) >> signalQSemN sem 1))
waitQSemN sem n
readIORef result
But usage of IORefs and semaphores seemed ugly to me. So I made another version:
import Data.Traversable
foldP :: Monoid m => [m] -> IO m
foldP xs = do
n <- getNumCapabilities
let yss = parallelize n xs
rs <- for yss (\ys -> runInUnboundThread (return (fold ys)))
return (fold rs)
The test code I used is:
import Data.Monoid
import System.CPUTime
main :: IO ()
main = do
start <- getCPUTime
Product result <- foldP (fmap Product [1 .. 100])
end <- getCPUTime
putStrLn ("Time took: " ++ show (end - start) ++ "ps.")
putStrLn ("Result: " ++ show result)
The second version of foldP outperformed the first version. When I used runInBoundThread instead of runInUnboundThread, it became even faster.
By what are these performance differences made?
TLDR; Use fold function from massiv package and you will likely get the most efficient solution in Haskell.
I would like to start by saying that the first thing that people forget when trying to implement concurrent patterns like this is exception handling. In the solution from the question the exception handling is non-existent thus it is totally wrong. Therefore I'd recommend to use existing implementations for common concurrency patterns. async is the goto library for concurrency, but for such use case it will not be the most efficient solution.
This particular example can easily be solved with scheduler package, in fact it is exactly the kind of stuff it was designed for. Here is how you can use it to achieve folding of monoids:
import Control.Scheduler
import Control.Monad.IO.Unlift
foldP :: (MonadUnliftIO m, Monoid n) => Comp -> [n] -> m n
foldP comp xs = do
rs <-
withScheduler comp $ \scheduler ->
mapM_ (scheduleWork scheduler . pure . fold) (parallelize (numWorkers scheduler) xs)
pure $ fold rs
See the Comp type for explanation on best parallelization strategies. From what I found in practice Par will usually work best, because it will use pinned threads created with forkOn
Note that the parallelize function is implemented inefficiently and dangerously as well, it is better to write it this way:
parallelize :: Int -> [a] -> [[a]]
parallelize n' xs' = go 0 id xs'
where
n = max 1 n'
-- at least two elements make sense to get benefit of parallel fold
k = max 2 $ quot (length xs') n
go i acc xs
| null xs = acc []
| i < n =
case splitAt k xs of
(ls, rs) -> go (i + 1) (acc . (ls :)) rs
| otherwise = acc . (xs:) $ []
One more bit of advise is that list is far from ideal data structure for parallelization and efficiency in general. In order to split the lists into chunks before parallelizing computation you already have to go through the data structure with parallelize, which can be avoided if you were to use an array. What I am getting at is use an array instead, as suggested in the beginning of this answer.

Why is my Haskell selection sort implementation extremely fast?

I implemented selection sort and compared it to Data.List's sort. It is orders of magnitudes faster than Data.List's sort. If I apply it to 10,000 randomly generated numbers the results are as follows:
✓ in 1.22µs: Selection sort
✓ in 9.84ms: Merge sort (Data.List)
This can't be right. First I thought maybe merge sort's intermediate results are cached and selection sort uses those to be much faster. Even when I comment out merge sort and only time selection sort, it is this fast however. I also verified the output and it is correctly sorted.
What causes this behaviour?
I use this code to test:
{-# LANGUAGE BangPatterns #-}
module Lib
( testSortingAlgorithms
) where
import System.Random (randomRIO)
import Text.Printf
import Control.Exception
import System.CPUTime
import Data.List (sort, sortOn)
selectionSort :: Ord a => [a] -> [a]
selectionSort [] = []
selectionSort nrs =
let (smallest, rest) = getSmallest nrs
in smallest : selectionSort rest
where getSmallest :: Ord a => [a] -> (a, [a])
getSmallest [a] = (a, [])
getSmallest (a:as) = let (smallest, rest) = getSmallest as
in if smallest > a then (a, smallest : rest)
else (smallest, a : rest)
main :: IO ()
main = testSortingAlgorithms
testSortingAlgorithms :: IO ()
testSortingAlgorithms = do
!list' <- list (10000)
results <- mapM (timeIt list') sorts
let results' = sortOn fst results
mapM_ (\(diff, msg) -> printf (msg) (diff::Double)) results'
return ()
sorts :: Ord a => [(String, [a] -> [a])]
sorts = [
("Selection sort", selectionSort)
, ("Merge sort (Data.List)", sort)
]
list :: Int -> IO [Int]
list n = sequence $ replicate n $ randomRIO (-127,127::Int)
timeIt :: (Ord a, Show a)
=> [a] -> (String, [a] -> [a]) -> IO (Double, [Char])
timeIt vals (name, sorter) = do
start <- getCPUTime
--v <- sorter vals `seq` return ()
let !v = sorter vals
--putStrLn $ show v
end <- getCPUTime
let (diff, ext) = unit $ (fromIntegral (end - start)) / (10^3)
let msg = if correct v
then (" ✓ in %0.2f" ++ ext ++ ": " ++ name ++ "\n")
else (" ✗ in %0.2f" ++ ext ++ ": " ++ name ++ "\n")
return (diff, msg)
correct :: (Ord a) => [a] -> Bool
correct [] = True
correct (a:[]) = True
correct (a1:a2:as) = a1 <= a2 && correct (a2:as)
unit :: Double -> (Double, String)
unit v | v < 10^3 = (v, "ns")
| v < 10^6 = (v / 10^3, "µs")
| v < 10^9 = (v / 10^6, "ms")
| otherwise = (v / 10^9, "s")
You write
let !v = sorter vals
which is "strict", but only to WHNF. So you are only timing how long it takes to find the smallest element of the list, not how long it takes to sort the whole thing. Selection sort starts by doing exactly that, so it is "optimal" for this incorrect benchmark, while mergesort does a bunch more work that's "wasted" if you only look at the first element.

Haskell: Sort an almost-sorted array

I've been learning Haskell in my spare time working through LYAH. Would like to improve upon my Haskell (/ Functional programming) skills by solving some problems from the imperative world. One of the problems from EPI is to print an "almost sorted array", in a sorted fashion where it is guaranteed that no element in the array is more than k away from its correct position. The input is a stream of elements and the requirement is to do this in O(n log k) time complexity and O(k) space complexity.
I've attempted to re-implement the imperative solution in Haskell as follows:
import qualified Data.Heap as Heap
-- print the k-sorted list in a sorted fashion
ksorted :: (Ord a, Show a) => [a] -> Int -> IO ()
ksorted [] _ = return ()
ksorted xs k = do
heap <- ksorted' xs Heap.empty
mapM_ print $ (Heap.toAscList heap) -- print the remaining elements in the heap.
where
ksorted' :: (Ord a, Show a) => [a] -> Heap.MinHeap a -> IO (Heap.MinHeap a)
ksorted' [] h = return h
ksorted' (x:xs) h = do let (m, h') = getMinAndBuildHeap h x in
(printMin m >> ksorted' xs h')
printMin :: (Show a) => Maybe a -> IO ()
printMin m = case m of
Nothing -> return ()
(Just item) -> print item
getMinAndBuildHeap :: (Ord a, Show a) => Heap.MinHeap a -> a -> (Maybe a, Heap.MinHeap a)
getMinAndBuildHeap h item= if (Heap.size h) > k
then ((Heap.viewHead h), (Heap.insert item (Heap.drop 1 h)))
else (Nothing, (Heap.insert item h))
I would like to know a better way of solving this in Haskell. Any inputs would be appreciated.
[Edit 1]: The input is stream, but for now I assumed a list instead (with only a forward iterator/ input iterator in some sense.)
[Edit 2]: added Data.Heap import to the code.
Thanks.
I think the main improvement is to separate the production of the sorted list from the printing of the sorted list. So:
import Data.Heap (MinHeap)
import qualified Data.Heap as Heap
ksort :: Ord a => Int -> [a] -> [a]
ksort k xs = go (Heap.fromList b) e where
(b, e) = splitAt (k-1) xs
go :: Ord a => MinHeap a -> [a] -> [a]
go heap [] = Heap.toAscList heap
go heap (x:xs) = x' : go heap' xs where
Just (x', heap') = Heap.view (Heap.insert x heap)
printKSorted :: (Ord a, Show a) => Int -> [a] -> IO ()
printKSorted k xs = mapM_ print (ksort k xs)
If I were feeling extra-special-fancy, I might try to turn go into a foldr or perhaps a mapAccumR, but in this case I think the explicit recursion is relatively readable, too.

How to optimize this Haskell code summing up the primes in sublinear time?

Problem 10 from Project Euler is to find the sum of all the primes below given n.
I solved it simply by summing up the primes generated by the sieve of Eratosthenes. Then I came across much more efficient solution by Lucy_Hedgehog (sub-linear!).
For n = 2⋅10^9:
Python code (from the quote above) runs in 1.2 seconds in Python 2.7.3.
C++ code (mine) runs in about 0.3 seconds (compiled with g++ 4.8.4).
I re-implemented the same algorithm in Haskell, since I'm learning it:
import Data.List
import Data.Map (Map, (!))
import qualified Data.Map as Map
problem10 :: Integer -> Integer
problem10 n = (sieve (Map.fromList [(i, i * (i + 1) `div` 2 - 1) | i <- vs]) 2 r vs) ! n
where vs = [n `div` i | i <- [1..r]] ++ reverse [1..n `div` r - 1]
r = floor (sqrt (fromIntegral n))
sieve :: Map Integer Integer -> Integer -> Integer -> [Integer] -> Map Integer Integer
sieve m p r vs | p > r = m
| otherwise = sieve (if m ! p > m ! (p - 1) then update m vs p else m) (p + 1) r vs
update :: Map Integer Integer -> [Integer] -> Integer -> Map Integer Integer
update m vs p = foldl' decrease m (map (\v -> (v, sumOfSieved m v p)) (takeWhile (>= p*p) vs))
decrease :: Map Integer Integer -> (Integer, Integer) -> Map Integer Integer
decrease m (k, v) = Map.insertWith (flip (-)) k v m
sumOfSieved :: Map Integer Integer -> Integer -> Integer -> Integer
sumOfSieved m v p = p * (m ! (v `div` p) - m ! (p - 1))
main = print $ problem10 $ 2*10^9
I compiled it with ghc -O2 10.hs and run with time ./10.
It gives the correct answer, but takes about 7 seconds.
I compiled it with ghc -prof -fprof-auto -rtsopts 10 and run with ./10 +RTS -p -h.
10.prof shows that decrease takes 52.2% time and 67.5% allocations.
After running hp2ps 10.hp I got such heap profile:
Again looks like decrease takes most of the heap. GHC version 7.6.3.
How would you optimize run time of this Haskell code?
Update 13.06.17:
I tried replacing immutable Data.Map with mutable Data.HashTable.IO.BasicHashTable from the hashtables package, but I'm probably doing something bad, since for tiny n = 30 it already takes too long, about 10 seconds. What's wrong?
Update 18.06.17:
Curious about the HashTable performance issues is a good read. I took Sherh's code using mutable Data.HashTable.ST.Linear, but dropped Data.Judy in instead. It runs in 1.1 seconds, still relatively slow.
I've done some small improvements so it runs in 3.4-3.5 seconds on my machine.
Using IntMap.Strict helped a lot. Other than that I just manually performed some ghc optimizations just to be sure. And make Haskell code more close to Python code from your link. As a next step you could try to use some mutable HashMap. But I'm not sure... IntMap can't be much faster than some mutable container because it's an immutable one. Though I'm still surprised about it's efficiency. I hope this can be implemented faster.
Here is the code:
import Data.List (foldl')
import Data.IntMap.Strict (IntMap, (!))
import qualified Data.IntMap.Strict as IntMap
p :: Int -> Int
p n = (sieve (IntMap.fromList [(i, i * (i + 1) `div` 2 - 1) | i <- vs]) 2 r vs) ! n
where vs = [n `div` i | i <- [1..r]] ++ [n', n' - 1 .. 1]
r = floor (sqrt (fromIntegral n) :: Double)
n' = n `div` r - 1
sieve :: IntMap Int -> Int -> Int -> [Int] -> IntMap Int
sieve m' p' r vs = go m' p'
where
go m p | p > r = m
| m ! p > m ! (p - 1) = go (update m vs p) (p + 1)
| otherwise = go m (p + 1)
update :: IntMap Int -> [Int] -> Int -> IntMap Int
update s vs p = foldl' decrease s (takeWhile (>= p2) vs)
where
sp = s ! (p - 1)
p2 = p * p
sumOfSieved v = p * (s ! (v `div` p) - sp)
decrease m v = IntMap.adjust (subtract $ sumOfSieved v) v m
main :: IO ()
main = print $ p $ 2*10^(9 :: Int)
UPDATE:
Using mutable hashtables I've managed to make performance up to ~5.5sec on Haskell with this implementation.
Also, I used unboxed vectors instead of lists in several places. Linear hashing seems to be the fastest. I think this can be done even faster. I noticed sse42 option in hasthables package. Not sure I've managed to set it correctly but even without it runs that fast.
UPDATE 2 (19.06.2017)
I've managed to make it 3x faster then best solution from #Krom (using my code + his map) by dropping judy hashmap at all. Instead just plain arrays are used. You can come up with the same idea if you notice that keys for S hashmap are either sequence from 1 to n' or n div i for i from 1 to r. So we can represent such HashMap as two arrays making lookups in array depending on searching key.
My code + Judy HashMap
$ time ./judy
95673602693282040
real 0m0.590s
user 0m0.588s
sys 0m0.000s
My code + my sparse map
$ time ./sparse
95673602693282040
real 0m0.203s
user 0m0.196s
sys 0m0.004s
This can be done even faster if instead of IOUArray already generated vectors and Vector library is used and readArray is replaced by unsafeRead. But I don't think this should be done if only you're not really interested in optimizing this as much as possible.
Comparison with this solution is cheating and is not fair. I expect same ideas implemented in Python and C++ will be even faster. But #Krom solution with closed hashmap is already cheating because it uses custom data structure instead of standard one. At least you can see that standard and most popular hash maps in Haskell are not that fast. Using better algorithms and better ad-hoc data structures can be better for such problems.
Here's resulting code.
First as a baseline, the timings of the existing approaches
on my machine:
Original program posted in the question:
time stack exec primorig
95673602693282040
real 0m4.601s
user 0m4.387s
sys 0m0.251s
Second the version using Data.IntMap.Strict from
here
time stack exec primIntMapStrict
95673602693282040
real 0m2.775s
user 0m2.753s
sys 0m0.052s
Shershs code with Data.Judy dropped in here
time stack exec prim-hash2
95673602693282040
real 0m0.945s
user 0m0.955s
sys 0m0.028s
Your python solution.
I compiled it with
python -O -m py_compile problem10.py
and the timing:
time python __pycache__/problem10.cpython-36.opt-1.pyc
95673602693282040
real 0m1.163s
user 0m1.160s
sys 0m0.003s
Your C++ version:
$ g++ -O2 --std=c++11 p10.cpp -o p10
$ time ./p10
sum(2000000000) = 95673602693282040
real 0m0.314s
user 0m0.310s
sys 0m0.003s
I didn't bother to provide a baseline for slow.hs, as I didn't
want to wait for it to complete when run with an argument of
2*10^9.
Subsecond performance
The following program runs in under a second on my machine.
It uses a hand rolled hashmap, which uses closed hashing with
linear probing and uses some variant of knuths hashfunction,
see here.
Certainly it is somewhat tailored to the case, as the lookup
function for example expects the searched keys to be present.
Timings:
time stack exec prim
95673602693282040
real 0m0.725s
user 0m0.714s
sys 0m0.047s
First I implemented my hand rolled hashmap simply to hash
the keys with
key `mod` size
and selected a size multiple times higher than the expected
input, but the program took 22s or more to complete.
Finally it was a matter of choosing a hash function which was
good for the workload.
Here is the program:
import Data.Maybe
import Control.Monad
import Data.Array.IO
import Data.Array.Base (unsafeRead)
type Number = Int
data Map = Map { keys :: IOUArray Int Number
, values :: IOUArray Int Number
, size :: !Int
, factor :: !Int
}
newMap :: Int -> Int -> IO Map
newMap s f = do
k <- newArray (0, s-1) 0
v <- newArray (0, s-1) 0
return $ Map k v s f
storeKey :: IOUArray Int Number -> Int -> Int -> Number -> IO Int
storeKey arr s f key = go ((key * f) `mod` s)
where
go :: Int -> IO Int
go ind = do
v <- readArray arr ind
go2 v ind
go2 v ind
| v == 0 = do { writeArray arr ind key; return ind; }
| v == key = return ind
| otherwise = go ((ind + 1) `mod` s)
loadKey :: IOUArray Int Number -> Int -> Int -> Number -> IO Int
loadKey arr s f key = s `seq` key `seq` go ((key *f) `mod` s)
where
go :: Int -> IO Int
go ix = do
v <- unsafeRead arr ix
if v == key then return ix else go ((ix + 1) `mod` s)
insertIntoMap :: Map -> (Number, Number) -> IO Map
insertIntoMap m#(Map ks vs s f) (k, v) = do
ix <- storeKey ks s f k
writeArray vs ix v
return m
fromList :: Int -> Int -> [(Number, Number)] -> IO Map
fromList s f xs = do
m <- newMap s f
foldM insertIntoMap m xs
(!) :: Map -> Number -> IO Number
(!) (Map ks vs s f) k = do
ix <- loadKey ks s f k
readArray vs ix
mupdate :: Map -> Number -> (Number -> Number) -> IO ()
mupdate (Map ks vs s fac) i f = do
ix <- loadKey ks s fac i
old <- readArray vs ix
let x' = f old
x' `seq` writeArray vs ix x'
r' :: Number -> Number
r' = floor . sqrt . fromIntegral
vs' :: Integral a => a -> a -> [a]
vs' n r = [n `div` i | i <- [1..r]] ++ reverse [1..n `div` r - 1]
vss' n r = r + n `div` r -1
list' :: Int -> Int -> [Number] -> IO Map
list' s f vs = fromList s f [(i, i * (i + 1) `div` 2 - 1) | i <- vs]
problem10 :: Number -> IO Number
problem10 n = do
m <- list' (19*vss) (19*vss+7) vs
nm <- sieve m 2 r vs
nm ! n
where vs = vs' n r
vss = vss' n r
r = r' n
sieve :: Map -> Number -> Number -> [Number] -> IO Map
sieve m p r vs | p > r = return m
| otherwise = do
v1 <- m ! p
v2 <- m ! (p - 1)
nm <- if v1 > v2 then update m vs p else return m
sieve nm (p + 1) r vs
update :: Map -> [Number] -> Number -> IO Map
update m vs p = foldM (decrease p) m $ takeWhile (>= p*p) vs
decrease :: Number -> Map -> Number -> IO Map
decrease p m k = do
v <- sumOfSieved m k p
mupdate m k (subtract v)
return m
sumOfSieved :: Map -> Number -> Number -> IO Number
sumOfSieved m v p = do
v1 <- m ! (v `div` p)
v2 <- m ! (p - 1)
return $ p * (v1 - v2)
main = do { n <- problem10 (2*10^9) ; print n; } -- 2*10^9
I am not a professional with hashing and that sort of stuff, so
this can certainly be improved a lot. Maybe we Haskellers should
improve the of the shelf hash maps or provide some simpler ones.
My hashmap, Shershs code
If I plug my hashmap in Shershs (see answer below) code, see here
we are even down to
time stack exec prim-hash2
95673602693282040
real 0m0.601s
user 0m0.604s
sys 0m0.034s
Why is slow.hs slow?
If you read through the source
for the function insert in Data.HashTable.ST.Basic, you
will see that it deletes the old key value pair and inserts
a new one. It doesn't look up the "place" for the value and
mutate it, as one might imagine, if one reads that it is
a "mutable" hashtable. Here the hashtable itself is mutable,
so you don't need to copy the whole hashtable for insertion
of a new key value pair, but the value places for the pairs
are not. I don't know if that is the whole story of slow.hs
being slow, but my guess is, it is a pretty big part of it.
A few minor improvements
So that's the idea I followed while trying to improve
your program the first time.
See, you don't need a mutable mapping from keys to values.
Your key set is fixed. You want a mapping from keys to mutable
places. (Which is, by the way, what you get from C++ by default.)
And so I tried to come up with that. I used IntMap IORef from
Data.IntMap.Strict and Data.IORef first and got a timing
of
tack exec prim
95673602693282040
real 0m2.134s
user 0m2.141s
sys 0m0.028s
I thought maybe it would help to work with unboxed values
and to get that, I used IOUArray Int Int with 1 element
each instead of IORef and got those timings:
time stack exec prim
95673602693282040
real 0m2.015s
user 0m2.018s
sys 0m0.038s
Not much of a difference and so I tried to get rid of bounds
checking in the 1 element arrays by using unsafeRead and
unsafeWrite and got a timing of
time stack exec prim
95673602693282040
real 0m1.845s
user 0m1.850s
sys 0m0.030s
which was the best I got using Data.IntMap.Strict.
Of course I ran each program multiple times to see if
the times are stable and the differences in run time aren't
just noise.
It looks like these are all just micro-optimizations.
And here is the program that ran fastest for me without using a hand rolled data structure:
import qualified Data.IntMap.Strict as M
import Control.Monad
import Data.Array.IO
import Data.Array.Base (unsafeRead, unsafeWrite)
type Number = Int
type Place = IOUArray Number Number
type Map = M.IntMap Place
tupleToRef :: (Number, Number) -> IO (Number, Place)
tupleToRef = traverse (newArray (0,0))
insertRefs :: [(Number, Number)] -> IO [(Number, Place)]
insertRefs = traverse tupleToRef
fromList :: [(Number, Number)] -> IO Map
fromList xs = M.fromList <$> insertRefs xs
(!) :: Map -> Number -> IO Number
(!) m i = unsafeRead (m M.! i) 0
mupdate :: Map -> Number -> (Number -> Number) -> IO ()
mupdate m i f = do
let place = m M.! i
old <- unsafeRead place 0
let x' = f old
-- make the application of f strict
x' `seq` unsafeWrite place 0 x'
r' :: Number -> Number
r' = floor . sqrt . fromIntegral
vs' :: Integral a => a -> a -> [a]
vs' n r = [n `div` i | i <- [1..r]] ++ reverse [1..n `div` r - 1]
list' :: [Number] -> IO Map
list' vs = fromList [(i, i * (i + 1) `div` 2 - 1) | i <- vs]
problem10 :: Number -> IO Number
problem10 n = do
m <- list' vs
nm <- sieve m 2 r vs
nm ! n
where vs = vs' n r
r = r' n
sieve :: Map -> Number -> Number -> [Number] -> IO Map
sieve m p r vs | p > r = return m
| otherwise = do
v1 <- m ! p
v2 <- m ! (p - 1)
nm <- if v1 > v2 then update m vs p else return m
sieve nm (p + 1) r vs
update :: Map -> [Number] -> Number -> IO Map
update m vs p = foldM (decrease p) m $ takeWhile (>= p*p) vs
decrease :: Number -> Map -> Number -> IO Map
decrease p m k = do
v <- sumOfSieved m k p
mupdate m k (subtract v)
return m
sumOfSieved :: Map -> Number -> Number -> IO Number
sumOfSieved m v p = do
v1 <- m ! (v `div` p)
v2 <- m ! (p - 1)
return $ p * (v1 - v2)
main = do { n <- problem10 (2*10^9) ; print n; } -- 2*10^9
If you profile that, you see that it spends most of the time in the custom lookup function (!),
don't know how to improve that further. Trying to inline (!) with {-# INLINE (!) #-}
didn't yield better results; maybe ghc already did this.
This code of mine evaluates the sum to 2⋅10^9 in 0.3 seconds and the sum to 10^12 (18435588552550705911377) in 19.6 seconds (if given sufficient RAM).
import Control.DeepSeq
import qualified Control.Monad as ControlMonad
import qualified Data.Array as Array
import qualified Data.Array.ST as ArrayST
import qualified Data.Array.Base as ArrayBase
primeLucy :: (Integer -> Integer) -> (Integer -> Integer) -> Integer -> (Integer->Integer)
primeLucy f sf n = g
where
r = fromIntegral $ integerSquareRoot n
ni = fromIntegral n
loop from to c = let go i = ControlMonad.when (to<=i) (c i >> go (i-1)) in go from
k = ArrayST.runSTArray $ do
k <- ArrayST.newListArray (-r,r) $ force $
[sf (div n (toInteger i)) - sf 1|i<-[r,r-1..1]] ++
[0] ++
[sf (toInteger i) - sf 1|i<-[1..r]]
ControlMonad.forM_ (takeWhile (<=r) primes) $ \p -> do
l <- ArrayST.readArray k (p-1)
let q = force $ f (toInteger p)
let adjust = \i j -> do { v <- ArrayBase.unsafeRead k (i+r); w <- ArrayBase.unsafeRead k (j+r); ArrayBase.unsafeWrite k (i+r) $!! v+q*(l-w) }
loop (-1) (-div r p) $ \i -> adjust i (i*p)
loop (-div r p-1) (-min r (div ni (p*p))) $ \i -> adjust i (div (-ni) (i*p))
loop r (p*p) $ \i -> adjust i (div i p)
return k
g :: Integer -> Integer
g m
| m >= 1 && m <= integerSquareRoot n = k Array.! (fromIntegral m)
| m >= integerSquareRoot n && m <= n && div n (div n m)==m = k Array.! (fromIntegral (negate (div n m)))
| otherwise = error $ "Function not precalculated for value " ++ show m
primeSum :: Integer -> Integer
primeSum n = (primeLucy id (\m -> div (m*m+m) 2) n) n
If your integerSquareRoot function is buggy (as reportedly some are), you can replace it here with floor . sqrt . fromIntegral.
Explanation:
As the name suggests it is based upon a generalization of the famous method by "Lucy Hedgehog" eventually discovered by the original poster.
It allows you to calculate many sums of the form (with p prime) without enumerating all the primes up to N and in time O(N^0.75).
Its inputs are the function f (i.e., id if you want the prime sum), its summatory function over all the integers (i.e., in that case the sum of the first m integers or div (m*m+m) 2), and N.
PrimeLucy returns a lookup function (with p prime) restricted to certain values of n: .
Try this and let me know how fast it is:
-- sum of primes
import Control.Monad (forM_, when)
import Control.Monad.ST
import Data.Array.ST
import Data.Array.Unboxed
sieve :: Int -> UArray Int Bool
sieve n = runSTUArray $ do
let m = (n-1) `div` 2
r = floor . sqrt $ fromIntegral n
bits <- newArray (0, m-1) True
forM_ [0 .. r `div` 2 - 1] $ \i -> do
isPrime <- readArray bits i
when isPrime $ do
let a = 2*i*i + 6*i + 3
b = 2*i*i + 8*i + 6
forM_ [a, b .. (m-1)] $ \j -> do
writeArray bits j False
return bits
primes :: Int -> [Int]
primes n = 2 : [2*i+3 | (i, True) <- assocs $ sieve n]
main = do
print $ sum $ primes 1000000
You can run it on ideone. My algorithm is the Sieve of Eratosthenes, and it should be quite fast for small n. For n = 2,000,000,000, the array size may be a problem, in which case you will need to use a segmented sieve. See my blog for more information about the Sieve of Eratosthenes. See this answer for information about a segmented sieve (but not in Haskell, unfortunately).

Performance of looping over an Unboxed array in Haskell

First of all, it's great. However, I came across a situation where my benchmarks turned up weird results. I am new to Haskell, and this is first time I've gotten my hands dirty with mutable arrays and Monads. The code below is based on this example.
I wrote a generic monadic for function that takes numbers and a step function rather than a range (like forM_ does). I compared using my generic for function (Loop A) against embedding an equivalent recursive function (Loop B). Having Loop A is noticeably faster than having Loop B. Weirder, having both Loop A and B together is faster than having Loop B by itself (but slightly slower than Loop A by itself).
Some possible explanations I can think of for the discrepancies. Note that these are just guesses:
Something I haven't learned yet about how Haskell extracts results from monadic functions.
Loop B faults the array in a less cache efficient manner than Loop A. Why?
I made a dumb mistake; Loop A and Loop B are actually different.
Note that in all 3 cases of having either or both Loop A and Loop B, the program produces the same output.
Here is the code. I tested it with ghc -O2 for.hs using GHC version 6.10.4 .
import Control.Monad
import Control.Monad.ST
import Data.Array.IArray
import Data.Array.MArray
import Data.Array.ST
import Data.Array.Unboxed
for :: (Num a, Ord a, Monad m) => a -> a -> (a -> a) -> (a -> m b) -> m ()
for start end step f = loop start where
loop i
| i <= end = do
f i
loop (step i)
| otherwise = return ()
primesToNA :: Int -> UArray Int Bool
primesToNA n = runSTUArray $ do
a <- newArray (2,n) True :: ST s (STUArray s Int Bool)
let sr = floor . (sqrt::Double->Double) . fromIntegral $ n+1
-- Loop A
for 4 n (+ 2) $ \j -> writeArray a j False
-- Loop B
let f i
| i <= n = do
writeArray a i False
f (i+2)
| otherwise = return ()
in f 4
forM_ [3,5..sr] $ \i -> do
si <- readArray a i
when si $
forM_ [i*i,i*i+i+i..n] $ \j -> writeArray a j False
return a
primesTo :: Int -> [Int]
primesTo n = [i | (i,p) <- assocs . primesToNA $ n, p]
main = print $ primesTo 30000000
I just tried benchmarking this with Criterion and GHC 6.12.1, and Loop A looks only slightly faster for me. I definitely don't get the weird "both together are faster than B alone" effect.
Also, if your step function really is just a step and doesn't do anything wacky with its argument, the following version of for seems a bit faster, especially for smaller arrays:
for' :: (Enum a, Num a, Ord a, Monad m) => a -> a -> (a -> a) -> (a -> m b) -> m ()
for' start end step = forM_ $ enumFromThenTo start (step start) end
Here are the results from Criterion, where loopA' is your loop A using my for', and where loopC is both A and B together:
benchmarking loopA...
mean: 2.372893 s, lb 2.370982 s, ub 2.374914 s, ci 0.950
std dev: 10.06753 ms, lb 8.820194 ms, ub 11.66965 ms, ci 0.950
benchmarking loopA'...
mean: 2.368167 s, lb 2.354312 s, ub 2.381413 s, ci 0.950
std dev: 69.50334 ms, lb 65.94236 ms, ub 73.17173 ms, ci 0.950
benchmarking loopB...
mean: 2.423160 s, lb 2.419131 s, ub 2.427260 s, ci 0.950
std dev: 20.78412 ms, lb 18.06613 ms, ub 24.99021 ms, ci 0.950
benchmarking loopC...
mean: 4.308503 s, lb 4.304875 s, ub 4.312110 s, ci 0.950
std dev: 18.48732 ms, lb 16.19325 ms, ub 21.32299 ms, ci 0.950<
And here's the code:
module Main where
import Control.Monad
import Control.Monad.ST
import Data.Array.ST
import Data.Array.Unboxed
import Criterion.Main
for :: (Num a, Ord a, Monad m) => a -> a -> (a -> a) -> (a -> m b) -> m ()
for start end step f = loop start where
loop i
| i <= end = do
f i
loop (step i)
| otherwise = return ()
for' :: (Enum a, Num a, Ord a, Monad m) => a -> a -> (a -> a) -> (a -> m b) -> m ()
for' start end step = forM_ $ enumFromThenTo start (step start) end
loopA arr n = for 4 n (+ 2) $ flip (writeArray arr) False
loopA' arr n = for' 4 n (+ 2) $ flip (writeArray arr) False
loopB arr n =
let f i | i <= n = do writeArray arr i False
f (i+2)
| otherwise = return ()
in f 4
loopC arr n = do
loopA arr n
loopB arr n
runPrimes loop n = do
let sr = floor . (sqrt::Double->Double) . fromIntegral $ n+1
a <- newArray (2,n) True :: (ST s (STUArray s Int Bool))
loop a n
forM_ [3,5..sr] $ \i -> do
si <- readArray a i
when si $
forM_ [i*i,i*i+i+i..n] $ \j -> writeArray a j False
return a
primesA n = [i | (i,p) <- assocs $ runSTUArray $ runPrimes loopA n, p]
primesA' n = [i | (i,p) <- assocs $ runSTUArray $ runPrimes loopA' n, p]
primesB n = [i | (i,p) <- assocs $ runSTUArray $ runPrimes loopB n, p]
primesC n = [i | (i,p) <- assocs $ runSTUArray $ runPrimes loopC n, p]
main = let n = 10000000 in
defaultMain [ bench "loopA" $ nf primesA n
, bench "loopA'" $ nf primesA' n
, bench "loopB" $ nf primesB n
, bench "loopC" $ nf primesC n ]
Perhaps compare and contrast with the Shootout nsieve program? in any case, the only way to know what really is happening is to look at the core (e.g. with the ghc-core tool).
{-# OPTIONS -O2 -optc-O -fbang-patterns -fglasgow-exts -optc-march=pentium4 #-}
--
-- The Computer Language Shootout
-- http://shootout.alioth.debian.org/
--
-- Contributed by Don Stewart 2005
-- nsieve over an ST monad Bool array
--
import Control.Monad.ST
import Data.Array.ST
import Data.Array.Base
import System
import Control.Monad
import Data.Bits
import Text.Printf
main = do
n <- getArgs >>= readIO . head :: IO Int
mapM_ (\i -> sieve (10000 `shiftL` (n-i))) [0, 1, 2]
sieve n = do
let r = runST (do a <- newArray (2,n) True :: ST s (STUArray s Int Bool)
go a n 2 0)
printf "Primes up to %8d %8d\n" (n::Int) (r::Int) :: IO ()
go !a !m !n !c
| n == m = return c
| otherwise = do
e <- unsafeRead a n
if e then let loop j
| j < m = do
x <- unsafeRead a j
when x $ unsafeWrite a j False
loop (j+n)
| otherwise = go a m (n+1) (c+1)
in loop (n `shiftL` 1)
else go a m (n+1) c

Resources