Why is filter in front of foldLeft slow in Scala? - performance

I wrote an answer to the first Project Euler question:
Add all the natural numbers below one thousand that are multiples of 3 or 5.
The first thing that came to me was:
(1 until 1000).filter(i => (i % 3 == 0 || i % 5 == 0)).foldLeft(0)(_ + _)
but it's slow (it takes 125 ms), so I rewrote it, simply thinking of 'another way' versus 'the faster way'
(1 until 1000).foldLeft(0){
(total, x) =>
x match {
case i if (i % 3 == 0 || i % 5 ==0) => i + total // Add
case _ => total //skip
}
}
This is much faster (only 2 ms). Why? I'm guess the second version uses only the Range generator and doesn't manifest a fully realized collection in any way, doing it all in one pass, both faster and with less memory. Am I right?
Here the code on IdeOne: http://ideone.com/GbKlP

The problem, as others have said, is that filter creates a new collection. The alternative withFilter doesn't, but that doesn't have a foldLeft. Also, using .view, .iterator or .toStream would all avoid creating the new collection in various ways, but they are all slower here than the first method you used, which I thought somewhat strange at first.
But, then... See, 1 until 1000 is a Range, whose size is actually very small, because it doesn't store each element. Also, Range's foreach is extremely optimized, and is even specialized, which is not the case of any of the other collections. Since foldLeft is implemented as a foreach, as long as you stay with a Range you get to enjoy its optimized methods.
(_: Range).foreach:
#inline final override def foreach[#specialized(Unit) U](f: Int => U) {
if (length > 0) {
val last = this.last
var i = start
while (i != last) {
f(i)
i += step
}
f(i)
}
}
(_: Range).view.foreach
def foreach[U](f: A => U): Unit =
iterator.foreach(f)
(_: Range).view.iterator
override def iterator: Iterator[A] = new Elements(0, length)
protected class Elements(start: Int, end: Int) extends BufferedIterator[A] with Serializable {
private var i = start
def hasNext: Boolean = i < end
def next: A =
if (i < end) {
val x = self(i)
i += 1
x
} else Iterator.empty.next
def head =
if (i < end) self(i) else Iterator.empty.next
/** $super
* '''Note:''' `drop` is overridden to enable fast searching in the middle of indexed sequences.
*/
override def drop(n: Int): Iterator[A] =
if (n > 0) new Elements(i + n, end) else this
/** $super
* '''Note:''' `take` is overridden to be symmetric to `drop`.
*/
override def take(n: Int): Iterator[A] =
if (n <= 0) Iterator.empty.buffered
else if (i + n < end) new Elements(i, i + n)
else this
}
(_: Range).view.iterator.foreach
def foreach[U](f: A => U) { while (hasNext) f(next()) }
And that, of course, doesn't even count the filter between view and foldLeft:
override def filter(p: A => Boolean): This = newFiltered(p).asInstanceOf[This]
protected def newFiltered(p: A => Boolean): Transformed[A] = new Filtered { val pred = p }
trait Filtered extends Transformed[A] {
protected[this] val pred: A => Boolean
override def foreach[U](f: A => U) {
for (x <- self)
if (pred(x)) f(x)
}
override def stringPrefix = self.stringPrefix+"F"
}

Try making the collection lazy first, so
(1 until 1000).view.filter...
instead of
(1 until 1000).filter...
That should avoid the cost of building an intermediate collection. You might also get better performance from using sum instead of foldLeft(0)(_ + _), it's always possible that some collection type might have a more efficient way to sum numbers. If not, it's still cleaner and more declarative...

Looking through the code, it looks like filter does build a new Seq on which the foldLeft is called. The second skips that bit. It's not so much the memory, although that can't but help, but that the filtered collection is never built at all. All that work is never done.
Range uses TranversableLike.filter, which looks like this:
def filter(p: A => Boolean): Repr = {
val b = newBuilder
for (x <- this)
if (p(x)) b += x
b.result
}
I think it's the += on line 4 that's the difference. Filtering in foldLeft eliminates it.

filter creates a whole new sequence on which then foldLeft is called. Try:
(1 until 1000).view.filter(i => (i % 3 == 0 || i % 5 == 0)).reduceLeft(_+_)
This will prevent said effect, merely wrapping the original thing. Exchanging foldLeft with reduceLeft is only cosmetic (in this case).

Now the challenge is, can you think of a yet more efficient way? Not that your solution is too slow in this case, but how well does it scale? What if instead of 1000, it was 1000000000? There is a solution that could compute the latter case just as quickly as the former.

Related

Given: aabcdddeabb => Expected: [(a,2),(b,1),(c,1),(d,3),(e,1),(a,1),(b,1)] in Scala

I'm really interested in how this algorithm can be implemented. If possible, it would be great to see an implementation with and without recursion. I am new to the language so I would be very grateful for help. All I could come up with was this code and it goes no further:
print(counterOccur("aabcdddeabb"))
def counterOccur(string: String) =
string.toCharArray.toList.map(char => {
if (!char.charValue().equals(char.charValue() + 1)) (char, counter)
else (char, counter + 1)
})
I realize that it's not even close to the truth, I just don't even have a clue what else could be used.
First solution with using recursion. I take Char by Char from string and check if last element in the Vector is the same as current. If elements the same I update last element by increasing count(It is first case). If last element does not the same I just add new element to the Vector(second case). When I took all Chars from the string I just return result.
def counterOccur(string: String): Vector[(Char, Int)] = {
#tailrec
def loop(str: List[Char], result: Vector[(Char, Int)]): Vector[(Char, Int)] = {
str match {
case x :: xs if result.lastOption.exists(_._1.equals(x)) =>
val count = result(result.size - 1)._2
loop(xs, result.updated(result.size - 1, (x, count + 1)))
case x :: xs =>
loop(xs, result :+ (x, 1))
case Nil => result
}
}
loop(string.toList, Vector.empty[(Char, Int)])
}
println(counterOccur("aabcdddeabb"))
Second solution that does not use recursion. It works the same, but instead of the recursion it is using foldLeft.
def counterOccur2(string: String): Vector[(Char, Int)] = {
string.foldLeft(Vector.empty[(Char, Int)])((r, v) => {
val lastElementIndex = r.size - 1
if (r.lastOption.exists(lv => lv._1.equals(v))) {
r.updated(lastElementIndex, (v, r(lastElementIndex)._2 + 1))
} else {
r :+ (v, 1)
}
})
}
println(counterOccur2("aabcdddeabb"))
You can use a very simple foldLeft to accumulate. You also don't need toCharArray and toList because strings are implicitly convertible to Seq[Char]:
"aabcdddeabb".foldLeft(collection.mutable.ListBuffer[(Char,Int)]()){ (acc, elm) =>
acc.lastOption match {
case Some((c, i)) if c == elm =>
acc.dropRightInPlace(1).addOne((elm, i+1))
case _ =>
acc.addOne((elm, 1))
}
}
Here is a solution using foldLeft and a custom State case class:
def countConsecutives[A](data: List[A]): List[(A, Int)] = {
final case class State(currentElem: A, currentCount: Int, acc: List[(A, Int)]) {
def result: List[(A, Int)] =
((currentElem -> currentCount) :: acc).reverse
def nextState(newElem: A): State =
if (newElem == currentElem)
this.copy(currentCount = this.currentCount + 1)
else
State(
currentElem = newElem,
currentCount = 1,
acc = (this.currentElem -> this.currentCount) :: this.acc
)
}
object State {
def initial(a: A): State =
State(
currentElem = a,
currentCount = 1,
acc = List.empty
)
}
data match {
case a :: tail =>
tail.foldLeft(State.initial(a)) {
case (state, newElem) =>
state.nextState(newElem)
}.result
case Nil =>
List.empty
}
}
You can see the code running here.
One possibility is to use the unfold method. This method is defined for several collection types, here I'm using it to produce an Iterator (documented here for version 2.13.8):
def spans[A](as: Seq[A]): Iterator[Seq[A]] =
Iterator.unfold(as) {
case head +: tail =>
val (span, rest) = tail.span(_ == head)
Some((head +: span, rest))
case _ =>
None
}
unfold starts from a state and applies a function that returns, either:
None if we want to signal that the collection ended
Some of a pair that contains the next item of the collection we want to produce and the "remaining" state that will be fed to the next iteration.
In this example in particular, we start from a sequence of A called as (which can be a sequence of characters) and at each iteration:
if there's at least one item
we split head and tail
we further split the tail into the longest prefix that contains items equal to the head and the rest
we return the head and the prefix we got above as the next item
we return the rest of the collection as the state for the following iteration
otherwise, we return None as there's nothing more to be done
The result is a fairly flexible function that can be used to group together spans of equal items. You can then define the function you wanted initially in terms of this:
def spanLengths[A](as: Seq[A]): Iterator[(A, Int)] =
spans(as).map(a => a.head -> a.length)
This can be probably made more generic and its performance improved, but I hope this can be an helpful example about another possible approach. While folding a collection is a recursive approach, unfolding is referred to as a corecursive one (Wikipedia article).
You can play around with this code here on Scastie.
For
str = "aabcdddeabb"
you could extract matches of the regular expression
rgx = /(.)\1*/
to obtain the array
["aa", "b", "c", "ddd", "e", "a", "bb"]
and then map each element of the array to the desired string.1
def counterOccur(str: String): List[(Char, Int)] = {
"""(.)\1*""".r
.findAllIn(str)
.map(m => (m.charAt(0), m.length)).toList
}
counterOccur("aabcdddeabb")
#=> res0: List[(Char, Int)] = List((a,2), (b,1), (c,1), (d,3), (e,1), (a,1), (b,2))
The regular expression reads, "match any character and save it to capture group 1 ((.)), then match the content of capture group 1 zero or more times (\1*).
1. Scala code kindly provided by #Thefourthbird.

Trial division for primes with immutable collections in Scala

I am trying to learn Scala and functional programming ideology by rewriting basic exercises. Currently I have trouble with naive approach for generating primes "trial division".
The trouble described below is that I could not rewrite well-known algorithm in functional style preserving efficiency, because I have no suitable immutable data structure, like a List but with fast operations not only on head, but also on the very end.
I started with writing java code which for every odd number tests its divisibility by already found primes (limited by square root of value being tested) - and adds it to the end of the list if no divisor was found.
http://ideone.com/QE8U0I
List<Integer> primes = new ArrayList<>();
primes.add(2);
int cur = 3;
while (primes.size() < 100000) {
for (Integer x : primes) {
if (x * x > cur) {
primes.add(cur);
break;
}
if (cur % x == 0) {
break;
}
}
cur += 2;
}
Now I tried to rewrite it in "functional way" - there was no problem with using recursion instead of loops, but I stuck with immutable collections. Core idea is as following:
http://ideone.com/4DQ6mi
def primes(n: Int) = {
#tailrec
def divisibleByAny(x: Int, list: List[Int]): Boolean = {
if (list.isEmpty) false else {
val h = list.head
h * h <= x && (x % h == 0 || divisibleByAny(x, list.tail))
}
}
#tailrec
def morePrimes(from: Int, prev: List[Int]): List[Int] = {
if (prev.size == n) prev else
morePrimes(from + 2, if (divisibleByAny(from, prev)) prev else prev :+ from)
}
morePrimes(3, List(2))
}
But it is slow - if I understand correctly because operation of adding to the end of immutable list requires creation of new copy of the whole stuff.
I searched over documentation to find more suitable data structure and tried to substitute list with immutable Queue, for it is said:
Adding items to the queue always has cost O(1) ... Removing an item is on average O(1).
But it is still even slower:
http://ideone.com/v8BsuQ
def primes(n: Int) = {
#tailrec
def divisibleByAny(x: Int, list: Queue[Int]): Boolean = {
if (list.isEmpty) false else {
val (h, t) = list.dequeue
h * h <= x && (x % h == 0 || divisibleByAny(x, t))
}
}
#tailrec
def morePrimes(from: Int, prev: Queue[Int]): Queue[Int] = {
if (prev.size == n) prev else
morePrimes(from + 2, if (divisibleByAny(from, prev)) prev else prev.enqueue(from))
}
morePrimes(3, Queue(2))
}
What is going wrong or am I missing something?
P.S. I believe there are other algorithms for generating primes which are more suitable for functional style. I think I've seen some paper. But now I'm interested in this one, or more precisely in existence of suitable data structure.
According to http://docs.scala-lang.org/overviews/collections/performance-characteristics.html Vectors have an amortised constant cost for appending, prepending and seeking. Indeed, using vectors instead of lists in your solution is much faster
def primes(n: Int) = {
#tailrec
def divisibleByAny(x: Int, list: Vector[Int]): Boolean = {
if (list.isEmpty) false else {
val (h +: t) = list
h * h <= x && (x % h == 0 || divisibleByAny(x, t))
}
}
#tailrec
def morePrimes(from: Int, prev: Vector[Int]): Vector[Int] = {
if (prev.length == n) prev else
morePrimes(from + 2, if (divisibleByAny(from, prev)) prev else prev :+ from)
}
morePrimes(3, Vector(2))
}
http://ideone.com/x3k4A3
I think you have 2 main options
Use a Vector - which is better than a list for appending. It is a Bitmapped Trie data structure (http://en.wikipedia.org/wiki/Trie). It’s “effectively” O(1) for appending to (i.e. O(1) on average)
Or...possibly the answer you're not looking for
Use a mutable data structure like ListBuffer - immutability it great to try achieve, and should be your go to collections - but sometimes for efficiency reasons, you may use mutable structures . What is key it to make sure it does not “leak out” of your classes. If you look at the List.scala implementation, you’ll see ListBuffer used a lot internally. However, its coverted back to a List just before it leaves the class. If its good enough for the core Scala libraries, its probably ok for you to use under exceptional cases that warrant it.
Except using Vector, also consider using higher-order functions instead of recursion. That's also a completely valid functional style. On my machine the following implementation of divisibleByAny is about 8x faster, than #Pyetras tailrec implementation when running primes(1000000):
def divisibleByAny(x: Int, list: Vector[Int]): Boolean =
list.view.takeWhile(el => el * el <= x).exists(x % _ == 0)

How does Scala's .min avoid the penalty of boxing and unboxing?

Vector.min is implemented as
def min[B >: A](implicit cmp: Ordering[B]): A = {
if (isEmpty)
throw new UnsupportedOperationException("empty.min")
reduceLeft((x, y) => if (cmp.lteq(x, y)) x else y)
}
and when you profile
Vector.fill(1000000)(scala.util.Random.nextLong).min
it's fast and there's no boxing or unboxing going on. If, however, you write the apparently equivalent
val cmp = implicitly[Ordering[Long]]
Vector.fill(1000000)(scala.util.Random.nextLong).reduceLeft((x, y) => if (cmp.lteq(x, y)) x else y)
it runs approximately 10 times slower (ignoring the time in Random, which otherwise dominates this, and yes, I warmed my benchmarks up...).
How is the first version avoiding the performance penalty of the boxing?
Edit: here's my profiling code:
val cmp = implicitly[Ordering[Long]]
def randomLongs = Vector.fill(1000000)(scala.util.Random.nextLong)
def timing[R](f: => R): (Long, R) = {
val startTime = System.nanoTime
val result = f
((System.nanoTime - startTime) / 1000000, result)
}
def minTiming = { val r = randomLongs; timing(r.min)._1 }
def reduceLeftTiming = { val r = randomLongs; timing(r.reduceLeft((x, y) => if (cmp.lteq(x, y)) x else y))._1 }
while(true) {
println((minTiming, reduceLeftTiming))
}
and I see times using min of around 20ms, using reduceLeft of ~200ms. I've profiled this code using YourKit; here's a screen grab of the call tree showing that min doesn't result in any boxing.
I think the first version infers java.lang.Long for B. So there still is boxing going on, but only while filling the vector, and afterwards all comparisons are between boxed objects.
In the second version, since type of cmp is given as Ordering[Long], java.lang.Longs in the vector have to be unboxed before being passed to cmp.lteq(x, y).

Find prime numbers using Scala. Help me to improve

I wrote this code to find the prime numbers less than the given number i in scala.
def findPrime(i : Int) : List[Int] = i match {
case 2 => List(2)
case _ => {
val primeList = findPrime(i-1)
if(isPrime(i, primeList)) i :: primeList else primeList
}
}
def isPrime(num : Int, prePrimes : List[Int]) : Boolean = prePrimes.forall(num % _ != 0)
But, I got a feeling the findPrime function, especially this part:
case _ => {
val primeList = findPrime(i-1)
if(isPrime(i, primeList)) i :: primeList else primeList
}
is not quite in the functional style.
I am still learning functional programming. Can anyone please help me improve this code to make it more functional.
Many thanks.
Here's a functional implementation of the Sieve of Eratosthenes, as presented in Odersky's "Functional Programming Principles in Scala" Coursera course :
// Sieving integral numbers
def sieve(s: Stream[Int]): Stream[Int] = {
s.head #:: sieve(s.tail.filter(_ % s.head != 0))
}
// All primes as a lazy sequence
val primes = sieve(Stream.from(2))
// Dumping the first five primes
print(primes.take(5).toList) // List(2, 3, 5, 7, 11)
The style looks fine to me. Although the Sieve of Eratosthenes is a very efficient way to find prime numbers, your approach works well too, since you are only testing for division against known primes. You need to watch out however--your recursive function is not tail recursive. A tail recursive function does not modify the result of the recursive call--in your example you prepend to the result of the recursive call. This means that you will have a long call stack and so findPrime will not work for large i. Here is a tail-recursive solution.
def primesUnder(n: Int): List[Int] = {
require(n >= 2)
def rec(i: Int, primes: List[Int]): List[Int] = {
if (i >= n) primes
else if (prime(i, primes)) rec(i + 1, i :: primes)
else rec(i + 1, primes)
}
rec(2, List()).reverse
}
def prime(num: Int, factors: List[Int]): Boolean = factors.forall(num % _ != 0)
This solution isn't prettier--it's more of a detail to get your solution to work for large arguments. Since the list is built up backwards to take advantage of fast prepends, the list needs to be reversed. As an alternative, you could use an Array, Vector or a ListBuffer to append the results. With the Array, however, you would need to estimate how much memory to allocate for it. Fortunately we know that pi(n) is about equal to n / ln(n) so you can choose a reasonable size. Array and ListBuffer are also a mutable data types, which goes again your desire for functional style.
Update: to get good performance out of the Sieve of Eratosthenes I think you'll need to store data in a native array, which also goes against your desire for style in functional programming. There might be a creative functional implementation though!
Update: oops! Missed it! This approach works well too if you only divide by primes less than the square root of the number you are testing! I missed this, and unfortunately it's not easy to adjust my solution to do this because I'm storing the primes backwards.
Update: here's a very non-functional solution that at least only checks up to the square root.
rnative, you could use an Array, Vector or a ListBuffer to append the results. With the Array, however, you would need to estimate how much memory to allocate for it. Fortunately we know that pi(n) is about equal to n / ln(n) so you can choose a reasonable size. Array and ListBuffer are also a mutable data types, which goes again your desire for functional style.
Update: to get good performance out of the Sieve of Eratosthenes I think you'll need to store data in a native array, which also goes against your desire for style in functional programming. There might be a creative functional implementation though!
Update: oops! Missed it! This approach works well too if you only divide by primes less than the square root of the number you are testing! I missed this, and unfortunately it's not easy to adjust my solution to do this because I'm storing the primes backwards.
Update: here's a very non-functional solution that at least only checks up to the square root.
import scala.collection.mutable.ListBuffer
def primesUnder(n: Int): List[Int] = {
require(n >= 2)
val primes = ListBuffer(2)
for (i <- 3 to n) {
if (prime(i, primes.iterator)) {
primes += i
}
}
primes.toList
}
// factors must be in sorted order
def prime(num: Int, factors: Iterator[Int]): Boolean =
factors.takeWhile(_ <= math.sqrt(num).toInt) forall(num % _ != 0)
Or I could use Vectors with my original approach. Vectors are probably not the best solution because they don't have the fasted O(1) even though it's amortized O(1).
As schmmd mentions, you want it to be tail recursive, and you also want it to be lazy. Fortunately there is a perfect data-structure for this: Stream.
This is a very efficient prime calculator implemented as a Stream, with a few optimisations:
object Prime {
def is(i: Long): Boolean =
if (i == 2) true
else if ((i & 1) == 0) false // efficient div by 2
else prime(i)
def primes: Stream[Long] = 2 #:: prime3
private val prime3: Stream[Long] = {
#annotation.tailrec
def nextPrime(i: Long): Long =
if (prime(i)) i else nextPrime(i + 2) // tail
def next(i: Long): Stream[Long] =
i #:: next(nextPrime(i + 2))
3 #:: next(5)
}
// assumes not even, check evenness before calling - perf note: must pass partially applied >= method
def prime(i: Long): Boolean =
prime3 takeWhile (math.sqrt(i).>= _) forall { i % _ != 0 }
}
Prime.is is the prime check predicate, and Prime.primes returns a Stream of all prime numbers. prime3 is where the Stream is computed, using the prime predicate to check for all prime divisors less than the square root of i.
/**
* #return Bitset p such that p(x) is true iff x is prime
*/
def sieveOfEratosthenes(n: Int) = {
val isPrime = mutable.BitSet(2 to n: _*)
for (p <- 2 to Math.sqrt(n) if isPrime(p)) {
isPrime --= p*p to n by p
}
isPrime.toImmutable
}
A sieve method is your best bet for small lists of numbers (up to 10-100 million or so).
see: Sieve of Eratosthenes
Even if you want to find much larger numbers, you can use the list you generate with this method as divisors for testing numbers up to n^2, where n is the limit of your list.
#mfa has mentioned using a Sieve of Eratosthenes - SoE and #Luigi Plinge has mentioned that this should be done using functional code, so #netzwerg has posted a non-SoE version; here, I post a "almost" functional version of the SoE using completely immutable state except for the contents of a mutable BitSet (mutable rather than immutable for performance) that I posted as an answer to another question:
object SoE {
def makeSoE_Primes(top: Int): Iterator[Int] = {
val topndx = (top - 3) / 2
val nonprms = new scala.collection.mutable.BitSet(topndx + 1)
def cullp(i: Int) = {
import scala.annotation.tailrec; val p = i + i + 3
#tailrec def cull(c: Int): Unit = if (c <= topndx) { nonprms += c; cull(c + p) }
cull((p * p - 3) >>> 1)
}
(0 to (Math.sqrt(top).toInt - 3) >>> 1).filterNot { nonprms }.foreach { cullp }
Iterator.single(2) ++ (0 to topndx).filterNot { nonprms }.map { i: Int => i + i + 3 }
}
}
How about this.
def getPrimeUnder(n: Int) = {
require(n >= 2)
val ol = 3 to n by 2 toList // oddList
def pn(ol: List[Int], pl: List[Int]): List[Int] = ol match {
case Nil => pl
case _ if pl.exists(ol.head % _ == 0) => pn(ol.tail, pl)
case _ => pn(ol.tail, ol.head :: pl)
}
pn(ol, List(2)).reverse
}
It's pretty fast for me, in my mac, to get all prime under 100k, its take around 2.5 sec.
A scalar fp approach
// returns the list of primes below `number`
def primes(number: Int): List[Int] = {
number match {
case a
if (a <= 3) => (1 to a).toList
case x => (1 to x - 1).filter(b => isPrime(b)).toList
}
}
// checks if a number is prime
def isPrime(number: Int): Boolean = {
number match {
case 1 => true
case x => Nil == {
2 to math.sqrt(number).toInt filter(y => x % y == 0)
}
}
}
def primeNumber(range: Int): Unit ={
val primeNumbers: immutable.IndexedSeq[AnyVal] =
for (number :Int <- 2 to range) yield {
val isPrime = !Range(2, Math.sqrt(number).toInt).exists(x => number % x == 0)
if(isPrime) number
}
for(prime <- primeNumbers) println(prime)
}
object Primes {
private lazy val notDivisibleBy2: Stream[Long] = 3L #:: notDivisibleBy2.map(_ + 2)
private lazy val notDivisibleBy2Or3: Stream[Long] = notDivisibleBy2
.grouped(3)
.map(_.slice(1, 3))
.flatten
.toStream
private lazy val notDivisibleBy2Or3Or5: Stream[Long] = notDivisibleBy2Or3
.grouped(10)
.map { g =>
g.slice(1, 7) ++ g.slice(8, 10)
}
.flatten
.toStream
lazy val primes: Stream[Long] = 2L #::
notDivisibleBy2.head #::
notDivisibleBy2Or3.head #::
notDivisibleBy2Or3Or5.filter { i =>
i < 49 || primes.takeWhile(_ <= Math.sqrt(i).toLong).forall(i % _ != 0)
}
def apply(n: Long): Stream[Long] = primes.takeWhile(_ <= n)
def getPrimeUnder(n: Long): Long = Primes(n).last
}

Scala PriorityQueue on Array[Int] performance issue with complex comparison function (caching is needed)

The problem involves the Scala PriorityQueue[Array[Int]] performance on large data set. The following operations are needed: enqueue, dequeue, and filter. Currently, my implementation is as follows:
For every element of type Array[Int], there is a complex evaluation function: (I'm not sure how to write it in a more efficient way, because it excludes the position 0)
def eval_fun(a : Array[Int]) =
if(a.size < 2) 3
else {
var ret = 0
var i = 1
while(i < a.size) {
if((a(i) & 0x3) == 1) ret += 1
else if((a(i) & 0x3) == 3) ret += 3
i += 1
}
ret / a.size
}
The ordering with a comparison function is based on the evaluation function: (Reversed, descendent order)
val arr_ord = new Ordering[Array[Int]] {
def compare(a : Array[Int], b : Array[Int]) = eval_fun(b) compare eval_fun(a) }
The PriorityQueue is defined as:
val pq: scala.collection.mutable.PriorityQueue[Array[Int]] = PriorityQueue()
Question:
Is there a more elegant and efficient way to write such a evaluation function? I'm thinking of using fold, but fold cannot exclude the position 0.
Is there a data structure to generate a priorityqueue with unique elements? Applying filter operation after each enqueue operation is not efficient.
Is there a cache method to reduce the evaluation computation? Since when adding a new element to the queue, every element may need to be evaluated by eval_fun again, which is not necessary if evaluated value of every element can be cached. Also, I should mention that two distinct element may have the same evaluated value.
Is there a more efficient data structure without using generic type? Because if the size of elements reaches 10,000 and the size of size reaches 1,000, the performance is terribly slow.
Thanks you.
(1) If you want maximum performance here, I would stick to the while loop, even if it is not terribly elegant. Otherwise, if you use a view on Array, you can easily drop the first element before going into the fold:
a.view.drop(1).foldLeft(0)( (sum, a) => sum + ((a & 0x03) match {
case 0x01 => 1
case 0x03 => 3
case _ => 0
})) / a.size
(2) You can maintain two structures, the priority queue, and a set. Both combined give you a sorted-set... So you could use collection.immutable.SortedSet, but there is no mutable variant in the standard library. Do want equality based on the priority function, or the actual array contents? Because in the latter case, you won't get around comparing arrays element by element for each insertion, undoing the effect of caching the priority function value.
(3) Just put the calculated priority along with the array in the queue. I.e.
implicit val ord = Ordering.by[(Int, Array[Int]), Int](_._1)
val pq = new collection.mutable.PriorityQueue[(Int, Array[Int])]
pq += eval_fun(a) -> a
Well, you can use a tail recursive loop (generally these are more "idiomatic":
def eval(a: Array[Int]): Int =
if (a.size < 2) 3
else {
#annotation.tailrec
def loop(ret: Int = 0, i: Int = 1): Int =
if (i >= a.size) ret / a.size
else {
val mod3 = (a(i) & 0x3)
if (mod3 == 1) loop(ret + 1, i + 1)
else if (mod3 == 3) loop(ret + 3, i + 1)
else loop(ret, i + 1)
}
loop()
}
Then you can use that to initialise a cached priority value:
case class PriorityArray(a: Array[Int]) {
lazy val priority = if (a.size < 2) 3 else {
#annotation.tailrec
def loop(ret: Int = 0, i: Int = 1): Int =
if (i >= a.size) ret / a.size
else {
val mod3 = (a(i) & 0x3)
if (mod3 == 2) loop(ret, i + 1)
else loop(ret + mod3, i + 1)
}
loop()
}
}
You may note also that I removed a redundant & op and have only the single conditional (for when it equals 2, rather than two checks for 1 && 3) – these should have some minimal effect.
There is not a huge difference from 0__'s proposal that just came though.
My answers:
If evaluation is critical, keep it as it is. You might get better performance with recursion (not sure why, but it happens), but you'll certainly get worse performance with pretty much any other approach.
No, there isn't, but you can come pretty close to it just modifying the dequeue operation:
def distinctDequeue[T](q: PriorityQueue[T]): T = {
val result = q.dequeue
while (q.head == result) q.dequeue
result
}
Otherwise, you'd have to keep a second data structure just to keep track of whether an element has been added or not. Either way, that equals sign is pretty heavy, but I have a suggestion to make it faster in the next item.
Note, however, that this requires that ties on the the cost function get solved in some other way.
Like 0__ suggested, put the cost on the priority queue. But you can also keep a cache on the function if that would be helpful. I'd try something like this:
val evalMap = scala.collection.mutable.HashMapWrappedArray[Int], Int
def eval_fun(a : Array[Int]) =
if(a.size < 2) 3
else evalMap.getOrElseUpdate(a, {
var ret = 0
var i = 1
while(i < a.size) {
if((a(i) & 0x3) == 1) ret += 1
else if((a(i) & 0x3) == 3) ret += 3
i += 1
}
ret / a.size
})
import scala.math.Ordering.Implicits._
val pq = new collection.mutable.PriorityQueue[(Int, WrappedArray[Int])]
pq += eval_fun(a) -> (a : WrappedArray[Int])
Note that I did not create a special Ordering -- I'm using the standard Ordering so that the WrappedArray will break the ties. There's little cost to wrap the Array, and you get it back with .array, but, on the other hand, you'll get the following:
Ties will be broken by comparing the array themselves. If there aren't many ties in the cost, this should be good enough. If there are, add something else to the tuple to help break ties without comparing the arrays.
That means all equal elements will be kept together, which will enable you to dequeue all of them at the same time, giving the impression of having kept only one.
And that equals will actually work, because WrappedArray compare like Scala sequences do.
I don't understand what you mean by that fourth point.

Resources