Consider the following function that sums x * (y + z) over all y in ys and then averages over the resulting matrix of sums:
import jax.lax
import jax.numpy as jnp
def f(x, ys):
z = jnp.ones((3000, 3000))
def scanned(carry, y):
return carry + x * (y + z), None
summed, _ = jax.lax.scan(scanned, jnp.zeros_like(z), ys)
return summed.mean()
Because I use lax.scan (instead of, e.g., vmap or lax.map followed by a sum over the first axis), memory usage doesn't significantly scale with the number of ys. The following code uses ~203MB regardless of whether n = 5 or n = 10:
import resource
print(f(1.0, jnp.ones(n)))
print(f"{1e-3 * resource.getrusage(resource.RUSAGE_SELF).ru_maxrss}MB")
But the gradient uses 557MB for n = 5 and 908MB for n = 10:
import jax
print(jax.grad(f)(1.0, jnp.ones(n)))
print(f"{1e-3 * resource.getrusage(resource.RUSAGE_SELF).ru_maxrss}MB")
The story is similar when these functions are jitted.
My best guess about what's going on here is that grad is storing every (y + z) in memory. Is this intended? And is there some way to tell grad to be more economical about what it stores in memory to achieve a similar lax.scan memory reduction when computing the gradient?
You're right that grad causes every (y + z) to be stored. Since the result of f is computed using x * (y + z), it needs to save the (y + z) values to compute the gradient. You can try using the new jax.remat, which causes values needed by the gradient computation to be recomputed instead of stored, thus saving memory. This probably makes sense for a scan like this, where you're creating a large amount of easy-to-compute values. See https://github.com/google/jax/pull/1749 for examples of using remat. I think doing scan(remat(scanned), ...) should work in this case.
cc @mattjj who created remat
This is perfect, thanks so much! I hadn't seen remat before -- looks like it's tailor-made for this type of problem.
For some reason rematifying scanned directly didn't seem to work; I found that I had to rematify the actual computation within the scan to get the desired memory reduction:
def f(x, ys):
z = jnp.ones((3000, 3000))
@jax.remat
def inner(y):
return x * (y + z)
def scanned(carry, y):
return carry + inner(y), None
summed, _ = jax.lax.scan(scanned, jnp.zeros_like(z), ys)
return summed.mean()
By the way, we're working on some other improvements that should make this work well even without remat by never instantiating the large ones((3000, 3000)) array. We'd still need remat in general, but in this case the memory savings can be had by avoiding the large constant.
Very cool, I'll keep my eyes peeled and keep updating the package. The work you all are doing here is really great.