ICFPC 2006 task on Haskell is too slow - performance

I am learning Haskell, and as a task for myself I was trying to implement a Universal Machine from ICFP Contest 2006. I came up with a code, which, at a first glance, seems to be working. However, when I try to execute any application for that Universal Machine provided on the contest's website (e.g. sandmark.umz), my implementation is too slow to actually run anything. Self-check did not finish in a couple of hours, and I had to kill the process. So, I am clearly doing something wrong, I just don't know what.
I have tried to use Haskell's profiler, but I couldn't make any sense out of those numbers as well. Garbage collection doesn't seem to be taking a lot of time (3 seconds out of 173 seconds of a sample). However, total allocated memory during those 173 seconds was almost 6 GB, while the maximum heap size was 13 MB.
Could you help me understand, what is wrong with my code? I know that the amount of code is quite large, but I am not sure how to come up with a minimum reproducible example in my case, when I don't really know what is relevant, and what is not. Thank you.
module Main where
import System.Environment (getArgs)
import System.Exit (exitSuccess)
import System.IO (hPutStrLn, stderr)
import System.IO.Error (catchIOError)
import Control.Monad (when)
import Control.Monad.Loops (iterateM_)
import Data.Array.IO (IOUArray, newArray, newListArray, readArray, writeArray, mapArray)
import Data.Bits
import Data.Binary.Get (getWord32be, runGet, isEmpty, Get)
import Data.Char (chr, ord)
import Data.Word (Word32)
import Data.Maybe (fromJust)
import qualified Data.IntMap.Strict as M
import qualified Data.ByteString.Lazy as B
import qualified Data.IntSet as IntSet
data UMState = UMState {
getFinger :: Word32,
getRegisters :: IOUArray Int Word32,
getArrays :: M.IntMap (IOUArray Word32 Word32),
getFreeIds :: [Int],
getMaxId :: Int,
getCopiedPlatters :: IntSet.IntSet
}
getOperation :: Word32 -> Int
getOperation x = fromIntegral $ (x `shiftR` 28) .&. 15
getRegisterIds :: Word32 -> (Int, Int, Int)
getRegisterIds x = (fromIntegral $ (x `shiftR` 6) .&. 7, fromIntegral $ (x `shiftR` 3) .&. 7, fromIntegral $ x .&. 7)
getOrthography :: Word32 -> (Int, Word32)
getOrthography x = (fromIntegral $ (x `shiftR` 25) .&. 7, x .&. 33554431)
setFinger :: UMState -> Word32 -> UMState
setFinger (UMState {
getFinger = _,
getRegisters = regs,
getArrays = arr,
getFreeIds = fids,
getMaxId = mid,
getCopiedPlatters = cp
}) f' = UMState {
getFinger = f',
getRegisters = regs,
getArrays = arr,
getFreeIds = fids,
getMaxId = mid,
getCopiedPlatters = cp
}
removePlatter :: UMState -> Int -> UMState
removePlatter (UMState {
getFinger = f,
getRegisters = regs,
getArrays = arr,
getFreeIds = fids,
getMaxId = mid,
getCopiedPlatters = cp
}) pid = UMState {
getFinger = f,
getRegisters = regs,
getArrays = M.delete pid arr,
getFreeIds = (pid:fids),
getMaxId = mid,
getCopiedPlatters = cp
}
insertPlatter :: UMState -> Int -> IOUArray Word32 Word32 -> UMState
insertPlatter (UMState {
getFinger = f,
getRegisters = regs,
getArrays = arr,
getFreeIds = fids#(hfid:tfids),
getMaxId = mid,
getCopiedPlatters = cp
}) pid platter = UMState {
getFinger = f,
getRegisters = regs,
getArrays = M.insert pid platter arr,
getFreeIds = if pid == hfid then tfids else fids,
getMaxId = max mid pid,
getCopiedPlatters = cp
}
insertPlatter (UMState {
getFinger = f,
getRegisters = regs,
getArrays = arr,
getFreeIds = [],
getMaxId = mid,
getCopiedPlatters = cp
}) pid platter = UMState {
getFinger = f,
getRegisters = regs,
getArrays = M.insert pid platter arr,
getFreeIds = [],
getMaxId = max mid pid,
getCopiedPlatters = cp
}
setCopiedPlatters :: UMState -> IntSet.IntSet -> UMState
setCopiedPlatters (UMState {
getFinger = f,
getRegisters = regs,
getArrays = arr,
getFreeIds = fids,
getMaxId = mid,
getCopiedPlatters = _
}) copied' = UMState {
getFinger = f,
getRegisters = regs,
getArrays = arr,
getFreeIds = fids,
getMaxId = mid,
getCopiedPlatters = copied'
}
main = do
args <- getArgs
fileName <- parseArgs $ filter (\arg -> arg /= "--") args
platters <- B.readFile fileName
array0 <- listToArray (runGet readPlatters platters)
regs <- (newArray (0, 7) 0 :: IO (IOUArray Int Word32))
let initState = (UMState {
getFinger = 0,
getRegisters = regs,
getArrays = M.insert 0 array0 M.empty,
getFreeIds = [],
getMaxId = 0,
getCopiedPlatters = IntSet.empty
})
in iterateM_ spinCycle initState
parseArgs :: [String] -> IO (String)
parseArgs [arg] = return arg
parseArgs args = fail $ "Exactly one argument expected. Found: " ++ (show args)
readPlatters :: Get [Word32]
readPlatters = do
empty <- isEmpty
if empty
then return []
else do
platter <- getWord32be
theRest <- readPlatters
return (platter:theRest)
listToArray :: [Word32] -> IO (IOUArray Word32 Word32)
listToArray lst = newListArray (fromIntegral 0, fromIntegral (length lst) - 1) lst
spinCycle :: UMState -> IO (UMState)
spinCycle state = do
platter <- readArray (fromJust (M.lookup 0 (getArrays state))) (getFinger state)
let state' = setFinger state $ getFinger state + 1
(aId, bId, cId) = getRegisterIds platter
regs = getRegisters state'
arrays = getArrays state' in (
case (getOperation platter) of
0 -> do
runConditionalMove aId bId cId regs
return state'
1 -> do
runArrayIndex aId bId cId regs arrays
return state'
2 -> runArrayAmendment aId bId cId state'
3 -> do
runAddition aId bId cId regs
return state'
4 -> do
runMultiplication aId bId cId regs
return state'
5 -> do
runDivision aId bId cId regs
return state'
6 -> do
runNand aId bId cId regs
return state'
7 -> runHalt
8 -> runAllocation bId cId state'
9 -> runAbandonment cId state'
10 -> do
runOutput cId regs
return state'
11 -> do
runInput cId regs
return state'
12 -> runLoadProgram bId cId state'
13 -> do
let (reg, val) = getOrthography platter
in (runOrthography reg val regs)
return state'
)
-- #0. Conditional Move.
runConditionalMove :: Int -> Int -> Int -> IOUArray Int Word32 -> IO ()
runConditionalMove a b c regs = do
hPutStrLn stderr ("conditionalMove " ++ (show a) ++ " " ++ (show b) ++ " " ++ (show c))
cRead <- readArray regs c
when (cRead /= 0) $ do
bRead <- readArray regs b
writeArray regs a bRead
-- #1. Array Index.
runArrayIndex :: Int -> Int -> Int -> IOUArray Int Word32 -> M.IntMap (IOUArray Word32 Word32) -> IO ()
runArrayIndex a b c regs arrays = do
hPutStrLn stderr ("arrayIndex " ++ (show a) ++ " " ++ (show b) ++ " " ++ (show c))
bRead <- readArray regs b
cRead <- readArray regs c
val <- readArray (fromJust (M.lookup (fromIntegral bRead) arrays)) cRead
writeArray regs a val
-- #2. Array Amendment.
runArrayAmendment :: Int -> Int -> Int -> UMState -> IO (UMState)
runArrayAmendment a b c state = do
hPutStrLn stderr ("arrayAmendment " ++ (show a) ++ " " ++ (show b) ++ " " ++ (show c))
aRead <- readArray (getRegisters state) a
bRead <- readArray (getRegisters state) b
cRead <- readArray (getRegisters state) c
stateToWrite <- if IntSet.member (fromIntegral aRead) (getCopiedPlatters state) then (do
pCopy <- mapArray id (fromJust (M.lookup (fromIntegral aRead) (getArrays state)))
let state' = insertPlatter state (fromIntegral aRead) pCopy
state'' = setCopiedPlatters state' $ IntSet.delete (fromIntegral aRead) (getCopiedPlatters state')
in return state''
) else return state
writeArray (fromJust (M.lookup (fromIntegral aRead) (getArrays stateToWrite))) bRead cRead
return stateToWrite
-- #3. Addition.
runAddition :: Int -> Int -> Int -> IOUArray Int Word32 -> IO ()
runAddition a b c regs = do
hPutStrLn stderr ("addition " ++ (show a) ++ " " ++ (show b) ++ " " ++ (show c))
bRead <- readArray regs b
cRead <- readArray regs c
writeArray regs a (bRead + cRead)
-- #4. Multiplication.
runMultiplication :: Int -> Int -> Int -> IOUArray Int Word32 -> IO ()
runMultiplication a b c regs = do
hPutStrLn stderr ("multiplication " ++ (show a) ++ " " ++ (show b) ++ " " ++ (show c))
bRead <- readArray regs b
cRead <- readArray regs c
writeArray regs a (bRead * cRead)
-- #5. Division.
runDivision :: Int -> Int -> Int -> IOUArray Int Word32 -> IO ()
runDivision a b c regs = do
hPutStrLn stderr ("division " ++ (show a) ++ " " ++ (show b) ++ " " ++ (show c))
bRead <- readArray regs b
cRead <- readArray regs c
writeArray regs a (bRead `div` cRead)
-- #6. Not-And.
runNand :: Int -> Int -> Int -> IOUArray Int Word32 -> IO ()
runNand a b c regs = do
hPutStrLn stderr ("nand " ++ (show a) ++ " " ++ (show b) ++ " " ++ (show c))
bRead <- readArray regs b
cRead <- readArray regs c
writeArray regs a (complement $ bRead .&. cRead)
-- #7. Halt.
runHalt = exitSuccess :: IO (UMState)
-- #8. Allocation.
runAllocation :: Int -> Int -> UMState -> IO (UMState)
runAllocation b c state = do
hPutStrLn stderr ("allocation " ++ (show b) ++ " " ++ (show c))
cRead <- readArray (getRegisters state) c
pArray <- (newArray (0, cRead) 0 :: IO (IOUArray Word32 Word32))
(state', newId) <-
case (getFreeIds state) of
(freeId:_) -> return (insertPlatter state freeId pArray, freeId)
[] -> let maxId' = getMaxId state + 1 in return (insertPlatter state maxId' pArray, maxId')
writeArray (getRegisters state') b (fromIntegral newId)
return state'
-- #9. Abandonment.
runAbandonment :: Int -> UMState -> IO (UMState)
runAbandonment c state = do
hPutStrLn stderr ("abandonment " ++ (show c))
cRead <- readArray (getRegisters state) c
return (removePlatter state $ fromIntegral cRead)
-- #10. Output.
runOutput :: Int -> IOUArray Int Word32 -> IO ()
runOutput c regs = do
cRead <- readArray regs c
when (cRead < 256) $ putChar . chr . fromIntegral $ cRead
-- #11. Input.
runInput :: Int -> IOUArray Int Word32 -> IO ()
runInput c regs = do
cRead <- getChar `catchIOError` (\_ -> return $ chr 255)
writeArray regs c (fromIntegral $ ord cRead)
-- #12. Load Program.
runLoadProgram :: Int -> Int -> UMState -> IO (UMState)
runLoadProgram b c state = do
hPutStrLn stderr ("loadProgram " ++ (show b) ++ " " ++ (show c))
bRead <- readArray (getRegisters state) b
cRead <- readArray (getRegisters state) c
let bReadInt = fromIntegral bRead
pCopy = fromJust (M.lookup bReadInt (getArrays state))
copied = IntSet.insert 0 (getCopiedPlatters state)
copied' = IntSet.insert bReadInt copied
state' = insertPlatter state 0 pCopy
state'' = setFinger state' cRead
state''' = setCopiedPlatters state'' copied'
in return state'''
-- #13. Orthography.
runOrthography :: Int -> Word32 -> IOUArray Int Word32 -> IO ()
runOrthography reg val regs = writeArray regs reg val

Total allocation of 3 gigabytes for a Haskell programming running for 176 seconds is miniscule. Most Haskell programs allocate 3-6 gigabytes per second for their entire runtime. In your case, much of the program is running in tight, allocation-free loops (generally a good thing when you're trying to write a fast program), which may explain the small amount of allocation. The small proportion of time spent garbage collecting is also a good sign.
I tested your program on sandmark.umz andcodex.umz, built with -O2 and no profiling.
I believe the main problem is that the hPutStrLn logging lines are generating tons of output, so your universal machine is spending all its time writing logs.
Comment out all the hPutStrLn lines, and SANDmark prints a line every few seconds. I have no idea how fast it's supposed to be, but it's certainly running.
For Codex, it completes the self-check succeeded in a few seconds and accepts a 32-character key. If you enter the wrong key, it prints "wrong key". If you enter the right key, it prints "decrypting..." At this point, it seems to freeze up, so I suspect your implementation is too slow, but not nearly as slow as you were reporting.
Note that you may find it helpful to turn off buffering on stdin and stdout at the start of main:
main = do
hSetBuffering stdin NoBuffering
hSetBuffering stdout NoBuffering
...
so that the getChar and putChar-based I/O operate immediately. This isn't strictly necessary, but might help avoid apparent lockups that are actually just buffering issues.

Related

Competitive programming using Haskell

I am currently trying to refresh my Haskell knowledge by solving some Hackerrank problems.
For example:
https://www.hackerrank.com/challenges/maximum-palindromes/problem
I've already implemented an imperative solution in C++ which got accepted for all test cases. Now I am trying to come up with a pure functional solution in (reasonably idiomatic) Haskell.
My current code is
module Main where
import Control.Monad
import qualified Data.ByteString.Char8 as C
import Data.Bits
import Data.List
import qualified Data.Map.Strict as Map
import qualified Data.IntMap.Strict as IntMap
import Debug.Trace
-- precompute factorials
compFactorials :: Int -> Int -> IntMap.IntMap Int
compFactorials n m = go 0 1 IntMap.empty
where
go a acc map
| a < 0 = map
| a < n = go a' acc' map'
| otherwise = map'
where
map' = IntMap.insert a acc map
a' = a + 1
acc' = (acc * a') `mod` m
-- precompute invs
compInvs :: Int -> Int -> IntMap.IntMap Int -> IntMap.IntMap Int
compInvs n m facts = go 0 IntMap.empty
where
go a map
| a < 0 = map
| a < n = go a' map'
| otherwise = map'
where
map' = IntMap.insert a v map
a' = a + 1
v = (modExp b (m-2) m) `mod` m
b = (IntMap.!) facts a
modExp :: Int -> Int -> Int -> Int
modExp b e m = go b e 1
where
go b e r
| (.&.) e 1 == 1 = go b' e' r'
| e > 0 = go b' e' r
| otherwise = r
where
r' = (r * b) `mod` m
b' = (b * b) `mod` m
e' = shift e (-1)
-- precompute frequency table
initFreqMap :: C.ByteString -> Map.Map Char (IntMap.IntMap Int)
initFreqMap inp = go 1 map1 map2 inp
where
map1 = Map.fromList $ zip ['a'..'z'] $ repeat 0
map2 = Map.fromList $ zip ['a'..'z'] $ repeat IntMap.empty
go idx m1 m2 inp
| C.null inp = m2
| otherwise = go (idx+1) m1' m2' $ C.tail inp
where
m1' = Map.update (\v -> Just $ v+1) (C.head inp) m1
m2' = foldl' (\m w -> Map.update (\v -> liftM (\c -> IntMap.insert idx c v) $ Map.lookup w m1') w m)
m2 ['a'..'z']
query :: Int -> Int -> Int -> Map.Map Char (IntMap.IntMap Int)
-> IntMap.IntMap Int -> IntMap.IntMap Int -> Int
query l r m freqMap facts invs
| x > 1 = (x * y) `mod` m
| otherwise = y
where
calcCnt cs = cr - cl
where
cl = IntMap.findWithDefault 0 (l-1) cs
cr = IntMap.findWithDefault 0 r cs
f1 acc cs
| even cnt = acc
| otherwise = acc + 1
where
cnt = calcCnt cs
f2 (acc1,acc2) cs
| cnt < 2 = (acc1 ,acc2)
| otherwise = (acc1',acc2')
where
cnt = calcCnt cs
n = cnt `div` 2
acc1' = acc1 + n
r = choose acc1' n
acc2' = (acc2 * r) `mod` m
-- calc binomial coefficient using Fermat's little theorem
choose n k
| n < k = 0
| otherwise = (f1 * t) `mod` m
where
f1 = (IntMap.!) facts n
i1 = (IntMap.!) invs k
i2 = (IntMap.!) invs (n-k)
t = (i1 * i2) `mod` m
x = Map.foldl' f1 0 freqMap
y = snd $ Map.foldl' f2 (0,1) freqMap
main :: IO()
main = do
inp <- C.getLine
q <- readLn :: IO Int
let modulo = 1000000007
let facts = compFactorials (C.length inp) modulo
let invs = compInvs (C.length inp) modulo facts
let freqMap = initFreqMap inp
forM_ [1..q] $ \_ -> do
line <- getLine
let [s1, s2] = words line
let l = (read s1) :: Int
let r = (read s2) :: Int
let result = query l r modulo freqMap facts invs
putStrLn $ show result
It passes all small and medium test cases but I am getting timeout with large test cases.
The key to solve this problem is to precompute some stuff once at the beginning and use them to answer the individual queries efficiently.
Now, my main problem where I need help is:
The initital profiling shows that the lookup operation of the IntMap seems to be the main bottleneck. Is there better alternative to IntMap for memoization? Or should I look at Vector or Array, which I believe will lead to more "ugly" code.
Even in current state, the code doesn't look nice (by functional standards) and as verbose as my C++ solution. Any tips to make it more idiomatic? Other than IntMap usage for memoization, do you spot any other obvious problems which can lead to performance problems?
And is there any good sources, where I can learn how to use Haskell more effectively for competitive programming?
A sample large testcase, where the current code gets timeout:
input.txt
output.txt
For comparison my C++ solution:
#include <vector>
#include <iostream>
#define MOD 1000000007L
long mod_exp(long b, long e) {
long r = 1;
while (e > 0) {
if ((e & 1) == 1) {
r = (r * b) % MOD;
}
b = (b * b) % MOD;
e >>= 1;
}
return r;
}
long n_choose_k(int n, int k, const std::vector<long> &fact_map, const std::vector<long> &inv_map) {
if (n < k) {
return 0;
}
long l1 = fact_map[n];
long l2 = (inv_map[k] * inv_map[n-k]) % MOD;
return (l1 * l2) % MOD;
}
int main() {
std::string s;
int q;
std::cin >> s >> q;
std::vector<std::vector<long>> freq_map;
std::vector<long> fact_map(s.size()+1);
std::vector<long> inv_map(s.size()+1);
for (int i = 0; i < 26; i++) {
freq_map.emplace_back(std::vector<long>(s.size(), 0));
}
std::vector<long> acc_map(26, 0);
for (int i = 0; i < s.size(); i++) {
acc_map[s[i]-'a']++;
for (int j = 0; j < 26; j++) {
freq_map[j][i] = acc_map[j];
}
}
fact_map[0] = 1;
inv_map[0] = 1;
for (int i = 1; i <= s.size(); i++) {
fact_map[i] = (i * fact_map[i-1]) % MOD;
inv_map[i] = mod_exp(fact_map[i], MOD-2) % MOD;
}
while (q--) {
int l, r;
std::cin >> l >> r;
std::vector<long> x(26, 0);
long t = 0;
long acc = 0;
long result = 1;
for (int i = 0; i < 26; i++) {
auto cnt = freq_map[i][r-1] - (l > 1 ? freq_map[i][l-2] : 0);
if (cnt % 2 != 0) {
t++;
}
long n = cnt / 2;
if (n > 0) {
acc += n;
result *= n_choose_k(acc, n, fact_map, inv_map);
result = result % MOD;
}
}
if (t > 0) {
result *= t;
result = result % MOD;
}
std::cout << result << std::endl;
}
}
UPDATE:
DanielWagner's answer has confirmed my suspicion that the main problem in my code was the usage of IntMap for memoization. Replacing IntMap with Array made my code perform similar to DanielWagner's solution.
module Main where
import Control.Monad
import Data.Array (Array)
import qualified Data.Array as A
import qualified Data.ByteString.Char8 as C
import Data.Bits
import Data.List
import Debug.Trace
-- precompute factorials
compFactorials :: Int -> Int -> Array Int Int
compFactorials n m = A.listArray (0,n) $ scanl' f 1 [1..n]
where
f acc a = (acc * a) `mod` m
-- precompute invs
compInvs :: Int -> Int -> Array Int Int -> Array Int Int
compInvs n m facts = A.listArray (0,n) $ map f [0..n]
where
f a = (modExp ((A.!) facts a) (m-2) m) `mod` m
modExp :: Int -> Int -> Int -> Int
modExp b e m = go b e 1
where
go b e r
| (.&.) e 1 == 1 = go b' e' r'
| e > 0 = go b' e' r
| otherwise = r
where
r' = (r * b) `mod` m
b' = (b * b) `mod` m
e' = shift e (-1)
-- precompute frequency table
initFreqMap :: C.ByteString -> Map.Map Char (Array Int Int)
initFreqMap inp = Map.fromList $ map f ['a'..'z']
where
n = C.length inp
f c = (c, A.listArray (0,n) $ scanl' g 0 [0..n-1])
where
g x j
| C.index inp j == c = x+1
| otherwise = x
query :: Int -> Int -> Int -> Map.Map Char (Array Int Int)
-> Array Int Int -> Array Int Int -> Int
query l r m freqMap facts invs
| x > 1 = (x * y) `mod` m
| otherwise = y
where
calcCnt freqMap = cr - cl
where
cl = (A.!) freqMap (l-1)
cr = (A.!) freqMap r
f1 acc cs
| even cnt = acc
| otherwise = acc + 1
where
cnt = calcCnt cs
f2 (acc1,acc2) cs
| cnt < 2 = (acc1 ,acc2)
| otherwise = (acc1',acc2')
where
cnt = calcCnt cs
n = cnt `div` 2
acc1' = acc1 + n
r = choose acc1' n
acc2' = (acc2 * r) `mod` m
-- calc binomial coefficient using Fermat's little theorem
choose n k
| n < k = 0
| otherwise = (f1 * t) `mod` m
where
f1 = (A.!) facts n
i1 = (A.!) invs k
i2 = (A.!) invs (n-k)
t = (i1 * i2) `mod` m
x = Map.foldl' f1 0 freqMap
y = snd $ Map.foldl' f2 (0,1) freqMap
main :: IO()
main = do
inp <- C.getLine
q <- readLn :: IO Int
let modulo = 1000000007
let facts = compFactorials (C.length inp) modulo
let invs = compInvs (C.length inp) modulo facts
let freqMap = initFreqMap inp
replicateM_ q $ do
line <- getLine
let [s1, s2] = words line
let l = (read s1) :: Int
let r = (read s2) :: Int
let result = query l r modulo freqMap facts invs
putStrLn $ show result
I think you've shot yourself in the foot by trying to be too clever. Below I'll show a straightforward implementation of a slightly different algorithm that is about 5x faster than your Haskell code.
Here's the core combinatoric computation. Given a character frequency count for a substring, we can compute the number of maximum-length palindromes this way:
Divide all the frequencies by two, rounding down; call this the div2-frequencies. We'll also want the mod2-frequencies, which is the set of letters for which we had to round down.
Sum the div2-frequencies to get the total length of the palindrome prefix; its factorial gives an overcount of the number of possible prefixes for the palindrome.
Take the product of the factorials of the div2-frequencies. This tells the factor by which we overcounted above.
Take the size of the mod2-frequencies, or choose 1 if there are none. We can extend any of the palindrome prefixes by one of the values in this set, if there are any, so we have to multiply by this size.
For the overcounting step, it's not super obvious to me whether it would be faster to store precomputed inverses for factorials, and take their product, or whether it's faster to just take the product of all the factorials and do one inverse operation at the very end. I'll do the latter, because it just intuitively seems faster to do one inversion per query than one lookup per repeated letter, but what do I know? Should be easy to test if you want to try to adapt the code yourself.
There's only one other quick insight I had vs. your code, which is that we can cache the frequency counts for prefixes of the input; then computing the frequency count for a substring is just pointwise subtraction of two cached counts. Your precomputation on the input I find to be a bit excessive in comparison.
Without further ado, let's see some code. As usual there's some preamble.
module Main where
import Control.Monad
import Data.Array (Array)
import qualified Data.Array as A
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Data.Monoid
Like you, I want to do all my computations on cheap Ints and bake in the modular operations where possible. I'll make a newtype to make sure this happens for me.
newtype Mod1000000007 = Mod Int deriving (Eq, Ord)
instance Num Mod1000000007 where
fromInteger = Mod . (`mod` 1000000007) . fromInteger
Mod l + Mod r = Mod ((l+r) `rem` 1000000007)
Mod l * Mod r = Mod ((l*r) `rem` 1000000007)
negate (Mod v) = Mod ((1000000007 - v) `rem` 1000000007)
abs = id
signum = id
instance Integral Mod1000000007 where
toInteger (Mod n) = toInteger n
quotRem a b = (a * b^1000000005, 0)
I baked in the base of 1000000007 in several places, but it's easy to generalize by giving Mod a phantom parameter and making a HasBase class to pick the base. Ask a fresh question if you're not sure how and are interested; I'll be happy to do a more thorough writeup. There's a few more instances for Mod that are basically uninteresting and primarily needed because of Haskell's wacko numeric class hierarchy:
instance Show Mod1000000007 where show (Mod n) = show n
instance Real Mod1000000007 where toRational (Mod n) = toRational n
instance Enum Mod1000000007 where
toEnum = Mod . (`mod` 1000000007)
fromEnum (Mod n) = n
Here's the precomputation we want to do for factorials...
type FactMap = Array Int Mod1000000007
factMap :: Int -> FactMap
factMap n = A.listArray (0,n) (scanl (*) 1 [1..])
...and for precomputing frequency maps for each prefix, plus getting a frequency map given a start and end point.
type FreqMap = Map Char Int
freqMaps :: String -> Array Int FreqMap
freqMaps s = go where
go = A.listArray (0, length s)
(M.empty : [M.insertWith (+) c 1 (go A.! i) | (i, c) <- zip [0..] s])
substringFreqMap :: Array Int FreqMap -> Int -> Int -> FreqMap
substringFreqMap maps l r = M.unionWith (-) (maps A.! r) (maps A.! (l-1))
Implementing the core computation described above is just a few lines of code, now that we have suitable Num and Integral instances for Mod1000000007:
palindromeCount :: FactMap -> FreqMap -> Mod1000000007
palindromeCount facts freqs
= toEnum (max 1 mod2Freqs)
* (facts A.! sum div2Freqs)
`div` product (map (facts A.!) div2Freqs)
where
(div2Freqs, Sum mod2Freqs) = foldMap (\n -> ([n `quot` 2], Sum (n `rem` 2))) freqs
Now we just need a short driver to read stuff and pass it around to the appropriate functions.
main :: IO ()
main = do
inp <- getLine
q <- readLn
let freqs = freqMaps inp
facts = factMap (length inp)
replicateM_ q $ do
[l,r] <- map read . words <$> getLine
print . palindromeCount facts $ substringFreqMap freqs l r
That's it. Notably I made no attempt to be fancy about bitwise operations and didn't do anything fancy with accumulators; everything is in what I would consider idiomatic purely-functional style. The final count is about half as much code that runs about 5x faster.
P.S. Just for fun, I replaced the last line with print (l+r :: Int)... and discovered that about half the time is spent in read. Ouch! Seems there's still plenty of low-hanging fruit if this isn't fast enough yet.

Optimizing n-queens in Haskell

This code:
{-# LANGUAGE BangPatterns #-}
module Main where
import Data.Bits
import Data.Word
import Control.Monad
import System.CPUTime
import Data.List
-- The Damenproblem.
-- Wiki: https://de.wikipedia.org/wiki/Damenproblem
main :: IO ()
main = do
start <- getCPUTime
print $ dame 14
end <- getCPUTime
print $ "Needed " ++ (show ((fromIntegral (end - start)) / (10^12))) ++ " Seconds"
type BitState = (Word64, Word64, Word64)
dame :: Int -> Int
dame max = foldl' (+) 0 $ map fn row
where fn x = recur (max - 2) $ nextState (x, x, x)
recur !depth !state = foldl' (+) 0 $ flip map row $ getPossible depth (getStateVal state) state
getPossible depth !stateVal state bit
| (bit .&. stateVal) > 0 = 0
| depth == 0 = 1
| otherwise = recur (depth - 1) (nextState (addBitToState bit state))
row = take max $ iterate moveLeft 1
getStateVal :: BitState -> Word64
getStateVal (l, r, c) = l .|. r .|. c
addBitToState :: Word64 -> BitState -> BitState
addBitToState l (ol, or, oc) = (ol .|. l, or .|. l, oc .|. l)
nextState :: BitState -> BitState
nextState (l, r, c) = (moveLeft l, moveRight r, c)
moveRight :: Word64 -> Word64
moveRight x = shiftR x 1
moveLeft :: Word64 -> Word64
moveLeft x = shift x 1
needs about 60 seconds to execute. If i enable compiler optimisation with -O2, it takes about 7 seconds. -O1 is faster and takes about 5 seconds.
Tested a java version of this code, with for-loops in place of mapped lists, it takes about 1s (!). Been trying my hardest to optimize yet none of the tips i found online helped more than half a second. Please help
Edit: Java version:
public class Queens{
static int getQueens(){
int res = 0;
for (int i = 0; i < N; i++) {
int pos = 1 << i;
res += run(pos << 1, pos >> 1, pos, N - 2);
}
return res;
}
static int run(long diagR, long diagL, long mid, int depth) {
long valid = mid | diagL | diagR;
int resBuffer = 0;
for (int i = 0; i < N; i++) {
int pos = 1 << i;
if ((valid & pos) > 0) {
continue;
}
if (depth == 0) {
resBuffer++;
continue;
}
long n_mid = mid | pos;
long n_diagL = (diagL >> 1) | (pos >> 1);
long n_diagR = (diagR << 1) | (pos << 1);
resBuffer += run(n_diagR, n_diagL, n_mid, depth - 1);
}
return resBuffer;
}
}
Edit: Running on windows with ghc 8.4.1 on an i5 650 with 3.2GHz.
Assuming your algorithm is correct (I haven't verified this), I was able to get consistently 900ms (faster than the Java implementation!). -O2 and -O3 were both comparable on my machine.
Notable changes: (EDIT: Most important change: switch from List to Vector) Switched to GHC 8.4.1, used strictness liberally, BitState is now a strict 3-tuple
Using Vectors is important to achieve better speed - in my opinion you can't achieve comparable speed with just linked lists, even with fusion. The Unboxed Vector is important because you know the Vector will always be of Word64s or Ints.
{-# LANGUAGE BangPatterns #-}
module Main (main) where
import Data.Bits ((.&.), (.|.), shiftR, shift)
import Data.Vector.Unboxed (Vector)
import qualified Data.Vector.Unboxed as Vector
import Data.Word (Word64)
import Prelude hiding (max, sum)
import System.CPUTime (getCPUTime)
--
-- The Damenproblem.
-- Wiki: https://de.wikipedia.org/wiki/Damenproblem
main :: IO ()
main = do
start <- getCPUTime
print $ dame 14
end <- getCPUTime
print $ "Needed " ++ (show ((fromIntegral (end - start)) / (10^12))) ++ " Seconds"
data BitState = BitState !Word64 !Word64 !Word64
bmap :: (Word64 -> Word64) -> BitState -> BitState
bmap f (BitState x y z) = BitState (f x) (f y) (f z)
{-# INLINE bmap #-}
bfold :: (Word64 -> Word64 -> Word64) -> BitState -> Word64
bfold f (BitState x y z) = x `f` y `f` z
{-# INLINE bfold #-}
singleton :: Word64 -> BitState
singleton !x = BitState x x x
{-# INLINE singleton #-}
dame :: Int -> Int
dame !x = sumWith fn row
where
fn !x' = recur (x - 2) $ nextState $ singleton x'
getPossible !depth !stateVal !state !bit
| (bit .&. stateVal) > 0 = 0
| depth == 0 = 1
| otherwise = recur (depth - 1) (nextState (addBitToState bit state))
recur !depth !state = sumWith (getPossible depth (getStateVal state) state) row
!row = Vector.iterateN x moveLeft 1
sumWith :: (Vector.Unbox a, Vector.Unbox b, Num b) => (a -> b) -> Vector a -> b
sumWith f as = Vector.sum $ Vector.map f as
{-# INLINE sumWith #-}
getStateVal :: BitState -> Word64
getStateVal !b = bfold (.|.) b
addBitToState :: Word64 -> BitState -> BitState
addBitToState !l !b = bmap (.|. l) b
nextState :: BitState -> BitState
nextState !(BitState l r c) = BitState (moveLeft l) (moveRight r) c
moveRight :: Word64 -> Word64
moveRight !x = shiftR x 1
{-# INLINE moveRight #-}
moveLeft :: Word64 -> Word64
moveLeft !x = shift x 1
{-# INLINE moveLeft #-}
I checked the core with ghc dame.hs -O2 -fforce-recomp -ddump-simpl -dsuppress-all, and it looked pretty good (i.e. everything unboxed, loops looked good). I was concerned that the partial application of getPossible might be a problem, but it turned out to not be. I feel like if I understood the algorithm better it might be possible to write in a better/more efficient way, however I'm not too concerned - this still manages to beat the Java implementation.

Why is the "better" digit-listing function slower?

I was playing around with Project Euler #34, and I wrote these functions:
import Data.Time.Clock.POSIX
import Data.Char
digits :: (Integral a) => a -> [Int]
digits x
| x < 10 = [fromIntegral x]
| otherwise = let (q, r) = x `quotRem` 10 in (fromIntegral r) : (digits q)
digitsByShow :: (Integral a, Show a) => a -> [Int]
digitsByShow = map (\x -> ord x - ord '0') . show
I thought that for sure digits has to be the faster one, as we don't convert to a String. I could not have been more wrong. I ran the two versions via pe034:
pe034 digitFunc = sum $ filter sumFactDigit [3..2540160]
where
sumFactDigit :: Int -> Bool
sumFactDigit n = n == (sum $ map sFact $ digitFunc n)
sFact :: Int -> Int
sFact n
| n == 0 = 1
| n == 1 = 1
| n == 2 = 2
| n == 3 = 6
| n == 4 = 24
| n == 5 = 120
| n == 6 = 720
| n == 7 = 5040
| n == 8 = 40320
| n == 9 = 362880
main = do
begin <- getPOSIXTime
print $ pe034 digitsByShow -- or digits
end <- getPOSIXTime
print $ end - begin
After compiling with ghc -O, digits consistently takes .5 seconds, while digitsByShow consistently takes .3 seconds. Why is this so? Why is the function which stays within Integer arithmetic slower, whereas the function which goes into string comparison is faster?
I ask this because I come from programming in Java and similar languages, where the % 10 trick of generating digits is way faster than the "convert to String" method. I haven't been able to wrap my head around the fact that converting to a string could be faster.
This is the best I can come up with.
digitsV2 :: (Integral a) => a -> [Int]
digitsV2 n = go n []
where
go x xs
| x < 10 = fromIntegral x : xs
| otherwise = case quotRem x 10 of
(q,r) -> go q (fromIntegral r : xs)
when compiled with -O2 and tested with Criterion
digits runs in 470.4 ms
digitsByShow runs in 421.8 ms
digitsV2 runs in 258.0 ms
results may vary
edit:
I am not sure why building the list like this helps so much.
But you can improve your codes speed by strictly evaluating quotRem x 10
You can do this with BangPatterns
| otherwise = let !(q, r) = x `quotRem` 10 in (fromIntegral r) : (digits q)
or with case
| otherwise = case quotRem x 10 of
(q,r) -> fromIntegral r : digits q
Doing this drops digits down to 323.5 ms
edit: time without using Criterion
digits = 464.3 ms
digitsStrict = 328.2 ms
digitsByShow = 259.2 ms
digitV2 = 252.5 ms
note: The criterion package measures software performance.
Let's investigate why #No_signal's solution is faster.
I made three runs of ghc:
ghc -O2 -ddump-simpl digits.hs >digits.txt
ghc -O2 -ddump-simpl digitsV2.hs >digitsV2.txt
ghc -O2 -ddump-simpl show.hs >show.txt
digits.hs
digits :: (Integral a) => a -> [Int]
digits x
| x < 10 = [fromIntegral x]
| otherwise = let (q, r) = x `quotRem` 10 in (fromIntegral r) : (digits q)
main = return $ digits 1
digitsV2.hs
digitsV2 :: (Integral a) => a -> [Int]
digitsV2 n = go n []
where
go x xs
| x < 10 = fromIntegral x : xs
| otherwise = let (q, r) = x `quotRem` 10 in go q (fromIntegral r : xs)
main = return $ digits 1
show.hs
import Data.Char
digitsByShow :: (Integral a, Show a) => a -> [Int]
digitsByShow = map (\x -> ord x - ord '0') . show
main = return $ digitsByShow 1
If you'd like to view the complete txt files, I placed them on ideone (rather than paste a 10000 char dump here):
digits.txt
digitsV2.txt
show.txt
If we carefully look through digits.txt, it appears that this is the relevant section:
lvl_r1qU = __integer 10
Rec {
Main.$w$sdigits [InlPrag=[0], Occ=LoopBreaker]
:: Integer -> (# Int, [Int] #)
[GblId, Arity=1, Str=DmdType <S,U>]
Main.$w$sdigits =
\ (w_s1pI :: Integer) ->
case integer-gmp-1.0.0.0:GHC.Integer.Type.ltInteger#
w_s1pI lvl_r1qU
of wild_a17q { __DEFAULT ->
case GHC.Prim.tagToEnum# # Bool wild_a17q of _ [Occ=Dead] {
False ->
let {
ds_s16Q [Dmd=<L,U(U,U)>] :: (Integer, Integer)
[LclId, Str=DmdType]
ds_s16Q =
case integer-gmp-1.0.0.0:GHC.Integer.Type.quotRemInteger
w_s1pI lvl_r1qU
of _ [Occ=Dead] { (# ipv_a17D, ipv1_a17E #) ->
(ipv_a17D, ipv1_a17E)
} } in
(# case ds_s16Q of _ [Occ=Dead] { (q_a11V, r_X12h) ->
case integer-gmp-1.0.0.0:GHC.Integer.Type.integerToInt r_X12h
of wild3_a17c { __DEFAULT ->
GHC.Types.I# wild3_a17c
}
},
case ds_s16Q of _ [Occ=Dead] { (q_X12h, r_X129) ->
case Main.$w$sdigits q_X12h
of _ [Occ=Dead] { (# ww1_s1pO, ww2_s1pP #) ->
GHC.Types.: # Int ww1_s1pO ww2_s1pP
}
} #);
True ->
(# GHC.Num.$fNumInt_$cfromInteger w_s1pI, GHC.Types.[] # Int #)
}
}
end Rec }
digitsV2.txt:
lvl_r1xl = __integer 10
Rec {
Main.$wgo [InlPrag=[0], Occ=LoopBreaker]
:: Integer -> [Int] -> (# Int, [Int] #)
[GblId, Arity=2, Str=DmdType <S,U><L,U>]
Main.$wgo =
\ (w_s1wh :: Integer) (w1_s1wi :: [Int]) ->
case integer-gmp-1.0.0.0:GHC.Integer.Type.ltInteger#
w_s1wh lvl_r1xl
of wild_a1dp { __DEFAULT ->
case GHC.Prim.tagToEnum# # Bool wild_a1dp of _ [Occ=Dead] {
False ->
case integer-gmp-1.0.0.0:GHC.Integer.Type.quotRemInteger
w_s1wh lvl_r1xl
of _ [Occ=Dead] { (# ipv_a1dB, ipv1_a1dC #) ->
Main.$wgo
ipv_a1dB
(GHC.Types.:
# Int
(case integer-gmp-1.0.0.0:GHC.Integer.Type.integerToInt ipv1_a1dC
of wild2_a1ea { __DEFAULT ->
GHC.Types.I# wild2_a1ea
})
w1_s1wi)
};
True -> (# GHC.Num.$fNumInt_$cfromInteger w_s1wh, w1_s1wi #)
}
}
end Rec }
I actually couldn't find the relevant section for show.txt. I'll work on that later.
Right off the bat, digitsV2.hs produces shorter code. That's probably a good sign for it.
digits.hs seems to be following this psuedocode:
def digits(w_s1pI):
if w_s1pI < 10: return [fromInteger(w_s1pI)]
else:
ds_s16Q = quotRem(w_s1pI, 10)
q_X12h = ds_s16Q[0]
r_X12h = ds_s16Q[1]
wild3_a17c = integerToInt(r_X12h)
ww1_s1pO = r_X12h
ww2_s1pP = digits(q_X12h)
ww2_s1pP.pushFront(ww1_s1pO)
return ww2_s1pP
digitsV2.hs seems to be following this psuedocode:
def digitsV2(w_s1wh, w1_s1wi=[]): # actually disguised as go(), as #No_signal wrote
if w_s1wh < 10:
w1_s1wi.pushFront(fromInteger(w_s1wh))
return w1_s1wi
else:
ipv_a1dB, ipv1_a1dC = quotRem(w_s1wh, 10)
w1_s1wi.pushFront(integerToIn(ipv1a1dC))
return digitsV2(ipv1_a1dC, w1_s1wi)
It might not be that these functions mutate lists like my psuedocode suggests, but this immediately suggests something: it looks as if digitsV2 is fully tail-recursive, whereas digits is actually not (may have to use some Haskell trampoline or something). It appears as if Haskell needs to store all the remainders in digits before pushing them all to the front of the list, whereas it can just push them and forget about them in digitsV2. This is purely speculation, but it is well-founded speculation.

Netwire 5 - A box cannot bounce

I am trying to convert challenge 3 ( https://ocharles.org.uk/blog/posts/2013-08-01-getting-started-with-netwire-and-sdl.html ) from netwire 4.0 to netwire 5.0 using OpenGL. Unfortunately, the box cannot bounce. The entire code is following. It seems to me that the function velocity does not work. When the box collides with a wall, it does not bounce but stops. How do I correct my program? Thanks in advance.
{-# LANGUAGE Arrows #-}
import Graphics.Rendering.OpenGL
import Graphics.UI.GLFW
import Data.IORef
import Prelude hiding ((.))
import Control.Monad.Fix (MonadFix)
import Control.Wire
import FRP.Netwire
isKeyDown :: (Enum k, Monoid e) => k -> Wire s e IO a e
isKeyDown k = mkGen_ $ \_ -> do
s <- getKey k
return $ case s of
Press -> Right mempty
Release -> Left mempty
acceleration :: (Monoid e) => Wire s e IO a Double
acceleration = pure ( 0) . isKeyDown (CharKey 'A') . isKeyDown (CharKey 'D')
<|> pure (-0.5) . isKeyDown (CharKey 'A')
<|> pure ( 0.5) . isKeyDown (CharKey 'D')
<|> pure ( 0)
velocity :: (Monad m, HasTime t s, Monoid e) => Wire s e m (Double, Bool) Double
velocity = integralWith bounce 0
where bounce c v
| c = (-v)
| otherwise = v
collided :: (Ord a, Fractional a) => (a, a) -> a -> (a, Bool)
collided (a, b) x
| x < a = (a, True)
| x > b = (b, True)
| otherwise = (x, False)
position' :: (Monad m, HasTime t s) => Wire s e m Double (Double, Bool)
position' = integral 0 >>> (arr $ collided (-0.8, 0.8))
challenge3 :: (HasTime t s) => Wire s () IO a Double
challenge3 = proc _ -> do
rec a <- acceleration -< ()
v <- velocity -< (a, c)
(p, c) <- position' -< v
returnA -< p
s :: Double
s = 0.05
y :: Double
y = 0.0
renderPoint :: (Double, Double) -> IO ()
renderPoint (x, y) = vertex $ Vertex2 (realToFrac x :: GLfloat) (realToFrac y :: GLfloat)
generatePoints :: Double -> Double -> Double -> [(Double, Double)]
generatePoints x y s =
[ (x - s, y - s)
, (x + s, y - s)
, (x + s, y + s)
, (x - s, y + s) ]
runNetwork :: (HasTime t s) => IORef Bool -> Session IO s -> Wire s e IO a Double -> IO ()
runNetwork closedRef session wire = do
pollEvents
closed <- readIORef closedRef
if closed
then return ()
else do
(st , session') <- stepSession session
(wt', wire' ) <- stepWire wire st $ Right undefined
case wt' of
Left _ -> return ()
Right x -> do
clear [ColorBuffer]
renderPrimitive Quads $
mapM_ renderPoint $ generatePoints x y s
swapBuffers
runNetwork closedRef session' wire'
main :: IO ()
main = do
initialize
openWindow (Size 1024 512) [DisplayRGBBits 8 8 8, DisplayAlphaBits 8, DisplayDepthBits 24] Window
closedRef <- newIORef False
windowCloseCallback $= do
writeIORef closedRef True
return True
runNetwork closedRef clockSession_ challenge3
closeWindow
By experience, I think the trick here is the fact that you technically have to bounce a few pixels before the actual collision, because if you detect it when it happens, then the inertia put your square a little bit "in" the wall, and so velocity is constantly reversed, causing your square to be blocked.
Ocharles actually nods to it in the blog post :
If this position falls outside the world bounds, we adjust the square (with a small epsilon to stop it getting stuck in the wall) and return the collision information.
Good luck with Netwire 5, I'm playing with it too, and I just begin to like it. ;)

Unboxing a function

I have a function that I am trying to optimize. This is part of a bigger code where I suspect this function is preventing GHC from unboxing Int arguments at higher level function that calls it. So, I wrote a simple test with two things in mind - understand the core, and try different things to see what makes GHC unbox it, so that I can apply the lessons to bigger code. Here is the function cmp with a test function wrapper:
{-# LANGUAGE BangPatterns #-}
module Cmp
( cmp,
test )
where
import Data.Vector.Unboxed as U hiding (mapM_)
import Data.Word
cmp :: (U.Unbox a, Eq a) => U.Vector a -> U.Vector a -> Int -> Int -> Int
cmp a b !i !j = go a b 0 i j
where
go v1 v2 !len !i !j| (i<n) && (j<m) && ((unsafeIndex v1 i) == (unsafeIndex v2 j)) = go v1 v2 (len+1) (i+1) (j+1)
| otherwise = len
where
n = U.length a
m = U.length b
{-# INLINABLE cmp #-}
test :: (U.Unbox a, Eq a) => U.Vector a -> U.Vector a -> U.Vector Int -> Int
test a b i = U.sum $ U.map (\x -> cmp a b x x) i
Ideally, test should call unboxed version of cmp with following signature (of course, correct me if I am wrong):
U.Vector a -> U.Vector a -> Int# -> Int# -> Int#
Looking at the core generated in ghc 7.6.1 (command line option:ghc -fforce-recomp -ddump-simpl -dsuppress-uniques -dsuppress-idinfo -dsuppress-module-prefixes -O2 -fllvm), I see this for inner loop for test - snippets from core below, with my comments added:
-- cmp function doesn't have any helper functions with unboxed Int
--
cmp
:: forall a.
(Unbox a, Eq a) =>
Vector a -> Vector a -> Int -> Int -> Int
...
-- This is the function that is called by test - it does keep the result
-- unboxed, but calls boxed cmp, and unboxes the result of cmp (I# y)
--
$wa
:: forall a.
(Unbox a, Eq a) =>
Vector a -> Vector a -> Vector Int -> Int#
$wa =
\ (# a)
(w :: Unbox a)
(w1 :: Eq a)
(w2 :: Vector a)
(w3 :: Vector a)
(w4 :: Vector Int) ->
case w4
`cast` (<TFCo:R:VectorInt> ; <NTCo:R:VectorInt>
:: Vector Int ~# Vector Int)
of _ { Vector ipv ipv1 ipv2 ->
letrec {
$s$wfoldlM'_loop :: Int# -> Int# -> Int#
$s$wfoldlM'_loop =
\ (sc :: Int#) (sc1 :: Int#) ->
case >=# sc1 ipv1 of _ {
False ->
case indexIntArray# ipv2 (+# ipv sc1) of wild { __DEFAULT ->
let {
x :: Int
x = I# wild } in
--
-- Calls cmp and unboxes the Int result as I# y
--
case cmp # a w w1 w2 w3 x x of _ { I# y ->
$s$wfoldlM'_loop (+# sc y) (+# sc1 1)
}
};
True -> sc
}; } in
$s$wfoldlM'_loop 0 0
}
-- helper function called by test - it calls $wa which calls boxed cmp
--
test1
:: forall a.
(Unbox a, Eq a) =>
Vector a -> Vector a -> Vector Int -> Id Int
test1 =
\ (# a)
(w :: Unbox a)
(w1 :: Eq a)
(w2 :: Vector a)
(w3 :: Vector a)
(w4 :: Vector Int) ->
case $wa # a w w1 w2 w3 w4 of ww { __DEFAULT ->
(I# ww) `cast` (Sym <(NTCo:Id <Int>)> :: Int ~# Id Int)
}
I will appreciate pointers on how to force unboxed version of cmp to be called from test. I tried strictifying different arguments, but that was like throwing the kitchen sink at it, which of course didn't work. I hope to use the lessons learnt here to solve the boxing/unboxing performance issue in the more complicated code.
Also, one more question - I have seen cast being used in the core, but haven't found any core references on Haskell/GHC wiki that explain what it is. It seems a type casting operation. I would appreciate explanation of what it is, and how to interpret it in the test1 function above.
Now I don't have ghc, so my advices would be verbal:
Why do you avoid {-# INLINE #-} pragma? High performance in Haskell is significantly based on function inlining. Add INLINE pragma to the go function.
Remove first two excessive parameters of go function. Read more about interoperation of inlining, specializing (unboxing) of parameters here: http://www.haskell.org/ghc/docs/latest/html/users_guide/pragmas.html#inline-pragma
Move m and n definitions one level up, along with go.

Resources