This problem involves an arbitrary number of dice with each an arbitrary number of sides. We then find the maximal number of dice that can be put in a straight, see Google's Code Jam explanation. The implementation should be reasonably efficient for around 10^5 dice with each one having up to 10^6 sides.
I've been trying to solve the problem in Haskell and this is my solution. However, it is not fast enough to earn full points on the problem, so can this be optimized?
import Data.List (sort)
import Data.Foldable (foldl')
getMaxStraight :: [Int] -> Int
getMaxStraight sides =
foldl'
(\maxStraight side -> if side > maxStraight then succ maxStraight else maxStraight)
0
(sort sides)
-- Doing IO in Haskell
-- Assuming the above works perfectly, something might not perform well below
main :: IO ()
main = do
line <- getLine :: IO String
let numberOfCases = read line :: Int
in mapM_ solveCase [1 .. numberOfCases]
solveCase :: Show a => a -> IO ()
solveCase i = do
line1 <- getLine :: IO String
line2 <- getLine :: IO String
let numberOfDice = read line1 :: Int
let diceSides = map read (words line2) :: [Int]
let maxStraight = getMaxStraight diceSides
putStrLn $ "Case #" ++ show i ++ ": " ++ show maxStraight
Aside from this, I've also written a Python solution which did run in time. I'd expect that Haskell would run faster than Python. What is going on?
def get_max_straight(sides):
max_straight = 0
for side in sorted(sides):
if side > max_straight:
max_straight += 1
return max_straight
# Doing IO in Python
# This works as needed
if __name__ == '__main__':
tests_len = int(input())
for case_num in range(1, 1 + tests_len):
input() # Discard unneeded input
sides = [int(s) for s in input().split(' ')]
print(f'Case #{case_num}: {get_max_straight(sides)}')
Related
I'm a Haskell beginner and have chosen it to solve a programming task for my class, however my solution is too slow and doesn't get accepted. I'm trying to profile it and was hoping that I could get some pointers from more advanced Haskellers here.
The only other solution in my class that got accepted so far was written in Rust. I'm sure that I should be able to achieve similar performance in Haskell and I wrote horrible imperative code in the hope of improving performance, alas to no avail.
My first suspicion relates to work, where I am using forever to go over the in-degree array until I get an out-of-bounds exception. I was hoping for this to be tail-recursive and to compile to a while (true) style loop.
My second suspicion is that I/O is perhaps slowing things down.
EDIT: The problem has likely to do with my algorithm because I am not keeping a queue of nodes with indegree 0. Thank you #luqui.
EDIT2: It seems that the real bottleneck was I/O, I fixed that thanks to #Davislor.
The task is based on this: http://www.spoj.com/UKCPLAD/problems/TOPOSORT/ and I am constrained to use only the libraries in the Haskell Platform.
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE LambdaCase #-}
{-# OPTIONS_GHC -O3 #-}
import Control.Monad
import Data.Array.IO
import Data.IORef
import Data.Int
import Control.Exception
type List = []
type Node = Int32
type Edge = (Node, Node)
type Indegree = Int32
main = do
(numNodes, _) <- readPair <$> getLine
edges <- map readPair . lines <$> getContents
topo numNodes edges
-- lower bound
{-# INLINE lb #-}
lb = 1
topo :: Node -> List Edge -> IO ()
topo numNodes edges = do
result <- newIORef []
count <- newIORef 0
indegrees <- newArray (lb,numNodes) 0 :: IO (IOUArray Node Indegree)
neighbours <- newArray (lb,numNodes) [] :: IO (IOArray Node (List Node))
forM_ edges $ \(from,to) -> do
update indegrees to (+1)
update neighbours from (to:)
let work = forever $ do
z <- getNext indegrees
modifyIORef' result (z:)
modifyIORef' count (+1)
ns <- readArray neighbours z
forM_ ns $ \n -> update indegrees n pred
work `catch`
\(_ :: SomeException) -> do
count <- readIORef count
if numNodes == count
then (mapM_ (\n -> putStr (show n ++ " ")) . reverse) =<< readIORef result
else putStrLn "Sandro fails."
{-# INLINE update #-}
update a i f = do
x <- readArray a i
writeArray a i (f x)
{-# INLINE getNext #-}
getNext indegrees = getNext' indegrees =<< getBounds indegrees
{-# INLINE getNext' #-}
getNext' indegrees (lb,ub) = readArray indegrees lb >>= \case
0 -> writeArray indegrees lb (-1) >> return lb
_ -> getNext' indegrees (lb+1,ub)
readPair :: String -> (Node,Node)
{-# INLINE readPair #-}
readPair = toPair . map read . words
where toPair [x,y] = (x,y)
toPair _ = error "Only two entries per line allowed"
Example output
$ ./topo
8 9
1 4
1 2
4 2
4 3
3 2
5 2
3 5
8 2
8 6
^D
1 4 3 5 7 8 2 6
If you haven’t already, profile your program by compiling with -prof -fprof-auto and then executing with the command-line options +RTS -p. This will generate a profile *.prof that will tell you which functions the program is spending all its time in. However, I can see immediately where the biggest time-waster is. Your instincts were right: it’s the I/O.
Having done that a lot, I can guarantee you that you’ll find that it’s spending the vast majority of its time doing I/O. The first thing you should always do to speed up your program is rewrite it to use fast I/O. Haskell is a fast language, when you use the right data structures. The default I/O library in the Prelude uses singly-linked lists with lazily-evaluated thunks where each node holds a single Unicode character. That would be slow in C, too!
I’ve gotten the best results with Data.ByteString.Lazy.Char8 when the input is ASCII, and Data.ByteString.Builder to generate the output. (An alternative is Data.Text.) That gets you a lazily-evaluated list of strict character buffers on input (so interactive input and output still works), and fills a single buffer on output.
After you’ve written the skeleton of the program with fast I/O, the next step is to look at your algorithm, and especially your data structures. Use profiling to see where all the time goes. But I’d recommend you use a functional algorithm rather than trying to write imperative programs in Haskell with do.
I almost always approach problems like this in Haskell with a more functional style: in particular, my main function is almost always something similar to:
import qualified Data.ByteString.Lazy.Char8 as B8
main :: IO()
main = B8.interact ( output . compute . input )
This makes everything except the call to interact a pure function, and isolates the parsing code and the formatting code so the compute part in the middle can be independent of that.
Since this is an assignment and you want to solve the problem yourself, I’ll refrain from refactoring the program for you, but here’s an example I wrote in response to a question on another forum to perform a counting sort. It should be suitable as a skeleton for other kinds of problems.
import Data.Array.IArray (accumArray, assocs)
import Data.Array.Unboxed (UArray)
import Data.ByteString.Builder (Builder, char7, intDec, toLazyByteString)
import qualified Data.ByteString.Lazy.Char8 as B8
import Data.Monoid ((<>))
main :: IO()
main = B8.interact ( output . compute . input ) where
input :: B8.ByteString -> [Int]
input = map perLine . tail . B8.lines where
perLine = decode . B8.readInt
decode (Just (x, _)) = x
decode Nothing = error "Invalid input: expected integer."
compute :: [Int] -> [Int]
compute = concatMap expand . assocs . countingSort . map encode where
encode i = (i, 1)
countingSort :: [(Int, Int)] -> UArray Int Int
countingSort = accumArray (+) 0 (lower, upper)
lower = 0
upper = 1000000
expand (i,c) = replicate c i
output :: [Int] -> B8.ByteString
output = toLazyByteString . foldMap perCase where
perCase :: Int -> Builder
perCase x = intDec x <> char7 '\n'
At present, this version ran in less than half the time of anyone else’s Haskell solution for the same problem, the same holds true for the actual contest problems I’ve used it for, and the approach generalizes.
So I suggest changing the I/O to be similar to that, first, then profiling, and coming back with the profiling output if that doesn’t make enough of a difference. This might also be a good Code Review question.
Thanks to #Davislor's suggestions I managed to get it to be much faster and I also refactored the code for the better and now I actually have an m log(n) algorithm. Surprisingly this doesn't make that much of a difference—the I/O far outweighed the suboptimal complexity of the algorithm.
EDIT: got rid of unsafePerformIO and it actually runs a teeny-weeny bit faster. Plus adding -XStrict shaves off even more time.
{-# LANGUAGE Strict #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_GHC -O2 #-}
import Control.Monad
import Data.Array.IO
import Data.Int
import Data.Set (Set)
import qualified Data.Set as Set
import Data.ByteString.Builder (Builder, char7, intDec, toLazyByteString)
import qualified Data.ByteString.Lazy.Char8 as B8
import Data.Monoid ((<>))
type List = []
type Node = Int
type Edge = (Node, Node)
type Indegree = Int
main = B8.putStrLn =<< topo . map readPair . B8.lines =<< B8.getContents
readPair :: B8.ByteString -> (Node,Node)
readPair str = (x,y)
where
(Just (x, str')) = B8.readInt str
(Just (y, _ )) = B8.readInt (B8.tail str')
topo :: List Edge -> IO B8.ByteString
topo inp = do
let (numNodes, _) = head inp
edges = tail inp
indegrees <- newArray (1,numNodes) 0 :: IO (IOUArray Node Indegree)
neighbours <- newArray (1,numNodes) [] :: IO (IOArray Node (List Node))
-- setup
forM_ edges $ \(from,to) -> do
update indegrees to (+1)
update neighbours from (to:)
zeroes <- collectIndegreeZero [] indegrees =<< getBounds indegrees
processQueue (Set.fromList zeroes) [] numNodes indegrees neighbours
where
collectIndegreeZero acc indegrees (lb,ub)
| lb > ub = return acc
| otherwise = do
indegr <- readArray indegrees lb
let acc' = if indegr == 0 then (lb:acc) else acc
collectIndegreeZero acc' indegrees (lb+1,ub)
processQueue queue result numNodes indegrees neighbours = do
if null queue
then if numNodes == 0
then return . toLazyByteString . foldMap whitespace . reverse $ result
else return "Sandro fails."
else do
(node,queue) <- return $ Set.deleteFindMin queue
ns <- readArray neighbours node
queue <- foldM decrIndegrees queue ns
processQueue queue (node:result) (numNodes-1) indegrees neighbours
where
decrIndegrees :: Set Node -> Node -> IO (Set Node)
decrIndegrees q n = do
i <- readArray indegrees n
writeArray indegrees n (i-1)
return $ if i == 1 then Set.insert n q else q
whitespace x = intDec x <> char7 ' '
{-# INLINE update #-}
update a i f = do
x <- readArray a i
writeArray a i (f x)
Haskell version(1.03s):
module Main where
import qualified Data.Text as T
import qualified Data.Text.IO as TIO
import Control.Monad
import Control.Applicative ((<$>))
import Data.Vector.Unboxed (Vector,(!))
import qualified Data.Vector.Unboxed as V
solve :: Vector Int -> Int
solve ar =
V.foldl' go 0 ar' where
ar' = V.zip ar (V.postscanr' max 0 ar)
go sr (p,m) = sr + m - p
main = do
t <- fmap (read . T.unpack) TIO.getLine -- With Data.Text, the example finishes 15% faster.
T.unlines . map (T.pack . show . solve . V.fromList . map (read . T.unpack) . T.words)
<$> replicateM t (TIO.getLine >> TIO.getLine) >>= TIO.putStr
F# version(0.17s):
open System
let solve (ar : uint64[]) =
let ar' =
let t = Array.scanBack max ar 0UL |> fun x -> Array.take (x.Length-1) x
Array.zip ar t
let go sr (p,m) = sr + m - p
Array.fold go 0UL ar'
let getIntLine() =
Console.In.ReadLine().Split [|' '|]
|> Array.choose (fun x -> if x <> "" then uint64 x |> Some else None)
let getInt() = getIntLine().[0]
let t = getInt()
for i=1 to int t do
getInt() |> ignore
let ar = getIntLine()
printfn "%i" (solve ar)
The above two programs are the solutions for the Stock Maximize problem and times are for the first test case of the Run Code button.
For some reason the F# version is roughly 6x faster, but I am pretty sure that if I replaced the slow library functions with imperative loops that I could speed it up by at least 3 times and more likely 10x.
Could the Haskell version be similarly improved?
I am doing the above for learning purposes and in general I am finding it difficult to figure out how to write efficient Haskell code.
If you switch to ByteString and stick with plain Haskell lists (instead of vectors) you will get a more efficient solution. You may also rewrite the solve function with a single left fold and bypass zip and right scan (1).
Overall, on my machine, I get 20 times performance improvement compared to your Haskell solution (2).
Below Haskell code performs faster than the F# code:
import Data.List (unfoldr)
import Control.Applicative ((<$>))
import Control.Monad (replicateM_)
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as C
parse :: ByteString -> [Int]
parse = unfoldr $ C.readInt . C.dropWhile (== ' ')
solve :: [Int] -> Int
solve xs = foldl go (const 0) xs minBound
where go f x s = if s < x then f x else s - x + f s
main = do
[n] <- parse <$> B.getLine
replicateM_ n $ B.getLine >> B.getLine >>= print . solve . parse
1. See edits for an earlier version of this answer which implements solve using zip and scanr.
2. HackerRank website shows even a larger performance improvement.
If I wanted to do that quickly in F# I would avoid all of the higher-order functions inside solve and just write a C-style imperative loop:
let solve (ar : uint64[]) =
let mutable sr, m = 0UL, 0UL
for i in ar.Length-1 .. -1 .. 0 do
let p = ar.[i]
m <- max p m
sr <- sr + m - p
sr
According to my measurements, this is 11x faster than your F#.
Then the performance is limited by the IO layer (unicode parsing) and string splitting. This can be optimised by reading into a byte buffer and writing the lexer by hand:
let buf = Array.create 65536 0uy
let mutable idx = 0
let mutable length = 0
do
use stream = System.Console.OpenStandardInput()
let rec read m =
let c =
if idx < length then
idx <- idx + 1
else
length <- stream.Read(buf, 0, buf.Length)
idx <- 1
buf.[idx-1]
if length > 0 && '0'B <= c && c <= '9'B then
read (10UL * m + uint64(c - '0'B))
else
m
let read() = read 0UL
for _ in 1UL .. read() do
Array.init (read() |> int) (fun _ -> read())
|> solve
|> System.Console.WriteLine
Just for the record, the F# version is also not optimal. I don't think it really matters at this point, but if people wanted to compare the performance, then it is worth noting that it can be made faster.
I have not tried very hard (you can certainly make it even faster by using restricted mutation, which would not be against the nature of F#), but simple change to use Seq instead of Array in the right places (to avoid allocating temporary arrays) makes the code about 2x to 3x faster:
let solve (ar : uint64[]) =
let ar' = Seq.zip ar (Array.scanBack max ar 0UL)
let go sr (p,m) = sr + m - p
Seq.fold go 0UL ar'
If you use Seq.zip, you can also drop the take call (because Seq.zip truncates the sequence automatically). Measured using #time using the following snippet:
let rnd = Random()
let inp = Array.init 100000 (fun _ -> uint64 (rnd.Next()))
for a in 0 .. 10 do ignore (solve inp) // Measure this line
I get around 150ms for the original code and something between 50-75ms using the new version.
Say I have two pure but unsafe functions, that do the same, but one of them is working on batches, and is asymptotically faster:
f :: Int -> Result -- takes O(1) time
f = unsafePerformIO ...
g :: [Int] -> [Result] -- takes O(log n) time
g = unsafePerformIO ...
A naive implementation:
getUntil :: Int -> [Result]
getUntil 0 = f 0
getUntil n = f n : getUntil n-1
switch is the n value where g gets cheaper than f.
getUntil will in practice be called with ever increasing n, but it might not start at 0. So since the Haskell runtime can memoize getUntil, performance will be optimal if getUntil is called with an interval lower than switch. But once the interval gets larger, this implementation is slow.
In an imperative program, I guess I would make a TreeMap (which could quickly be checked for gaps) for caching all calls. On cache misses, it would get filled with the results of g, if the gap was greater than switch in length, and f otherwise, respectively.
How can this be optimized in Haskell?
I think I am just looking for:
an ordered map filled on-demand using a fill function that would fill all values up to the requested index using one function if the missing range is small, another if it is large
a get operation on the map which returns a list of all lower values up to the requested index. This would result in a function similar to getUntil above.
I'll elaborate in my proposal for using map, after some tests I just ran.
import System.IO
import System.IO.Unsafe
import Control.Concurrent
import Control.Monad
switch :: Int
switch = 1000
f :: Int -> Int
f x = unsafePerformIO $ do
threadDelay $ 500 * x
putStrLn $ "Calculated from scratch: f(" ++ show x ++ ")"
return $ 500*x
g :: Int -> Int
g x = unsafePerformIO $ do
threadDelay $ x*x `div` 2
putStrLn $ "Calculated from scratch: g(" ++ show x ++ ")"
return $ x*x `div` 2
cachedFG :: [Int]
cachedFG = map g [0 .. switch] ++ map f [switch+1 ..]
main :: IO ()
main = forever $ getLine >>= print . (cachedFG !!) . read
… where f, g and switch have the same meaning indicated in the question.
The above program can be compiled as is using GHC. When executed, positive integers can be entered, followed by a newline, and the application will print some value based on the number entered by the user plus some extra indication on what values are being calculated from scratch.
A short session with this program is:
User: 10000
Program: Calculated from scratch: f(10000)
Program: 5000000
User: 10001
Program: Calculated from scratch: f(10001)
Program: 5000500
User: 10000
Program: 5000000
^C
The program has to be killed/terminated manually.
Notice that the last value entered doesn't show a "calculated from scratch" message. This indicates that the program has the value cached/memoized somewhere. You can try executing this program yourself; but have into account that threadDelay's lag is proportional to the value entered.
The getUntil function then could be implemented using:
getUntil :: Int -> [Int]
getUntil n = take n cachedFG
or:
getUntil :: Int -> [Int]
getUntil = flip take cachedFG
If you don't know the value for switch, you can try evaluating f and g in parallel and use the fastest result, but that's another show.
Suppose someone makes a program to play chess, or solve sudoku. In this kind of program it makes sense to have a tree structure representing game states.
This tree would be very large, "practically infinite". Which isn't by itself a problem as Haskell supports infinite data structures.
An familiar example of an infinite data structure:
fibs = 0 : 1 : zipWith (+) fibs (tail fibs)
Nodes are only allocated when first used, so the list takes finite memory. One may also iterate over an infinite list if they don't keep references to its head, allowing the garbage collector to collect its parts which are not needed anymore.
Back to the tree example - suppose one does some iteration over the tree, the tree nodes iterated over may not be freed if the root of the tree is still needed (for example in an iterative deepening search, the tree would be iterated over several times and so the root needs to be kept).
One possible solution for this problem that I thought of is using an "unmemo-monad".
I'll try to demonstrate what this monad is supposed to do using monadic lists:
import Control.Monad.ListT (ListT) -- cabal install List
import Data.Copointed -- cabal install pointed
import Data.List.Class
import Prelude hiding (enumFromTo)
nums :: ListT Unmemo Int -- What is Unmemo?
nums = enumFromTo 0 1000000
main = print $ div (copoint (foldlL (+) 0 nums)) (copoint (lengthL nums))
Using nums :: [Int], the program would take a lot of memory as a reference to nums is needed by lengthL nums while it is being iterated over foldlL (+) 0 nums.
The purpose of Unmemo is to make the runtime not keep the nodes iterated over.
I attempted using ((->) ()) as Unmemo, but it yields the same results as nums :: [Int] does - the program uses a lot of memory, as evident by running it with +RTS -s.
Is there anyway to implement Unmemo that does what I want?
Same trick as with a stream -- don't capture the remainder directly, but instead capture a value and a function which yields a remainder. You can add memoization on top of this as necessary.
data UTree a = Leaf a | Branch a (a -> [UTree a])
I'm not in the mood to figure it out precisely at the moment, but this structure arises, I'm sure, naturally as the cofree comonad over a fairly straightforward functor.
Edit
Found it: http://hackage.haskell.org/packages/archive/comonad-transformers/1.6.3/doc/html/Control-Comonad-Trans-Stream.html
Or this is perhaps simpler to understand: http://hackage.haskell.org/packages/archive/streams/0.7.2/doc/html/Data-Stream-Branching.html
In either case, the trick is that your f can be chosen to be something like data N s a = N (s -> (s,[a])) for an appropriate s (s being the type of your state parameter of the stream -- the seed of your unfold, if you will). That might not be exactly correct, but something close should do...
But of course for real work, you can scrap all this and just write the datatype directly as above.
Edit 2
The below code illustrates how this can prevent sharing. Note that even in the version without sharing, there are humps in the profile indicating that the sum and length calls aren't running in constant space. I'd imagine that we'd need an explicit strict accumulation to knock those down.
{-# LANGUAGE DeriveFunctor #-}
import Data.Stream.Branching(Stream(..))
import qualified Data.Stream.Branching as S
import Control.Arrow
import Control.Applicative
import Data.List
data UM s a = UM (s -> Maybe a) deriving Functor
type UStream s a = Stream (UM s) a
runUM s (UM f) = f s
liftUM x = UM $ const (Just x)
nullUM = UM $ const Nothing
buildUStream :: Int -> Int -> Stream (UM ()) Int
buildUStream start end = S.unfold (\x -> (x, go x)) start
where go x
| x < end = liftUM (x + 1)
| otherwise = nullUM
sumUS :: Stream (UM ()) Int -> Int
sumUS x = S.head $ S.scanr (\x us -> maybe 0 id (runUM () us) + x) x
lengthUS :: Stream (UM ()) Int -> Int
lengthUS x = S.head $ S.scanr (\x us -> maybe 0 id (runUM () us) + 1) x
sumUS' :: Stream (UM ()) Int -> Int
sumUS' x = last $ usToList $ liftUM $ S.scanl (+) 0 x
lengthUS' :: Stream (UM ()) Int -> Int
lengthUS' x = last $ usToList $ liftUM $ S.scanl (\acc _ -> acc + 1) 0 x
usToList x = unfoldr (\um -> (S.head &&& S.tail) <$> runUM () um) x
maxNum = 1000000
nums = buildUStream 0 maxNum
numsL :: [Int]
numsL = [0..maxNum]
-- All these need to be run with increased stack to avoid an overflow.
-- This generates an hp file with two humps (i.e. the list is not shared)
main = print $ div (fromIntegral $ sumUS' nums) (fromIntegral $ lengthUS' nums)
-- This generates an hp file as above, and uses somewhat less memory, at the cost of
-- an increased number of GCs. -H helps a lot with that.
-- main = print $ div (fromIntegral $ sumUS nums) (fromIntegral $ lengthUS nums)
-- This generates an hp file with one hump (i.e. the list is shared)
-- main = print $ div (fromIntegral $ sum $ numsL) (fromIntegral $ length $ numsL)
I'm trying to learn Haskell and after an article in reddit about Markov text chains, I decided to implement Markov text generation first in Python and now in Haskell. However I noticed that my python implementation is way faster than the Haskell version, even Haskell is compiled to native code. I am wondering what I should do to make the Haskell code run faster and for now I believe it's so much slower because of using Data.Map instead of hashmaps, but I'm not sure
I'll post the Python code and Haskell as well. With the same data, Python takes around 3 seconds and Haskell is closer to 16 seconds.
It comes without saying that I'll take any constructive criticism :).
import random
import re
import cPickle
class Markov:
def __init__(self, filenames):
self.filenames = filenames
self.cache = self.train(self.readfiles())
picklefd = open("dump", "w")
cPickle.dump(self.cache, picklefd)
picklefd.close()
def train(self, text):
splitted = re.findall(r"(\w+|[.!?',])", text)
print "Total of %d splitted words" % (len(splitted))
cache = {}
for i in xrange(len(splitted)-2):
pair = (splitted[i], splitted[i+1])
followup = splitted[i+2]
if pair in cache:
if followup not in cache[pair]:
cache[pair][followup] = 1
else:
cache[pair][followup] += 1
else:
cache[pair] = {followup: 1}
return cache
def readfiles(self):
data = ""
for filename in self.filenames:
fd = open(filename)
data += fd.read()
fd.close()
return data
def concat(self, words):
sentence = ""
for word in words:
if word in "'\",?!:;.":
sentence = sentence[0:-1] + word + " "
else:
sentence += word + " "
return sentence
def pickword(self, words):
temp = [(k, words[k]) for k in words]
results = []
for (word, n) in temp:
results.append(word)
if n > 1:
for i in xrange(n-1):
results.append(word)
return random.choice(results)
def gentext(self, words):
allwords = [k for k in self.cache]
(first, second) = random.choice(filter(lambda (a,b): a.istitle(), [k for k in self.cache]))
sentence = [first, second]
while len(sentence) < words or sentence[-1] is not ".":
current = (sentence[-2], sentence[-1])
if current in self.cache:
followup = self.pickword(self.cache[current])
sentence.append(followup)
else:
print "Wasn't able to. Breaking"
break
print self.concat(sentence)
Markov(["76.txt"])
--
module Markov
( train
, fox
) where
import Debug.Trace
import qualified Data.Map as M
import qualified System.Random as R
import qualified Data.ByteString.Char8 as B
type Database = M.Map (B.ByteString, B.ByteString) (M.Map B.ByteString Int)
train :: [B.ByteString] -> Database
train (x:y:[]) = M.empty
train (x:y:z:xs) =
let l = train (y:z:xs)
in M.insertWith' (\new old -> M.insertWith' (+) z 1 old) (x, y) (M.singleton z 1) `seq` l
main = do
contents <- B.readFile "76.txt"
print $ train $ B.words contents
fox="The quick brown fox jumps over the brown fox who is slow jumps over the brown fox who is dead."
a) How are you compiling it? (ghc -O2 ?)
b) Which version of GHC?
c) Data.Map is pretty efficient, but you can be tricked into lazy updates -- use insertWith' , not insertWithKey.
d) Don't convert bytestrings to String. Keep them as bytestrings, and store those in the Map
Data.Map is designed under the assumption that the class Ord comparisons take constant time. For string keys this may not be the case—and when the strings are equal it is never the case. You may or may not be hitting this problem depending on how large your corpus is and how many words have common prefixes.
I'd be tempted to try a data structure that is designed to operate with sequence keys, such as for example a the bytestring-trie package kindly suggested by Don Stewart.
I tried to avoid doing anything fancy or subtle. These are just two approaches to doing the grouping; the first emphasizes pattern matching, the second doesn't.
import Data.List (foldl')
import qualified Data.Map as M
import qualified Data.ByteString.Char8 as B
type Database2 = M.Map (B.ByteString, B.ByteString) (M.Map B.ByteString Int)
train2 :: [B.ByteString] -> Database2
train2 words = go words M.empty
where go (x:y:[]) m = m
go (x:y:z:xs) m = let addWord Nothing = Just $ M.singleton z 1
addWord (Just m') = Just $ M.alter inc z m'
inc Nothing = Just 1
inc (Just cnt) = Just $ cnt + 1
in go (y:z:xs) $ M.alter addWord (x,y) m
train3 :: [B.ByteString] -> Database2
train3 words = foldl' update M.empty (zip3 words (drop 1 words) (drop 2 words))
where update m (x,y,z) = M.alter (addWord z) (x,y) m
addWord word = Just . maybe (M.singleton word 1) (M.alter inc word)
inc = Just . maybe 1 (+1)
main = do contents <- B.readFile "76.txt"
let db = train3 $ B.words contents
print $ "Built a DB of " ++ show (M.size db) ++ " words"
I think they are both faster than the original version, but admittedly I only tried them against the first reasonable corpus I found.
EDIT
As per Travis Brown's very valid point,
train4 :: [B.ByteString] -> Database2
train4 words = foldl' update M.empty (zip3 words (drop 1 words) (drop 2 words))
where update m (x,y,z) = M.insertWith (inc z) (x,y) (M.singleton z 1) m
inc k _ = M.insertWith (+) k 1
Here's a foldl'-based version that seems to be about twice as fast as your train:
train' :: [B.ByteString] -> Database
train' xs = foldl' (flip f) M.empty $ zip3 xs (tail xs) (tail $ tail xs)
where
f (a, b, c) = M.insertWith (M.unionWith (+)) (a, b) (M.singleton c 1)
I tried it on the Project Gutenberg Huckleberry Finn (which I assume is your 76.txt), and it produces the same output as your function. My timing comparison was very unscientific, but this approach is probably worth a look.
1) I'm not clear on your code.
a) You define "fox" but don't use it. Were you meaning for us to try to help you using "fox" instead of reading the file?
b) You declare this as "module Markov" then have a 'main' in the module.
c) System.Random isn't needed. It does help us help you if you clean code a bit before posting.
2) Use ByteStrings and some strict operations as Don said.
3) Compile with -O2 and use -fforce-recomp to be sure you actually recompiled the code.
4) Try this slight transformation, it works very fast (0.005 seconds). Obviously the input is absurdly small, so you'd need to provide your file or just test it yourself.
{-# LANGUAGE OverloadedStrings, BangPatterns #-}
module Main where
import qualified Data.Map as M
import qualified Data.ByteString.Lazy.Char8 as B
type Database = M.Map (B.ByteString, B.ByteString) (M.Map B.ByteString Int)
train :: [B.ByteString] -> Database
train xs = go xs M.empty
where
go :: [B.ByteString] -> Database -> Database
go (x:y:[]) !m = m
go (x:y:z:xs) !m =
let m' = M.insertWithKey' (\key new old -> M.insertWithKey' (\_ n o -> n + 1) z 1 old) (x, y) (M.singleton z 1) m
in go (y:z:xs) m'
main = print $ train $ B.words fox
fox="The quick brown fox jumps over the brown fox who is slow jumps over the brown fox who is dead."
As Don suggested, look into using the stricer versions o your functions: insertWithKey' (and M.insertWith' since you ignore the key param the second time anyway).
It looks like your code probably builds up a lot of thunks until it gets to the end of your [String].
Check out: http://book.realworldhaskell.org/read/profiling-and-optimization.html
...especially try graphing the heap (about halfway through the chapter). Interested to see what you figure out.