Not sure if this is the place to post this, so let me know if I should be putting this elsewhere.
I'm trying to implement gradient-enhanced kriging with JAX. Ideally, the user should be able to define a kernel and JAX will be able to automatically find its gradient/hessian and generate the appropriate covariance matrix. Here's what I've tried so far:
Kernel Implementation
import jax.numpy as np
from jax import jit, jacrev, jacfwd, vmap
from abc import ABC, abstractmethod
class Kernel(ABC):
def __init__(self, name: str, debug=True):
self.name = name
def __call__(self, x1: np.ndarray, x2: np.ndarray):
raise NotImplementedError()
@abstractmethod
def forward(self, x1: np.ndarray, x2: np.ndarray, thetas: np.ndarray):
raise NotImplementedError()
class RBF(Kernel):
def __init__(self, length_scale=1.0):
self.length_scale = length_scale
super(RBF, self).__init__('RBF')
def __call__(self, x1: np.ndarray, x2: np.ndarray):
return self.forward(x1, x2, np.array([self.length_scale]))
def forward(self, x1: np.ndarray, x2: np.ndarray, thetas: np.ndarray):
assert thetas.shape == (1,)
length_scale = thetas[0]
# d = np.linalg.norm(
# np.expand_dims(x1, 0) - np.expand_dims(x2, 1),
# axis=-1
# ) / length_scale
# return np.exp(-0.5 * (d * d))
d = (np.expand_dims(x1, 0) - np.expand_dims(x2, 1)) / length_scale
d = d * d
d = np.sum(d, axis=-1)
return np.exp(-0.5 * d)
class RBFGrad(RBF):
def __init__(self, length_scale=1.0):
super(RBFGrad, self).__init__(length_scale)
self.dkdx1 = jit(jacfwd(super(RBFGrad, self).forward, argnums=0))
self.dkdx2 = jit(jacfwd(super(RBFGrad, self).forward, argnums=1))
self.dk2dx1dx2 = jit(jacfwd(jacrev(super(RBFGrad, self).forward, argnums=0), argnums=1))
def forward(self, x1: np.ndarray, x2: np.ndarray, thetas: np.ndarray):
K = super().forward(x1, x2, thetas)
dx2 = self.dkdx2(x1, x2, thetas).sum(-2)
upper = np.concatenate([dx2[:, :, i] for i in range(dx2.shape[-1])], axis=1)
dx1 = self.dkdx1(x1, x2, thetas).sum(-2)
left = np.concatenate([dx1[:, :, i] for i in range(dx1.shape[-1])], axis=0)
dx1dx2 = self.dk2dx1dx2(x1, x2, thetas).sum(2).sum(-2)
dx2_concatenated = np.concatenate([
dx1dx2[:, :, :, i]
for i in range(dx1dx2.shape[-1])
], axis=1)
hess = np.concatenate([
dx2_concatenated[:, :, i]
for i in range(dx2_concatenated.shape[-1])
], axis=0)
# form the overall covariance matrix
# [
# [K, dK/dx2 ],
# [dK/dx1, dK^2/dx1dx2 ]
# ]
return np.concatenate([
np.concatenate([K, upper], axis=1),
np.concatenate([left, hess], axis=1)],
axis=0
)
Test
import pytest
from gpgrad.kernels import *
import jax.numpy as np
import numpy as onp
import jax
def test_rbf():
a = np.linspace(0, 10).reshape(-1, 1)
k = RBF(1.0)
result = k(a, a)
assert result.shape == (len(a), len(a))
assert np.allclose(np.diag(result), 1.0)
assert np.allclose(k(a, a), k(a.reshape(-1, 1), a.reshape(-1, 1)))
def test_rbfgrad():
a = np.linspace(0, 10).reshape(-1, 1)
b = np.linspace(0, 10).reshape(-1, 1)
k = RBFGrad(1.0)
result = k(a, b)
assert result.shape == (2*len(a), 2*len(a))
# diagonals of kernel matrix should be 1.0
assert np.allclose(np.diag(result[:len(a), :len(a)]), 1.0)
# diagonals of derivative matrices wrt. x1 or x2 should be 0.0
assert np.allclose(np.diag(result[:len(a), len(a):]), 0.0)
assert np.allclose(np.diag(result[len(a):, :len(a)]), 0.0)
# diagonals of hessian wrt x1 and x2 should be 1.0
assert np.allclose(np.diag(result[len(a):, len(a):]), 1.0)
I _think_ this gets me what I want, though I'm admittedly a little confused by the resulting shapes of the derivatives (dx1 and dx2 in RBFGrad). For inputs of shape (N, d), dx1 and dx2 shapes are (N, N, N, d). I think this is because JAX is trying to take the derivative of each row of x1/x2 wrt each element of the NxN matrix of RBF.forward(). This results in lots of 0.0's in the derivatives (see dx1 or dx2 in RBFGrad.forward() before .sum(-2)). This suggests O(N) operations that simply evaluate to 0.0. This is exacerbated when I evaluate the second derivative matrix, making the code quite slow to run for a modest number of rows for x.
On my laptop, test_rbf() runs in about 600ms without JIT compiling. test_rbfgrad() runs in about 15s even with JIT compiling on the derivative functions. The main bottleneck there is the second derivative function.
Is there a faster way to implement this in JAX, perhaps by eliminating the numerous operations that result in 0's in the derivatives?
If you don't want to calculate a bunch of off-diagonal terms, you might try using vmap on top of jacfwd or jacrev.
Thanks for the question! I agree with what @shoyer said.
To say a bit more, if f: R^n -> R^m then we expect its Jacobian to be an m x n matrix. The function being differentiated here, namely RBF.forward with respect to its first argument, takes an input of shape (50, 1) and produces an output of shape (50, 50), which is why we expect to see a Jacobian of shape (50, 50, 50, 1). (If we want to think in flattened terms, we'd say that the input has dimension n=50 and the output has dimension m=2500, so the Jacobian matrix is 2500 x 50.)
The trouble here is we just want to see the data axes of size 50 as batch axes. That is, we really want to think of the kernel as taking two vectors and outputting a scalar. We only carry along the batch axis in the implementation for vectorization efficiency.
But with vmap, we don't need to batch by hand, and as a result it can be easier to express the functions and derivatives we want. Concretely, take a look at the forward_ and dkdx2_ methods below (note the underscores on the end), and the assertion:
class RBF(Kernel):
# ...
def forward_(self, x1: np.ndarray, x2: np.ndarray, thetas: np.ndarray):
assert thetas.shape == (1,)
length_scale = thetas[0]
dist_sq = np.vdot(x1, x1) + np.vdot(x2, x2) - 2 * np.vdot(x1, x2)
return np.exp(-0.5 * dist_sq / length_scale**2)
class RBFGrad(RBF):
def __init__(self, length_scale=1.0):
super(RBFGrad, self).__init__(length_scale)
self.dkdx1 = jit(jacfwd(super(RBFGrad, self).forward, argnums=0))
self.dkdx2 = jit(jacfwd(super(RBFGrad, self).forward, argnums=1))
self.dk2dx1dx2 = jit(jacfwd(jacrev(super(RBFGrad, self).forward, argnums=0), argnums=1))
self.dkdx2_ = jit(vmap(vmap(grad(super().forward_, argnums=1), (0, None, None)), (None, 0, None)))
def forward(self, x1: np.ndarray, x2: np.ndarray, thetas: np.ndarray):
K = super().forward(x1, x2, thetas)
dx2 = self.dkdx2(x1, x2, thetas).sum(-2)
dx2_ = self.dkdx2_(x1, x2, thetas)
assert np.allclose(dx2, dx2_)
# ...
I rewrote the forward_ method so that it more obviously applies to single vectors at a time, and also to use the polarization identity ||u - v||^2 = ||u||^2 + ||v||^2 - 2 u cdot v which gives us more matrix multiplies and I think is often faster (though one should benchmark against computing ||u - v||^2 more directly as before). I also added the dkdx2_ method, which uses vmap to do all the batching, and compared against the old calculation (it may be faster to use jvp than grad but this was convenient). That second bit, with the vmaps, is the main thing I wanted to illustrate, as it was basically @shoyer's advice.
One other piece of advice would be to put jit on RBFGrad.forward, because the more code XLA can see the more it can optimize things.
WDYT?
There may be a nicer way to write the nested vmap using the vectorize wrapper.
One more thought: to get the second derivatives, you just need to write grad(grad(...)) inside the double-vmap. Getting rid of batch dimensions makes things so much easier!
It might be worth noting that this vmap + auto-diff strategy is actually rather fundamental to JAX. Even jacfwd and jacrev themselves are mere thin wrappers over vmap + jvp/vjp: https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Composing-VJPs,-JVPs,-and-vmap
Oh vmap, of course! Thanks @shoyer and @mattjj, I'll check in again if I run into more trouble.
This works great, and is much faster than my original implementation! Though the second derivative will need jacrev(grad(...)) as grad returns a vector.
Most helpful comment
Thanks for the question! I agree with what @shoyer said.
To say a bit more, if f: R^n -> R^m then we expect its Jacobian to be an m x n matrix. The function being differentiated here, namely
RBF.forwardwith respect to its first argument, takes an input of shape (50, 1) and produces an output of shape (50, 50), which is why we expect to see a Jacobian of shape (50, 50, 50, 1). (If we want to think in flattened terms, we'd say that the input has dimension n=50 and the output has dimension m=2500, so the Jacobian matrix is 2500 x 50.)The trouble here is we just want to see the data axes of size 50 as batch axes. That is, we really want to think of the kernel as taking two vectors and outputting a scalar. We only carry along the batch axis in the implementation for vectorization efficiency.
But with
vmap, we don't need to batch by hand, and as a result it can be easier to express the functions and derivatives we want. Concretely, take a look at theforward_anddkdx2_methods below (note the underscores on the end), and the assertion:I rewrote the
forward_method so that it more obviously applies to single vectors at a time, and also to use the polarization identity ||u - v||^2 = ||u||^2 + ||v||^2 - 2 u cdot v which gives us more matrix multiplies and I think is often faster (though one should benchmark against computing ||u - v||^2 more directly as before). I also added thedkdx2_method, which usesvmapto do all the batching, and compared against the old calculation (it may be faster to usejvpthangradbut this was convenient). That second bit, with thevmaps, is the main thing I wanted to illustrate, as it was basically @shoyer's advice.One other piece of advice would be to put
jitonRBFGrad.forward, because the more code XLA can see the more it can optimize things.WDYT?