The following repro script shows the slowness for the issue in the title. This happens in CPU and might happen in GPU too (but I don't have 2 GPUs to test).
import jax.numpy as np
from jax import pmap, jit, random
# requires the flag: XLA_FLAGS=--xla_force_host_platform_device_count=12
from jax.config import config; config.update('jax_platform_name', 'cpu')
n = 10
@jit
def g(rng, x):
key, subkey = random.split(rng)
return key, x + 0.01 * random.normal(subkey)
def f(x):
rng = random.PRNGKey(0)
collection = []
for i in range(n):
rng, x = g(rng, x)
collection.append(x)
return np.stack(collection)
pmap(f)(np.array([0., 1.]))
cc @neerajprad
I haven't looked at this in detail yet, but one thing to mention is that on CPU telling XLA to treat each core as its own device sets up the possibility for a lot of thread contention unless you also tell it not to try to multithread within each device (by setting xla_cpu_multi_thread_eigen to be false and intra_op_parallelism_threads to be 1, though we might not even have flags / env vars for those yet).
More generally, manual parallelism on the CPU is pretty much totally untested at this point, both by us and (AFAIK) the XLA team.
That said, we'd like to make it work! So this is a great issue to push us forward. At the very least we can look into the issue and maybe plumb these flags if they're relevant.
manual parallelism on the CPU is pretty much totally untested at this point, both by us and (AFAIK) the XLA team.
I have tested for a variety of operators (even involving lax.fori_loop) and pmap works pretty well! :)
A strange thing to me is that if I return some random result (instead of the result from g), then it is pretty fast. So I guess the running work of pmap is fine and the slowness happens at the collect work of pmap. For example,
def f(x):
rng = random.PRNGKey(0)
collection = []
x_init = x
for i in range(n):
print(i)
rng, x = g(rng, x)
collection.append(x_init + 1)
return np.stack(collection)
setting xla_cpu_multi_thread_eigen to be false and intra_op_parallelism_threads to be 1
Thanks for your suggestion about those flags! I will try to see if I can set it and if it works. If this issue just happens in CPU then it is highly related to these flags.
By the way, can I ask the following related question: in which device print(i) will trigger?
I have tested for a variety of operators (even involving lax.fori_loop) and pmap works pretty well! :)
Wow, cool! I'm glad to hear that. And I appreciate your positivity :D
Thanks for your suggestion about those flags! I will try to see if I can set it and if it works.
Just to clarify about those options, they live somewhere in XLA and they might not be easily settable from e.g. your shell (until we plumb them through). You might have to grep the XLA source code for them.
By the way, can I ask the following related question: in which device print(i) will trigger?
The Python code isn't run in parallel under pmap; instead, it's specialized into a jaxpr (i.e. traced on abstract values) and then the function represented by that jaxpr is replicated across cores and executed in parallel on different data. The print call will execute at trace time, so it'll just happen on whatever core Python is running, but that's before any of the parallelism or indeed any of the FLOPs start happening. (As to which core Python decides to run on, and how that relates to what cores are associated with different logical XLA:CPU devices, I'm not sure, but I'd guess it's up to the OS to decide.)
Does that make sense?
Your explanation totally makes sense to me (because the printed message appears quite fast but it took a while to get results)! Thank you a lot!
Thanks @mattjj! I can make it fast now with
XLA_FLAGS="--xla_force_host_platform_device_count=12 --xla_cpu_multi_thread
_eigen=False --intra_op_parallelism_threads=1"
Currently, it works well for some toy models, so I will close this issue. If I find some performance issue with larger models, I'll open a separated issue. Thanks again!
Wow, that's great to hear! That was a lucky guess about contention. And I'm really happy those flags seem to be plumbed through to actually work!
By the way, AIUI none of the XLA collective operations (for cross-device communication, like all-reduces) are implemented on XLA:CPU, so a lot of pmap features (including nested pmaps) won't work on XLA:CPU yet. Several are being added to XLA:GPU now though, so there's a lot of forward progress.
If you end up wanting a collective primitive on XLA:CPU, just open an issue and we'll work with the XLA folks on it!
Thanks @mattjj ! Sadly that I don't know if we need that feature. Currently, my usage case is to replicate a function (from beginning to end) across devices and collect the results. It seems that pmap just does that job beautifully without cost (except intra_op_parallelism I guess). If I need those advanced features, I'll surely open an issue for it. :)