Operations that do matrix multiplication in JAX accept a precison for controlling precision when executed on TPUs.
The current API seems non-ideal to me:
np.dot(x, y, precision=lax.Precison.HIGHEST)). This is a little cumbersome and inconsistent with most NumPy/SciPy APIs which use strings (e.g., np.dot(x, y, precision='highest')).@ infix operator.On TPUs, "HIGH" precision corresponds to 3 passes of bfloat16 and "HIGHEST" precision corresponds to 6 passes, which is effectively full float32 (dropping denormals), as explained in this Intel paper.
With that in mind, and considering that ideally we would retain some flexibility for alternative matmul optimizations that might appears on other platforms, what more descriptive naming scheme makes sense for values of the precision argument?
Some ideas:
'low', 'high', 'highest': description of precision level'fastest', 'fast', 'slow': description of speed'fastest', 'fast', 'accurate': mixed description, using only positive words'bfloat16', 'float24', 'float32': rough precision of the underlying arithmetic (but what is "float24"??)I think I lean towards option 3?
Notes from offline discussion:
min_precision would be more appropriate.dtype. XLA will use lowest precision on bfloat16 data regardless of the precision option.precision={'tpu': 'bfloat16', 'gpu': 'float16'}.precision=1e-2 or precision=1e-6. But this mixes together precision in the significand and the exponent, which misses important nuances like bfloat16 vs float16.Given that we want to support platform specific options, descriptive names seem like the best bet.
The main remaining concern is what to call "3 pass bfloat16" precision on TPUs, which approximates roughly 16 bits of precision for the significand. "intermediate" precision would be OK for TPUs, but seems very vague in general. Maybe bfloat24 or bfloat16_3x would be appropriate? (We could also support bfloat16_6x as a more precise description of float32.)
“3 pass bfloat16” is coincidentally very close to (slightly higher than?) the precision of Nvidia’s new “tensorfloat32”. So that could also be a good name for this intermediate precision on TPUs
Users have also requested a way to set a more "global" default precision.
One possible mechanism to do this is via a scope, e.g.:
with jax.precision("highest"):
...
I would suggest that it should override only operations with default precision.
I would suggest that it should override only operations with default precision.
I assume you mean only for precision=None, rather than the confusingly named precision=lax.Precison.DEFAULT (aka bfloat16)?
Yes, I meant None, not what we are currently calling DEFAULT.
Most helpful comment
Users have also requested a way to set a more "global" default precision.
One possible mechanism to do this is via a scope, e.g.:
I would suggest that it should override only operations with default precision.