Jax: gradients through np.where when one of branches is nan, #347 except with arctan2

Created on 23 Jul 2019  Â·  7Comments  Â·  Source: google/jax

I have come across an issue that I think is almost identical to #347, except with arctan2. Here's a repro:

def test(a):
  b = np.arctan2(a,a)
  print(b)

  temp = np.array([0.,1.,0.])
  c = np.where(temp>0.5,0.,b)
  print(c)

  return np.sum(c)

aa = np.array([3.,0.,-5.])
print(test(aa))
print(grad(test)(aa))
question

Most helpful comment

Thanks for raising this, and for the clear repro.

It's actually a surprisingly thorny issue, as we've recently realized, and I think the change I made in #383 to fix #347 was misguided.

The fundamental trouble, as @dougalm explained to me, is including nan (and inf too, but let's focus on nan) in a system that relies on properties of linear maps. For example, the field / vector space axioms require that 0 * x = x * 0 = 0 for any x. But nan also behaves like nan * x = x * nan = nan for any x. So what value should we assign to an expression like 0 * nan? Should it be 0 * nan = nan or 0 * nan = 0? Unfortunately either choice leads to problems...

This comes up in automatic differentiation because of how nan Jacobians can arise, as in your example, and interact with zero (co)vectors. The Jacobian of lambda x: np.arctan2(x, x) evaluated at 0 is lambda x: np.nan * x. That means if we choose the convention that 0 * nan = nan then we'll get grad(lambda x: np.arctan2(x, x))(0.) results in nan. That seems like a sensible outcome, since the mathematical function clearly denoted by that program isn't differentiable at 0.

So far so good with the 0 * nan = nan convention. How about a program like this one?

def f(x):
  if x > 0.5:
    return np.arctan2(x, x)
  else:
    return 0.

# grad(f)(0.) ==> 0.

That also works like you might expect, with grad(f)(0.) == 0.. So what goes wrong in your example? Or in this next one, which seems like it should mean the same thing as the one just above?

def f(x):
  return np.where(x > 0.5, np.arctan2(x, x), 0.)

# grad(f)(0.) ==> nan

Or even...

def f(x):
  return np.where(False, np.arctan2(x, x), 0.)

# grad(f)(0.) ==> nan

And before you start thinking (like I did) that this is a problem only with np.where specifically, here's another instantiation of the same problem without using np.where:

grad(lambda x: np.array([0., 1./x])[0])(0.)  # ==> nan

That last one is just a funny denotation of the zero function (i.e. a constant function), and we're still getting nan!

The trouble with all these, both with np.where and the indexing example, is that in some path through the program (e.g. one side of the np.where) we're generating Jacobians like lambda x: x * np.nan. These paths aren't "taken" in that they're selected off by some array-level primitive like np.where or indexing, and ultimately that means zero covectors are propagated along them. If no nans were involved that'd work great because those branches end up contributing zeros to the final sum-total result. But with the convention that 0 * nan = nan those zero covectors can be turned into nans, and including those nan covectors in a sum gives us a nan result.

From that perspective, it's almost surprising that the example with the explicit if didn't have any issue. But that's because in that case the program we differentiate doesn't represent both paths through the program: we specialize away the if entirely and only see the "good" path. The trouble with these other examples is we're writing programs that explicitly represent both sides when specialized out for autodiff, and even though we're only selecting a "good" side, as we propagate zero covectors through the bad side on the backward pass we start to generate nans.

Okay, so if we choose 0 * nan = nan we'll end up getting nan values for gradients of programs that we think denote differentiable mathematical functions. What if we just choose 0 * nan = 0? That's what #383 did (as an attempted fix for #347), introducing lax._safe_mul for which 0 * x = 0 for any x value and using it in some differentiation rules. But it turns out the consequences of this choice are even worse, as @alextp showed me with this example:

grad(lambda x: np.sqrt(x)**2)(0.)  # 0. !!!

That should be a nan. If it weren't a nan then the only defensible value is 1 since that's the directional derivative on the right (and there isn't one on the left). But we're giving 0, and that's a silently incorrect derivative, the worst sin a system can commit. And that incorrect behavior comes from choosing 0 * nan = 0 (i.e. from `lax._safe_mul and #383).

So I plan to revert #383 and go back to producing nan values as the lesser of two evils.

The only solution we've come up with so far to achieve both criteria (i.e. to produce non-nan derivatives for programs that involve selecting off non-differentiable branches, and not to produce incorrect zero derivatives for non-differentiable programs where we should instead get nan) is pretty heavyweight, and is something like tracking a symbolic zero mask potentially through the entire backward pass of differentiation. (These issues don't come up in forward-mode.) That solution sounds heavy both in terms of implementation and in terms of runtime work to be done.

Once we revert #383, if you have programs (like the one in the OP) that you want to express in a differentiable way, one workaround might be to write things in terms of a vectorized_cond function instead of np.where, maybe something like this:

def vectorized_cond(pred, true_fun, false_fun, operand):
  # true_fun and false_fun must act elementwise (i.e. be vectorized)
  true_op = np.where(pred, operand, 0)
  false_op = np.where(pred, 0, operand)
  return np.where(pred, true_fun(true_op), false_fun(false_op))

# no nans, even after reverting #383
grad(lambda x: vectorized_cond(x > 0.5, lambda x: np.arctan2(x, x), lambda x: 0., x))(0.)

But that's clumsy, and doesn't solve the indexing version of the problem (i.e. grad(lambda x: np.array([0., 1./x])[0])(0.)). Only the heavyweight solution would handle that, as far as we know. Maybe we could implement it and make it opt-in, like --differentiate_more_programs_but_slowly...

WDYT?

All 7 comments

Thanks for raising this, and for the clear repro.

It's actually a surprisingly thorny issue, as we've recently realized, and I think the change I made in #383 to fix #347 was misguided.

The fundamental trouble, as @dougalm explained to me, is including nan (and inf too, but let's focus on nan) in a system that relies on properties of linear maps. For example, the field / vector space axioms require that 0 * x = x * 0 = 0 for any x. But nan also behaves like nan * x = x * nan = nan for any x. So what value should we assign to an expression like 0 * nan? Should it be 0 * nan = nan or 0 * nan = 0? Unfortunately either choice leads to problems...

This comes up in automatic differentiation because of how nan Jacobians can arise, as in your example, and interact with zero (co)vectors. The Jacobian of lambda x: np.arctan2(x, x) evaluated at 0 is lambda x: np.nan * x. That means if we choose the convention that 0 * nan = nan then we'll get grad(lambda x: np.arctan2(x, x))(0.) results in nan. That seems like a sensible outcome, since the mathematical function clearly denoted by that program isn't differentiable at 0.

So far so good with the 0 * nan = nan convention. How about a program like this one?

def f(x):
  if x > 0.5:
    return np.arctan2(x, x)
  else:
    return 0.

# grad(f)(0.) ==> 0.

That also works like you might expect, with grad(f)(0.) == 0.. So what goes wrong in your example? Or in this next one, which seems like it should mean the same thing as the one just above?

def f(x):
  return np.where(x > 0.5, np.arctan2(x, x), 0.)

# grad(f)(0.) ==> nan

Or even...

def f(x):
  return np.where(False, np.arctan2(x, x), 0.)

# grad(f)(0.) ==> nan

And before you start thinking (like I did) that this is a problem only with np.where specifically, here's another instantiation of the same problem without using np.where:

grad(lambda x: np.array([0., 1./x])[0])(0.)  # ==> nan

That last one is just a funny denotation of the zero function (i.e. a constant function), and we're still getting nan!

The trouble with all these, both with np.where and the indexing example, is that in some path through the program (e.g. one side of the np.where) we're generating Jacobians like lambda x: x * np.nan. These paths aren't "taken" in that they're selected off by some array-level primitive like np.where or indexing, and ultimately that means zero covectors are propagated along them. If no nans were involved that'd work great because those branches end up contributing zeros to the final sum-total result. But with the convention that 0 * nan = nan those zero covectors can be turned into nans, and including those nan covectors in a sum gives us a nan result.

From that perspective, it's almost surprising that the example with the explicit if didn't have any issue. But that's because in that case the program we differentiate doesn't represent both paths through the program: we specialize away the if entirely and only see the "good" path. The trouble with these other examples is we're writing programs that explicitly represent both sides when specialized out for autodiff, and even though we're only selecting a "good" side, as we propagate zero covectors through the bad side on the backward pass we start to generate nans.

Okay, so if we choose 0 * nan = nan we'll end up getting nan values for gradients of programs that we think denote differentiable mathematical functions. What if we just choose 0 * nan = 0? That's what #383 did (as an attempted fix for #347), introducing lax._safe_mul for which 0 * x = 0 for any x value and using it in some differentiation rules. But it turns out the consequences of this choice are even worse, as @alextp showed me with this example:

grad(lambda x: np.sqrt(x)**2)(0.)  # 0. !!!

That should be a nan. If it weren't a nan then the only defensible value is 1 since that's the directional derivative on the right (and there isn't one on the left). But we're giving 0, and that's a silently incorrect derivative, the worst sin a system can commit. And that incorrect behavior comes from choosing 0 * nan = 0 (i.e. from `lax._safe_mul and #383).

So I plan to revert #383 and go back to producing nan values as the lesser of two evils.

The only solution we've come up with so far to achieve both criteria (i.e. to produce non-nan derivatives for programs that involve selecting off non-differentiable branches, and not to produce incorrect zero derivatives for non-differentiable programs where we should instead get nan) is pretty heavyweight, and is something like tracking a symbolic zero mask potentially through the entire backward pass of differentiation. (These issues don't come up in forward-mode.) That solution sounds heavy both in terms of implementation and in terms of runtime work to be done.

Once we revert #383, if you have programs (like the one in the OP) that you want to express in a differentiable way, one workaround might be to write things in terms of a vectorized_cond function instead of np.where, maybe something like this:

def vectorized_cond(pred, true_fun, false_fun, operand):
  # true_fun and false_fun must act elementwise (i.e. be vectorized)
  true_op = np.where(pred, operand, 0)
  false_op = np.where(pred, 0, operand)
  return np.where(pred, true_fun(true_op), false_fun(false_op))

# no nans, even after reverting #383
grad(lambda x: vectorized_cond(x > 0.5, lambda x: np.arctan2(x, x), lambda x: 0., x))(0.)

But that's clumsy, and doesn't solve the indexing version of the problem (i.e. grad(lambda x: np.array([0., 1./x])[0])(0.)). Only the heavyweight solution would handle that, as far as we know. Maybe we could implement it and make it opt-in, like --differentiate_more_programs_but_slowly...

WDYT?

Just to chime in, I would support a vectorized_cond function! I think np.where (at least for me) is often used as a poor-mans differentiable cond so this seems very reasonable to me. Ofc as you said this doesn't solve the indexed version; naively it seems like the indexed case might not be so hard to avoid by restructuring code (index and then operate?) But I've never come across myself so I can't really comment.

Collecting more notes, here's something of a worked example for differentiating f = lambda x: np.where(True, x, np.log(x)) at 0.0.

Here's the jaxpr for the JVP of that function as it would be applied to tangent values:

{ lambda  ;  ; a.
  let b = select True a 0.0
      c = div a 0.0
      d = select True 0.0 c
      e = add_any b d
  in e }

Notice the value c looks dangerous, but it's ignored on the d = select True 0.0 c line. (select is just the XLA name for where.)

But here's the transpose of that linear function:

{ lambda  ;  ; a.
  let b = select True 0.0 a
      c = div b 0.0
      d = select True a 0.0
      e = add_any c d
  in e }

We're feeding a zero covector b into the right branch's VJP, leading to c being a nan covector. That infects the output when we sum the covectors for each branch in the last line.

This example shows how the problem only comes up in reverse-mode and not forward-mode.

Thank you all for the incredibly detailed comments, and for the workarounds (which work great for me). Much appreciated!

One question I have that perhaps @mattjj can help me understand better. In trying to make e.g. the sqrt function safe the current strategy seems to be to use vectorized_cond. However, naively it seems like this introduces several extra np.where's that are superfluous, especially on the forward pass.

Does anyone have any opinions about a solution like this,

@custom_transforms
def safe_sqrt(x):
  return np.sqrt(x)
defjvp(safe_sqrt, lambda g, ans, x: 0.5 * g / np.where(x > 0, ans, np.inf) )

Of course a safe version of each op in question would have to be produced separately.

Yes, a safe version of sqrt like this will work, at the expense of
potentially hiding real NaNs.

On Tue, Jul 30, 2019 at 10:52 AM sschoenholz notifications@github.com
wrote:

One question I have that perhaps @mattjj https://github.com/mattjj can
help me understand better. In trying to make e.g. the sqrt function safe
the current strategy seems to be to use vectorized_cond. However, naively
it seems like this introduces several extra np.where's that are
superfluous, especially on the forward pass.

Does anyone have any opinions about a solution like this,

@custom_transformsdef safe_sqrt(x):
return np.sqrt(x)
defjvp(safe_sqrt, lambda g, ans, x: 0.5 * g / np.where(x > 0, ans, np.inf) )

Of course a safe version of each op in question would have to be produced
separately.

—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
https://github.com/google/jax/issues/1052?email_source=notifications&email_token=AAABHRMUKCUQBLYES4VKIILQCB5VFA5CNFSM4IF74ACKYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOD3EYSVY#issuecomment-516524375,
or mute the thread
https://github.com/notifications/unsubscribe-auth/AAABHRMPE6CZKBSP6JKQ5UDQCB5VFANCNFSM4IF74ACA
.

--

  • Alex
Was this page helpful?
0 / 5 - 0 ratings

Related issues

ericmjl picture ericmjl  Â·  53Comments

shyoshyo picture shyoshyo  Â·  26Comments

NeilGirdhar picture NeilGirdhar  Â·  23Comments

christopherhesse picture christopherhesse  Â·  32Comments

murphyk picture murphyk  Â·  31Comments