Hi,
When trying to use conv nets, for example by running examples/resnet50.py from this repository, I get this error:
Falling back to default algorithm.
Convolution performance may be suboptimal.
2020-11-12 12:54:30.861420: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc:336] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2020-11-12 12:54:30.861673: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_client.cc:1780] Execution of replica 0 failed: Unimplemented: DNN library is not found.
Traceback (most recent call last):
File "resnet50.py", line 123, in
opt_state = update(i, opt_state, next(batches))
jax.traceback_util.FilteredStackTrace: RuntimeError: Unimplemented: DNN library is not found.
The stack trace above excludes JAX-internal frames.The following is the original exception that occurred, unmodified.`
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "resnet50.py", line 123, in
opt_state = update(i, opt_state, next(batches))
File "/h/farzaneh/.conda/envs/myjax3/lib/python3.8/site-packages/jax/traceback_util.py", line 133, in reraise_with_filtered_traceback
return fun(args, *kwargs)
File "/h/farzaneh/.conda/envs/myjax3/lib/python3.8/site-packages/jax/api.py", line 217, in f_jitted
out = xla.xla_call(
File "/h/farzaneh/.conda/envs/myjax3/lib/python3.8/site-packages/jax/core.py", line 1177, in bind
return call_bind(self, fun, args, *params)
File "/h/farzaneh/.conda/envs/myjax3/lib/python3.8/site-packages/jax/core.py", line 1168, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/h/farzaneh/.conda/envs/myjax3/lib/python3.8/site-packages/jax/core.py", line 1180, in process
return trace.process_call(self, fun, tracers, params)
File "/h/farzaneh/.conda/envs/myjax3/lib/python3.8/site-packages/jax/core.py", line 579, in process_call
return primitive.impl(f, tracers, *params)
File "/h/farzaneh/.conda/envs/myjax3/lib/python3.8/site-packages/jax/interpreters/xla.py", line 559, in _xla_call_impl
return compiled_fun(*args)
File "/h/farzaneh/.conda/envs/myjax3/lib/python3.8/site-packages/jax/interpreters/xla.py", line 807, in _execute_compiled
out_bufs = compiled.execute(input_bufs)
RuntimeError: Unimplemented: DNN library is not found.
`
Any suggestions to solve this? I am using cuda-11.0
Thanks for the question! Which version of jaxlib are you using?
I'm using 0.1.56
This is the one I installed:
pip install --upgrade jax jaxlib==0.1.56+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html
Can you try that with jaxlib==0.1.57+cuda110 and see if the issue goes away?
I suspect this means that jax can't locate your CuDNN installation. What is the path to your CuDNN libraries on your system?
I tried with jaxlib==0.1.57+cuda110 but still the same error.
For Cudnn this is what I have:
echo $CUDNN_PATH
/pkgs/cudnn-11.0-v8.0.4.30/`
Can you try setting the LD_LIBRARY_PATH environment variable to include the path to the .so files of your cudnn installation?
I export the LD_LIBRARY_PATH as follows:
export LD_LIBRARY_PATH=/pkgs/cudnn-11.0-v8.0.4.30/lib64/:/pkgs/cuda-11.1/targets/x86_64-linux/lib/
and it worked.
Thanks!
Thanks, @hawkinsp !