Taichi: [Opt] [autodiff] Moving some statements outside "if" can cause NaN in gradients

Created on 1 Jul 2020  路  2Comments  路  Source: taichi-dev/taichi

Background
Currently, we are moving some statements that don't seem to have side effects out of if branches for optimization. After we do this, some ifs can be flattened into a select statement (i.e., the conditional (or ternary) operator in many programming languages), which can be much faster in some cases, so the optimization has been existed for over 1 year, and it's not considered "advanced optimization".

However, when fast_math=True, moving statements that can cause NaN is UB (undefined behavior). Such statements include binary operations and unary operations a / b, sqrt(a), log(a), asin(a)...

Bug Example
This usually doesn't cause problems, but I found a case that can be problematic (even when fast_math=False):

if x > 1:
  y = 1.0 / x

dy/dx should be -1 / x^2 when x > 1, and 0 otherwise.

However, if we move the division outside the if clause, we'll be doing autodiff on an IR like this:

$0 = load x
$1 = $0 > 1
$2 = 1.0 / $0
$3 = 0
$4 = select($1, $2, $3)

And we'll get ($0~4adj are initialized to 0.)

$4adj += 1
if ($1) {
  $2adj += $4adj
} else {
  $3adj += $4adj
}
$0adj += -$2adj / ($0)^2

So dy/dx becomes -1 / x^2 when x > 1, and 0 / x^2 otherwise. What if x == 0? We get dy/dx = NaN.

Possible Solutions

  1. Simply don't move statements that can cause NaN outside if branches. Safe but may does harm to the performance.
  2. Modify the autodiff pass, and wrap $0adj += -$2adj / ($0)^2 in if ($0 != 0) (i.e., wrap "make adjoint" of BinaryOpStmt with BinaryOpType==div in if (stmt->rhs != 0) (the condition here in if is another BinaryOpStmt)). Then we'll have one more if per division that needs autodiff. Also may do harm to the performance.
  3. Add something called None (I suggest DataType::u0?), and initialize $0~4adj to None instead of 0. None * a = None / a = None, for all a (even if a is NaN). a + None = a - None = a for all a. This may be a systematic solution, but it may cause a lot of work, and may cause some new issues. (e.g., How to deal with possible Nones in all backends?)
  4. Ask users to not let anything that appears as a denominator anywhere to be zero at any time in the kernel. (e.g., replace 0 with 1e-30 for each assignment to a variable that appears as a denominator somewhere) Similarly, anything that appears in sqrt should be non-negative...
  5. Add a compilation flag for moving statements that can cause NaN outside if branches (and loops in the future).
  6. More solutions are welcome!

This issue impedes us from doing new optimizations like CSE for global pointers which seems unrelated... (We can do them if we choose solution 4, but I think it's (probably?) not a good solution.)

advanced optimization discussion potential bug

Most helpful comment

Thanks for the nice summary and all the potential solutions!

  1. Simply don't move statements that can cause NaN outside if branches. Safe but may does harm to the performance.

After considering all of them, I actually think this is would be the best solution. Note that not moving statements outside ifs doesn't necessarily mean lower performance:

  • On CPUs as long as the branches are predictable, we don't have to pay the branch-misprediction cost.
  • On GPUs, branches themselves are even cheaper than those on CPUs. What really harms performance would be warp divergence (again, if the threads in the warp are not uniform and unpredictable).
  • In some cases moving a costly statement that's rarely executed outside the ifmay even harm performance: e.g., an sin in If on CPUs.
  • If a user really cares about performance, he should try to write select on himself.
  • Once we have compile-time inference of ADStack sizes, then the extra stacks in AutoDiff introduced by the branching statements will run much faster.
  • It's true that flattening branches can help the optimizer work better. But in this case, we can add a per-kernel or even per-if compilation flag to tell the compiler if it should move statements out of the if bodies. (And set that to false by default.)
  • Note that the CFG optimizations you have implemented make the optimizers do a better job when there are branches compared to the old basic-block-level optimization, so the gain of flattening the ifs for IR optimization is smaller now.

For now, maybe we can add a field CompileConfig::flatten_if and set that to false by default?

All 2 comments

I think x == 0 in the example may be a corner case, and we can choose solution 5, for now, to make our unit tests happy. @yuanming-hu WDYT? And how to name that flag?

Thanks for the nice summary and all the potential solutions!

  1. Simply don't move statements that can cause NaN outside if branches. Safe but may does harm to the performance.

After considering all of them, I actually think this is would be the best solution. Note that not moving statements outside ifs doesn't necessarily mean lower performance:

  • On CPUs as long as the branches are predictable, we don't have to pay the branch-misprediction cost.
  • On GPUs, branches themselves are even cheaper than those on CPUs. What really harms performance would be warp divergence (again, if the threads in the warp are not uniform and unpredictable).
  • In some cases moving a costly statement that's rarely executed outside the ifmay even harm performance: e.g., an sin in If on CPUs.
  • If a user really cares about performance, he should try to write select on himself.
  • Once we have compile-time inference of ADStack sizes, then the extra stacks in AutoDiff introduced by the branching statements will run much faster.
  • It's true that flattening branches can help the optimizer work better. But in this case, we can add a per-kernel or even per-if compilation flag to tell the compiler if it should move statements out of the if bodies. (And set that to false by default.)
  • Note that the CFG optimizations you have implemented make the optimizers do a better job when there are branches compared to the old basic-block-level optimization, so the gain of flattening the ifs for IR optimization is smaller now.

For now, maybe we can add a field CompileConfig::flatten_if and set that to false by default?

Was this page helpful?
0 / 5 - 0 ratings

Related issues

liaopeiyuan picture liaopeiyuan  路  3Comments

yuanming-hu picture yuanming-hu  路  3Comments

Xayahp picture Xayahp  路  3Comments

yuanming-hu picture yuanming-hu  路  3Comments

yuanming-hu picture yuanming-hu  路  3Comments