I am starting in Haskell and am interested in how to get matching performance for simple code I would normally write in C or Python. Consider the following problem.
You are given a long string of 1s and 0s of length n. We want to output, for each substring of length m, the number of 1s in that window. That is the output has n-m+1 different possible values between 0 and m inclusive.
In C this is very simple to do in time proportional to n and using extra space (on top of the space needed to store the input) proportional to m bits. You just count the number of 1s in the first window of length m and then maintain two pointers, one to the start of the window and one to the end and increment or decrement depending of whether one points to a 1 and the other points to a 0 or the opposite occurs.
Is it possible to get the same theoretical performance in a purely functional way in Haskell?
Some terrible code:
chunkBits m = helper
where helper [] = []
helper xs = sum (take m xs) : helper (drop m xs)
main = print $ chunkBits 5 [0,1,1,0,1,0,0,1,0,1,0,1,1,1,0,0,0,1]
C Code
Here is the C code you've described:
int sliding_window(const char * const str, const int n, const int m, int * result){
const char * back = str;
const char * front = str + m;
int sum = 0;
int i;
for(i = 0; i < m; ++i){
sum += str[i] == '1';
}
*result++ = sum;
for(; i < n; ++i){
sum += *front++ == '1';
sum -= *back++ == '1';
*result++ = sum;
}
return n - m + 1;
}
Algorithm
The code above is apparently O(n), since we have n iterations. But lets go a step back and have a look at the underlying algorithm:
Sum the first m elements. Keep this as sum. O(m)
Our first window has sum 1s. O(1)
Until we've exhausted our original string: O(n)
"Slide" the window. O(1)
add 1 to sum if we gain a '1' by sliding O(1)
subtract 1 from sum if we lose a '1' by sliding O(1)
Push sum onto the results. O(1)
Since n > m (otherwise there is no window), O(n) holds.
Moulding a Haskell variant
That's basically a left scan (scanl) with a way to get a list of those differences in (2.1.). So all we need is a way to somehow slide:
slide :: Int -> [Char] -> [Int]
slide m xs = zipWith f xs (drop m xs)
where
f '1' '0' = -1 -- we lose a one
f '0' '1' = 1 -- we gain a one
f _ _ = 0 -- nothing :/
That's O(n), where n is the length of our list.
slidingWindow :: Int -> [Char] -> [Int]
slidingWindow m xs = scanl (+) start (slide m xs)
where
start = length (filter (== '1') (take m xs))
That's O(n), same as in C, since both use the same algorithm.
Caveats
In a real life application, you would always use Text or ByteString instead of String, since the latter is a list of Char with much overhead. Since you only use a string of '1' and '0', you can use ByteString:
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as BS
import Data.List (scanl')
slide :: Int -> ByteString -> [Int]
slide m xs = BS.zipWith f xs (BS.drop m xs)
where
f '1' '0' = -1
f '0' '1' = 1
f _ _ = 0
slidingWindow :: Int -> ByteString -> [Int]
slidingWindow m xs = scanl' (+) start (slide m xs)
where
start = BS.count '1' (BS.take m xs)
Update
After reading the question more carefully I noticed that the
C program reads its input from an array.
So here is an equivalent Haskell "pure" function which performs the task.
import qualified Data.Vector as V
import Data.List
import Control.Monad
count :: Int -> V.Vector Int -> [Int]
count m v =
let c0 = V.sum (V.take m v)
n = V.length v
results = scanl' go c0 [0..n-m-1]
where go r i = r - (v V.! i) + (v V.! (i+m))
in results
test1 = let v = V.fromList [0,0,1,1,1,1,1,0,0,0,0]
in print $ count 3 v
Even though count returns a list it will be generated lazily. Moreover, if it is consume by another list operation it could be optimized via one of the various fusion techniques.
Original Answer
This is a good exercise, but why does it have to be "purely functional" (and what does that mean anyway)?
You can write the C algorithm in Haskell - it's not as terse, but it will
generate essentially the same code.
import Data.Vector.Unboxed.Mutable as V
count m = do
v <- V.replicate m '0'
let toInt ch = if ch == '1' then 1 else 0
let loop c i = do
ch <- getChar
oldch <- V.read v i
let c' = c + toInt ch - toInt oldch
V.write v i ch
let i' = mod (i+1) m
putStrLn $ show c
loop c' i'
loop 0 0
main = count 3
(For simplicity this generates n results.)
If you were benchmark this note that you are also including the performance of
getChar and putStrLn and show, so it might be difficult to make a fair
comparison with a C program. However, it has O(n) complexity and constant
memory usage which is what I think you're asking for.
The most basic level is re-implementing the cool HOF-based algorithms with hand-written recursive functions to express the loops.
Banged patterns mark arguments as strict, so simple values can be calculated without unnecessary delay (this is implicitly taken care of when using scanl', for example). This also shows that "pointers" are just names:
{-# LANGUAGE BangPatterns #-}
-- assumes xs has only 0s and 1s
counts :: Int -> [Int] -> [Int]
counts m xs = g 0 m xs
where
g !c 0 ys = h c ys xs
g !c _ [] = [] -- m > |xs|
g !c m (y:ys) = g (c+y) (m-1) ys
h !c [] _ = [c]
h !c (y:ys) (x:xs) = c : h (c+y-x) ys xs
Testing,
> counts [1,1,0,0,1,1,0,1] 2
[2,1,0,1,2,1,1]
> counts [1,1,0,0,1,1,1,1] 3
[2,1,1,2,3,3]
You get an integer n and you need to find the index of its first appearance in Stern's Diatomic Sequence.
The sequence is defined like this:
a[0] = 0
a[1] = 1
a[2*i] = a[i]
a[2*i+1] = a[i] + a[i+1]
See MathWorld.
Because n can be up to 400000, it's not a good idea to brute-force it, especially since the time limit is 4000 ms.
The sequence is pretty odd: first occurrence of 8 is 21, but first occurrence of 6 is 33.
Any ideas how to solve this?
Maybe this might help: OEIS
We can easily solve for the first occurrence of a number in the range of 400000 in under four seconds:
Prelude Diatomic> firstDiatomic 400000
363490989
(0.03 secs, 26265328 bytes)
Prelude Diatomic> map firstDiatomic [400000 .. 400100]
[363490989,323659475,580472163,362981813,349334091,355685483,346478235,355707595
,291165867,346344083,347155797,316314293,576398643,315265835,313171245,355183267
,315444051,315970205,575509833,311741035,340569429,313223987,565355925,296441165
,361911645,312104147,557145429,317106853,323637939,324425077,610613547,311579309
,316037811,311744107,342436533,348992869,313382235,325406123,355818699,312128723
,347230875,324752171,313178421,312841811,313215645,321754459,576114987,325793195
,313148763,558545581,355294101,359224397,345462093,307583675,355677549,312120731
,341404245,316298389,581506779,345401947,312109779,316315061,315987123,313447771
,361540179,313878107,304788843,325765547,316036275,313731751,355635795,312035947
,346756533,313873883,349358379,357393763,559244877,313317739,325364139,312128107
,580201947,358182323,314944173,357403987,584291115,312158827,347448723,363246413
,315935571,349386085,315929427,312137323,357247725,313207657,320121429,356954923
,557139285,296392013,576042123,311726765,296408397]
(2.45 secs, 3201358192 bytes)
The key to it is the Calkin-Wilf tree.
Starting from the fraction 1/1, it is built by the rule that for a node with the fraction a/b, its left child carries the fraction a/(a+b), and its right child the fraction (a+b)/b.
1/1
/ \
/ \
/ \
1/2 2/1
/ \ / \
1/3 3/2 2/3 3/1
etc. The diatomic sequence (starting at index 1) is the sequence of numerators of the fractions in the Calkin-Wilf tree, when that is traversed level by level, each level from left to right.
If we look at the tree of indices
1
/ \
/ \
/ \
2 3
/ \ / \
4 5 6 7
/ \
8 9 ...
we can easily verify that the node at index k in the Calkin-Wilf tree carries the fraction a[k]/a[k+1] by induction.
That is obviously true for k = 1 (a[1] = a[2] = 1), and from then on,
for k = 2*j we have the left child of the node with index j, so the fraction is a[j]/(a[j]+a[j+1]) and a[k] = a[j] and a[k+1] = a[j] + a[j+1] are the defining equations of the sequence.
for k = 2*j+1 we have the right child of the node with index j, so the fraction is (a[j]+a[j+1])/a[j+1] and that is a[k]/a[k+1] again by the defining equations.
All positive reduced fractions occur exactly once in the Calkin-Wilf tree (left as an exercise for the reader), hence all positive integers occur in the diatomic sequence.
We can find the node in the Calkin-Wilf tree from the index by following the binary representation of the index, from the most significant bit to the least, for a 1-bit we go to the right child and for a 0-bit to the left. (For that, it is nice to augment the Calkin-Wilf tree with a node 0/1 whose right child is the 1/1 node, so that we need have a step for the most significant set bit of the index.)
Now, that doesn't yet help very much to solve the problem at hand.
But, let us first solve a related problem: For a reduced fraction p/q, determine its index.
Suppose that p > q. Then we know that p/q is a right child, and its parent is (p-q)/q. If also p-q > q, we have again a right child, whose parent is (p - 2*q)/q. Continuing, if
p = a*q + b, 1 <= b < q
then we reach the p/q node from the b/q node by going to the right child a times.
Now we need to find a node whose numerator is smaller than its denominator. That is of course the left child of its parent. The parent of b/q is b/(q-b) then. If
q = c*b + d, 1 <= d < b
we have to go to the left child c times from the node b/d to reach b/q.
And so on.
We can find the way from the root (1/1) to the p/q node using the continued fraction (I consider only simple continued fractions here) expansion of p/q. Let p > q and
p/q = [a_0, a_1, ..., a_r,1]
the continued fraction expansion of p/q ending in 1.
If r is even, then go to the right child a_r times, then to the left a_(r-1) times, then to the right child ... then a_1 times to the left child, and finally a_0 times to the right.
If r is odd, then first go to the left child a_r times, then a_(r-1) times to the right ... then a_1 times to the left child, and finally a_0 times to the right.
For p < q, we must end going to the left, hence start going to the left for even r and start going to the right for odd r.
We have thus found a close connection between the binary representation of the index and the continued fraction expansion of the fraction carried by the node via the path from the root to the node.
Let the run-length-encoding of the index k be
[c_1, c_2, ..., c_j] (all c_i > 0)
i.e. the binary representation of k starts with c_1 ones, followed by c_2 zeros, then c_3 ones etc., and ending with c_j
ones, if k is odd - hence j is also odd;
zeros, if k is even - hence j is also even.
Then [c_j, c_(j-1), ..., c_2, c_1] is the continued fraction expansion of a[k]/a[k+1] whose length has the same parity as k (every rational has exactly two continued fraction expansions, one with odd length, the other with even length).
The RLE gives the path from the 0/1 node above 1/1 to a[k]/a[k+1]. The length of the path is
the number of bits necessary to represent k, and
the sum of the partial quotients in the continued fraction expansion.
Now, to find the index of the first occurrence of n > 0 in the diatomic sequence, we first observe that the smallest index must necessarily be odd, since a[k] = a[k/2] for even k. Let the smallest index be k = 2*j+1. Then
the length of the RLE of k is odd,
the fraction at the node with index k is a[2*j+1]/a[2*j+2] = (a[j] + a[j+1])/a[j+1], hence it is a right child.
So the smallest index k with a[k] = n corresponds to the left-most ending of all the shortest paths to a node with numerator n.
The shortest paths correspond to the continued fraction expansions of n/m, where 0 < m <= n is coprime to n [the fraction must be reduced] with the smallest sum of the partial quotients.
What kind of length do we need to expect? Given a continued fraction p/q = [a_0, a_1, ..., a_r] with a_0 > 0 and sum
s = a_0 + ... + a_r
the numerator p is bounded by F(s+1) and the denominator q by F(s), where F(j) is the j-th Fibonacci number. The bounds are sharp, for a_0 = a_1 = ... = a_r = 1 the fraction is F(s+1)/F(s).
So if F(t) < n <= F(t+1), the sum of the partial quotients of the continued fraction expansion (either of the two) is >= t. Often there is an m such that the sum of the partial quotients of the continued fraction expansion of n/m is exactly t, but not always:
F(5) = 5 < 6 <= F(6) = 8
and the continued fraction expansions of the two reduced fractions 6/m with 0 < m <= 6 are
6/1 = [6] (alternatively [5,1])
6/5 = [1,4,1] (alternatively [1,5])
with sum of the partial quotients 6. However, the smallest possible sum of partial quotients is never much larger (the largest I'm aware of is t+2).
The continued fraction expansions of n/m and n/(n-m) are closely related. Let's assume that m < n/2, and let
n/m = [a_0, a_1, ..., a_r]
Then a_0 >= 2,
(n-m)/m = [a_0 - 1, a_1, ..., a_r]
and since
n/(n-m) = 1 + m/(n-m) = 1 + 1/((n-m)/m)
the continued fraction expansion of n/(n-m) is
n/(n-m) = [1, a_0 - 1, a_1, ..., a_r]
In particular, the sum of the partial quotients is the same for both.
Unfortunately, I'm not aware of a way to find the m with the smallest sum of partial quotients without brute force, so the algorithm is (I assume n > 2
for 0 < m < n/2 coprime to n, find the continued fraction expansion of n/m, collecting the ones with the smallest sum of the partial quotients (the usual algorithm produces expansions whose last partial quotient is > 1, we assume that).
Adjust the found continued fraction expansions [those are not large in number] it the following way:
if the CF [a_0, a_1, ..., a_r] has even length, convert it to [a_0, a_1, ..., a_(r-1), a_r - 1, 1]
otherwise, use [1, a_0 - 1, a_1, ..., a_(r-1), a_r - 1, 1]
(that chooses the one between n/m and n/(n-m) leading to the smaller index)
reverse the continued fractions to obtain the run-length-encodings of the corresponding indices
choose the smallest among them.
In step 1, it is useful to use the smallest sum found so far to short-cut.
Code (Haskell, since that's easiest):
module Diatomic (diatomic, firstDiatomic, fuscs) where
import Data.List
strip :: Int -> Int -> Int
strip p = go
where
go n = case n `quotRem` p of
(q,r) | r == 0 -> go q
| otherwise -> n
primeFactors :: Int -> [Int]
primeFactors n
| n < 1 = error "primeFactors: non-positive argument"
| n == 1 = []
| n `rem` 2 == 0 = 2 : go (strip 2 (n `quot` 2)) 3
| otherwise = go n 3
where
go 1 _ = []
go m p
| m < p*p = [m]
| r == 0 = p : go (strip p q) (p+2)
| otherwise = go m (p+2)
where
(q,r) = m `quotRem` p
contFracLim :: Int -> Int -> Int -> Maybe [Int]
contFracLim = go []
where
go acc lim n d = case n `quotRem` d of
(a,b) | lim < a -> Nothing
| b == 0 -> Just (a:acc)
| otherwise -> go (a:acc) (lim - a) d b
fixUpCF :: [Int] -> [Int]
fixUpCF [a]
| a < 3 = [a]
| otherwise = [1,a-2,1]
fixUpCF xs
| even (length xs) = case xs of
(1:_) -> fixEnd xs
(a:bs) -> 1 : (a-1) : bs
| otherwise = case xs of
(1:_) -> xs
(a:bs) -> 1 : fixEnd ((a-1):bs)
fixEnd :: [Int] -> [Int]
fixEnd [a,1] = [a+1]
fixEnd [a] = [a-1,1]
fixEnd (a:bs) = a : fixEnd bs
fixEnd _ = error "Shouldn't have called fixEnd with an empty list"
cfCompare :: [Int] -> [Int] -> Ordering
cfCompare (a:bs) (c:ds) = case compare a c of
EQ -> cfCompare ds bs
cp -> cp
fibs :: [Integer]
fibs = 0 : 1 : zipWith (+) fibs (tail fibs)
toNumber :: [Int] -> Integer
toNumber = foldl' ((+) . (*2)) 0 . concat . (flip (zipWith replicate) $ cycle [1,0])
fuscs :: Integer -> (Integer, Integer)
fuscs 0 = (0,1)
fuscs 1 = (1,1)
fuscs n = case n `quotRem` 2 of
(q,r) -> let (a,b) = fuscs q
in if r == 0
then (a,a+b)
else (a+b,b)
diatomic :: Integer -> Integer
diatomic = fst . fuscs
firstDiatomic :: Int -> Integer
firstDiatomic n
| n < 0 = error "Diatomic sequence has no negative terms"
| n < 2 = fromIntegral n
| n == 2 = 3
| otherwise = toNumber $ bestCF n
bestCF :: Int -> [Int]
bestCF n = check [] estimate start
where
pfs = primeFactors n
(step,ops) = case pfs of
(2:xs) -> (2,xs)
_ -> (1,pfs)
start0 = (n-1) `quot` 2
start | even n && even start0 = start0 - 1
| otherwise = start0
eligible k = all ((/= 0) . (k `rem`)) ops
estimate = length (takeWhile (<= fromIntegral n) fibs) + 2
check candidates lim k
| k < 1 || n `quot` k >= lim = if null candidates
then check [] (2*lim) start
else minimumBy cfCompare candidates
| eligible k = case contFracLim lim n k of
Nothing -> check candidates lim (k-step)
Just cf -> let s = sum cf
in if s < lim
then check [fixUpCF cf] s (k - step)
else check (fixUpCF cf : candidates) lim (k-step)
| otherwise = check candidates lim (k-step)
I would recommend you read this letter from Dijkstra which explains an alternative way of computing this function via:
n, a, b := N, 1, 0;
do n ≠ 0 and even(n) → a, n:= a + b, n/2
odd(n) → b, n:= b + a, (n-1)/2
od {b = fusc(N)}
This starts with a,b=1,0 and effectively uses successive bits of N (from least to most significant) to increase a and b, the final result being the value of b.
The index of the first appearance of a particular value for b can therefore be computed via finding the smallest n for which this iteration will result in that value of b.
One method for finding this smallest n is to use A* search where the cost is the value of n. The efficiency of the algorithm will be determined by your choice of heuristic.
For the heuristic, I would recommend noting that:
the final value will always be a multiple of the gcd(a,b) (this can be used to rule out some nodes that can never produce the target)
b always increases
there is a maximum (exponential) rate at which b can increase (the rate depends on the current value of a)
EDIT
Here is some example Python code to illustrate the A* approach.
from heapq import *
def gcd(a,b):
while a:
a,b=b%a,a
return b
def heuristic(node,goal):
"""Estimate least n required to make b==goal"""
n,a,b,k = node
if b==goal: return n
# Otherwise needs to have at least one more bit set
# Improve this heuristic to make the algorithm faster
return n+(1<<k)
def diatomic(goal):
"""Return index of first appearance of n in Stern's Diatomic sequence"""
start=0,1,0,0
f_score=[] # This is used as a heap
heappush(f_score, (0,start) )
while 1:
s,node = heappop(f_score)
n,a,b,k = node
if b==goal:
return n
for node in [ (n,a+b,b,k+1),(n+(1<<k),a,b+a,k+1) ]:
n2,a2,b2,k2 = node
if b2<=goal and (goal%gcd(a2,b2))==0:
heappush(f_score,(heuristic(node,goal),node))
print [diatomic(n) for n in xrange(1,10)]
I currently have the following function to get the divisors of an integer:
-- All divisors of a number
divisors :: Integer -> [Integer]
divisors 1 = [1]
divisors n = firstHalf ++ secondHalf
where firstHalf = filter (divides n) (candidates n)
secondHalf = filter (\d -> n `div` d /= d) (map (n `div`) (reverse firstHalf))
candidates n = takeWhile (\d -> d * d <= n) [1..n]
I ended up adding the filter to secondHalf because a divisor was repeating when n is a square of a prime number. This seems like a very inefficient way to solve this problem.
So I have two questions: How do I measure if this really is a bottle neck in my algorithm? And if it is, how do I go about finding a better way to avoid repetitions when n is a square of a prime?
To mesure where the bottleneck is, put the three auxiliary definitions (firstHalf, secondHalf, candidates) at the top level, and run your code with the profiler on: ghc -prof --make divisors.hs ./divisors 100 +RTS -p -RTS
Also, you know that the biggest candidate is sqrt n, so instead of doing that many multiplications d*d, just consider [1..floor (sqrt n)]
For better algorithms, you should take a maths book, for it's not a haskell related question… Things you can consider: if "a divides b", then for all divisor d of a, d divides b as well.
You'll want to use memoization or dynamic programming to avoid checking multiple times if a given d divides b (for example, if 15 and 27 divide b, then you need to mathematically check only once that 3 divides b. The other times, you just see if 3 is in your table of divisors of b).
You needn't test all the elements of reversed second half. You know that if the square root is present, it is the head element there:
secondHalf = let (r:ds) = [n `div` d | d <- reverse firstHalf]
in [r | n `div` r /= r] ++ ds
This assumes n is positive.
A simpler way to handle the sqrt of a number differently is to handle it separately:
divs n =
let
r = floor $ sqrt $ fromIntegral n
(a,b) = unzip $ (1,n) : [(d, q) | d<-[2..r-1], let (q,r)=quotRem n d, r==0]
in
if r*r==n
then a ++ r : reverse b
else a ++ reverse b
That way we get the second half for free, as a part of producing the first half.
But this could hardly be a bottleneck in your application because the algorithm itself is inefficient. It is usually much faster to generate the divisors from a number's prime factorization. Prime factorization by trial division can be much quicker because we divide out each divisor as it is found, reducing the number being factorized and thus the amount of divisors that are tried (up to the reduced number's square root). For example, 12348 = 2*2*3*3*7*7*7 and no factor above 7 is tried in the process of factorization, whereas in divs 12348 the number 12348 is divided by all numbers from 2 to 110:
factorize n = go n (2:[3,5..]) -- or: (go n primes) where
where -- primes = 2 :
go n ds#(d:t) -- filter (null.tail.factorize) [3,5..]
| d*d > n = [n]
| r == 0 = d : go q ds
| otherwise = go n t
where (q,r) = quotRem n d