Jax: convert jax array to torch tensor

Created on 16 Dec 2019  路  7Comments  路  Source: google/jax

Is there an efficient way to load a JAX array into a torch tensor? A naive way of doing this would be

import numpy as np

np_array = np.asarray(jax_array)
torch_ten = torch.from_numpy(np_array).cuda()

As far as I can see, this would inefficient because the array is moved to CPU and then back on the GPU again.

Just to be clear: I am not interested in any gradient information being preserved. Only the entries of the array need to be loaded.

Many thanks for this amazing framework. I love it!

enhancement

Most helpful comment

I believe DLPack support is now fully working at GitHub head. You'll either need to build jaxlib from source at head or wait until we make another jaxlib wheel release (0.1.39) before you can use it.

I also implemented __cuda_array_interface__, although there's a PyTorch bug that means you can't directly import the result into PyTorch. See https://github.com/google/jax/issues/1100 .

Hope this helps!

All 7 comments

See also #1100.

I'm also interested. I've been taking a look at the issue that @jekbradbury linked, and there are other helpful links there. For example, https://github.com/VoVAllen/tf-dlpack, which implements a way to turn Tensorflow tensors into Torch ones using DLPack. This is somewhat of a progress report.

I think the way to go is to implement the same, but as a function that takes in a PyLocalBuffer from local_client.h. This object can be accessed from a Jax DeviceArray by using the attribute .device_buffer, and it contains the information we need.

The main trouble I'm having is Bazel. I've been trying to compile something that #includes the local_client.h header above, but I haven't managed. The main problem is that the target '@org_tensorflow//tensorflow/compiler/xla/python:local_client' is not visible from outside of TensorFlow.

My question for folks who know Bazel is: Do we need to fork TensorFlow to access '@org_tensorflow//tensorflow/compiler/xla/python:local_client' from the outside? Is there a way to force visibility?

My totally uninformed guess is that you'll have a not-great time accessing the local client APIs from outside the TF repository, and you might be best off either writing your __cuda_array_interface__/DLPack wrappers inside the XLA Python client or exposing the necessary C++ APIs through pybind and then using them from Python code.
CC @skye and @hawkinsp

I don't think there's a way to force visibility, and forking TF is also not an ideal solution here. This is likely gonna require a change to TF itself, but I can help you through that process. What do you need to access from the PyLocalBuffer? Implementing __cuda_array_interface__ seems like a good option IMO.

I've taken a second look at it, I was a bit hasty. Using the Python interface in xla.cc, it seems possible to implement __cuda_array_interface__ fully in Python. More relevant for the Jax->Pytorch use case, it is possible to access all the information needed to create a DLArray. The information needed is almost the same as for __cuda_array_interface__, and it is (cribbed from Pytorch's DLConvertor)

  • Pointer to the data on the GPU (use .unsafe_buffer_pointer)
  • number of dimensions, dtype, shape (easy to access, even from the DeviceArray)
  • A "deleter" function, that deallocates the array (only needed for DLArray). Just pass the .delete Python method.
  • strides I can't find this anywhere. I deduce that all PyLocalBuffers have stride 1 in every direction, that is, no striding. This implies they return a copy of the underlying data every time they're transposed or sliced. This seems to be verified by poking around:
import jax.numpy as np
a = np.arange(25).reshape((5, 5)).astype(np.float32)
a[::2, ::2].device_buffer.unsafe_buffer_pointer() == a[::1, ::1].device_buffer.unsafe_buffer_pointer()
# False

Anyways, this makes it easy to implement __cuda_array_interface__ and to produce DLArrays.

To support the Pytorch->Jax use case, we would need to consume DLArrays. This would require creating a PyLocalBuffer given a device pointer and a shape, that calls a "deleter" function when it is deleted. This seems a lot harder, it would require being able to access the PyLocalBuffer constructor. It maybe also requires adding another attribute that keeps track of the deleter function and modifying PyLocalBuffer::Delete to call it.

2123 has an initial implementation of DLPack support, which allows JAX/PyTorch interoperability.

Unfortunately it's still a little buggy; we need another C++ change which should show up in jaxlib 0.1.39 whenever we release it.

I believe DLPack support is now fully working at GitHub head. You'll either need to build jaxlib from source at head or wait until we make another jaxlib wheel release (0.1.39) before you can use it.

I also implemented __cuda_array_interface__, although there's a PyTorch bug that means you can't directly import the result into PyTorch. See https://github.com/google/jax/issues/1100 .

Hope this helps!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

sursu picture sursu  路  3Comments

clemisch picture clemisch  路  3Comments

rdaems picture rdaems  路  3Comments

asross picture asross  路  3Comments

murphyk picture murphyk  路  3Comments