I am assembling a single-core CPU performance benchmark for various HPC libraries from the modern Python ecosystem. I would like to include JAX, and first results seem very promising, but I'm failing to restrict it to a single thread. As far as I can tell, there is no corresponding setting in jax.config, and it doesn't listen to any of the popular flags (e.g. OMP_NUM_THREADS).
I installed JAX and jaxlib from PyPI on OSX.
Is there any way to pull this off?
Hey @dionhaefner ! I had a similar issue: https://github.com/google/jax/issues/743 HTH
We should document the solution we figured out in #743.
Hmm, unfortunately that doesn't seem to do it.
I am using this script for testing:
import jax
@jax.jit
def bench(sa, ct, p):
return sa + ct * p
def run(sa, ct, p):
return bench(sa, ct, p).block_until_ready()
if __name__ == '__main__':
import numpy
size = 10_000_000
s = numpy.random.rand(size)
t = numpy.random.rand(size)
p = numpy.random.rand(size)
for _ in range(100):
run(s, t, p)
Running this gives:
XLA_FLAGS="--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1 inter_op_parallelism_threads=1" time python bench_jax.py
<snip>/lib/python3.7/site-packages/jax/lib/xla_bridge.py:115: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
6.64 real 10.59 user 1.98 sys
Note the lower real than user time. Looking at the resource monitor, I can see ~180% CPU usage.
I'm having the same issue trying to run jax on an HPC with multiple CPUs. The solution from #743 doesn't work for me either.
I tried setting
os.environ['MKL_NUM_THREADS']='1'
os.environ['OPENBLAS_NUM_THREADS']='1'
os.environ["NUM_INTER_THREADS"]="1"
os.environ["NUM_INTRA_THREADS"]="1"
os.environ["XLA_FLAGS"] = ("--xla_cpu_multi_thread_eigen=false "
"intra_op_parallelism_threads=1")
I'm getting RuntimeError: Resource temporarily unavailable on the HPC and on my laptop I can see how the code uses more than one CPU.
@dionhaefner did you manage to et around this problem?
And here's the jax mnist example, where I tried to disable multithreading putting the above flags in the beginning of the python script, but it doesn't work :(
Maybe you could try the threadpoolctl library to set the number of threads for all backends. (BLAS, openmp, MKL, ... that is. I don't mean JAX backends.)
```python
from threadpoolctl import threadpool_limits
[...]
with threadpool_limits(limits=1):
for _ in range(10000):
run(s, t, p)
I ended up setting processor affinity (only works on Unix systems though, and might require special permissions):
$ taskset -c 0 python myscript.py
But won't that limit the script to one specific CPU? And you have to manage manually which instance runs on which CPU.
Yes. It's just a workaround, not really a solution.
I wonder if it's jax or xla (or some other internal package) which spawns the extra threads.
In any case it'd be very useful to have a clear example (maybe even as part of the official jax list of example) on how to get this under control; otherwise allocating resources on hpc clusters becomes a tedious trial-and-error task.
Afaik the threads are spawned by the CPU backend(s), namely openmp, BLAS, MKL and the like. To get that under control is specifically the purpose of threadpoolctl. If I understand your problem/usecase correctly, it's exactly what you need!
I guess I'm doing something wrong, but threadpoolctl doesn't seem to work for me on osx. Could you take a quick look at the updated
This code uses 22 threads and up to 500 %CPU (as given in top).
I believe the code XLA generates on CPU doesn't use MKL, OpenBLAS, or the system BLAS, so environment variables related to those libraries are unlikely to have an effect; for BLAS and related operations (e.g. convolutions), it uses an embedded copy of Eigen (really Eigen's Tensor sub-library) and for everything else it generates its own code with LLVM.
--xla_cpu_multi_thread_eigen=false should ensure that the BLAS library calls use Eigen in single-threaded mode, but non-BLAS code seems to be multi-threaded based on the XLA configuration option intra_op_parallelism_threads (which might not be wired through as a flag, or at least I can't find it?)
CC @hawkinsp
Oh no, I think I mixed up my anecdotal evidence from numba with this issue. I'm sorry mgbukov that I led you on the wrong track.
@clemisch no worries, ideas are always welcome!
@jekbradbury I did set both --xla_cpu_multi_thread_eigen=false and intra_op_parallelism_threads in mnist_single_thread.txt above, but there's still some residual multithreading. Are you suggesting that there's something wrong with the flags/the way ti set them in the script, or that there are more/other flags I need to set?
It's going to be almost impossible to completely eliminate threading from JAX; the runtime internally uses multiple threads, e.g., to overlap the Python interpreter with XLA compilation and interpretation.
That said, we can probably avoid threading in the compute-intensive XLA-generated code; XLA itself has support for this although I'm unsure if it's plumbed through the the API surface in a usable way.
A quick note: setting the task affinity map is the correct way to limit JAX's CPU usage at the moment. JAX sizes its main threadpool using this logic:
https://github.com/tensorflow/tensorflow/blob/4b2cb67756009dda843c6b56a8b320c8a54373e0/tensorflow/core/platform/default/port.cc#L67
If launching from mpirun, mpirun knows how to set task affinities correctly.
I have a use-case where I want to run a JIT compiled function with only a single CPU, using separate threading to handle parallelism (with Dask). At the moment, it looks like there's no way to do this?
Similar to @shoyer I wish to orchestrate many multithreading XLA ops on CPU, except using pmap instead of dask. Each XLA op is a jitted function containing some BLAS and non-BLAS code. I'd like the BLAS code to be limited to 1 or 2 threads per job.
I've noted this behaviour that doesn't exactly make sense to me.
ncpu=2
os.environ['XLA_FLAGS'] = f"--xla_force_host_platform_device_count={ncpu}"
os.environ["XLA_FLAGS"] = ("--xla_cpu_multi_thread_eigen=false "
"intra_op_parallelism_threads=1")
from jax import local_device_count
print(local_device_count()) #1 instead of 2
The desire is that one can create ncpu virtual devices and turn off the intra_op parallelism. However, the suggested method in this issue seems to do neither of these things.
In your example you immediately overwrite the XLA_FLAGS env var. Did you mean to append instead?
Ah that's right! Thanks @dionhaefner
Most helpful comment
A quick note: setting the task affinity map is the correct way to limit JAX's CPU usage at the moment. JAX sizes its main threadpool using this logic:
https://github.com/tensorflow/tensorflow/blob/4b2cb67756009dda843c6b56a8b320c8a54373e0/tensorflow/core/platform/default/port.cc#L67
If launching from
mpirun,mpirunknows how to set task affinities correctly.