The following code illustrates this point:
x = np.array([1, 2, 3])
def g(y, yp):
if y is yp:
return 1
else:
return 0
print(g(x, x))
print(vmap(g, 0, 0)(x, x))
While it's not totally clear how to preserve the semantics of is, I do feel like it would be nice if there was consistency. An alternative would be to use np.allclose, but that seems slower and semantically a bit different from is.
I would guess is isn't preserved under jit, either, and fixing that seems even harder.
I'm not sure what can be done about this besides discouraging it.
The semantics of is are Python object/reference identity, while the semantics of JAX's embedded language are purely based on values. I suppose this issue is a corollary of "jit will silently ignore Python side effects in your code": it will also ignore aspects of Python semantics having to do with memory aliasing or object identity of traced values.
(And since is is defined for all Python values and can't be overloaded, it's not possible to error out on this the way we can for if dynamic_value:.)
Thanks for the clarification James! It does seem like not much can be done. However it would probably be good to add a some documentation warning people not to use is with jax transforms. Of course, as always Peter is one step ahead!
Most helpful comment
The semantics of
isare Python object/reference identity, while the semantics of JAX's embedded language are purely based on values. I suppose this issue is a corollary of "jitwill silently ignore Python side effects in your code": it will also ignore aspects of Python semantics having to do with memory aliasing or object identity of traced values.(And since
isis defined for all Python values and can't be overloaded, it's not possible to error out on this the way we can forif dynamic_value:.)