Some linear constraints seem to be ignored in function NMinimize with Mathematica 8 - wolfram-mathematica

I'm trying to minimize a non-linear function of four variables with some linear constraints. Mathematica 8 is unable to find a good solution giving complex values of the function at some point in the iteration. This implies that one or some contraints are not being enabled in the process. Is this a bug or limitation of the optimization function ?
Function to minimize is
ff[lxw_, lwz_, c_, d_] := - J1 (lxw + lwz) - 2 J2 c +
T (-Log[2] - 1/2 (1 - lxw) Log[(1 - lxw)/4] -
1/2 (1 + lxw) Log[(1 + lxw)/4] -
1/2 (1 - lwz) Log[(1 - lwz)/4] -
1/2 (1 + lwz) Log[(1 + lwz)/4] + 1/2 (1 - d) Log[(1 - d)/16] +
1/8 (1 + 2 c + d - 2 lwz - 2 lxw) Log[
1/16 (1 + 2 c + d - 2 lwz - 2 lxw)])
where
T = 10;
J1 = 1;
J2 = -0.2;
are constant parameters. Then I try
NMinimize[{ff[lxw, lwz, c, d],
2 c + d - 2 lwz - 2 lxw >= -0.999 &&
-0.999 <= lxw <= 0.999 &&
-0.999 <= lwz <= 0.999 &&
-0.999 <= c <= 0.999 &&
d <= 0.9999}, {lxw, lwz, c, d}]
with the result
NMinimize::nrnum: "The function value 5.87777[VeryThinSpace]-4.87764\ I\n
is not a real number at {c,d,lwz,lxw} = {-0.718817,-1.28595,0.69171,-0.932461}.
I would appreciate if someone can give a hint at what is happening here.

Try this:
Clear[ff];
ff[lxw_, lwz_, c_, d_] /; 2 c + d - 2 lwz - 2 lxw >= -0.999 :=
< your function def >
This will cause the cause the function to be unevaluated in case NMinimize takes an excursion out of bounds. Sorry i cant test this from here.. If that doesn't do try asking on mathematica.stackexchange.com
Aside, why use <=.999 instead of simply < 1 ?
It just might help if you fix that too ( use integer 1, not 1. )

The warning is appearing because at the values given in the warning the last term in ff is complex, due to taking the log of a negative number, i.e.
{c, d, lwz, lxw} = {
-0.7188174745559741`,
-1.2859482844800894`,
0.6917100913968041`,
-0.9324611085040573`};
Log[1/16 (1 + 2 c + d - 2 lwz - 2 lxw)]
-2.5558 + 3.14159 i
1/16 (1 + 2 c + d - 2 lwz - 2 lxw)
-0.0776301
In Mathematica 9 a result is produced in addition to the warning :-
{-4.90045, {c -> 0.94425, d -> -0.315633, lwz -> 0.900231, lxw -> -0.191476}}
I.e.
{c, d, lwz, lxw} = {
0.9442497691706085`,
-0.31563295950647885`,
0.900230825707721`,
-0.1914760216875171`};
ff[lxw, lwz, c, d]
-4.90045

Related

Code performance optimization : speeding up population of an array in Matlab

Part of a longer algorithm, populating the column array b takes 80% of the total run time. The line (see code below) where S1 is calculated in the for loop is the bottleneck because it's called an enormously large amount of times due to the nested loops. I have tried vectorizing the loops and/or using the sum built-in function but I always end up with an even higher computation time.
For the sake of simplicity we can look at one single iteration, namely n = 1 and n + 1 = 2. So the 3D array containing the quantity R at each iteration has only two elements : R(:,:,1) and R(:,:,2).
Where:
I have posted here a minimum working sample that you can yourself run:
tic
Nz = 3636 ; % Number of rows
Nx = 910 ; % Number of columns
R = rand(Nz, Nx, 2) ;
alpha = 3 ;
for ii = Nx - 1 : -1 : 2
b = zeros(Nz, 1) ;
for jj = 2 : Nz - 1
S = 0 ;
S1 = 0 ;
for mm = Nx - 1 : -1 : ii % Summation
S1 = S1 + R(2, mm, 2) + R(2, mm, 1) - (R(1, mm, 2) + R(1, mm, 1)) ;
S = S + (R(jj + 1, mm, 2) + R(jj + 1, mm, 1)) - 2 * (R(jj, mm, 2) + R(jj, mm, 1)) + (R(jj - 1 , mm, 2) + R(jj - 1, mm, 1)) ;
end
b(1) = 2 *(R(1, ii, 1) + 2 * alpha * S1) ;
b(jj) = 2 *(R(jj, ii, 1) + alpha * S) ;
end
end
toc
The current elapsed time is about 28-30 [s] on my laptop (i7-9850H, 2.6 GHz, 16 GB RAM).
I am aware that b is reset to zero at each ii iteration, but to reduce the sample to bare minimum I omitted a line just before the last end: q = D\b. It's a linear system solved column by column starting from the right-most column of the grid (most external loop). That's why at every ii (column) iteration I compute a new b and thus a new solution q.
How can I speed up this computation?
EDIT:
I tried to vectorize the sum, by eliminating the inner for loop and using the function sum, but the time increases dramatically.
tic
Nz = 3636 ; % Number of rows
Nx = 910 ; % Number of columns
R = rand(Nz, Nx, 2) ;
alpha = 3 ;
for ii = Nx - 1 : -1 : 2
b = zeros(Nz, 1) ;
for jj = 2 : Nz - 1
S1 = sum([R(2, Nx - 1 : -1 : ii, 2), R(2, Nx - 1 : -1 : ii, 1), - R(1, Nx - 1 : -1 : ii, 2), R(1,Nx - 1 : -1 : ii, 1)]) ;
S = sum([R(jj + 1, Nx - 1 : -1 : ii, 2), R(jj + 1, Nx - 1 : -1 : ii, 1), - 2 * R(jj, Nx - 1 : -1 : ii, 2), 2 * R(jj, Nx - 1 : -1 : ii, 1), R(jj - 1 , Nx - 1 : -1 : ii, 2), R(jj - 1, Nx - 1 : -1 : ii, 1)]) ;
b(1) = 2 *(R(1, ii, 1) + 2 * alpha * S1) ;
b(jj) = 2 *(R(jj, ii, 1) + alpha * S) ;
end
end
toc
Most of the computations are repeated summations that can be converted to cumsum. The result b is a matrix that you simply can use D\b to compute q for all columns at once.
R = R(:, 2:end-1, :);
sm = sum(R, 3);
S1 = cumsum(sm(2, :) - sm(1, :), 2, 'reverse')
S = cumsum(conv2(sm, [1;-2;1], 'valid'), 2, 'reverse')
b = [2 .* (R(1, :, 1) + 2 * alpha * S1);
2 .* (R(2:end-1, :, 1) + alpha * S);
zeros(1, size(R, 2))];
q = D\b;
On my laptop (MATLAB R2020a, i7-5500U, 2.4 GHz, 8 GB RAM), your original code runs in about 45 s.
Elapsed time is 45.239229 seconds.
Your attempt using sum runs in about 191 s.
Elapsed time is 191.424418 seconds.
#Ander Biguri suggestion runs in about 174 s, which is faster, but not as fast as using for loops.
Elapsed time is 174.288989 seconds.
It appears that MATLAB built-in sum function runs slower than the for loop, as briefly discussed here. The array initialization b=zeros(Nz,1) is fast in newer versions of MATLAB (at least since R2019a), but you can initialize it differently if you are using an older version.
#Ander suggestions, however, are noteworthy and we can follow his suggestions, but use for loop for the summations:
tic
Nz = 3636 ; % Number of rows
Nx = 910 ; % Number of columns
R = rand(Nz, Nx, 2) ;
alpha = 3 ;
for ii = Nx - 1 : -1 : 2
b = zeros(Nz, 1) ;
S1 = 0 ;
for mm = Nx - 1 : -1 : ii % Summation
S = S + (R(jj + 1, mm, 2) + R(jj + 1, mm, 1)) - 2 * (R(jj, mm, 2) + R(jj, mm, 1)) + (R(jj - 1 , mm, 2) + R(jj - 1, mm, 1)) ;
end
for jj = 2 : Nz - 1
S = 0 ;
for mm = Nx - 1 : -1 : ii % Summation
S = S + (R(jj + 1, mm, 2) + R(jj + 1, mm, 1)) - 2 * (R(jj, mm, 2) + R(jj, mm, 1)) + (R(jj - 1 , mm, 2) + R(jj - 1, mm, 1)) ;
end
b(1) = 2 *(R(1, ii, 1) + 2 * alpha * S1) ;
b(jj) = 2 *(R(jj, ii, 1) + alpha * S) ;
end
end
toc
Elapsed time is 26.911663 seconds.
If computing time is a big concern, you can create some temporary variables before the for loop to avoid some summations/multiplications. This makes the code slightly worse to read, but it runs faster:
tic
Nz = 3636 ; % Number of rows
Nx = 910 ; % Number of columns
R = rand(Nz, Nx, 2) ;
R12 = R(:,:,1)+R(:,:,2);
alpha = 3 ;
Nx1 = Nx-1;
Nz1 = Nz-1;
two_alpha = 2*alpha;
for ii = Nx1:-1:2
b = zeros(Nz, 1) ;
S1 = 0 ;
for mm = ii:Nx1 % Summation
S1 = S1 + R12(2, mm) - R12(1, mm) ;
end
for jj = 2:Nz1
S = 0 ;
for mm = ii:Nx1 % Summation
S = S + R12(jj + 1, mm) -2*R12(jj, mm) + R12(jj - 1 , mm) ;
end
b(1) = 2 *(R(1, ii,1) + two_alpha*S1) ;
b(jj) = 2 *(R(jj, ii,1) + alpha*S) ;
end
end
toc
Elapsed time is 12.202103 seconds.
Since the limits of summation extend up to N-1, creation of the Nx1 and Nz1 variables avoids three subtractions on each ii-loop. The two_alpha variable also avoids one multiplication per iteration. Creation of the R12 also speeds up calculation, since in the original code this sum on the 3rd dimesion appears on every line of the summations.
I tried to also create a two_R12=2*R12 variable for the mm-loop, but it runs slower since the for loop must access elements on two different variables.
You also stated that the b variable is used to solve a liner system q = D\b. This will also slow down computations. If the D matrix is constant, consider decompose it, in order to also speed up calculations.
I'm just working with your optimized code. Mostly because your optimized code does not do the same as the non optimized one, there are more things that make no sense (b is replaced in all of the ii loops, so ii is superfluous). I give you thus some tips to improve code in general, but I can't fix a broken question.
Tips:
make sure your code is easy to read by you.
R(2, Nx - 1 : -1 : ii, 2), R(2, Nx - 1 : -1 : ii, 1) is just R(2, Nx - 1 : -1 : ii, 2:1), and in fact as you are adding them R(2, Nx - 1 : -1 : ii, :).
Indexing an array is much faster if you do it as expected, i.e., in order. You are adding array elements, but reading them in reverse. Just read them normally, i.e. Nx-1:-1:ii is just ii:Nx-1. Bonus that it's much more readable. In fact the above line is now suddenly R(2, ii:Nx-1, :). Much more readable.
Don't create temporary arrays if can be avoided. Particularly if in the end, these will be the size of R! You are just indexing it, and you don't care about the ordering because you will add it together. S can be just:
S = sum(R(jj-1:jj+1, ii : Nx-1, :).*cat(3,[1;2;1],[1; -2; 1]),'all') ;
% R(jj-1:jj+1, ii : Nx-1, :) % grabs the needed values
% cat(3,[1;2;1],[1; -2; 1]) % small trick to multiply the arrays (implicit broadcast)
% sum( ,'all') % add them regardless of shape.
This already speeds up the code by a lot.
4- Don't overcompute. S1 does not depend in jj, in fact, it's the same for each ii. So why compute it jj times for each ii ?
So, you can change your optimized code to the following for a x1.7 speedup:
tic
Nz = 1000 ; % Number of rows
Nx = 200 ; % Number of columns
R = rand(Nz, Nx, 2) ;
alpha = 3 ;
for ii = Nx - 1 : -1 : 2
b1 = zeros(Nz, 1) ;
% we can optimize this one too, but I don't think its the main culprit
% here.
S1 = sum([R(2, ii : Nx-1, 2), R(2, ii : Nx-1, 1),...
- R(1, ii : Nx-1, 2), R(1,ii : Nx-1, 1)]) ;
b1(1) = 2 *(R(1, ii, 1) + 2 * alpha * S1) ;
for jj = 2 : Nz - 1
S = sum(R(jj-1:jj+1, ii : Nx-1, :).*cat(3,[1;2;1],[1; -2; 1]),'all') ;
b1(jj) = 2 *(R(jj, ii, 1) + alpha * S) ;
end
end
toc

Mathematica Code with Module and If statement

Can I simply ask the logical flow of the below Mathematica code? What are the variables arg and abs doing? I have been searching for answers online and used ToMatlab but still cannot get the answer. Thank you.
Code:
PositiveCubicRoot[p_, q_, r_] :=
Module[{po3 = p/3, a, b, det, abs, arg},
b = ( po3^3 - po3 q/2 + r/2);
a = (-po3^2 + q/3);
det = a^3 + b^2;
If[det >= 0,
det = Power[Sqrt[det] - b, 1/3];
-po3 - a/det + det
,
(* evaluate real part, imaginary parts cancel anyway *)
abs = Sqrt[-a^3];
arg = ArcCos[-b/abs];
abs = Power[abs, 1/3];
abs = (abs - a/abs);
arg = -po3 + abs*Cos[arg/3]
]
]
abs and arg are being reused multiple times in the algorithm.
In a case where det > 0 the steps are
po3 = p/3;
b = (po3^3 - po3 q/2 + r/2);
a = (-po3^2 + q/3);
abs1 = Sqrt[-a^3];
arg1 = ArcCos[-b/abs1];
abs2 = Power[abs1, 1/3];
abs3 = (abs2 - a/abs2);
arg2 = -po3 + abs3*Cos[arg1/3]
abs3 can be identified as A in this answer: Using trig identity to a solve cubic equation
That is the most salient point of this answer.
Evaluating symbolically and numerically may provide some other insights.
Using demo inputs
{p, q, r} = {-2.52111798, -71.424692, -129.51520};
Copyable version of trig identity notes - NB a, b, p & q are used differently in this post
Plot[x^3 - 2.52111798 x^2 - 71.424692 x - 129.51520, {x, 0, 15}]
a = 1;
b = -2.52111798;
c = -71.424692;
d = -129.51520;
p = (3 a c - b^2)/3 a^2;
q = (2 b^3 - 9 a b c + 27 a^2 d)/27 a^3;
A = 2 Sqrt[-p/3]
A == abs3
-(b/3) + A Cos[1/3 ArcCos[
-((b/3)^3 - (b/3) c/2 + d/2)/Sqrt[-(-(b^2/9) + c/3)^3]]]
Edit
There is also a solution shown here
TRIGONOMETRIC SOLUTION TO THE CUBIC EQUATION, by Alvaro H. Salas
Clear[a, b, c]
1/3 (-a + 2 Sqrt[a^2 - 3 b] Cos[1/3 ArcCos[
(-2 a^3 + 9 a b - 27 c)/(2 (a^2 - 3 b)^(3/2))]]) /.
{a -> -2.52111798, b -> -71.424692, c -> -129.51520}
10.499

Using conditions to find imaginary and real part

I have used Solve to find the solution of an equation in Mathematica (The reason I am posting here is that no one could answer my question in mathematica stack.)The solution is called s and it is a function of two variables called v and ro. I want to find imaginary and real part of s and I want to use the information that v and ro are real and they are in the below interval:
$ 0.02 < ro < 1 ,
40
The code I used is as below:
ClearAll["Global`*"]
d = 1; l = 100; k = 0.001; kk = 0.001;ke = 0.0014;dd = 0.5 ; dr = 0.06; dc = 1000; p = Sqrt[8 (ro l /2 - 1)]/l^2;
m = (4 dr + ke^2 (d + dd)/2) (-k^2 + kk^2) (1 - l ro/2) (d - dd)/4 -
I v p k l (4 dr + ke^2 (d + dd)/2)/4 - v^2 ke^2/4 + I v k dr l p/4;
xr = 0.06/n;
tr = d/n;
dp = (x (v I kk/2 (4 dr + ke^2 (d + dd)/2) - I v kk ke^2 (d - dd)/8 - dr l p k kk (d - dd)/4) + y ((xr I kk (ro - 1/l) (4 dr + ke^2 (d + dd)/2)) - I v kk tr ke^2 (1/l - ro/2) + I dr xr 4 kk (1/l - ro/2)))/m;
a = -I v k dp/4 - I xr y kk p/2 + l ke^2 dp p (d + dd)/8 + (-d + dd)/4 k kk x + dr l p dp;
aa = -v I kk dp/4 + xr I y k p/2 - tr y ke^2 (1/l - ro/2) - (d - dd) x kk^2/4 + ke^2 x (d - dd)/8;
ca = CoefficientArrays[{x (s + ke^2 (d + dd)/2) +
dp (v I kk - l (d - dd) k p kk/2) + y (tr ro ke^2) - (d -
dd) ((-kk^2 + k^2) aa - 2 k kk a)/(4 dr + ke^2 (d + dd)/2) == 0, y (s + dc ke^2) + n x == 0}, {x, y}];
mat = Normal[ca];
matt = Last#mat;
sha = Solve[Det[matt] == 0, s];
shaa = Assuming[v < 100 && v > 40 && ro < 1 && ro > 0.03,Simplify[%]];
reals = Re[shaa];
ims = Im[shaa];
Solve[reals == 0, ro]
but it gives no answer. Could anyone help? I really appreciate any solution to this problem.
I run your code down to this point
mat = Normal[ca]
and look at the result.
There are lots of very tiny floating point coefficients, so small that I suspect most of them are just floating point noise now. Mathematica thinks 0.1 is only known to 1 significant digit of precision and your mat result is perhaps nothing more than zero correct digits now.
I continue down to this point
sha = Solve[Det[matt] == 0, s]
If you look at the value of sha you will see it is s->stuff and I don't think that is at all what you think it is. Mathematica returns "rules" from Solve, not just expressions.
If I change that line to
sha = s/.Solve[Det[matt] == 0, s]
then I am guessing that is closer to what you are imagining you want.
I continue to
shaa = Assuming[40<v<100 && .03<ro<1, Simplify[sha]];
reals = Re[shaa]
And I instead use, because you are assuming v and ro to be Real and because ComplexExpand has often been very helpful in getting Re to provide desired results,
reals=Re[ComplexExpand[shaa]]
and I click on Show ALL to see the full expanded value of that. That is about 32 large screens full of your expression.
In that are hundreds of
Arg[-1. + 50. ro]
and if I understand your intention I believe all those simplify to 0. If that is correct then
reals=reals/.Arg[-1. + 50. ro]->0
reduces the size of reals down to about 20 large screen fulls.
But there are still hundreds of examples of Sqrt[(-1.+50. ro)^2] and ((-1.+50. ro)^2)^(1/4) making up your reals. Unfortunately I'm expecting your enormous expression is too large and will take too long for Simplify with assumptions to be able to be practically effective.
Perhaps additional replacements to coax it into dramatically simplifying your reals without making any mistakes about Real versus Complex, but you have to be extremely careful with such things because it is very common for users to make mistakes when dealing with complex numbers and roots and powers and functions and end up with an incorrect result, might get your problem down to the point where it might be feasible for
Solve[reals == 0, ro]
to give you a meaningful answer.
This should give you some ideas of what you need to think carefully about and work on.

Can integer division be rounded up, instead of down?

Is there a way to round the result of integer division up to the nearest integer, rather than down?
For example, I would like to change the default behavior:
irb(main):001:0> 5 / 2
=> 2
To the following behavior:
irb(main):001:0> 5 / 2
=> 3
The function you are looking for is ceil.
Ceil returns the nearest integer, rounded upwards, for a floating point number.
4/3 = 1
4.0/3.0 = 1.3333...3
(4.0/3.0).ceil = 2
Also, note that this rounds in the positive direction, so
(-4.0/3.0).ceil = -1, NOT -2
Also, there is the corresponding floor function which rounds downwards.
This is rather an algorithm question than a ruby specific question.
Try (a + b - 1) / b. For example
(5 + 2 - 1) / 2 #=> 3
(10 + 3 - 1) / 3 #=> 4
(6 + 3 - 1) / 3 #=> 2
You can define an instance method, say divide_by, in the Integer class (monkey patch):
class Integer
def divide_by(divisor)
(self + divisor - 1) / divisor
end
end
According to my benchmark result, it's about 1/2 times faster than the to_f then ceil solution.
CORRECTION
The method shown above gives wrong result when both the dividend and the divisor are negative.
Here's the method that gives the correct result in all cases: (a * 2 + b) / (b * 2)
a = 5
b = 2
(a * 2 + b) / (b * 2) #=> 3
a = 6
b = 2
(a * 2 + b) / (b * 2) #=> 3
a = 5
b = 1
(a * 2 + b) / (b * 2) #=> 5
a = -5
b = 2
(a * 2 + b) / (b * 2) #=> -2 (-2.5 rounded up to -2)
a = 5
b = -2
(a * 2 + b) / (b * 2) #=> -2 (-2.5 rounded up to -2)
a = -5
b = -2
(a * 2 + b) / (b * 2) #=> 3
a = 10
b = 0
(a * 2 + b) / (b * 2) #=> raises ZeroDivisionError
The monkey patch should be
class Integer
def divide_by(divisor)
(self * 2 + divisor) / (divisor * 2)
end
end
Mathematical Proof:
The dividend a and the divisor b meets the equation a = kb + m where a, b, k, m are all integers, b is not zero, and m is between b and 0 (can be 0).
For example, when a is 5 and b is 2, then a = 2b + 1, thus in this case, k = 2 and m = 1.
Another example for negative divisor, a is 5, b is -2, then a = -3b + (-1), thus k = -3 and m = -1.
(2a + b) / 2b
= (2(kb + m) + b) / 2b
= (2kb + b + 2m) / 2b
When m = 0
(2kb + b + 2m) / 2b
= (2k + 1)b / 2b
= k + (1 / 2)
= k + 0 # in Ruby
= k # in Ruby
and since k = a / b, we got the correct answer.
When m is not 0,
(2kb + b + 2m) / 2b
= ((2k + 2)b - b + 2m) / 2b
= (k + 1) + (2m - b) / 2b
If b > 0, then 2m - b < b so (2m - b) / 2b < 1 / 2. So the second term is always 0 in integer division.
If b < 0, then 2m - b > b and still (2m - b) / 2b < 1 / 2 so the second term is still 0.
In either case, (2a + b) / 2b is rounded to k + 1 when m is not 0.
If what you actually want is to Integer div and round up if there's any left over, just do it straightforward as the logic would dictate on paper using a second line of modular operation (%) to check the remainders of the division:
a = 5
b = 2
result = a / b #=> 2
result += 1 if (a % b).positive?
#=> 3
a = 6
b = 3
result = a / b #=> 2
result += 1 if (a % b).positive?
#=> 2

How to find all possible values of four variables when squared sum to N?

A^2+B^2+C^2+D^2 = N Given an integer N, print out all possible combinations of integer values of ABCD which solve the equation.
I am guessing we can do better than brute force.
Naive brute force would be something like:
n = 3200724;
lim = sqrt (n) + 1;
for (a = 0; a <= lim; a++)
for (b = 0; b <= lim; b++)
for (c = 0; c <= lim; c++)
for (d = 0; d <= lim; d++)
if (a * a + b * b + c * c + d * d == n)
printf ("%d %d %d %d\n", a, b, c, d);
Unfortunately, this will result in over a trillion loops, not overly efficient.
You can actually do substantially better than that by discounting huge numbers of impossibilities at each level, with something like:
#include <stdio.h>
int main(int argc, char *argv[]) {
int n = atoi (argv[1]);
int a, b, c, d, na, nb, nc, nd;
int count = 0;
for (a = 0, na = n; a * a <= na; a++) {
for (b = 0, nb = na - a * a; b * b <= nb; b++) {
for (c = 0, nc = nb - b * b; c * c <= nc; c++) {
for (d = 0, nd = nc - c * c; d * d <= nd; d++) {
if (d * d == nd) {
printf ("%d %d %d %d\n", a, b, c, d);
count++;
}
tot++;
}
}
}
}
printf ("Found %d solutions\n", count);
return 0;
}
It's still brute force, but not quite as brutish inasmuch as it understands when to stop each level of looping as early as possible.
On my (relatively) modest box, that takes under a second (a) to get all solutions for numbers up to 50,000. Beyond that, it starts taking more time:
n time taken
---------- ----------
100,000 3.7s
1,000,000 6m, 18.7s
For n = ten million, it had been going about an hour and a half before I killed it.
So, I would say brute force is perfectly acceptable up to a point. Beyond that, more mathematical solutions would be needed.
For even more efficiency, you could only check those solutions where d >= c >= b >= a. That's because you could then build up all the solutions from those combinations into permutations (with potential duplicate removal where the values of two or more of a, b, c, or d are identical).
In addition, the body of the d loop doesn't need to check every value of d, just the last possible one.
Getting the results for 1,000,000 in that case takes under ten seconds rather than over six minutes:
0 0 0 1000
0 0 280 960
0 0 352 936
0 0 600 800
0 24 640 768
: : : :
424 512 512 544
428 460 500 596
432 440 480 624
436 476 532 548
444 468 468 604
448 464 520 560
452 452 476 604
452 484 484 572
500 500 500 500
Found 1302 solutions
real 0m9.517s
user 0m9.505s
sys 0m0.012s
That code follows:
#include <stdio.h>
int main(int argc, char *argv[]) {
int n = atoi (argv[1]);
int a, b, c, d, na, nb, nc, nd;
int count = 0;
for (a = 0, na = n; a * a <= na; a++) {
for (b = a, nb = na - a * a; b * b <= nb; b++) {
for (c = b, nc = nb - b * b; c * c <= nc; c++) {
for (d = c, nd = nc - c * c; d * d < nd; d++);
if (d * d == nd) {
printf ("%4d %4d %4d %4d\n", a, b, c, d);
count++;
}
}
}
}
printf ("Found %d solutions\n", count);
return 0;
}
And, as per a suggestion by DSM, the d loop can disappear altogether (since there's only one possible value of d (discounting negative numbers) and it can be calculated), which brings the one million case down to two seconds for me, and the ten million case to a far more manageable 68 seconds.
That version is as follows:
#include <stdio.h>
#include <math.h>
int main(int argc, char *argv[]) {
int n = atoi (argv[1]);
int a, b, c, d, na, nb, nc, nd;
int count = 0;
for (a = 0, na = n; a * a <= na; a++) {
for (b = a, nb = na - a * a; b * b <= nb; b++) {
for (c = b, nc = nb - b * b; c * c <= nc; c++) {
nd = nc - c * c;
d = sqrt (nd);
if (d * d == nd) {
printf ("%d %d %d %d\n", a, b, c, d);
count++;
}
}
}
}
printf ("Found %d solutions\n", count);
return 0;
}
(a): All timings are done with the inner printf commented out so that I/O doesn't skew the figures.
The Wikipedia page has some interesting background information, but Lagrange's four-square theorem (or, more correctly, Bachet's Theorem - Lagrange only proved it) doesn't really go into detail on how to find said squares.
As I said in my comment, the solution is going to be nontrivial. This paper discusses the solvability of four-square sums. The paper alleges that:
There is no convenient algorithm (beyond the simple one mentioned in
the second paragraph of this paper) for finding additional solutions
that are indicated by the calculation of representations, but perhaps
this will streamline the search by giving an idea of what kinds of
solutions do and do not exist.
There are a few other interesting facts related to this topic. There
exist other theorems that state that every integer can be written as a
sum of four particular multiples of squares. For example, every
integer can be written as N = a^2 + 2b^2 + 4c^2 + 14d^2. There are 54
cases like this that are true for all integers, and Ramanujan provided
the complete list in the year 1917.
For more information, see Modular Forms. This is not easy to understand unless you have some background in number theory. If you could generalize Ramanujan's 54 forms, you may have an easier time with this. With that said, in the first paper I cite, there is a small snippet which discusses an algorithm that may find every solution (even though I find it a bit hard to follow):
For example, it was reported in 1911 that the calculator Gottfried
Ruckle was asked to reduce N = 15663 as a sum of four squares. He
produced a solution of 125^2 + 6^2 + 1^2 + 1^2 in 8 seconds, followed
immediately by 125^2 + 5^2 + 3^2 + 2^2. A more difficult problem
(reflected by a first term that is farther from the original number,
with correspondingly larger later terms) took 56 seconds: 11399 = 105^2
+ 15^2 + 8^2 + 5^2. In general, the strategy is to begin by setting the first term to be the largest square below N and try to represent the
smaller remainder as a sum of three squares. Then the first term is
set to the next largest square below N, and so forth. Over time a
lightning calculator would become familiar with expressing small
numbers as sums of squares, which would speed up the process.
(Emphasis mine.)
The algorithm is described as being recursive, but it could easily be implemented iteratively.
It seems as though all integers can be made by such a combination:
0 = 0^2 + 0^2 + 0^2 + 0^2
1 = 1^2 + 0^2 + 0^2 + 0^2
2 = 1^2 + 1^2 + 0^2 + 0^2
3 = 1^2 + 1^2 + 1^2 + 0^2
4 = 2^2 + 0^2 + 0^2 + 0^2, 1^2 + 1^2 + 1^2 + 1^2 + 1^2
5 = 2^2 + 1^2 + 0^2 + 0^2
6 = 2^2 + 1^2 + 1^2 + 0^2
7 = 2^2 + 1^2 + 1^2 + 1^2
8 = 2^2 + 2^2 + 0^2 + 0^2
9 = 3^2 + 0^2 + 0^2 + 0^2, 2^2 + 2^2 + 1^2 + 0^2
10 = 3^2 + 1^2 + 0^2 + 0^2, 2^2 + 2^2 + 1^2 + 1^2
11 = 3^2 + 1^2 + 1^2 + 0^2
12 = 3^2 + 1^2 + 1^2 + 1^2, 2^2 + 2^2 + 2^2 + 0^2
.
.
.
and so forth
As I did some initial working in my head, I thought that it would be only the perfect squares that had more than 1 possible solution. However after listing them out it seems to me there is no obvious order to them. However, I thought of an algorithm I think is most appropriate for this situation:
The important thing is to use a 4-tuple (a, b, c, d). In any given 4-tuple which is a solution to a^2 + b^2 + c^2 + d^2 = n, we will set ourselves a constraint that a is always the largest of the 4, b is next, and so on and so forth like:
a >= b >= c >= d
Also note that a^2 cannot be less than n/4, otherwise the sum of the squares will have to be less than n.
Then the algorithm is:
1a. Obtain floor(square_root(n)) # this is the maximum value of a - call it max_a
1b. Obtain the first value of a such that a^2 >= n/4 - call it min_a
2. For a in a range (min_a, max_a)
At this point we have selected a particular a, and are now looking at bridging the gap from a^2 to n - i.e. (n - a^2)
3. Repeat steps 1a through 2 to select a value of b. This time instead of finding
floor(square_root(n)) we find floor(square_root(n - a^2))
and so on and so forth. So the entire algorithm would look something like:
1a. Obtain floor(square_root(n)) # this is the maximum value of a - call it max_a
1b. Obtain the first value of a such that a^2 >= n/4 - call it min_a
2. For a in a range (min_a, max_a)
3a. Obtain floor(square_root(n - a^2))
3b. Obtain the first value of b such that b^2 >= (n - a^2)/3
4. For b in a range (min_b, max_b)
5a. Obtain floor(square_root(n - a^2 - b^2))
5b. Obtain the first value of b such that b^2 >= (n - a^2 - b^2)/2
6. For c in a range (min_c, max_c)
7. We now look at (n - a^2 - b^2 - c^2). If its square root is an integer, this is d.
Otherwise, this tuple will not form a solution
At steps 3b and 5b I use (n - a^2)/3, (n - a^2 - b^2)/2. We divide by 3 or 2, respectively, because of the number of values in the tuple not yet 'fixed'.
An example:
doing this on n = 12:
1a. max_a = 3
1b. min_a = 2
2. for a in range(2, 3):
use a = 2
3a. we now look at (12 - 2^2) = 8
max_b = 2
3b. min_b = 2
4. b must be 2
5a. we now look at (12 - 2^2 - 2^2) = 4
max_c = 2
5b. min_c = 2
6. c must be 2
7. (n - a^2 - b^2 - c^2) = 0, hence d = 0
so a possible tuple is (2, 2, 2, 0)
2. use a = 3
3a. we now look at (12 - 3^2) = 3
max_b = 1
3b. min_b = 1
4. b must be 1
5a. we now look at (12 - 3^2 - 1^2) = 2
max_c = 1
5b. min_c = 1
6. c must be 1
7. (n - a^2 - b^2 - c^2) = 1, hence d = 1
so a possible tuple is (3, 1, 1, 1)
These are the only two possible tuples - hey presto!
nebffa has a great answer. one suggestion:
step 3a: max_b = min(a, floor(square_root(n - a^2))) // since b <= a
max_c and max_d can be improved in the same way too.
Here is another try:
1. generate array S: {0, 1, 2^2, 3^2,.... nr^2} where nr = floor(square_root(N)).
now the problem is to find 4 numbers from the array that sum(a, b,c,d) = N;
2. according to neffa's post (step 1a & 1b), a (which is the largest among all 4 numbers) is between [nr/2 .. nr].
We can loop a from nr down to nr/2 and calculate r = N - S[a];
now the question is to find 3 numbers from S the sum(b,c,d) = r = N -S[a];
here is code:
nr = square_root(N);
S = {0, 1, 2^2, 3^2, 4^2,.... nr^2};
for (a = nr; a >= nr/2; a--)
{
r = N - S[a];
// it is now a 3SUM problem
for(b = a; b >= 0; b--)
{
r1 = r - S[b];
if (r1 < 0)
continue;
if (r1 > N/2) // because (a^2 + b^2) >= (c^2 + d^2)
break;
for (c = 0, d = b; c <= d;)
{
sum = S[c] + S[d];
if (sum == r1)
{
print a, b, c, d;
c++; d--;
}
else if (sum < r1)
c++;
else
d--;
}
}
}
runtime is O(sqare_root(N)^3).
Here is the test result running java on my VM (time in milliseconds, result# is total num of valid combination, time 1 with printout, time2 without printout):
N result# time1 time2
----------- -------- -------- -----------
1,000,000 1302 859 281
10,000,000 6262 16109 7938
100,000,000 30912 442469 344359

Resources