Jax: 2nd order derivs wants insane amount of ram

Created on 30 May 2019  路  8Comments  路  Source: google/jax

import jax
import jax.numpy as np
import numpy as onp

def E_fn(conf):
    ri = np.expand_dims(conf, 0)
    rj = np.expand_dims(conf, 1)
    dxdydz = np.power(ri - rj, 2)
    dij = np.sqrt(np.sum(dxdydz, axis=-1))
    return np.sum(dij)

dE_dx_fn = jax.jacrev(E_fn, argnums=(0,))
d2E_dx2_fn = jax.jacfwd(dE_dx_fn, argnums=(0,))

d2E_dx2_fn(onp.random.rand(2483, 3))

Results in:

RuntimeError: Resource exhausted: Out of memory while trying to allocate 551102853132 bytes.

This happens on both CPU and GPU.

There's no reason this calculation should require 551GB's worth of ram. The explicit hessian is "only" (24833)^24 bytes=221 MB

question

Most helpful comment

Actually under a jit we never run things op-by-op (though a version of Jax from 2017 did things that way); the values propagated through the python code are abstract ones, basically representing a set of possible arrays, and those aren鈥檛 backed by any float values. The abstract values just store shape and dtype, and so they don鈥檛 take much memory or cause any FLOPs.

The more I think about it, the more I think XLA should be doing this rewrite optimization for us (and by extension jit should be doing it for you). I鈥檒l raise it with the XLA folks and see what they think. In the worst case, if for some unforeseen reason XLA can鈥檛 do this optimization, this is a place where custom ops could help.

All 8 comments

Well, there is a pretty good reason for requiring a lot of RAM here: the memory required isn't only a function of the final output size, but also of the size of the intermediates. The intermediate dxdydz has shape (2483, 2483, 3), so a single tangent vector for it would have the same shape, but to evaluate a jacfwd we need to push forward an entire standard basis of tangent vectors at once; since the input to E_fn has shape (2483, 3), that means to evaluate jacfwd(E_fn) we need (just for the tangent basis for this one intermediate) to allocate an array of size (2483 * 3, 2483, 2483, 3). That's a lot of bytes! And doing a jacrev as well can only make things worse (EDIT: though this is a scalar-output function, so no worries there).

Using jit here might allow XLA to do some memory-saving optimizations, but I doubt it'd be anything so dramatic as to make this computation feasible. (I didn't actually test this though!)

What we really need is a different evaluation strategy for the Hessian. The most memory-efficient thing to do is not to use jacfwd and jacrev, but instead to use hvp like this (from the autodiff cookbook):

from jax import jvp, grad

def hvp(f, primals, tangents):
  return jvp(grad(f), primals, tangents)[1]

Then you can evaluate the Hessian one column at a time by applying this function to one-hot tangent vectors in an loop.

A slightly better thing to do, to avoid re-doing the linearization work on every call, is to replace jvp with linearize, which basically caches the linearization work. Here's what it might look like in your code:

import jax
import jax.numpy as np
import numpy as onp

def E_fn(conf):
    ri = np.expand_dims(conf, 0)
    rj = np.expand_dims(conf, 1)
    dxdydz = np.power(ri - rj, 2)
    dij = np.sqrt(np.sum(dxdydz, axis=-1))
    return np.sum(dij)

def hessian(f, x):
  _, hvp = jax.linearize(jax.grad(f), x)
  hvp = jax.jit(hvp)  # seems like a substantial speedup to do this
  basis = np.eye(np.prod(x.shape)).reshape(-1, *x.shape)
  return np.stack([hvp(e) for e in basis]).reshape(x.shape + x.shape)


x = onp.random.randn(2483, 3)  # doesn't memory error but might require 30+ seconds
H = hessian(E_fn, x)

In some cases you could use vmap instead of that outer loop with calls to hvp, but here we can't afford the memory.

Actually, the best evaluation strategy would be to use vmap to push through as many standard basis elements as you can afford given the memory in your machine, and have an outer loop over that. I didn't explore that option here.

WDYT?

Hi Matt, thanks for the detailed reply as always.

I was afraid that this was what happened underneath the hood (with regard to the size of tangent basis of the dxdydz intermediate). Unfortunately it looks like your proposed solution (even with linearization) is also prohibitively expensive in time, as the complexity becomes O(N^3) where N in the example is 2483.

As a point of reference, a hand-written kernel for this computation can be ran in less than 3 milliseconds and maintains an O(N^2) complexity. The trick us physicists use to do these types of kernels is to push the derivative down the to inner loop, i.e. differentiate only the dxdydz element-wise, and not differentiate the double outer loop ri = np.expand_dims(conf, 0), rj = np.expand_dims(conf, 1), and instead scatter the results back into the hessian array. I was talking to @sschoenholz about this before and I thought he mentioned there was a special kind of an operator that automagically did something like this.

XLA (ie jit) is the thing that can avoid materializing large intermediates by fusing operations. You should try applying a jit to the original code to see if it helps. But it might not generate the same code as your handwritten code because ultimately there are tradeoffs in deciding where to fuse, and so those decisions are made with compiler heuristics.

If applying a jit in this case doesn鈥檛 avoid materializing the large intermediates, maybe the XLA folks would consider that a bug and would improve their compiler cost model for us. This behavior may be backed-dependent too: generally the TPU backend (which you don鈥檛 have access to yet) is the smartest and the CPU backend is the dumbest.

Without using jit, thorough, there鈥檚 no hope: the way this computation is expressed, in op-by-op mode those intermediates must be generated.

Thanks for the reply - it looks like we'll have to punt to XLA/JIT team then. As is, JITing is sort of a no-go:

import jax
import jax.numpy as np
import numpy as onp

def E_fn(conf):
    ri = np.expand_dims(conf, 0)
    rj = np.expand_dims(conf, 1)
    dxdydz = np.power(ri - rj, 2)
    dij = np.sqrt(np.sum(dxdydz, axis=-1))
    return np.sum(dij)

dE_dx_fn = jax.jacrev(E_fn, argnums=(0,))
d2E_dx2_fn = jax.jacfwd(dE_dx_fn, argnums=(0,))

print("start jit")
d2E_dx2_fn = jax.jit(d2E_dx2_fn)
print("end jit")

d2E_dx2_fn(onp.random.rand(2483, 3)) # OOMs still

Correct me if I'm wrong, doesn't the code have to run at least once in op mode before the trace becomes available for JITing? This seems like a catch-22 to me

Actually under a jit we never run things op-by-op (though a version of Jax from 2017 did things that way); the values propagated through the python code are abstract ones, basically representing a set of possible arrays, and those aren鈥檛 backed by any float values. The abstract values just store shape and dtype, and so they don鈥檛 take much memory or cause any FLOPs.

The more I think about it, the more I think XLA should be doing this rewrite optimization for us (and by extension jit should be doing it for you). I鈥檒l raise it with the XLA folks and see what they think. In the worst case, if for some unforeseen reason XLA can鈥檛 do this optimization, this is a place where custom ops could help.

Thanks a ton!

Another thing we can do for this particular code (not sure how representative it is of your real code) is use the law of cosines to compute the Euclidean Gram matrix:

def E_fn(conf):
  norms = np.sum(conf ** 2, -1)
  dij = np.sqrt(norms[..., None] + norms - 2 * np.dot(conf, conf.T))
  return np.sum(dij)

With 64-bit values enabled, that produces the same result as the original code on an onp.random.randn(2483, 3) input array.

This should be faster (fewer FLOPs) and have smaller intermediates, but it's still not good enough to avoid the OOM on my machine. (Is there an even better way to compute the sum of all Euclidean distances? I'm not coming up with anything...)

This is an extremely simplified (but representative) repro of a much more complicated set of non-truncated potentials that can't always be reduced using the law of cosines (but it's a nice trick to compute the Gramian)

Was this page helpful?
0 / 5 - 0 ratings