How to convert numpy array to the jax tensor, or from jax tensor to numpy array?
Thanks for the question!
import numpy as onp
import jax.numpy as jnp
numpy_array = onp.array([1, 2, 3])
jax_array = jnp.array(numpy_array)
numpy_array_again = onp.array(jax_array)
Does that answer your question?
I recommend using asarray instead of array to avoid unnecessary memory copies.
Thank you very much!^_^ @mattjj @shoyer
Most helpful comment
Thanks for the question!
Does that answer your question?