import jax.numpy as jnp
arr = jnp.arange(8)
print(type(arr))
arr_copy = arr.copy()
print(type(arr_copy))
Output:
<class 'jax.interpreters.xla.DeviceArray'>
<class 'numpy.ndarray'>
This is intentional, but maybe we should revise it? What do you think? (It's a pretty old convention and might be ripe for revision.)
One of the motivations is that because DeviceArrays are immutable there is little reason to copy them. (But now that buffer donation is supported, they are mutable in a certain sense, and so we might want to copy them without bouncing back to the host...)
@hawkinsp thoughts?
Even if DeviceArray is immutable, just from an API consistency perspective it would be nice if DeviceArray.copy() returned self. The general contract for copy() is that it returns an object of the same type.
It might also make sense to implement copy-on-write, if that's feasible. If somebody is using .copy() for defensive _duck typing_ purposes (which seems like the most likely use-case at the moment), then they probably aren't also using JAX's specialized buffer donation API.
@shoyer you are right, as usual!
I dug in a little bit and, as I commented in our internal chat, I think this DeviceArray.copy behavior is ancient, from a time before we wanted to expose a jax.device_get or in general admit that users should think about DeviceArrays versus ndarrays. Any original motivation for this behavior has long been rendered irrelevant.
I'm going to try revising this behavior.
Related: would it be possible to expose array.flags.writeable to make it easy to write code that would work both for jax and numpy using ducktyping?
Edit: actually I am not so sure about this as the semantics are different than that of numpy anyways.
n [24]: a = np.arange(5).astype(np.float32)
In [25]: b = jnp.asarray(a)
In [26]: c = b
In [27]: b -= b.mean()
In [28]: b
Out[28]: DeviceArray([-2., -1., 0., 1., 2.], dtype=float32)
In [29]: c
Out[29]: DeviceArray([0., 1., 2., 3., 4.], dtype=float32)
In [30]: a
Out[30]: array([0., 1., 2., 3., 4.], dtype=float32)
Also it might be worth adding a dedicated method to return a numpy array explicitly, e.g. array.numpy() as in TF and PyTorch. Whether or not it should make a copy in case the buffer of the jax array is already hosted in main memory, I am not so sure.
would it be possible to expose array.flags.writeable to make it easy to write code that would work both for jax and numpy using ducktyping?
This is a good idea. We'd have to think about what to do about some of the other flags, which aren't as relevant to jax devicearrays.