jax minimization with stochastically estimated gradients - random

I'm trying to use the bfgs optimizer from tensorflow_probability.substrates.jax and from jax.scipy.optimize.minimize to minimize a function f which is estimated from pseudo-random samples and has a jax.random.PRNGKey as argument. To use this function with the jax/tfp bfgs minimizer, I wrap the function inside a lambda function
seed = 100
key = jax.random.PRNGKey(seed)
fun = lambda x: return f(x,key)
result = jax.scipy.optimize.minimize(fun = fun, ...)
What is the best way to update the key when the minimization routine calls the function to be minimized so that I use different pseudo-random numbers in a reproducible way? Maybe a global key variable? If yes, is there an example I could follow?
Secondly, is there a way to make the optimization stop after a certain amount of time, as one could do with a callback in scipy? I could directly use the scipy implementation of bfgs/ l-bfgs-b/ etc and use jax ony for the estimation of the function and of tis gradients, which seems to work. Is there a difference between the scipy, jax.scipy and tfp.jax bfgs implementations?
Finally, is there a way to print the values of the arguments of fun during the bfgs optimization in jax.scipy or tfp, given that f is jitted?
Thank you!

There is no way to do what you're asking with jax.scipy.optimize.minimize, because the minimizer does not offer any means to track changing state between function calls, and does not provide for any inbuilt stochasticity in the optimizer.
If you're interested in stochastic optimization in JAX, you might try stochastic optimization in JAXOpt, which provides a much more flexible set of optimization routines.
Regarding your second question, if you'd like to print values during the course of a jit-compiled optimization or other loop, you can use jax.debug.print.

Related

matlab: inefficient "mysum"

in order to estimate some non-linear models, i need to derive a quite huge function numerically. one part of the target function includes a polynomial (which is created by some sums).
there are quite a lot of iterations in this process and my computer takes way to much time to compute (though it yields reasonable estimates). the profiler claims, that my handwritten sum-function is by far the most time-consuming part of the algorithm. this is my first matlab project, so im pretty new to it. maybe you can help to optimize it :)
function [output] = mysum(a,b,inputfun)
output=0;
for i=a:b
output=ouput+inputfun(i);
end
if you want to know, how i use it. this is the polynomial:
function [ weights ] = wexpo(theta)
global lag;
for i=1:lag
weights(i) = exp(mysum(1,length(theta),#(k) theta(k)*(i-1)^k))...
/mysum(0,lag-1,#(j)...
exp(mysum(1,length(theta),#(k) theta(k)*j^k)));
end
If you can use Matlab functions:
function [output] = mysum(a,b,inputfun)
output = sum(inputfun(a:b))

Fast check if element is in MATLAB matrix

I would like to verify whether an element is present in a MATLAB matrix.
At the beginning, I implemented as follows:
if ~isempty(find(matrix(:) == element))
which is obviously slow. Thus, I changed to:
if sum(matrix(:) == element) ~= 0
but this is again slow: I am calling a lot of times the function that contains this instruction, and I lose 14 seconds each time!
Is there a way of further optimize this instruction?
Thanks.
If you just need to know if a value exists in a matrix, using the second argument of find to specify that you just want one value will be slightly faster (25-50%) and even a bit faster than using sum, at least on my machine. An example:
matrix = randi(100,1e4,1e4);
element = 50;
~isempty(find(matrix(:)==element,1))
However, in recent versions of Matlab (I'm using R2014b), nnz is finally faster for this operation, so:
matrix = randi(100,1e4,1e4);
element = 50;
nnz(matrix==element)~=0
On my machine this is about 2.8 times faster than any other approach (including using any, strangely) for the example provided. To my mind, this solution also has the benefit of being the most readable.
In my opinion, there are several things you could try to improve performance:
following your initial idea, i would go for the function any to test is any of the equality tests had a success:
if any(matrix(:) == element)
I tested this on a 1000 by 1000 matrix and it is faster than the solutions you have tested.
I do not think that the unfolding matrix(:) is penalizing since it is equivalent to a reshape and Matlab does this in a smart way where it does not actually allocate and move memory since you are not modifying the temporary object matrix(:)
If your does not change between the calls to the function or changes rarely you could simply use another vector containing all the elements of your matrix, but sorted. This way you could use a more efficient search algorithm O(log(N)) test for the presence of your element.
I personally like the ismember function for this kind of problems. It might not be the fastest but for non critical parts of the code it greatly improves readability and code maintenance (and I prefer to spend one hour coding something that will take day to run than spending one day to code something that will run in one hour (this of course depends on how often you use this program, but it is something one should never forget)
If you can have a sorted copy of the elements of your matrix, you could consider using the undocumented Matlab function ismembc but remember that inputs must be sorted non-sparse non-NaN values.
If performance really is critical you might want to write your own mex file and for this task you could even include some simple parallelization using openmp.
Hope this helps,
Adrien.

Derivative of a program

Let us assume you can represent a program as mathematical function, that's possible. How does the program representation of the first derivative of that function look like? Is there a way to transform a program to its "derivative" form, and does this make sense at all?
Yes it does make sense, it's known as Automatic Differentiation. There are one or two experimental compilers which can do this, for example NAGware's Differentiation Enabled Fortran Compiler Technology. And there are a lot of research papers on the topic. I suggest you get Googling.
First, it only makes sense to try to get the derivative of a pure function (one that does not affect external state and returns the exact same output for every input). Second, the type system of many programming languages involves a lot of step functions (e.g. integers), meaning you'd have to get your program to work in terms of continuous functions in order to get a valid first derivative. Third, getting the derivative of any function involves breaking it down and manipulating it symbolically. Thus, you can't get the derivative of a function without knowing how what operations it is made of. This could be achieved with reflection.
You could create a derivative approximation function if your programming language supports closures (that is, nested functions and the ability to put functions into variables and return them). Here is a JavaScript example taken from http://en.wikipedia.org/wiki/Closure_%28computer_science%29 :
function derivative(f, dx) {
return function(x) {
return (f(x + dx) - f(x)) / dx;
};
}
Thus, you could say:
function f(x) { return x*x; }
f_prime = derivative(f, 0.0001);
Here, f_prime will approximate function(x) {return 2*x;}
If a programming language implemented higher-order functions and enough algebra, one could implement a real derivative function in it. That would be really cool.
See Lambda the Ultimate discussions on Derivatives and dissections of data types and Derivatives of Regular Expressions
How do you define the mathematical function of a program?
A derivative represent the rate of change of a function. If your function isn't continuous its derivative will be undefined over most of the domain.
I'm just gonna say that this doesn't make a lot of sense, as a program is much more abstract and "ruleless" than a mathematical function. As a derivative is a measure of the change in output as the input changes, there are certainly some programs where this could apply. However, you'd need to be able to quantify your input/output both in numerical terms.
Since input/output would both numerical, it's reasonable to assume that your program represents or operates similarly to a mathematical function, or series of functions. Hence, you can easily represent a derivative, but it would be no different than converting the mathematical derivative of a function to a computer program.
If the program is denoted as a distribution (Schwartz) then you have some notion of derivative assuming that tests functions models your postcondition (you can still take the limit to get a characteristic function). For instance, the assignment x:=x+1 is associated to the Dirac distribution \delta_{x_0+1} where x_0 is the initial value of the variable x. However, I have no idea what is the computational meaning of \delta_{x_0+1}'.
I am wondering, what if the program your're trying to "derive" uses some form of heursitics ? How can it be derived then ?
Half-jokingly, we all know that all real programs use at least a rand().

Optimizing the computation of a recursive sequence

What is the fastest way in R to compute a recursive sequence defined as
x[1] <- x1
x[n] <- f(x[n-1])
I am assuming that the vector x of proper length is preallocated. Is there a smarter way than just looping?
Variant: extend this to vectors:
x[,1] <- x1
x[,n] <- f(x[,n-1])
Solve the recurrence relationship ;)
In terms of the question of whether this can be fully "vectorized" in any way, I think the answer is probably "no". The fundamental idea behind array programming is that operations apply to an entire set of values at the same time. Similarly for questions of "embarassingly parallel" computation. In this case, since your recursive algorithm depends on each prior state, there would be no way to gain speed from parallel processing: it must be run serially.
That being said, the usual advice for speeding up your program applies. For instance, do as much of the calculation outside of your recursive function as possible. Sort everything. Predefine your array lengths so they don't have to grow during the looping. Etc. See this question for a similar discussion. There is also a pseudocode example in Tim Hesterberg's article on efficient S-Plus Programming.
You could consider writing it in C / C++ / Fortran and use the handy inline package to deal with the compiling, linking and loading for you.
Of course, your function f() may be a real constraint if that one needs to remain an R function. There is a callback-from-C++-to-R example in Rcpp but this requires a bit more work than just using inline.
Well if you need the entire sequence how fast it can be? assuming that the function is O(1), you cannot do better than O(n), and looping through will give you just that.
In general, the syntax x$y <- f(z) will have to reallocate x every time, which would be very slow if x is a large object. But, it turns out that R has some tricks so that the list replacement function [[<- doesn't reallocate the whole list every time. So I think you can reasonably efficiently do:
x[[1]] <- x1
for (m in seq(2, n))
x[[m]] <- f(x[[m-1]])
The only wasteful aspect here is that you have to generate an array of length n-1 for the for loop, which isn't ideal, but it's probably not a giant issue. You could replace it by a while loop if you preferred. The usual vectorization tricks (lapply, etc.) won't work here...
(The double brackets give you a list element, which is what you probably want, rather than a singleton list.)
For more details, see Chambers (2008). Software for Data Analysis. p. 473-474.

Is there a way to predict unknown function value based on its previous values

I have values returned by unknown function like for example
# this is an easy case - parabolic function
# but in my case function is realy unknown as it is connected to process execution time
[0, 1, 4, 9]
is there a way to predict next value?
Not necessarily. Your "parabolic function" might be implemented like this:
def mindscrew
#nums ||= [0, 1, 4, 9, "cat", "dog", "cheese"]
#nums.pop
end
You can take a guess, but to predict with certainty is impossible.
You can try using neural networks approach. There are pretty many articles you can find by Google query "neural network function approximation". Many books are also available, e.g. this one.
If you just want data points
Extrapolation of data outside of known points can be estimated, but you need to accept the potential differences are much larger than with interpolation of data between known points. Strictly, both can be arbitrarily inaccurate, as the function could do anything crazy between the known points, even if it is a well-behaved continuous function. And if it isn't well-behaved, all bets are already off ;-p
There are a number of mathematical approaches to this (that have direct application to computer science) - anything from simple linear algebra to things like cubic splines; and everything in between.
If you want the function
Getting esoteric; another interesting model here is genetic programming; by evolving an expression over the known data points it is possible to find a suitably-close approximation. Sometimes it works; sometimes it doesn't. Not the language you were looking for, but Jason Bock shows some C# code that does this in .NET 3.5, here: Evolving LINQ Expressions.
I happen to have his code "to hand" (I've used it in some presentations); with something like a => a * a it will find it almost instantly, but it should (in theory) be able to find virtually any method - but without any defined maximum run length ;-p It is also possible to get into a dead end (evolutionary speaking) where you simply never recover...
Use the Wolfram Alpha API :)
Yes. Maybe.
If you have some input and output values, i.e. in your case [0,1,2,3] and [0,1,4,9], you could use response surfaces (basicly function fitting i believe) to 'guess' the actual function (in your case f(x)=x^2). If you let your guessing function be f(x)=c1*x+c2*x^2+c3 there are algorithms that will determine that c1=0, c2=1 and c3=0 given your input and output and given the resulting function you can predict the next value.
Note that most other answers to this question are valid as well. I am just assuming that you want to fit some function to data. In other words, I find your question quite vague, please try to pose your questions as complete as possible!
In general, no... unless you know it's a function of a particular form (e.g. polynomial of some degree N) and there is enough information to constrain the function.
e.g. for a more "ordinary" counterexample (see Chuck's answer) for why you can't necessarily assume n^2 w/o knowing it's a quadratic equation, you could have f(n) = n4 - 6n3 + 12n2 - 6n, which has for n=0,1,2,3,4,5 f(n) = 0,1,4,9,40,145.
If you do know it's a particular form, there are some options... if the form is a linear addition of basis functions (e.g. f(x) = a + bcos(x) + csqrt(x)) then using least-squares can get you the unknown coefficients for the best fit using those basis functions.
See also this question.
You can apply statistical methods to try and guess the next answer, but that might not work very well if the function is like this one (c):
int evil(void){
static int e = 0;
if(50 == e++){
e = e * 100;
}
return e;
}
This function will return nice simple increasing numbers then ... BAM.
That's a hard problem.
You should check out the recurrence relation equation for special cases where it could be possible such a task.

Resources