Here is a repro script
import jax
import jax.numpy as np
from jax import random, lax, jit
def welford_covariance():
def init_fn(size):
return np.zeros(size), np.zeros(size), 0
def update_fn(sample, state):
mean, m2, n = state
n = n + 1
delta_pre = sample - mean
mean = mean + delta_pre / n
delta_post = sample - mean
m2 = m2 + delta_pre * delta_post
return mean, m2, n
def final_fn(state):
mean, m2, n = state
cov = m2 / (n - 1)
cov_inv_sqrt = np.sqrt(np.reciprocal(cov))
return cov, cov_inv_sqrt
return init_fn, update_fn, final_fn
def warmup_adapter():
mm_init, mm_update, mm_final = welford_covariance()
def init_fn(z, rng, mass_matrix_size):
inverse_mass_matrix = np.ones(mass_matrix_size)
mass_matrix_sqrt = inverse_mass_matrix
mm_state = mm_init(mass_matrix_size)
return (inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng)
def _update_at_window_end(z, rng_ss, state):
inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng = state
inverse_mass_matrix, mass_matrix_sqrt = mm_final(mm_state)
mm_state = mm_init(inverse_mass_matrix.shape[-1])
return (inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng)
def update_fn(t, accept_prob, z, state):
inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng = state
rng, rng_ss = random.split(rng)
state = (inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng)
state = lax.cond(t < 10,
(z, rng_ss, state), lambda args: _update_at_window_end(*args),
state, lambda x: x)
return state
return init_fn, update_fn
wa_init, wa_update = warmup_adapter()
wa_update = jit(wa_update) # uncomment this will make it fast
z = np.ones(3)
wa_state = wa_init(z, random.PRNGKey(0), mass_matrix_size=3)
import time
for t in range(10):
tic = time.time()
wa_state = wa_update(t, 0.1 * t, z, wa_state)
print(time.time() - tic)
which returns
0.0958707332611084
0.08851313591003418
0.08699154853820801
0.09005379676818848
0.08801078796386719
0.08790731430053711
0.09052276611328125
0.08893036842346191
0.0877068042755127
0.08900189399719238
while using non-jit wa_update, we get
0.12209296226501465
0.002318859100341797
0.0018045902252197266
0.002086162567138672
0.0018031597137451172
0.0017747879028320312
0.0020046234130859375
0.002245187759399414
0.0020575523376464844
0.0024704933166503906
I think this is just a typo somewhere (similar to #1237).
cc @neerajprad
Thanks for flagging this. As a stop-gap, you might need to tell your users to pip install jax==0.1.41.
We'll get this squashed ASAP though.
It's hitting a compile every time (not sure why yet). I added a print("COMPILING") before this line and it's printing for every iteration. Pretty much all of these bad performance issues come from missing the compilation cache.
I hereby promise that in the PR closing this bug I'll also add at least a print statement there that can be enabled with a flag! That will make it super easy to check when a performance problem is a recompilation issue.
7539325 will squash this, included in #1240. The fix needed was this line.
I made a mistake in compilation caching when I merged a change in #1224 with a change made on master that affected how xla.abstractify works on DeviceArrays.
The real lesson here is we need tests to avoid performance regressions.
It's hitting a compile every time
Thanks @mattjj, I think so too.
It seems related to an issue I got a few days ago: the code uses all GPU memory while I just run CPU code (with jaxlib 0.1.23 installed from https://storage.googleapis.com/jax-releases). Things are back to normal again when I install jaxlib version from pypi. I'll check if #1240 fixes it too.
Most helpful comment
It's hitting a compile every time (not sure why yet). I added a
print("COMPILING")before this line and it's printing for every iteration. Pretty much all of these bad performance issues come from missing the compilation cache.I hereby promise that in the PR closing this bug I'll also add at least a print statement there that can be enabled with a flag! That will make it super easy to check when a performance problem is a recompilation issue.