When trying to run JAX with jaxlib==0.1.44 I run in to a segmentation fault on my machine with Python 3.8 and CUDA 10.2 if I run on GPU. This issue no longer occurs if I downgrade jaxlib to 0.1.43.
I installed jaxlib using the installation instructions in the README for both versions, and I properly set the XLA CUDA directory in both cases to the same location. From what I gather, only jaxlib is changing to generate the segfault.
I tried to do some digging and it seems like the segfault is coming from jaxlib/xla_extension.so, particularly here is what gdb produces:
0x00007fffd6f991e8 in absl::lts_2020_02_25::Mutex::ReaderLock() () from /home/ziyadedher/research/.venv/lib/python3.8/site-packages/jaxlib/xla_extension.so
Reverting to jaxlib==0.1.43 fixes the issue.
>>> jax.__version__
'0.1.63'
>>> jaxlib.__version__
'0.1.44'
>>> tensorflow.__version__
'2.2.0-rc3'
Some system information truncated to show the important bits:
$ nvcc --version
Cuda compilation tools, release 10.2, V10.2.89
$ python --version
Python 3.8.2
$ modinfo nvidia
filename: /lib/modules/5.6.4-arch1-1/extramodules/nvidia.ko.xz
version: 440.82
Could you help us reproduce this issue on our end? A first attempt to do so didn't set off the segfault, but this may be due to differences in our setup.
A set of instructions that recreate the problem, and any minimal Python program that results in the segfault, would be very useful.
@froystig
Yep, I have been trying to find a minimal Python program to reproduce this but I have been running into trouble. The conditions for the segfault seem incredibly strange. It seems to somehow be related to constructing a matplotlib figure while using a global variable to store a JAX PRNG key.
Here is some code that reproduces the error on my main machine and another machine running Arch and another machine running Ubuntu 16.04; I have highlighted two lines of interest:
import matplotlib.pyplot as plt
import jax
key = jax.random.PRNGKey(0)
# key, new_key = jax.random.split(key) # [1]
def get_new_key():
global key
key, new_key = jax.random.split(key)
return new_key
fig = plt.figure() # [2]
new_key = get_new_key()
Uncommenting [1] or commenting [2] causes the segfault to stop occurring on both of my machines. I am still trying to whittle away at this program to come up with something simpler, but this is what I have until now.
Does this reproduce the error on your end?
Hmm, actually it does not seem to depend on the fact that I am mutating a global variable. It seems to arise when I attempt to call a function that uses a PRNG key that is defined in an outer scope after having called plt.figure before I use that key to split:
import matplotlib.pyplot as plt
import jax
key = jax.random.PRNGKey(0)
# key, new_key = jax.random.split(key) # [1]
def new_key():
# key = jax.random.PRNGKey(0) # [2]
_, new_key = jax.random.split(key)
# new_key() # [3]
plt.figure() # [4]
new_key()
Reversing the commenting on any of [1-4] causes the issue to disappear.
Obviously, this is a testament to why using global Python variables is not a good idea. It is still kinda crazy though. Only thing I can think of is that matplotlib is overwriting the global key when calling plt.figure in some way that is causing this. That still doesn't explain why splitting it before calling plt.figure makes it work even after plt.figure though...
I will add some more system information to the main issue to try and help out with reproduction.
Also, I have tried restarting my system and reloading my Nvidia drivers to no avail, in case this is just a weird one-off thing. I am being able to consistently hit this issue across boots and across devices; albeit all using the same versioning of everything I mentioned in the main issue. Only (relevant?) difference across the machines besides hardware is that my main machine runs Arch and the secondary one is running Ubuntu 16.04.
I'm unable to reproduce this running CUDA 10.2, jax 0.1.63, jaxlib 0.1.44, and Python 3.8.2 on a Debian VM with four K80s (although I'm running using CUDA_VISIBLE_DEVICES=0). My Nvidia driver version is 440.33.01, I could try upgrading to 440.82 to match your setup.
What kind of GPU do you have? I also wonder if having a display is somehow important.
Also just tried on an Ubuntu 18.04 VM with V100s and Nvidia driver version 440.64.00, still no dice.
Not sure if my issue is exactly the same as I am unable to reproduce a segfault in @ziyadedher 's example, but I also get consistent segfaults with jaxlib 0.1.44 that get fixed by reverting to 0.1.43.
from functools import partial
import jax
import jax.numpy as np
from jax import grad, jit, vmap
from jax import random
from jax.experimental import optimizers, stax
from jax.experimental.stax import Dense, Relu
M = 1024
N = 100000 # problem seems to not happen with much smaller N and M
net_init, net_apply = stax.serial(
Dense(M), Relu,
Dense(M), Relu,
Dense(10)
)
rng = random.PRNGKey(10)
x_train = np.zeros((N, 100))
x_test = np.zeros((N + 1, 100)) # no segfault if train/test are same size
y_train = np.zeros((N, 10))
y_test = np.zeros((N + 1, 10))
in_shape = (-1, x_train.shape[-1])
out_shape, net_params = net_init(rng, in_shape)
rng, new_rng = random.split(rng)
@jax.partial(jit, static_argnums=3) # doesn't segfault without jitting
def binary_loss(params, inputs, labels, net_apply):
logits = net_apply(params, inputs)
predicted_labels = np.argmax(logits, axis=-1)
targets = np.argmax(labels, axis=-1)
return np.mean(predicted_labels==targets)
test_acc = binary_loss(net_params, x_test, y_test, net_apply)
train_acc = binary_loss(net_params, x_train, y_train, net_apply) # also doesn't segfault if only called once
I'm on Ubuntu 16.04, Python 3.74, Cuda 10.2, Nvidia 440.64.00 with a GTX1080 Ti, using jax 0.1.63, with tf 2.1.0.
The same issue also occurs for me when I tried cuda100 and 101, as well as with tf 2.2.0-rc3.
@skye I am running a 1080Ti. It might be display related, given that it only occurs for me when I construct a matplotlib figure. I'll try reproducing in a VM.
Honestly, the entire thing is very strange. The error itself is somewhat flaky as well. Given @azhou42's experience, might it be that there is some weird memory allocation going on that is causing different parts to segfault on different machines? I was able to reproduce @azhou42's segfault on my machine and it seems to occur on the same Mutex::ReaderLock().
Is xla_extensions.so provided by tensorflow? I might go digging and try bisecting to find an offending commit.
Very strange.
@ziyadedher I'm guessing you're not getting any more of a stacktrace from gdb? xla_extensions.so is indeed provided by tensorflow. If you'd like to dig, you can follow the directions for building jaxlib from source. You'll probably want to use a local tensorflow checkout, you can enable that in your WORKSPACE: https://github.com/google/jax/blob/master/WORKSPACE#L38. Alternatively you can update the commit hash in the tensorflow http_archive above (I suggest commenting out the sha256 line then, so you don't have to deal with it).
I'll try @azhou42's repro next week.
Also seeing this segfault running a simple ReformerLM in trax at model init. Using tensorflow-gpu==1.15.2 (since trax doesn't work with TF2+). Switching to jaxlib 0.1.43 resolves the issue.
Ubuntu 19.10, Python 3.74, Cuda 10.0, Nvidia 440.64.00 with a GTX1070
@azhou42's repro works for me, thanks! I will attempt to find the root cause.
Not sure if this is related, but I got a colab notebook to crash for an unknown reason only when using jaxlib 0.1.44.
Ubuntu 18.04, Python 3.6.9, Cuda 10.1, Nvidia 440.64.00 (I've seen the issue on both Tesla P100 and K80)
Well I didn't make as much progress today as I hoped, but I at least got some stacktraces from @azhou42's repro. The segfault is happening in at least two places.
Stack 1:
Traceback (most recent call first):
File "/home/skyewm/jax/jax/interpreters/xla.py", line 584, in <listcomp>
else h(device_put(x, device)) for h, x in zip(handlers, outs)]
File "/home/skyewm/jax/jax/interpreters/xla.py", line 584, in _execute_trivial
else h(device_put(x, device)) for h, x in zip(handlers, outs)]
File "/home/skyewm/jax/jax/interpreters/xla.py", line 459, in _xla_call_impl
return compiled_fun(*args)
File "/home/skyewm/jax/jax/core.py", line 978, in call_bind
outs = primitive.impl(f, *args, **params)
File "/home/skyewm/jax/jax/interpreters/partial_eval.py", line 177, in process_call
out_flat = call_primitive.bind(fun, *in_consts, **params)
File "/home/skyewm/jax/jax/core.py", line 981, in call_bind
outs = map(full_lower, top_trace.process_call(primitive, f, tracers, params))
File "/home/skyewm/jax/jax/api.py", line 153, in f_jitted
name=flat_fun.__name__)
File "/home/skyewm/jax/jax/numpy/lax_numpy.py", line 1251, in where
return _where(condition, x, y)
File "/home/skyewm/jax/jax/numpy/lax_numpy.py", line 2824, in _argminmax
mask_idxs = where(lax._eq_meet(a, op(a, axis, keepdims=True)), idxs, maxval)
File "/home/skyewm/jax/jax/numpy/lax_numpy.py", line 2806, in argmax
return _argminmax(max, a, axis)
File "segfault.py", line 35, in binary_loss
predicted_labels = np.argmax(logits, axis=-1)
File "/home/skyewm/jax/jax/linear_util.py", line 150, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/skyewm/jax/jax/interpreters/partial_eval.py", line 430, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/home/skyewm/jax/jax/interpreters/xla.py", line 474, in _xla_callable
fun, pvals, instantiate=False, stage_out=True, bottom=True)
File "/home/skyewm/jax/jax/linear_util.py", line 221, in memoized_fun
ans = call(fun, *args)
File "/home/skyewm/jax/jax/interpreters/xla.py", line 457, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, name, *map(arg_spec, args))
File "/home/skyewm/jax/jax/core.py", line 978, in call_bind
outs = primitive.impl(f, *args, **params)
File "/home/skyewm/jax/jax/api.py", line 153, in f_jitted
name=flat_fun.__name__)
File "segfault.py", line 42, in <module>
train_acc = binary_loss(net_params, x_train, y_train, net_apply) # also doesn't segfault if only called once
I narrowed this down and it's happening in the device_put(x, device).
Stack 2:
Traceback (most recent call first):
File "/home/skyewm/jax/jax/interpreters/xla.py", line 459, in _xla_call_impl
return compiled_fun(*args)
File "/home/skyewm/jax/jax/core.py", line 978, in call_bind
outs = primitive.impl(f, *args, **params)
File "/home/skyewm/jax/jax/interpreters/partial_eval.py", line 177, in process_call
out_flat = call_primitive.bind(fun, *in_consts, **params)
File "/home/skyewm/jax/jax/core.py", line 981, in call_bind
outs = map(full_lower, top_trace.process_call(primitive, f, tracers, params))
File "/home/skyewm/jax/jax/api.py", line 153, in f_jitted
name=flat_fun.__name__)
File "/home/skyewm/jax/jax/numpy/lax_numpy.py", line 1251, in where
return _where(condition, x, y)
File "/home/skyewm/jax/jax/numpy/lax_numpy.py", line 2824, in _argminmax
mask_idxs = where(lax._eq_meet(a, op(a, axis, keepdims=True)), idxs, maxval)
File "/home/skyewm/jax/jax/numpy/lax_numpy.py", line 2806, in argmax
return _argminmax(max, a, axis)
File "segfault.py", line 35, in binary_loss
predicted_labels = np.argmax(logits, axis=-1)
File "/home/skyewm/jax/jax/linear_util.py", line 150, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/skyewm/jax/jax/interpreters/partial_eval.py", line 430, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/home/skyewm/jax/jax/interpreters/xla.py", line 474, in _xla_callable
fun, pvals, instantiate=False, stage_out=True, bottom=True)
File "/home/skyewm/jax/jax/linear_util.py", line 221, in memoized_fun
ans = call(fun, *args)
File "/home/skyewm/jax/jax/interpreters/xla.py", line 457, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, name, *map(arg_spec, args))
File "/home/skyewm/jax/jax/core.py", line 978, in call_bind
outs = primitive.impl(f, *args, **params)
File "/home/skyewm/jax/jax/api.py", line 153, in f_jitted
name=flat_fun.__name__)
File "segfault.py", line 43, in <module>
train_acc = binary_loss(net_params, x_train, y_train, net_apply) # also doesn't segfault if only called once
I'll continue debugging tomorrow.
Oh also got the C++ stack, forgot to share that:
#0 0x00007fff4a5a7fb8 in absl::lts_2020_02_25::Mutex::ReaderLock() ()
from /home/skyewm/.local/lib/python3.6/site-packages/jaxlib/xla_extension.so
#1 0x00007fff4a51be0a in stream_executor::Stream::ThenDoHostCallback(std::function<void ()>) ()
from /home/skyewm/.local/lib/python3.6/site-packages/jaxlib/xla_extension.so
#2 0x00007fff47a91875 in xla::LocalDeviceState::ThenExecuteOnCallbackThread(stream_executor::Stream*, std::function<void ()>) const ()
from /home/skyewm/.local/lib/python3.6/site-packages/jaxlib/xla_extension.so
#3 0x00007fff47a7aced in xla::PyLocalBuffer::Release(bool) ()
from /home/skyewm/.local/lib/python3.6/site-packages/jaxlib/xla_extension.so
#4 0x00007fff47a7d4e1 in xla::PyLocalBuffer::Delete() ()
from /home/skyewm/.local/lib/python3.6/site-packages/jaxlib/xla_extension.so
#5 0x00007fff47a7d617 in xla::PyLocalBuffer::~PyLocalBuffer() ()
from /home/skyewm/.local/lib/python3.6/site-packages/jaxlib/xla_extension.so
#6 0x00007fff479e37ca in pybind11::class_<xla::PyLocalBuffer, xla::ClientAndUniquePtr<xla::PyLocalBuffer> >::dealloc(pybind11::detail::value_and_holder&) ()
from /home/skyewm/.local/lib/python3.6/site-packages/jaxlib/xla_extension.so
#7 0x00007fff47a17403 in pybind11::detail::clear_instance(_object*) ()
from /home/skyewm/.local/lib/python3.6/site-packages/jaxlib/xla_extension.so
#8 0x00007fff47a17adf in pybind11_object_dealloc ()
from /home/skyewm/.local/lib/python3.6/site-packages/jaxlib/xla_extension.so
#9 0x00000000005085a7 in frame_dealloc (
f=Frame 0x1db6288, for file /home/skyewm/jax/jax/interpreters/xla.py, line 599, in _execute_trivial ()) at ../Objects/frameobject.c:462
We have a strong suspicion that the bug is here:
https://github.com/tensorflow/tensorflow/blob/05991352f7fdb12ed774561269609fd908e7f95e/tensorflow/compiler/xla/python/local_client.cc#L778
.release() and .get() are called on a std::unique_ptr in different arguments to the same function. Argument order of evaluation differs between compilers (e.g., clang vs gcc). We tend to test clang internally (and have never seen this bug) but our external builds are built with gcc which has the opposite order of evaluation. @skye is preparing a fix.
This should be fixed in jaxlib 0.1.45, hot off the press! I'm gonna close this, but please let us know if you're still experiencing segfaults. (Here's the fix for anyone interested: https://github.com/tensorflow/tensorflow/commit/78edbb6403b73d6c79bd58e23e08dc21b5c33847)
Most helpful comment
We have a strong suspicion that the bug is here:
https://github.com/tensorflow/tensorflow/blob/05991352f7fdb12ed774561269609fd908e7f95e/tensorflow/compiler/xla/python/local_client.cc#L778
.release()and.get()are called on astd::unique_ptrin different arguments to the same function. Argument order of evaluation differs between compilers (e.g., clang vs gcc). We tend to testclanginternally (and have never seen this bug) but our external builds are built withgccwhich has the opposite order of evaluation. @skye is preparing a fix.