jax tensor and numpy array convertion

Created on 8 Jan 2020  路  3Comments  路  Source: google/jax

How to convert numpy array to the jax tensor, or from jax tensor to numpy array?

documentation question

Most helpful comment

Thanks for the question!

import numpy as onp
import jax.numpy as jnp

numpy_array = onp.array([1, 2, 3])
jax_array = jnp.array(numpy_array)
numpy_array_again = onp.array(jax_array)

Does that answer your question?

All 3 comments

Thanks for the question!

import numpy as onp
import jax.numpy as jnp

numpy_array = onp.array([1, 2, 3])
jax_array = jnp.array(numpy_array)
numpy_array_again = onp.array(jax_array)

Does that answer your question?

I recommend using asarray instead of array to avoid unnecessary memory copies.

Thank you very much!^_^ @mattjj @shoyer

Was this page helpful?
0 / 5 - 0 ratings

Related issues

clemisch picture clemisch  路  3Comments

asross picture asross  路  3Comments

murphyk picture murphyk  路  3Comments

alexbw picture alexbw  路  3Comments

clemisch picture clemisch  路  3Comments