Jax: Scalars passed into np.array should return 0-dim arrays

Created on 16 Dec 2018  路  3Comments  路  Source: google/jax

import jax.numpy as np
import numpy as onp

onp.array(0).dtype # works, because onp.array(0) returns a 0-dim array
np.array(0).dtype # doesn't work, because np.array(0) returns an int

This showed up when trying to do lax.dynamic_update_slice(array, value, np.array(0))

bug

Most helpful comment

@kroq-gar78 that PR was closed but not merged, so this bug still isn't fixed and remains open.

@neerajprad we're going to check in a from-scratch rewrite of scan in the next week or two, though that new version will need testing and could have similar issues at first. Thanks for pointing that out!

All 3 comments

124 says this was fixed, but I still get an error:

In [1]: import numpy as onp

In [2]: import jax.numpy as np

In [3]: onp.array(0).dtype
Out[3]: dtype('int64')

In [4]: np.array(0).dtype
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-4-05f7ef02f43f> in <module>
----> 1 np.array(0).dtype

AttributeError: 'int' object has no attribute 'dtype'

I'm on commit 6476d5ffc6bee0e8d9d7b54914565d798b8dda29.

I am not sure if this is a regression, but we noticed this recently as well when we found that scan does not work well with scalar arrays (https://github.com/pyro-ppl/numpyro/pull/91). The reason is that jax.numpy.array(1.) returns a python float instead of a DeviceArray unlike numpy.

cc. @mattjj in case this needs to be reopened.

@kroq-gar78 that PR was closed but not merged, so this bug still isn't fixed and remains open.

@neerajprad we're going to check in a from-scratch rewrite of scan in the next week or two, though that new version will need testing and could have similar issues at first. Thanks for pointing that out!

Was this page helpful?
0 / 5 - 0 ratings