PyTorch has experimental support for named tensors achieving some compelling design goals while keeping existing code compatible. For example, binop broadcasting is still based on dimension order (unlike in xarray), consistent with standard NumPy/JAX/... semantics, but checks that aligned dimension names match.
It would be great to have named tensors that work both in op-by-op and under function transformations in JAX.
@shoyer In https://github.com/google/jax/issues/1565 you mentioned that this could be done by wrapping JAX based on https://github.com/google/jax/pull/611. According to my current understanding, this means:
eval_names transform.NamedDeviceArray subtype of DeviceArray that adds a names property.NamedDeviceArrays. For that,jax.numpy, wrapping each op with the named transform. NamedDeviceArray using https://github.com/google/jax/pull/611 (+1 for merging). Alternatively, one could rewrite jax.numpy using numpy_dispatch.get_array_module from https://github.com/google/jax/pull/4076 (appears cumbersome).jitted functions propagate names when applied to NamedDeviceArrays.Is this plan sound? @shoyer @mattjj Would you update (and merge, if successful) https://github.com/google/jax/pull/611 just for this application? In that case, I'd be interested in prototyping a named tensor library for JAX, with a good amount of passion, in accordance with https://github.com/google/jax/issues/1565. (:
Have you started working on this @JuliusKunze ?
We are actually working on something that will pretty much realize the plan that @JuliusKunze has outlined here, with some additional benefits too (e.g. making it very easy to shard those programs with named axes over multiple accelerators).
@Jeevesh8 No, and now I won't anymore. (: @apaszke That's great to hear! Will this go into the JAX repo?
Most helpful comment
We are actually working on something that will pretty much realize the plan that @JuliusKunze has outlined here, with some additional benefits too (e.g. making it very easy to shard those programs with named axes over multiple accelerators).