jax tensor and numpy array convertion

Created on 8 Jan 2020  路  3Comments  路  Source: google/jax

How to convert numpy array to the jax tensor, or from jax tensor to numpy array?

documentation question

Most helpful comment

Thanks for the question!

import numpy as onp
import jax.numpy as jnp

numpy_array = onp.array([1, 2, 3])
jax_array = jnp.array(numpy_array)
numpy_array_again = onp.array(jax_array)

Does that answer your question?

All 3 comments

Thanks for the question!

import numpy as onp
import jax.numpy as jnp

numpy_array = onp.array([1, 2, 3])
jax_array = jnp.array(numpy_array)
numpy_array_again = onp.array(jax_array)

Does that answer your question?

I recommend using asarray instead of array to avoid unnecessary memory copies.

Thank you very much!^_^ @mattjj @shoyer

Was this page helpful?
0 / 5 - 0 ratings