jax.numpy.tile inconsistent with np.tile (when 0 in repeats)

Created on 31 Jul 2020  路  1Comment  路  Source: google/jax

>>> import numpy as np
>>> import jax.numpy as jnp
>>> np.tile(np.array([0, 1, 2]), (1, 1, 2))
array([[[0, 1, 2, 0, 1, 2]]])
>>> jnp.tile(jnp.array([0, 1, 2]), (1, 2, 2))
DeviceArray([[[0, 1, 2, 0, 1, 2]]], dtype=int64)
>>> np.tile(np.array([0, 1, 2]), (1, 0, 2))
array([], shape(1, 0, 6), dtype=int64)
>>> jnp.tile(jnp.array([0, 1, 2]), (1, 0, 2)
...
ValueError: Need at least one array to concatenate

A fix locally is to wrap the jax.numpy.tile like so:

        if 0 in repeats:
            return jnp.array([]).reshape(np.array(tensor_in.shape) * np.array(repeats))
        return jnp.tile(tensor_in, repeats)
bug

Most helpful comment

Wow that was fast @jakevdp! Thanks so much! :)

>All comments

Wow that was fast @jakevdp! Thanks so much! :)

Was this page helpful?
0 / 5 - 0 ratings

Related issues

RobertTLange picture RobertTLange  路  3Comments

clemisch picture clemisch  路  3Comments

madvn picture madvn  路  3Comments

kunc picture kunc  路  3Comments

sursu picture sursu  路  3Comments