To be honest, I'm not sure whether the problem is caused by Cuda.
My Jax code worked fine with Cuda 11.0.
But I also wanted to use Tensorflow, and it had some issues with Cuda 11.0, so I downgraded Cuda to 10.2.
Then Tensorflow worked fine.
However, my original Jax coded started to crash. I reinstalled Jax for Cuda 10.2.
This is the error message.
2020-10-16 20:13:56.970918: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_blas.cc:225] failed to create cublas handle: CUBLAS_STATUS_NOT_INITIALIZED
2020-10-16 20:13:56.970956: F external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc:113] Check failed: stream->parent()->GetBlasGemmAlgorithms(&algorithms)
Fatal Python error: Aborted
Current thread 0x00007faf958a9740 (most recent call first):
File "/home/keunwoo/.virtualenvs/hacas/lib/python3.6/site-packages/jax/interpreters/xla.py", line 344 in backend_compile
File "/home/keunwoo/.virtualenvs/hacas/lib/python3.6/site-packages/jax/interpreters/xla.py", line 703 in _xla_callable
File "/home/keunwoo/.virtualenvs/hacas/lib/python3.6/site-packages/jax/linear_util.py", line 247 in memoized_fun
File "/home/keunwoo/.virtualenvs/hacas/lib/python3.6/site-packages/jax/interpreters/xla.py", line 557 in _xla_call_impl
File "/home/keunwoo/.virtualenvs/hacas/lib/python3.6/site-packages/jax/core.py", line 575 in process_call
File "/home/keunwoo/.virtualenvs/hacas/lib/python3.6/site-packages/jax/core.py", line 1165 in process
File "/home/keunwoo/.virtualenvs/hacas/lib/python3.6/site-packages/jax/core.py", line 1153 in call_bind
File "/home/keunwoo/.virtualenvs/hacas/lib/python3.6/site-packages/jax/core.py", line 1162 in bind
File "/home/keunwoo/.virtualenvs/hacas/lib/python3.6/site-packages/jax/api.py", line 217 in f_jitted
It is hard for me to interpret the message. It seems like the problem is related to Tensorflow.
Any kind of help would be appreciated.
Since you are mixing both TF and JAX on GPU, please read https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
Do the environment variables described there help?
@hawkinsp Thanks for your help. I got to try to set XLA_PYTHON_CLIENT_PREALLOCATE=false.
But the document doesn't say where (or how) I can set that parameter. Is that something I have to set as an environment variable?
Or can I set the parameter in code?
Setting XLA_PYTHON_CLIENT_PREALLOCATE=false works fine.
I don't use Tensorflow for the Jax project. They are totally separated.
It is interesting that they have an effect on each other even though they are not mixed in the same project.
The relevant thing here is the amount of free GPU memory. Try running nvidia-smi to see the current status.
e.g., if you run two GPU-using processes at the same time, you might easily trigger this error.
Closing since it seems your issue is resolved, feel free to reopen if needed.