Hi,
I am observing an unexpected behavior from jax.vmap.
I isolated the issue in the following code snippet. If "distance_matrices" are calculated without vmap, it works fine but when I use jax.vmap, the gradients become NaN.
I am using Jax version 0.1.75.
I appreciate any help that you can provide.
import jax
import jax.numpy as np
print(jax.__version__)
@jax.custom_transforms
def safe_sqrt(x):
return np.sqrt(x)
jax.defjvp(safe_sqrt, lambda g, ans, x: 0.5 * g / np.where(x > 0, ans, np.inf) )
def create_distance_matrices_single(tiled_positions, shift):
tiled_positions_trans = tiled_positions.swapaxes(0,1)
shifted_tiled_positions_trans = tiled_positions_trans + shift
diff = tiled_positions - shifted_tiled_positions_trans
distance_matrix = safe_sqrt(np.square(diff).sum(axis=2))
return distance_matrix
def create_distance_matrices(positions, shifts):
count = len(positions)
atom_pos = positions.reshape((count,1,3))
tiled_positions = np.tile(atom_pos,(1,count,1))
# if jax.vmap is used, gradients become NaNs
distance_matrices = jax.vmap(create_distance_matrices_single, in_axes=(None,0))(tiled_positions, shifts)
#distance_matrices = create_distance_matrices_single(tiled_positions, shifts)
return np.sum(distance_matrices.flatten())
grad_func = jax.value_and_grad(create_distance_matrices)
position = np.ones((2,3))
shift = np.zeros((1,3))
val, grads = grad_func(position,shift)
The issue is likely from using jax.custom_transforms, which is deprecated in large part because of unfavorable interactions with vmap. Could you try using jax.custom_jvp insetad? Here's the tutorial.
There's also some commentary about this issue in this section of the custom_jvp design note.
Thank you for the response @mattjj
Yes using jax.custom_jvp has resolved the issue.
import jax
import jax.numpy as np
print(jax.__version__)
@jax.custom_jvp
def safe_sqrt(x):
return np.sqrt(x)
def safe_sqrt_jvp(primals, tangents):
x, = primals
t, = tangents
return np.sqrt(x), np.where(x==0,0.0, 0.5 * np.power(x, -0.5)) * t
safe_sqrt.defjvp(safe_sqrt_jvp)
def create_distance_matrices_single(tiled_positions, shift):
tiled_positions_trans = tiled_positions.swapaxes(0,1)
shifted_tiled_positions_trans = tiled_positions_trans + shift
diff = tiled_positions - shifted_tiled_positions_trans
distance_matrix = safe_sqrt(np.square(diff).sum(axis=2))
return distance_matrix
def create_distance_matrices(positions, shifts):
count = len(positions)
atom_pos = positions.reshape((count,1,3))
tiled_positions = np.tile(atom_pos,(1,count,1))
# if jax.vmap is used, gradients become NaNs
distance_matrices = jax.vmap(create_distance_matrices_single, in_axes=(None,0))(tiled_positions, shifts)
#distance_matrices = create_distance_matrices_single(tiled_positions, shifts)
return np.sum(distance_matrices.flatten())
grad_func = jax.value_and_grad(create_distance_matrices)
position = np.ones((2,3))
shift = np.zeros((1,3))
val, grads = grad_func(position,shift)
Woohoo, glad to hear it!