Jax: Limit number of threads when running on CPU

Created on 21 Oct 2019  路  22Comments  路  Source: google/jax

I am assembling a single-core CPU performance benchmark for various HPC libraries from the modern Python ecosystem. I would like to include JAX, and first results seem very promising, but I'm failing to restrict it to a single thread. As far as I can tell, there is no corresponding setting in jax.config, and it doesn't listen to any of the popular flags (e.g. OMP_NUM_THREADS).

I installed JAX and jaxlib from PyPI on OSX.

Is there any way to pull this off?

documentation

Most helpful comment

A quick note: setting the task affinity map is the correct way to limit JAX's CPU usage at the moment. JAX sizes its main threadpool using this logic:
https://github.com/tensorflow/tensorflow/blob/4b2cb67756009dda843c6b56a8b320c8a54373e0/tensorflow/core/platform/default/port.cc#L67

If launching from mpirun, mpirun knows how to set task affinities correctly.

All 22 comments

Hey @dionhaefner ! I had a similar issue: https://github.com/google/jax/issues/743 HTH

We should document the solution we figured out in #743.

Hmm, unfortunately that doesn't seem to do it.

I am using this script for testing:

import jax


@jax.jit
def bench(sa, ct, p):
    return sa + ct * p


def run(sa, ct, p):
    return bench(sa, ct, p).block_until_ready()


if __name__ == '__main__':
    import numpy
    size = 10_000_000
    s = numpy.random.rand(size)
    t = numpy.random.rand(size)
    p = numpy.random.rand(size)

    for _ in range(100):
        run(s, t, p)

Running this gives:

XLA_FLAGS="--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1 inter_op_parallelism_threads=1" time python bench_jax.py
<snip>/lib/python3.7/site-packages/jax/lib/xla_bridge.py:115: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')
        6.64 real        10.59 user         1.98 sys

Note the lower real than user time. Looking at the resource monitor, I can see ~180% CPU usage.

I'm having the same issue trying to run jax on an HPC with multiple CPUs. The solution from #743 doesn't work for me either.

I tried setting

os.environ['MKL_NUM_THREADS']='1' 
os.environ['OPENBLAS_NUM_THREADS']='1'

os.environ["NUM_INTER_THREADS"]="1"
os.environ["NUM_INTRA_THREADS"]="1"

os.environ["XLA_FLAGS"] = ("--xla_cpu_multi_thread_eigen=false "
                           "intra_op_parallelism_threads=1")

I'm getting RuntimeError: Resource temporarily unavailable on the HPC and on my laptop I can see how the code uses more than one CPU.

@dionhaefner did you manage to et around this problem?

And here's the jax mnist example, where I tried to disable multithreading putting the above flags in the beginning of the python script, but it doesn't work :(

mnist_single_thread.txt

Maybe you could try the threadpoolctl library to set the number of threads for all backends. (BLAS, openmp, MKL, ... that is. I don't mean JAX backends.)

```python
from threadpoolctl import threadpool_limits

[...]

with threadpool_limits(limits=1):
for _ in range(10000):
run(s, t, p)

I ended up setting processor affinity (only works on Unix systems though, and might require special permissions):

$ taskset -c 0 python myscript.py

But won't that limit the script to one specific CPU? And you have to manage manually which instance runs on which CPU.

Yes. It's just a workaround, not really a solution.

I wonder if it's jax or xla (or some other internal package) which spawns the extra threads.

In any case it'd be very useful to have a clear example (maybe even as part of the official jax list of example) on how to get this under control; otherwise allocating resources on hpc clusters becomes a tedious trial-and-error task.

Afaik the threads are spawned by the CPU backend(s), namely openmp, BLAS, MKL and the like. To get that under control is specifically the purpose of threadpoolctl. If I understand your problem/usecase correctly, it's exactly what you need!

I guess I'm doing something wrong, but threadpoolctl doesn't seem to work for me on osx. Could you take a quick look at the updated

mnist_single_thread.txt

This code uses 22 threads and up to 500 %CPU (as given in top).

I believe the code XLA generates on CPU doesn't use MKL, OpenBLAS, or the system BLAS, so environment variables related to those libraries are unlikely to have an effect; for BLAS and related operations (e.g. convolutions), it uses an embedded copy of Eigen (really Eigen's Tensor sub-library) and for everything else it generates its own code with LLVM.

--xla_cpu_multi_thread_eigen=false should ensure that the BLAS library calls use Eigen in single-threaded mode, but non-BLAS code seems to be multi-threaded based on the XLA configuration option intra_op_parallelism_threads (which might not be wired through as a flag, or at least I can't find it?)

CC @hawkinsp

Oh no, I think I mixed up my anecdotal evidence from numba with this issue. I'm sorry mgbukov that I led you on the wrong track.

@clemisch no worries, ideas are always welcome!

@jekbradbury I did set both --xla_cpu_multi_thread_eigen=false and intra_op_parallelism_threads in mnist_single_thread.txt above, but there's still some residual multithreading. Are you suggesting that there's something wrong with the flags/the way ti set them in the script, or that there are more/other flags I need to set?

It's going to be almost impossible to completely eliminate threading from JAX; the runtime internally uses multiple threads, e.g., to overlap the Python interpreter with XLA compilation and interpretation.

That said, we can probably avoid threading in the compute-intensive XLA-generated code; XLA itself has support for this although I'm unsure if it's plumbed through the the API surface in a usable way.

A quick note: setting the task affinity map is the correct way to limit JAX's CPU usage at the moment. JAX sizes its main threadpool using this logic:
https://github.com/tensorflow/tensorflow/blob/4b2cb67756009dda843c6b56a8b320c8a54373e0/tensorflow/core/platform/default/port.cc#L67

If launching from mpirun, mpirun knows how to set task affinities correctly.

I have a use-case where I want to run a JIT compiled function with only a single CPU, using separate threading to handle parallelism (with Dask). At the moment, it looks like there's no way to do this?

Similar to @shoyer I wish to orchestrate many multithreading XLA ops on CPU, except using pmap instead of dask. Each XLA op is a jitted function containing some BLAS and non-BLAS code. I'd like the BLAS code to be limited to 1 or 2 threads per job.

I've noted this behaviour that doesn't exactly make sense to me.

    ncpu=2
    os.environ['XLA_FLAGS'] = f"--xla_force_host_platform_device_count={ncpu}"
    os.environ["XLA_FLAGS"] = ("--xla_cpu_multi_thread_eigen=false "
                               "intra_op_parallelism_threads=1")
    from jax import local_device_count
    print(local_device_count()) #1 instead of 2

The desire is that one can create ncpu virtual devices and turn off the intra_op parallelism. However, the suggested method in this issue seems to do neither of these things.

In your example you immediately overwrite the XLA_FLAGS env var. Did you mean to append instead?

Ah that's right! Thanks @dionhaefner

Was this page helpful?
0 / 5 - 0 ratings

Related issues

shoyer picture shoyer  路  24Comments

proteneer picture proteneer  路  22Comments

samuela picture samuela  路  27Comments

alexbw picture alexbw  路  26Comments

kirk86 picture kirk86  路  22Comments