Knuth-Morris-Pratt implementation in Haskell -- Index out of bounds - algorithm

I've used the pseudocode from Wikipedia in an attempt to write a KMP algorithm in Haskell.
It's giving "index out of bounds" when I try to search beyond the length of the pattern and I can't seem to find the issue; my "fixes" have only ruined the result.
import Control.Monad
import Control.Lens
import qualified Data.ByteString.Char8 as C
import qualified Data.Vector.Unboxed as V
(!) :: C.ByteString -> Int -> Char
(!) = C.index
-- Make the table for the KMP. Directly from Wikipedia. Works as expected for inputs from Wikipedia article.
mkTable :: C.ByteString -> V.Vector Int
mkTable pat = make 2 0 (ix 0 .~ (negate 1) $ V.replicate l 0)
where
l = C.length pat
make :: Int -> Int -> V.Vector Int -> V.Vector Int
make p c t
| p >= l = t
| otherwise = proc
where
proc | pat ! (p-1) == pat ! c
= make (p+1) (c+1) (ix p .~ (c+1) $ t)
| c > 0 = make p (t V.! c) t
| otherwise = make (p+1) c (ix p .~ 0 $ t)
kmp :: C.ByteString -> C.ByteString -> V.Vector Int -> Int
kmp text pat tbl = search 0 0
where
l = C.length text
search m i
| m + i >= l = l
| otherwise = cond
where
-- The conditions for the loop, given in the wiki article
cond | pat ! i == text ! (m+i)
= if i == C.length pat - 1
then m
else search m (i+1)
| tbl V.! i > (-1)
= search (m + i - (tbl V.! i)) (tbl V.! i)
| otherwise
= search 0 (m+1)
main :: IO()
main = do
t <- readLn
replicateM_ t $ do
text <- C.getLine
pat <- C.getLine
putStrLn $ kmp text pat (mkTable pat)

Simple solution: I mixed up m and i in the last condition of kmp.
| otherwise = search 0 (m+1)
Becomes
| otherwise = search (m+1) 0
And the issue is resolved.
Aside from that, it's necessary to use unboxed arrays in the ST monad or the table generation takes an absurd amount of time.

Related

Implementing LLL algorithm in Haskell

I'm implementing the LLL basis reduction algorithm in Haskell. I'm basing my code on the pseudocode on Wikipedia. Here is what I have so far. Apologies for the code dump; I strongly suspect the issue lies in lll but I'm giving everything just in case.
import Linear as L
f v x = v `L.dot` x
gram_schmidt b =
let aux vs us =
case vs of
v:t -> let vus = map (\u -> project u v) us
s = foldr (^+^) zero vus
u = v ^-^ s in
aux t (us++[u])
[] -> us
in aux b []
swap :: Int -> Int -> [a] -> [a]
swap i j xs =
let elemI = xs !! i
elemJ = xs !! j
left = take i xs
middle = take (j - i - 1) (drop (i + 1) xs)
right = drop (j + 1) xs
in left ++ [elemJ] ++ middle ++ [elemI] ++ right
update i xs new =
let left = take (i-1) xs
right = drop (i) xs
in left ++ [new] ++ right
sort_vecs vs = map snd (sort (zip (map norm vs) vs))
lll :: Int -> [[Double]] -> Double -> [[Double]]
lll d b delta =
let b' = gram_schmidt b
aux :: [[Double]] -> [[Double]] -> Int -> [[Double]]
aux b b' k =
if k >= d then
b
else
let aux2 :: [[Double]] -> [[Double]] -> Int -> [[Double]]
aux2 b b' j =
if j < 0 then
let mu = (f (b!!k) (b'!!(k-1))) / (f (b'!!(k-1)) (b'!!(k-1))) in
if f (b'!!k) (b'!!k) >= (delta-mu^2) * f (b'!!(k-1)) (b'!!(k-1)) then
aux b b' (k+1)
else
let bb = swap k (k-1) b
bb' = gram_schmidt bb in
aux bb bb' (max (k-1) 1)
else
let mu = (f (b!!k) (b'!!j)) / (f (b'!!j) (b'!!j)) in
if abs mu > 0.5 then
let bk = b!!k
bj = b!!j
bb = update k b (bk ^-^ (fromIntegral (round mu)) *^ bj)
bb' = gram_schmidt bb in
aux2 bb bb' (j-1)
else
aux2 b b' (j-1)
in aux2 b b' (k-1)
in sort_vecs (aux b b' 1)
My issue is that it seems to find a basis of a sublattice. In particular, lll d [[-0.8526334764831849,-3.125000000000004e-2],[-1.2941941738241598,4.419417382415916e-2]] 0.75 returns [[0.41107277914220997,0.10669417382415924],[-1.2941941738241598,4.419417382415916e-2]], a basis for a index-2 sublattice, and with basis which are almost-parallel. I've been staring at this code for ages to no avail (I thought there was an issue with update where (i-1) should be (i) and (i) should be (i+1) but this caused an infinite loop). Any help is greatly appreciated.

Competitive programming using Haskell

I am currently trying to refresh my Haskell knowledge by solving some Hackerrank problems.
For example:
https://www.hackerrank.com/challenges/maximum-palindromes/problem
I've already implemented an imperative solution in C++ which got accepted for all test cases. Now I am trying to come up with a pure functional solution in (reasonably idiomatic) Haskell.
My current code is
module Main where
import Control.Monad
import qualified Data.ByteString.Char8 as C
import Data.Bits
import Data.List
import qualified Data.Map.Strict as Map
import qualified Data.IntMap.Strict as IntMap
import Debug.Trace
-- precompute factorials
compFactorials :: Int -> Int -> IntMap.IntMap Int
compFactorials n m = go 0 1 IntMap.empty
where
go a acc map
| a < 0 = map
| a < n = go a' acc' map'
| otherwise = map'
where
map' = IntMap.insert a acc map
a' = a + 1
acc' = (acc * a') `mod` m
-- precompute invs
compInvs :: Int -> Int -> IntMap.IntMap Int -> IntMap.IntMap Int
compInvs n m facts = go 0 IntMap.empty
where
go a map
| a < 0 = map
| a < n = go a' map'
| otherwise = map'
where
map' = IntMap.insert a v map
a' = a + 1
v = (modExp b (m-2) m) `mod` m
b = (IntMap.!) facts a
modExp :: Int -> Int -> Int -> Int
modExp b e m = go b e 1
where
go b e r
| (.&.) e 1 == 1 = go b' e' r'
| e > 0 = go b' e' r
| otherwise = r
where
r' = (r * b) `mod` m
b' = (b * b) `mod` m
e' = shift e (-1)
-- precompute frequency table
initFreqMap :: C.ByteString -> Map.Map Char (IntMap.IntMap Int)
initFreqMap inp = go 1 map1 map2 inp
where
map1 = Map.fromList $ zip ['a'..'z'] $ repeat 0
map2 = Map.fromList $ zip ['a'..'z'] $ repeat IntMap.empty
go idx m1 m2 inp
| C.null inp = m2
| otherwise = go (idx+1) m1' m2' $ C.tail inp
where
m1' = Map.update (\v -> Just $ v+1) (C.head inp) m1
m2' = foldl' (\m w -> Map.update (\v -> liftM (\c -> IntMap.insert idx c v) $ Map.lookup w m1') w m)
m2 ['a'..'z']
query :: Int -> Int -> Int -> Map.Map Char (IntMap.IntMap Int)
-> IntMap.IntMap Int -> IntMap.IntMap Int -> Int
query l r m freqMap facts invs
| x > 1 = (x * y) `mod` m
| otherwise = y
where
calcCnt cs = cr - cl
where
cl = IntMap.findWithDefault 0 (l-1) cs
cr = IntMap.findWithDefault 0 r cs
f1 acc cs
| even cnt = acc
| otherwise = acc + 1
where
cnt = calcCnt cs
f2 (acc1,acc2) cs
| cnt < 2 = (acc1 ,acc2)
| otherwise = (acc1',acc2')
where
cnt = calcCnt cs
n = cnt `div` 2
acc1' = acc1 + n
r = choose acc1' n
acc2' = (acc2 * r) `mod` m
-- calc binomial coefficient using Fermat's little theorem
choose n k
| n < k = 0
| otherwise = (f1 * t) `mod` m
where
f1 = (IntMap.!) facts n
i1 = (IntMap.!) invs k
i2 = (IntMap.!) invs (n-k)
t = (i1 * i2) `mod` m
x = Map.foldl' f1 0 freqMap
y = snd $ Map.foldl' f2 (0,1) freqMap
main :: IO()
main = do
inp <- C.getLine
q <- readLn :: IO Int
let modulo = 1000000007
let facts = compFactorials (C.length inp) modulo
let invs = compInvs (C.length inp) modulo facts
let freqMap = initFreqMap inp
forM_ [1..q] $ \_ -> do
line <- getLine
let [s1, s2] = words line
let l = (read s1) :: Int
let r = (read s2) :: Int
let result = query l r modulo freqMap facts invs
putStrLn $ show result
It passes all small and medium test cases but I am getting timeout with large test cases.
The key to solve this problem is to precompute some stuff once at the beginning and use them to answer the individual queries efficiently.
Now, my main problem where I need help is:
The initital profiling shows that the lookup operation of the IntMap seems to be the main bottleneck. Is there better alternative to IntMap for memoization? Or should I look at Vector or Array, which I believe will lead to more "ugly" code.
Even in current state, the code doesn't look nice (by functional standards) and as verbose as my C++ solution. Any tips to make it more idiomatic? Other than IntMap usage for memoization, do you spot any other obvious problems which can lead to performance problems?
And is there any good sources, where I can learn how to use Haskell more effectively for competitive programming?
A sample large testcase, where the current code gets timeout:
input.txt
output.txt
For comparison my C++ solution:
#include <vector>
#include <iostream>
#define MOD 1000000007L
long mod_exp(long b, long e) {
long r = 1;
while (e > 0) {
if ((e & 1) == 1) {
r = (r * b) % MOD;
}
b = (b * b) % MOD;
e >>= 1;
}
return r;
}
long n_choose_k(int n, int k, const std::vector<long> &fact_map, const std::vector<long> &inv_map) {
if (n < k) {
return 0;
}
long l1 = fact_map[n];
long l2 = (inv_map[k] * inv_map[n-k]) % MOD;
return (l1 * l2) % MOD;
}
int main() {
std::string s;
int q;
std::cin >> s >> q;
std::vector<std::vector<long>> freq_map;
std::vector<long> fact_map(s.size()+1);
std::vector<long> inv_map(s.size()+1);
for (int i = 0; i < 26; i++) {
freq_map.emplace_back(std::vector<long>(s.size(), 0));
}
std::vector<long> acc_map(26, 0);
for (int i = 0; i < s.size(); i++) {
acc_map[s[i]-'a']++;
for (int j = 0; j < 26; j++) {
freq_map[j][i] = acc_map[j];
}
}
fact_map[0] = 1;
inv_map[0] = 1;
for (int i = 1; i <= s.size(); i++) {
fact_map[i] = (i * fact_map[i-1]) % MOD;
inv_map[i] = mod_exp(fact_map[i], MOD-2) % MOD;
}
while (q--) {
int l, r;
std::cin >> l >> r;
std::vector<long> x(26, 0);
long t = 0;
long acc = 0;
long result = 1;
for (int i = 0; i < 26; i++) {
auto cnt = freq_map[i][r-1] - (l > 1 ? freq_map[i][l-2] : 0);
if (cnt % 2 != 0) {
t++;
}
long n = cnt / 2;
if (n > 0) {
acc += n;
result *= n_choose_k(acc, n, fact_map, inv_map);
result = result % MOD;
}
}
if (t > 0) {
result *= t;
result = result % MOD;
}
std::cout << result << std::endl;
}
}
UPDATE:
DanielWagner's answer has confirmed my suspicion that the main problem in my code was the usage of IntMap for memoization. Replacing IntMap with Array made my code perform similar to DanielWagner's solution.
module Main where
import Control.Monad
import Data.Array (Array)
import qualified Data.Array as A
import qualified Data.ByteString.Char8 as C
import Data.Bits
import Data.List
import Debug.Trace
-- precompute factorials
compFactorials :: Int -> Int -> Array Int Int
compFactorials n m = A.listArray (0,n) $ scanl' f 1 [1..n]
where
f acc a = (acc * a) `mod` m
-- precompute invs
compInvs :: Int -> Int -> Array Int Int -> Array Int Int
compInvs n m facts = A.listArray (0,n) $ map f [0..n]
where
f a = (modExp ((A.!) facts a) (m-2) m) `mod` m
modExp :: Int -> Int -> Int -> Int
modExp b e m = go b e 1
where
go b e r
| (.&.) e 1 == 1 = go b' e' r'
| e > 0 = go b' e' r
| otherwise = r
where
r' = (r * b) `mod` m
b' = (b * b) `mod` m
e' = shift e (-1)
-- precompute frequency table
initFreqMap :: C.ByteString -> Map.Map Char (Array Int Int)
initFreqMap inp = Map.fromList $ map f ['a'..'z']
where
n = C.length inp
f c = (c, A.listArray (0,n) $ scanl' g 0 [0..n-1])
where
g x j
| C.index inp j == c = x+1
| otherwise = x
query :: Int -> Int -> Int -> Map.Map Char (Array Int Int)
-> Array Int Int -> Array Int Int -> Int
query l r m freqMap facts invs
| x > 1 = (x * y) `mod` m
| otherwise = y
where
calcCnt freqMap = cr - cl
where
cl = (A.!) freqMap (l-1)
cr = (A.!) freqMap r
f1 acc cs
| even cnt = acc
| otherwise = acc + 1
where
cnt = calcCnt cs
f2 (acc1,acc2) cs
| cnt < 2 = (acc1 ,acc2)
| otherwise = (acc1',acc2')
where
cnt = calcCnt cs
n = cnt `div` 2
acc1' = acc1 + n
r = choose acc1' n
acc2' = (acc2 * r) `mod` m
-- calc binomial coefficient using Fermat's little theorem
choose n k
| n < k = 0
| otherwise = (f1 * t) `mod` m
where
f1 = (A.!) facts n
i1 = (A.!) invs k
i2 = (A.!) invs (n-k)
t = (i1 * i2) `mod` m
x = Map.foldl' f1 0 freqMap
y = snd $ Map.foldl' f2 (0,1) freqMap
main :: IO()
main = do
inp <- C.getLine
q <- readLn :: IO Int
let modulo = 1000000007
let facts = compFactorials (C.length inp) modulo
let invs = compInvs (C.length inp) modulo facts
let freqMap = initFreqMap inp
replicateM_ q $ do
line <- getLine
let [s1, s2] = words line
let l = (read s1) :: Int
let r = (read s2) :: Int
let result = query l r modulo freqMap facts invs
putStrLn $ show result
I think you've shot yourself in the foot by trying to be too clever. Below I'll show a straightforward implementation of a slightly different algorithm that is about 5x faster than your Haskell code.
Here's the core combinatoric computation. Given a character frequency count for a substring, we can compute the number of maximum-length palindromes this way:
Divide all the frequencies by two, rounding down; call this the div2-frequencies. We'll also want the mod2-frequencies, which is the set of letters for which we had to round down.
Sum the div2-frequencies to get the total length of the palindrome prefix; its factorial gives an overcount of the number of possible prefixes for the palindrome.
Take the product of the factorials of the div2-frequencies. This tells the factor by which we overcounted above.
Take the size of the mod2-frequencies, or choose 1 if there are none. We can extend any of the palindrome prefixes by one of the values in this set, if there are any, so we have to multiply by this size.
For the overcounting step, it's not super obvious to me whether it would be faster to store precomputed inverses for factorials, and take their product, or whether it's faster to just take the product of all the factorials and do one inverse operation at the very end. I'll do the latter, because it just intuitively seems faster to do one inversion per query than one lookup per repeated letter, but what do I know? Should be easy to test if you want to try to adapt the code yourself.
There's only one other quick insight I had vs. your code, which is that we can cache the frequency counts for prefixes of the input; then computing the frequency count for a substring is just pointwise subtraction of two cached counts. Your precomputation on the input I find to be a bit excessive in comparison.
Without further ado, let's see some code. As usual there's some preamble.
module Main where
import Control.Monad
import Data.Array (Array)
import qualified Data.Array as A
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Data.Monoid
Like you, I want to do all my computations on cheap Ints and bake in the modular operations where possible. I'll make a newtype to make sure this happens for me.
newtype Mod1000000007 = Mod Int deriving (Eq, Ord)
instance Num Mod1000000007 where
fromInteger = Mod . (`mod` 1000000007) . fromInteger
Mod l + Mod r = Mod ((l+r) `rem` 1000000007)
Mod l * Mod r = Mod ((l*r) `rem` 1000000007)
negate (Mod v) = Mod ((1000000007 - v) `rem` 1000000007)
abs = id
signum = id
instance Integral Mod1000000007 where
toInteger (Mod n) = toInteger n
quotRem a b = (a * b^1000000005, 0)
I baked in the base of 1000000007 in several places, but it's easy to generalize by giving Mod a phantom parameter and making a HasBase class to pick the base. Ask a fresh question if you're not sure how and are interested; I'll be happy to do a more thorough writeup. There's a few more instances for Mod that are basically uninteresting and primarily needed because of Haskell's wacko numeric class hierarchy:
instance Show Mod1000000007 where show (Mod n) = show n
instance Real Mod1000000007 where toRational (Mod n) = toRational n
instance Enum Mod1000000007 where
toEnum = Mod . (`mod` 1000000007)
fromEnum (Mod n) = n
Here's the precomputation we want to do for factorials...
type FactMap = Array Int Mod1000000007
factMap :: Int -> FactMap
factMap n = A.listArray (0,n) (scanl (*) 1 [1..])
...and for precomputing frequency maps for each prefix, plus getting a frequency map given a start and end point.
type FreqMap = Map Char Int
freqMaps :: String -> Array Int FreqMap
freqMaps s = go where
go = A.listArray (0, length s)
(M.empty : [M.insertWith (+) c 1 (go A.! i) | (i, c) <- zip [0..] s])
substringFreqMap :: Array Int FreqMap -> Int -> Int -> FreqMap
substringFreqMap maps l r = M.unionWith (-) (maps A.! r) (maps A.! (l-1))
Implementing the core computation described above is just a few lines of code, now that we have suitable Num and Integral instances for Mod1000000007:
palindromeCount :: FactMap -> FreqMap -> Mod1000000007
palindromeCount facts freqs
= toEnum (max 1 mod2Freqs)
* (facts A.! sum div2Freqs)
`div` product (map (facts A.!) div2Freqs)
where
(div2Freqs, Sum mod2Freqs) = foldMap (\n -> ([n `quot` 2], Sum (n `rem` 2))) freqs
Now we just need a short driver to read stuff and pass it around to the appropriate functions.
main :: IO ()
main = do
inp <- getLine
q <- readLn
let freqs = freqMaps inp
facts = factMap (length inp)
replicateM_ q $ do
[l,r] <- map read . words <$> getLine
print . palindromeCount facts $ substringFreqMap freqs l r
That's it. Notably I made no attempt to be fancy about bitwise operations and didn't do anything fancy with accumulators; everything is in what I would consider idiomatic purely-functional style. The final count is about half as much code that runs about 5x faster.
P.S. Just for fun, I replaced the last line with print (l+r :: Int)... and discovered that about half the time is spent in read. Ouch! Seems there's still plenty of low-hanging fruit if this isn't fast enough yet.

Why is head-tail pattern matching so much faster than indexing?

I was working on a HackerRank problem today and initially wrote it with indexing and it was incredibly slow for most of the test cases because they were huge. I then decided to switch it to head:tail pattern matching and it just zoomed. The difference was absolutely night and day, but I can't figure out how it was such a change in efficiency. Here is the code for reference if it is at all useful
Most efficient attempt with indexing
count :: Eq a => Integral b => a -> [a] -> b
count e [] = 0
count e (a:xs) = (count e xs +) $ if a == e then 1 else 0
fullCheck :: String -> Bool
fullCheck a = prefixCheck 0 (0,0,0,0) a (length a) && (count 'R' a == count 'G' a) && (count 'Y' a == count 'B' a)
prefixCheck :: Int -> (Int, Int, Int, Int) -> String -> Int -> Bool
prefixCheck n (r',g',y',b') s l
| n == l = True
| otherwise =
((<= 1) $ abs $ r - g) && ((<= 1) $ abs $ y - b)
&& prefixCheck (n+1) (r,g,y,b) s l
where c = s !! n
r = if c == 'R' then r' + 1 else r'
g = if c == 'G' then g' + 1 else g'
y = if c == 'Y' then y' + 1 else y'
b = if c == 'B' then b' + 1 else b'
run :: Int -> IO ()
run 0 = putStr ""
run n = do
a <- getLine
print $ fullCheck a
run $ n - 1
main :: IO ()
main = do
b <- getLine
run $ read b
head:tail pattern matching attempt
count :: Eq a => Integral b => a -> [a] -> b
count e [] = 0
count e (a:xs) = (count e xs +) $ if a == e then 1 else 0
fullCheck :: String -> Bool
fullCheck a = prefixCheck (0,0,0,0) a && (count 'R' a == count 'G' a) && (count 'Y' a == count 'B' a)
prefixCheck :: (Int, Int, Int, Int) -> String -> Bool
prefixCheck (r,g,y,b) [] = r == g && y == b
prefixCheck (r',g',y',b') (h:s) = ((<= 1) $ abs $ r - g) && ((<= 1) $ abs $ y - b)
&& prefixCheck (r,g,y,b) s
where r = if h == 'R' then r' + 1 else r'
g = if h == 'G' then g' + 1 else g'
y = if h == 'Y' then y' + 1 else y'
b = if h == 'B' then b' + 1 else b'
run :: Int -> IO ()
run 0 = putStr ""
run n = do
a <- getLine
print $ fullCheck a
run $ n - 1
main :: IO ()
main = do
b <- getLine
run $ read b
For reference as well, the question was
You are given a sequence of N balls in 4 colors: red, green, yellow and blue. The sequence is full of colors if and only if all of the following conditions are true:
There are as many red balls as green balls.
There are as many yellow balls as blue balls.
Difference between the number of red balls and green balls in every prefix of the sequence is at most 1.
Difference between the number of yellow balls and blue balls in every prefix of the sequence is at most 1.
Where a prefix of a string is any substring from the beginning to m where m is less than the size of the string
You have already got the answer in the comments why lists indexing performs linearly. But, if you are interested in a more Haskell style solution to the Hackerrank problem your referring to, even head-tail pattern matching is unnecessary. A more performant solution can be done with right folds:
import Control.Applicative ((<$>))
import Control.Monad (replicateM_)
solve :: String -> Bool
solve s = foldr go (\r g y b -> r == g && y == b) s 0 0 0 0
where
go x run r g y b
| 1 < abs (r - g) || 1 < abs (y - b) = False
| x == 'R' = run (r + 1) g y b
| x == 'G' = run r (g + 1) y b
| x == 'Y' = run r g (y + 1) b
| x == 'B' = run r g y (b + 1)
main :: IO ()
main = do
n <- read <$> getLine
replicateM_ n $ getLine >>= print . solve

Why is the "better" digit-listing function slower?

I was playing around with Project Euler #34, and I wrote these functions:
import Data.Time.Clock.POSIX
import Data.Char
digits :: (Integral a) => a -> [Int]
digits x
| x < 10 = [fromIntegral x]
| otherwise = let (q, r) = x `quotRem` 10 in (fromIntegral r) : (digits q)
digitsByShow :: (Integral a, Show a) => a -> [Int]
digitsByShow = map (\x -> ord x - ord '0') . show
I thought that for sure digits has to be the faster one, as we don't convert to a String. I could not have been more wrong. I ran the two versions via pe034:
pe034 digitFunc = sum $ filter sumFactDigit [3..2540160]
where
sumFactDigit :: Int -> Bool
sumFactDigit n = n == (sum $ map sFact $ digitFunc n)
sFact :: Int -> Int
sFact n
| n == 0 = 1
| n == 1 = 1
| n == 2 = 2
| n == 3 = 6
| n == 4 = 24
| n == 5 = 120
| n == 6 = 720
| n == 7 = 5040
| n == 8 = 40320
| n == 9 = 362880
main = do
begin <- getPOSIXTime
print $ pe034 digitsByShow -- or digits
end <- getPOSIXTime
print $ end - begin
After compiling with ghc -O, digits consistently takes .5 seconds, while digitsByShow consistently takes .3 seconds. Why is this so? Why is the function which stays within Integer arithmetic slower, whereas the function which goes into string comparison is faster?
I ask this because I come from programming in Java and similar languages, where the % 10 trick of generating digits is way faster than the "convert to String" method. I haven't been able to wrap my head around the fact that converting to a string could be faster.
This is the best I can come up with.
digitsV2 :: (Integral a) => a -> [Int]
digitsV2 n = go n []
where
go x xs
| x < 10 = fromIntegral x : xs
| otherwise = case quotRem x 10 of
(q,r) -> go q (fromIntegral r : xs)
when compiled with -O2 and tested with Criterion
digits runs in 470.4 ms
digitsByShow runs in 421.8 ms
digitsV2 runs in 258.0 ms
results may vary
edit:
I am not sure why building the list like this helps so much.
But you can improve your codes speed by strictly evaluating quotRem x 10
You can do this with BangPatterns
| otherwise = let !(q, r) = x `quotRem` 10 in (fromIntegral r) : (digits q)
or with case
| otherwise = case quotRem x 10 of
(q,r) -> fromIntegral r : digits q
Doing this drops digits down to 323.5 ms
edit: time without using Criterion
digits = 464.3 ms
digitsStrict = 328.2 ms
digitsByShow = 259.2 ms
digitV2 = 252.5 ms
note: The criterion package measures software performance.
Let's investigate why #No_signal's solution is faster.
I made three runs of ghc:
ghc -O2 -ddump-simpl digits.hs >digits.txt
ghc -O2 -ddump-simpl digitsV2.hs >digitsV2.txt
ghc -O2 -ddump-simpl show.hs >show.txt
digits.hs
digits :: (Integral a) => a -> [Int]
digits x
| x < 10 = [fromIntegral x]
| otherwise = let (q, r) = x `quotRem` 10 in (fromIntegral r) : (digits q)
main = return $ digits 1
digitsV2.hs
digitsV2 :: (Integral a) => a -> [Int]
digitsV2 n = go n []
where
go x xs
| x < 10 = fromIntegral x : xs
| otherwise = let (q, r) = x `quotRem` 10 in go q (fromIntegral r : xs)
main = return $ digits 1
show.hs
import Data.Char
digitsByShow :: (Integral a, Show a) => a -> [Int]
digitsByShow = map (\x -> ord x - ord '0') . show
main = return $ digitsByShow 1
If you'd like to view the complete txt files, I placed them on ideone (rather than paste a 10000 char dump here):
digits.txt
digitsV2.txt
show.txt
If we carefully look through digits.txt, it appears that this is the relevant section:
lvl_r1qU = __integer 10
Rec {
Main.$w$sdigits [InlPrag=[0], Occ=LoopBreaker]
:: Integer -> (# Int, [Int] #)
[GblId, Arity=1, Str=DmdType <S,U>]
Main.$w$sdigits =
\ (w_s1pI :: Integer) ->
case integer-gmp-1.0.0.0:GHC.Integer.Type.ltInteger#
w_s1pI lvl_r1qU
of wild_a17q { __DEFAULT ->
case GHC.Prim.tagToEnum# # Bool wild_a17q of _ [Occ=Dead] {
False ->
let {
ds_s16Q [Dmd=<L,U(U,U)>] :: (Integer, Integer)
[LclId, Str=DmdType]
ds_s16Q =
case integer-gmp-1.0.0.0:GHC.Integer.Type.quotRemInteger
w_s1pI lvl_r1qU
of _ [Occ=Dead] { (# ipv_a17D, ipv1_a17E #) ->
(ipv_a17D, ipv1_a17E)
} } in
(# case ds_s16Q of _ [Occ=Dead] { (q_a11V, r_X12h) ->
case integer-gmp-1.0.0.0:GHC.Integer.Type.integerToInt r_X12h
of wild3_a17c { __DEFAULT ->
GHC.Types.I# wild3_a17c
}
},
case ds_s16Q of _ [Occ=Dead] { (q_X12h, r_X129) ->
case Main.$w$sdigits q_X12h
of _ [Occ=Dead] { (# ww1_s1pO, ww2_s1pP #) ->
GHC.Types.: # Int ww1_s1pO ww2_s1pP
}
} #);
True ->
(# GHC.Num.$fNumInt_$cfromInteger w_s1pI, GHC.Types.[] # Int #)
}
}
end Rec }
digitsV2.txt:
lvl_r1xl = __integer 10
Rec {
Main.$wgo [InlPrag=[0], Occ=LoopBreaker]
:: Integer -> [Int] -> (# Int, [Int] #)
[GblId, Arity=2, Str=DmdType <S,U><L,U>]
Main.$wgo =
\ (w_s1wh :: Integer) (w1_s1wi :: [Int]) ->
case integer-gmp-1.0.0.0:GHC.Integer.Type.ltInteger#
w_s1wh lvl_r1xl
of wild_a1dp { __DEFAULT ->
case GHC.Prim.tagToEnum# # Bool wild_a1dp of _ [Occ=Dead] {
False ->
case integer-gmp-1.0.0.0:GHC.Integer.Type.quotRemInteger
w_s1wh lvl_r1xl
of _ [Occ=Dead] { (# ipv_a1dB, ipv1_a1dC #) ->
Main.$wgo
ipv_a1dB
(GHC.Types.:
# Int
(case integer-gmp-1.0.0.0:GHC.Integer.Type.integerToInt ipv1_a1dC
of wild2_a1ea { __DEFAULT ->
GHC.Types.I# wild2_a1ea
})
w1_s1wi)
};
True -> (# GHC.Num.$fNumInt_$cfromInteger w_s1wh, w1_s1wi #)
}
}
end Rec }
I actually couldn't find the relevant section for show.txt. I'll work on that later.
Right off the bat, digitsV2.hs produces shorter code. That's probably a good sign for it.
digits.hs seems to be following this psuedocode:
def digits(w_s1pI):
if w_s1pI < 10: return [fromInteger(w_s1pI)]
else:
ds_s16Q = quotRem(w_s1pI, 10)
q_X12h = ds_s16Q[0]
r_X12h = ds_s16Q[1]
wild3_a17c = integerToInt(r_X12h)
ww1_s1pO = r_X12h
ww2_s1pP = digits(q_X12h)
ww2_s1pP.pushFront(ww1_s1pO)
return ww2_s1pP
digitsV2.hs seems to be following this psuedocode:
def digitsV2(w_s1wh, w1_s1wi=[]): # actually disguised as go(), as #No_signal wrote
if w_s1wh < 10:
w1_s1wi.pushFront(fromInteger(w_s1wh))
return w1_s1wi
else:
ipv_a1dB, ipv1_a1dC = quotRem(w_s1wh, 10)
w1_s1wi.pushFront(integerToIn(ipv1a1dC))
return digitsV2(ipv1_a1dC, w1_s1wi)
It might not be that these functions mutate lists like my psuedocode suggests, but this immediately suggests something: it looks as if digitsV2 is fully tail-recursive, whereas digits is actually not (may have to use some Haskell trampoline or something). It appears as if Haskell needs to store all the remainders in digits before pushing them all to the front of the list, whereas it can just push them and forget about them in digitsV2. This is purely speculation, but it is well-founded speculation.

Haskell performance of memoization implement of dynamic programming is poor [duplicate]

I'm trying to memoize the following function:
gridwalk x y
| x == 0 = 1
| y == 0 = 1
| otherwise = (gridwalk (x - 1) y) + (gridwalk x (y - 1))
Looking at this I came up with the following solution:
gw :: (Int -> Int -> Int) -> Int -> Int -> Int
gw f x y
| x == 0 = 1
| y == 0 = 1
| otherwise = (f (x - 1) y) + (f x (y - 1))
gwlist :: [Int]
gwlist = map (\i -> gw fastgw (i `mod` 20) (i `div` 20)) [0..]
fastgw :: Int -> Int -> Int
fastgw x y = gwlist !! (x + y * 20)
Which I then can call like this:
gw fastgw 20 20
Is there an easier, more concise and general way (notice how I had to hardcode the max grid dimensions in the gwlist function in order to convert from 2D to 1D space so I can access the memoizing list) to memoize functions with multiple parameters in Haskell?
You can use a list of lists to memoize the function result for both parameters:
memo :: (Int -> Int -> a) -> [[a]]
memo f = map (\x -> map (f x) [0..]) [0..]
gw :: Int -> Int -> Int
gw 0 _ = 1
gw _ 0 = 1
gw x y = (fastgw (x - 1) y) + (fastgw x (y - 1))
gwstore :: [[Int]]
gwstore = memo gw
fastgw :: Int -> Int -> Int
fastgw x y = gwstore !! x !! y
Use the data-memocombinators package from hackage. It provides easy to use memorization techniques and provides an easy and breve way to use them:
import Data.MemoCombinators (memo2,integral)
gridwalk = memo2 integral integral gridwalk' where
gridwalk' x y
| x == 0 = 1
| y == 0 = 1
| otherwise = (gridwalk (x - 1) y) + (gridwalk x (y - 1))
Here is a version using Data.MemoTrie from the MemoTrie package to memoize the function:
import Data.MemoTrie(memo2)
gridwalk :: Int -> Int -> Int
gridwalk = memo2 gw
where
gw 0 _ = 1
gw _ 0 = 1
gw x y = gridwalk (x - 1) y + gridwalk x (y - 1)
If you want maximum generality, you can memoize a memoizing function.
memo :: (Num a, Enum a) => (a -> b) -> [b]
memo f = map f (enumFrom 0)
gwvals = fmap memo (memo gw)
fastgw :: Int -> Int -> Int
fastgw x y = gwvals !! x !! y
This technique will work with functions that have any number of arguments.
Edit: thanks to Philip K. for pointing out a bug in the original code. Originally memo had a "Bounded" constraint instead of "Num" and began the enumeration at minBound, which would only be valid for natural numbers.
Lists aren't a good data structure for memoizing, though, because they have linear lookup complexity. You might be better off with a Map or IntMap. Or look on Hackage.
Note that this particular code does rely on laziness, so if you wanted to switch to using a Map you would need to take a bounded amount of elements from the list, as in:
gwByMap :: Int -> Int -> Int -> Int -> Int
gwByMap maxX maxY x y = fromMaybe (gw x y) $ M.lookup (x,y) memomap
where
memomap = M.fromList $ concat [[((x',y'),z) | (y',z) <- zip [0..maxY] ys]
| (x',ys) <- zip [0..maxX] gwvals]
fastgw2 :: Int -> Int -> Int
fastgw2 = gwByMap 20 20
I think ghc may be stupid about sharing in this case, you may need to lift out the x and y parameters, like this:
gwByMap maxX maxY = \x y -> fromMaybe (gw x y) $ M.lookup (x,y) memomap

Resources