Taichi: Upgrade the Constant Folding pass

Created on 21 Feb 2020  路  3Comments  路  Source: taichi-dev/taichi

Concisely describe the proposed feature
The constant folder works only for DataType::i32 with BinaryOpType::mul and BinaryOpType::sub. We can systematically upgrade it.

Describe the solution you'd like
We may write the constant folder for each type and each operation, but it takes too much code to write them one by one. It would be better to have something like val(DataType dt) to avoid type checkings like if (data_type == DataType::i32) return val.val_int32();.

feature request ir

Most helpful comment

This will be very helpful! It is worth considering that we many many operations to constant fold:

  • add, mul, div, sub
  • sin, cos, tan, tanh
  • sqrt, pow, ...
  • casting
  • ...

Implementing each operation for each data type can be very time consuming and error-prone.
A more systematic solution is to

  • JIT Compile the operation with operand types into a function and invoke that function to do constant fold
  • Or use the LLVM interpreter to evaluate the result.

All 3 comments

This will be very helpful! It is worth considering that we many many operations to constant fold:

  • add, mul, div, sub
  • sin, cos, tan, tanh
  • sqrt, pow, ...
  • casting
  • ...

Implementing each operation for each data type can be very time consuming and error-prone.
A more systematic solution is to

  • JIT Compile the operation with operand types into a function and invoke that function to do constant fold
  • Or use the LLVM interpreter to evaluate the result.

It would be better to have something like val(DataType dt) to avoid type checkings like if (data_type == DataType::i32) return val.val_int32();.

Possible solution, use the union member value_bits in TypeConstant:

      if (src_type == dst_type) {
        new_constant.value_bits = input.value_bits;
        success = true;
      }

We may write the constant folder for each type and each operation, but it takes too much code to write them one by one.

In constant_fold.cpp line 58 and 68:

- new_constant.val_int32() = lhs->val[0].val_int32() * rhs->val[0].val_int32();
+ new_constant = lhs->val[0] * rhs->val[0];

Then write operator*() and operator+() for TypedConstant like:

TypedConstant operator*(const TypedConstant &other)
{
  TypeConstant result;
  TI_ASSERT(other->dt == this->dt);
  switch (this->dt) {
#define REGISTER_TYPE(t) case DataType::t: result->val_##t = this->val_##t * other->val_##t;
  REGISTER_TYPE(i32)
  REGISTER_TYPE(f32)
  // ...
#undef REGISTER_TYPE
  }
  return result;
}

sin, cos, tan, tanh
sqrt, pow, ...

These works may also be reduced by using macros like REGISTER_FUNC(f).

Also try:

#define REGISTER_TYPE(t, op) // ...
#define REGISTER_OP(op) \
TypedConstant operator##op(const TypedConstant &other) /* `##` here might be unnecessary */ \
{ \
/* ... */ \
}

Hope this would be helpful to you.

We can add analysis/is_expr_const.cpp:

bool is_expr_const(Expr &&expr) {
  if (expr.is_pure_function())    // i.e. constexpr
    if (expr.oprand.all(is_expr_const))
      return true;
  return false;
}
Was this page helpful?
0 / 5 - 0 ratings

Related issues

yuanming-hu picture yuanming-hu  路  3Comments

yuanming-hu picture yuanming-hu  路  3Comments

yuanming-hu picture yuanming-hu  路  3Comments

GeoffreyPlitt picture GeoffreyPlitt  路  4Comments

liaopeiyuan picture liaopeiyuan  路  3Comments