numpy.ndarray instances currently return True if you run an isinstance check against jax.numpy.ndarray. I guess I see how this happens: I think Jax doesn't actually use that type, so it's maybe the actual one from numpy? It's a bit of a hassle when you're checking the array provenances though.
def test_numpy_ndarray_is_not_instance_of_jax_numpy_ndarray():
assert not isinstance(numpy.zeros(1), jax.numpy.ndarray)
Btw, what's the preferred way to convert data from Jax to numpy? I've found jax.device_get() by poking around, but I don't think it's documented.
Thanks for the great project!
Thanks for the questions, and the positivity!
Btw, what's the preferred way to convert data from Jax to numpy?
I usually import numpy as onp and then just use onp.array(jax_array) (or asarray), i.e. the usual way you use NumPy to turn something into an ndarray. That's as efficient as possible.
numpy.ndarray instances currently return True if you run an isinstance check against jax.numpy.ndarray.
This was a conscious choice, since we thought we wanted jax.numpy to be as close as possible to a drop-in replacement for numpy, where you just import jax.numpy instead of numpy and everything still works. But maybe it's more common, and more in line with our philosophy of explicitness, to imagine that users will import both numpy and jax.numpy and want to keep the two straight, e.g. with the kinds of isinstance checks you mention.
I think this line, together with the fact that onp.ndarray is included in _arraylike_types, controls this behavior. (See also the comment above those lines.) Maybe we should consider changing this.
Thoughts?
@shoyer @hawkinsp I'm keen to get your thoughts in particular!
I agree, as I wrote over in https://github.com/google/jax/pull/1081#issuecomment-517099099, this behavior surprised me. This sort of dynamic inheritance is rarely used; my guess is that it could lead to bugs. My vote would definitely be for encouraging separate/explicit jax.numpy and numpy imports.
@mattjj Thanks for the quick reply!
I would've expected the following:
jax_arr = jax.numpy.zeros(1)
assert isinstance(jax_arr, numpy.ndarray)
np_arr = numpy.ndarray(1)
assert not isinstance(np_arr, jax.numpy.ndarray)
This is the normal inheritancish behaviour: jax.numpy.ndarray is the new impostor, and it can claim to be a type of numpy.ndarray (even when that's not literally true). But it's kind of weird to trick numpy.ndarray into believing it's a type of jax.numpy.ndarray, since that's really not at all true.
Finally, I think it can be convenient to do something like from jax import numpy in a quick script, and it's nice for that to work fine when the program only has to deal with jax arrays or numpy arrays, but not both. But if a user writes that import in a context where they'll have a mixture of the two types, their code will have all sorts of bugs, and I think that's not really jax's fault? So I think having isinstance(arr, numpy.ndarray) return False will be the least of their problems. Like, yeah, they might have written code expecting that to return True --- but the actual truth is False, and they're better off knowing it.
These arguments are pretty convincing IMO. I'll ping our internal chat room to see if there's any dissent, and if not we should fix this.
@honnibal did you mean for the last line to have a not in it?
I think pretending that inheritance goes either way is surprising.
One implementation detail, not necessarily relevant to the question of how things _should_ behave but maybe useful as an explanation, is that jax.numpy.ndarray is not actually our array type; our array type is jax.interpreters.DeviceArray. The jax.numpy.ndarray value is there only for isinstance checks (which we can configure to act however we want, as we should figure out in this thread).
@mattjj Oops, yes! Fixed.