The issue can be replicated using the following repro code
import jax
x = jax.pmap(lambda x: x)(jax.numpy.ones((2, 2, 2)))
print(x.reshape((-1, 2)).reshape(-1))
The error message is
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-129-21f9347dfcf4> in <module>
1 import jax
2 x = jax.pmap(lambda x: x)(jax.numpy.ones((2, 2, 2)))
----> 3 print(x.reshape((-1, 2)).reshape(-1))
~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/interpreters/xla.py in _forward_method(attrname, self, fun, *args)
969
970 def _forward_method(attrname, self, fun, *args):
--> 971 return fun(getattr(self, attrname), *args)
972 _forward_to_value = partial(_forward_method, "_value")
973
~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/interpreters/pxla.py in _value(self)
587 def _value(self):
588 if self._npy_value is None:
--> 589 self.copy_to_host_async()
590 npy_value = onp.empty(self.aval.shape, self.aval.dtype)
591 for i in self.one_replica_buffer_indices:
~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/interpreters/pxla.py in copy_to_host_async(self)
566 def copy_to_host_async(self):
567 for buffer_index in self.one_replica_buffer_indices:
--> 568 self.device_buffers[buffer_index].copy_to_host_async()
569
570 def delete(self):
IndexError: list index out of range
Actually, we can't do anything with doubly-reshaped ShardedDeviceArray. Something likes x.reshape((-1, 2)).reshape(-1) * 1 displays a better error message
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-136-b3cc850a373d> in <module>
1 import jax
2 x = jax.pmap(lambda x: x)(jax.numpy.ones((2, 2, 2)))
----> 3 y = x.reshape((-1, 2)).reshape(-1) * 1
~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in deferring_binary_op(self, other)
4257 if not isinstance(other, _scalar_types + _arraylike_types + (core.Tracer,)):
4258 return NotImplemented
-> 4259 return binary_op(self, other)
4260 return deferring_binary_op
4261
~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in fn(x1, x2)
338 def fn(x1, x2):
339 x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
--> 340 return lax_fn(x1, x2) if x1.dtype != bool_ else bool_lax_fn(x1, x2)
341 return _wraps(numpy_fn)(fn)
342
~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/lax/lax.py in mul(x, y)
306 def mul(x: Array, y: Array) -> Array:
307 r"""Elementwise multiplication: :math:`x \times y`."""
--> 308 return mul_p.bind(x, y)
309
310 def div(x: Array, y: Array) -> Array:
~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/core.py in bind(self, *args, **kwargs)
271 top_trace = find_top_trace(args)
272 if top_trace is None:
--> 273 return self.impl(*args, **kwargs)
274
275 tracers = map(top_trace.full_raise, args)
~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/interpreters/xla.py in apply_primitive(prim, *args, **params)
227 """Impl rule that compiles and runs a single primitive 'prim' using XLA."""
228 compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
--> 229 return compiled_fun(*args)
230
231 @cache()
~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/interpreters/xla.py in _execute_compiled_primitive(prim, compiled, result_handler, *args)
326 def _execute_compiled_primitive(prim, compiled, result_handler, *args):
327 device, = compiled.local_devices()
--> 328 input_bufs = [device_put(x, device) for x in args if x is not token]
329 out_bufs = compiled.execute(input_bufs)
330 if FLAGS.jax_debug_nans:
~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/interpreters/xla.py in <listcomp>(.0)
326 def _execute_compiled_primitive(prim, compiled, result_handler, *args):
327 device, = compiled.local_devices()
--> 328 input_bufs = [device_put(x, device) for x in args if x is not token]
329 out_bufs = compiled.execute(input_bufs)
330 if FLAGS.jax_debug_nans:
~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/interpreters/xla.py in device_put(x, device)
115 x = canonicalize_dtype(x)
116 try:
--> 117 return device_put_handlers[type(x)](x, device)
118 except KeyError as err:
119 raise TypeError(f"No device_put handler for type: {type(x)}") from err
~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/interpreters/xla.py in _device_put_array(x, device)
121 def _device_put_array(x, device: Optional[Device]):
122 backend = xb.get_device_backend(device)
--> 123 return backend.buffer_from_pyval(x, device)
124
125 def _device_put_scalar(x, device):
RuntimeError: Invalid argument: from_python argument must be an array.
Thanks @fehiepsi !
I think it's a bug in our funky ShardedDeviceArray reshape rules, which are only there to make soft_pmap work (and soft_pmap itself is a prototype feature).
soft_pmap). I'm tempted to wait for #3370 to land, which I'm trying to make happen in the next 7 days, instead of fixing this on master.What do you think? Is this a fix that can wait ~7 days, or is it super annoying?
Thanks Matt! Please take your time, it does not block my coding fun. :D
On Wed, Jul 1, 2020 at 2:01 PM Matthew Johnson notifications@github.com
wrote:
Thanks @fehiepsi https://github.com/fehiepsi !
I think it's a bug in our funky ShardedDeviceArray reshape rules, which
are only there to make soft_pmap work (and soft_pmap itself is a
prototype feature).3370 https://github.com/google/jax/pull/3370 fixes this, and even
eliminates the possibility for such bugs because it removes the weird
ShardedDeviceArray reshape rules entirely (and re-implements soft_pmap).
I'm tempted to wait for #3370 https://github.com/google/jax/pull/3370
to land, which I'm trying to make happen in the next 7 days, instead of
fixing this on master.What do you think? Is this a fix that can wait ~7 days, or is it super
annoying?—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
https://github.com/google/jax/issues/3597#issuecomment-652593099, or
unsubscribe
https://github.com/notifications/unsubscribe-auth/ABEEKVS6VZRRO4CZN5KPUUDRZOBYNANCNFSM4OLMNPDA
.
Hi @mattjj, we will release NumPyro 0.3 when this issue is fixed. Do you have an estimate on when #3370 is landed? I am happy to test if there is any regression with your PR.
@fehiepsi , you're too kind. Thanks for the heads up and the offer.
For the last ~month, the only updates to #3370 have been rebasing it on master as I struggle to land it in Google's monorepo. The trouble is that with a monorepo I need to make sure all the existing code works with any updates to JAX, and #3370 breaks a lot of code by making some common performance bugs into hard errors. New code is being added fast enough that I'm not sure I've even made any progress.
I've been really stuck on this, so I hesitate to give any new projected date on when #3370 will land. I have one more idea to try this week. If it doesn't look like it's going to work by Wednesday evening, I'll try to fix this issue directly on master rather than blocking it on #3370. Does that work for you?
I've been really stuck on this
Haha, it always happens. I think you will find a way soon. ;)
Does that work for you?
Thanks a lot @mattjj! That really works for us.
I am, of course, behind schedule :) Maybe it's worth just fixing this issue directly, and not blocking on #3370...
I think this was fixed #3370!
Most helpful comment
@fehiepsi , you're too kind. Thanks for the heads up and the offer.
For the last ~month, the only updates to #3370 have been rebasing it on master as I struggle to land it in Google's monorepo. The trouble is that with a monorepo I need to make sure all the existing code works with any updates to JAX, and #3370 breaks a lot of code by making some common performance bugs into hard errors. New code is being added fast enough that I'm not sure I've even made any progress.
I've been really stuck on this, so I hesitate to give any new projected date on when #3370 will land. I have one more idea to try this week. If it doesn't look like it's going to work by Wednesday evening, I'll try to fix this issue directly on master rather than blocking it on #3370. Does that work for you?