Jax: Cannot print a doubly-reshaped ShardedDeviceArray

Created on 29 Jun 2020  Â·  7Comments  Â·  Source: google/jax

The issue can be replicated using the following repro code

import jax
x = jax.pmap(lambda x: x)(jax.numpy.ones((2, 2, 2)))
print(x.reshape((-1, 2)).reshape(-1))

The error message is

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-129-21f9347dfcf4> in <module>
      1 import jax
      2 x = jax.pmap(lambda x: x)(jax.numpy.ones((2, 2, 2)))
----> 3 print(x.reshape((-1, 2)).reshape(-1))

~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/interpreters/xla.py in _forward_method(attrname, self, fun, *args)
    969 
    970 def _forward_method(attrname, self, fun, *args):
--> 971   return fun(getattr(self, attrname), *args)
    972 _forward_to_value = partial(_forward_method, "_value")
    973 

~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/interpreters/pxla.py in _value(self)
    587   def _value(self):
    588     if self._npy_value is None:
--> 589       self.copy_to_host_async()
    590       npy_value = onp.empty(self.aval.shape, self.aval.dtype)
    591       for i in self.one_replica_buffer_indices:

~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/interpreters/pxla.py in copy_to_host_async(self)
    566   def copy_to_host_async(self):
    567     for buffer_index in self.one_replica_buffer_indices:
--> 568       self.device_buffers[buffer_index].copy_to_host_async()
    569 
    570   def delete(self):

IndexError: list index out of range

Actually, we can't do anything with doubly-reshaped ShardedDeviceArray. Something likes x.reshape((-1, 2)).reshape(-1) * 1 displays a better error message

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-136-b3cc850a373d> in <module>
      1 import jax
      2 x = jax.pmap(lambda x: x)(jax.numpy.ones((2, 2, 2)))
----> 3 y = x.reshape((-1, 2)).reshape(-1) * 1

~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in deferring_binary_op(self, other)
   4257     if not isinstance(other, _scalar_types + _arraylike_types + (core.Tracer,)):
   4258       return NotImplemented
-> 4259     return binary_op(self, other)
   4260   return deferring_binary_op
   4261 

~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in fn(x1, x2)
    338   def fn(x1, x2):
    339     x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
--> 340     return lax_fn(x1, x2) if x1.dtype != bool_ else bool_lax_fn(x1, x2)
    341   return _wraps(numpy_fn)(fn)
    342 

~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/lax/lax.py in mul(x, y)
    306 def mul(x: Array, y: Array) -> Array:
    307   r"""Elementwise multiplication: :math:`x \times y`."""
--> 308   return mul_p.bind(x, y)
    309 
    310 def div(x: Array, y: Array) -> Array:

~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/core.py in bind(self, *args, **kwargs)
    271     top_trace = find_top_trace(args)
    272     if top_trace is None:
--> 273       return self.impl(*args, **kwargs)
    274 
    275     tracers = map(top_trace.full_raise, args)

~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/interpreters/xla.py in apply_primitive(prim, *args, **params)
    227   """Impl rule that compiles and runs a single primitive 'prim' using XLA."""
    228   compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
--> 229   return compiled_fun(*args)
    230 
    231 @cache()

~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/interpreters/xla.py in _execute_compiled_primitive(prim, compiled, result_handler, *args)
    326 def _execute_compiled_primitive(prim, compiled, result_handler, *args):
    327   device, = compiled.local_devices()
--> 328   input_bufs = [device_put(x, device) for x in args if x is not token]
    329   out_bufs = compiled.execute(input_bufs)
    330   if FLAGS.jax_debug_nans:

~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/interpreters/xla.py in <listcomp>(.0)
    326 def _execute_compiled_primitive(prim, compiled, result_handler, *args):
    327   device, = compiled.local_devices()
--> 328   input_bufs = [device_put(x, device) for x in args if x is not token]
    329   out_bufs = compiled.execute(input_bufs)
    330   if FLAGS.jax_debug_nans:

~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/interpreters/xla.py in device_put(x, device)
    115   x = canonicalize_dtype(x)
    116   try:
--> 117     return device_put_handlers[type(x)](x, device)
    118   except KeyError as err:
    119     raise TypeError(f"No device_put handler for type: {type(x)}") from err

~/miniconda3/envs/pydata/lib/python3.8/site-packages/jax/interpreters/xla.py in _device_put_array(x, device)
    121 def _device_put_array(x, device: Optional[Device]):
    122   backend = xb.get_device_backend(device)
--> 123   return backend.buffer_from_pyval(x, device)
    124 
    125 def _device_put_scalar(x, device):

RuntimeError: Invalid argument: from_python argument must be an array.
bug

Most helpful comment

@fehiepsi , you're too kind. Thanks for the heads up and the offer.

For the last ~month, the only updates to #3370 have been rebasing it on master as I struggle to land it in Google's monorepo. The trouble is that with a monorepo I need to make sure all the existing code works with any updates to JAX, and #3370 breaks a lot of code by making some common performance bugs into hard errors. New code is being added fast enough that I'm not sure I've even made any progress.

I've been really stuck on this, so I hesitate to give any new projected date on when #3370 will land. I have one more idea to try this week. If it doesn't look like it's going to work by Wednesday evening, I'll try to fix this issue directly on master rather than blocking it on #3370. Does that work for you?

All 7 comments

Thanks @fehiepsi !

I think it's a bug in our funky ShardedDeviceArray reshape rules, which are only there to make soft_pmap work (and soft_pmap itself is a prototype feature).

3370 fixes this, and even eliminates the possibility for such bugs because it removes the weird ShardedDeviceArray reshape rules entirely (and re-implements soft_pmap). I'm tempted to wait for #3370 to land, which I'm trying to make happen in the next 7 days, instead of fixing this on master.

What do you think? Is this a fix that can wait ~7 days, or is it super annoying?

Thanks Matt! Please take your time, it does not block my coding fun. :D

On Wed, Jul 1, 2020 at 2:01 PM Matthew Johnson notifications@github.com
wrote:

Thanks @fehiepsi https://github.com/fehiepsi !

I think it's a bug in our funky ShardedDeviceArray reshape rules, which
are only there to make soft_pmap work (and soft_pmap itself is a
prototype feature).

3370 https://github.com/google/jax/pull/3370 fixes this, and even

eliminates the possibility for such bugs because it removes the weird
ShardedDeviceArray reshape rules entirely (and re-implements soft_pmap).
I'm tempted to wait for #3370 https://github.com/google/jax/pull/3370
to land, which I'm trying to make happen in the next 7 days, instead of
fixing this on master.

What do you think? Is this a fix that can wait ~7 days, or is it super
annoying?

—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
https://github.com/google/jax/issues/3597#issuecomment-652593099, or
unsubscribe
https://github.com/notifications/unsubscribe-auth/ABEEKVS6VZRRO4CZN5KPUUDRZOBYNANCNFSM4OLMNPDA
.

Hi @mattjj, we will release NumPyro 0.3 when this issue is fixed. Do you have an estimate on when #3370 is landed? I am happy to test if there is any regression with your PR.

@fehiepsi , you're too kind. Thanks for the heads up and the offer.

For the last ~month, the only updates to #3370 have been rebasing it on master as I struggle to land it in Google's monorepo. The trouble is that with a monorepo I need to make sure all the existing code works with any updates to JAX, and #3370 breaks a lot of code by making some common performance bugs into hard errors. New code is being added fast enough that I'm not sure I've even made any progress.

I've been really stuck on this, so I hesitate to give any new projected date on when #3370 will land. I have one more idea to try this week. If it doesn't look like it's going to work by Wednesday evening, I'll try to fix this issue directly on master rather than blocking it on #3370. Does that work for you?

I've been really stuck on this

Haha, it always happens. I think you will find a way soon. ;)

Does that work for you?

Thanks a lot @mattjj! That really works for us.

I am, of course, behind schedule :) Maybe it's worth just fixing this issue directly, and not blocking on #3370...

I think this was fixed #3370!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

harshit-2115 picture harshit-2115  Â·  3Comments

clemisch picture clemisch  Â·  3Comments

rdaems picture rdaems  Â·  3Comments

DylanMuir picture DylanMuir  Â·  3Comments

alexbw picture alexbw  Â·  3Comments