Jax: Installing from source using Conda and CUDA could be improved

Created on 1 Feb 2019  Â·  22Comments  Â·  Source: google/jax

Thanks to all contributors for their efforts in creating and open sourcing the library.
I would like to add my 2 cents of installation process involving building from source for whatever that's worth.

I always like to install things in conda envs so that there is no clash between different software version or requirement libraries.

MWE:

conda create -n jax python scipy cudnn cudatoolkit
conda list

image

Now the installation process:

python build/build.py --enable_cuda --cuda_path  ~/miniconda3/envs/jax/lib/ --cudnn_path ~/miniconda3/envs/jax/include

2 Problems arise:

1. nvcc cannot be found in path ~/miniconda3/envs/jax/lib/ bin 
actually the path is wrong, it should have been ~/miniconda3/envs/jax/bin.
Anyways, I copy nvcc from system wide installation /opt/cuda/bin/nvcc into ~/miniconda3/envs/jax/lib/bin.
So far so good.

2. re-running build it complains about cuda.h
Cuda Configuration Error: Cannot find cuda.h under ~/miniconda3/envs/jax/lib 
FAILED: Build did NOT complete successfully (4 packages loaded, 16 targets

ok, let's copy /opt/cuda/include/cuda.h into ~/miniconda3/envs/jax/lib 
re-running build after removing completely rm -rf ~/.cache/bazel
gives again the same error about not being able to find cuda.h.
At this point I am out of ideas.

Anyone else having other ideas on how to resolve this?

build

Most helpful comment

@hawkinsp thanks a lot, especially for your patience.

Would providing prebuilt Conda packages solve your problem adequately?

I think that would be useful for a lot of ppl, not just me and would actually make jax more popular so that ppl could give it a try and report back on things that worked and things that didn't work. That would facilitate better development of jax IMHO.

All 22 comments

Looks like we need to do something to make the build.py script work better with conda. I think most of us have been installing CUDA, etc. manually rather than via conda. You might try pointing --cuda_path to a system-wide installation of CUDA in the meantime.

For now, the easiest workaround to do would be to install the prebuilt pip wheel of jaxlib. Another option would be to not enable GPU support, although that will be slower.

Note you can still use jax from source, even if you use a binary jaxlib. jax is all pure Python code, whereas jaxlib is the binary part (essentially just XLA and Python bindings around XLA.)

@hawkinsp thanks for the suggestion.

You might try pointing --cuda_path to a system-wide installation of CUDA in the meantime.

Although that's fine temporary I would prefer to avoid this. Especially in my case with rolling release os. The above solution is fine till my next upgrade where cuda will be updated and then I'll have to build again, that's why I prefer the conda solutions since they are isolated evironment snapshots not affected by time releases of os or other libs

@hawkinsp

For now, the easiest workaround to do would be to install the prebuilt pip wheel of jaxlib

That doesn't work either:

pip install --upgrade -q https://storage.googleapis.com/jax-wheels/cuda100/jaxlib-0.1.6-cp37-none-linux_x86_64.whl
pip install --upgrade -q jax
conda list
absl-py                   0.7.0                    pypi_0    pypi
blas                      1.0                         mkl  
ca-certificates           2019.1.23                     0  
certifi                   2018.11.29               py37_0  
cudatoolkit               10.0.130                      0  
cudnn                     7.3.1                cuda10.0_0  
intel-openmp              2019.1                      144  
jax                       0.1.18                   pypi_0    pypi
jaxlib                    0.1.6                    pypi_0    pypi
libedit                   3.1.20181209         hc058e9b_0  
libffi                    3.2.1                hd88cf55_4  
libgcc-ng                 8.2.0                hdf63c60_1  
libgfortran-ng            7.3.0                hdf63c60_0  
libstdcxx-ng              8.2.0                hdf63c60_1  
mkl                       2019.1                      144  
mkl_fft                   1.0.10           py37ha843d7b_0  
mkl_random                1.0.2            py37hd81dba3_0  
ncurses                   6.1                  he6710b0_1  
numpy                     1.15.4           py37h7e9f1db_0  
numpy-base                1.15.4           py37hde5b4d6_0  
openssl                   1.1.1a               h7b6447c_0  
opt-einsum                2.3.2                    pypi_0    pypi
pip                       19.0.1                   py37_0  
protobuf                  3.6.1                    pypi_0    pypi
python                    3.7.2                h0371630_0  
readline                  7.0                  h7b6447c_5  
scipy                     1.2.0            py37h7c811a0_0  
setuptools                40.7.3                   py37_0  
six                       1.12.0                   pypi_0    pypi
sqlite                    3.26.0               h7b6447c_0  
tk                        8.6.8                hbc83047_0  
wheel                     0.32.3                   py37_0  
xz                        5.2.4                h14c3975_4  
zlib                      1.2.11               h7b6447c_3 

Testing jax:

import jax.numpy as np
from jax import grad, jit, vmap
from jax import random
wkey, bkey = random.split(random.PRNGKey(0))

Error:

2019-02-08 14:40:41.685537: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:142] Unable to find libdevice dir. Using '.'
2019-02-08 14:40:41.694622: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:846] Failed to compile ptx to cubin.  Will attempt to let GPU driver compile the ptx. Not found: /usr/local/cuda-10.0/bin/ptxas not found
>>> 

It seems to me like the paths for cuda are hard-coded Not found: /usr/local/cuda-10.0/bin/ptxas. Same issues have also been raised for tensorlfow, are you guys by any chance re-using some code-base from tensorflow? Looks like that's the case here?

This particular error comes from XLA (which is included as part of TF) so this is indeed a common issue across TF and JAX, most likely.

I think you should be able to work around it by setting the environment variable:
export XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/my/cuda/installation

Out of curiosity, where is your CUDA installation? One thing we could do is add more paths to the list of search paths XLA uses to find ptxas. Are there are any CUDA... environment variables set that point to it?

@hawkinsp thanks for the tip

I think you should be able to work around it by setting the environment variable:
export XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/my/cuda/installation

This did the trick!

Out of curiosity, where is your CUDA installation?

---> /opt/cuda/

One thing we could do is add more paths to the list of search paths XLA uses to find ptxas

What would be very useful is to break out of those dependencies if possible. Conda gives the ability to install cudatoolkit and cudnn, but those are light files such as /cuda/include & /cuda/lib so that any framework can utilize them. But it doesn't install ptxas or nvcc for that matter. What strikes me as odd is that since jax underneath is still using tensorflow (if my understanding is correct?) how come it still requires ptxas & nvcc (and especially for binary files such as wheels?) while I can easily install tensorflow in any environment that I have pre-installed cudatoolkit & cudnn?

Are there are any CUDA... environment variables set that point to it?

I am afraid not, and that is by choice since I like to keep things separated in their own environments. This is especially helpful because I have rolling release distro. So whenever my distro pkg manager updates cuda I'll have to recompile any software that relies on /opt/cuda, but instead if I keep everything on it's own conda environment then I don't have to recompile anything and keep working on with cuda.x even if that is outdated!

I looked into this a bit more.

Unfortunately I don't think we can use Conda's cudatoolkit without doing a significant amount of work, because it's incomplete and missing some standard parts of the CUDA toolkit. As the comment (https://github.com/google/jax/issues/302#issue-405702005) above says, it lacks nvcc, and it's going to be difficult to build XLA without it.

(Aside: nvcc isn't strictly necessary to build XLA at the moment, which is the only part of TensorFlow that JAX needs, but it would be a bunch of work to teach the common TensorFlow build how not to look for nvcc for the XLA-only case. If someone wants to teach TF's build scripts that CUDA without nvcc is a thing, I'd happily review the PR! And it's not out of the question that someone will want to add nvcc-built code to XLA in the future, at which point we'd have to undo all this work.)

I note also that, for example, the PyTorch build also requires the user to install CUDA system-wide, not from cudatoolkit (https://pytorch.org/get-started/locally/#linux-from-source). So this isn't really a problem specific to JAX.

I think for now, the best I can do is recommend something like this:

Step 1

Install CUDA somewhere (e.g. use NVidia's runfile-based installer). You can put it anywhere you like, system-wide or in your home directory. I installed it in /usr/local/cuda-10.0.

Step 2

conda create -n jax python scipy cudnn
conda list
git clone https://github.com/google/jax.git
python build/build.py --enable_cuda --cuda_path  /usr/local/cuda-10.0 --cudnn_path ~/anaconda3/envs/jax

I hope that helps? I'm sorry I don't have a better suggestion; what gets redistributed in cudatoolkit isn't up to me :-(

@hawkinsp Thanks once again for the comprehensive reply.

note also that, for example, the PyTorch build also requires the user to install CUDA system-wide, not from cudatoolkit

I don't think this statement is correct though.
Here's an example:

nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2018 NVIDIA Corporation
Built on Sat_Aug_25_21:08:01_CDT_2018
Cuda compilation tools, release 10.0, V10.0.130

I have cuda10 installed system-wise, that doesn't prevent me from installing pytorch with cuda9 and use that:

conda list
blas                      1.0                         mkl  
ca-certificates           2019.1.23                     0  
certifi                   2018.11.29               py37_0  
cffi                      1.11.5           py37he75722e_1  
cudatoolkit               9.0                  h13b8566_0  
freetype                  2.9.1                h8a8886c_1  
intel-openmp              2019.1                      144  
jpeg                      9b                   h024ee3a_2  
libedit                   3.1.20181209         hc058e9b_0  
libffi                    3.2.1                hd88cf55_4  
libgcc-ng                 8.2.0                hdf63c60_1  
libgfortran-ng            7.3.0                hdf63c60_0  
libpng                    1.6.36               hbc83047_0  
libstdcxx-ng              8.2.0                hdf63c60_1  
libtiff                   4.0.10               h2733197_2  
mkl                       2019.1                      144  
mkl_fft                   1.0.10           py37ha843d7b_0  
mkl_random                1.0.2            py37hd81dba3_0  
ncurses                   6.1                  he6710b0_1  
ninja                     1.8.2            py37h6bb024c_1  
numpy                     1.15.4           py37h7e9f1db_0  
numpy-base                1.15.4           py37hde5b4d6_0  
olefile                   0.46                     py37_0  
openssl                   1.1.1a               h7b6447c_0  
pillow                    5.4.1            py37h34e0f95_0  
pip                       19.0.1                   py37_0  
pycparser                 2.19                     py37_0  
python                    3.7.2                h0371630_0  
pytorch                   1.0.1           py3.7_cuda9.0.176_cudnn7.4.2_2    pytorch
readline                  7.0                  h7b6447c_5  
scipy                     1.2.0            py37h7c811a0_0  
setuptools                40.7.3                   py37_0  
six                       1.12.0                   py37_0  
sqlite                    3.26.0               h7b6447c_0  
tk                        8.6.8                hbc83047_0  
torchvision               0.2.1                      py_2    pytorch
wheel                     0.32.3                   py37_0  
xz                        5.2.4                h14c3975_4  
zlib                      1.2.11               h7b6447c_3  
zstd                      1.3.7                h0b5b093_0  

Example:

import torch
torch.randn(3, 5)
tensor([[-0.8300,  0.8224, -0.4590,  0.3709,  0.0423],
        [ 0.5900, -2.2393,  0.0046, -0.2045, -0.4696],
        [ 0.2258, -0.8005, -1.2405, -0.6236,  1.9128]])
>>> 

The point is that you still need some system-wide install to get nvcc at build time. cudatoolkit isn't enough. (I guess that's what is happening for PyTorch.) TF doesn't really support mixing two different installs, I think. I'm sure PRs would be welcome.

I think it would be better to either:

  • remove the nvcc dependency from XLA, as outlined above, or
  • even better, we provide prebuilt Conda packages so you can then install just cudatoolkit to run JAX, even if it's not sufficient to build it.

(But note there's a minor hurdle to the latter, too. XLA needs two unusual things from the CUDA runtime — libdevice, which happily the cudatoolkit package does have, and ptxas, which it does not. It is possible to fall back to using the NVidia kernel driver's version of ptxas, and XLA will do just that, but we've found it is often buggy and the driver-hosted copy deadlocks from time to time.)

Would providing prebuilt Conda packages solve your problem adequately?

@hawkinsp thanks a lot, especially for your patience.

Would providing prebuilt Conda packages solve your problem adequately?

I think that would be useful for a lot of ppl, not just me and would actually make jax more popular so that ppl could give it a try and report back on things that worked and things that didn't work. That would facilitate better development of jax IMHO.

This particular error comes from XLA (which is included as part of TF) so this is indeed a common issue across TF and JAX, most likely.

I think you should be able to work around it by setting the environment variable:
export XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/my/cuda/installation

Out of curiosity, where is your CUDA installation? One thing we could do is add more paths to the list of search paths XLA uses to find ptxas. Are there are any CUDA... environment variables set that point to it?

Thank you, this lets me install JAX with GPU support on my conda environment!

Would providing prebuilt Conda packages solve your problem adequately?

yes please, esp. if there are pip related regressions like here

How to delete cuda and cudnn from Conda?

Since Conda-installed CUDA has the error as follows, I want to remove conda-installed cuda and cudnn, and install independent CUDA and cudnn from Nvidia.

Error : Failed to get convolution algorithm. This is probably because cuDNN failed to initialize, so try looking to see if a warning log message was printed above.

However, while I use the commands as follows but can not remove them, I can not remove them.l
conda remove --name cuda --all
conda remove --name cudnn --all

I see that two documents including cudatoolkit-10.0.130-0 and cudnn-7.3.1-cuda10.0.0_0 in the path as
follows.

/home/anaconda3/pkgs/cudatoolkit-10.0.130-0
/home/anaconda3/pkgs/cudnn-7.3.1-cuda10.0.0_0

How can I delete cuda and cudnn embedded in Anaconda.

Thanks in advance,

Mike

That question might be better for a Conda mailing list or issue tracker; I don't think we know much about Conda.

I looked into this a bit more.

Unfortunately I don't think we can use Conda's cudatoolkit without doing a significant amount of work, because it's incomplete and missing some standard parts of the CUDA toolkit. As the comment (#302 (comment)) above says, it lacks nvcc, and it's going to be difficult to build XLA without it.

(Aside: nvcc isn't _strictly_ necessary to build XLA at the moment, which is the only part of TensorFlow that JAX needs, but it would be a bunch of work to teach the common TensorFlow build how not to look for nvcc for the XLA-only case. If someone wants to teach TF's build scripts that CUDA without nvcc is a thing, I'd happily review the PR! And it's not out of the question that someone will want to add nvcc-built code to XLA in the future, at which point we'd have to undo all this work.)

I note also that, for example, the PyTorch build also requires the user to install CUDA system-wide, not from cudatoolkit (https://pytorch.org/get-started/locally/#linux-from-source). So this isn't really a problem specific to JAX.

I think for now, the best I can do is recommend something like this:

Step 1

Install CUDA somewhere (e.g. use NVidia's runfile-based installer). You can put it anywhere you like, system-wide or in your home directory. I installed it in /usr/local/cuda-10.0.

Step 2

conda create -n jax python scipy cudnn
conda list
git clone https://github.com/google/jax.git
python build/build.py --enable_cuda --cuda_path  /usr/local/cuda-10.0 --cudnn_path ~/anaconda3/envs/jax

I hope that helps? I'm sorry I don't have a better suggestion; what gets redistributed in cudatoolkit isn't up to me :-(

These steps broadly worked for me with a few tweaks:

  • need to specify the version of cudnn of your system, conda doesn't automatically work that out
  • needed to install future
  • needed to install tensorflow-gpu
  • needed to specify python binary path
  • needed to finish building from source with pip
  • needed to reboot

Altogether, try modifying step 2 as:

conda create -n jax python tensorflow-gpu scipy future cudnn=YOUR_VERSION
conda list
git clone https://github.com/google/jax.git
python build/build.py --enable_cuda --cuda_path /usr/local/cuda-x.x --cudnn_path ~/anaconda3/envs/jax --python_bin_path ~/anaconda3/envs/jax/bin/python
pip install -e build
pip install -e .

Then reboot

Did this work for anyone else? It took about an hour to build. Is this normal?

I tried this (with 2 simple additional steps, highlighted) but got the errors below

conda create -n jax python tensorflow-gpu scipy future cudnn=7 python=3.7
**conda activate jax**
git clone https://github.com/google/jax.git
**cd jax**
python build/build.py --enable_cuda --cuda_path /usr --cudnn_path ~/miniconda3/envs/jax --python_bin_path ~/miniconda3/envs/jax/bin/python

...

Bazel binary path: ./bazel-0.24.1-linux-x86_64
Python binary path: /home/murphyk/miniconda3/envs/jax/bin/python
MKL-DNN enabled: yes
-march=native: no
CUDA enabled: yes
CUDA toolkit path: /usr
CUDNN library path: /home/murphyk/miniconda3/envs/jax

Building XLA and installing it in the jaxlib source tree...
INFO: Build options --action_env and --python_path have changed, discarding analysis cache.
ERROR: /home/murphyk/jax/build/BUILD.bazel:21:1: error loading package 'jaxlib': in /home/murphyk/.cache/bazel/_bazel_murphyk/dd8a6ab338402747dc013d7665a15b3c/external/org_tensorflow/tensorflow/core/platform/default/build_config.bzl: Encountered error while reading extension file 'cuda/build_defs.bzl': no such package '@local_config_cuda//cuda': Traceback (most recent call last):
    File "/home/murphyk/.cache/bazel/_bazel_murphyk/dd8a6ab338402747dc013d7665a15b3c/external/org_tensorflow/third_party/gpus/cuda_configure.bzl", line 1266
        _create_local_cuda_repository(repository_ctx)
    File "/home/murphyk/.cache/bazel/_bazel_murphyk/dd8a6ab338402747dc013d7665a15b3c/external/org_tensorflow/third_party/gpus/cuda_configure.bzl", line 988, in _create_local_cuda_repository
        _get_cuda_config(repository_ctx)
    File "/home/murphyk/.cache/bazel/_bazel_murphyk/dd8a6ab338402747dc013d7665a15b3c/external/org_tensorflow/third_party/gpus/cuda_configure.bzl", line 714, in _get_cuda_config
        find_cuda_config(repository_ctx, ["cuda", "cudnn"])
    File "/home/murphyk/.cache/bazel/_bazel_murphyk/dd8a6ab338402747dc013d7665a15b3c/external/org_tensorflow/third_party/gpus/cuda_configure.bzl", line 694, in find_cuda_config
        auto_configure_fail(("Failed to run find_cuda_config...))
    File "/home/murphyk/.cache/bazel/_bazel_murphyk/dd8a6ab338402747dc013d7665a15b3c/external/org_tensorflow/third_party/gpus/cuda_configure.bzl", line 325, in auto_configure_fail
        fail(("\n%sCuda Configuration Error:%...)))

Cuda Configuration Error: Failed to run find_cuda_config.py: Inconsistent CUDA toolkit path: /usr vs /usr/lib
 and referenced by '//build:install_xla_in_source_tree'
ERROR: /home/murphyk/jax/build/BUILD.bazel:21:1: error loading package 'jaxlib': in /home/murphyk/.cache/bazel/_bazel_murphyk/dd8a6ab338402747dc013d7665a15b3c/external/org_tensorflow/tensorflow/core/platform/default/build_config.bzl: Encountered error while reading extension file 'cuda/build_defs.bzl': no such package '@local_config_cuda//cuda': Traceback (most recent call last):
    File "/home/murphyk/.cache/bazel/_bazel_murphyk/dd8a6ab338402747dc013d7665a15b3c/external/org_tensorflow/third_party/gpus/cuda_configure.bzl", line 1266
        _create_local_cuda_repository(repository_ctx)
    File "/home/murphyk/.cache/bazel/_bazel_murphyk/dd8a6ab338402747dc013d7665a15b3c/external/org_tensorflow/third_party/gpus/cuda_configure.bzl", line 988, in _create_local_cuda_repository
        _get_cuda_config(repository_ctx)
    File "/home/murphyk/.cache/bazel/_bazel_murphyk/dd8a6ab338402747dc013d7665a15b3c/external/org_tensorflow/third_party/gpus/cuda_configure.bzl", line 714, in _get_cuda_config
        find_cuda_config(repository_ctx, ["cuda", "cudnn"])
    File "/home/murphyk/.cache/bazel/_bazel_murphyk/dd8a6ab338402747dc013d7665a15b3c/external/org_tensorflow/third_party/gpus/cuda_configure.bzl", line 694, in find_cuda_config
        auto_configure_fail(("Failed to run find_cuda_config...))
    File "/home/murphyk/.cache/bazel/_bazel_murphyk/dd8a6ab338402747dc013d7665a15b3c/external/org_tensorflow/third_party/gpus/cuda_configure.bzl", line 325, in auto_configure_fail
        fail(("\n%sCuda Configuration Error:%...)))

However, maybe I should somehow use the locations below?

locate cuda | grep /cuda*
/usr/include/cuda.h
...

locate cudnn.h 
/usr/lib/x86_64-linux-gnu/libcudnn.so
/usr/include/cudnn.h
...

i want to upgrade tensorflow-gpu=1.9 to tf-gpu=1.14 in conda environment but the problem is when i try to upgrade it automatically upgrade cuda and cudnn which i don't want. anyone can guide me how to upgrade tf-gpu without disturbing CUDAA and cudnn. thanks

Not sure how you ended up on this issue tracker, but TensorFlow binaries are tied closely to specific CUDA and cuDNN versions. If you want to use e.g. TF 1.14 with CUDA 9, I believe you’ll need to build it from source.

I think you can close this issue :)
Sorry for the noise.

On Sat, Sep 14, 2019 at 9:30 AM James Bradbury notifications@github.com
wrote:

Not sure how you ended up on this issue tracker, but TensorFlow binaries
are tied closely to specific CUDA and cuDNN versions. If you want to use
e.g. TF 1.14 with CUDA 9, I believe you’ll need to build it from source.

—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub
https://github.com/google/jax/issues/302?email_source=notifications&email_token=ABDK6EGWNHP33ZDGESX5SGLQJUGSRA5CNFSM4GTZIWLKYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOD6W7E3A#issuecomment-531493484,
or mute the thread
https://github.com/notifications/unsubscribe-auth/ABDK6EBSL7DZF3OJKQRYCU3QJUGSRANCNFSM4GTZIWLA
.

If you install cudatoolkit-dev from the conda-forge channel, nvcc is there. I can get jax to work from the binary wheel with the 10.1 version...

If you install cudatoolkit-dev from the conda-forge channel, nvcc is there. I can get jax to work from the binary wheel with the 10.1 version...

Could you maybe tell us how you did it? (I'm guessing environment vars, but I can't get it to work without a system-wide CUDA installation...)

@rahuldave I was able to get nvcc from cudatoolkit-dev as you said, but did you run into the following at all where it cannot find the runtime file?
image

Figured it out - just had to include the directory in my environment with cuda_runtime.h during the call to nvcc. Found that directory using locate

Closing since this bug looks stale. Feel free to reopen if there's some action for us to take!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

rdaems picture rdaems  Â·  3Comments

sussillo picture sussillo  Â·  3Comments

RobertTLange picture RobertTLange  Â·  3Comments

lonelykid picture lonelykid  Â·  3Comments

kunc picture kunc  Â·  3Comments