Jax: jacfwd gives different results from jacrev

Created on 30 Apr 2019  路  5Comments  路  Source: google/jax

When working with an electrostatic pairwise potential, I'm seeing nans in jacrev but not jacfwd

import jax
import jax.numpy as np

from jax.scipy.special import erf, erfc
def delta_r(ri, rj, box=None):
    diff = ri - rj # this can be either N,N,3 or B,3
    if box is not None:
        diff -= box[2]*np.floor(np.expand_dims(diff[...,2], axis=-1)/box[2][2]+0.5)
        diff -= box[1]*np.floor(np.expand_dims(diff[...,1], axis=-1)/box[1][1]+0.5)
        diff -= box[0]*np.floor(np.expand_dims(diff[...,0], axis=-1)/box[0][0]+0.5)
    return diff

def distance(ri, rj, box=None):
    dxdydz = np.power(delta_r(ri, rj, box), 2)
    # np.linalg.norm nans but this doesn't
    dij = np.sqrt(np.sum(dxdydz, axis=-1))
    return dij


def pairwise(conf, charges, box=None):
    num_atoms = conf.shape[0]
    qi = np.expand_dims(charges, 0) # (1, N)
    qj = np.expand_dims(charges, 1) # (N, 1)
    qij = np.multiply(qi, qj)
    ri = np.expand_dims(conf, 0)
    rj = np.expand_dims(conf, 1)
    dij = distance(ri, rj, box)
    eij = qij/dij
    eij = np.where(np.eye(num_atoms), np.zeros_like(eij), eij) # zero out diagonals

    # print(dij)
    alphaEwald = 1.0
    eij_direct = np.where(dij > 2.0, np.zeros_like(eij), eij)
    eij_direct *= erfc(alphaEwald*eij_direct)
    eij_direct = np.sum(eij_direct)/2

    return np.sum(eij_direct)

if __name__ == "__main__":
    charges = np.array([1.3, 0.3, 0.3, 0.3, 0.3], dtype=np.float64)
    conf = np.array([
        [ 0.0637,   0.0126,   0.2203],
        [ 1.0573,  -0.2011,   1.2864],
        [ 2.3928,   1.2209,  -0.2230],
        [-0.6891,   1.6983,   0.0780],
        [-0.6312,  -1.6261,  -0.2601]
    ], dtype=np.float64)


    box = np.array([
        [2.0, 0.0, 0.0],
        [0.6, 1.6, 0.0],
        [0.4, 0.7, 1.1]
    ], dtype=np.float64)

    print(jax.jacfwd(pairwise, 0)(conf, charges, box))
    print(jax.jacfwd(pairwise, 1)(conf, charges, box))

    print(jax.jacrev(pairwise, 0)(conf, charges, box))
    print(jax.jacrev(pairwise, 1)(conf, charges, box))

results

[[-0.01737787  0.11672211  0.39640915]
 [-0.01465135 -0.13763744 -0.05796134]
 [ 0.20212802  0.2141647  -0.0337759 ]
 [-0.02587864  0.03783851 -0.05620446]
 [-0.14422016 -0.23108788 -0.24846745]]
[-0.12515384  1.015857    0.6813821   0.6616855   0.3078022 ]
[[-0.01737785  0.11672208  0.39640918]
 [-0.01465137 -0.13763744 -0.05796135]
 [ 0.20212804  0.21416475 -0.03377588]
 [-0.02587865  0.03783851 -0.05620446]
 [-0.14422016 -0.23108791 -0.24846748]]
[nan nan nan nan nan]
bug

Most helpful comment

Funny that you run across this now; if my guess at what's going on is correct, we were just talking about this at the end of last week. Basically np.where is hard to handle correctly when one side produces nans. JAX, Autograd, and TF (and possibly others) all have this bug. We think we know how to solve it in JAX (in a way we couldn't with Autograd), but it'd take a surprising amount of work.

Here's an even simpler repro:

In [1]: from jax import grad

In [2]: import jax.numpy as np

In [3]: grad(lambda x: np.where(True, x, np.log(x)))(0.)
Out[3]: array(nan, dtype=float32)

In [4]: jvp(lambda x: np.where(True, x, np.log(x)), (0.,), (1.,))
Out[4]: (array(0., dtype=float32), array(1., dtype=float32))

As you figured out, the workaround is to add more np.where calls. The more complete solution will take some time to describe.

For the time being, I'm glad this issue is open!

All 5 comments

Another fun, simplified example, jacfwd nans but jacrev doesn't

import jax
import jax.numpy as np

from jax.scipy.special import erf, erfc

def distance(ri, rj, box=None):
    # np.linalg.norm nans but this doesn't
    dij = np.sqrt(np.sum(np.power(ri - rj, 2), axis=-1))
    # dij = np.linalg.norm(ri-rj, axis=-1)
    return dij


def pairwise(conf):
    ri = np.expand_dims(conf, 0)
    rj = np.expand_dims(conf, 1)
    dij = distance(ri, rj)
    return np.sum(1/dij)

if __name__ == "__main__":
    charges = np.array([1.3, 0.3, 0.3, 0.3, 0.3], dtype=np.float64)
    conf = np.array([
        [ 0.0637,   0.0126,   0.2203],
        [ 1.0573,  -0.2011,   1.2864],
        [ 2.3928,   1.2209,  -0.2230],
        [-0.6891,   1.6983,   0.0780],
        [-0.6312,  -1.6261,  -0.2601]
    ], dtype=np.float64)

    print(jax.jacfwd(pairwise, 0)(conf))
    print(jax.jacrev(pairwise, 0)(conf))
[[nan nan nan]
 [nan nan nan]
 [nan nan nan]
 [nan nan nan]
 [nan nan nan]]
[[ 0.4101447   0.00255716  0.42205697]
 [-0.76706415  0.34222    -1.1305999 ]
 [-0.71000564 -0.36618775  0.26675546]
 [ 0.5916799  -0.90446496  0.11163328]
 [ 0.47524524  0.92587554  0.33015415]]

If you switch the distance call explicitly to np.linalg.norm then both NaN out (as expected)

[[nan nan nan]
 [nan nan nan]
 [nan nan nan]
 [nan nan nan]
 [nan nan nan]]
[[nan nan nan]
 [nan nan nan]
 [nan nan nan]
 [nan nan nan]
 [nan nan nan]]

I can temporarily work around this by adding two extra keep_masks:

import jax
import jax.numpy as np

from jax.scipy.special import erf, erfc

def distance(ri, rj, box=None):
    # np.linalg.norm nans but this doesn't
    # dij = np.sqrt(np.sum(np.power(ri - rj, 2), axis=-1))
    dij = np.linalg.norm(ri-rj, axis=-1)
    return dij


def pairwise(conf, charges):
    ri = np.expand_dims(conf, 0)
    rj = np.expand_dims(conf, 1)
    dij = distance(ri, rj)

    qi = np.expand_dims(charges, 0) # (1, N)
    qj = np.expand_dims(charges, 1) # (N, 1)
    qij = np.multiply(qi, qj)

    keep_mask = 1 - np.eye(conf.shape[0])

    qij = np.where(keep_mask, qij, np.zeros_like(qij))
    dij = np.where(keep_mask, dij, np.zeros_like(dij))
    eij = np.where(keep_mask, qij/dij, np.zeros_like(dij)) # zero out diagonals

    return np.sum(eij)

if __name__ == "__main__":
    charges = np.array([1.3, 0.3, 0.3, 0.3, 0.3], dtype=np.float64)
    conf = np.array([
        [ 0.0637,   0.0126,   0.2203],
        [ 1.0573,  -0.2011,   1.2864],
        [ 2.3928,   1.2209,  -0.2230],
        [-0.6891,   1.6983,   0.0780],
        [-0.6312,  -1.6261,  -0.2601]
    ], dtype=np.float64)

    print(jax.jacfwd(pairwise, 0)(conf, charges))
    print(jax.jacfwd(pairwise, 1)(conf, charges))

    print(jax.jacrev(pairwise, 0)(conf, charges))
    print(jax.jacrev(pairwise, 1)(conf, charges))
[[ 0.15995641  0.00099726  0.16460222]
 [-0.2556007   0.07092553 -0.301932  ]
 [-0.13806196 -0.07143067  0.03812321]
 [ 0.12439986 -0.24072078  0.02349606]
 [ 0.10930635  0.24022864  0.07571051]]
[1.2823117 2.4415243 1.5562676 1.9857843 1.9567187]
[[ 0.15995646  0.00099726  0.1646022 ]
 [-0.2556007   0.07092553 -0.30193198]
 [-0.13806194 -0.07143066  0.03812321]
 [ 0.12439984 -0.24072078  0.02349606]
 [ 0.10930634  0.24022865  0.07571051]]
[1.2823117 2.4415243 1.5562676 1.985784  1.9567186]

Funny that you run across this now; if my guess at what's going on is correct, we were just talking about this at the end of last week. Basically np.where is hard to handle correctly when one side produces nans. JAX, Autograd, and TF (and possibly others) all have this bug. We think we know how to solve it in JAX (in a way we couldn't with Autograd), but it'd take a surprising amount of work.

Here's an even simpler repro:

In [1]: from jax import grad

In [2]: import jax.numpy as np

In [3]: grad(lambda x: np.where(True, x, np.log(x)))(0.)
Out[3]: array(nan, dtype=float32)

In [4]: jvp(lambda x: np.where(True, x, np.log(x)), (0.,), (1.,))
Out[4]: (array(0., dtype=float32), array(1., dtype=float32))

As you figured out, the workaround is to add more np.where calls. The more complete solution will take some time to describe.

For the time being, I'm glad this issue is open!

+1, a fix will be much appreciated, AFAIK I'm experiencing this a lot when I need to clip before functions with infinite derivatives and 0 * np.inf is treated as nan in cases where it needs to be 0.

For instance these methods will NaN in jacrev for x out of the function domain:

# !!! NaNs

def _arccos(x):
  x = np.clip(x, -1, 1)
  return np.arccos(x)


def _sqrt(x):
  x = np.maximum(x, 0)
  return np.sqrt(x)

These will work correctly:

def _arccos(x):
  x = np.where(np.abs(x) >= 1, np.sign(x), x)
  return np.arccos(x)


def _sqrt(x):
  x = np.where(x <= 0, 0, x)
  return np.sqrt(x)

Check out the discussion in #1052, especially this comment. The current executive summary is that this isn't something easy to fix without introducing other problems; a short-term solution might be to provide an alternative to np.where.

Was this page helpful?
0 / 5 - 0 ratings

Related issues

alexbw picture alexbw  路  3Comments

asross picture asross  路  3Comments

kunc picture kunc  路  3Comments

sschoenholz picture sschoenholz  路  3Comments

shannon63 picture shannon63  路  3Comments