jax.lax.switch host memory leak

Created on 16 Dec 2020  路  11Comments  路  Source: google/jax

I'm encountering a memory leak in code that repeatedly calls jax.lax.switch. Here's a small example that reproduces the issue in colab:

import resource, gc
import numpy as np
import jax, jax.numpy as jp

d = 100

def do_thing():
  i = jp.array(np.random.rand() > 0.5, dtype="int32")
  x = jp.array(np.random.randn(d))
  y = jp.array(np.random.randn(d))
  def fn(i, x, y):
    return jax.lax.switch(i, [(lambda _: x), (lambda _: y)], None)
  fn(i, x, y)

stats = np.zeros(1000)
for i in range(len(stats)):
  gc.collect()
  do_thing()
  gc.collect()
  stats[i] = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss # (kilobytes)

import matplotlib.pyplot as plt
plt.plot(stats)
plt.gca().ticklabel_format(style="plain", useOffset=False)

jax_switch_memory_leak

performance

All 11 comments

It looks like we are repeatedly compiling a cond primitive and never getting cache hits for it:

WARNING:absl:Compiling cond for args ((ShapedArray(int32[]), None), (ShapedArray(float32[100]), None), (ShapedArray(float32[100]), None)).
  File "/Users/phawkins/p/jax/q.py", line 18, in <module>
    do_thing()
  File "/Users/phawkins/p/jax/q.py", line 13, in do_thing
    fn(i, x, y)
  File "/Users/phawkins/p/jax/q.py", line 12, in fn
    return jax.lax.switch(i, [(lambda _: x), (lambda _: y)], None)
  File "/Users/phawkins/p/jax/jax/_src/lax/control_flow.py", line 616, in switch
    out = cond_p.bind(
  File "/Users/phawkins/p/jax/jax/_src/lax/control_flow.py", line 1105, in cond_bind
    return core.Primitive.bind(cond_p, *args, branches=branches, linear=linear)
  File "/Users/phawkins/p/jax/jax/core.py", line 271, in bind
    out = top_trace.process_primitive(self, tracers, params)
  File "/Users/phawkins/p/jax/jax/core.py", line 595, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/Users/phawkins/p/jax/jax/interpreters/xla.py", line 235, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
  File "/Users/phawkins/p/jax/jax/interpreters/xla.py", line 275, in xla_primitive_callable
    traceback.print_stack()

(We also seem to be missing compilation logging for primitives!)

I think always recompiling here is expected, because we're passing in fresh lambda objects as arguments to switch each time.

However, I wonder if we can try to mitigate this. @apaszke recently told me (I think) that f.__code__ is cached by CPython:

In [1]: hash(lambda x: x)
Out[1]: 8752827212549

In [2]: hash(lambda x: x)
Out[2]: 8752827192743

In [3]: hash((lambda x: x).__code__)
Out[3]: -5762213570434008460

In [4]: hash((lambda x: x).__code__)
Out[4]: -5762213570434008460

Maybe in some cases we can leverage this to get more cache hits, at least in some cases?

In the meantime, the fix is not to pass lambdas into switches, and instead reuse the same function object.

EDIT: sorry, while writing this comment, I forgot that the real issue is the memory growth. Even if we don't get cache hits, we shouldn't leak memory!

With __code__, we might have a runtime/memory tradeoff to consider: does the time to hash/eq the __code__ scale with the code size? My guess is that functions are compared by object id, by contrast.

But even if we keyed on f.__code__ rather than on f, we'd still miss the cache in this case because the branch lambdas close over fresh arrays. See:

>>> a, b = np.ones(4), np.ones(7)
>>> id(a), id(b)
(140197940693536, 140197940693616)
>>> f = lambda _: a
>>> g = lambda _: b
>>> hash(f.__code__) == hash(g.__code__)
False
>>> f.__code__ == g.__code__
False

This behavior is also consistent with our current extraction of closure-captured values as "consts" when we stage to a jaxpr.

So, keying on __code__ might still be an improvement for code that looks like:

lax.switch(..., [(lambda x: x + x), (lambda x: x * x)], x)

but not for the example given here.

Relatedly, the memory growth is bound to be more noticeable due to those size 100 arrays being stored alongside the lambdas, again since they're extracted when staging to jaxpr. @mattjj's suggestion to "reuse the same function object" would have forced the branches to be written so that they accept those arrays as a formal argument and wouldn't be held by the cache entry.

Concretely, does this mean that if I use a globally defined function and pass everything in through switch's third argument rather than by closure, the memory leak would go away? That's doable.

Would a functools.partial wrapped version of that globally defined function still be okay? (The partial would provide a plain Python int to the function, no arrays or anything like that.)

Yeah, a rewrite along the following lines ought to work around the observed leak:

fst = lambda z: z[0]
snd = lambda z: z[1]

def do_thing():
  i = jp.array(np.random.rand() > 0.5, dtype="int32")
  x = jp.array(np.random.randn(d))
  y = jp.array(np.random.randn(d))
  return jax.lax.switch(i, [fst, snd], (x, y))

If the use of partial that you have in mind is in order to set things up, along the lines of:

take = lambda i, z: z[i]
fst, snd = partial(take, 0), partial(take, 1)

then yes that should not interfere with this workaround. But pushing that partial into the switch, as in:

def do_thing():
  # ...
  return jax.lax.switch(i, [partial(take, 0), partial(take, 1)], (x, y))

will leak again, since it creates new functions every time. We can mitigate the latter internally by keying on __code__, but we don't do that yet today.

This latter "leak" would not be as expensive or noticeable as the one you were originally seeing, since it doesn't store a size 100 array in every cache entry any longer. But the cache will still grow to account for the fresh functions, and there will be a runtime cost to staging those fresh functions out to jaxpr every time switch is called.

Gotcha. My real use case is basically I have a sequence of heterogeneous things that I flattened into a common format so I can stack them into an array. Then I have another sequence of Enum values that indicate the type of each element; these Enum values are stacked into an array of int32. At some point I need to do some computations that are structurally different depending on the type of the element, and so I use a vmap with inside of it a switch to basically turn the int32 back into Enum values. Here's a contrived example:

import jax, jax.numpy as jp
from functools import partial
from enum import IntEnum

class Kind(IntEnum):
  LEFT = 0
  RIGHT = 1

def general_handler(kind, xy):
  [x, y] = xy
  if kind == Kind.LEFT: return x
  if kind == Kind.RIGHT: return y

# prepare partial applications once
handlers = dict()
for kind in Kind:
  handlers[kind] = partial(general_handler, kind)

@jax.vmap
def fn(dynamic_kind, x, y):
  return jax.lax.switch(dynamic_kind,
                        [handlers[static_kind] for static_kind in Kind],
                        [x, y])

Based on what you've said, I think this should avoid both the leak and the cache misses.

I asked about partial because it doesn't create a function but a "partial object" -- simple data structure containing basically the function and its partial arguments. I thought that partial(fn, 0) == partial(fn, 0) so a newly created partial object for the same globally defined function with the same arguments would still match the cached one, but alas partial objects don't compare that way.

It looks like the original issue is resolved. Still I'd like to track the possibility of keying on f.__code__ rather than on id(f) in our cache.

From a few experiments, it appears that comparing __code__ can yield false positives. Accounting for __closure__ might make up the difference. But __closure__ isn't always available, e.g. on methods. In fact even __code__ isn't always available, e.g. for builtin functions like hash. We can either keep things simple and avoid this change entirely, or we can key on __code__ and __closure__ when both are available, and otherwise key on object identity.

I'm leaning towards avoiding the change for now because (i) I'm not entirely sure about its correctness, and (ii) it doesn't solve the entire problem presented in this issue originally. We can revisit this if the special case problem (same code, same closure, fresh function object) comes up. We could tackle the overall issue of an observed leak in a separate way altogether, for instance by amending our cache eviction policy.

I'm still seeing the same problem if I modify my initial example as suggested:

import resource, gc
import numpy as np
import jax, jax.numpy as jp

def xbranch(xy): return xy[0]
def ybranch(xy): return xy[1]

d = 1
def do_thing():
  i = jp.array(np.random.rand() > 0.5, dtype="int32")
  x = jp.array(np.random.randn(d))
  y = jp.array(np.random.randn(d))
  jax.lax.switch(i, [xbranch, ybranch], [x, y])

stats = np.zeros(100)
for i in range(len(stats)):
  gc.collect()
  do_thing()
  gc.collect()
  stats[i] = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss # (kilobytes)

import matplotlib.pyplot as plt
plt.plot(stats)
plt.gca().ticklabel_format(style="plain", useOffset=False)

jax_switch_memory_leak_2

The rate of growth doesn't seem to depend on the size d of the data, so it's something else that's being leaked. It's always ~40MB over the 100 iterations.

Thanks! That helped me isolate a related downstream cache miss. #5294 should fix it.

Was this page helpful?
0 / 5 - 0 ratings