When working with an electrostatic pairwise potential, I'm seeing nans in jacrev but not jacfwd
import jax
import jax.numpy as np
from jax.scipy.special import erf, erfc
def delta_r(ri, rj, box=None):
diff = ri - rj # this can be either N,N,3 or B,3
if box is not None:
diff -= box[2]*np.floor(np.expand_dims(diff[...,2], axis=-1)/box[2][2]+0.5)
diff -= box[1]*np.floor(np.expand_dims(diff[...,1], axis=-1)/box[1][1]+0.5)
diff -= box[0]*np.floor(np.expand_dims(diff[...,0], axis=-1)/box[0][0]+0.5)
return diff
def distance(ri, rj, box=None):
dxdydz = np.power(delta_r(ri, rj, box), 2)
# np.linalg.norm nans but this doesn't
dij = np.sqrt(np.sum(dxdydz, axis=-1))
return dij
def pairwise(conf, charges, box=None):
num_atoms = conf.shape[0]
qi = np.expand_dims(charges, 0) # (1, N)
qj = np.expand_dims(charges, 1) # (N, 1)
qij = np.multiply(qi, qj)
ri = np.expand_dims(conf, 0)
rj = np.expand_dims(conf, 1)
dij = distance(ri, rj, box)
eij = qij/dij
eij = np.where(np.eye(num_atoms), np.zeros_like(eij), eij) # zero out diagonals
# print(dij)
alphaEwald = 1.0
eij_direct = np.where(dij > 2.0, np.zeros_like(eij), eij)
eij_direct *= erfc(alphaEwald*eij_direct)
eij_direct = np.sum(eij_direct)/2
return np.sum(eij_direct)
if __name__ == "__main__":
charges = np.array([1.3, 0.3, 0.3, 0.3, 0.3], dtype=np.float64)
conf = np.array([
[ 0.0637, 0.0126, 0.2203],
[ 1.0573, -0.2011, 1.2864],
[ 2.3928, 1.2209, -0.2230],
[-0.6891, 1.6983, 0.0780],
[-0.6312, -1.6261, -0.2601]
], dtype=np.float64)
box = np.array([
[2.0, 0.0, 0.0],
[0.6, 1.6, 0.0],
[0.4, 0.7, 1.1]
], dtype=np.float64)
print(jax.jacfwd(pairwise, 0)(conf, charges, box))
print(jax.jacfwd(pairwise, 1)(conf, charges, box))
print(jax.jacrev(pairwise, 0)(conf, charges, box))
print(jax.jacrev(pairwise, 1)(conf, charges, box))
results
[[-0.01737787 0.11672211 0.39640915]
[-0.01465135 -0.13763744 -0.05796134]
[ 0.20212802 0.2141647 -0.0337759 ]
[-0.02587864 0.03783851 -0.05620446]
[-0.14422016 -0.23108788 -0.24846745]]
[-0.12515384 1.015857 0.6813821 0.6616855 0.3078022 ]
[[-0.01737785 0.11672208 0.39640918]
[-0.01465137 -0.13763744 -0.05796135]
[ 0.20212804 0.21416475 -0.03377588]
[-0.02587865 0.03783851 -0.05620446]
[-0.14422016 -0.23108791 -0.24846748]]
[nan nan nan nan nan]
Another fun, simplified example, jacfwd nans but jacrev doesn't
import jax
import jax.numpy as np
from jax.scipy.special import erf, erfc
def distance(ri, rj, box=None):
# np.linalg.norm nans but this doesn't
dij = np.sqrt(np.sum(np.power(ri - rj, 2), axis=-1))
# dij = np.linalg.norm(ri-rj, axis=-1)
return dij
def pairwise(conf):
ri = np.expand_dims(conf, 0)
rj = np.expand_dims(conf, 1)
dij = distance(ri, rj)
return np.sum(1/dij)
if __name__ == "__main__":
charges = np.array([1.3, 0.3, 0.3, 0.3, 0.3], dtype=np.float64)
conf = np.array([
[ 0.0637, 0.0126, 0.2203],
[ 1.0573, -0.2011, 1.2864],
[ 2.3928, 1.2209, -0.2230],
[-0.6891, 1.6983, 0.0780],
[-0.6312, -1.6261, -0.2601]
], dtype=np.float64)
print(jax.jacfwd(pairwise, 0)(conf))
print(jax.jacrev(pairwise, 0)(conf))
[[nan nan nan]
[nan nan nan]
[nan nan nan]
[nan nan nan]
[nan nan nan]]
[[ 0.4101447 0.00255716 0.42205697]
[-0.76706415 0.34222 -1.1305999 ]
[-0.71000564 -0.36618775 0.26675546]
[ 0.5916799 -0.90446496 0.11163328]
[ 0.47524524 0.92587554 0.33015415]]
If you switch the distance call explicitly to np.linalg.norm then both NaN out (as expected)
[[nan nan nan]
[nan nan nan]
[nan nan nan]
[nan nan nan]
[nan nan nan]]
[[nan nan nan]
[nan nan nan]
[nan nan nan]
[nan nan nan]
[nan nan nan]]
I can temporarily work around this by adding two extra keep_masks:
import jax
import jax.numpy as np
from jax.scipy.special import erf, erfc
def distance(ri, rj, box=None):
# np.linalg.norm nans but this doesn't
# dij = np.sqrt(np.sum(np.power(ri - rj, 2), axis=-1))
dij = np.linalg.norm(ri-rj, axis=-1)
return dij
def pairwise(conf, charges):
ri = np.expand_dims(conf, 0)
rj = np.expand_dims(conf, 1)
dij = distance(ri, rj)
qi = np.expand_dims(charges, 0) # (1, N)
qj = np.expand_dims(charges, 1) # (N, 1)
qij = np.multiply(qi, qj)
keep_mask = 1 - np.eye(conf.shape[0])
qij = np.where(keep_mask, qij, np.zeros_like(qij))
dij = np.where(keep_mask, dij, np.zeros_like(dij))
eij = np.where(keep_mask, qij/dij, np.zeros_like(dij)) # zero out diagonals
return np.sum(eij)
if __name__ == "__main__":
charges = np.array([1.3, 0.3, 0.3, 0.3, 0.3], dtype=np.float64)
conf = np.array([
[ 0.0637, 0.0126, 0.2203],
[ 1.0573, -0.2011, 1.2864],
[ 2.3928, 1.2209, -0.2230],
[-0.6891, 1.6983, 0.0780],
[-0.6312, -1.6261, -0.2601]
], dtype=np.float64)
print(jax.jacfwd(pairwise, 0)(conf, charges))
print(jax.jacfwd(pairwise, 1)(conf, charges))
print(jax.jacrev(pairwise, 0)(conf, charges))
print(jax.jacrev(pairwise, 1)(conf, charges))
[[ 0.15995641 0.00099726 0.16460222]
[-0.2556007 0.07092553 -0.301932 ]
[-0.13806196 -0.07143067 0.03812321]
[ 0.12439986 -0.24072078 0.02349606]
[ 0.10930635 0.24022864 0.07571051]]
[1.2823117 2.4415243 1.5562676 1.9857843 1.9567187]
[[ 0.15995646 0.00099726 0.1646022 ]
[-0.2556007 0.07092553 -0.30193198]
[-0.13806194 -0.07143066 0.03812321]
[ 0.12439984 -0.24072078 0.02349606]
[ 0.10930634 0.24022865 0.07571051]]
[1.2823117 2.4415243 1.5562676 1.985784 1.9567186]
Funny that you run across this now; if my guess at what's going on is correct, we were just talking about this at the end of last week. Basically np.where is hard to handle correctly when one side produces nans. JAX, Autograd, and TF (and possibly others) all have this bug. We think we know how to solve it in JAX (in a way we couldn't with Autograd), but it'd take a surprising amount of work.
Here's an even simpler repro:
In [1]: from jax import grad
In [2]: import jax.numpy as np
In [3]: grad(lambda x: np.where(True, x, np.log(x)))(0.)
Out[3]: array(nan, dtype=float32)
In [4]: jvp(lambda x: np.where(True, x, np.log(x)), (0.,), (1.,))
Out[4]: (array(0., dtype=float32), array(1., dtype=float32))
As you figured out, the workaround is to add more np.where calls. The more complete solution will take some time to describe.
For the time being, I'm glad this issue is open!
+1, a fix will be much appreciated, AFAIK I'm experiencing this a lot when I need to clip before functions with infinite derivatives and 0 * np.inf is treated as nan in cases where it needs to be 0.
For instance these methods will NaN in jacrev for x out of the function domain:
# !!! NaNs
def _arccos(x):
x = np.clip(x, -1, 1)
return np.arccos(x)
def _sqrt(x):
x = np.maximum(x, 0)
return np.sqrt(x)
These will work correctly:
def _arccos(x):
x = np.where(np.abs(x) >= 1, np.sign(x), x)
return np.arccos(x)
def _sqrt(x):
x = np.where(x <= 0, 0, x)
return np.sqrt(x)
Check out the discussion in #1052, especially this comment. The current executive summary is that this isn't something easy to fix without introducing other problems; a short-term solution might be to provide an alternative to np.where.
Most helpful comment
Funny that you run across this now; if my guess at what's going on is correct, we were just talking about this at the end of last week. Basically
np.whereis hard to handle correctly when one side produces nans. JAX, Autograd, and TF (and possibly others) all have this bug. We think we know how to solve it in JAX (in a way we couldn't with Autograd), but it'd take a surprising amount of work.Here's an even simpler repro:
As you figured out, the workaround is to add more
np.wherecalls. The more complete solution will take some time to describe.For the time being, I'm glad this issue is open!