Is there an efficient way to load a JAX array into a torch tensor? A naive way of doing this would be
import numpy as np
np_array = np.asarray(jax_array)
torch_ten = torch.from_numpy(np_array).cuda()
As far as I can see, this would inefficient because the array is moved to CPU and then back on the GPU again.
Just to be clear: I am not interested in any gradient information being preserved. Only the entries of the array need to be loaded.
Many thanks for this amazing framework. I love it!
See also #1100.
I'm also interested. I've been taking a look at the issue that @jekbradbury linked, and there are other helpful links there. For example, https://github.com/VoVAllen/tf-dlpack, which implements a way to turn Tensorflow tensors into Torch ones using DLPack. This is somewhat of a progress report.
I think the way to go is to implement the same, but as a function that takes in a PyLocalBuffer from local_client.h. This object can be accessed from a Jax DeviceArray by using the attribute .device_buffer, and it contains the information we need.
The main trouble I'm having is Bazel. I've been trying to compile something that #includes the local_client.h header above, but I haven't managed. The main problem is that the target '@org_tensorflow//tensorflow/compiler/xla/python:local_client' is not visible from outside of TensorFlow.
My question for folks who know Bazel is: Do we need to fork TensorFlow to access '@org_tensorflow//tensorflow/compiler/xla/python:local_client' from the outside? Is there a way to force visibility?
My totally uninformed guess is that you'll have a not-great time accessing the local client APIs from outside the TF repository, and you might be best off either writing your __cuda_array_interface__/DLPack wrappers inside the XLA Python client or exposing the necessary C++ APIs through pybind and then using them from Python code.
CC @skye and @hawkinsp
I don't think there's a way to force visibility, and forking TF is also not an ideal solution here. This is likely gonna require a change to TF itself, but I can help you through that process. What do you need to access from the PyLocalBuffer? Implementing __cuda_array_interface__ seems like a good option IMO.
I've taken a second look at it, I was a bit hasty. Using the Python interface in xla.cc, it seems possible to implement __cuda_array_interface__ fully in Python. More relevant for the Jax->Pytorch use case, it is possible to access all the information needed to create a DLArray. The information needed is almost the same as for __cuda_array_interface__, and it is (cribbed from Pytorch's DLConvertor)
.unsafe_buffer_pointer)dtype, shape (easy to access, even from the DeviceArray)DLArray). Just pass the .delete Python method.strides I can't find this anywhere. I deduce that all PyLocalBuffers have stride 1 in every direction, that is, no striding. This implies they return a copy of the underlying data every time they're transposed or sliced. This seems to be verified by poking around:import jax.numpy as np
a = np.arange(25).reshape((5, 5)).astype(np.float32)
a[::2, ::2].device_buffer.unsafe_buffer_pointer() == a[::1, ::1].device_buffer.unsafe_buffer_pointer()
# False
Anyways, this makes it easy to implement __cuda_array_interface__ and to produce DLArrays.
To support the Pytorch->Jax use case, we would need to consume DLArrays. This would require creating a PyLocalBuffer given a device pointer and a shape, that calls a "deleter" function when it is deleted. This seems a lot harder, it would require being able to access the PyLocalBuffer constructor. It maybe also requires adding another attribute that keeps track of the deleter function and modifying PyLocalBuffer::Delete to call it.
Unfortunately it's still a little buggy; we need another C++ change which should show up in jaxlib 0.1.39 whenever we release it.
I believe DLPack support is now fully working at GitHub head. You'll either need to build jaxlib from source at head or wait until we make another jaxlib wheel release (0.1.39) before you can use it.
I also implemented __cuda_array_interface__, although there's a PyTorch bug that means you can't directly import the result into PyTorch. See https://github.com/google/jax/issues/1100 .
Hope this helps!
Most helpful comment
I believe DLPack support is now fully working at GitHub head. You'll either need to build jaxlib from source at head or wait until we make another jaxlib wheel release (0.1.39) before you can use it.
I also implemented
__cuda_array_interface__, although there's a PyTorch bug that means you can't directly import the result into PyTorch. See https://github.com/google/jax/issues/1100 .Hope this helps!