hi all,
thanks for making such a cool library! I was wondering how I should go about saving jax arrays to file, since np.save isn't yet implemented. Apologies if I missed something!
Rowan
Thanks for the question!
JAX's device-backed ndarray class, DeviceArray, is effectively a subclass of numpy.ndarray. You can turn it into a regular ndarray using numpy.asarray, like this:
import numpy as onp
numpy_array = onp.asarray(device_array)
DeviceArrays try to turn themselves into ndarrays automatically when appropriate, so you can also use Numpy's save on them directly:
onp.save('array.npy', device_array)
We could also override __getstate__ so that they work with pickle. (We did that in an older version of JAX, but it looks like we lost that code at some point.)
I think we should probably bring numpy.save into jax.numpy, so that you don't have to import raw NumPy (as onp here) to get the same saving behavior... Let's leave this issue open until we either add jax.numpy.save or else figure out there's some reason not to.
In the meantime, does import numpy as onp and onp.save work for you?
Thanks for the response! Sorry, I'm still learning about the library and didn't know that you could convert DeviceArrays into regular numpy arrays like that -- that helps a lot and solves my problem (and will hopefully help others if they search for it) :)
I'm having issues saving intermediate results of my computations. In some cases, I need/want to save intermediate tensors so I can inspect them later. Inside a jitted function, jax.numpy.save seems to behave the same as original numpy.save, in the sense that I get the same error message saying I cannot convert the tensor: "
Exception: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced
This error can occur when a JAX Tracer object is passed to a raw numpy function, or a method on a numpy.ndarray object. You might want to check that you are using jnp together with import jax.numpy as jnp rather than using np via import numpy as np. If this error arises on a line that involves array indexing, like x[idx], it may be that the array being indexed x is a raw numpy.ndarray while the indices idx are a JAX Tracer instance; in that case, you can instead write jax.device_put(x)[idx]."
I'm blocked by this: my code takes a lifetime to run without jitting, but when I'm jitting, I can't get any intermediate information from my simulations... What would be the solution or alternative to this?
@jwnys save ia a side effects and jit only works on pure functions, that is, any side effect will only run during tracing once and will not be visible at runtime. Your best option would be to have this as an additional output.
If you use Haiku you can leverage (or rather abuse) hk.set_state to avoid having to manually propagate intermediate values all the way to the output.
Most helpful comment
Thanks for the question!
JAX's device-backed ndarray class,
DeviceArray, is effectively a subclass ofnumpy.ndarray. You can turn it into a regular ndarray usingnumpy.asarray, like this:DeviceArrays try to turn themselves into ndarrays automatically when appropriate, so you can also use Numpy's
saveon them directly:We could also override
__getstate__so that they work with pickle. (We did that in an older version of JAX, but it looks like we lost that code at some point.)I think we should probably bring
numpy.saveintojax.numpy, so that you don't have to import raw NumPy (asonphere) to get the same saving behavior... Let's leave this issue open until we either addjax.numpy.saveor else figure out there's some reason not to.In the meantime, does
import numpy as onpandonp.savework for you?