import jax.numpy as np
import numpy as onp
onp.array(0).dtype # works, because onp.array(0) returns a 0-dim array
np.array(0).dtype # doesn't work, because np.array(0) returns an int
This showed up when trying to do lax.dynamic_update_slice(array, value, np.array(0))
In [1]: import numpy as onp
In [2]: import jax.numpy as np
In [3]: onp.array(0).dtype
Out[3]: dtype('int64')
In [4]: np.array(0).dtype
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-4-05f7ef02f43f> in <module>
----> 1 np.array(0).dtype
AttributeError: 'int' object has no attribute 'dtype'
I'm on commit 6476d5ffc6bee0e8d9d7b54914565d798b8dda29.
I am not sure if this is a regression, but we noticed this recently as well when we found that scan does not work well with scalar arrays (https://github.com/pyro-ppl/numpyro/pull/91). The reason is that jax.numpy.array(1.) returns a python float instead of a DeviceArray unlike numpy.
cc. @mattjj in case this needs to be reopened.
@kroq-gar78 that PR was closed but not merged, so this bug still isn't fixed and remains open.
@neerajprad we're going to check in a from-scratch rewrite of scan in the next week or two, though that new version will need testing and could have similar issues at first. Thanks for pointing that out!
Most helpful comment
@kroq-gar78 that PR was closed but not merged, so this bug still isn't fixed and remains open.
@neerajprad we're going to check in a from-scratch rewrite of
scanin the next week or two, though that new version will need testing and could have similar issues at first. Thanks for pointing that out!