Jax: Matrix multiplication precision API

Created on 4 Feb 2020  ·  6Comments  ·  Source: google/jax

Operations that do matrix multiplication in JAX accept a precison for controlling precision when executed on TPUs.

The current API seems non-ideal to me:

  1. You have to pass an enum value (e.g., np.dot(x, y, precision=lax.Precison.HIGHEST)). This is a little cumbersome and inconsistent with most NumPy/SciPy APIs which use strings (e.g., np.dot(x, y, precision='highest')).
  2. The current names for precision levels ("highest", "high" and "default") are not very descriptive. In my ideal world we would use some direct indication of the corresponding precision (e.g., bfloat16 multiplication with float32 accumulation), but as the very least can we switch "default" to "low"?
  3. The default low precision is a bit of a footgun, at least when doing anything that isn't implementing a neural net layer. In my opinion, it would be much safer to use "highest" precision by default (which isn't _that_ much slower) on float32 data. Neural net libraries, of course, can default to lower precision, so this really only effects users who directly use NumPy APIs or the @ infix operator.

Most helpful comment

Users have also requested a way to set a more "global" default precision.

One possible mechanism to do this is via a scope, e.g.:

with jax.precision("highest"):
  ...

I would suggest that it should override only operations with default precision.

All 6 comments

On TPUs, "HIGH" precision corresponds to 3 passes of bfloat16 and "HIGHEST" precision corresponds to 6 passes, which is effectively full float32 (dropping denormals), as explained in this Intel paper.

With that in mind, and considering that ideally we would retain some flexibility for alternative matmul optimizations that might appears on other platforms, what more descriptive naming scheme makes sense for values of the precision argument?

Some ideas:

  1. 'low', 'high', 'highest': description of precision level
  2. 'fastest', 'fast', 'slow': description of speed
  3. 'fastest', 'fast', 'accurate': mixed description, using only positive words
  4. 'bfloat16', 'float24', 'float32': rough precision of the underlying arithmetic (but what is "float24"??)

I think I lean towards option 3?

Notes from offline discussion:

  • This is really a "minimum precision" configuration, so perhaps a name like min_precision would be more appropriate.
  • The other way to configure matmul precision (maybe more obvious) is by explicitly setting dtype. XLA will use lowest precision on bfloat16 data regardless of the precision option.
  • We want an API that also can support new matmul precision options as they arise on different platforms (GPU, CPU, etc), e.g., precision={'tpu': 'bfloat16', 'gpu': 'float16'}.
  • Another option would be to specify precision numerically, e.g., precision=1e-2 or precision=1e-6. But this mixes together precision in the significand and the exponent, which misses important nuances like bfloat16 vs float16.

Given that we want to support platform specific options, descriptive names seem like the best bet.

The main remaining concern is what to call "3 pass bfloat16" precision on TPUs, which approximates roughly 16 bits of precision for the significand. "intermediate" precision would be OK for TPUs, but seems very vague in general. Maybe bfloat24 or bfloat16_3x would be appropriate? (We could also support bfloat16_6x as a more precise description of float32.)

“3 pass bfloat16” is coincidentally very close to (slightly higher than?) the precision of Nvidia’s new “tensorfloat32”. So that could also be a good name for this intermediate precision on TPUs

Users have also requested a way to set a more "global" default precision.

One possible mechanism to do this is via a scope, e.g.:

with jax.precision("highest"):
  ...

I would suggest that it should override only operations with default precision.

I would suggest that it should override only operations with default precision.

I assume you mean only for precision=None, rather than the confusingly named precision=lax.Precison.DEFAULT (aka bfloat16)?

Yes, I meant None, not what we are currently calling DEFAULT.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

lonelykid picture lonelykid  ·  3Comments

madvn picture madvn  ·  3Comments

harshit-2115 picture harshit-2115  ·  3Comments

shannon63 picture shannon63  ·  3Comments

zhongwen picture zhongwen  ·  3Comments