Jax: Named tensors

Created on 30 Nov 2020  路  3Comments  路  Source: google/jax

PyTorch has experimental support for named tensors achieving some compelling design goals while keeping existing code compatible. For example, binop broadcasting is still based on dimension order (unlike in xarray), consistent with standard NumPy/JAX/... semantics, but checks that aligned dimension names match.

It would be great to have named tensors that work both in op-by-op and under function transformations in JAX.

@shoyer In https://github.com/google/jax/issues/1565 you mentioned that this could be done by wrapping JAX based on https://github.com/google/jax/pull/611. According to my current understanding, this means:

  • Add name rules for lax primitives, returning the output dimension names for given input dimension names.
  • Add a corresponding eval_names transform.
  • Add a NamedDeviceArray subtype of DeviceArray that adds a names property.
  • We want names to be propagated in op-by-op mode on NamedDeviceArrays. For that,
  • Make jitted functions propagate names when applied to NamedDeviceArrays.

Is this plan sound? @shoyer @mattjj Would you update (and merge, if successful) https://github.com/google/jax/pull/611 just for this application? In that case, I'd be interested in prototyping a named tensor library for JAX, with a good amount of passion, in accordance with https://github.com/google/jax/issues/1565. (:

enhancement

Most helpful comment

We are actually working on something that will pretty much realize the plan that @JuliusKunze has outlined here, with some additional benefits too (e.g. making it very easy to shard those programs with named axes over multiple accelerators).

All 3 comments

Have you started working on this @JuliusKunze ?

We are actually working on something that will pretty much realize the plan that @JuliusKunze has outlined here, with some additional benefits too (e.g. making it very easy to shard those programs with named axes over multiple accelerators).

@Jeevesh8 No, and now I won't anymore. (: @apaszke That's great to hear! Will this go into the JAX repo?

Was this page helpful?
0 / 5 - 0 ratings

Related issues

shannon63 picture shannon63  路  3Comments

alexbw picture alexbw  路  3Comments

sussillo picture sussillo  路  3Comments

murphyk picture murphyk  路  3Comments

harshit-2115 picture harshit-2115  路  3Comments