Jax: vmap unexpected behavior: NaNs in gradients

Created on 13 Aug 2020  路  4Comments  路  Source: google/jax

Hi,
I am observing an unexpected behavior from jax.vmap.
I isolated the issue in the following code snippet. If "distance_matrices" are calculated without vmap, it works fine but when I use jax.vmap, the gradients become NaN.

I am using Jax version 0.1.75.

I appreciate any help that you can provide.

import jax
import jax.numpy as np
print(jax.__version__)
@jax.custom_transforms
def safe_sqrt(x):
  return np.sqrt(x)
jax.defjvp(safe_sqrt, lambda g, ans, x: 0.5 * g / np.where(x > 0, ans, np.inf) )

def create_distance_matrices_single(tiled_positions, shift):

    tiled_positions_trans = tiled_positions.swapaxes(0,1)
    shifted_tiled_positions_trans = tiled_positions_trans + shift
    diff = tiled_positions - shifted_tiled_positions_trans
    distance_matrix = safe_sqrt(np.square(diff).sum(axis=2))
    return distance_matrix


def create_distance_matrices(positions, shifts):
    count = len(positions)
    atom_pos = positions.reshape((count,1,3))
    tiled_positions = np.tile(atom_pos,(1,count,1))

    # if jax.vmap is used, gradients become NaNs
    distance_matrices = jax.vmap(create_distance_matrices_single, in_axes=(None,0))(tiled_positions, shifts)
    #distance_matrices = create_distance_matrices_single(tiled_positions, shifts)
    return np.sum(distance_matrices.flatten())

grad_func = jax.value_and_grad(create_distance_matrices)

position = np.ones((2,3))
shift = np.zeros((1,3))

val, grads = grad_func(position,shift)
question

All 4 comments

The issue is likely from using jax.custom_transforms, which is deprecated in large part because of unfavorable interactions with vmap. Could you try using jax.custom_jvp insetad? Here's the tutorial.

There's also some commentary about this issue in this section of the custom_jvp design note.

Thank you for the response @mattjj
Yes using jax.custom_jvp has resolved the issue.

import jax
import jax.numpy as np
print(jax.__version__)

@jax.custom_jvp
def safe_sqrt(x):
  return np.sqrt(x)

def safe_sqrt_jvp(primals, tangents):
  x, = primals
  t, = tangents
  return np.sqrt(x), np.where(x==0,0.0, 0.5 * np.power(x, -0.5)) * t

safe_sqrt.defjvp(safe_sqrt_jvp)
def create_distance_matrices_single(tiled_positions, shift):

    tiled_positions_trans = tiled_positions.swapaxes(0,1)
    shifted_tiled_positions_trans = tiled_positions_trans + shift
    diff = tiled_positions - shifted_tiled_positions_trans
    distance_matrix = safe_sqrt(np.square(diff).sum(axis=2))
    return distance_matrix


def create_distance_matrices(positions, shifts):
    count = len(positions)
    atom_pos = positions.reshape((count,1,3))
    tiled_positions = np.tile(atom_pos,(1,count,1))

    # if jax.vmap is used, gradients become NaNs
    distance_matrices = jax.vmap(create_distance_matrices_single, in_axes=(None,0))(tiled_positions, shifts)
    #distance_matrices = create_distance_matrices_single(tiled_positions, shifts)
    return np.sum(distance_matrices.flatten())

grad_func = jax.value_and_grad(create_distance_matrices)

position = np.ones((2,3))
shift = np.zeros((1,3))

val, grads = grad_func(position,shift)

Woohoo, glad to hear it!

Was this page helpful?
0 / 5 - 0 ratings

Related issues

lhk picture lhk  路  3Comments

zhongwen picture zhongwen  路  3Comments

asross picture asross  路  3Comments

harshit-2115 picture harshit-2115  路  3Comments

madvn picture madvn  路  3Comments