Tvm: [TIR] Bugs in HoistIfThenElse

Created on 11 May 2020  路  7Comments  路  Source: apache/tvm

HoistIfThenElse is a pass currently not enabled in TVM. I tried to enable it in #5553, but there are too many bugs in this pass. Let's fix them first.

BUG 1: HoistIfThenElse transforms

for (n.inner, 0, 2) {
  for (o.inner, 0, 2) {
    if ((((threadIdx.y*2) + n.inner) < 2)) {
      if ((((threadIdx.z*2) + o.inner) < 4)) {
        if ((threadIdx.y < 1)) {
          if ((threadIdx.z < 2)) {
            tvm_store_matrix_sync(Conv.wmma.accumulator, 16, 16, 16, ((n.inner*2) + o.inner), tvm_access_ptr(type_annotation(), Conv, (((((threadIdx.y*401408) + (n.inner*200704)) + (blockIdx.z*1024)) + (threadIdx.z*512)) + (o.inner*256)), 256, 2), 16, "row_major")
          }
        }
      }
    }
  }
}

into

if ((((threadIdx.y*2) + n.inner) < 2)) {
  if ((threadIdx.y < 1)) {
    if ((threadIdx.z < 2)) {
      for (n.inner, 0, 2) {
        for (o.inner, 0, 2) {
          if ((((threadIdx.z*2) + o.inner) < 4)) {
            tvm_store_matrix_sync(Conv.wmma.accumulator, 16, 16, 16, ((n.inner*2) + o.inner), tvm_access_ptr(type_annotation(), Conv, (((((threadIdx.y*401408) + (n.inner*200704)) + (blockIdx.z*1024)) + (threadIdx.z*512)) + (o.inner*256)), 256, 2), 16, "row_major")
          }
        }
      }
    }
  }
}

Possible cause:

https://github.com/apache/incubator-tvm/blob/0e877521f454e239f5c44bb88e557801444d81a5/src/tir/pass/hoist_if_then_else.cc#L295

It only checks whether if_stmt has a preferred position, but that position is not guaranteed to be the current position. Change it to

if (if_position_map.count(if_stmt.get()) &&
    if_position_map.at(if_stmt.get()).as<ForNode>()->loop_var.get() == top_for_var) {

may solve the problem.

BUG 2: src/tir/transforms/split_host_device.cc want the IR to be an SSA form, where each variable can only be defined once. Since we are copying loops into both "then" branches and "else" branches, we have to rename the loop variables in "else" branches to be different from those in "then" branches. I have already written some code for this, see #5553.

BUG 3: IfThenElse nodes containing thread indices should not be hoisted over the definition of the indices. This would happen when Attr node for thread_extent is scheduled into the body of a For node, using a compute_at command. I have already written some code for this, see #5553.

BUG 4:

https://github.com/apache/incubator-tvm/blob/0e877521f454e239f5c44bb88e557801444d81a5/src/tir/pass/hoist_if_then_else.cc#L371

Look at this line. if_stmt can already been updated when running this line. Look at the example below.

for (i, 0, 10) {
  for (j, 0, 10) {
    for (k, 0, 10) {
      if ((i >= 3)) {
        if ((j >= 3)) {
          data[(((i*100) + (j*10)) + k)] = (data[(((i*100) + (j*10)) + k)] + 0.5f)
        }
      }
    }
  }
}

After hoisting j >= 3, if becomes

for (i, 0, 10) {
  for (j, 0, 10) {
    if ((j >= 3)) {
      for (k, 0, 10) {
        if ((i >= 3)) {
          data[(((i*100) + (j*10)) + k)] = (data[(((i*100) + (j*10)) + k)] + 0.5f)
        }
      }
    }
  }
}

Now, when we are hoisting i >= 3, we need to compare and remove

if ((i >= 3)) {
  if ((j >= 3)) {
    data[(((i*100) + (j*10)) + k)] = (data[(((i*100) + (j*10)) + k)] + 0.5f)
  }
}

But j >= 3 has been gone, so RemoveIf fails. We have to track the updating to IfThenElse just like what we did for For.

BUG 5: It is for tests this time.

https://github.com/apache/incubator-tvm/blob/0e877521f454e239f5c44bb88e557801444d81a5/tests/python/unittest/test_tir_pass_hoist_if.py#L175

Why do we expect a ('For', 'j') inside itself? As a potential problem, maybe we should change the variable names to prevent there are two is and two js.

These are all the bugs I found.

Beside, I suggest changing all the for (size_t i = 0; i < xxx.size(); i++) into for (size_t i = 0, n = xxx.size(); i < n; i++), since C++ compiler can't detect this loop invariant.

@kevinthesun Maybe you can have a look.

All 7 comments

@kevinthesun it would be great if you can followup

@roastduck Thank you for bringing these up. This pass was tested only for limited number of cuda conv2d workloads, and not production ready yet. It would be great if you can help fix or improve this pass.

@roastduck would you be interested in taking over the pass?

I met some difficulties improving this pass. For now, I'm not going to take over it.

This pass massively utilizes low level semantics such as PostOrderVisit (instead of StmtExprMutator) and raw pointers to Object, and it relies on manually tracking the updates to these pointers, which is hard to understand. Maybe we should develop an improved StmtExprMutator, which can track the updates to the nodes.

We could certainly rewrite the pass completely, instad of the PostOrderVisit

Given that this pass is not product ready and we have not yet migrated this pass to the transform. Perhaps we can remove the pass for now, and then add it back once we have a better impl. Leaving the thread open for a week to see how would everyone think

Was this page helpful?
0 / 5 - 0 ratings

Related issues

ZihengJiang picture ZihengJiang  路  36Comments

tqchen picture tqchen  路  40Comments

joshpoll picture joshpoll  路  28Comments

sgrechanik-h picture sgrechanik-h  路  25Comments

ghostplant picture ghostplant  路  48Comments