Jax: RuntimeError: Internal: libdevice not found at ./libdevice.10.bc

Created on 5 Oct 2020  路  5Comments  路  Source: google/jax

I install GPU version JAX and I encounter the following error when I first run a program with the AD support of JAX.

The following is the error info.

2020-10-05 16:30:40.039587: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:70] Can't find libdevice directory ${CUDA_DIR}/nvvm/libdevice. This may result in compilation or runtime failures, if the program we try to run uses routines from libdevice.
2020-10-05 16:30:40.039604: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:71] Searched for CUDA in the following directories:
2020-10-05 16:30:40.039608: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:74]   ./cuda_sdk_lib
2020-10-05 16:30:40.039611: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:74]   /usr/local/cuda-10.1
2020-10-05 16:30:40.039613: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:74]   .
2020-10-05 16:30:40.039616: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:76] You can choose the search directory by setting xla_gpu_cuda_data_dir in HloModule's DebugOptions.  For most apps, setting the environment variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work.
i: 0
2020-10-05 16:30:40.769544: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:320] libdevice is required by this HLO module but was not found at ./libdevice.10.bc
Traceback (most recent call last):
  File "adpath.py", line 252, in <module>
...........
  File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/numpy/lax_numpy.py", line 347, in fn
    return lax_fn(x)
  File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/lax/lax.py", line 285, in sqrt
    return sqrt_p.bind(x)
jax.traceback_util.FilteredStackTrace: RuntimeError: Internal: libdevice not found at ./libdevice.10.bc
The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  File "adpath.py", line 252, in <module>
.......
  File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/api.py", line 900, in jacfun
    y, jac = vmap(pushfwd, out_axes=(None, -1))(_std_basis(dyn_args))
  File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/traceback_util.py", line 137, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/api.py", line 1217, in batched_fun
    out_flat = batching.batch(flat_fun, args_flat, in_axes_flat,
  File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/interpreters/batching.py", line 36, in batch
    return batched_fun.call_wrapped(*in_vals)
  File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/linear_util.py", line 151, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/api.py", line 1681, in _jvp
    out_primals, out_tangents = ad.jvp(flat_fun).call_wrapped(ps_flat, ts_flat)
  File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/linear_util.py", line 151, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  ..........
  File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/numpy/lax_numpy.py", line 347, in fn
    return lax_fn(x)
  File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/lax/lax.py", line 285, in sqrt
    return sqrt_p.bind(x)
  File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/core.py", line 266, in bind
    out = top_trace.process_primitive(self, tracers, params)
  File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/core.py", line 574, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/interpreters/xla.py", line 224, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
  File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/interpreters/xla.py", line 264, in xla_primitive_callable
    compiled = backend_compile(backend, built_c, options)
  File "/home/tie/anaconda3/lib/python3.8/site-packages/jax/interpreters/xla.py", line 325, in backend_compile
    return backend.compile(built_c, compile_options=options)
RuntimeError: Internal: libdevice not found at ./libdevice.10.bc
====================================================================================

Maybe I don't install the correct CUDA version. I don't know. Can anyone help me figure it out?

Most helpful comment

sudo ln -s /usr/lib/cuda /usr/local/cuda-10.1

This actually solves my problem. Thanks very much @skye

All 5 comments

Check out the instructions near the bottom of https://github.com/google/jax#pip-installation, starting with "Note that some GPU functionality expects the CUDA installation to be at /usr/local/cuda-X.X". Do any of the suggestions there help?

I tried to create a symlink,
sudo ln -s /path/to/cuda /usr/local/cuda-X.X

And it still give the above error info.

sudo ln -s /usr/lib/cuda /usr/local/cuda-10.1

This actually solves my problem. Thanks very much @skye

Thanks for letting us know!

If you are not able to ln -s your cuda (for example because public cluster permissions), you always can set XLA_FLAGS=--xla_gpu_cuda_data_dir=/usr/lib/cuda before each command you need to use cuda:

XLA_FLAGS=--xla_gpu_cuda_data_dir=/usr/lib/cuda python3 your_script_with_cuda.py

Thank you very much for that post.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

shannon63 picture shannon63  路  3Comments

yfji picture yfji  路  3Comments

asross picture asross  路  3Comments

rdaems picture rdaems  路  3Comments

lonelykid picture lonelykid  路  3Comments