The following repro script, which is taken from #1239,
import jax
from jax.config import config; config.update('jax_platform_name', 'cpu')
import jax.numpy as np
from jax import random, lax, jit
def welford_covariance():
def init_fn(size):
return np.zeros(size), np.zeros(size), 0
def update_fn(sample, state):
mean, m2, n = state
n = n + 1
delta_pre = sample - mean
mean = mean + delta_pre / n
delta_post = sample - mean
m2 = m2 + delta_pre * delta_post
return mean, m2, n
def final_fn(state):
mean, m2, n = state
cov = m2 / (n - 1)
cov_inv_sqrt = np.sqrt(np.reciprocal(cov))
return cov, cov_inv_sqrt
return init_fn, update_fn, final_fn
def warmup_adapter():
mm_init, mm_update, mm_final = welford_covariance()
def init_fn(z, rng, mass_matrix_size):
inverse_mass_matrix = np.ones(mass_matrix_size)
mass_matrix_sqrt = inverse_mass_matrix
mm_state = mm_init(mass_matrix_size)
return (inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng)
def _update_at_window_end(z, rng_ss, state):
inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng = state
inverse_mass_matrix, mass_matrix_sqrt = mm_final(mm_state)
mm_state = mm_init(inverse_mass_matrix.shape[-1])
return (inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng)
def update_fn(t, accept_prob, z, state):
inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng = state
rng, rng_ss = random.split(rng)
state = (inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng)
state = lax.cond(t < 10,
(z, rng_ss, state), lambda args: _update_at_window_end(*args),
state, lambda x: x)
return state
return init_fn, update_fn
wa_init, wa_update = warmup_adapter()
wa_update = jit(wa_update) # uncomment this will make it fast
z = np.ones(3)
wa_state = wa_init(z, random.PRNGKey(0), mass_matrix_size=3)
import time
for t in range(10):
tic = time.time()
wa_state = wa_update(t, 0.1 * t, z, wa_state)
print(time.time() - tic)
causes
[mutex.cc : 419] RAW: Lock blocking 0x7fbae0006640 @ 0x7fbbc9577eaf 0x7fbbc95785a6 0x7fbbc6a81236 0x7fbbc9576913 0x7fbbc9463f12 0x7fbbf463c421
[mutex.cc : 419] RAW: Unlock 0x7fbae0006640 @ 0x7fbbc6a80ae2 0x7fbbc9578088 0x7fbbc95785a6 0x7fbbc6a81236 0x7fbbc9576913 0x7fbbc9463f12 0x7fbbf463c421
using https://storage.googleapis.com/jax-releases/cuda100/jaxlib-0.1.23-cp36-none-linux_x86_64.whl.
This issue just happens after the recent refactoring of JAX (related to no more tuple). The code runs fine if we change config.update('jax_platform_name', 'cpu') to config.update('jax_platform_name', 'gpu') or if we use jaxlib from pypi.
cc @mattjj @skye
One of the amazing XLA:{C,G}PU team members @cheshire tried to repro this internally (meaning against the most up-to-date XLA code) and the bug didn鈥檛 appear. The optimistic explanation for that is this bug is already fixed and we just need to push out an updated jaxlib. The pessimistic explanation is that it鈥檚 just hard to repro.
We鈥檒l update jaxlib wheels today and hope that fixes things.
Thanks @mattjj, that's a great news!
@fehiepsi we just uploaded jaxlib 0.1.25 wheels. Does it still repro?
@mattjj The issue still happens, though I don't see mutex.cc messages anymore. When I uncomment the line wa_update = jit(wa_update), the iteration run for the first step and stuck there (all GPU memory is allocated). I tried to restart my computer and create a new conda environment to test, the issue still happens. Are you able to replicate the issue in your machine?
It turns out that we accidentally updated to an XLA version that is still a few days out of date, so we're going to update it one more time just to rule out that category of explanations. (I haven't tried to repro externally on the new jaxlib wheels yet, but will with 0.1.26.)
This is at the top of our priority list, but just so I understand how many fires we should try to light under people, is this currently blocking your work or your users'?
Hi @mattjj, please don't worry about it. We currently pin the jax version to 0.1.41 so this issue doesn't blocking our work. :)
We just pushed jaxlib 0.1.26 wheels, but unfortunately the bug is still there: I was able to repro using your script on a fresh cloud VM.
We'll follow up with the XLA team and see if we can track this down!
I think I spotted it! I had been thinking this was a jaxlib issue, but I tried rolling back jax a bit and the segfault went away. I realized it must have been xla.py accidentally mixing memory pointers on one device with those on another, or something like that, so I spotted a bug fixed in 434d175 (another of mine, I'm afraid, from the big rewrite!).
I'm not sure why we weren't able to repro this internally. I'm working on a CI test we can use for this to avoid a regression, but right now I'm able to verify that this fixes the issue on a GPU cloud VM.
This test can catch it but only if run in isolation (i.e. not with all the other tests, because of how platform stuff is cached I think):
def test_instantiate_device_constant_set_platform(self):
# cf. issue 1241
try:
prev_platform = FLAGS.jax_platform_name
FLAGS.jax_platform_name = "cpu"
jit(lambda x: x)(np.zeros(2)) # doesn't segfault
finally:
FLAGS.jax_platform_name = prev_platform
Thanks @mattjj ! That should be the reason.
Uploaded jax 0.1.43 on pypi with this fix.
Most helpful comment
I think I spotted it! I had been thinking this was a jaxlib issue, but I tried rolling back jax a bit and the segfault went away. I realized it must have been xla.py accidentally mixing memory pointers on one device with those on another, or something like that, so I spotted a bug fixed in 434d175 (another of mine, I'm afraid, from the big rewrite!).
I'm not sure why we weren't able to repro this internally. I'm working on a CI test we can use for this to avoid a regression, but right now I'm able to verify that this fixes the issue on a GPU cloud VM.