Jax: Support __cuda_array_interface__ on GPU

Created on 2 Aug 2019  路  10Comments  路  Source: google/jax

https://numba.pydata.org/numba-doc/dev/cuda/cuda_array_interface.html

It would not be hard to make DeviceArray implement this interface on GPU.

It would be slightly harder to support wrapping a DeviceArray around an existing CUDA array, but not that hard.

enhancement

Most helpful comment

Edit: I am wrong, apparently we don't support imports.

Yeah this need came up again recently ( cc @leofang @quasiben ).

All 10 comments

I think even just supporting one of the directions (i.e. making DeviceArray implement this interface on GPU) would already be a great addition.
I would be happy to help, but I am not sure where to find the pointer to GPU memory / what else to pay attention to.

TensorFlow now supports dlpack: https://github.com/VoVAllen/tf-dlpack/issues/3

PR #2133 added __cuda_array_interface__ export. You'll need a jaxlib built from GitHub head or you'll need to wait for us to make another jaxlib wheel release.

Because of https://github.com/pytorch/pytorch/issues/32868 you can't directly import the resulting arrays to PyTorch. But because of https://github.com/cupy/cupy/issues/2616 you can "launder" the array via CuPy and into PyTorch if you want.

(Another option for interoperability is DLPack, which JAX supports at Github head, in both directions.)

Could this be reopened until import support is added as well?

I don't follow. We support both directions, I believe?

Edit: I am wrong, apparently we don't support imports.

Edit: I am wrong, apparently we don't support imports.

Yeah this need came up again recently ( cc @leofang @quasiben ).

Although note that DLPack imports should work, so that's an option if the exporter supports DLPack.

Thanks John! Yeah we just finished a GPU Hackathon, and a few of our teams evaluating JAX asked us why JAX can't work with other libraries like CuPy and PyTorch _bidirectionally_. It'd be very useful, say, to do autograd in JAX, postprocess in CuPy, then bring it back to JAX.

Also: I haven't tried this, but since CuPy supports both __cuda_array_interface__ and DLPack, you can most likely "launder" an array via CuPy into JAX:

  • export the array via __cuda_array_interface__ to CuPy.
  • export the array via DLPack from CuPy.
  • import the DLPack into JAX.

(Obviously this isn't ideal, but it might unblock you.)

Hi @hawkinsp I recently pm'd @apaszke in an occasion where this support was mentioned. It'd be nice if JAX can prioritize the bi-directional support for the CUDA Array Interface (and update to the latest v3 protocol, in which the synchronization semantics is specified).

As you pointed out in a DLPack issue (https://github.com/dmlc/dlpack/issues/50), DLPack lacks the support for complex numbers and it's unlikely to be resolved in the foreseeable future. For array libraries this is simply not an acceptable workaround and is actually a blocker for several applications that I am aware.

Thanks, and happy new year!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

shyoshyo picture shyoshyo  路  26Comments

christopherhesse picture christopherhesse  路  32Comments

dionhaefner picture dionhaefner  路  22Comments

JuliusKunze picture JuliusKunze  路  23Comments

dwang55 picture dwang55  路  22Comments