Jax: jit cache leaks memory

Created on 27 Jan 2019  路  4Comments  路  Source: google/jax

Running the following example with --jit leaks memory (duplicates x?); GPU memory utilization increases by 1 GiB per iteration. Without --jit it works fine.

#!/usr/bin/env python3
import argparse
import jax
import jax.numpy as np
import numpy as onp


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--jit', action='store_true')
    args = parser.parse_args()

    x = onp.random.rand(1024, 1024).astype(onp.float32)
    x = onp.repeat(x, 256, axis=0)
    x = jax.device_put(x)
    # x is 1 GiB large

    v = onp.random.rand(1024).astype(onp.float32)
    v = jax.device_put(v)

    for i in range(16):
        print(f'{i + 1}. iteration')

        def f(v):
            return np.dot(x, v)

        if args.jit:
            f = jax.jit(f)

        f(v)


if __name__ == '__main__':
    main()

While this particular example can be fixed with the following code, in general that does not work because I might need to redefine (and re-jit) a function in a loop.

#!/usr/bin/env python3
import jax
import jax.numpy as np
import numpy as onp


def main():
    x = onp.random.rand(1024, 1024).astype(onp.float32)
    x = onp.repeat(x, 256, axis=0)
    x = jax.device_put(x)
    # x is 1 GiB large

    v = onp.random.rand(1024).astype(onp.float32)
    v = jax.device_put(v)

    def f(x, v):
        return np.dot(x, v)

    f = jax.jit(f, static_argnums=(0,))

    for i in range(16):
        print(f'{i + 1}. iteration')

        f(x, v)


if __name__ == '__main__':
    main()

Note that the following does not solve the problem and has the same _leak_ as the original example:

#!/usr/bin/env python3
import jax
import jax.numpy as np
import numpy as onp


def main():
    x = onp.random.rand(1024, 1024).astype(onp.float32)
    x = onp.repeat(x, 256, axis=0)
    x = jax.device_put(x)
    # x is 1 GiB large

    v = onp.random.rand(1024).astype(onp.float32)
    v = jax.device_put(v)

    for i in range(16):
        print(f'{i + 1}. iteration')

        def f(x, v):
            return np.dot(x, v)

        f = jax.jit(f, static_argnums=(0,))

        f(x, v)


if __name__ == '__main__':
    main()

So I guess my question is: how can I define (and jit) many functions that all use the same large array as a constant without duplicating the constant? I think this should be possible in the case above where the functions overwrite each other, and in principle it should even be possible with all functions existing simultaneously.
Of course I can always create the functions with an explicit and non-static x argument, but then jit cannot make use of the fact that x is a constant.

Maybe I am missing a simple trick?

P.S. JAX is absolutely awesome!

bug

Most helpful comment

I could be wrong, but I think the only way that jit can meaningfully make use of the fact that x is a constant is by hard-coding it into the computation sent to the XLA compiler (which can then perform constant-prop or, more likely, just leave it as a large constant in the resulting binary).

If you want multiple JITted computations to share device memory for x, you would have to create it as a device-backed array and pass it as a formal argument to each of the computations. I would be surprised if, in practice, this hurts your performance relative to the situation with hard-coded constants.

All 4 comments

I could be wrong, but I think the only way that jit can meaningfully make use of the fact that x is a constant is by hard-coding it into the computation sent to the XLA compiler (which can then perform constant-prop or, more likely, just leave it as a large constant in the resulting binary).

If you want multiple JITted computations to share device memory for x, you would have to create it as a device-backed array and pass it as a formal argument to each of the computations. I would be surprised if, in practice, this hurts your performance relative to the situation with hard-coded constants.

In practice it means that one is forced to pass data using explicit arguments and thus one can't use e.g. partial to pass some data early to an inner function if several outer functions using this inner function will be jitted. Instead each outer function must know about this data and pass it along to the inner function.

I like the "bug" label here, because we can do better at this than we do now!

We've talked before about automatically lifting large closed-over constants into formal arguments rather than baking them into the XLA computation. Then if we cache those values in device memory, we might want some kind of cache eviction policy. We can still persist the host memory representations in case the constants are needed again. (More generally, for all device values we might want some kind of caching policy like this.)

Here is a more realistic example that's closer to my actual problem. It demonstrates why all currently possible workarounds are suboptimal afaik (except for a new vjp implementation).

Assume we have many identically shaped arrays in many_xs. The same goes for many_vs.

The following works and the forward pass through f is done only once per x:

for x in many_xs:
    # do one forward pass for each x
    fx, vjp_fun = jax.vjp(f, x)
    for v in many_vs:
        # do one backward pass for each x and v
        vjp_fun(v)

But now we want to JIT this because there are a lot of small operations involved ...

First attempt: this is bad, because vjp_fun needs be recompiled for each x even though the code is the same (just the data cached inside vjp_fun changed):

for x in many_xs:
    fx, vjp_fun = jax.vjp(f, x)
    vjp_fun = jax.jit(vjp_fun)
    for v in many_vs:
        vjp_fun(v)

Okay, so we need to have a single fixed function outside of the loop:

Second attempt: this solves the repeated compilation, but but now the forward pass through f happens (len(many_vs)) times for each x: not good!

@jax.jit
def new_vjp_fun(x, v):
    fx, vjp_fun = jax.vjp(f, x)
    return vjp_fun(v)

for x in many_xs:
    for v in many_vs:
        new_vjp_fun(x, v)

In other words, when data is baked into functions as it naturally happens when using e.g. vjp, then there does not seem to be a way to write efficient code.

I guess with the way things currently work, the only _perfectly efficient_ solution would be a new vjp implementation (say vjp2) that returns code and data separately from each other. Then something like this could work:

# we could now be static on the code for vjp and generate it only once
partial(@jax.jit, static_argnums=(0,))
def new_vjp_fun(vjp_fun_code, vjp_fun_data, v):
    return vjp_fun_code(vjp_fun_data, v)

# call it once to get the code that we will reuse (ignore the data)
fx, vjp_fun_code, vjp_fun_data = jax.vjp2(f, x)
for i in range(100):
    x = get_x(i)
    fx, _, vjp_fun_data = jax.vjp2(f, x)  # we ignore the returned code and instead reuse it (to avoid recompilation)
    for j in range(50):
        v = get_v(j)
        new_vjp_fun(x, v)

Not sure if such a vjp2 can easily be implemented at the user level. From what I've seen, ad.vjp already returns a function with data baked in, not only api.jvp ...

Was this page helpful?
0 / 5 - 0 ratings

Related issues

asross picture asross  路  3Comments

DylanMuir picture DylanMuir  路  3Comments

madvn picture madvn  路  3Comments

harshit-2115 picture harshit-2115  路  3Comments

kunc picture kunc  路  3Comments