Jax: Gradients with `odeint` slow on GPU

Created on 24 Nov 2020  Â·  5Comments  Â·  Source: google/jax

The following MWE trains a simple neural ODE model with gradient descent to match a 2-D dynamical system (Van der Pol oscillator) with sampled data along a single trajectory. Each iteration of the training loop runs slowly on my GPU when compared to running everything on my CPU (roughly estimated with tqdm at 17 iterations/sec on GPU vs. upwards of 800 iterations/sec on CPU).

Any first impressions about what might be going on? I can look into doing better profiling if need be.

Versions: jax 0.2.6, jaxlib 0.1.57+cuda102, cuda 10.2

import jax
import jax.numpy as jnp
from jax.experimental.ode import odeint
try:
    from tqdm import tqdm
except ImportError:
    tqdm = lambda x: x

# Uncomment this line to force using the CPU
# jax.config.update('jax_platform_name', 'cpu')

# Some utilities for dealing with PyTrees of parameters
def tree_axpy(a, x_tree, y_tree):
    """Compute `y = a*x` for two PyTrees `(x, y)` and a scalar `a`."""
    ax = jax.tree_util.tree_map(lambda x: a * x, x_tree)
    axpy = jax.tree_util.tree_multimap(lambda x, y: x + y, ax, y_tree)
    return axpy

def tree_normsq(x_tree):
    """Compute sum of squared norms across a PyTree."""
    normsq = jax.tree_util.tree_reduce(lambda x, y: x + jnp.sum(y**2), x_tree, 0.)
    return normsq

# Define true ODE, our approximator, and the loss function
def f(x, t):
    """Compute state derivative of a Van der Pol oscillator."""
    mu = 1.
    dx = jnp.hstack([
        mu*(x[0] - x[0]**3/3 - x[1]),
        x[0]/mu
    ])
    return dx

def f_est(x, t, params):
    """Estimate state derivative with a two-layer neural network."""
    W = params['W']
    b = params['b']
    y = W[0]@x + b[0]
    y = W[1]@jnp.tanh(y) + b[1]
    return y

def loss(params, x, t, reg_coeff):
    """Compute the sum of squared losses along a queried trajectory."""
    x_hat = odeint(f_est, x[0], t, params)
    error = jnp.sum((x - x_hat)**2)
    loss_value = error + reg_coeff*tree_normsq(params)
    return loss_value

# Generate data along a trajectory of the true system
x0 = jnp.array([1., 0.])
t0, tf = (0., 5.)
dt = 0.1
num_steps = int((tf - t0) / dt) + 1
t = jnp.linspace(t0, tf, num_steps)
x = odeint(f, x0, t)

# Initialize neural network parameters
n = 2
hdim = 32  # size of hidden layer
key = jax.random.PRNGKey(0)
params = {
    'W': [
        0.1*jax.random.normal(key, (hdim, n)),
        0.1*jax.random.normal(key, (n, hdim)),
    ],
    'b': [
        0.1*jax.random.normal(key, (hdim,)),
        0.1*jax.random.normal(key, (n,)),
    ]
}

# Training
loss_buffer = []
step_size = 1e-4
reg_coeff = 1e-6
value_and_grad = jax.jit(jax.value_and_grad(loss))
for _ in tqdm(range(5000)):
    value, grad = value_and_grad(params, x, t, reg_coeff)
    loss_buffer.append(value)
    params = tree_axpy(-step_size, grad, params)  # gradient descent step
print('Regularized fit loss:', loss_buffer[-1])

# Plotting (optional)
try:
    import matplotlib.pyplot as plt

    x_est = odeint(f_est, x0, t, params)

    fig, axes = plt.subplots(1, 2, figsize=(15,5))
    axes[0].plot(x[:,0], x[:,1], '--x')
    axes[0].plot(x_est[:,0], x_est[:,1], '-')
    axes[1].plot(loss_buffer)
    axes[1].set_yscale('log')
    plt.show()
except ImportError:
    print('Package `matplotlib` not found! Skipping plots.')

All 5 comments

The short answer is that unfortunately at this time XLA GPU is not great at code generation for tight loops like those in odeint. The body of the while_loop is compiled into one or more GPU kernels, which has significant launch overhead because control flow goes back to the CPU in each iteration.

@shoyer does anyone work on improving while_loop?

Yes, there are several ongoing streams of work to improve while_loop.

@shoyer, apologies for unrelated with the topic questions. Could you share the links to PRs, branches to ongoing work if it's publicly available?

Sorry, I don't have any details that I can share at this time.

On Wed, Dec 9, 2020 at 12:37 PM Artem Artemev notifications@github.com
wrote:

@shoyer https://github.com/shoyer, apologies for unrelated with the
topic questions. Could you share the links to PRs, branches to ongoing work
if it's publicly available?

—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
https://github.com/google/jax/issues/5006#issuecomment-742033766, or
unsubscribe
https://github.com/notifications/unsubscribe-auth/AAJJFVSQQ6RAS2NDDY2AH5DST7NXDANCNFSM4UBHJBMQ
.

Was this page helpful?
0 / 5 - 0 ratings