Scala PriorityQueue on Array[Int] performance issue with complex comparison function (caching is needed) - performance

The problem involves the Scala PriorityQueue[Array[Int]] performance on large data set. The following operations are needed: enqueue, dequeue, and filter. Currently, my implementation is as follows:
For every element of type Array[Int], there is a complex evaluation function: (I'm not sure how to write it in a more efficient way, because it excludes the position 0)
def eval_fun(a : Array[Int]) =
if(a.size < 2) 3
else {
var ret = 0
var i = 1
while(i < a.size) {
if((a(i) & 0x3) == 1) ret += 1
else if((a(i) & 0x3) == 3) ret += 3
i += 1
}
ret / a.size
}
The ordering with a comparison function is based on the evaluation function: (Reversed, descendent order)
val arr_ord = new Ordering[Array[Int]] {
def compare(a : Array[Int], b : Array[Int]) = eval_fun(b) compare eval_fun(a) }
The PriorityQueue is defined as:
val pq: scala.collection.mutable.PriorityQueue[Array[Int]] = PriorityQueue()
Question:
Is there a more elegant and efficient way to write such a evaluation function? I'm thinking of using fold, but fold cannot exclude the position 0.
Is there a data structure to generate a priorityqueue with unique elements? Applying filter operation after each enqueue operation is not efficient.
Is there a cache method to reduce the evaluation computation? Since when adding a new element to the queue, every element may need to be evaluated by eval_fun again, which is not necessary if evaluated value of every element can be cached. Also, I should mention that two distinct element may have the same evaluated value.
Is there a more efficient data structure without using generic type? Because if the size of elements reaches 10,000 and the size of size reaches 1,000, the performance is terribly slow.
Thanks you.

(1) If you want maximum performance here, I would stick to the while loop, even if it is not terribly elegant. Otherwise, if you use a view on Array, you can easily drop the first element before going into the fold:
a.view.drop(1).foldLeft(0)( (sum, a) => sum + ((a & 0x03) match {
case 0x01 => 1
case 0x03 => 3
case _ => 0
})) / a.size
(2) You can maintain two structures, the priority queue, and a set. Both combined give you a sorted-set... So you could use collection.immutable.SortedSet, but there is no mutable variant in the standard library. Do want equality based on the priority function, or the actual array contents? Because in the latter case, you won't get around comparing arrays element by element for each insertion, undoing the effect of caching the priority function value.
(3) Just put the calculated priority along with the array in the queue. I.e.
implicit val ord = Ordering.by[(Int, Array[Int]), Int](_._1)
val pq = new collection.mutable.PriorityQueue[(Int, Array[Int])]
pq += eval_fun(a) -> a

Well, you can use a tail recursive loop (generally these are more "idiomatic":
def eval(a: Array[Int]): Int =
if (a.size < 2) 3
else {
#annotation.tailrec
def loop(ret: Int = 0, i: Int = 1): Int =
if (i >= a.size) ret / a.size
else {
val mod3 = (a(i) & 0x3)
if (mod3 == 1) loop(ret + 1, i + 1)
else if (mod3 == 3) loop(ret + 3, i + 1)
else loop(ret, i + 1)
}
loop()
}
Then you can use that to initialise a cached priority value:
case class PriorityArray(a: Array[Int]) {
lazy val priority = if (a.size < 2) 3 else {
#annotation.tailrec
def loop(ret: Int = 0, i: Int = 1): Int =
if (i >= a.size) ret / a.size
else {
val mod3 = (a(i) & 0x3)
if (mod3 == 2) loop(ret, i + 1)
else loop(ret + mod3, i + 1)
}
loop()
}
}
You may note also that I removed a redundant & op and have only the single conditional (for when it equals 2, rather than two checks for 1 && 3) – these should have some minimal effect.
There is not a huge difference from 0__'s proposal that just came though.

My answers:
If evaluation is critical, keep it as it is. You might get better performance with recursion (not sure why, but it happens), but you'll certainly get worse performance with pretty much any other approach.
No, there isn't, but you can come pretty close to it just modifying the dequeue operation:
def distinctDequeue[T](q: PriorityQueue[T]): T = {
val result = q.dequeue
while (q.head == result) q.dequeue
result
}
Otherwise, you'd have to keep a second data structure just to keep track of whether an element has been added or not. Either way, that equals sign is pretty heavy, but I have a suggestion to make it faster in the next item.
Note, however, that this requires that ties on the the cost function get solved in some other way.
Like 0__ suggested, put the cost on the priority queue. But you can also keep a cache on the function if that would be helpful. I'd try something like this:
val evalMap = scala.collection.mutable.HashMapWrappedArray[Int], Int
def eval_fun(a : Array[Int]) =
if(a.size < 2) 3
else evalMap.getOrElseUpdate(a, {
var ret = 0
var i = 1
while(i < a.size) {
if((a(i) & 0x3) == 1) ret += 1
else if((a(i) & 0x3) == 3) ret += 3
i += 1
}
ret / a.size
})
import scala.math.Ordering.Implicits._
val pq = new collection.mutable.PriorityQueue[(Int, WrappedArray[Int])]
pq += eval_fun(a) -> (a : WrappedArray[Int])
Note that I did not create a special Ordering -- I'm using the standard Ordering so that the WrappedArray will break the ties. There's little cost to wrap the Array, and you get it back with .array, but, on the other hand, you'll get the following:
Ties will be broken by comparing the array themselves. If there aren't many ties in the cost, this should be good enough. If there are, add something else to the tuple to help break ties without comparing the arrays.
That means all equal elements will be kept together, which will enable you to dequeue all of them at the same time, giving the impression of having kept only one.
And that equals will actually work, because WrappedArray compare like Scala sequences do.
I don't understand what you mean by that fourth point.

Related

Trial division for primes with immutable collections in Scala

I am trying to learn Scala and functional programming ideology by rewriting basic exercises. Currently I have trouble with naive approach for generating primes "trial division".
The trouble described below is that I could not rewrite well-known algorithm in functional style preserving efficiency, because I have no suitable immutable data structure, like a List but with fast operations not only on head, but also on the very end.
I started with writing java code which for every odd number tests its divisibility by already found primes (limited by square root of value being tested) - and adds it to the end of the list if no divisor was found.
http://ideone.com/QE8U0I
List<Integer> primes = new ArrayList<>();
primes.add(2);
int cur = 3;
while (primes.size() < 100000) {
for (Integer x : primes) {
if (x * x > cur) {
primes.add(cur);
break;
}
if (cur % x == 0) {
break;
}
}
cur += 2;
}
Now I tried to rewrite it in "functional way" - there was no problem with using recursion instead of loops, but I stuck with immutable collections. Core idea is as following:
http://ideone.com/4DQ6mi
def primes(n: Int) = {
#tailrec
def divisibleByAny(x: Int, list: List[Int]): Boolean = {
if (list.isEmpty) false else {
val h = list.head
h * h <= x && (x % h == 0 || divisibleByAny(x, list.tail))
}
}
#tailrec
def morePrimes(from: Int, prev: List[Int]): List[Int] = {
if (prev.size == n) prev else
morePrimes(from + 2, if (divisibleByAny(from, prev)) prev else prev :+ from)
}
morePrimes(3, List(2))
}
But it is slow - if I understand correctly because operation of adding to the end of immutable list requires creation of new copy of the whole stuff.
I searched over documentation to find more suitable data structure and tried to substitute list with immutable Queue, for it is said:
Adding items to the queue always has cost O(1) ... Removing an item is on average O(1).
But it is still even slower:
http://ideone.com/v8BsuQ
def primes(n: Int) = {
#tailrec
def divisibleByAny(x: Int, list: Queue[Int]): Boolean = {
if (list.isEmpty) false else {
val (h, t) = list.dequeue
h * h <= x && (x % h == 0 || divisibleByAny(x, t))
}
}
#tailrec
def morePrimes(from: Int, prev: Queue[Int]): Queue[Int] = {
if (prev.size == n) prev else
morePrimes(from + 2, if (divisibleByAny(from, prev)) prev else prev.enqueue(from))
}
morePrimes(3, Queue(2))
}
What is going wrong or am I missing something?
P.S. I believe there are other algorithms for generating primes which are more suitable for functional style. I think I've seen some paper. But now I'm interested in this one, or more precisely in existence of suitable data structure.
According to http://docs.scala-lang.org/overviews/collections/performance-characteristics.html Vectors have an amortised constant cost for appending, prepending and seeking. Indeed, using vectors instead of lists in your solution is much faster
def primes(n: Int) = {
#tailrec
def divisibleByAny(x: Int, list: Vector[Int]): Boolean = {
if (list.isEmpty) false else {
val (h +: t) = list
h * h <= x && (x % h == 0 || divisibleByAny(x, t))
}
}
#tailrec
def morePrimes(from: Int, prev: Vector[Int]): Vector[Int] = {
if (prev.length == n) prev else
morePrimes(from + 2, if (divisibleByAny(from, prev)) prev else prev :+ from)
}
morePrimes(3, Vector(2))
}
http://ideone.com/x3k4A3
I think you have 2 main options
Use a Vector - which is better than a list for appending. It is a Bitmapped Trie data structure (http://en.wikipedia.org/wiki/Trie). It’s “effectively” O(1) for appending to (i.e. O(1) on average)
Or...possibly the answer you're not looking for
Use a mutable data structure like ListBuffer - immutability it great to try achieve, and should be your go to collections - but sometimes for efficiency reasons, you may use mutable structures . What is key it to make sure it does not “leak out” of your classes. If you look at the List.scala implementation, you’ll see ListBuffer used a lot internally. However, its coverted back to a List just before it leaves the class. If its good enough for the core Scala libraries, its probably ok for you to use under exceptional cases that warrant it.
Except using Vector, also consider using higher-order functions instead of recursion. That's also a completely valid functional style. On my machine the following implementation of divisibleByAny is about 8x faster, than #Pyetras tailrec implementation when running primes(1000000):
def divisibleByAny(x: Int, list: Vector[Int]): Boolean =
list.view.takeWhile(el => el * el <= x).exists(x % _ == 0)

How to design an algorithm to calculate countdown style maths number puzzle

I have always wanted to do this but every time I start thinking about the problem it blows my mind because of its exponential nature.
The problem solver I want to be able to understand and code is for the countdown maths problem:
Given set of number X1 to X5 calculate how they can be combined using mathematical operations to make Y.
You can apply multiplication, division, addition and subtraction.
So how does 1,3,7,6,8,3 make 348?
Answer: (((8 * 7) + 3) -1) *6 = 348.
How to write an algorithm that can solve this problem? Where do you begin when trying to solve a problem like this? What important considerations do you have to think about when designing such an algorithm?
Very quick and dirty solution in Java:
public class JavaApplication1
{
public static void main(String[] args)
{
List<Integer> list = Arrays.asList(1, 3, 7, 6, 8, 3);
for (Integer integer : list) {
List<Integer> runList = new ArrayList<>(list);
runList.remove(integer);
Result result = getOperations(runList, integer, 348);
if (result.success) {
System.out.println(integer + result.output);
return;
}
}
}
public static class Result
{
public String output;
public boolean success;
}
public static Result getOperations(List<Integer> numbers, int midNumber, int target)
{
Result midResult = new Result();
if (midNumber == target) {
midResult.success = true;
midResult.output = "";
return midResult;
}
for (Integer number : numbers) {
List<Integer> newList = new ArrayList<Integer>(numbers);
newList.remove(number);
if (newList.isEmpty()) {
if (midNumber - number == target) {
midResult.success = true;
midResult.output = "-" + number;
return midResult;
}
if (midNumber + number == target) {
midResult.success = true;
midResult.output = "+" + number;
return midResult;
}
if (midNumber * number == target) {
midResult.success = true;
midResult.output = "*" + number;
return midResult;
}
if (midNumber / number == target) {
midResult.success = true;
midResult.output = "/" + number;
return midResult;
}
midResult.success = false;
midResult.output = "f" + number;
return midResult;
} else {
midResult = getOperations(newList, midNumber - number, target);
if (midResult.success) {
midResult.output = "-" + number + midResult.output;
return midResult;
}
midResult = getOperations(newList, midNumber + number, target);
if (midResult.success) {
midResult.output = "+" + number + midResult.output;
return midResult;
}
midResult = getOperations(newList, midNumber * number, target);
if (midResult.success) {
midResult.output = "*" + number + midResult.output;
return midResult;
}
midResult = getOperations(newList, midNumber / number, target);
if (midResult.success) {
midResult.output = "/" + number + midResult.output;
return midResult
}
}
}
return midResult;
}
}
UPDATE
It's basically just simple brute force algorithm with exponential complexity.
However you can gain some improvemens by leveraging some heuristic function which will help you to order sequence of numbers or(and) operations you will process in each level of getOperatiosn() function recursion.
Example of such heuristic function is for example difference between mid result and total target result.
This way however only best-case and average-case complexities get improved. Worst case complexity remains untouched.
Worst case complexity can be improved by some kind of branch cutting. I'm not sure if it's possible in this case.
Sure it's exponential but it's tiny so a good (enough) naive implementation would be a good start. I suggest you drop the usual infix notation with bracketing, and use postfix, it's easier to program. You can always prettify the outputs as a separate stage.
Start by listing and evaluating all the (valid) sequences of numbers and operators. For example (in postfix):
1 3 7 6 8 3 + + + + + -> 28
1 3 7 6 8 3 + + + + - -> 26
My Java is laughable, I don't come here to be laughed at so I'll leave coding this up to you.
To all the smart people reading this: yes, I know that for even a small problem like this there are smarter approaches which are likely to be faster, I'm just pointing OP towards an initial working solution. Someone else can write the answer with the smarter solution(s).
So, to answer your questions:
I begin with an algorithm that I think will lead me quickly to a working solution. In this case the obvious (to me) choice is exhaustive enumeration and testing of all possible calculations.
If the obvious algorithm looks unappealing for performance reasons I'll start thinking more deeply about it, recalling other algorithms that I know about which are likely to deliver better performance. I may start coding one of those first instead.
If I stick with the exhaustive algorithm and find that the run-time is, in practice, too long, then I might go back to the previous step and code again. But it has to be worth my while, there's a cost/benefit assessment to be made -- as long as my code can outperform Rachel Riley I'd be satisfied.
Important considerations include my time vs computer time, mine costs a helluva lot more.
A working solution in c++11 below.
The basic idea is to use a stack-based evaluation (see RPN) and convert the viable solutions to infix notation for display purposes only.
If we have N input digits, we'll use (N-1) operators, as each operator is binary.
First we create valid permutations of operands and operators (the selector_ array). A valid permutation is one that can be evaluated without stack underflow and which ends with exactly one value (the result) on the stack. Thus 1 1 + is valid, but 1 + 1 is not.
We test each such operand-operator permutation with every permutation of operands (the values_ array) and every combination of operators (the ops_ array). Matching results are pretty-printed.
Arguments are taken from command line as [-s] <target> <digit>[ <digit>...]. The -s switch prevents exhaustive search, only the first matching result is printed.
(use ./mathpuzzle 348 1 3 7 6 8 3 to get the answer for the original question)
This solution doesn't allow concatenating the input digits to form numbers. That could be added as an additional outer loop.
The working code can be downloaded from here. (Note: I updated that code with support for concatenating input digits to form a solution)
See code comments for additional explanation.
#include <iostream>
#include <vector>
#include <algorithm>
#include <stack>
#include <iterator>
#include <string>
namespace {
enum class Op {
Add,
Sub,
Mul,
Div,
};
const std::size_t NumOps = static_cast<std::size_t>(Op::Div) + 1;
const Op FirstOp = Op::Add;
using Number = int;
class Evaluator {
std::vector<Number> values_; // stores our digits/number we can use
std::vector<Op> ops_; // stores the operators
std::vector<char> selector_; // used to select digit (0) or operator (1) when evaluating. should be std::vector<bool>, but that's broken
template <typename T>
using Stack = std::stack<T, std::vector<T>>;
// checks if a given number/operator order can be evaluated or not
bool isSelectorValid() const {
int numValues = 0;
for (auto s : selector_) {
if (s) {
if (--numValues <= 0) {
return false;
}
}
else {
++numValues;
}
}
return (numValues == 1);
}
// evaluates the current values_ and ops_ based on selector_
Number eval(Stack<Number> &stack) const {
auto vi = values_.cbegin();
auto oi = ops_.cbegin();
for (auto s : selector_) {
if (!s) {
stack.push(*(vi++));
continue;
}
Number top = stack.top();
stack.pop();
switch (*(oi++)) {
case Op::Add:
stack.top() += top;
break;
case Op::Sub:
stack.top() -= top;
break;
case Op::Mul:
stack.top() *= top;
break;
case Op::Div:
if (top == 0) {
return std::numeric_limits<Number>::max();
}
Number res = stack.top() / top;
if (res * top != stack.top()) {
return std::numeric_limits<Number>::max();
}
stack.top() = res;
break;
}
}
Number res = stack.top();
stack.pop();
return res;
}
bool nextValuesPermutation() {
return std::next_permutation(values_.begin(), values_.end());
}
bool nextOps() {
for (auto i = ops_.rbegin(), end = ops_.rend(); i != end; ++i) {
std::size_t next = static_cast<std::size_t>(*i) + 1;
if (next < NumOps) {
*i = static_cast<Op>(next);
return true;
}
*i = FirstOp;
}
return false;
}
bool nextSelectorPermutation() {
// the start permutation is always valid
do {
if (!std::next_permutation(selector_.begin(), selector_.end())) {
return false;
}
} while (!isSelectorValid());
return true;
}
static std::string buildExpr(const std::string& left, char op, const std::string &right) {
return std::string("(") + left + ' ' + op + ' ' + right + ')';
}
std::string toString() const {
Stack<std::string> stack;
auto vi = values_.cbegin();
auto oi = ops_.cbegin();
for (auto s : selector_) {
if (!s) {
stack.push(std::to_string(*(vi++)));
continue;
}
std::string top = stack.top();
stack.pop();
switch (*(oi++)) {
case Op::Add:
stack.top() = buildExpr(stack.top(), '+', top);
break;
case Op::Sub:
stack.top() = buildExpr(stack.top(), '-', top);
break;
case Op::Mul:
stack.top() = buildExpr(stack.top(), '*', top);
break;
case Op::Div:
stack.top() = buildExpr(stack.top(), '/', top);
break;
}
}
return stack.top();
}
public:
Evaluator(const std::vector<Number>& values) :
values_(values),
ops_(values.size() - 1, FirstOp),
selector_(2 * values.size() - 1, 0) {
std::fill(selector_.begin() + values_.size(), selector_.end(), 1);
std::sort(values_.begin(), values_.end());
}
// check for solutions
// 1) we create valid permutations of our selector_ array (eg: "1 1 + 1 +",
// "1 1 1 + +", but skip "1 + 1 1 +" as that cannot be evaluated
// 2) for each evaluation order, we permutate our values
// 3) for each value permutation we check with each combination of
// operators
//
// In the first version I used a local stack in eval() (see toString()) but
// it turned out to be a performance bottleneck, so now I use a cached
// stack. Reusing the stack gives an order of magnitude speed-up (from
// 4.3sec to 0.7sec) due to avoiding repeated allocations. Using
// std::vector as a backing store also gives a slight performance boost
// over the default std::deque.
std::size_t check(Number target, bool singleResult = false) {
Stack<Number> stack;
std::size_t res = 0;
do {
do {
do {
Number value = eval(stack);
if (value == target) {
++res;
std::cout << target << " = " << toString() << "\n";
if (singleResult) {
return res;
}
}
} while (nextOps());
} while (nextValuesPermutation());
} while (nextSelectorPermutation());
return res;
}
};
} // namespace
int main(int argc, const char **argv) {
int i = 1;
bool singleResult = false;
if (argc > 1 && std::string("-s") == argv[1]) {
singleResult = true;
++i;
}
if (argc < i + 2) {
std::cerr << argv[0] << " [-s] <target> <digit>[ <digit>]...\n";
std::exit(1);
}
Number target = std::stoi(argv[i]);
std::vector<Number> values;
while (++i < argc) {
values.push_back(std::stoi(argv[i]));
}
Evaluator evaluator{values};
std::size_t res = evaluator.check(target, singleResult);
if (!singleResult) {
std::cout << "Number of solutions: " << res << "\n";
}
return 0;
}
Input is obviously a set of digits and operators: D={1,3,3,6,7,8,3} and Op={+,-,*,/}. The most straight forward algorithm would be a brute force solver, which enumerates all possible combinations of these sets. Where the elements of set Op can be used as often as wanted, but elements from set D are used exactly once. Pseudo code:
D={1,3,3,6,7,8,3}
Op={+,-,*,/}
Solution=348
for each permutation D_ of D:
for each binary tree T with D_ as its leafs:
for each sequence of operators Op_ from Op with length |D_|-1:
label each inner tree node with operators from Op_
result = compute T using infix traversal
if result==Solution
return T
return nil
Other than that: read jedrus07's and HPM's answers.
By far the easiest approach is to intelligently brute force it. There is only a finite amount of expressions you can build out of 6 numbers and 4 operators, simply go through all of them.
How many? Since you don't have to use all numbers and may use the same operator multiple times, This problem is equivalent to "how many labeled strictly binary trees (aka full binary trees) can you make with at most 6 leaves, and four possible labels for each non-leaf node?".
The amount of full binary trees with n leaves is equal to catalan(n-1). You can see this as follows:
Every full binary tree with n leaves has n-1 internal nodes and corresponds to a non-full binary tree with n-1 nodes in a unique way (just delete all the leaves from the full one to get it). There happen to be catalan(n) possible binary trees with n nodes, so we can say that a strictly binary tree with n leaves has catalan(n-1) possible different structures.
There are 4 possible operators for each non-leaf node: 4^(n-1) possibilities
The leaves can be numbered in n! * (6 choose (n-1)) different ways. (Divide this by k! for each number that occurs k times, or just make sure all numbers are different)
So for 6 different numbers and 4 possible operators you get Sum(n=1...6) [ Catalan(n-1) * 6!/(6-n)! * 4^(n-1) ] possible expressions for a total of 33,665,406. Not a lot.
How do you enumerate these trees?
Given a collection of all trees with n-1 or less nodes, you can create all trees with n nodes by systematically pairing all of the n-1 trees with the empty tree, all n-2 trees with the 1 node tree, all n-3 trees with all 2 node tree etc. and using them as the left and right sub trees of a newly formed tree.
So starting with an empty set you first generate the tree that has just a root node, then from a new root you can use that either as a left or right sub tree which yields the two trees that look like this: / and . And so on.
You can turn them into a set of expressions on the fly (just loop over the operators and numbers) and evaluate them as you go until one yields the target number.
I've written my own countdown solver, in Python.
Here's the code; it is also available on GitHub:
#!/usr/bin/env python3
import sys
from itertools import combinations, product, zip_longest
from functools import lru_cache
assert sys.version_info >= (3, 6)
class Solutions:
def __init__(self, numbers):
self.all_numbers = numbers
self.size = len(numbers)
self.all_groups = self.unique_groups()
def unique_groups(self):
all_groups = {}
all_numbers, size = self.all_numbers, self.size
for m in range(1, size+1):
for numbers in combinations(all_numbers, m):
if numbers in all_groups:
continue
all_groups[numbers] = Group(numbers, all_groups)
return all_groups
def walk(self):
for group in self.all_groups.values():
yield from group.calculations
class Group:
def __init__(self, numbers, all_groups):
self.numbers = numbers
self.size = len(numbers)
self.partitions = list(self.partition_into_unique_pairs(all_groups))
self.calculations = list(self.perform_calculations())
def __repr__(self):
return str(self.numbers)
def partition_into_unique_pairs(self, all_groups):
# The pairs are unordered: a pair (a, b) is equivalent to (b, a).
# Therefore, for pairs of equal length only half of all combinations
# need to be generated to obtain all pairs; this is set by the limit.
if self.size == 1:
return
numbers, size = self.numbers, self.size
limits = (self.halfbinom(size, size//2), )
unique_numbers = set()
for m, limit in zip_longest(range((size+1)//2, size), limits):
for numbers1, numbers2 in self.paired_combinations(numbers, m, limit):
if numbers1 in unique_numbers:
continue
unique_numbers.add(numbers1)
group1, group2 = all_groups[numbers1], all_groups[numbers2]
yield (group1, group2)
def perform_calculations(self):
if self.size == 1:
yield Calculation.singleton(self.numbers[0])
return
for group1, group2 in self.partitions:
for calc1, calc2 in product(group1.calculations, group2.calculations):
yield from Calculation.generate(calc1, calc2)
#classmethod
def paired_combinations(cls, numbers, m, limit):
for cnt, numbers1 in enumerate(combinations(numbers, m), 1):
numbers2 = tuple(cls.filtering(numbers, numbers1))
yield (numbers1, numbers2)
if cnt == limit:
return
#staticmethod
def filtering(iterable, elements):
# filter elements out of an iterable, return the remaining elements
elems = iter(elements)
k = next(elems, None)
for n in iterable:
if n == k:
k = next(elems, None)
else:
yield n
#staticmethod
#lru_cache()
def halfbinom(n, k):
if n % 2 == 1:
return None
prod = 1
for m, l in zip(reversed(range(n+1-k, n+1)), range(1, k+1)):
prod = (prod*m)//l
return prod//2
class Calculation:
def __init__(self, expression, result, is_singleton=False):
self.expr = expression
self.result = result
self.is_singleton = is_singleton
def __repr__(self):
return self.expr
#classmethod
def singleton(cls, n):
return cls(f"{n}", n, is_singleton=True)
#classmethod
def generate(cls, calca, calcb):
if calca.result < calcb.result:
calca, calcb = calcb, calca
for result, op in cls.operations(calca.result, calcb.result):
expr1 = f"{calca.expr}" if calca.is_singleton else f"({calca.expr})"
expr2 = f"{calcb.expr}" if calcb.is_singleton else f"({calcb.expr})"
yield cls(f"{expr1} {op} {expr2}", result)
#staticmethod
def operations(x, y):
yield (x + y, '+')
if x > y: # exclude non-positive results
yield (x - y, '-')
if y > 1 and x > 1: # exclude trivial results
yield (x * y, 'x')
if y > 1 and x % y == 0: # exclude trivial and non-integer results
yield (x // y, '/')
def countdown_solver():
# input: target and numbers. If you want to play with more or less than
# 6 numbers, use the second version of 'unsorted_numbers'.
try:
target = int(sys.argv[1])
unsorted_numbers = (int(sys.argv[n+2]) for n in range(6)) # for 6 numbers
# unsorted_numbers = (int(n) for n in sys.argv[2:]) # for any numbers
numbers = tuple(sorted(unsorted_numbers, reverse=True))
except (IndexError, ValueError):
print("You must provide a target and numbers!")
return
solutions = Solutions(numbers)
smallest_difference = target
bestresults = []
for calculation in solutions.walk():
diff = abs(calculation.result - target)
if diff <= smallest_difference:
if diff < smallest_difference:
bestresults = [calculation]
smallest_difference = diff
else:
bestresults.append(calculation)
output(target, smallest_difference, bestresults)
def output(target, diff, results):
print(f"\nThe closest results differ from {target} by {diff}. They are:\n")
for calculation in results:
print(f"{calculation.result} = {calculation.expr}")
if __name__ == "__main__":
countdown_solver()
The algorithm works as follows:
The numbers are put into a tuple of length 6 in descending order. Then, all unique subgroups of lengths 1 to 6 are created, the smallest groups first.
Example: (75, 50, 5, 9, 1, 1) -> {(75), (50), (9), (5), (1), (75, 50), (75, 9), (75, 5), ..., (75, 50, 9, 5, 1, 1)}.
Next, the groups are organised into a hierarchical tree: every group is partitioned into all unique unordered pairs of its non-empty subgroups.
Example: (9, 5, 1, 1) -> [(9, 5, 1) + (1), (9, 1, 1) + (5), (5, 1, 1) + (9), (9, 5) + (1, 1), (9, 1) + (5, 1)].
Within each group of numbers, the calculations are performed and the results are stored. For groups of length 1, the result is simply the number itself. For larger groups, the calculations are carried out on every pair of subgroups: in each pair, all results of the first subgroup are combined with all results of the second subgroup using +, -, x and /, and the valid outcomes are stored.
Example: (75, 5) consists of the pair ((75), (5)). The result of (75) is 75; the result of (5) is 5; the results of (75, 5) are [75+5=80, 75-5=70, 75*5=375, 75/5=15].
In this manner, all results are generated, from the smallest groups to the largest. Finally, the algorithm iterates through all results and selects the ones that are the closest match to the target number.
For a group of m numbers, the maximum number of arithmetic computations is
comps[m] = 4*sum(binom(m, k)*comps[k]*comps[m-k]//(1 + (2*k)//m) for k in range(1, m//2+1))
For all groups of length 1 to 6, the maximum total number of computations is then
total = sum(binom(n, m)*comps[m] for m in range(1, n+1))
which is 1144386. In practice, it will be much less, because the algorithm reuses the results of duplicate groups, ignores trivial operations (adding 0, multiplying by 1, etc), and because the rules of the game dictate that intermediate results must be positive integers (which limits the use of the division operator).
I think, you need to strictly define the problem first. What you are allowed to do and what you are not. You can start by making it simple and only allowing multiplication, division, substraction and addition.
Now you know your problem space- set of inputs, set of available operations and desired input. If you have only 4 operations and x inputs, the number of combinations is less than:
The number of order in which you can carry out operations (x!) times the possible choices of operations on every step: 4^x. As you can see for 6 numbers it gives reasonable 2949120 operations. This means that this may be your limit for brute force algorithm.
Once you have brute force and you know it works, you can start improving your algorithm with some sort of A* algorithm which would require you to define heuristic functions.
In my opinion the best way to think about it is as the search problem. The main difficulty will be finding good heuristics, or ways to reduce your problem space (if you have numbers that do not add up to the answer, you will need at least one multiplication etc.). Start small, build on that and ask follow up questions once you have some code.
I wrote a terminal application to do this:
https://github.com/pg328/CountdownNumbersGame/tree/main
Inside, I've included an illustration of the calculation of the size of the solution space (it's n*((n-1)!^2)*(2^n-1), so: n=6 -> 2,764,800. I know, gross), and more importantly why that is. My implementation is there if you care to check it out, but in case you don't I'll explain here.
Essentially, at worst it is brute force because as far as I know it's impossible to determine whether any specific branch will result in a valid answer without explicitly checking. Having said that, the average case is some fraction of that; it's {that number} divided by the number of valid solutions (I tend to see around 1000 on my program, where 10 or so are unique and the rest are permutations fo those 10). If I handwaved a number, I'd say roughly 2,765 branches to check which takes like no time. (Yes, even in Python.)
TL;DR: Even though the solution space is huge and it takes a couple million operations to find all solutions, only one answer is needed. Best route is brute force til you find one and spit it out.
I wrote a slightly simpler version:
for every combination of 2 (distinct) elements from the list and combine them using +,-,*,/ (note that since a>b then only a-b is needed and only a/b if a%b=0)
if combination is target then record solution
recursively call on the reduced lists
import sys
def driver():
try:
target = int(sys.argv[1])
nums = list((int(sys.argv[i+2]) for i in range(6)))
except (IndexError, ValueError):
print("Provide a list of 7 numbers")
return
solutions = list()
solve(target, nums, list(), solutions)
unique = set()
final = list()
for s in solutions:
a = '-'.join(sorted(s))
if not a in unique:
unique.add(a)
final.append(s)
for s in final: #print them out
print(s)
def solve(target, nums, path, solutions):
if len(nums) == 1:
return
distinct = sorted(list(set(nums)), reverse = True)
rem1 = list(distinct)
for n1 in distinct: #reduce list by combining a pair
rem1.remove(n1)
for n2 in rem1:
rem2 = list(nums) # in case of duplicates we need to start with full list and take out the n1,n2 pair of elements
rem2.remove(n1)
rem2.remove(n2)
combine(target, solutions, path, rem2, n1, n2, '+')
combine(target, solutions, path, rem2, n1, n2, '-')
if n2 > 1:
combine(target, solutions, path, rem2, n1, n2, '*')
if not n1 % n2:
combine(target, solutions, path, rem2, n1, n2, '//')
def combine(target, solutions, path, rem2, n1, n2, symb):
lst = list(rem2)
ans = eval("{0}{2}{1}".format(n1, n2, symb))
newpath = path + ["{0}{3}{1}={2}".format(n1, n2, ans, symb[0])]
if ans == target:
solutions += [newpath]
else:
lst.append(ans)
solve(target, lst, newpath, solutions)
if __name__ == "__main__":
driver()

Why is my Scala tail-recursion faster than the while loop?

Here are two solutions to exercise 4.9 in Cay Horstmann's Scala for the Impatient: "Write a function lteqgt(values: Array[Int], v: Int) that returns a triple containing the counts of values less than v, equal to v, and greater than v." One uses tail recursion, the other uses a while loop. I thought that both would compile to similar bytecode but the while loop is slower than the tail recursion by a factor of almost 2. This suggests to me that my while method is badly written.
import scala.annotation.tailrec
import scala.util.Random
object PerformanceTest {
def main(args: Array[String]): Unit = {
val bigArray:Array[Int] = fillArray(new Array[Int](100000000))
println(time(lteqgt(bigArray, 25)))
println(time(lteqgt2(bigArray, 25)))
}
def time[T](block : => T):T = {
val start = System.nanoTime : Double
val result = block
val end = System.nanoTime : Double
println("Time = " + (end - start) / 1000000.0 + " millis")
result
}
#tailrec def fillArray(a:Array[Int], pos:Int=0):Array[Int] = {
if (pos == a.length)
a
else {
a(pos) = Random.nextInt(50)
fillArray(a, pos+1)
}
}
#tailrec def lteqgt(values: Array[Int], v:Int, lt:Int=0, eq:Int=0, gt:Int=0, pos:Int=0):(Int, Int, Int) = {
if (pos == values.length)
(lt, eq, gt)
else
lteqgt(values, v, lt + (if (values(pos) < v) 1 else 0), eq + (if (values(pos) == v) 1 else 0), gt + (if (values(pos) > v) 1 else 0), pos+1)
}
def lteqgt2(values:Array[Int], v:Int):(Int, Int, Int) = {
var lt = 0
var eq = 0
var gt = 0
var pos = 0
val limit = values.length
while (pos < limit) {
if (values(pos) > v)
gt += 1
else if (values(pos) < v)
lt += 1
else
eq += 1
pos += 1
}
(lt, eq, gt)
}
}
Adjust the size of bigArray according to your heap size. Here is some sample output:
Time = 245.110899 millis
(50004367,2003090,47992543)
Time = 465.836894 millis
(50004367,2003090,47992543)
Why is the while method so much slower than the tailrec? Naively the tailrec version looks to be at a slight disadvantage, as it must always perform 3 "if" checks for every iteration, whereas the while version will often only perform 1 or 2 tests due to the else construct. (NB reversing the order I perform the two methods does not affect the outcome).
Test results (after reducing array size to 20000000)
Under Java 1.6.22 I get 151 and 122 ms for tail-recursion and while-loop respectively.
Under Java 1.7.0 I get 55 and 101 ms
So under Java 6 your while-loop is actually faster; both have improved in performance under Java 7, but the tail-recursive version has overtaken the loop.
Explanation
The performance difference is due to the fact that in your loop, you conditionally add 1 to the totals, while for recursion you always add either 1 or 0. So they are not equivalent. The equivalent while-loop to your recursive method is:
def lteqgt2(values:Array[Int], v:Int):(Int, Int, Int) = {
var lt = 0
var eq = 0
var gt = 0
var pos = 0
val limit = values.length
while (pos < limit) {
gt += (if (values(pos) > v) 1 else 0)
lt += (if (values(pos) < v) 1 else 0)
eq += (if (values(pos) == v) 1 else 0)
pos += 1
}
(lt, eq, gt)
}
and this gives exactly the same execution time as the recursive method (regardless of Java version).
Discussion
I'm not an expert on why the Java 7 VM (HotSpot) can optimize this better than your first version, but I'd guess it's because it's taking the same path through the code each time (rather than branching along the if / else if paths), so the bytecode can be inlined more efficiently.
But remember that this is not the case in Java 6. Why one while-loop outperforms the other is a question of JVM internals. Happily for the Scala programmer, the version produced from idiomatic tail-recursion is the faster one in the latest version of the JVM.
The difference could also be occurring at the processor level. See this question, which explains how code slows down if it contains unpredictable branching.
The two constructs are not identical. In particular, in the first case you don't need any jumps (on x86, you can use cmp and setle and add, instead of having to use cmp and jb and (if you don't jump) add. Not jumping is faster than jumping on pretty much every modern architecture.
So, if you have code that looks like
if (a < b) x += 1
where you may add or you may jump instead, vs.
x += (a < b)
(which only makes sense in C/C++ where 1 = true and 0 = false), the latter tends to be faster as it can be turned into more compact assembly code. In Scala/Java, you can't do this, but you can do
x += if (a < b) 1 else 0
which a smart JVM should recognize is the same as x += (a < b), which has a jump-free machine code translation, which is usually faster than jumping. An even smarter JVM would recognize that
if (a < b) x += 1
is the same yet again (because adding zero doesn't do anything).
C/C++ compilers routinely perform optimizations like this. Being unable to apply any of these optimizations was not a mark in the JIT compiler's favor; apparently it can as of 1.7, but only partially (i.e. it doesn't recognize that adding zero is the same as a conditional adding one, but it does at least convert x += if (a<b) 1 else 0 into fast machine code).
Now, none of this has anything to do with tail recursion or while loops per se. With tail recursion it's more natural to write the if (a < b) 1 else 0 form, but you can do either; and with while loops you can also do either. It just so happened that you picked one form for tail recursion and the other for the while loop, making it look like recursion vs. looping was the change instead of the two different ways to do the conditionals.

Why is filter in front of foldLeft slow in Scala?

I wrote an answer to the first Project Euler question:
Add all the natural numbers below one thousand that are multiples of 3 or 5.
The first thing that came to me was:
(1 until 1000).filter(i => (i % 3 == 0 || i % 5 == 0)).foldLeft(0)(_ + _)
but it's slow (it takes 125 ms), so I rewrote it, simply thinking of 'another way' versus 'the faster way'
(1 until 1000).foldLeft(0){
(total, x) =>
x match {
case i if (i % 3 == 0 || i % 5 ==0) => i + total // Add
case _ => total //skip
}
}
This is much faster (only 2 ms). Why? I'm guess the second version uses only the Range generator and doesn't manifest a fully realized collection in any way, doing it all in one pass, both faster and with less memory. Am I right?
Here the code on IdeOne: http://ideone.com/GbKlP
The problem, as others have said, is that filter creates a new collection. The alternative withFilter doesn't, but that doesn't have a foldLeft. Also, using .view, .iterator or .toStream would all avoid creating the new collection in various ways, but they are all slower here than the first method you used, which I thought somewhat strange at first.
But, then... See, 1 until 1000 is a Range, whose size is actually very small, because it doesn't store each element. Also, Range's foreach is extremely optimized, and is even specialized, which is not the case of any of the other collections. Since foldLeft is implemented as a foreach, as long as you stay with a Range you get to enjoy its optimized methods.
(_: Range).foreach:
#inline final override def foreach[#specialized(Unit) U](f: Int => U) {
if (length > 0) {
val last = this.last
var i = start
while (i != last) {
f(i)
i += step
}
f(i)
}
}
(_: Range).view.foreach
def foreach[U](f: A => U): Unit =
iterator.foreach(f)
(_: Range).view.iterator
override def iterator: Iterator[A] = new Elements(0, length)
protected class Elements(start: Int, end: Int) extends BufferedIterator[A] with Serializable {
private var i = start
def hasNext: Boolean = i < end
def next: A =
if (i < end) {
val x = self(i)
i += 1
x
} else Iterator.empty.next
def head =
if (i < end) self(i) else Iterator.empty.next
/** $super
* '''Note:''' `drop` is overridden to enable fast searching in the middle of indexed sequences.
*/
override def drop(n: Int): Iterator[A] =
if (n > 0) new Elements(i + n, end) else this
/** $super
* '''Note:''' `take` is overridden to be symmetric to `drop`.
*/
override def take(n: Int): Iterator[A] =
if (n <= 0) Iterator.empty.buffered
else if (i + n < end) new Elements(i, i + n)
else this
}
(_: Range).view.iterator.foreach
def foreach[U](f: A => U) { while (hasNext) f(next()) }
And that, of course, doesn't even count the filter between view and foldLeft:
override def filter(p: A => Boolean): This = newFiltered(p).asInstanceOf[This]
protected def newFiltered(p: A => Boolean): Transformed[A] = new Filtered { val pred = p }
trait Filtered extends Transformed[A] {
protected[this] val pred: A => Boolean
override def foreach[U](f: A => U) {
for (x <- self)
if (pred(x)) f(x)
}
override def stringPrefix = self.stringPrefix+"F"
}
Try making the collection lazy first, so
(1 until 1000).view.filter...
instead of
(1 until 1000).filter...
That should avoid the cost of building an intermediate collection. You might also get better performance from using sum instead of foldLeft(0)(_ + _), it's always possible that some collection type might have a more efficient way to sum numbers. If not, it's still cleaner and more declarative...
Looking through the code, it looks like filter does build a new Seq on which the foldLeft is called. The second skips that bit. It's not so much the memory, although that can't but help, but that the filtered collection is never built at all. All that work is never done.
Range uses TranversableLike.filter, which looks like this:
def filter(p: A => Boolean): Repr = {
val b = newBuilder
for (x <- this)
if (p(x)) b += x
b.result
}
I think it's the += on line 4 that's the difference. Filtering in foldLeft eliminates it.
filter creates a whole new sequence on which then foldLeft is called. Try:
(1 until 1000).view.filter(i => (i % 3 == 0 || i % 5 == 0)).reduceLeft(_+_)
This will prevent said effect, merely wrapping the original thing. Exchanging foldLeft with reduceLeft is only cosmetic (in this case).
Now the challenge is, can you think of a yet more efficient way? Not that your solution is too slow in this case, but how well does it scale? What if instead of 1000, it was 1000000000? There is a solution that could compute the latter case just as quickly as the former.

Distributed algorithm to compute the balance of the parentheses

This is an interview question: "How to build a distributed algorithm to compute the balance of the parentheses ?"
Usually he balance algorithm scans a string form left to right and uses a stack to make sure that the number of open parentheses always >= the number of close parentheses and finally the number of open parentheses == the number of close parentheses.
How would you make it distributed ?
You can break the string into chunks and process each separately, assuming you can read and send to the other machines in parallel. You need two numbers for each string.
The minimum nesting depth achieved relative to the start of the string.
The total gain or loss in nesting depth across the whole string.
With these values, you can compute the values for the concatenation of many chunks as follows:
minNest = 0
totGain = 0
for p in chunkResults
minNest = min(minNest, totGain + p.minNest)
totGain += p.totGain
return new ChunkResult(minNest, totGain)
The parentheses are matched if the final values of totGain and minNest are zero.
I would apply the map-reduce algorithm in which the map function would compute a part of the string return either an empty string if parentheses are balanced or a string with the last parenthesis remaining.
Then the reduce function would concatenate the result of two returned strings by map function and compute it again returning the same result than map. At the end of all computations, you'd either obtain an empty string or a string containing the un-balanced parenthesis.
I'll try to have a more detailed explain on #jonderry's answer. Code first, in Scala
def parBalance(chars: Array[Char], chunkSize: Int): Boolean = {
require(chunkSize > 0, "chunkSize must be greater than 0")
def traverse(from: Int, until: Int): (Int, Int) = {
var count = 0
var stack = 0
var nest = 0
for (n <- from until until) {
val cur = chars(c)
if (cur == '(') {
count += 1
stack += 1
}
else if (cur == ')') {
count -= 1
if (stack > 0) stack -= 1
else nest -= 1
}
}
(nest, count)
}
def reduce(from: Int, until: Int): (Int, Int) = {
val m = (until + from) / 2
if (until - from <= chunkSize) {
traverse(from, until)
} else {
parallel(reduce(from, m), reduce(m, until)) match {
case ((minNestL, totGainL), (minNestR, totGainR)) => {
((minNestL min (minNestR + totGainL)), (totGainL + totGainR))
}
}
}
}
reduce(0, chars.length) == (0,0)
}
Given a string, if we remove balanced parentheses, what's left will be in a form )))(((, give n for number of ) and m for number of (, then m >= 0, n <= 0(for easier calculation). Here n is minNest and m+n is totGain. To make a true balanced string, we need m+n == 0 && n == 0.
In a parallel operation, how to we derive those for node from it's left and right? For totGain we just needs to add them up. When calculating n for node, it can just be n(left) if n(right) not contribute or n(right) + left.totGain whichever is smaller.

Resources