Jax: Unexpectedly high grad-of-scan memory usage

Created on 22 May 2020  路  4Comments  路  Source: google/jax

Consider the following function that sums x * (y + z) over all y in ys and then averages over the resulting matrix of sums:

import jax.lax
import jax.numpy as jnp

def f(x, ys):
    z = jnp.ones((3000, 3000))

    def scanned(carry, y):
        return carry + x * (y + z), None

    summed, _ = jax.lax.scan(scanned, jnp.zeros_like(z), ys)
    return summed.mean()

Because I use lax.scan (instead of, e.g., vmap or lax.map followed by a sum over the first axis), memory usage doesn't significantly scale with the number of ys. The following code uses ~203MB regardless of whether n = 5 or n = 10:

import resource

print(f(1.0, jnp.ones(n)))
print(f"{1e-3 * resource.getrusage(resource.RUSAGE_SELF).ru_maxrss}MB")

But the gradient uses 557MB for n = 5 and 908MB for n = 10:

import jax

print(jax.grad(f)(1.0, jnp.ones(n)))
print(f"{1e-3 * resource.getrusage(resource.RUSAGE_SELF).ru_maxrss}MB")

The story is similar when these functions are jitted.

My best guess about what's going on here is that grad is storing every (y + z) in memory. Is this intended? And is there some way to tell grad to be more economical about what it stores in memory to achieve a similar lax.scan memory reduction when computing the gradient?

question

All 4 comments

You're right that grad causes every (y + z) to be stored. Since the result of f is computed using x * (y + z), it needs to save the (y + z) values to compute the gradient. You can try using the new jax.remat, which causes values needed by the gradient computation to be recomputed instead of stored, thus saving memory. This probably makes sense for a scan like this, where you're creating a large amount of easy-to-compute values. See https://github.com/google/jax/pull/1749 for examples of using remat. I think doing scan(remat(scanned), ...) should work in this case.

cc @mattjj who created remat

This is perfect, thanks so much! I hadn't seen remat before -- looks like it's tailor-made for this type of problem.

For some reason rematifying scanned directly didn't seem to work; I found that I had to rematify the actual computation within the scan to get the desired memory reduction:

def f(x, ys):
    z = jnp.ones((3000, 3000))

    @jax.remat
    def inner(y):
        return x * (y + z)

    def scanned(carry, y):
        return carry + inner(y), None

    summed, _ = jax.lax.scan(scanned, jnp.zeros_like(z), ys)
    return summed.mean()

By the way, we're working on some other improvements that should make this work well even without remat by never instantiating the large ones((3000, 3000)) array. We'd still need remat in general, but in this case the memory savings can be had by avoiding the large constant.

Very cool, I'll keep my eyes peeled and keep updating the package. The work you all are doing here is really great.

Was this page helpful?
0 / 5 - 0 ratings