Compute median of up to 5 in Scala - algorithm

So, while answering some other question I stumbled upon the necessity of computing the median of 5. Now, there's a similar question in another language, but I want a Scala algorithm for it, and I'm not sure I'm happy with mine.

Here's an immutable Scala version that has the minimum number of compares (6) and doesn't look too ugly:
def med5(five: (Int,Int,Int,Int,Int)) = {
// Return a sorted tuple (one compare)
def order(a: Int, b: Int) = if (a<b) (a,b) else (b,a)
// Given two self-sorted pairs, pick the 2nd of 4 (two compares)
def pairs(p: (Int,Int), q: (Int,Int)) = {
(if (p._1 < q._1) order(p._2,q._1) else order(q._2,p._1))._1
}
// Strategy is to throw away smallest or second smallest, leaving two self-sorted pairs
val ltwo = order(five._1,five._2)
val rtwo = order(five._4,five._5)
if (ltwo._1 < rtwo._1) pairs(rtwo,order(ltwo._2,five._3))
else pairs(ltwo,order(rtwo._2,five._3))
}
Edit: As requested by Daniel, here's a modification to work with all sizes, and in arrays so it should be efficient. I can't make it pretty, so efficiency is the next best thing. (>200M medians/sec with a pre-allocated array of 5, which is slightly more than 100x faster than Daniel's version, and 8x faster than my immutable version above (for lengths of 5)).
def med5b(five: Array[Int]): Int = {
def order2(a: Array[Int], i: Int, j: Int) = {
if (a(i)>a(j)) { val t = a(i); a(i) = a(j); a(j) = t }
}
def pairs(a: Array[Int], i: Int, j: Int, k: Int, l: Int) = {
if (a(i)<a(k)) { order2(a,j,k); a(j) }
else { order2(a,i,l); a(i) }
}
if (five.length < 2) return five(0)
order2(five,0,1)
if (five.length < 4) return (
if (five.length==2 || five(2) < five(0)) five(0)
else if (five(2) > five(1)) five(1)
else five(2)
)
order2(five,2,3)
if (five.length < 5) pairs(five,0,1,2,3)
else if (five(0) < five(2)) { order2(five,1,4); pairs(five,1,4,2,3) }
else { order2(five,3,4); pairs(five,0,1,3,4) }
}

Jeez, way to over-think it, guys.
def med5(a : Int, b: Int, c : Int, d : Int, e : Int) =
List(a, b, c, d, e).sort(_ > _)(2)

As suggested, here's my own algorithm:
def medianUpTo5(arr: Array[Double]): Double = {
def oneAndOrderedPair(a: Double, smaller: Double, bigger: Double): Double =
if (bigger < a) bigger
else if (a < smaller) smaller else a
def partialOrder(a: Double, b: Double, c: Double, d: Double) = {
val (s1, b1) = if (a < b) (a, b) else (b, a)
val (s2, b2) = if (c < d) (c, d) else (d, c)
(s1, b1, s2, b2)
}
def medianOf4(a: Double, b: Double, c: Double, d: Double): Double = {
val (s1, b1, s2, b2) = partialOrder(a, b, c, d)
if (b1 < b2) oneAndOrderedPair(s2, s1, b1)
else oneAndOrderedPair(s1, s2, b2)
}
arr match {
case Array(a) => a
case Array(a, b) => a min b
case Array(a, b, c) =>
if (a < b) oneAndOrderedPair(c, a, b)
else oneAndOrderedPair(c, b, a)
case Array(a, b, c, d) => medianOf4(a, b, c, d)
case Array(a, b, c, d, e) =>
val (s1, b1, s2, b2) = partialOrder(a, b, c, d)
if (s1 < s2) medianOf4(e, b1, s2, b2)
else medianOf4(e, b2, s1, b1)
}
}

Related

Scala how to define an ordering for Rationals

I have to implement compareRationals as something like
(a, b) => {
the body goes here
}
to compare to fractions, transform them so they both have the same denominator, then order the two results by their numerator to make sure they have the same denominator, need to find out the Least Common Denominator so my code works for println(insertionSort2(List(rationals))) and currently works for all the println statements besides that. I really need help to define compareRationals so println(insertionSort2(List(rationals))) shouldBe List(fourth, third, half)
Object {
def insertionSort2[A](xs: List[A])(implicit ord: Ordering[A]): List[A] = {
def insert2(y: A, ys: List[A]): List[A] =
ys match {
case List() => y :: List()
case z :: zs =>
if (ord.lt(y, z)) y :: z :: zs
else z :: insert2(y, zs)
}
xs match {
case List() => List()
case y :: ys => insert2(y, insertionSort2(ys))
}
}
class Rational(x: Int, y: Int) {
private def gcd(a: Int, b: Int): Int = if (b == 0) a else gcd(b, a % b)
private val g = gcd(x, y)
lazy val numer: Int = x / g
lazy val denom: Int = y / g
}
val compareRationals: (Rational, Rational) => Int =
implicit val rationalOrder: Ordering[Rational] =
new Ordering[Rational] {
def compare(x: Rational, y: Rational): Int = compareRationals(x, y)
}
def main(args: Array[String]): Unit = {
val half = new Rational(1, 2)
val third = new Rational(1, 3)
val fourth = new Rational(1, 4)
val rationals = List(third, half, fourth)
println(insertionSort2(List(4,2,9,5,8))(Ordering.Int))
println(insertionSort2(List(4,2,9,5,8)))
println(insertionSort2(List(rationals)))
}
}
}
I think this is all you need.
val compareRationals: (Rational, Rational) => Int =
(x,y) => x.numer * y.denom - y.numer * x.denom

How to divide a set into two sets such that the difference of the average is minimum?

As I understand, it is related to the partition problem.
But I would like to ask a slightly different problem which I don't care about the sum but the average. In this case, it needs to optimize 2 constraints (sum and number of items) at the same time. It seems to be a harder problem and I cannot see any solutions online.
Are there any solutions for this variant? Or how does it relate to the partition problem?
Example:
input X = [1,1,1,1,1,6]
output based on sum: A = [1,1,1,1,1], B=[6]
output based on average: A = [1], B=[1,1,1,1,6]
On some inputs, a modification of the dynamic program for the usual partition problem will give a speedup. We have to classify each partial solution by its count and sum instead of just sum, which slows things down a bit. Python 3 below (note that the use of dictionaries implicitly collapses functionally identical partial solutions):
def children(ab, x):
a, b = ab
yield a + [x], b
yield a, b + [x]
def proper(ab):
a, b = ab
return a and b
def avg(lst):
return sum(lst) / len(lst)
def abs_diff_avg(ab):
a, b = ab
return abs(avg(a) - avg(b))
def min_abs_diff_avg(lst):
solutions = {(0, 0): ([], [])}
for x in lst:
solutions = {
(sum(a), len(a)): (a, b)
for ab in solutions.values()
for (a, b) in children(ab, x)
}
return min(filter(proper, solutions.values()), key=abs_diff_avg)
print(min_abs_diff_avg([1, 1, 1, 1, 1, 6]))
let S_i the sum of a subset of v of size i
let S be the total sum of v, n the length of v
the err to minimize is
err_i = |avg(S_i) - avg(S-S_i)|
err_i = |S_i/i - (S-S_i)/(n-i)|
err_i = |(nS_i - iS)/(i(n-i))|
algorithm below does:
for all tuple sizes (1,...,n/2) as i
- for all tuples of size i-1 as t_{i-1}
- generate all possible tuple of size i from t_{i-1} by adjoining one elem from v
- track best tuple in regard of err_i
The only cut I found being:
for two tuples of size i having the same sum, keep the one whose last element's index is the smallest
e.g given tuples A, B (where X is some taken element from v)
A: [X,....,X....]
B: [.,X,.....,X..]
keep A because its right-most element has the minimal index
(idea being that at size 3, A will offer the same candidates as B plus some more)
function generateTuples (v, tuples) {
const nextTuples = new Map()
for (const [, t] of tuples) {
for (let l = t.l + 1; l < v.length; ++l) {
const s = t.s + v[l]
if (!nextTuples.has(s) || nextTuples.get(s).l > l) {
const nextTuple = { v: t.v.concat(l), s, l }
nextTuples.set(s, nextTuple)
}
}
}
return nextTuples
}
function processV (v) {
const fErr = (() => {
const n = v.length
const S = v.reduce((s, x) => s + x, 0)
return ({ s: S_i, v }) => {
const i = v.length
return Math.abs((n * S_i - i * S) / (i * (n - i)))
}
})()
let tuples = new Map([[0, { v: [], s: 0, l: -1 }]])
let best = null
let err = 9e3
for (let i = 0; i < Math.ceil(v.length / 2); ++i) {
const nextTuples = generateTuples(v, tuples)
for (const [, t] of nextTuples) {
if (fErr(t) <= err) {
best = t
err = fErr(t)
}
}
tuples = nextTuples
}
const s1Indices = new Set(best.v)
return {
sol: v.reduce(([v1, v2], x, i) => {
(s1Indices.has(i) ? v1 : v2).push(x)
return [v1, v2]
}, [[], []]),
err
}
}
console.log('best: ', processV([1, 1, 1, 1, 1, 6]))
console.log('best: ', processV([1, 2, 3, 4, 5]))
console.log('best: ', processV([1, 3, 5, 7, 7, 8]))

Tail recursive solution in Scala for Linked-List chaining

I wanted to write a tail-recursive solution for the following problem on Leetcode -
You are given two non-empty linked lists representing two non-negative integers. The digits are stored in reverse order and each of their nodes contains a single digit. Add the two numbers and return it as a linked list.
You may assume the two numbers do not contain any leading zero, except the number 0 itself.
Example:
*Input: (2 -> 4 -> 3) + (5 -> 6 -> 4)*
*Output: 7 -> 0 -> 8*
*Explanation: 342 + 465 = 807.*
Link to the problem on Leetcode
I was not able to figure out a way to call the recursive function in the last line.
What I am trying to achieve here is the recursive calling of the add function that adds the heads of the two lists with a carry and returns a node. The returned node is chained with the node in the calling stack.
I am pretty new to scala, I am guessing I may have missed some useful constructs.
/**
* Definition for singly-linked list.
* class ListNode(_x: Int = 0, _next: ListNode = null) {
* var next: ListNode = _next
* var x: Int = _x
* }
*/
import scala.annotation.tailrec
object Solution {
def addTwoNumbers(l1: ListNode, l2: ListNode): ListNode = {
add(l1, l2, 0)
}
//#tailrec
def add(l1: ListNode, l2: ListNode, carry: Int): ListNode = {
var sum = 0;
sum = (if(l1!=null) l1.x else 0) + (if(l2!=null) l2.x else 0) + carry;
if(l1 != null || l2 != null || sum > 0)
ListNode(sum%10,add(if(l1!=null) l1.next else null, if(l2!=null) l2.next else null,sum/10))
else null;
}
}
You have a couple of problems, which can mostly be reduced as being not idiomatic.
Things like var and null are not common in Scala and usually, you would use a tail-recursive algorithm to avoid that kind of things.
Finally, remember that a tail-recursive algorithm requires that the last expression is either a plain value or a recursive call. For doing that, you usually keep track of the remaining job as well as an accumulator.
Here is a possible solution:
type Digit = Int // Refined [0..9]
type Number = List[Digit] // Refined NonEmpty.
def sum(n1: Number, n2: Number): Number = {
def aux(d1: Digit, d2: Digit, carry: Digit): (Digit, Digit) = {
val tmp = d1 + d2 + carry
val d = tmp % 10
val c = tmp / 10
d -> c
}
#annotation.tailrec
def loop(r1: Number, r2: Number, acc: Number, carry: Digit): Number =
(r1, r2) match {
case (d1 :: tail1, d2 :: tail2) =>
val (d, c) = aux(d1, d2, carry)
loop(r1 = tail1, r2 = tail2, d :: acc, carry = c)
case (Nil, d2 :: tail2) =>
val (d, c) = aux(d1 = 0, d2, carry)
loop(r1 = Nil, r2 = tail2, d :: acc, carry = c)
case (d1 :: tail1, Nil) =>
val (d, c) = aux(d1, d2 = 0, carry)
loop(r1 = tail1, r2 = Nil, d :: acc, carry = c)
case (Nil, Nil) =>
acc
}
loop(r1 = n1, r2 = n2, acc = List.empty, carry = 0).reverse
}
Now, this kind of recursions tends to be very verbose.
Usually, the stdlib provide ways to make this same algorithm more concise:
// This is a solution that do not require the numbers to be already reversed and the output is also in the correct order.
def sum(n1: Number, n2: Number): Number = {
val (result, carry) = n1.reverseIterator.zipAll(n2.reverseIterator, 0, 0).foldLeft(List.empty[Digit] -> 0) {
case ((acc, carry), (d1, d2)) =>
val tmp = d1 + d2 + carry
val d = tmp % 10
val c = tmp / 10
(d :: acc) -> c
}
if (carry > 0) carry :: result else result
}
Scala is less popular on LeetCode, but this Solution (which is not the best) would get accepted by LeetCode's online judge:
import scala.collection.mutable._
object Solution {
def addTwoNumbers(listA: ListNode, listB: ListNode): ListNode = {
var tempBufferA: ListBuffer[Int] = ListBuffer.empty
var tempBufferB: ListBuffer[Int] = ListBuffer.empty
tempBufferA.clear()
tempBufferB.clear()
def listTraversalA(listA: ListNode): ListBuffer[Int] = {
if (listA == null) {
return tempBufferA
} else {
tempBufferA += listA.x
listTraversalA(listA.next)
}
}
def listTraversalB(listB: ListNode): ListBuffer[Int] = {
if (listB == null) {
return tempBufferB
} else {
tempBufferB += listB.x
listTraversalB(listB.next)
}
}
val resultA: ListBuffer[Int] = listTraversalA(listA)
val resultB: ListBuffer[Int] = listTraversalB(listB)
val resultSum: BigInt = BigInt(resultA.reverse.mkString) + BigInt(resultB.reverse.mkString)
var listNodeResult: ListBuffer[ListNode] = ListBuffer.empty
val resultList = resultSum.toString.toList
var lastListNode: ListNode = null
for (i <-0 until resultList.size) {
if (i == 0) {
lastListNode = new ListNode(resultList(i).toString.toInt)
listNodeResult += lastListNode
} else {
lastListNode = new ListNode(resultList(i).toString.toInt, lastListNode)
listNodeResult += lastListNode
}
}
return listNodeResult.reverse(0)
}
}
References
For additional details, you can see the Discussion Board. There are plenty of accepted solutions, explanations, efficient algorithms with a variety of languages, and time/space complexity analysis in there.

Integer in an interval with maximized number of trailing zero bits

Sought is an efficient algorithm that finds the unique integer in an interval [a, b] which has the maximum number of trailing zeros in its binary representation (a and b are integers > 0):
def bruteForce(a: Int, b: Int): Int =
(a to b).maxBy(Integer.numberOfTrailingZeros(_))
def binSplit(a: Int, b: Int): Int = {
require(a > 0 && a <= b)
val res = ???
assert(res == bruteForce(a, b))
res
}
here are some examples
bruteForce( 5, 7) == 6 // binary 110 (1 trailing zero)
bruteForce( 1, 255) == 128 // binary 10000000
bruteForce(129, 255) == 192 // binary 11000000
etc.
This one finds the number of zeros:
// Requires a>0
def mtz(a: Int, b: Int, mask: Int = 0xFFFFFFFE, n: Int = 0): Int = {
if (a > (b & mask)) n
else mtz(a, b, mask<<1, n+1)
}
This one returns the number with those zeros:
// Requires a > 0
def nmtz(a: Int, b: Int, mask: Int = 0xFFFFFFFE): Int = {
if (a > (b & mask)) b & (mask>>1)
else nmtz(a, b, mask<<1)
}
I doubt the log(log(n)) solution has a small enough constant term to beat this. (But you could do binary search on the number of zeros to get log(log(n)).)
I decided to take Rex's challenge and produce something faster. :-)
// requires a > 0
def mtz2(a: Int, b: Int, mask: Int = 0xffff0000, shift: Int = 8, n: Int = 16): Int = {
if (shift == 0) if (a > (b & mask)) n - 1 else n
else if (a > (b & mask)) mtz2(a, b, mask >> shift, shift / 2, n - shift)
else mtz2(a, b, mask << shift, shift / 2, n + shift)
}
Benchmarked with
import System.{currentTimeMillis => now}
def time[T](f: => T): T = {
val start = now
try { f } finally { println("Elapsed: " + (now - start)/1000.0 + " s") }
}
val range = 1 to 200
time(f((a, b) => mtz(a, b)))
time(f((a, b) => mtz2(a, b)))
First see if there is a power of two that lies within your interval. If there is at least one, the largest one wins.
Otherwise, choose the largest power of two that is less than your minimum bound.
Does 1100000...0 lie in your bound? If yes, you've won. If it's still less than your minimum bound, try 1110000...0; otherwise, if it's greater than your maximum bound, try 1010000...0.
And so forth, until you win.
as a conclusion, here is my variant of Rex' answer which gives both the center value and also an 'extent' which is the minimum power of two distance from the center which covers both a in the one direction and b in the other.
#tailrec def binSplit(a: Int, b: Int, mask: Int = 0xFFFFFFFF): (Int, Int) = {
val mask2 = mask << 1
if (a > (b & mask2)) (b & mask, -mask)
else binSplit(a, b, mask2)
}
def test(): Unit = {
val Seq(r1, r2) = Seq.fill(2)(util.Random.nextInt(0x3FFFFFFF) + 1)
val (a, b) = if (r1 <= r2) (r1, r2) else (r2, r1)
val (center, extent) = binSplit(a, b)
assert((center >= a) && (center <= b) && (center - extent) <= a &&
(center - extent) >= 0 && (center + extent) > b, (a, b, center, extent))
}
for (i <- 0 to 100000) { test() }

Simplest way to get the top n elements of a Scala Iterable

Is there a simple and efficient solution to determine the top n elements of a Scala Iterable? I mean something like
iter.toList.sortBy(_.myAttr).take(2)
but without having to sort all elements when only the top 2 are of interest. Ideally I'm looking for something like
iter.top(2, _.myAttr)
see also: Solution for the top element using an Ordering: In Scala, how to use Ordering[T] with List.min or List.max and keep code readable
Update:
Thank you all for your solutions. Finally, I took the original solution of user unknown and adopted it to use Iterable and the pimp-my-library pattern:
implicit def iterExt[A](iter: Iterable[A]) = new {
def top[B](n: Int, f: A => B)(implicit ord: Ordering[B]): List[A] = {
def updateSofar (sofar: List [A], el: A): List [A] = {
//println (el + " - " + sofar)
if (ord.compare(f(el), f(sofar.head)) > 0)
(el :: sofar.tail).sortBy (f)
else sofar
}
val (sofar, rest) = iter.splitAt(n)
(sofar.toList.sortBy (f) /: rest) (updateSofar (_, _)).reverse
}
}
case class A(s: String, i: Int)
val li = List (4, 3, 6, 7, 1, 2, 9, 5).map(i => A(i.toString(), i))
println(li.top(3, _.i))
My solution (bound to Int, but should be easily changed to Ordered (a few minutes please):
def top (n: Int, li: List [Int]) : List[Int] = {
def updateSofar (sofar: List [Int], el: Int) : List [Int] = {
// println (el + " - " + sofar)
if (el < sofar.head)
(el :: sofar.tail).sortWith (_ > _)
else sofar
}
/* better readable:
val sofar = li.take (n).sortWith (_ > _)
val rest = li.drop (n)
(sofar /: rest) (updateSofar (_, _)) */
(li.take (n). sortWith (_ > _) /: li.drop (n)) (updateSofar (_, _))
}
usage:
val li = List (4, 3, 6, 7, 1, 2, 9, 5)
top (2, li)
For above list, take the first 2 (4, 3) as starting TopTen (TopTwo).
Sort them, such that the first element is the bigger one (if any).
repeatedly iterate through the rest of the list (li.drop(n)), and compare the current element with the maximum of the list of minimums; replace, if neccessary, and resort again.
Improvements:
Throw away Int, and use ordered.
Throw away (_ > _) and use a user-Ordering to allow BottomTen. (Harder: pick the middle 10 :) )
Throw away List, and use Iterable instead
update (abstraction):
def extremeN [T](n: Int, li: List [T])
(comp1: ((T, T) => Boolean), comp2: ((T, T) => Boolean)):
List[T] = {
def updateSofar (sofar: List [T], el: T) : List [T] =
if (comp1 (el, sofar.head))
(el :: sofar.tail).sortWith (comp2 (_, _))
else sofar
(li.take (n) .sortWith (comp2 (_, _)) /: li.drop (n)) (updateSofar (_, _))
}
/* still bound to Int:
def top (n: Int, li: List [Int]) : List[Int] = {
extremeN (n, li) ((_ < _), (_ > _))
}
def bottom (n: Int, li: List [Int]) : List[Int] = {
extremeN (n, li) ((_ > _), (_ < _))
}
*/
def top [T] (n: Int, li: List [T])
(implicit ord: Ordering[T]): Iterable[T] = {
extremeN (n, li) (ord.lt (_, _), ord.gt (_, _))
}
def bottom [T] (n: Int, li: List [T])
(implicit ord: Ordering[T]): Iterable[T] = {
extremeN (n, li) (ord.gt (_, _), ord.lt (_, _))
}
top (3, li)
bottom (3, li)
val sl = List ("Haus", "Garten", "Boot", "Sumpf", "X", "y", "xkcd", "x11")
bottom (2, sl)
To replace List with Iterable seems to be a bit harder.
As Daniel C. Sobral pointed out in the comments, a high n in topN can lead to much sorting work, so that it could be useful, to do a manual insertion sort instead of repeatedly sorting the whole list of top-n elements:
def extremeN [T](n: Int, li: List [T])
(comp1: ((T, T) => Boolean), comp2: ((T, T) => Boolean)):
List[T] = {
def sortedIns (el: T, list: List[T]): List[T] =
if (list.isEmpty) List (el) else
if (comp2 (el, list.head)) el :: list else
list.head :: sortedIns (el, list.tail)
def updateSofar (sofar: List [T], el: T) : List [T] =
if (comp1 (el, sofar.head))
sortedIns (el, sofar.tail)
else sofar
(li.take (n) .sortWith (comp2 (_, _)) /: li.drop (n)) (updateSofar (_, _))
}
top/bottom method and usage as above. For small groups of top/bottom Elements, the sorting is rarely called, a few times in the beginning, and then less and less often over time. For example, 70 times with top (10) of 10 000, and 90 times with top (10) of 100 000.
Here's another solution that is simple and has pretty good performance.
def pickTopN[T](k: Int, iterable: Iterable[T])(implicit ord: Ordering[T]): Seq[T] = {
val q = collection.mutable.PriorityQueue[T](iterable.toSeq:_*)
val end = Math.min(k, q.size)
(1 to end).map(_ => q.dequeue())
}
The Big O is O(n + k log n), where k <= n. So the performance is linear for small k and at worst n log n.
The solution can also be optimized to be O(k) for memory but O(n log k) for performance. The idea is to use a MinHeap to track only the top k items at all times. Here's the solution.
def pickTopN[A, B](n: Int, iterable: Iterable[A], f: A => B)(implicit ord: Ordering[B]): Seq[A] = {
val seq = iterable.toSeq
val q = collection.mutable.PriorityQueue[A](seq.take(n):_*)(ord.on(f).reverse) // initialize with first n
// invariant: keep the top k scanned so far
seq.drop(n).foreach(v => {
q += v
q.dequeue()
})
q.dequeueAll.reverse
}
Yet another version:
val big = (1 to 100000)
def maxes[A](n:Int)(l:Traversable[A])(implicit o:Ordering[A]) =
l.foldLeft(collection.immutable.SortedSet.empty[A]) { (xs,y) =>
if (xs.size < n) xs + y
else {
import o._
val first = xs.firstKey
if (first < y) xs - first + y
else xs
}
}
println(maxes(4)(big))
println(maxes(2)(List("a","ab","c","z")))
Using the Set force the list to have unique values:
def maxes2[A](n:Int)(l:Traversable[A])(implicit o:Ordering[A]) =
l.foldLeft(List.empty[A]) { (xs,y) =>
import o._
if (xs.size < n) (y::xs).sort(lt _)
else {
val first = xs.head
if (first < y) (y::(xs - first)).sort(lt _)
else xs
}
}
You don't need to sort the entire collection in order to determine the top N elements. However, I don't believe that this functionality is supplied by the raw library, so you would have to roll you own, possibly using the pimp-my-library pattern.
For example, you can get the nth element of a collection as follows:
class Pimp[A, Repr <% TraversableLike[A, Repr]](self : Repr) {
def nth(n : Int)(implicit ord : Ordering[A]) : A = {
val trav : TraversableLike[A, Repr] = self
var ltp : List[A] = Nil
var etp : List[A] = Nil
var mtp : List[A] = Nil
trav.headOption match {
case None => error("Cannot get " + n + " element of empty collection")
case Some(piv) =>
trav.foreach { a =>
val cf = ord.compare(piv, a)
if (cf == 0) etp ::= a
else if (cf > 0) ltp ::= a
else mtp ::= a
}
if (n < ltp.length)
new Pimp[A, List[A]](ltp.reverse).nth(n)(ord)
else if (n < (ltp.length + etp.length))
piv
else
new Pimp[A, List[A]](mtp.reverse).nth(n - ltp.length - etp.length)(ord)
}
}
}
(This is not very functional; sorry)
It's then trivial to get the top n elements:
def topN(n : Int)(implicit ord : Ordering[A], bf : CanBuildFrom[Repr, A, Repr]) ={
val b = bf()
val elem = new Pimp[A, Repr](self).nth(n)(ord)
import util.control.Breaks._
breakable {
var soFar = 0
self.foreach { tt =>
if (ord.compare(tt, elem) < 0) {
b += tt
soFar += 1
}
}
assert (soFar <= n)
if (soFar < n) {
self.foreach { tt =>
if (ord.compare(tt, elem) == 0) {
b += tt
soFar += 1
}
if (soFar == n) break
}
}
}
b.result()
}
Unfortunately I'm having trouble getting this pimp to be discovered via this implicit:
implicit def t2n[A, Repr <% TraversableLike[A, Repr]](t : Repr) : Pimp[A, Repr]
= new Pimp[A, Repr](t)
I get this:
scala> List(4, 3, 6, 7, 1, 2, 8, 5).topN(4)
<console>:9: error: could not find implicit value for evidence parameter of type (List[Int]) => scala.collection.TraversableLike[A,List[Int]]
List(4, 3, 6, 7, 1, 2, 8, 5).topN(4)
^
However, the code actually works OK:
scala> new Pimp(List(4, 3, 6, 7, 1, 2, 8, 5)).topN(4)
res3: List[Int] = List(3, 1, 2, 4)
And
scala> new Pimp("ioanusdhpisjdmpsdsvfgewqw").topN(6)
res2: java.lang.String = adddfe
If the goal is to not sort the whole list then you could do something like this (of course it could be optimized a tad so that we don't change the list when the number clearly shouldn't be there):
List(1,6,3,7,3,2).foldLeft(List[Int]()){(l, n) => (n :: l).sorted.take(2)}
I implemented such an ranking algorithm recently in the Rank class of Apache Jackrabbit (in Java though). See the take method for the gist of it. The basic idea is to quicksort but terminate prematurely as soon as the top n elements have been found.
Here is asymptotically O(n) solution.
def top[T](data: List[T], n: Int)(implicit ord: Ordering[T]): List[T] = {
require( n < data.size)
def partition_inner(shuffledData: List[T], pivot: T): List[T] =
shuffledData.partition( e => ord.compare(e, pivot) > 0 ) match {
case (left, right) if left.size == n => left
case (left, x :: rest) if left.size < n =>
partition_inner(util.Random.shuffle(data), x)
case (left # y :: rest, right) if left.size > n =>
partition_inner(util.Random.shuffle(data), y)
}
val shuffled = util.Random.shuffle(data)
partition_inner(shuffled, shuffled.head)
}
scala> top(List.range(1,10000000), 5)
Due to recursion, this solution will take longer than some non-linear solutions above and can cause java.lang.OutOfMemoryError: GC overhead limit exceeded.
But slightly more readable IMHO and functional style. Just for job interview ;).
What is more important, that this solution can be easily parallelized.
def top[T](data: List[T], n: Int)(implicit ord: Ordering[T]): List[T] = {
require( n < data.size)
#tailrec
def partition_inner(shuffledData: List[T], pivot: T): List[T] =
shuffledData.par.partition( e => ord.compare(e, pivot) > 0 ) match {
case (left, right) if left.size == n => left.toList
case (left, right) if left.size < n =>
partition_inner(util.Random.shuffle(data), right.head)
case (left, right) if left.size > n =>
partition_inner(util.Random.shuffle(data), left.head)
}
val shuffled = util.Random.shuffle(data)
partition_inner(shuffled, shuffled.head)
}
For small values of n and large lists, getting the top n elements can be implemented by picking out the max element n times:
def top[T](n:Int, iter:Iterable[T])(implicit ord: Ordering[T]): Iterable[T] = {
def partitionMax(acc: Iterable[T], it: Iterable[T]): Iterable[T] = {
val max = it.max(ord)
val (nextElems, rest) = it.partition(ord.gteq(_, max))
val maxElems = acc ++ nextElems
if (maxElems.size >= n || rest.isEmpty) maxElems.take(n)
else partitionMax(maxElems, rest)
}
if (iter.isEmpty) iter.take(0)
else partitionMax(iter.take(0), iter)
}
This does not sort the entire list and takes an Ordering. I believe every method I call in partitionMax is O(list size) and I only expect to call it n times at most, so the overall efficiency for small n will be proportional to the size of the iterator.
scala> top(5, List.range(1,1000000))
res13: Iterable[Int] = List(999999, 999998, 999997, 999996, 999995)
scala> top(5, List.range(1,1000000))(Ordering[Int].on(- _))
res14: Iterable[Int] = List(1, 2, 3, 4, 5)
You could also add a branch for when n gets close to size of the iterable, and switch to iter.toList.sortBy(_.myAttr).take(n).
It does not return the type of collection provided, but you can look at How do I apply the enrich-my-library pattern to Scala collections? if this is a requirement.
An optimised solution using PriorityQueue with Time Complexity of O(nlogk). In the approach given in the update, you are sorting the sofar list every time which is not needed and below it is optimised by using PriorityQueue.
import scala.language.implicitConversions
import scala.language.reflectiveCalls
import collection.mutable.PriorityQueue
implicit def iterExt[A](iter: Iterable[A]) = new {
def top[B](n: Int, f: A => B)(implicit ord: Ordering[B]) : List[A] = {
def updateSofar (sofar: PriorityQueue[A], el: A): PriorityQueue[A] = {
if (ord.compare(f(el), f(sofar.head)) < 0){
sofar.dequeue
sofar.enqueue(el)
}
sofar
}
val (sofar, rest) = iter.splitAt(n)
(PriorityQueue(sofar.toSeq:_*)( Ordering.by( (x :A) => f(x) ) ) /: rest) (updateSofar (_, _)).dequeueAll.toList.reverse
}
}
case class A(s: String, i: Int)
val li = List (4, 3, 6, 7, 1, 2, 9, 5).map(i => A(i.toString(), i))
println(li.top(3, -_.i))

Resources