Jax: conda-based installation

Created on 3 Jan 2019  Â·  53Comments  Â·  Source: google/jax

Putting this here and tagging myself @ericmjl so that I can remember this exists.

To get jax into the hands of data scientists and machine learning researchers, conda installation would be very useful. I will take a stab at this on conda-forge, and record my progress here.

build contributions welcome enhancement

Most helpful comment

Please open issues about the conda-forge package in https://github.com/conda-forge/jaxlib-feedstock/issues for visibility.

Also, https://github.com/conda-forge/jaxlib-feedstock/pull/16 should fix it.

All 53 comments

That's fantastic, thanks Eric!

On Thu, Jan 3, 2019 at 7:54 AM Eric Ma notifications@github.com wrote:

Putting this here and tagging myself @ericmjl https://github.com/ericmjl
so that I can remember this exists.

To get jax into the hands of data scientists and machine learning
researchers, conda installation would be very useful. I will take a stab
at this on conda-forge, and record my progress here.

—
You are receiving this because you are subscribed to this thread.
Reply to this email directly, view it on GitHub
https://github.com/google/jax/issues/189, or mute the thread
https://github.com/notifications/unsubscribe-auth/AAJ4j3X8A5mSS07XGNkaMF5M2cehsW0Cks5u_WnhgaJpZM4ZnfRC
.

To start off, I tried getting jax onto my own personal channel on anaconda.org. If I can do this successfully, usually I am able to get it onto conda-forge with no problems.

Commands executed:

$ conda skeleton pypi jax
$ cd jax
$ conda build .

Everything builds correctly up till the point where the import tests run. jax imports jaxlib, and jaxlib needs to be on conda-forge and specified as a dependency of jax in order for the jax build process to work properly.

Unfortunately, I don't see the a tarball for jaxlib on PyPI. Perhaps that needs to go up first?

The tarball urls can be found here
https://github.com/google/jax/blob/master/README.md#pip-installation

On Thu, Jan 3, 2019 at 8:19 AM Eric Ma notifications@github.com wrote:

To start off, I tried getting jax onto my own personal channel on
anaconda.org. If I can do this successfully, usually I am able to get it
onto conda-forge with no problems.

Commands executed:

$ conda skeleton pypi jax
$ cd jax
$ conda build .

Everything builds correctly up till the point where the import tests run.
jax imports jaxlib, and jaxlib needs to be on conda-forge and specified
as a dependency of jax in order for the jax build process to work
properly.

Unfortunately, I don't see the a tarball for jaxlib on PyPI. Perhaps that
needs to go up first?

—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub
https://github.com/google/jax/issues/189#issuecomment-451044559, or mute
the thread
https://github.com/notifications/unsubscribe-auth/AAJ4jy10OLOqxUDj3gZhjDIkzh9r3i_fks5u_W_QgaJpZM4ZnfRC
.

@alexbw I think those are the wheel URLs, not the tarballs. Does jaxlib have tarballs, or do they have to be built from source? If it's the latter, I might have to rope in some help from friends who are maintaining conda-forge.

@ocefpaf, one question for you - can we pull down Python wheels from a URL and use that pre-compiled wheel as part of a conda-forge-based recipe? Having had a night to sleep over this issue, it seems to me that building from source is going to be a painful thing to do on conda-forge, while having pre-compiled Python wheels installed into the correct location would be easier.

@ocefpaf, one question for you - can we pull down Python wheels from a URL and use that pre-compiled wheel as part of a conda-forge-based recipe? Having had a night to sleep over this issue, it seems to me that building from source is going to be a painful thing to do on conda-forge, while having pre-compiled Python wheels installed into the correct location would be easier.

Even though prefer building from source we do "repacking" in cases like that.

Here is an example: https://github.com/conda-forge/flask-restplus-feedstock/blob/d41ecd6077ba51df75cb15a2b06e737bdc43f8d6/recipe/meta.yaml

@ocefpaf thanks for the response! Another dumb question, hope you don't mind - the jaxlib wheels are for macOS and Linux only: https://pypi.org/project/jaxlib/#files

I plan to "repackage" only the Python 3 wheels. Is there a way for us to specify which repackaged wheel to be downloaded, based on OS? Or is this out of scope for conda-forge?

Thanks for driving this, Eric!

Just a question: how do we know that building will be painful on conda-forge? Our build process and build script are pretty simple if we can install and run bazel and meet the compiler toolchain requirements of TensorFlow. Since there are already conda packages for TF, we should be able to follow that setup, since the only thing JAX needs to compile from source is a sub-target inside TF. In other words, if we knew how to build TF on conda-forge, then we'd already know how to build what JAX needs, as it's a subset of TF.

@mattjj the bazel portion of installation takes quite a long while to run, which is what drives most of the "pain" from installation. If I am not mistaken, it may be a drain on free community resources to have to build jaxlib from scratch each time. @ocefpaf, do you have any input on this? For reference, it takes over 10 minutes on my home GPU tower to build from source.

I plan to "repackage" only the Python 3 wheels. Is there a way for us to specify which repackaged wheel to be downloaded, based on OS? Or is this out of scope for conda-forge?

Yep. Just use the pre-processor selectors like in this example.

@ocefpaf, do you have any input on this? For reference, it takes over 10 minutes on my home GPU tower to build from source.

At the moment we, conda-forge, cannot afford long builds (>1 hour) and we do not have GPU support yet. However, we are experimenting with azure pipelines to be able to do long builds, and I believe we may get even some GPUs. More on this soon...

This is very helpful, thank you @ocefpaf!

Wanted to ensure that there was a cross-reference. jaxlib conda-forge PR is here: https://github.com/conda-forge/staged-recipes/pull/7529

I mimicked the Tensorflow build recipe. Each time there's an update to jaxlib, the recipe, specifically build.sh and meta.yaml have to be updated.

No builds happen for Windows, as it is currently unavailable. To encourage Py3k adoption, I also intentionally did not include the Python 2 wheels in the build.

Looping back here about conda, guys. I tried submitting a PR for just jaxlib: https://github.com/conda-forge/staged-recipes/pull/7529

However, it appears that there is an issue with the macOS build, which is only resolvable by building from source.

I think things will be cleaner if jaxlib is separated from jax. @mattjj, I remember this was on your roadmap before - am I remembering correctly, or am I mistaken about this?

Yes, we had a plan to separate out jaxlib and call it xlapy, though it hasn't been a high priority compared to other work because there hasn't been a clear upside.

Is it possible to have separate conda packages for jax and jaxlib, without splitting the git repository? I ask mainly because we might not have time to dig into this for a while.

@mattjj I can try something - e.g. downloading an zip or tar archive of the whole jax repository, and then building from source.

Could you guys put up a tagged release on GitHub? That's generally better received by the conda-forge admins, and it'll give me a so-called "point release reference" that I can point the build recipe against, rather than always building against an ever-evolving master :smile:.

Hah, looks like I was one step too fast for the conda-forge admins.

I just updated the conda-forge PR: https://github.com/conda-forge/staged-recipes/pull/7529

Looks like there's a fix for the error being encountered before.

Could you guys put up a tagged release on GitHub?

Is it viable to instead just clone a specific git commit? We might be able to do tagged releases and stuff, but I'd like to minimize the number of potential blockers for you.

Any updates? Anything we can help with?

It looks like other folks are eager for Conda packages too (https://github.com/google/jax/issues/302)

Thanks!

@hawkinsp I'm still working on it! :smile: But yes, I've run into conda build issues at the moment. I stepped away from the conda-forge build for a little while (kwargs="work"), but you can track progress here: https://github.com/conda-forge/staged-recipes/pull/7529

That PR, btw, is just a "copy the wheels over" PR to enable jax on CPU to be distributed by conda-forge. I think it will be more difficult to get jax + GPU over, unless there's a build process at Google you guys could use to release jax+jaxlib wheels targeted for various CUDAs? (If so, then I could re-attempt things.)

Well, we build wheels for Python {2.7, 3.6, 3.7} x CUDA {9.0, 9.2, 10.0} already, which are the ones linked here:
https://github.com/google/jax#pip-installation

Our build script is open source: it is done by this script using a Docker container: https://github.com/google/jax/blob/master/build/build_jaxlib_wheels.sh

Ideally we could somehow build one wheel and distribute it for both Conda and Pypi, but if needs be we could also build separate Conda packages as well. We just need to know how to build one...

Well, we build wheels for Python {2.7, 3.6, 3.7} x CUDA {9.0, 9.2, 10.0} already, which are the ones linked here:
https://github.com/google/jax#pip-installation

Oh!!! I'm sorry, I missed that. I was somehow focused on just the CPU versions. My bad.

Ideally we could somehow build one wheel and distribute it for both Conda and Pypi, but if needs be we could also build separate Conda packages as well. We just need to know how to build one...

conda-forge build scripts are basically specified entirely by a YAML file, or as a YAML file + some other scripts. The latter is what I tried doing in that PR, specifically here to get jaxlib into conda-forge.

@ocefpaf, maybe you could provide some guidance to the jax team on where the docs for conda-forge recipes live? I've been doing packages for as long as conda-forge has been around, so it's kind of in my head now, but I know I've stumbled quite a few times because of the evolving infrastructure.

The part where jax's distribution doesn't match my mental model of packaging is that I need to build two things (jax and jaxlib) for jax to work (i.e. jax depends on jaxlib), but jaxlib lives in the same repository as jax; usually dependencies are other packages maintained by other people, so I just have to worry about my own package at hand. In other words, I've usually seen some separation between package X and dependency of X. Though, maybe I'm not as seasoned enough as a software person and have only encountered the simple cases ^_^.

It's not out of the question we could split jaxlib (the mostly C++ part) into a separate repository from jax (the pure Python part). But it would be helpful to know what the constraints are before we start moving things around.

Just to some input: Conda (more specifically anaconda) has it's own build env that it supplies so that the software can run in most places with relative ease. This includes it's own glibc allowing software built properly overtop of conda to run anywhere conda can be installed. When this environment isn't honored it can cause issues on systems anaconda otherwise works just fine on, centos 7 for example.
We are starting to get requests to install this software on our computational cluster which sure we could do from source but if it can be reliably installed off conda-forge which already manages the build env, and can bring in the cuda libraries on demand for gpu based code that is much easier to injest and maintain.

@alexbw @dougalm @mattjj @hawkinsp thanks to @ocefpaf, there is a conda package for jax and jaxlib on conda-forge now! He kindly worked on packaging what is currently distributed on PyPI while at the SciPy 2019 sprints.

It is currently only Python 3.7/CPU only. This is because it the recipe simply pulls down the py37 version (hard-coded). To the best of my knowledge, there’s no “elegant” way of distributing the CUDA-enabled packages using the same recipe given the current way jax and jaxlib are distributed on PyPI. If we could build jaxlib on conda-forge, that would greatly simplify the conda-based distribution story. I think @ocefpaf has more details than I do, as he knows the issues that he ran into trying to build jaxlib on conda-forge while at the sprints.

Sorry if this is the wrong place to bring this up; I'm a fairly naive user just trying to see of some of the numpy fitting code I've written for my dissertation can be accelerated by putting in some jax in the right spots.

I have found that when installing from conda the import for jax currently fails with a ModuleNotFoundError: No module named 'fastcache'. I was able to rectify this by simply calling conda install fastcache in the correct env, which makes me think that the conda recipe may not have the correct dependencies.

I'm on ubuntu 18.04, and I'm using the most recent conda: 4.7.12. I can replicate the issue by creating a new env, installing jax from the conda-forge channel, then popping open the interpreter and trying to import jax. When I exit, use conda to install fastcache, then re-enter the interpreter, I can run blurbs of example code from the main github page with no issue.

I'm happy to make this a separate issue, but it seemed intimately related to the conversation here. Thanks for going to the trouble of trying to get this on conda.

@lgsmith I'm one of the recipe maintainers, and have just put in a PR to add fastcache.

Woohoo thanks @ocefpaf and @ericmjl ! Somehow I missed your comment back in July.

What are your thoughts on what to do with the conda recipe going forward? Should we do things on the JAX side to support it better?

Should we do things on the JAX side to support it better?

We, conda-forge, are not building it from source yet and we should. I recall facing a few difficulties when I tried that. (Sorry, I don't remember them now :grimacing:)

When I send a PR to do that I'll ping you and let you know what they were.

Thanks @ericmjl!

To the jax people and the conda forge people, again thanks for putting this up there; those of us in scientific computing would find having higher performance versions available as binaries as managed by conda completely wonderful. Or even conda based builds, as is done with psi4. I for one have cuda on my system for other reasons, and am worried that I'll start having other problems if I change my cuda version so it matches a cudnn version that allows me to build the cuda enabled jax from source on my machine. Having all of that insulated in a conda env, even if I still have to run a compile script once it's all downloaded, would be tremendous.

Thanks all for the great work with the package. I've been trying to get jax installed on my system, but have failed to get jax to use the GPU (i.e., when I run a jax command, a warning comes up saying that it is falling back to the CPU). To be clear, I receive the above when I install jax through conda.

FWIW, building from source fails because of some idiosyncrasies in the CUDA installation on the machine I'm trying to install on. Specifically, it errors when reading the cudnn.h file since cudnn.h has major number 6 listed and the libcudnn.so on my system seems to be from version 7; it seems the build uses the major version listed in cudnn.h to look for libcudnn.so.X where X is the major version number (and my libcudnn.so ends with 7). Unfortunately the administrator of the system is away for a couple weeks and won't be able to address the underlying issue.

But is the GPU not being detected a known issue with the conda build or is this just on my system? Perhaps due to my bad CUDA install. I am able to run pytorch GPU-enabled code, seemingly fine, on my system. But I know pytorch uses the CUDA toolkit in conda and jax does not.

Any guidance would be really appreciated and I can provide more details if needed. Apologies if this is in the wrong place; I see issue (#302), but that appears to be specific to building from source. I'd prefer to skip building from source if possible.

@jcreinhold at the moment, only CPU builds are supported on the conda-forge recipe. @ocefpaf was the one who really got it going; for me, I just update the build recipe as jax and jaxlib get updated on PyPI.

The GPU builds are tough to manage via conda. @ocefpaf has more insider knowledge than I do, as I don’t have the requisite knowledge to elaborate further. We would definitely welcome help getting the GPU builds working.

As for your install issues on a local machine from source, have you tried installing in an isolated env?

From a broader perspective, as far as I can tell, I feel a lot like the dog who knows not what he is doing whenever I (re-)install CUDA on my home GPU tower. If only NVIDIA could make installation of their drivers and CUDA much easier, that would solve so many problems that we have! :smile:

Thanks for the response @ericmjl and the background info. For posterity, I tried to install it on an isolated env, but no luck. I ran into an error when building from source due to having GCC >v6 (the GCC version on the system is v7). I tried to locally install an earlier version of GCC, but I couldn't get it the install to recognize the local installation.

Does the GCC version number problem sound familiar? Or am I just messing something up?

Hmmm, the gcc version number problem doesn’t sound familiar at all. @hawkinsp @mattjj @dougalm do you guys have any insight here?

The GPU builds are tough to manage via conda. @ocefpaf has more insider knowledge than I do, as I don’t have the requisite knowledge to elaborate further. We would definitely welcome help getting the GPU builds working.

There has been some progress on the GPU builds. @jakirkham knows more about the details and/or if we are ready to try it with jax. (I'm deep into the day job these days and not a lot of time for conda-forge, sorry.)

Hello! Any updates on this? Can we install jax via conda?

@cossio if you do conda search -c conda-forge jax, you'll see what versions and platforms are available.

Keep in mind it's not available for all versions of Python and platforms, so YMMV.

@ericmjl Thanks. It's weird that conda search -c conda-forge jax returns a list of versions including a 0.1.57, even though the latest Jax release (according to https://github.com/google/jax/releases) is jax-v0.1.55.

image

That's confusing :confused:

Which one should I install?

@cossio the conda-forge builds come from PyPI, which is the authoritative source of releases (usually); the tagged releases on GitHub don't have a good automated process in place (unless the jax maintainers have something), so I'd advise to use PyPI as the authoritative source of versions.

If you search there, 0.1.57 is the latest, and conda-forge simply copies the right files over to make them conda-distributable.

@ericmjl Thanks for the explanation!

Ran into a conda install issue on linux today. I think this is from a very recent change (2 ish hours based on the conda forge site). Installing jax runs, and brings in jaxlib, but jaxlib for linux is still 0.1.37 and needs to be 0.1.38 according to errors I got when I imported it into my conda interpreter.

I tried rolling back the jax installation in a clean env (miniconda base, channel priority is conda forge strict, so minimized contaminants) by doing conda create -n jt jax=0.1.58. This 'worked' in the sense that I no longer got a version error when I tried to import jax, but I still couldn't import numpy. See the following error:

>>> from jax import numpy
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ImportError: cannot import name 'numpy' from 'jax' (unknown location)

Things are working fine on my mac, which I believe is related to the mac version of the repository having jaxlib 0.1.38. Things are also working in a different conda env on the same machine that had installed jax earlier this week (Sunday, IIRC). Is there a quick fix for this?

EDIT: I found that by doing conda create -n jaxtest jax=0.1.58 jaxlib=0.1.37 I was able to get a working version of the two. It seems one would need to specify versions for both if one wants to roll back?

jaxlib 0.1.38 was missing from conda-forge due to an unseen CI failiure. Just re-started the CIs and it should be up soon.

Thanks for the timely response.

Perhaps I don't understand what's going on here, but it is still not working, although the anaconda website has now bumped the linux jaxlib version number to 0.1.38. Will it take longer than the website takes to update to actually be able to install the updated package?

Will it take longer than the website takes to update to actually be able to install the updated package?

Yes. The CDN takes a while to sync. Did you try it again? It should be ok now.

Ah that's funny. I'd figure the website would be restricted to match whatever you actually pull down with the command. Rerunning things I got a download progress meter for the new jax, and things work when I pop open an interpreter. Thanks for checking into this.

Are there also build for python versions different from version 3.7. of jaxlib?

A conda search -c conda-forge jaxlib shows this:

# Name                       Version           Build  Channel             
jaxlib                        0.1.21  py37h5ca1d4c_0  conda-forge         
jaxlib                        0.1.22  py37h5ca1d4c_0  conda-forge         
jaxlib                        0.1.27  py37h5ca1d4c_0  conda-forge         
jaxlib                        0.1.28  py37h5ca1d4c_0  conda-forge         
jaxlib                        0.1.29  py37h5ca1d4c_0  conda-forge         
jaxlib                        0.1.31  py37h5ca1d4c_0  conda-forge         
jaxlib                        0.1.32  py37h5ca1d4c_0  conda-forge         
jaxlib                        0.1.33  py37h5ca1d4c_0  conda-forge         
jaxlib                        0.1.35  py37h5ca1d4c_0  conda-forge         
jaxlib                        0.1.36  py37h5ca1d4c_0  conda-forge         
jaxlib                        0.1.37  py37h5ca1d4c_0  conda-forge         
jaxlib                        0.1.38  py37h5ca1d4c_0  conda-forge         
jaxlib                        0.1.39  py37h5ca1d4c_0  conda-forge  

Not at this moment. Help with getting the recipe work with py38 is much welcome, as the conda recipe this is basically maintained by myself and @ocefpaf on volunteer time.

What's the issue with the recipe? Just changing python version in the recipe seemed to work. At least I got a simple example to run.

Please open issues about the conda-forge package in https://github.com/conda-forge/jaxlib-feedstock/issues for visibility.

Also, https://github.com/conda-forge/jaxlib-feedstock/pull/16 should fix it.

why is this still not officially supported?

@dangpzanco come volunteer, and you can help make it happen! (Whatever it is you want, which wasn’t clear from your post.)

@ericmjl Oh, I'm sorry it wasn't clear. I meant that JAX doesn't have the same care for conda support like, for example, Pytorch does. I praise your volunteering, but I hoped JAX's main team would make an effort for it to be a bit more accessible (and not just leave the community by itself doing all the work). And I know it isn't the scope of this issue, but lacking Windows support is also a problem for many potential users (me included)...

why is this still not officially supported?

Well, we have a small team, and so even though there's a lot of work worth doing, we have to choose what to prioritize.

I believe JAX works great with WSL, so if that works on your Windows setup you might want to give it a try.

@ericmjl I love your attitude! We've benefitted a huge amount from open-source contributions, and I hope we can get a lot more over time! We're all on the same team here, doing our best to push things forward.

why is this still not officially supported?

Well, we have a small team, and so even though there's a lot of work worth doing, we have to choose what to prioritize.

I believe JAX works great with WSL, so if that works on your Windows setup you might want to give it a try.

@ericmjl I love your attitude! We've benefitted a huge amount from open-source contributions, and I hope we can get a lot more over time! We're all on the same team here, doing our best to push things forward.

WSL gpu ops are 2/3 times slower than on native linux (tf2 and jax) or on windows10 (tf2) so still would be nice to have w10 support :P

Was this page helpful?
0 / 5 - 0 ratings

Related issues

harshit-2115 picture harshit-2115  Â·  3Comments

sschoenholz picture sschoenholz  Â·  3Comments

rdaems picture rdaems  Â·  3Comments

clemisch picture clemisch  Â·  3Comments

alexbw picture alexbw  Â·  3Comments