Jax: GPU determinism flag

Created on 3 Apr 2019  路  4Comments  路  Source: google/jax

Some ops seem to be non-deterministic on GPU. This can be annoying when debugging instability issues, and there are also some niche applications where deterministic functions are absolutely necessary.

For ops where determinism affects performance (and therefore isn't a desirable default), it would be great to have a flag to allow users enable determinism.

This issue has recently been addressed in tf, for cuDNN stuff see here and here, I can't find the relevant pr for reductions but I believe they were made deterministic recently. For a list of things that still may not be deterministic in tf see here.

documentation question

Most helpful comment

TF_CUDNN_DETERMINISTIC should work with XLA:GPU.

XLA:GPU reductions are nondeterministic, though. Changing this would be a lot of work. If you all wanted us to prioritize it, we should talk to understand the costs/benefits.

All 4 comments

Great idea. There's a chance TF_CUDNN_DETERMINISTIC will already work with XLA (and hence JAX) too, though @jlebar could provide more insight.

TF_CUDNN_DETERMINISTIC should work with XLA:GPU.

XLA:GPU reductions are nondeterministic, though. Changing this would be a lot of work. If you all wanted us to prioritize it, we should talk to understand the costs/benefits.

@j-towns since we now know we're automatically at parity with TF/XLA cuDNN determinism using the same flag, I think it makes sense to close this issue, though please re-open if we should follow up more here. As Justin points out, getting total determinism is a bigger ask, and (imo) might be better discussed on the TF issue tracker, where a larger audience can weigh in.

It seems setting the environment variable XLA_FLAGS=--xla_gpu_deterministic_reductions is also a good idea. Not sure if TF_CUDNN_DETERMINISTIC is also needed. See also #4823 and google/flax#33.

Was this page helpful?
0 / 5 - 0 ratings