Possibly a silly question, but - is there a way to save a jitted function? The current compilation time for a specific function I'm using is more than a couple of hours, and I'd be very happy to save it for later use (later on the same machine, and ideally - on a different machine altogether; I guess this is problematic, since compilation depends on the machine?). I've tried to pickle it, and got: AttributeError: Can't pickle local object '_jit.<locals>.f_jitted'. Is this somehow possible?
In principle this is possible, but it would be a bunch of engineering work.
I suspect our first step should be to try to identify why the compilation time is so long. "hours" is unreasonable. Would you be able to share a reproduction of the problem?
I'm also interested in getting those compilation times down. I'm always impressed at how patient early users are; if a compile time is longer than a minute, I'd give up!
Can you say more about what your program looks like? Are there likely some large unrolled loops? Would it be feasible to apply @jit only to smaller subfunctions without sacrificing performance? (With the upcoming asynchronous execution feature, it should get even easier to maximize performance while only jit-ing smaller functions.)
EDIT: somehow I missed the second part of Peter's comment, which says the same thing!
Sure! I'm trying to write a Metropolis-Hastings sampler that samples from (unnormalised) maximum entropy distributions with arbitrary constraints (or observables, if you prefer the physics jargon) over binary vectors.
A simple reproducible example:
import jax.numpy as np
import jax.random as random
from jax import jit,vmap
from jax.ops import index_update
import numpy as onp
from functools import partial
import itertools as it
N = 30
marg_1 = lambda i,x:x[i]
marg_2 = lambda i,j,x:x[i]*x[j]
marg_1s = [jit(partial(marg_1,i)) for i in range(N)]
marg_2s = [jit(partial(marg_2,i,j)) for i,j in list(it.combinations(range(N),r=2))]
funcs = marg_1s+marg_2s
@jit
def calc_e(factors,word):
return np.sum(factors*np.array([func(word) for func in funcs]))
factors = np.array(onp.random.randn(len(funcs)))
def sample(key,n_samps):
state = random.randint(key,minval=0,maxval=2, shape=(N,))
unifs = random.uniform(key, shape=(n_samps*N,))
all_states = np.zeros((n_samps,N))
for j in range(n_samps*N):
if j%N==0:
all_states = index_update(all_states,j//N,state)
state_flipped = index_update(state,j%N,1-state[j%N])
#calc energy difference
dE = calc_e(factors,state_flipped)-calc_e(factors,state)
accept = ((dE < 0) | (unifs[j] < np.exp(-dE)))
state = np.where(accept, state_flipped, state)
return all_states
sample = jit(sample, static_argnums=(1,))
key = random.PRNGKey(1)
n_samps=10
%time sample(key,n_samps)
Here, N is the size of the binary vectors (word above), and factors are the Lagrange multipliers in the MaxEnt distribution. The above sampling procedure takes a couple of seconds on my laptop with N=3 and n_samps=10, but blows up to hours when either of these becomes not-that-much-larger. This example is an Ising model; eventually I'm interested in supporting arbitrary functions and larger N.
Very cool! Thanks for the example.
We can make the compile time invariant to n_samps by rolling a loop:
import jax.numpy as np
import jax.random as random
from jax import jit,vmap
from jax.ops import index_update
from jax.lax import fori_loop
import numpy as onp
from functools import partial
import itertools as it
N = 30
marg_1 = lambda i,x:x[i]
marg_2 = lambda i,j,x:x[i]*x[j]
marg_1s = [jit(partial(marg_1,i)) for i in range(N)]
marg_2s = [jit(partial(marg_2,i,j)) for i,j in list(it.combinations(range(N),r=2))]
funcs = marg_1s+marg_2s
@jit
def calc_e(factors,word):
return np.sum(factors*np.array([func(word) for func in funcs]))
factors = np.array(onp.random.randn(len(funcs)))
def sample(key, n_samps):
state = random.randint(key,minval=0,maxval=2, shape=(N,))
unifs = random.uniform(key, shape=(n_samps*N,))
def run_mh(j, loop_carry):
state, all_states = loop_carry
all_states = index_update(all_states,j//N,state) # a bit wasteful
state_flipped = index_update(state,j%N,1-state[j%N])
dE = calc_e(factors,state_flipped)-calc_e(factors,state)
accept = ((dE < 0) | (unifs[j] < np.exp(-dE)))
state = np.where(accept, state_flipped, state)
return state, all_states
all_states = np.zeros((n_samps,N))
all_states = fori_loop(0, n_samps * N, run_mh, (state, all_states))
return all_states
sample = jit(sample, static_argnums=(1,))
You could imagine alternatives like only rolling up an inner loop that performs one sweep, or keeping that unrolled and rolling an outer loop over n_samp, but here I kept both flattened into one rolled loop. We could avoid doing the wasteful all_states = indexed_update(all_states, ...) on every iteration, either by using a lax.cond or else keeping the outer loop unrolled and un-jitted (i.e. just roll one sweep, or a fixed number of sweeps, into a loop).
I'm afraid this doesn't much help how the compilation time scales with N though, since the loop over funcs is still unrolled in the energy calculation. To improve that we might need XLA to stop inlining all functions...
What do you think?
What kind of magic is this? 馃槃 Compilation time is now ~4 seconds with N=30, and after compilation, generating 1000 samples takes ~25ms. This is incredible - thank you @mattjj !
Regarding the scaling of the compilation time with N - I've been trying to vectorize the energy calculation (#673), but this seems impractical for arbitrary functions (but might provide a speed-up in specific cases, eg Ising model). Anyway, this is less of a problem at the moment.
Regarding all_states = indexed_update(all_states, ...) - just to make sure I understand, this is wasteful because we're updating the j//N row of all_states N times instead of 1 time, right? I've tried replacing it with:
all_states = lax.cond(j%N==0,state,lambda x:index_update(all_states,j//N,x),all_states,lambda x:x)
But this increased sampling times by a factor of 10, so I'm probably not using cond properly...
What kind of magic is this? :smile:
The mostly-undocumented kind, of course!
I'm really glad it helped.
There are some notes on control flow primitives, namely lax.cond and lax.while_loop, in the gotchas notebook. Their main purpose of lax.while_loop (of which lax.fori_loop is a convenience wrapper) is to reduce compile times.
Regarding
all_states = indexed_update(all_states, ...)- just to make sure I understand, this is wasteful because we're updating thej//Nrow ofall_statesN times instead of 1 time, right?
Yes, that's right.
I think lax.cond (which directly translates to the XLA Conditional HLO) might be a bit of a performance cliff. In XLA its main purpose is to gate side-effecting operations (of which there are a couple). In principle a functionally pure langauge like (most of) XLA could decide to evaluate the subexpressions of something like np.where(j % N, all_states, index_update(all_states, j//N, x)) lazily and not do any extra work, so there's no in-principle need for lax.cond unless you have side-effects, though I'm told it's unlikely XLA will add this behavior. (JAX doesn't expose any way to stage out side-effecting operations yet, so the only purpose of lax.cond is to save compute when there is a really expensive but rare branch on one side.)
We just had a meeting with XLA:{C,G}PU teams yesterday, and the two main JAX priorities we communicated to them were (1) adding support for calling custom kernels on GPU and (2) reducing compile times for programs like these (which may mean avoiding function inlining). The prospects for tackling 2 sounded surprisingly good, though a general solution may still be months away.
That's super interesting, thanks for sharing!
I've tried switching the order of true/false functions in cond, and also using np.where, but index_update was faster than both (compilation and runtime). I still need to make sure this doesn't introduce MCMC issues or biases, but that's a different issue. 馃槃
Regarding the original subject - I think that an interesting use case for saving compiled functions would be during post-installation testing of some package that uses JAX, since testing time is (usually) less expensive than runtime-time. I'm sure there are many problems here I can't even think of (perhaps various version changes after compilation?), but it might be a useful feature.
I ran into a similar issue Can't pickle local object when trying to parallelize the jit-compilation of a large number of sub-functions (the function I try to jit is too large for the compiler to handle). Is there any other way to use all my CPU cores to do the jit compilation in parallel, aside from standard Python multi-process which will invoke pickling? My current code spends about 500s jit compiling followed by 1s of actual computation. Unfortunately the compiler only uses one thread, and there are no flag to speed up compilation.
Any chance you could provide a code example of what you're trying to do? It's hard to make concrete suggestions without code to discuss. Even better, we can pass slow-to-compile examples along to the XLA team for improvement.
One generic strategy, though, is just not to use @jit on your top-level function. You can use it on smaller sub-functions instead. You should still get good performance, and we've been bringing dispatch overheads down significantly in the last few weeks, so the speed of calling multiple smaller @jit computations is only getting better.
(My comment about the smaller sub-functions was more about #679, though from the context here it sounds like you may already have broken things down.)
Edit: The 'solution' I suggested below does not actually do anything helpful. See more details in https://github.com/google/jax/issues/679#issuecomment-707627988
I also came up against the problem of trying to pickle a jitted function when trying to run parallel MCMC chains using JAX to calculate the model functions and derivatives. In case it helps anyone else having this problem who comes across this issue, I found the following workaround.
While the inbuilt multiprocessing module has problems with pickling any transformed JAX functions due to the use of nested / lambda functions, the multiprocess package which uses dill to perform the pickling seems to be able to be used without problems with non-jitted JAX functions or functions that call not-jitted JAX functions. However when for example using a jitted JAX function as the func argument of multiprocessing.Pool.map we seem to get a deadlock.
To get around this I found if you only apply the JIT transformation within the child process things work fine. Providing you can ensure the function to be parallelised / jitted is not called in the parent process first, you can simply replace any jit decorators with the following 'delayed' version that only applies the JIT transform on the first call to the function:
from jax import api
def delayed_jit(func, *jit_args, **jit_kwargs):
jitted_func = None
def wrapped(*args, **kwargs):
nonlocal jitted_func
if jitted_func is None:
jitted_func = api.jit(func, *jit_args, **jit_kwargs)
return jitted_func(*args, **kwargs)
return wrapped
We can then use this delayed_jit decorator in place of jit as in the following simple example
from multiprocess import Pool
import numpy as onp
import jax.numpy as np
@delayed_jit
def norm(x):
return np.sum(x**2)**0.5
grad_norm = delayed_jit(api.grad(norm))
rng = onp.random.RandomState(1234)
vectors = rng.standard_normal((100, 10))
pool = Pool(4)
norm_vectors = pool.map(norm, vectors)
grad_norm_vectors = pool.map(grad_norm, vectors)
Hi! There are many open issues related to long compile times, and I assume many wouldn't mind if they could store and load the compiled function. However, in some use cases even seconds are a lot (e.g. large number of different functions) and figuring out what is wrong seems not easy. In my case I have this
import jax
import jax.numpy as j_numpy
import jax.scipy.stats as js_stats
import time
def f(response, parameters, control_parameters):
_e8 = j_numpy.exp(-parameters[0]*control_parameters['x'])
_e10 = -parameters[1]
_e12 = j_numpy.exp(_e10*control_parameters['x'])
_e15 = parameters[1] - parameters[0]
_e16 = (_e8 - _e12)*parameters[0]/_e15
_e20 = js_stats.norm.ppf(1 - js_stats.norm.cdf(response[1], loc=_e16, scale=parameters[3]))
return _e20
# _e26 = js_stats.norm.ppf(1 - js_stats.norm.cdf(response[0], loc=_e8, scale=parameters[2]))
# _e27 = _e26**2
# _e33 = parameters[5]**2
# _e38 = (1 - _e33)**0.5
# _e47 = 1 + (_e10*_e8 + parameters[0]*_e12)/_e15
# _e51 = js_stats.norm.ppf(1 - js_stats.norm.cdf(response[2], loc=_e47, scale=parameters[4]))
# _e58 = parameters[6]**2
# _e63 = (1 - _e58)**0.5
# _e70 = (_e51 - parameters[6]*_e26)/_e63
# _e74 = (_e20 - parameters[5]*_e26)/_e38
# _e81 = parameters[7]**2
# result = j_numpy.exp(parameters[5]*(0.5*parameters[5]*(_e20**2 +
# _e27) - _e20*_e26)/(_e33 - 1))/_e38 * (j_numpy.exp(parameters[6]*
# (0.5*parameters[6]*(_e51**2 + _e27) - _e51*_e26)/(_e58 -
# 1))/_e63)*(j_numpy.exp(parameters[7]*(0.5*parameters[7]*
# (_e70**2 + _e74**2) - _e70*_e74)/(_e81 - 1))/(1 - _e81)**0.5)*\
# js_stats.norm.pdf(response[0], loc=_e8, scale=parameters[2])*\
# js_stats.norm.pdf(response[1], loc=_e16, scale=parameters[3])*\
# js_stats.norm.pdf(response[2], loc=_e47, scale=parameters[4])
# return result
parameters = j_numpy.array([0.7, 0.2, 1., 2., 3., 1/2, 1/3, 0.11])
control_parameters = dict(x=1.23)
response = [0.12, 0.23, 0.34]
print('f:', f(response, parameters, control_parameters))
get_hessian = jax.hessian(lambda response, parameters, control_parameters, f=f: j_numpy.log(f(response, parameters, control_parameters)), argnums=1)
get_hessian = jax.jit(get_hessian)
begin = time.time()
print('d2f:', get_hessian(response, parameters, control_parameters))
end = time.time()
print('elapsed: ', end - begin)
begin = time.time()
get_hessian(response, parameters, control_parameters)
end = time.time()
print('elapsed: ', end - begin)
On my machine, this takes 15s to compile (CPU backend) and 1ms to evaluate. Using the full version of f leads to 3 minutes and evaluates just as fast. I really don't know what to change to reduce compile time.
Btw, thanks for your efforts and useful support!
@matt-graham I'm trying to follow your suggestion and it complains "Can't pickle local object delayed_jit". Any idea?
@gnool Are you using dill / multiprocess as opposed to the inbuilt pickle / multiprocessing? The former is able to serialize a much wider range of types including local / nested functions and it looks like the error you are getting maybe due to using the in-built pickle.
More generally with regards to my suggestion: now that I understand a little better about how jit works, I think the delayed_jit decorator I suggest above doesn't actually do anything useful. jit already by construction returns a wrapped version of the decorated function which only performs the tracing / compilation on its first call. The same example as I used above also works if we just directly use jit:
from multiprocess import Pool
import numpy as onp
import jax.numpy as np
from jax import api
@api.jit
def norm(x):
return np.sum(x**2)**0.5
grad_norm = api.jit(api.grad(norm))
rng = onp.random.RandomState(1234)
vectors = rng.standard_normal((100, 10))
pool = Pool(4)
norm_vectors = pool.map(norm, vectors)
grad_norm_vectors = pool.map(grad_norm, vectors)
However there is still a brittleness inasmuch that if we execute any JAX operation in the main process before serializing the jitted function(s) to send it to a new process we get a deadlock, e.g. as in the following
from multiprocess import Pool
import numpy as onp
import jax.numpy as np
from jax import api
@api.jit
def norm(x):
return np.sum(x**2)**0.5
# Execute arbitrary JAX operation in main process
y = np.sin(0)
rng = onp.random.RandomState(1234)
vectors = rng.standard_normal((100, 10))
pool = Pool(4)
# The following call now deadlocks
norm_vectors = pool.map(norm, vectors)
@matt-graham Thanks a lot for your clarification. I have indeed used the native package, sorry for getting confused with multiprocess vs multiprocessing!
Regarding your comment on the strict requirement that we should avoid performing any JAX operation in the main process before serializing the jitted function, I have a follow-up question. I have N to-be-jitted unique functions (each take about similar run time in terms of compilation), and am thinking if I can use multiprocess (perhaps the Process submodule?) to JIT each of these function? Does this by any chance violate the condition you mentioned above?
Perhaps a more general question is whether since the last time you posted your workaround on this, have you discovered other ways to work around this initial compilation time? (e.g. use multi-cores to perform compilation, or save the compilation results, etc.) I am using this for an application that is unfortunately very sensitive to overall run-time, and currently the compilation step is at least 10,000 times slower than actual JIT'ed function evaluation time.
Most helpful comment
Edit: The 'solution' I suggested below does not actually do anything helpful. See more details in https://github.com/google/jax/issues/679#issuecomment-707627988
I also came up against the problem of trying to pickle a jitted function when trying to run parallel MCMC chains using JAX to calculate the model functions and derivatives. In case it helps anyone else having this problem who comes across this issue, I found the following workaround.
While the inbuilt
multiprocessingmodule has problems with pickling any transformed JAX functions due to the use of nested / lambda functions, themultiprocesspackage which usesdillto perform the pickling seems to be able to be used without problems with non-jitted JAX functions or functions that call not-jitted JAX functions. However when for example using a jitted JAX function as thefuncargument ofmultiprocessing.Pool.mapwe seem to get a deadlock.To get around this I found if you only apply the JIT transformation within the child process things work fine. Providing you can ensure the function to be parallelised / jitted is not called in the parent process first, you can simply replace any
jitdecorators with the following 'delayed' version that only applies the JIT transform on the first call to the function:We can then use this
delayed_jitdecorator in place ofjitas in the following simple example