Background
Currently, we are moving some statements that don't seem to have side effects out of if branches for optimization. After we do this, some ifs can be flattened into a select statement (i.e., the conditional (or ternary) operator in many programming languages), which can be much faster in some cases, so the optimization has been existed for over 1 year, and it's not considered "advanced optimization".
However, when fast_math=True, moving statements that can cause NaN is UB (undefined behavior). Such statements include binary operations and unary operations a / b, sqrt(a), log(a), asin(a)...
Bug Example
This usually doesn't cause problems, but I found a case that can be problematic (even when fast_math=False):
if x > 1:
y = 1.0 / x
dy/dx should be -1 / x^2 when x > 1, and 0 otherwise.
However, if we move the division outside the if clause, we'll be doing autodiff on an IR like this:
$0 = load x
$1 = $0 > 1
$2 = 1.0 / $0
$3 = 0
$4 = select($1, $2, $3)
And we'll get ($0~4adj are initialized to 0.)
$4adj += 1
if ($1) {
$2adj += $4adj
} else {
$3adj += $4adj
}
$0adj += -$2adj / ($0)^2
So dy/dx becomes -1 / x^2 when x > 1, and 0 / x^2 otherwise. What if x == 0? We get dy/dx = NaN.
Possible Solutions
NaN outside if branches. Safe but may does harm to the performance.$0adj += -$2adj / ($0)^2 in if ($0 != 0) (i.e., wrap "make adjoint" of BinaryOpStmt with BinaryOpType==div in if (stmt->rhs != 0) (the condition here in if is another BinaryOpStmt)). Then we'll have one more if per division that needs autodiff. Also may do harm to the performance.None (I suggest DataType::u0?), and initialize $0~4adj to None instead of 0. None * a = None / a = None, for all a (even if a is NaN). a + None = a - None = a for all a. This may be a systematic solution, but it may cause a lot of work, and may cause some new issues. (e.g., How to deal with possible Nones in all backends?)0 with 1e-30 for each assignment to a variable that appears as a denominator somewhere) Similarly, anything that appears in sqrt should be non-negative...NaN outside if branches (and loops in the future). This issue impedes us from doing new optimizations like CSE for global pointers which seems unrelated... (We can do them if we choose solution 4, but I think it's (probably?) not a good solution.)
I think x == 0 in the example may be a corner case, and we can choose solution 5, for now, to make our unit tests happy. @yuanming-hu WDYT? And how to name that flag?
Thanks for the nice summary and all the potential solutions!
- Simply don't move statements that can cause
NaNoutside if branches. Safe but may does harm to the performance.
After considering all of them, I actually think this is would be the best solution. Note that not moving statements outside ifs doesn't necessarily mean lower performance:
ifmay even harm performance: e.g., an sin in If on CPUs.select on himself.ADStack sizes, then the extra stacks in AutoDiff introduced by the branching statements will run much faster.if bodies. (And set that to false by default.)ifs for IR optimization is smaller now.For now, maybe we can add a field CompileConfig::flatten_if and set that to false by default?
Most helpful comment
Thanks for the nice summary and all the potential solutions!
After considering all of them, I actually think this is would be the best solution. Note that not moving statements outside
ifs doesn't necessarily mean lower performance:ifmay even harm performance: e.g., ansininIfon CPUs.selecton himself.ADStacksizes, then the extra stacks in AutoDiff introduced by the branching statements will run much faster.ifbodies. (And set that tofalseby default.)ifs for IR optimization is smaller now.For now, maybe we can add a field
CompileConfig::flatten_ifand set that tofalseby default?