>>> import numpy as np
>>> import jax.numpy as jnp
>>> np.tile(np.array([0, 1, 2]), (1, 1, 2))
array([[[0, 1, 2, 0, 1, 2]]])
>>> jnp.tile(jnp.array([0, 1, 2]), (1, 2, 2))
DeviceArray([[[0, 1, 2, 0, 1, 2]]], dtype=int64)
>>> np.tile(np.array([0, 1, 2]), (1, 0, 2))
array([], shape(1, 0, 6), dtype=int64)
>>> jnp.tile(jnp.array([0, 1, 2]), (1, 0, 2)
...
ValueError: Need at least one array to concatenate
A fix locally is to wrap the jax.numpy.tile like so:
if 0 in repeats:
return jnp.array([]).reshape(np.array(tensor_in.shape) * np.array(repeats))
return jnp.tile(tensor_in, repeats)
Wow that was fast @jakevdp! Thanks so much! :)
Most helpful comment
Wow that was fast @jakevdp! Thanks so much! :)