@ziatdinovmax was observing a phenomenon that memory keeps increasing until crashing for some NumPyro models (specifically, the issue https://github.com/pyro-ppl/numpyro/issues/447) when they are run in GPU. Here is an isolated repro code, which only uses JAX ops:
from jax import random, lax
def model(key):
D_X, D_H, D_Y = 3, 5, 1
# sample first layer (we put unit normal priors on all weights)
key, subkey = random.split(key)
w1 = random.normal(subkey, shape=(D_X, D_H))
key, subkey = random.split(key)
w1 = random.uniform(subkey, shape=w1.shape, minval=-2, maxval=2)
# sample second layer
key, subkey = random.split(key)
w2 = random.normal(subkey, shape=(D_H, D_H))
key, subkey = random.split(key)
w2 = random.uniform(subkey, shape=w2.shape, minval=-2, maxval=2)
# sample final layer of weights and neural network output
key, subkey = random.split(key)
w3 = random.normal(subkey, shape=(D_H, D_Y))
key, subkey = random.split(key)
w3 = random.uniform(subkey, shape=w3.shape, minval=-2, maxval=2)
# we put a prior on the observation noise
key, subkey = random.split(key)
prec_obs = random.normal(subkey)
key, subkey = random.split(key)
prec_obs = random.uniform(subkey, shape=prec_obs.shape, minval=-2, maxval=2)
return w1, w2, w3, prec_obs
def cond_fn(state):
return state[0] < 100
def body_fn(state):
i, key, _ = state
key, subkey = random.split(key)
return i + 1, key, model(subkey)
init_state = (0, random.PRNGKey(0), model(random.PRNGKey(2019)))
i, key, params = lax.while_loop(cond_fn, body_fn, init_state)
Interesting, if I use rolled loop in threefry_2x32, the issue does not happen. Could it be a hint?
cc @mattjj :)
We also observed a memory leak in https://github.com/pyro-ppl/brmp/issues/67 on the CPU, which may be related. When I turn on jax.core.check_leaks=True, I get the following stacktrace, which also seems to point to line 150 in threefry_2x32.
traceback
Traceback (most recent call last):
File "/Users/npradhan/workspace/pyro_dev/brmp/tests/ex2.py", line 38, in <module>
mcmc.run(random.PRNGKey(i), X, X)
File "/Users/npradhan/workspace/pyro_dev/numpyro/numpyro/infer/mcmc.py", line 639, in run
args, kwargs)
File "/Users/npradhan/workspace/pyro_dev/numpyro/numpyro/infer/mcmc.py", line 582, in _single_chain_mcmc
model_args=args, model_kwargs=kwargs)
File "/Users/npradhan/workspace/pyro_dev/numpyro/numpyro/infer/mcmc.py", line 404, in init
rng_key, rng_key_init_model = random.split(rng_key)
File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/random.py", line 194, in split
return _split(key, num)
File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/api.py", line 149, in f_jitted
out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend)
File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/core.py", line 591, in call_bind
outs = primitive.impl(f, *args, **params)
File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/interpreters/xla.py", line 347, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, *map(abstractify, args))
File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/linear_util.py", line 209, in memoized_fun
ans = call(fun, *args)
File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/interpreters/xla.py", line 361, in _xla_callable
jaxpr, (pvals, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/linear_util.py", line 153, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/random.py", line 199, in _split
return lax.reshape(threefry_2x32(key, counts), (num, 2))
File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/api.py", line 149, in f_jitted
out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend)
File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/core.py", line 594, in call_bind
outs = map(full_lower, top_trace.process_call(primitive, f, tracers, params))
File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 111, in process_call
out_flat = call_primitive.bind(fun, *in_consts, **params)
File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/core.py", line 591, in call_bind
outs = primitive.impl(f, *args, **params)
File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/interpreters/xla.py", line 347, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, *map(abstractify, args))
File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/linear_util.py", line 209, in memoized_fun
ans = call(fun, *args)
File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/interpreters/xla.py", line 361, in _xla_callable
jaxpr, (pvals, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/linear_util.py", line 153, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/random.py", line 150, in threefry_2x32
x, _, _ = lax.fori_loop(0, 5, step, (x, rotate_list(ks), rotations))
File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/lax/lax_control_flow.py", line 132, in fori_loop
Trace stack
[] ---
[' MasterTrace(-1,JaxprTrace)\n', ' MasterTrace(-2,JaxprTrace)\n']
(lower, upper, init_val))
File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/lax/lax_control_flow.py", line 174, in while_loop
cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(cond_fun, in_tree, init_avals)
File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/lax/lax_control_flow.py", line 59, in _initial_style_jaxpr
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True)
File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 345, in trace_to_jaxpr
del master
File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/contextlib.py", line 119, in __exit__
next(self.gen)
File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/core.py", line 460, in new_master
raise Exception('Leaked trace {}'.format(t()))
Exception: Leaked trace MasterTrace(0,JaxprTrace)
@fehiepsi I think this is simply that XLA is taking a lot of memory to compile that computation. I raised an issue with the XLA team, although I suspect we are going to need to move the PRNG computation out of XLA and into a custom CUDA kernel as a short-term fix. PRNG compilation time/space is a constant pain point.
@neerajprad The experimental check_leaks support has bitrotted, so I wouldn't trust the results. But if you can provide a small self-contained reproduction of the CPU memory problem, we can look into it.
Thanks, Peter!
@hawkinsp - Regarding the second case, @fehiepsi was able to create a small JAX only example that captures at least a part of the issue, I think. Here is the code snippet:
import gc
from collections import Counter
import jax; jax.config.update('jax_platform_name', 'cpu')
from jax import numpy as np
from jax import lax, random
class ABC:
def body_fn(self, i, key):
key, _ = random.split(key)
return key
for i in range(3):
abc = ABC()
value = lax.fori_loop(0, 10, abc.body_fn, random.PRNGKey(0))
print("\nGC OBJECTS:")
cnt = Counter()
# force gc collection
gc.collect()
for x in gc.get_objects():
if isinstance(x, list):
if len(x) > 1:
cnt[type(x[0])] += 1
print(cnt.most_common(10))
We find that the count of jax.core.Var objects keeps increasing per iteration. If we pull out the body_fn function and use it directly, the count of all objects types remains the same per iteration, as expected. Is this difference in behavior expected?
@neerajprad thanks for finding this. Coincidentally I think someone brought this up in our Google chat room just a few hours ago.
I haven't thought about it yet but the first suspect that jumps to my mind is this line. After our detuplification rewrite in #1224, we create cycles in the JAX tracer graph (which ultimately gets serialized into a jaxpr). Maybe we need a weakref there...
EDIT: and thanks for the repro script, that's incredibly helpful.
Actually, sorry, I read this too quickly and got confused.
The reason the count of Vars is increasing here is just that
Var objects appear in jaxprswhile_p primitive (underling fori_loop)while_p so as not to recompile this lax.fori_loop many timesSo the script here is actually just showing us that the cache is working.
I checked that by just removing the caching from xla.xla_primitive_callable and from lax_control_flow._initial_style_jaxpr. Then the counts didn't increase.
There is a leak in the JaxprTracer graph though, exactly where Dougal thought there might be.
I'll follow up with a PR.
@mattjj - Thanks for looking into this! That makes sense. Any idea why this doesn't recur when body_fn is outside (as a regular function)?
We are trying to debug https://github.com/pyro-ppl/brmp/issues/67, and isolated it to this snippet (though it could be some other issue, or just be something odd that we are doing in numpyro). Your comments have been very helpful, and I will try to debug further based on what you said.
There is a leak in the JaxprTracer graph though, exactly where Dougal thought there might be.
Looking forward to your PR. I'll run it against the original issue from @null-a, and see if that helps with that.
I can confirm that we do not see this issue with the original example if we turn off all caching, so this just seems to be expected caching behavior. The difference in behavior that I raised earlier, is simply because we are creating new instances in the loop, so the self arg is different, which creates a new entry in the cache. The elements in the cache won't be removed until it reaches its max size.
Sorry about the red herring here, @mattjj. Please feel free to close this issue if the other leak issue is being tracked elsewhere.
No worries, please always eagerly report _possible_ bugs, so we can work through them just like this! (We love when you guys raise issues.)
Also, let us know if the caches are growing too large in your use cases. We can make the caching logic smarter.
@hawkinsp should we leave this issue open pending some XLA:GPU compilation follow-up?
Yes. As I mentioned above, the immediate fix is adding a threefry CUDA kernel to jaxlib.
Also, let us know if the caches are growing too large in your use cases. We can make the caching logic smarter.
Thanks for the help, we will let you know. I think there are a few optimizations that we can make in numpyro first that should result in better cache hits and lesser entries overall.
PR #1756 makes the example in this issue compile quickly and without the memory blowup. Note the fix requires a rebuild of jaxlib (or requires waiting for us to release new jaxlib wheels.) Hope that helps!
Whoa!!! Thanks so much for the quick fix, Peter!
Most helpful comment
PR #1756 makes the example in this issue compile quickly and without the memory blowup. Note the fix requires a rebuild of jaxlib (or requires waiting for us to release new jaxlib wheels.) Hope that helps!