I have come across an issue that I think is almost identical to #347, except with arctan2. Here's a repro:
def test(a):
b = np.arctan2(a,a)
print(b)
temp = np.array([0.,1.,0.])
c = np.where(temp>0.5,0.,b)
print(c)
return np.sum(c)
aa = np.array([3.,0.,-5.])
print(test(aa))
print(grad(test)(aa))
Thanks for raising this, and for the clear repro.
It's actually a surprisingly thorny issue, as we've recently realized, and I think the change I made in #383 to fix #347 was misguided.
The fundamental trouble, as @dougalm explained to me, is including nan (and inf too, but let's focus on nan) in a system that relies on properties of linear maps. For example, the field / vector space axioms require that 0 * x = x * 0 = 0 for any x. But nan also behaves like nan * x = x * nan = nan for any x. So what value should we assign to an expression like 0 * nan? Should it be 0 * nan = nan or 0 * nan = 0? Unfortunately either choice leads to problems...
This comes up in automatic differentiation because of how nan Jacobians can arise, as in your example, and interact with zero (co)vectors. The Jacobian of lambda x: np.arctan2(x, x) evaluated at 0 is lambda x: np.nan * x. That means if we choose the convention that 0 * nan = nan then we'll get grad(lambda x: np.arctan2(x, x))(0.) results in nan. That seems like a sensible outcome, since the mathematical function clearly denoted by that program isn't differentiable at 0.
So far so good with the 0 * nan = nan convention. How about a program like this one?
def f(x):
if x > 0.5:
return np.arctan2(x, x)
else:
return 0.
# grad(f)(0.) ==> 0.
That also works like you might expect, with grad(f)(0.) == 0.. So what goes wrong in your example? Or in this next one, which seems like it should mean the same thing as the one just above?
def f(x):
return np.where(x > 0.5, np.arctan2(x, x), 0.)
# grad(f)(0.) ==> nan
Or even...
def f(x):
return np.where(False, np.arctan2(x, x), 0.)
# grad(f)(0.) ==> nan
And before you start thinking (like I did) that this is a problem only with np.where specifically, here's another instantiation of the same problem without using np.where:
grad(lambda x: np.array([0., 1./x])[0])(0.) # ==> nan
That last one is just a funny denotation of the zero function (i.e. a constant function), and we're still getting nan!
The trouble with all these, both with np.where and the indexing example, is that in some path through the program (e.g. one side of the np.where) we're generating Jacobians like lambda x: x * np.nan. These paths aren't "taken" in that they're selected off by some array-level primitive like np.where or indexing, and ultimately that means zero covectors are propagated along them. If no nans were involved that'd work great because those branches end up contributing zeros to the final sum-total result. But with the convention that 0 * nan = nan those zero covectors can be turned into nans, and including those nan covectors in a sum gives us a nan result.
From that perspective, it's almost surprising that the example with the explicit if didn't have any issue. But that's because in that case the program we differentiate doesn't represent both paths through the program: we specialize away the if entirely and only see the "good" path. The trouble with these other examples is we're writing programs that explicitly represent both sides when specialized out for autodiff, and even though we're only selecting a "good" side, as we propagate zero covectors through the bad side on the backward pass we start to generate nans.
Okay, so if we choose 0 * nan = nan we'll end up getting nan values for gradients of programs that we think denote differentiable mathematical functions. What if we just choose 0 * nan = 0? That's what #383 did (as an attempted fix for #347), introducing lax._safe_mul for which 0 * x = 0 for any x value and using it in some differentiation rules. But it turns out the consequences of this choice are even worse, as @alextp showed me with this example:
grad(lambda x: np.sqrt(x)**2)(0.) # 0. !!!
That should be a nan. If it weren't a nan then the only defensible value is 1 since that's the directional derivative on the right (and there isn't one on the left). But we're giving 0, and that's a silently incorrect derivative, the worst sin a system can commit. And that incorrect behavior comes from choosing 0 * nan = 0 (i.e. from `lax._safe_mul and #383).
So I plan to revert #383 and go back to producing nan values as the lesser of two evils.
The only solution we've come up with so far to achieve both criteria (i.e. to produce non-nan derivatives for programs that involve selecting off non-differentiable branches, and not to produce incorrect zero derivatives for non-differentiable programs where we should instead get nan) is pretty heavyweight, and is something like tracking a symbolic zero mask potentially through the entire backward pass of differentiation. (These issues don't come up in forward-mode.) That solution sounds heavy both in terms of implementation and in terms of runtime work to be done.
Once we revert #383, if you have programs (like the one in the OP) that you want to express in a differentiable way, one workaround might be to write things in terms of a vectorized_cond function instead of np.where, maybe something like this:
def vectorized_cond(pred, true_fun, false_fun, operand):
# true_fun and false_fun must act elementwise (i.e. be vectorized)
true_op = np.where(pred, operand, 0)
false_op = np.where(pred, 0, operand)
return np.where(pred, true_fun(true_op), false_fun(false_op))
# no nans, even after reverting #383
grad(lambda x: vectorized_cond(x > 0.5, lambda x: np.arctan2(x, x), lambda x: 0., x))(0.)
But that's clumsy, and doesn't solve the indexing version of the problem (i.e. grad(lambda x: np.array([0., 1./x])[0])(0.)). Only the heavyweight solution would handle that, as far as we know. Maybe we could implement it and make it opt-in, like --differentiate_more_programs_but_slowly...
WDYT?
Just to chime in, I would support a vectorized_cond function! I think np.where (at least for me) is often used as a poor-mans differentiable cond so this seems very reasonable to me. Ofc as you said this doesn't solve the indexed version; naively it seems like the indexed case might not be so hard to avoid by restructuring code (index and then operate?) But I've never come across myself so I can't really comment.
FWIW, there's also some discussion here: https://github.com/tensorflow/probability/blob/master/discussion/where-nan.pdf
Collecting more notes, here's something of a worked example for differentiating f = lambda x: np.where(True, x, np.log(x)) at 0.0.
Here's the jaxpr for the JVP of that function as it would be applied to tangent values:
{ lambda ; ; a.
let b = select True a 0.0
c = div a 0.0
d = select True 0.0 c
e = add_any b d
in e }
Notice the value c looks dangerous, but it's ignored on the d = select True 0.0 c line. (select is just the XLA name for where.)
But here's the transpose of that linear function:
{ lambda ; ; a.
let b = select True 0.0 a
c = div b 0.0
d = select True a 0.0
e = add_any c d
in e }
We're feeding a zero covector b into the right branch's VJP, leading to c being a nan covector. That infects the output when we sum the covectors for each branch in the last line.
This example shows how the problem only comes up in reverse-mode and not forward-mode.
Thank you all for the incredibly detailed comments, and for the workarounds (which work great for me). Much appreciated!
One question I have that perhaps @mattjj can help me understand better. In trying to make e.g. the sqrt function safe the current strategy seems to be to use vectorized_cond. However, naively it seems like this introduces several extra np.where's that are superfluous, especially on the forward pass.
Does anyone have any opinions about a solution like this,
@custom_transforms
def safe_sqrt(x):
return np.sqrt(x)
defjvp(safe_sqrt, lambda g, ans, x: 0.5 * g / np.where(x > 0, ans, np.inf) )
Of course a safe version of each op in question would have to be produced separately.
Yes, a safe version of sqrt like this will work, at the expense of
potentially hiding real NaNs.
On Tue, Jul 30, 2019 at 10:52 AM sschoenholz notifications@github.com
wrote:
One question I have that perhaps @mattjj https://github.com/mattjj can
help me understand better. In trying to make e.g. the sqrt function safe
the current strategy seems to be to use vectorized_cond. However, naively
it seems like this introduces several extra np.where's that are
superfluous, especially on the forward pass.Does anyone have any opinions about a solution like this,
@custom_transformsdef safe_sqrt(x):
return np.sqrt(x)
defjvp(safe_sqrt, lambda g, ans, x: 0.5 * g / np.where(x > 0, ans, np.inf) )Of course a safe version of each op in question would have to be produced
separately.—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
https://github.com/google/jax/issues/1052?email_source=notifications&email_token=AAABHRMUKCUQBLYES4VKIILQCB5VFA5CNFSM4IF74ACKYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOD3EYSVY#issuecomment-516524375,
or mute the thread
https://github.com/notifications/unsubscribe-auth/AAABHRMPE6CZKBSP6JKQ5UDQCB5VFANCNFSM4IF74ACA
.
--
Most helpful comment
Thanks for raising this, and for the clear repro.
It's actually a surprisingly thorny issue, as we've recently realized, and I think the change I made in #383 to fix #347 was misguided.
The fundamental trouble, as @dougalm explained to me, is including
nan(andinftoo, but let's focus onnan) in a system that relies on properties of linear maps. For example, the field / vector space axioms require that0 * x = x * 0 = 0for anyx. Butnanalso behaves likenan * x = x * nan = nanfor anyx. So what value should we assign to an expression like0 * nan? Should it be0 * nan = nanor0 * nan = 0? Unfortunately either choice leads to problems...This comes up in automatic differentiation because of how
nanJacobians can arise, as in your example, and interact with zero (co)vectors. The Jacobian oflambda x: np.arctan2(x, x)evaluated at 0 islambda x: np.nan * x. That means if we choose the convention that0 * nan = nanthen we'll getgrad(lambda x: np.arctan2(x, x))(0.)results innan. That seems like a sensible outcome, since the mathematical function clearly denoted by that program isn't differentiable at 0.So far so good with the
0 * nan = nanconvention. How about a program like this one?That also works like you might expect, with
grad(f)(0.) == 0.. So what goes wrong in your example? Or in this next one, which seems like it should mean the same thing as the one just above?Or even...
And before you start thinking (like I did) that this is a problem only with
np.wherespecifically, here's another instantiation of the same problem without usingnp.where:That last one is just a funny denotation of the zero function (i.e. a constant function), and we're still getting
nan!The trouble with all these, both with
np.whereand the indexing example, is that in some path through the program (e.g. one side of thenp.where) we're generating Jacobians likelambda x: x * np.nan. These paths aren't "taken" in that they're selected off by some array-level primitive likenp.whereor indexing, and ultimately that means zero covectors are propagated along them. If nonans were involved that'd work great because those branches end up contributing zeros to the final sum-total result. But with the convention that0 * nan = nanthose zero covectors can be turned intonans, and including thosenancovectors in a sum gives us ananresult.From that perspective, it's almost surprising that the example with the explicit
ifdidn't have any issue. But that's because in that case the program we differentiate doesn't represent both paths through the program: we specialize away theifentirely and only see the "good" path. The trouble with these other examples is we're writing programs that explicitly represent both sides when specialized out for autodiff, and even though we're only selecting a "good" side, as we propagate zero covectors through the bad side on the backward pass we start to generatenans.Okay, so if we choose
0 * nan = nanwe'll end up gettingnanvalues for gradients of programs that we think denote differentiable mathematical functions. What if we just choose0 * nan = 0? That's what #383 did (as an attempted fix for #347), introducinglax._safe_mulfor which0 * x = 0for anyxvalue and using it in some differentiation rules. But it turns out the consequences of this choice are even worse, as @alextp showed me with this example:That should be a
nan. If it weren't ananthen the only defensible value is1since that's the directional derivative on the right (and there isn't one on the left). But we're giving 0, and that's a silently incorrect derivative, the worst sin a system can commit. And that incorrect behavior comes from choosing0 * nan = 0(i.e. from `lax._safe_mul and #383).So I plan to revert #383 and go back to producing
nanvalues as the lesser of two evils.The only solution we've come up with so far to achieve both criteria (i.e. to produce non-
nanderivatives for programs that involve selecting off non-differentiable branches, and not to produce incorrect zero derivatives for non-differentiable programs where we should instead getnan) is pretty heavyweight, and is something like tracking a symbolic zero mask potentially through the entire backward pass of differentiation. (These issues don't come up in forward-mode.) That solution sounds heavy both in terms of implementation and in terms of runtime work to be done.Once we revert #383, if you have programs (like the one in the OP) that you want to express in a differentiable way, one workaround might be to write things in terms of a
vectorized_condfunction instead ofnp.where, maybe something like this:But that's clumsy, and doesn't solve the indexing version of the problem (i.e.
grad(lambda x: np.array([0., 1./x])[0])(0.)). Only the heavyweight solution would handle that, as far as we know. Maybe we could implement it and make it opt-in, like--differentiate_more_programs_but_slowly...WDYT?