Jax: memory leak of jax.random in GPU

Created on 18 Nov 2019  路  16Comments  路  Source: google/jax

@ziatdinovmax was observing a phenomenon that memory keeps increasing until crashing for some NumPyro models (specifically, the issue https://github.com/pyro-ppl/numpyro/issues/447) when they are run in GPU. Here is an isolated repro code, which only uses JAX ops:

from jax import random, lax

def model(key):
    D_X, D_H, D_Y = 3, 5, 1
    # sample first layer (we put unit normal priors on all weights)
    key, subkey = random.split(key)
    w1 = random.normal(subkey, shape=(D_X, D_H))
    key, subkey = random.split(key)
    w1 = random.uniform(subkey, shape=w1.shape, minval=-2, maxval=2)

    # sample second layer
    key, subkey = random.split(key)
    w2 = random.normal(subkey, shape=(D_H, D_H))
    key, subkey = random.split(key)
    w2 = random.uniform(subkey, shape=w2.shape, minval=-2, maxval=2)

    # sample final layer of weights and neural network output
    key, subkey = random.split(key)
    w3 = random.normal(subkey, shape=(D_H, D_Y))
    key, subkey = random.split(key)
    w3 = random.uniform(subkey, shape=w3.shape, minval=-2, maxval=2)

    # we put a prior on the observation noise
    key, subkey = random.split(key)
    prec_obs = random.normal(subkey)
    key, subkey = random.split(key)
    prec_obs = random.uniform(subkey, shape=prec_obs.shape, minval=-2, maxval=2)
    return w1, w2, w3, prec_obs

def cond_fn(state):
    return state[0] < 100

def body_fn(state):
    i, key, _ = state
    key, subkey = random.split(key)
    return i + 1, key, model(subkey)

init_state = (0, random.PRNGKey(0), model(random.PRNGKey(2019)))
i, key, params = lax.while_loop(cond_fn, body_fn, init_state)

Interesting, if I use rolled loop in threefry_2x32, the issue does not happen. Could it be a hint?

cc @mattjj :)

Most helpful comment

PR #1756 makes the example in this issue compile quickly and without the memory blowup. Note the fix requires a rebuild of jaxlib (or requires waiting for us to release new jaxlib wheels.) Hope that helps!

All 16 comments

We also observed a memory leak in https://github.com/pyro-ppl/brmp/issues/67 on the CPU, which may be related. When I turn on jax.core.check_leaks=True, I get the following stacktrace, which also seems to point to line 150 in threefry_2x32.

traceback

Traceback (most recent call last):
  File "/Users/npradhan/workspace/pyro_dev/brmp/tests/ex2.py", line 38, in <module>
    mcmc.run(random.PRNGKey(i), X, X)
  File "/Users/npradhan/workspace/pyro_dev/numpyro/numpyro/infer/mcmc.py", line 639, in run
    args, kwargs)
  File "/Users/npradhan/workspace/pyro_dev/numpyro/numpyro/infer/mcmc.py", line 582, in _single_chain_mcmc
    model_args=args, model_kwargs=kwargs)
  File "/Users/npradhan/workspace/pyro_dev/numpyro/numpyro/infer/mcmc.py", line 404, in init
    rng_key, rng_key_init_model = random.split(rng_key)
  File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/random.py", line 194, in split
    return _split(key, num)
  File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/api.py", line 149, in f_jitted
    out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend)
  File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/core.py", line 591, in call_bind
    outs = primitive.impl(f, *args, **params)
  File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/interpreters/xla.py", line 347, in _xla_call_impl
    compiled_fun = _xla_callable(fun, device, backend, *map(abstractify, args))
  File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/linear_util.py", line 209, in memoized_fun
    ans = call(fun, *args)
  File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/interpreters/xla.py", line 361, in _xla_callable
    jaxpr, (pvals, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
  File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/linear_util.py", line 153, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/random.py", line 199, in _split
    return lax.reshape(threefry_2x32(key, counts), (num, 2))
  File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/api.py", line 149, in f_jitted
    out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend)
  File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/core.py", line 594, in call_bind
    outs = map(full_lower, top_trace.process_call(primitive, f, tracers, params))
  File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 111, in process_call
    out_flat = call_primitive.bind(fun, *in_consts, **params)
  File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/core.py", line 591, in call_bind
    outs = primitive.impl(f, *args, **params)
  File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/interpreters/xla.py", line 347, in _xla_call_impl
    compiled_fun = _xla_callable(fun, device, backend, *map(abstractify, args))
  File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/linear_util.py", line 209, in memoized_fun
    ans = call(fun, *args)
  File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/interpreters/xla.py", line 361, in _xla_callable
    jaxpr, (pvals, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
  File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/linear_util.py", line 153, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/random.py", line 150, in threefry_2x32
    x, _, _ = lax.fori_loop(0, 5, step, (x, rotate_list(ks), rotations))
  File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/lax/lax_control_flow.py", line 132, in fori_loop
Trace stack
[] ---
['  MasterTrace(-1,JaxprTrace)\n', '  MasterTrace(-2,JaxprTrace)\n']
    (lower, upper, init_val))
  File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/lax/lax_control_flow.py", line 174, in while_loop
    cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(cond_fun, in_tree, init_avals)
  File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/lax/lax_control_flow.py", line 59, in _initial_style_jaxpr
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True)
  File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 345, in trace_to_jaxpr
    del master
  File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/contextlib.py", line 119, in __exit__
    next(self.gen)
  File "/Users/npradhan/miniconda3/envs/brmp/lib/python3.7/site-packages/jax/core.py", line 460, in new_master
    raise Exception('Leaked trace {}'.format(t()))
Exception: Leaked trace MasterTrace(0,JaxprTrace)

@fehiepsi I think this is simply that XLA is taking a lot of memory to compile that computation. I raised an issue with the XLA team, although I suspect we are going to need to move the PRNG computation out of XLA and into a custom CUDA kernel as a short-term fix. PRNG compilation time/space is a constant pain point.

@neerajprad The experimental check_leaks support has bitrotted, so I wouldn't trust the results. But if you can provide a small self-contained reproduction of the CPU memory problem, we can look into it.

Thanks, Peter!

@hawkinsp - Regarding the second case, @fehiepsi was able to create a small JAX only example that captures at least a part of the issue, I think. Here is the code snippet:

import gc
from collections import Counter

import jax; jax.config.update('jax_platform_name', 'cpu')
from jax import numpy as np
from jax import lax, random

class ABC:
    def body_fn(self, i, key):
        key, _ = random.split(key)
        return key

for i in range(3):
    abc = ABC()
    value = lax.fori_loop(0, 10, abc.body_fn, random.PRNGKey(0))

    print("\nGC OBJECTS:")
    cnt = Counter()
    # force gc collection
    gc.collect()
    for x in gc.get_objects():
        if isinstance(x, list):
            if len(x) > 1:
                cnt[type(x[0])] += 1
    print(cnt.most_common(10))

We find that the count of jax.core.Var objects keeps increasing per iteration. If we pull out the body_fn function and use it directly, the count of all objects types remains the same per iteration, as expected. Is this difference in behavior expected?

@neerajprad thanks for finding this. Coincidentally I think someone brought this up in our Google chat room just a few hours ago.

I haven't thought about it yet but the first suspect that jumps to my mind is this line. After our detuplification rewrite in #1224, we create cycles in the JAX tracer graph (which ultimately gets serialized into a jaxpr). Maybe we need a weakref there...

EDIT: and thanks for the repro script, that's incredibly helpful.

Actually, sorry, I read this too quickly and got confused.

The reason the count of Vars is increasing here is just that

  1. Var objects appear in jaxprs
  2. there are jaxprs representing the cond and body functions present in the parameters of the while_p primitive (underling fori_loop)
  3. we memoize the eval rule of while_p so as not to recompile this lax.fori_loop many times

So the script here is actually just showing us that the cache is working.

I checked that by just removing the caching from xla.xla_primitive_callable and from lax_control_flow._initial_style_jaxpr. Then the counts didn't increase.

There is a leak in the JaxprTracer graph though, exactly where Dougal thought there might be.

I'll follow up with a PR.

@mattjj - Thanks for looking into this! That makes sense. Any idea why this doesn't recur when body_fn is outside (as a regular function)?

We are trying to debug https://github.com/pyro-ppl/brmp/issues/67, and isolated it to this snippet (though it could be some other issue, or just be something odd that we are doing in numpyro). Your comments have been very helpful, and I will try to debug further based on what you said.

There is a leak in the JaxprTracer graph though, exactly where Dougal thought there might be.

Looking forward to your PR. I'll run it against the original issue from @null-a, and see if that helps with that.

I can confirm that we do not see this issue with the original example if we turn off all caching, so this just seems to be expected caching behavior. The difference in behavior that I raised earlier, is simply because we are creating new instances in the loop, so the self arg is different, which creates a new entry in the cache. The elements in the cache won't be removed until it reaches its max size.

Sorry about the red herring here, @mattjj. Please feel free to close this issue if the other leak issue is being tracked elsewhere.

No worries, please always eagerly report _possible_ bugs, so we can work through them just like this! (We love when you guys raise issues.)

Also, let us know if the caches are growing too large in your use cases. We can make the caching logic smarter.

@hawkinsp should we leave this issue open pending some XLA:GPU compilation follow-up?

Yes. As I mentioned above, the immediate fix is adding a threefry CUDA kernel to jaxlib.

Also, let us know if the caches are growing too large in your use cases. We can make the caching logic smarter.

Thanks for the help, we will let you know. I think there are a few optimizations that we can make in numpyro first that should result in better cache hits and lesser entries overall.

PR #1756 makes the example in this issue compile quickly and without the memory blowup. Note the fix requires a rebuild of jaxlib (or requires waiting for us to release new jaxlib wheels.) Hope that helps!

Whoa!!! Thanks so much for the quick fix, Peter!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

martiningram picture martiningram  路  21Comments

kirk86 picture kirk86  路  22Comments

shyoshyo picture shyoshyo  路  26Comments

shoyer picture shoyer  路  24Comments

ricardobarroslourenco picture ricardobarroslourenco  路  35Comments