Jax: Allow mask of strided ops

Created on 29 Jul 2020  路  5Comments  路  Source: google/jax

Currently, mask can only support ops with output sizes expressible as polynomials of the input sizes. This excludes:

  • Strided slicing and convolutions (only supported for edge case when the stride perfectly divides the input size)
  • Operations with variable output shape, such as single-argument np.where(x)

This is a hassle. @j-towns @mattjj @dougalm What are your thoughts on the following plan to allow these cases?

  • We split up shape rules into two traceable JAX functions: A shape validation rule indicating if input shapes to an op are valid, and the actual rule to calculate the output shape.
  • We remove polymorphic shape propagation. We make mask propagate the logical shapes using corresponding shape rules. vmap(mask) will then propagate batches of logical shapes using batched shape rules. We raise inconsistencies in dynamic shapes using the validation rules.
  • We remove polymorphic input specs. mask will be parameter-free. A (batched) masked function will take padded inputs and (batches of) logical input shapes, and return padded outputs and (batches of) logical output shapes.
  • We decorate shape rules with @numpy_eval() from https://github.com/google/jax/pull/3923. This ensures XLA dispatch is never wasted on constant shape calculations, even during jit compilation (+ no staging occurs). If the NumPy backend turns out to be too slow for some rules, we keep additional optimized NumPy rules for those.
  • We remove the Poly class and all other things polymorphic. This will remove the need for polymorphic special cases in shape rules.
  • Instead of jnp.sum(x) / shape_as_value(x.shape)[0] users can write jnp.sum(x) / logical_shape(x)[0]. logical_shape will retrieve the logical shape from the underlying MaskTrace. We remove the masking.shape_envs variable.
enhancement question

Most helpful comment

We raise inconsistencies in dynamic shapes using the validation rules.

We might be able to use the host_callback stuff to do this from inside compiled code.

All 5 comments

We raise inconsistencies in dynamic shapes using the validation rules.

We might be able to use the host_callback stuff to do this from inside compiled code.

  • We decorate shape rules with @numpy_eval() from #3923. This ensures XLA dispatch is never wasted on constant shape calculations, even during jit compilation (+ no staging occurs). If the NumPy backend turns out to be too slow for some rules, we keep additional optimized NumPy rules for those.

I'm not sure I understand this. How does XLA dispatch get invoked on shape calculations at present? I would guess that most cases where this is currently happening are bugs that would be surfaced by turning on omnistaging.

@shoyer Currently there is no XLA dispatch for shapes since shape rules are written directly in NumPy. In order to allow dynamic/batched shape propagation for masking, we would need traceable shape rules. @numpy_eval() would be useful to
calculate static, single shapes (i. e. in standard use cases, outside of masking) as fast as in the current implementation, without XLA dispatch, using the same code (i. e. not having two versions of each shape rule). We should be able to get very close to the original performance with #4117.

I think the main problem with this proposal is that we can't do shape checking and gracefully raise a shape error for incompatible dynamic shapes inside a jit. As I mentioned above, it is technically possible to raise an exception inside jit but I don't think you'd get a nice traceback from user code and the functionality is experimental.

@j-towns I agree, this is probably the main hurdle. How about we record the stack trace of potential shape errors during compilation and attach it to errors raised via host callback?

Was this page helpful?
0 / 5 - 0 ratings