Currently TVM's boundary check avoids some invalid global memory access, it ignores the case when the arguments in reduce_axis requires global memory accessing (to an index tensor, this is common when dealing with sparse tensor/ragged tensors).
Below is a simple example (segment sum) to reproduce the problem, what it did is basically is:
x and a offset(indicates the segment information) tensor offsets (starts with 0 and ends with the length of x).i, compute the sum of elements inside segment in x: sum(x[offsets[i]:offsets[i+1]]), and store the results in out[i].import tvm
import tvm.te as te
num_elements = te.var('num_elements', dtype='int32')
num_segments = te.var('num_segments', dtype='int32')
x = te.placeholder((num_elements,), dtype='float32', name='x')
offsets = te.placeholder((num_segments + 1), dtype='int32', name='offsets')
def segment_sum(i):
"""Compute sum(x[offsets[i]:offsets[i+1]])"""
k = te.reduce_axis((0, offsets[i + 1] - offsets[i]))
return te.sum(x[k + offsets[i]], axis=k)
out = te.compute(
(num_segments,),
segment_sum,
name='out'
)
s = te.create_schedule(out.op)
segment_axis = out.op.axis[0]
segment_outer, segment_inner = s[out.op].split(segment_axis, factor=4)
s[out.op].bind(segment_inner, te.thread_axis('threadIdx.x'))
s[out.op].bind(segment_outer, te.thread_axis('blockIdx.x'))
print(tvm.lower(s, [x, offsets, out]))
Below is the generated code
primfn(x_1: handle, offsets_1: handle, out_1: handle) -> ()
attr = {"global_symbol": "main", "tir.noalias": True}
buffers = {out: Buffer(out_2: Pointer(float32), float32, [num_segments: int32], [stride: int32], type="auto"),
x: Buffer(x_2: Pointer(float32), float32, [num_elements: int32], [stride_1: int32], type="auto"),
offsets: Buffer(offsets_2: Pointer(int32), int32, [(num_segments + 1)], [])}
buffer_map = {x_1: x, offsets_1: offsets, out_1: out} {
attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = floordiv((num_segments + 3), 4);
attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")] "thread_extent" = 4;
if (blockIdx.x < floordiv(num_segments, 4)) {
out_2[(((blockIdx.x*4) + threadIdx.x)*stride)] = 0f32
for (rv: int32, 0, ((int32*)offsets_2[(((blockIdx.x*4) + threadIdx.x) + 1)] - (int32*)offsets_2[((blockIdx.x*4) + threadIdx.x)])) {
if (((blockIdx.x*4) + threadIdx.x) < num_segments) {
out_2[(((blockIdx.x*4) + threadIdx.x)*stride)] = ((float32*)out_2[(((blockIdx.x*4) + threadIdx.x)*stride)] + (float32*)x_2[((rv + (int32*)offsets_2[((blockIdx.x*4) + threadIdx.x)])*stride_1)])
}
}
} else {
if (((blockIdx.x*4) + threadIdx.x) < num_segments) {
out_2[(((blockIdx.x*4) + threadIdx.x)*stride)] = 0f32
}
for (rv_1: int32, 0, ((int32*)offsets_2[(((blockIdx.x*4) + threadIdx.x) + 1)] - (int32*)offsets_2[((blockIdx.x*4) + threadIdx.x)])) {
if (((blockIdx.x*4) + threadIdx.x) < num_segments) {
out_2[(((blockIdx.x*4) + threadIdx.x)*stride)] = ((float32*)out_2[(((blockIdx.x*4) + threadIdx.x)*stride)] + (float32*)x_2[((rv_1 + (int32*)offsets_2[((blockIdx.x*4) + threadIdx.x)])*stride_1)])
}
}
}
}
Note that in for (rv_1: int32, 0, ((int32*)offsets_2[(((blockIdx.x*4) + threadIdx.x) + 1)] - (int32*)offsets_2[((blockIdx.x*4) + threadIdx.x)])) {, the memory access to offsets_2 is not protected thus incurring invalid memory access error when ((blockIdx.x*4) + threadIdx.x) is greater then num_elements.
If we change the order of the if-statement and the for-loop, the program should work correctly:
if (((blockIdx.x*4) + threadIdx.x) < num_segments) {
for (rv_1: int32, 0, ((int32*)offsets_2[(((blockIdx.x*4) + threadIdx.x) + 1)] - (int32*)offsets_2[((blockIdx.x*4) + threadIdx.x)])) {
out_2[(((blockIdx.x*4) + threadIdx.x)*stride)] = ((float32*)out_2[(((blockIdx.x*4) + threadIdx.x)*stride)] + (float32*)x_2[((rv_1 + (int32*)offsets_2[((blockIdx.x*4) + threadIdx.x)])*stride_1)])
}
}
The bug was also mentioned in TVM forum.
I think this error is related to https://github.com/apache/incubator-tvm/blob/f13fed55cfe872ba7f40970f6a35f965d186a30a/src/tir/transforms/bound_checker.cc, I wonder how could I change it to be aware of global memory access in reduce_axis?
cc @junrushao1994
Just a kind reminder that #5130 has been there for half a year :(
I see, looking into the case. Both this example and #5130 seems to appear when the iteration is dependent on an outer loop. And we will need to find a detector to insert the condition properly when generating the loops.
The current compute bound insertion assumes the loop bounds to be not iterator dependent. It might also be useful to think about what can be done and (cannot be done) in iterator dependent loops. As the original assumption of compute means the axis should be independent from each other for certain scheduling correctness to hold. In this case, a dependent reduction bound would certainly bring different implications(for example, it no longer makes sense to reorder spatial axis and reduction), and possible different analysis for the scheduling to hold correctly.
A contribution to make the enhancement for this case would be more than welcomed. However, we should also think about the long term implications and how to correctly support this kind of workloads
Also cc @spectrometerHBH @Hzfengsy since this might bring some fruit for thoughts for the TIR schedule design
Most helpful comment
Also cc @spectrometerHBH @Hzfengsy since this might bring some fruit for thoughts for the TIR schedule design