Taichi: Basic Algebraic Simplification Pass

Created on 11 Feb 2020  路  15Comments  路  Source: taichi-dev/taichi

Concisely describe the proposed feature
A simplification pass that reduces a + 0, 0 + a, a - 0, a * 1, a / 1, to a, etc.

Describe the solution you'd like
Add an IR pass at https://github.com/taichi-dev/taichi/blob/1a42ea4a181994201fd7278f7ba0e76a0245ee23/taichi/transforms/alg_simp.cpp

Additional comments
An example test case is https://github.com/taichi-dev/taichi/blob/1a42ea4a181994201fd7278f7ba0e76a0245ee23/tests/cpp/alg_simp.cpp

feature request

All 15 comments

@xumingkuan after you build Taichi from source, execute

ti test_cpp

to run the C++ tests, including simplify_add_zero. Have fun!

Possible solution:

for (auto &s: block) {
  if (s->is<BinaryOpStmt>()) {
    if (s->op_type == BinaryOpType::add || s->op_type == BinaryOpType::sub) {
      if (alg_is_zero(&s->rhs)) {
        s = s->lhs;
      } else if (s->op_type == BinaryOpType::add && alg_is_zero(&s->lhs)) {
        s = s->rhs;
      }
    } else if (s->op_type == BinaryOpType::mul || s->op_type == BinaryOpType::div) {
      if (alg_is_one(&s->rhs)) {
        s = s->lhs;
      } else if (s->op_type == BinaryOpType::mul && alg_is_one(&s->lhs)) {
        s = s->rhs;
      }
    }
  }
}

@archibate Thanks for the suggested solution! Our intern at MIT @xumingkuan is working on this :-)

I'm now implementing it. I wonder if it is necessary to iterate like this (in constant fold):

static void run(IRNode *node) {
  ConstantFold folder;
  while (true) {
    bool modified = false;
    try {
      node->accept(&folder);
    } catch (IRModified) {
      modified = true;
    }
    if (!modified)
      break;
  }
}

And it seems that constant fold only works for i32. Shall we also implement constant fold for other types?

I just wrote a simple void visit(BinaryOpStmt *stmt) override and std::cout << stmt->parent; outputs 0000000000000000. Why is this happening? It causes get_ir_root(); to crash.

UPD: Even just adding irpass::constant_fold(block.get()); in test_alg_simp.cpp (before irpass::alg_simp(block.get());, of course) and std::cout << "QAQ" << stmt->parent << std::endl; at https://github.com/taichi-dev/taichi/blob/master/taichi/transforms/constant_fold.cpp#L41 results in QAQ0000000000000000.

Awesome!

I wonder if it is necessary to iterate like this (in constant fold)

Yes, please. This will allow us to simplify a * 1 * 1 + 0 into a.

And it seems that constant fold only works for i32. Shall we also implement constant fold for other types?

Yeah the constant folder is rather limited now. I think an idea to systematically upgrade the constant folding pass. Let's talk about that after your alg_simp pass is done.

I just wrote a simple void visit(BinaryOpStmt *stmt) override and std::cout << stmt->parent; outputs 0000000000000000. Why is this happening? It causes get_ir_root(); to crash.

Oh sorry. Current Stmt->push_back does not set the parent of the new statement. Please add something like stmt->parent = this in that function, just like Block->insert. My bad.

What about a * 0? We can optimize it when a is not NaN. (Then we may need to make this pass iterative.)

You can add a boolean flag fastmath to the pass. When fastmath is true you can just optimize it out.

Will fastmath = true by default (just write fastmath = true in the code)? Or where can I get the option value of fastmath?

Program::config::fast_math

current_program is nullptr?

Yeah, because for the testing env you don't really have a program, just a bunch of instructions. When actually doing the lowering you will have prog->config.fast_math:
https://github.com/taichi-dev/taichi/blob/master/taichi/backends/codegen_x86.cpp#L611

For testing you can just pass in true/false as a parameter to your pass, depending on if you are testing with fast_math or not.

So may I change void full_simplify(IRNode *root); to void full_simplify(IRNode *root, bool fast_math);? Or are there better solutions?

Good question! Please pass in a CompileConfig instance.

Was this page helpful?
0 / 5 - 0 ratings