I install GPU version JAX and I encounter the following error when I first run a program with the AD support of JAX.
The following is the error info.
2020-10-05 16:30:40.039587: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:70] Can't find libdevice directory ${CUDA_DIR}/nvvm/libdevice. This may result in compilation or runtime failures, if the program we try to run uses routines from libdevice.
2020-10-05 16:30:40.039604: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:71] Searched for CUDA in the following directories:
2020-10-05 16:30:40.039608: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:74] ./cuda_sdk_lib
2020-10-05 16:30:40.039611: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:74] /usr/local/cuda-10.1
2020-10-05 16:30:40.039613: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:74] .
2020-10-05 16:30:40.039616: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:76] You can choose the search directory by setting xla_gpu_cuda_data_dir in HloModule's DebugOptions. For most apps, setting the environment variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work.
i: 0
2020-10-05 16:30:40.769544: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:320] libdevice is required by this HLO module but was not found at ./libdevice.10.bc
Traceback (most recent call last):
File "adpath.py", line 252, in <module>
...........
File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/numpy/lax_numpy.py", line 347, in fn
return lax_fn(x)
File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/lax/lax.py", line 285, in sqrt
return sqrt_p.bind(x)
jax.traceback_util.FilteredStackTrace: RuntimeError: Internal: libdevice not found at ./libdevice.10.bc
The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "adpath.py", line 252, in <module>
.......
File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/api.py", line 900, in jacfun
y, jac = vmap(pushfwd, out_axes=(None, -1))(_std_basis(dyn_args))
File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/traceback_util.py", line 137, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/api.py", line 1217, in batched_fun
out_flat = batching.batch(flat_fun, args_flat, in_axes_flat,
File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/interpreters/batching.py", line 36, in batch
return batched_fun.call_wrapped(*in_vals)
File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/linear_util.py", line 151, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/api.py", line 1681, in _jvp
out_primals, out_tangents = ad.jvp(flat_fun).call_wrapped(ps_flat, ts_flat)
File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/linear_util.py", line 151, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
..........
File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/numpy/lax_numpy.py", line 347, in fn
return lax_fn(x)
File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/lax/lax.py", line 285, in sqrt
return sqrt_p.bind(x)
File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/core.py", line 266, in bind
out = top_trace.process_primitive(self, tracers, params)
File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/core.py", line 574, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/interpreters/xla.py", line 224, in apply_primitive
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/interpreters/xla.py", line 264, in xla_primitive_callable
compiled = backend_compile(backend, built_c, options)
File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/interpreters/xla.py", line 325, in backend_compile
return backend.compile(built_c, compile_options=options)
RuntimeError: Internal: libdevice not found at ./libdevice.10.bc
====================================================================================
Maybe I don't install the correct CUDA version. I don't know. Can anyone help me figure it out?
Check out the instructions near the bottom of https://github.com/google/jax#pip-installation, starting with "Note that some GPU functionality expects the CUDA installation to be at /usr/local/cuda-X.X". Do any of the suggestions there help?
I tried to create a symlink,
sudo ln -s /path/to/cuda /usr/local/cuda-X.X
And it still give the above error info.
sudo ln -s /usr/lib/cuda /usr/local/cuda-10.1
This actually solves my problem. Thanks very much @skye
Thanks for letting us know!
If you are not able to ln -s your cuda (for example because public cluster permissions), you always can set XLA_FLAGS=--xla_gpu_cuda_data_dir=/usr/lib/cuda before each command you need to use cuda:
XLA_FLAGS=--xla_gpu_cuda_data_dir=/usr/lib/cuda python3 your_script_with_cuda.py
Thank you very much for that post.
Most helpful comment
sudo ln -s /usr/lib/cuda /usr/local/cuda-10.1
This actually solves my problem. Thanks very much @skye