Thanks to all contributors for their efforts in creating and open sourcing the library.
I would like to add my 2 cents of installation process involving building from source for whatever that's worth.
I always like to install things in conda envs so that there is no clash between different software version or requirement libraries.
MWE:
conda create -n jax python scipy cudnn cudatoolkit
conda list

Now the installation process:
python build/build.py --enable_cuda --cuda_path ~/miniconda3/envs/jax/lib/ --cudnn_path ~/miniconda3/envs/jax/include
2 Problems arise:
1. nvcc cannot be found in path ~/miniconda3/envs/jax/lib/ bin
actually the path is wrong, it should have been ~/miniconda3/envs/jax/bin.
Anyways, I copy nvcc from system wide installation /opt/cuda/bin/nvcc into ~/miniconda3/envs/jax/lib/bin.
So far so good.
2. re-running build it complains about cuda.h
Cuda Configuration Error: Cannot find cuda.h under ~/miniconda3/envs/jax/lib
FAILED: Build did NOT complete successfully (4 packages loaded, 16 targets
ok, let's copy /opt/cuda/include/cuda.h into ~/miniconda3/envs/jax/lib
re-running build after removing completely rm -rf ~/.cache/bazel
gives again the same error about not being able to find cuda.h.
At this point I am out of ideas.
Anyone else having other ideas on how to resolve this?
Looks like we need to do something to make the build.py script work better with conda. I think most of us have been installing CUDA, etc. manually rather than via conda. You might try pointing --cuda_path to a system-wide installation of CUDA in the meantime.
For now, the easiest workaround to do would be to install the prebuilt pip wheel of jaxlib. Another option would be to not enable GPU support, although that will be slower.
Note you can still use jax from source, even if you use a binary jaxlib. jax is all pure Python code, whereas jaxlib is the binary part (essentially just XLA and Python bindings around XLA.)
@hawkinsp thanks for the suggestion.
You might try pointing
--cuda_pathto a system-wide installation of CUDA in the meantime.
Although that's fine temporary I would prefer to avoid this. Especially in my case with rolling release os. The above solution is fine till my next upgrade where cuda will be updated and then I'll have to build again, that's why I prefer the conda solutions since they are isolated evironment snapshots not affected by time releases of os or other libs
@hawkinsp
For now, the easiest workaround to do would be to install the prebuilt
pipwheel ofjaxlib
That doesn't work either:
pip install --upgrade -q https://storage.googleapis.com/jax-wheels/cuda100/jaxlib-0.1.6-cp37-none-linux_x86_64.whl
pip install --upgrade -q jax
conda list
absl-py 0.7.0 pypi_0 pypi
blas 1.0 mkl
ca-certificates 2019.1.23 0
certifi 2018.11.29 py37_0
cudatoolkit 10.0.130 0
cudnn 7.3.1 cuda10.0_0
intel-openmp 2019.1 144
jax 0.1.18 pypi_0 pypi
jaxlib 0.1.6 pypi_0 pypi
libedit 3.1.20181209 hc058e9b_0
libffi 3.2.1 hd88cf55_4
libgcc-ng 8.2.0 hdf63c60_1
libgfortran-ng 7.3.0 hdf63c60_0
libstdcxx-ng 8.2.0 hdf63c60_1
mkl 2019.1 144
mkl_fft 1.0.10 py37ha843d7b_0
mkl_random 1.0.2 py37hd81dba3_0
ncurses 6.1 he6710b0_1
numpy 1.15.4 py37h7e9f1db_0
numpy-base 1.15.4 py37hde5b4d6_0
openssl 1.1.1a h7b6447c_0
opt-einsum 2.3.2 pypi_0 pypi
pip 19.0.1 py37_0
protobuf 3.6.1 pypi_0 pypi
python 3.7.2 h0371630_0
readline 7.0 h7b6447c_5
scipy 1.2.0 py37h7c811a0_0
setuptools 40.7.3 py37_0
six 1.12.0 pypi_0 pypi
sqlite 3.26.0 h7b6447c_0
tk 8.6.8 hbc83047_0
wheel 0.32.3 py37_0
xz 5.2.4 h14c3975_4
zlib 1.2.11 h7b6447c_3
Testing jax:
import jax.numpy as np
from jax import grad, jit, vmap
from jax import random
wkey, bkey = random.split(random.PRNGKey(0))
Error:
2019-02-08 14:40:41.685537: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:142] Unable to find libdevice dir. Using '.'
2019-02-08 14:40:41.694622: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:846] Failed to compile ptx to cubin. Will attempt to let GPU driver compile the ptx. Not found: /usr/local/cuda-10.0/bin/ptxas not found
>>>
It seems to me like the paths for cuda are hard-coded Not found: /usr/local/cuda-10.0/bin/ptxas. Same issues have also been raised for tensorlfow, are you guys by any chance re-using some code-base from tensorflow? Looks like that's the case here?
This particular error comes from XLA (which is included as part of TF) so this is indeed a common issue across TF and JAX, most likely.
I think you should be able to work around it by setting the environment variable:
export XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/my/cuda/installation
Out of curiosity, where is your CUDA installation? One thing we could do is add more paths to the list of search paths XLA uses to find ptxas. Are there are any CUDA... environment variables set that point to it?
@hawkinsp thanks for the tip
I think you should be able to work around it by setting the environment variable:
export XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/my/cuda/installation
This did the trick!
Out of curiosity, where is your CUDA installation?
---> /opt/cuda/
One thing we could do is add more paths to the list of search paths XLA uses to find
ptxas
What would be very useful is to break out of those dependencies if possible. Conda gives the ability to install cudatoolkit and cudnn, but those are light files such as /cuda/include & /cuda/lib so that any framework can utilize them. But it doesn't install ptxas or nvcc for that matter. What strikes me as odd is that since jax underneath is still using tensorflow (if my understanding is correct?) how come it still requires ptxas & nvcc (and especially for binary files such as wheels?) while I can easily install tensorflow in any environment that I have pre-installed cudatoolkit & cudnn?
Are there are any CUDA... environment variables set that point to it?
I am afraid not, and that is by choice since I like to keep things separated in their own environments. This is especially helpful because I have rolling release distro. So whenever my distro pkg manager updates cuda I'll have to recompile any software that relies on /opt/cuda, but instead if I keep everything on it's own conda environment then I don't have to recompile anything and keep working on with cuda.x even if that is outdated!
I looked into this a bit more.
Unfortunately I don't think we can use Conda's cudatoolkit without doing a significant amount of work, because it's incomplete and missing some standard parts of the CUDA toolkit. As the comment (https://github.com/google/jax/issues/302#issue-405702005) above says, it lacks nvcc, and it's going to be difficult to build XLA without it.
(Aside: nvcc isn't strictly necessary to build XLA at the moment, which is the only part of TensorFlow that JAX needs, but it would be a bunch of work to teach the common TensorFlow build how not to look for nvcc for the XLA-only case. If someone wants to teach TF's build scripts that CUDA without nvcc is a thing, I'd happily review the PR! And it's not out of the question that someone will want to add nvcc-built code to XLA in the future, at which point we'd have to undo all this work.)
I note also that, for example, the PyTorch build also requires the user to install CUDA system-wide, not from cudatoolkit (https://pytorch.org/get-started/locally/#linux-from-source). So this isn't really a problem specific to JAX.
I think for now, the best I can do is recommend something like this:
Install CUDA somewhere (e.g. use NVidia's runfile-based installer). You can put it anywhere you like, system-wide or in your home directory. I installed it in /usr/local/cuda-10.0.
conda create -n jax python scipy cudnn
conda list
git clone https://github.com/google/jax.git
python build/build.py --enable_cuda --cuda_path /usr/local/cuda-10.0 --cudnn_path ~/anaconda3/envs/jax
I hope that helps? I'm sorry I don't have a better suggestion; what gets redistributed in cudatoolkit isn't up to me :-(
@hawkinsp Thanks once again for the comprehensive reply.
note also that, for example, the PyTorch build also requires the user to install CUDA system-wide, not from
cudatoolkit
I don't think this statement is correct though.
Here's an example:
nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2018 NVIDIA Corporation
Built on Sat_Aug_25_21:08:01_CDT_2018
Cuda compilation tools, release 10.0, V10.0.130
I have cuda10 installed system-wise, that doesn't prevent me from installing pytorch with cuda9 and use that:
conda list
blas 1.0 mkl
ca-certificates 2019.1.23 0
certifi 2018.11.29 py37_0
cffi 1.11.5 py37he75722e_1
cudatoolkit 9.0 h13b8566_0
freetype 2.9.1 h8a8886c_1
intel-openmp 2019.1 144
jpeg 9b h024ee3a_2
libedit 3.1.20181209 hc058e9b_0
libffi 3.2.1 hd88cf55_4
libgcc-ng 8.2.0 hdf63c60_1
libgfortran-ng 7.3.0 hdf63c60_0
libpng 1.6.36 hbc83047_0
libstdcxx-ng 8.2.0 hdf63c60_1
libtiff 4.0.10 h2733197_2
mkl 2019.1 144
mkl_fft 1.0.10 py37ha843d7b_0
mkl_random 1.0.2 py37hd81dba3_0
ncurses 6.1 he6710b0_1
ninja 1.8.2 py37h6bb024c_1
numpy 1.15.4 py37h7e9f1db_0
numpy-base 1.15.4 py37hde5b4d6_0
olefile 0.46 py37_0
openssl 1.1.1a h7b6447c_0
pillow 5.4.1 py37h34e0f95_0
pip 19.0.1 py37_0
pycparser 2.19 py37_0
python 3.7.2 h0371630_0
pytorch 1.0.1 py3.7_cuda9.0.176_cudnn7.4.2_2 pytorch
readline 7.0 h7b6447c_5
scipy 1.2.0 py37h7c811a0_0
setuptools 40.7.3 py37_0
six 1.12.0 py37_0
sqlite 3.26.0 h7b6447c_0
tk 8.6.8 hbc83047_0
torchvision 0.2.1 py_2 pytorch
wheel 0.32.3 py37_0
xz 5.2.4 h14c3975_4
zlib 1.2.11 h7b6447c_3
zstd 1.3.7 h0b5b093_0
Example:
import torch
torch.randn(3, 5)
tensor([[-0.8300, 0.8224, -0.4590, 0.3709, 0.0423],
[ 0.5900, -2.2393, 0.0046, -0.2045, -0.4696],
[ 0.2258, -0.8005, -1.2405, -0.6236, 1.9128]])
>>>
The point is that you still need some system-wide install to get nvcc at build time. cudatoolkit isn't enough. (I guess that's what is happening for PyTorch.) TF doesn't really support mixing two different installs, I think. I'm sure PRs would be welcome.
I think it would be better to either:
nvcc dependency from XLA, as outlined above, orcudatoolkit to run JAX, even if it's not sufficient to build it.(But note there's a minor hurdle to the latter, too. XLA needs two unusual things from the CUDA runtime — libdevice, which happily the cudatoolkit package does have, and ptxas, which it does not. It is possible to fall back to using the NVidia kernel driver's version of ptxas, and XLA will do just that, but we've found it is often buggy and the driver-hosted copy deadlocks from time to time.)
Would providing prebuilt Conda packages solve your problem adequately?
@hawkinsp thanks a lot, especially for your patience.
Would providing prebuilt Conda packages solve your problem adequately?
I think that would be useful for a lot of ppl, not just me and would actually make jax more popular so that ppl could give it a try and report back on things that worked and things that didn't work. That would facilitate better development of jax IMHO.
This particular error comes from XLA (which is included as part of TF) so this is indeed a common issue across TF and JAX, most likely.
I think you should be able to work around it by setting the environment variable:
export XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/my/cuda/installationOut of curiosity, where is your CUDA installation? One thing we could do is add more paths to the list of search paths XLA uses to find
ptxas. Are there are any CUDA... environment variables set that point to it?
Thank you, this lets me install JAX with GPU support on my conda environment!
Would providing prebuilt Conda packages solve your problem adequately?
yes please, esp. if there are pip related regressions like here
How to delete cuda and cudnn from Conda?
Since Conda-installed CUDA has the error as follows, I want to remove conda-installed cuda and cudnn, and install independent CUDA and cudnn from Nvidia.
Error : Failed to get convolution algorithm. This is probably because cuDNN failed to initialize, so try looking to see if a warning log message was printed above.
However, while I use the commands as follows but can not remove them, I can not remove them.l
conda remove --name cuda --all
conda remove --name cudnn --all
I see that two documents including cudatoolkit-10.0.130-0 and cudnn-7.3.1-cuda10.0.0_0 in the path as
follows.
/home/anaconda3/pkgs/cudatoolkit-10.0.130-0
/home/anaconda3/pkgs/cudnn-7.3.1-cuda10.0.0_0
How can I delete cuda and cudnn embedded in Anaconda.
Thanks in advance,
Mike
That question might be better for a Conda mailing list or issue tracker; I don't think we know much about Conda.
I looked into this a bit more.
Unfortunately I don't think we can use Conda's
cudatoolkitwithout doing a significant amount of work, because it's incomplete and missing some standard parts of the CUDA toolkit. As the comment (#302 (comment)) above says, it lacksnvcc, and it's going to be difficult to build XLA without it.(Aside:
nvccisn't _strictly_ necessary to build XLA at the moment, which is the only part of TensorFlow that JAX needs, but it would be a bunch of work to teach the common TensorFlow build how not to look fornvccfor the XLA-only case. If someone wants to teach TF's build scripts that CUDA withoutnvccis a thing, I'd happily review the PR! And it's not out of the question that someone will want to addnvcc-built code to XLA in the future, at which point we'd have to undo all this work.)I note also that, for example, the PyTorch build also requires the user to install CUDA system-wide, not from
cudatoolkit(https://pytorch.org/get-started/locally/#linux-from-source). So this isn't really a problem specific to JAX.I think for now, the best I can do is recommend something like this:
Step 1
Install CUDA somewhere (e.g. use NVidia's runfile-based installer). You can put it anywhere you like, system-wide or in your home directory. I installed it in
/usr/local/cuda-10.0.Step 2
conda create -n jax python scipy cudnn conda list git clone https://github.com/google/jax.git python build/build.py --enable_cuda --cuda_path /usr/local/cuda-10.0 --cudnn_path ~/anaconda3/envs/jaxI hope that helps? I'm sorry I don't have a better suggestion; what gets redistributed in
cudatoolkitisn't up to me :-(
These steps broadly worked for me with a few tweaks:
Altogether, try modifying step 2 as:
conda create -n jax python tensorflow-gpu scipy future cudnn=YOUR_VERSION
conda list
git clone https://github.com/google/jax.git
python build/build.py --enable_cuda --cuda_path /usr/local/cuda-x.x --cudnn_path ~/anaconda3/envs/jax --python_bin_path ~/anaconda3/envs/jax/bin/python
pip install -e build
pip install -e .
Then reboot
Did this work for anyone else? It took about an hour to build. Is this normal?
I tried this (with 2 simple additional steps, highlighted) but got the errors below
conda create -n jax python tensorflow-gpu scipy future cudnn=7 python=3.7
**conda activate jax**
git clone https://github.com/google/jax.git
**cd jax**
python build/build.py --enable_cuda --cuda_path /usr --cudnn_path ~/miniconda3/envs/jax --python_bin_path ~/miniconda3/envs/jax/bin/python
...
Bazel binary path: ./bazel-0.24.1-linux-x86_64
Python binary path: /home/murphyk/miniconda3/envs/jax/bin/python
MKL-DNN enabled: yes
-march=native: no
CUDA enabled: yes
CUDA toolkit path: /usr
CUDNN library path: /home/murphyk/miniconda3/envs/jax
Building XLA and installing it in the jaxlib source tree...
INFO: Build options --action_env and --python_path have changed, discarding analysis cache.
ERROR: /home/murphyk/jax/build/BUILD.bazel:21:1: error loading package 'jaxlib': in /home/murphyk/.cache/bazel/_bazel_murphyk/dd8a6ab338402747dc013d7665a15b3c/external/org_tensorflow/tensorflow/core/platform/default/build_config.bzl: Encountered error while reading extension file 'cuda/build_defs.bzl': no such package '@local_config_cuda//cuda': Traceback (most recent call last):
File "/home/murphyk/.cache/bazel/_bazel_murphyk/dd8a6ab338402747dc013d7665a15b3c/external/org_tensorflow/third_party/gpus/cuda_configure.bzl", line 1266
_create_local_cuda_repository(repository_ctx)
File "/home/murphyk/.cache/bazel/_bazel_murphyk/dd8a6ab338402747dc013d7665a15b3c/external/org_tensorflow/third_party/gpus/cuda_configure.bzl", line 988, in _create_local_cuda_repository
_get_cuda_config(repository_ctx)
File "/home/murphyk/.cache/bazel/_bazel_murphyk/dd8a6ab338402747dc013d7665a15b3c/external/org_tensorflow/third_party/gpus/cuda_configure.bzl", line 714, in _get_cuda_config
find_cuda_config(repository_ctx, ["cuda", "cudnn"])
File "/home/murphyk/.cache/bazel/_bazel_murphyk/dd8a6ab338402747dc013d7665a15b3c/external/org_tensorflow/third_party/gpus/cuda_configure.bzl", line 694, in find_cuda_config
auto_configure_fail(("Failed to run find_cuda_config...))
File "/home/murphyk/.cache/bazel/_bazel_murphyk/dd8a6ab338402747dc013d7665a15b3c/external/org_tensorflow/third_party/gpus/cuda_configure.bzl", line 325, in auto_configure_fail
fail(("\n%sCuda Configuration Error:%...)))
Cuda Configuration Error: Failed to run find_cuda_config.py: Inconsistent CUDA toolkit path: /usr vs /usr/lib
and referenced by '//build:install_xla_in_source_tree'
ERROR: /home/murphyk/jax/build/BUILD.bazel:21:1: error loading package 'jaxlib': in /home/murphyk/.cache/bazel/_bazel_murphyk/dd8a6ab338402747dc013d7665a15b3c/external/org_tensorflow/tensorflow/core/platform/default/build_config.bzl: Encountered error while reading extension file 'cuda/build_defs.bzl': no such package '@local_config_cuda//cuda': Traceback (most recent call last):
File "/home/murphyk/.cache/bazel/_bazel_murphyk/dd8a6ab338402747dc013d7665a15b3c/external/org_tensorflow/third_party/gpus/cuda_configure.bzl", line 1266
_create_local_cuda_repository(repository_ctx)
File "/home/murphyk/.cache/bazel/_bazel_murphyk/dd8a6ab338402747dc013d7665a15b3c/external/org_tensorflow/third_party/gpus/cuda_configure.bzl", line 988, in _create_local_cuda_repository
_get_cuda_config(repository_ctx)
File "/home/murphyk/.cache/bazel/_bazel_murphyk/dd8a6ab338402747dc013d7665a15b3c/external/org_tensorflow/third_party/gpus/cuda_configure.bzl", line 714, in _get_cuda_config
find_cuda_config(repository_ctx, ["cuda", "cudnn"])
File "/home/murphyk/.cache/bazel/_bazel_murphyk/dd8a6ab338402747dc013d7665a15b3c/external/org_tensorflow/third_party/gpus/cuda_configure.bzl", line 694, in find_cuda_config
auto_configure_fail(("Failed to run find_cuda_config...))
File "/home/murphyk/.cache/bazel/_bazel_murphyk/dd8a6ab338402747dc013d7665a15b3c/external/org_tensorflow/third_party/gpus/cuda_configure.bzl", line 325, in auto_configure_fail
fail(("\n%sCuda Configuration Error:%...)))
However, maybe I should somehow use the locations below?
locate cuda | grep /cuda*
/usr/include/cuda.h
...
locate cudnn.h
/usr/lib/x86_64-linux-gnu/libcudnn.so
/usr/include/cudnn.h
...
i want to upgrade tensorflow-gpu=1.9 to tf-gpu=1.14 in conda environment but the problem is when i try to upgrade it automatically upgrade cuda and cudnn which i don't want. anyone can guide me how to upgrade tf-gpu without disturbing CUDAA and cudnn. thanks
Not sure how you ended up on this issue tracker, but TensorFlow binaries are tied closely to specific CUDA and cuDNN versions. If you want to use e.g. TF 1.14 with CUDA 9, I believe you’ll need to build it from source.
I think you can close this issue :)
Sorry for the noise.
On Sat, Sep 14, 2019 at 9:30 AM James Bradbury notifications@github.com
wrote:
Not sure how you ended up on this issue tracker, but TensorFlow binaries
are tied closely to specific CUDA and cuDNN versions. If you want to use
e.g. TF 1.14 with CUDA 9, I believe you’ll need to build it from source.—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub
https://github.com/google/jax/issues/302?email_source=notifications&email_token=ABDK6EGWNHP33ZDGESX5SGLQJUGSRA5CNFSM4GTZIWLKYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOD6W7E3A#issuecomment-531493484,
or mute the thread
https://github.com/notifications/unsubscribe-auth/ABDK6EBSL7DZF3OJKQRYCU3QJUGSRANCNFSM4GTZIWLA
.
If you install cudatoolkit-dev from the conda-forge channel, nvcc is there. I can get jax to work from the binary wheel with the 10.1 version...
If you install cudatoolkit-dev from the conda-forge channel, nvcc is there. I can get jax to work from the binary wheel with the 10.1 version...
Could you maybe tell us how you did it? (I'm guessing environment vars, but I can't get it to work without a system-wide CUDA installation...)
@rahuldave I was able to get nvcc from cudatoolkit-dev as you said, but did you run into the following at all where it cannot find the runtime file?

Figured it out - just had to include the directory in my environment with cuda_runtime.h during the call to nvcc. Found that directory using locate
Closing since this bug looks stale. Feel free to reopen if there's some action for us to take!
Most helpful comment
@hawkinsp thanks a lot, especially for your patience.
I think that would be useful for a lot of ppl, not just me and would actually make jax more popular so that ppl could give it a try and report back on things that worked and things that didn't work. That would facilitate better development of jax IMHO.