By default jax appears to multithread most operations, eg.
x = jr.normal(jrkey, shape=(50000, 50000))
x @ x
will run across all available cores. This is great in general, and matches numpy's behavior. But it presents problems when trying to run a bunch of small operations in parallel, eg. running the same script initialized with 4 different random seeds on a 4-core machine.
Is there any option in jax to cap the number of threads that it uses? Something like https://stackoverflow.com/questions/17053671/python-how-do-you-stop-numpy-from-multithreading?
Great question!
I'm not sure, but I think we'd want to set the xla_cpu_multi_thread_eigen (also defined as a flag) and intra_op_parallelism_threads XLA options to be False. We could expose those in JAX somehow...
My Hail Mary attempt of setting the OPENBLAS_NUM_THREADS, MKL_NUM_THREADS, and OMP_NUM_THREADS environment variables didn't work.
Those are the right environment variables to fiddle with for PyTorch, but they won't have an effect in JAX (or any TensorFlow or XLA-based codebase on CPU) because the BLAS library used is Eigen Tensor (not OpenBLAS or MKL) and the threading mechanism used is Eigen threadpools (not OpenMP).
This is now a blocking issue for me. I'd be happy to put together a PR, but I'm not really sure how to start... Also not sure how you guys would like this to be exposed to the user.
For my usecase I'd like to run many jax experiments in parallel. Ideally these could all be managed as multiprocessing threads in a Pool with each task restricted to some subset of CPUs. When using numpy I accomplish this with
import os
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
...
pool = multiprocessing.Pool()
pool.imap_unordered(my_task, range(num_random_seeds))
But if thread-based throttling isn't possible that's not a dealbreaker. I can always kick off jobs as separate python processes.
Did you try setting these environment variables? (My comment didn't explain this very well.)
XLA_FLAGS="--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1" python my_file.py
That seems to work for me in a test, at least for a big matmul.
Awesome, that seems to do the trick! Thank you so much!
For my own (and other's) future googling, my current approach looks like
from multiprocessing import get_context
import os
# Limit ourselves to single-threaded jax/xla operations to avoid thrashing. See
# https://github.com/google/jax/issues/743.
os.environ["XLA_FLAGS"] = ("--xla_cpu_multi_thread_eigen=false "
"intra_op_parallelism_threads=1")
def job(random_seed: int):
# jax jax jax
if __name__ == "__main__":
# See https://codewithoutrules.com/2018/09/04/python-multiprocessing/.
with get_context("spawn").Pool() as pool:
pool.imap_unordered(job, range(100))
There may be a way better way, but it seems to work 🤷♀️
@mattjj I don't see intra_op_parallelism_threads among the XLA_FLAG options. Also, I still see the multi-threaded behaviour trying your suggested combination of threads. Any chance that it would be possible to set intra_op thread limit specifically somewhere?
Most helpful comment
For my own (and other's) future googling, my current approach looks like
There may be a way better way, but it seems to work 🤷♀️