Concisely describe the proposed feature
A simplification pass that reduces a + 0, 0 + a, a - 0, a * 1, a / 1, to a, etc.
Describe the solution you'd like
Add an IR pass at https://github.com/taichi-dev/taichi/blob/1a42ea4a181994201fd7278f7ba0e76a0245ee23/taichi/transforms/alg_simp.cpp
Additional comments
An example test case is https://github.com/taichi-dev/taichi/blob/1a42ea4a181994201fd7278f7ba0e76a0245ee23/tests/cpp/alg_simp.cpp
@xumingkuan after you build Taichi from source, execute
ti test_cpp
to run the C++ tests, including simplify_add_zero. Have fun!
Possible solution:
for (auto &s: block) {
if (s->is<BinaryOpStmt>()) {
if (s->op_type == BinaryOpType::add || s->op_type == BinaryOpType::sub) {
if (alg_is_zero(&s->rhs)) {
s = s->lhs;
} else if (s->op_type == BinaryOpType::add && alg_is_zero(&s->lhs)) {
s = s->rhs;
}
} else if (s->op_type == BinaryOpType::mul || s->op_type == BinaryOpType::div) {
if (alg_is_one(&s->rhs)) {
s = s->lhs;
} else if (s->op_type == BinaryOpType::mul && alg_is_one(&s->lhs)) {
s = s->rhs;
}
}
}
}
@archibate Thanks for the suggested solution! Our intern at MIT @xumingkuan is working on this :-)
I'm now implementing it. I wonder if it is necessary to iterate like this (in constant fold):
static void run(IRNode *node) {
ConstantFold folder;
while (true) {
bool modified = false;
try {
node->accept(&folder);
} catch (IRModified) {
modified = true;
}
if (!modified)
break;
}
}
And it seems that constant fold only works for i32. Shall we also implement constant fold for other types?
I just wrote a simple void visit(BinaryOpStmt *stmt) override and std::cout << stmt->parent; outputs 0000000000000000. Why is this happening? It causes get_ir_root(); to crash.
UPD: Even just adding irpass::constant_fold(block.get()); in test_alg_simp.cpp (before irpass::alg_simp(block.get());, of course) and std::cout << "QAQ" << stmt->parent << std::endl; at https://github.com/taichi-dev/taichi/blob/master/taichi/transforms/constant_fold.cpp#L41 results in QAQ0000000000000000.
Awesome!
I wonder if it is necessary to iterate like this (in constant fold)
Yes, please. This will allow us to simplify a * 1 * 1 + 0 into a.
And it seems that constant fold only works for i32. Shall we also implement constant fold for other types?
Yeah the constant folder is rather limited now. I think an idea to systematically upgrade the constant folding pass. Let's talk about that after your alg_simp pass is done.
I just wrote a simple
void visit(BinaryOpStmt *stmt) overrideandstd::cout << stmt->parent;outputs0000000000000000. Why is this happening? It causesget_ir_root();to crash.
Oh sorry. Current Stmt->push_back does not set the parent of the new statement. Please add something like stmt->parent = this in that function, just like Block->insert. My bad.
What about a * 0? We can optimize it when a is not NaN. (Then we may need to make this pass iterative.)
You can add a boolean flag fastmath to the pass. When fastmath is true you can just optimize it out.
Will fastmath = true by default (just write fastmath = true in the code)? Or where can I get the option value of fastmath?
Program::config::fast_math
current_program is nullptr?
Yeah, because for the testing env you don't really have a program, just a bunch of instructions. When actually doing the lowering you will have prog->config.fast_math:
https://github.com/taichi-dev/taichi/blob/master/taichi/backends/codegen_x86.cpp#L611
For testing you can just pass in true/false as a parameter to your pass, depending on if you are testing with fast_math or not.
So may I change void full_simplify(IRNode *root); to void full_simplify(IRNode *root, bool fast_math);? Or are there better solutions?
Good question! Please pass in a CompileConfig instance.