Currently, mask can only support ops with output sizes expressible as polynomials of the input sizes. This excludes:
np.where(x)This is a hassle. @j-towns @mattjj @dougalm What are your thoughts on the following plan to allow these cases?
mask propagate the logical shapes using corresponding shape rules. vmap(mask) will then propagate batches of logical shapes using batched shape rules. We raise inconsistencies in dynamic shapes using the validation rules.mask will be parameter-free. A (batched) masked function will take padded inputs and (batches of) logical input shapes, and return padded outputs and (batches of) logical output shapes.@numpy_eval() from https://github.com/google/jax/pull/3923. This ensures XLA dispatch is never wasted on constant shape calculations, even during jit compilation (+ no staging occurs). If the NumPy backend turns out to be too slow for some rules, we keep additional optimized NumPy rules for those.Poly class and all other things polymorphic. This will remove the need for polymorphic special cases in shape rules. jnp.sum(x) / shape_as_value(x.shape)[0] users can write jnp.sum(x) / logical_shape(x)[0]. logical_shape will retrieve the logical shape from the underlying MaskTrace. We remove the masking.shape_envs variable.We raise inconsistencies in dynamic shapes using the validation rules.
We might be able to use the host_callback stuff to do this from inside compiled code.
- We decorate shape rules with
@numpy_eval()from #3923. This ensures XLA dispatch is never wasted on constant shape calculations, even duringjitcompilation (+ no staging occurs). If the NumPy backend turns out to be too slow for some rules, we keep additional optimized NumPy rules for those.
I'm not sure I understand this. How does XLA dispatch get invoked on shape calculations at present? I would guess that most cases where this is currently happening are bugs that would be surfaced by turning on omnistaging.
@shoyer Currently there is no XLA dispatch for shapes since shape rules are written directly in NumPy. In order to allow dynamic/batched shape propagation for masking, we would need traceable shape rules. @numpy_eval() would be useful to
calculate static, single shapes (i. e. in standard use cases, outside of masking) as fast as in the current implementation, without XLA dispatch, using the same code (i. e. not having two versions of each shape rule). We should be able to get very close to the original performance with #4117.
I think the main problem with this proposal is that we can't do shape checking and gracefully raise a shape error for incompatible dynamic shapes inside a jit. As I mentioned above, it is technically possible to raise an exception inside jit but I don't think you'd get a nice traceback from user code and the functionality is experimental.
@j-towns I agree, this is probably the main hurdle. How about we record the stack trace of potential shape errors during compilation and attach it to errors raised via host callback?
Most helpful comment
We might be able to use the host_callback stuff to do this from inside compiled code.