Jax: Add support for AMD GPUs

Created on 16 Jan 2020  Â·  35Comments  Â·  Source: google/jax

Is it possible to run JAX on other GPU architectures other than NVIDIA (ex.: Intel, AMD)?

contributions welcome

Most helpful comment

I was able to build jax with initial support for ROCm (AMD GPUs) by compiling it using XLA from ROCmSoftwarePlatform/tensorflow-upstream (update: after https://github.com/tensorflow/tensorflow/pull/45344 you can use upstream TF) and adding a few options to the build scripts.

The code can be found here: inailuig/jax (update: after https://github.com/google/jax/pull/5114 you can use upstream jax)

Executing

import jax
print(jax.devices())
print(jax.devices()[0].device_kind)
x = jax.numpy.array([1.2, 3.4, 5.6])
y = jax.numpy.exp(x)
print(y)

on my RX480 outputs

[GpuDevice(id=0)]
Ellesmere [Radeon RX 470/480/570/570X/580/580X/590]
'+code-object-v3' is not a recognized feature for this target (ignoring feature)
'+code-object-v3' is not a recognized feature for this target (ignoring feature)
'+code-object-v3' is not a recognized feature for this target (ignoring feature)
[  3.3201168  29.964104  270.4264   ]
2020-11-22 20:40:04.841794: E external/org_tensorflow/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc:614] Deallocating stream with pending work
2020-11-22 20:40:04.842168: E external/org_tensorflow/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc:614] Deallocating stream with pending work
2020-11-22 20:40:04.842517: E external/org_tensorflow/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc:614] Deallocating stream with pending work
2020-11-22 20:40:04.842866: E external/org_tensorflow/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc:614] Deallocating stream with pending work
2020-11-22 20:40:04.844206: E external/org_tensorflow/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc:614] Deallocating stream with pending work

which already looks very promising.
However there are still things missing such as the custom gpu kernels in jaxlib (cublas, cuda_prng, cusolver).

For those who want to build this:
I am running Ubuntu 20.04.1 with rocm 3.9.0 installed using the official instructions.
Also it is necessary to install these additional packages:
rocm-dev miopen-hip rocfft rocblas rccl hipsparse rocrand rocsolver hipblas
Then the whole thing can be built with
python3 build/build.py --enable_rocm --rocm_path /opt/rocm-3.9.0
Optionally different amdgpu targets can be specified with --rocm_amdgpu_targets (see here). For now I put in some default targets, however autodetection does also work (by passing "" (an empty string) which overrides the default).

All 35 comments

In principle, sure! All we need is XLA to support that architecture.

In practice that means we support at the moment: CPU, NVidia GPU, and TPU.

Happily AMD has been contributing support for AMD GPUs to XLA. We haven't tried it out in JAX, but assuming the XLA support is complete, I see no good reason it wouldn't work with a few small JAX changes. If you are excited about AMD GPUs, we'd certainly welcome contributions enabling that functionality in JAX.

I don't think Intel GPUs have XLA support at the moment, but I wouldn't rule it out in the future as the various compiler toolchains (e.g., XLA, MLIR) progress.

The AMDGPU backend for XLA is being actively developed; these PRs probably have the most up-to-date status (seems like many but not all tests pass?)

One thing to note is that the AMD integrations require that you rebuild XLA from source; there's no way to build a single TF or XLA binary that can use both NVIDIA CUDA and AMD ROCm.

For Intel hardware, I imagine we'd need something like MLIR translation from HLO dialect to nGraph dialect. I'm guessing nobody is actively working on that, but ccing @nmostafa in case Intel has plans in that area.

Glad to see this ROCm thing seems to be funded with fulltime developers by AMD. Better late than never I suppose. I hope they learned at least a little from their misadventures in GPGPU, with opencl being half-assedly supported; and in practice if you wanted to get anything done, you had no choice to go with the platform that didnt require you to say, reinvent your FFT libraries from scratch. I hope this time around they realize there is some minimum investment in software theyd be smart to make, if they want to offer a competitive ecosystem. Its crazy to see how much money nvidia has made off this; in the meanwhile google adds a completely new viable hardware and software alternative in the forms of TPUs; and AMD is still working on getting compatibility with any of the software out there. It does not inspire much confidence to be honest; it seems wise to bet against them ever getting out a robust feature complete alternative, if they couldnt even get anything out 4 years ago already. But id love to be wrong about this, and for there to be some genuine competition in desktop ML acceleration in the future.

Can someone help me how to use Jax on AMD GPUs? Are there any code snippets we can start with?

Any update on the topic?
How can that happen tensorflow supports AMD GPU-s but JAX doesn't?
Isn't ROCM is the CUDA for AMD GPU-s and inplace replacements of each others?

There's no technical blocker to using JAX on AMD GPUs. We on the JAX team simply don't have access to any AMD GPUs at the moment to develop or test the necessary changes (which are probably not that large, given most of the necessary work has been done in the context of TensorFlow.)

Contributions are welcome!

The AMDGPU backend for XLA is being actively developed

That's good to know @jekbradbury, thanks

I just wanted to ask, when we are taking about AMD GPUs being supported, is it going to be on all platforms (i.e. including MacOS) or are we talking Linux/Windows only?

I believe the AMDGPU backend support for XLA is based on ROCm, which doesn't support macOS.

I was able to build jax with initial support for ROCm (AMD GPUs) by compiling it using XLA from ROCmSoftwarePlatform/tensorflow-upstream (update: after https://github.com/tensorflow/tensorflow/pull/45344 you can use upstream TF) and adding a few options to the build scripts.

The code can be found here: inailuig/jax (update: after https://github.com/google/jax/pull/5114 you can use upstream jax)

Executing

import jax
print(jax.devices())
print(jax.devices()[0].device_kind)
x = jax.numpy.array([1.2, 3.4, 5.6])
y = jax.numpy.exp(x)
print(y)

on my RX480 outputs

[GpuDevice(id=0)]
Ellesmere [Radeon RX 470/480/570/570X/580/580X/590]
'+code-object-v3' is not a recognized feature for this target (ignoring feature)
'+code-object-v3' is not a recognized feature for this target (ignoring feature)
'+code-object-v3' is not a recognized feature for this target (ignoring feature)
[  3.3201168  29.964104  270.4264   ]
2020-11-22 20:40:04.841794: E external/org_tensorflow/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc:614] Deallocating stream with pending work
2020-11-22 20:40:04.842168: E external/org_tensorflow/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc:614] Deallocating stream with pending work
2020-11-22 20:40:04.842517: E external/org_tensorflow/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc:614] Deallocating stream with pending work
2020-11-22 20:40:04.842866: E external/org_tensorflow/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc:614] Deallocating stream with pending work
2020-11-22 20:40:04.844206: E external/org_tensorflow/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc:614] Deallocating stream with pending work

which already looks very promising.
However there are still things missing such as the custom gpu kernels in jaxlib (cublas, cuda_prng, cusolver).

For those who want to build this:
I am running Ubuntu 20.04.1 with rocm 3.9.0 installed using the official instructions.
Also it is necessary to install these additional packages:
rocm-dev miopen-hip rocfft rocblas rccl hipsparse rocrand rocsolver hipblas
Then the whole thing can be built with
python3 build/build.py --enable_rocm --rocm_path /opt/rocm-3.9.0
Optionally different amdgpu targets can be specified with --rocm_amdgpu_targets (see here). For now I put in some default targets, however autodetection does also work (by passing "" (an empty string) which overrides the default).

@inailuig That's exciting progress! Nice work! (Sorry for the slow response, many of us were on vacation this last week.)

Technically speaking the cublas/cusolver and cuda_prng kernels are somewhat optional. The cuda_prng kernel is a compile-time optimization and can be safely omitted (at the cost of increased compile time), and cublas/cusolver are only needed for linear algebra support. So it might be possible to check things in even before those pieces work.

I'm curious: is it possible to use upstream TF instead of the ROCm fork? We frequently update our TF (XLA) version, so any ROCm specific fork is likely to be stale.

@hawkinsp Turns out all that is missing in upstream TF is actually looking for devices with the right platform i.e. some changes in tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc (from this commit: https://github.com/ROCmSoftwarePlatform/tensorflow-upstream/commit/0ba02369635a60dfbc28d5583e521999f519c9f1)

diff --git a/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc
index 4863e5e8165..870007f1dca 100644
--- a/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc
+++ b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc
@@ -57,11 +57,19 @@ xla::StatusOr<xla::DeviceAssignment> GpuClient::GetDefaultDeviceAssignment(

 // Builds an xla::LocalClient for the GPU platform.
 StatusOr<LocalClient*> GetGpuXlaClient() {
+#if GOOGLE_CUDA
   TF_ASSIGN_OR_RETURN(se::Platform * platform,
                       PlatformUtil::GetPlatform("CUDA"));
   if (platform->VisibleDeviceCount() <= 0) {
     return FailedPrecondition("No visible NVidia GPU devices.");
   }
+#else
+  TF_ASSIGN_OR_RETURN(se::Platform * platform,
+                      PlatformUtil::GetPlatform("ROCm"));
+  if (platform->VisibleDeviceCount() <= 0) {
+    return FailedPrecondition("No visible AMD GPU devices.");
+  }
+#endif
   LocalClientOptions options;
   options.set_platform(platform);
   return ClientLibrary::GetOrCreateLocalClient(options);

Do you think we could get something like that upstreamed into TF ?

For cuda_prng and the cublas/cusolver kernels I was also able to get them running (2 or 3 of the lapack functions (cusolver) are not yet implemented in rocsolver, but everything else is there; also requires a few more changes to TF; I will post more once I cleaned it up a bit)

We certainly can upstream something like that. That file is really part of JAX so we can change it as we see fit. You can send PRs to TensorFlow and assign me; I can review.

@hawkinsp @inailuig

Thank you for trying out JAX on AMD GPUs. I am on the TF framework team in AMD, and would like to get a better understanding of the TF changes that are required to get JAX working. We would be more than happy to help out.

I also had a question for you. Does JAX have unit-tests that run on GPUs, and if so can you point me to the directions to run them. I would like to get them running on internally on our platform,

thanks again

deven

@deven-amd We'll need to wait for @inailuig to send out their remaining changes to get things to build.

Once those changes are checked in, the best way to do this is probably something like this:

git clone https://github.com/google/jax.git
git clone https://github.com/tensorflow/tensorflow.git /mydir/tensorfow
cd jax
python build/build.py --bazel_options=--override_repository=org_tensorflow=/mydir/tensorflow --enable_rocm
pip install dist/*.whl
pip install -e .
XLA_PYTHON_CLIENT_ALLOCATOR=platform pytest -n auto tests examples

This builds and installs jaxlib with TF (XLA) from head (rather than whatever version we have pinned in our WORKSPACE file). (You can also achieve this by editing the WORKSPACE file; see the comments in that file.)

The XLA_PYTHON_CLIENT_ALLOCATOR avoids using the BFC allocator which preallocates GPU memory, which means that we should be able to run tests in parallel using multiple processes (-n auto enables this).

I should note there are probably a few tests that fail at head on Nvidia GPUs also (https://github.com/google/jax/issues/5067).

Hi Peter,

Thanks for the quick response.

I will try out the directions you have provided + the docs, to get the JAX
unit tests working on the ROCm platform. I expect to work on this next
week, will ping you if I run into any issues. In case I do, would you
rather I email you directly or file an issue on the JAX github repo?

Thanks

deven

On Fri, Dec 4, 2020 at 12:10 PM Peter Hawkins notifications@github.com
wrote:

See also
https://jax.readthedocs.io/en/latest/developer.html#running-the-tests

—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
https://github.com/google/jax/issues/2012#issuecomment-738896809, or
unsubscribe
https://github.com/notifications/unsubscribe-auth/AIZGTXBBS2FMGNUOINZHVXTSTEJYHANCNFSM4KHSBE2Q
.

@deven-amd

If there's no reason otherwise, we like to do development in the open so the community can be involved. So I'd file issues/PRs or use Github discussions. You can ping me in any issues or PRs if you want to make sure I take a look!

@deven-amd Thanks for reaching out, would be great if you could in particular help with fixing the tests which are still failing.

I just opened https://github.com/google/jax/pull/5114 for the remaining build related stuff in jax.
In general things seem to be working.

However there are still some tests failing because of bugs (e.g. stuff related to conv, dot_general, triangular solve, ...)
Other Features are simply not implemented yet for ROCm in XLA (e.g. TRSM for complex args).
For the latter we will have to identify and skip them.

Also there is this error message

 E external/org_tensorflow/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc:614] Deallocating stream with pending work

which keeps popping up when the program terminates. @deven-amd would you be able to look into this?


For the BLAS/LAPACK wrappers (i.e. jaxlib/cusolver.py and the related pybind modules but for rocm)
I mostly followed what @hawkinsp did for cuda here since its just lots of glue code around roc/cu BLAS/Solver routines). This can be found in https://github.com/google/jax/pull/5115

For this to work we still need a few changes in TF:

  1. custom_call_thunk needs to be enabled and build for rocm: https://github.com/inailuig/tensorflow/commit/44d3a233c6971344d595aacbad1459e1822264cd
  2. in xla_client.py "CUDA" is hardcoded when you try to register a custom call target
    I suggest we fix this like so:https://github.com/inailuig/tensorflow/commit/be9602a7666eb05edc33faae7825f8401968e885
    (This still keeps CUDA as default when you pass 'gpu' unfortunately)
    Then we can register functions for ROCM like this:
    xla_client.register_custom_call_target(_name, _value, platform="ROCM")
    Everywhere else in jax we can keep 'gpu'.
  3. We need to add rocSolver targets to the build scripts somewhere (I think we should add this to TF, although I guess it would also be possibe to add them just to jax)
    For my attempt at this see: https://github.com/inailuig/tensorflow/commit/606d7933b39f4115f8aea61e25bceb906855b5bf
  4. Not strictly necessary but nice to have: rocm_library, see https://github.com/inailuig/tensorflow/commit/e08f34ca8fe49056407eeaa706556af891d6857d

All of this can be found in https://github.com/inailuig/tensorflow/tree/jax-rocm-gpukernels (there are 2 more commits which are useful for debugging, but not necessary)

@hawkinsp How should we proceed?

  1. Seems fine: I'd send that as a PR.
  2. Also looks fine to me. I might be tempted to change "gpu" to mean "register both CUDA and ROCM", which we could do by making xla_platform_names a dictionary whose values are a list of names and then register all of them.
  3. Seems plausible, and adding it to TF is probably the better place (that way, TF can share the build rules). I'm a bit surprised that TF doesn't have ROCSolver hooked up already.
  4. Also seems reasonable to me, but I'm not as sure about this.

Retitling this bug to focus on AMD GPUs only; we can open new bugs for other hardware vendors if needed.

@inailuig

@deven-amd would you be able to look into this?

How do I go about reproducing the error on my end?

Following the directions provided by @hawkinsp , I am able to build jax and run the tests on CPU platform.
Next step is to reproduce the behaviour you see on ROCm. If I understand correctly, this requires change both to TF and JAX.

I am building with

  • org_tensforlow pointing to to https://github.com/inailuig/tensorflow/tree/jax-rocm-gpukernels AND
  • using the JAX source from https://github.com/inailuig/jax/tree/rocm
    and then running the tests...but all tests come back PASS
...
======================================================= 10060 passed, 1070 skipped in 249.80s (0:04:09) ========================================================
rocm-user@prj47-rack-15:~/jax$ git status
HEAD detached at inailuig/rocm
Untracked files:
...

Seems like I am not doing something right.

As for the changes on the TF side (1 thru 4) in your post,
1, 3 & 4: I am in the process of those changes to the ROCm fork of TF. Let me know if you plan on creating PRs to get those changes into the TF repo, or if you want to me to push them out once they are in the ROCm fork (i.e. ROCm fork ---> upstream TF)
2: will changing the mapping from 'gpu: 'CUDA' to ` 'gpu' : 'GPU`` also achive the same effect?

@deven-amd I suspect your gpu is not detected for some reason, so the tests run on the cpu. (did you compile with --enable_rocm?)

See my initial example from above: https://github.com/google/jax/issues/2012#issuecomment-731843904
for how to check if the gpu is detected and used. Also you could try to reproduce the error with it.

If you want to test my rocblas/cublas wrappers as well you should use https://github.com/inailuig/tensorflow/tree/jax-rocm-gpukernels and https://github.com/inailuig/jax/tree/rocm-gpukernels

Otherwise you can use upstream TF and upstream jax (everything necessary is merged now; edit: except rng which wont work unless you remove cuda_prng.py).

this is what I get with the simple example

rocm-user@prj47-rack-15:/common/JAX$ python3 simple.py 
2020-12-07 22:25:07.839411: E external/org_tensorflow/tensorflow/stream_executor/rocm/rocm_driver.cc:982] could not retrieve ROCM device count: HIP_ERROR_NoDevice
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[CpuDevice(id=0)]
cpu
[  3.320117  29.964104 270.4264  ]

let me debug this further

@inailuig

I am now able to reproduce the error you get with the simple testcase

rocm-user@rocm-framework-14:/common/JAX$ python3 simple.py 
[GpuDevice(id=0)]
Vega 20
'+code-object-v3' is not a recognized feature for this target (ignoring feature)
'+code-object-v3' is not a recognized feature for this target (ignoring feature)
'+code-object-v3' is not a recognized feature for this target (ignoring feature)
[  3.320117  29.964104 270.42636 ]
2020-12-08 13:12:36.643101: E external/org_tensorflow/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc:614] Deallocating stream with pending work
2020-12-08 13:12:36.643758: E external/org_tensorflow/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc:614] Deallocating stream with pending work
2020-12-08 13:12:36.644247: E external/org_tensorflow/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc:614] Deallocating stream with pending work
2020-12-08 13:12:36.646026: E external/org_tensorflow/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc:614] Deallocating stream with pending work
2020-12-08 13:12:36.650082: E external/org_tensorflow/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc:614] Deallocating stream with pending work
2020-12-08 13:12:36.654172: E external/org_tensorflow/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc:614] Deallocating stream with pending work

The '+code-object-v3' is not a recognized feature for this target (ignoring feature) warning is a known issue, and will be gone soon (fix is in ROCm TF fork, will be upstreamed soon)

looking into the error messages now

@inailuig @hawkinsp

I have added all the TF side changes in the ROCm TF fork

https://github.com/ROCmSoftwarePlatform/tensorflow-upstream/pull/1198
https://github.com/ROCmSoftwarePlatform/tensorflow-upstream/pull/1201

I will file a PR soon to push those changes from ROCm fork --> upstream TF and cc you guys on it.
Until that PR is accepted, pointing org_tensorflow to the develop-upstream branch in the ROCm fork will help with testing for ROCm support.

The Deallocating stream with pending work error message can be ignored for now. It is being incorrectly issued in this case, and we are looking into isolating the cause and fixing it.

@deven-amd That's great! We do need these upstream changes though, because as I mentioned earlier, we often track upstream XLA changes pretty closely. So any fork will rapidly become stale.

Out of curiosity, are there docker containers for building with ROCm? We do our NVidia release builds inside a Docker container (https://cs.opensource.google/jax/jax/+/master:build/build_jaxlib_wheels.sh). No promises, but we might also be able to build AMD linux wheels together with our NVidia wheels, although we would have no way to test them.

The Deallocating stream with pending work error message can be ignored for now. It is being incorrectly issued in this case, and we are looking into isolating the cause and fixing it.

alright, thanks!

@deven-amd If you want to open a combined PR for everything then that is also fine by me, we can have a discussion on there especially about 2. (Otherwise I also would not mind submitting 1. and/or 2. myself)

@deven-amd I have another question:
currently some tests fail because of `rocBLAS does not currently support the TRSM operation for ...complex... (here: https://github.com/tensorflow/tensorflow/blob/6859f52a3fba6714b5360262f190c9649613ac5c/tensorflow/stream_executor/rocm/rocm_blas.cc#L2432 and also the next one)

Afaik rocBLAS does support ctrsm and ztrsm now, so would you be able to add and upstream them to TF?
For the meantime I have a workaround on the jax side: https://github.com/inailuig/jax/commit/6a9b49d3ee1045986a6100ee48017333ea927ee1

Out of curiosity, are there docker containers for building with ROCm?

@hawkinsp : docker pull rocm/rocm-terminal may work.
See https://github.com/RadeonOpenCompute/ROCm-docker

@hawkinsp , you can use containers from the following docker repo, as base containers for building with ROCm

https://hub.docker.com/r/rocm/dev-ubuntu-18.04/tags?page=1&ordering=last_updated

/cc @sunway513 for awareness

We do need these upstream changes though

@deven-amd If you want to open a combined PR for everything then that is also fine by m

just filed a combined PR - https://github.com/tensorflow/tensorflow/pull/45583.


btw, when I tried to build using the tip of TF, I am getting the following error

ERROR: /home/rocm-user/.cache/bazel/_bazel_root/15199cbabc3b1eef2a9e7002ce358bc9/external/org_tensorflow/tensorflow/compiler/xla/python/BUILD:162:1: C++ compilation of rule '@org_tensorflow//tensorflow/compiler/xla/python:py_client' failed (Exit 1): crosstool_wrapper_driver_is_not_gcc failed: error executing command 
...
external/org_tensorflow/tensorflow/compiler/xla/python/py_buffer.cc:90:16: error: 'bit_cast' is not a member of 'absl'
   return absl::bit_cast<std::uintptr_t>(ptr);
                ^~~~~~~~
...

hoping that this error is transient, and will be resolved on the TF side.


Afaik rocBLAS does support ctrsm and ztrsm now, so would you be able to add and upstream them to TF?

will look into this next.

btw, when I tried to build using the tip of TF, I am getting the following error

ERROR: /home/rocm-user/.cache/bazel/_bazel_root/15199cbabc3b1eef2a9e7002ce358bc9/external/org_tensorflow/tensorflow/compiler/xla/python/BUILD:162:1: C++ compilation of rule '@org_tensorflow//tensorflow/compiler/xla/python:py_client' failed (Exit 1): crosstool_wrapper_driver_is_not_gcc failed: error executing command 
...
external/org_tensorflow/tensorflow/compiler/xla/python/py_buffer.cc:90:16: error: 'bit_cast' is not a member of 'absl'
   return absl::bit_cast<std::uintptr_t>(ptr);
                ^~~~~~~~
...

hoping that this error is transient, and will be resolved on the TF side.

Also for me. Have to use the version pinned in the jax WORKSPACE for now.

@zhangqiaorjc has a fix coming for that error.

Just a quick status update:

Currently there are only ~60 tests left which fail: failed.txt examples.txt

and some tests which we should skip on ROCm for now: not_implemented.txt other.txt

Was this page helpful?
0 / 5 - 0 ratings

Related issues

clemisch picture clemisch  Â·  3Comments

sschoenholz picture sschoenholz  Â·  3Comments

madvn picture madvn  Â·  3Comments

DylanMuir picture DylanMuir  Â·  3Comments

murphyk picture murphyk  Â·  3Comments