Rust: Optimize away bounds check in loop indexing into slice, given an assertion

Created on 8 May 2020  路  18Comments  路  Source: rust-lang/rust

I wrote a simple loop indexing into a slice, to test rustc's ability to optimize away bounds checks if it knows an index is in bounds. Even with this very simple test case, I can't seem to get rust to omit the bounds checks no matter what assert! I add. (I know that I could trivially write this code using iterators instead, but I'm trying to figure out rust's ability to optimize here.)

Test case (edited since original posting to augment the assert! further):

#[no_mangle]
fn f(slice: &[u64], start: usize, end: usize) -> u64 {
    let mut total = 0;
    assert!(start < end && start < slice.len() && end <= slice.len());
    for i in start..end {
        total += slice[i];
    }
    total
}

I put that into the compiler explorer, with -O, and the resulting assembly looks like this:

f:
        push    rax
        cmp     rdx, rcx
        jae     .LBB5_8
        cmp     rsi, rdx
        jbe     .LBB5_8
        cmp     rsi, rcx
        jb      .LBB5_8
        xor     eax, eax
.LBB5_4:
        cmp     rdx, rsi
        jae     .LBB5_7
        add     rax, qword ptr [rdi + 8*rdx]
        add     rdx, 1
        cmp     rcx, rdx
        jne     .LBB5_4
        pop     rcx
        ret
.LBB5_7:
        lea     rax, [rip + .L__unnamed_5]
        mov     rdi, rdx
        mov     rdx, rax
        call    qword ptr [rip + core::panicking::panic_bounds_check@GOTPCREL]
        ud2
.LBB5_8:
        call    std::panicking::begin_panic
        ud2

Based on the x86 calling convention, rdi contains the slice base address, rsi contains the slice length, rdx contains start, and rcx contains end.

So, the first three comparisons verify the assertion and jump to .LBB5_8 if it fails, to panic.

Then inside the loop, there's still another comparison of rdx to rsi, and a jump to .LBB5_7 to panic if out of bounds.

As far as I can tell, that's exactly the same comparison. Shouldn't rustc be able to optimize away that bounds check?

Things I've tested:

  • I tried replacing the assert! with an if and unreachable!, or an if and unsafe { std::hint::unreachable_unchecked() }, but in both cases the loop still checked if the index was in bounds on each iteration.
  • I tried using -Zmutable-noalias=yes, which didn't help.
  • I tried various forms of the assertion condition.

Ideally, rustc should be able to optimize away the bounds check in the loop, based on the assertion. Even better would be if rustc could hoist the bounds check out of the loop even without the assertion, but that seems like a harder problem.

A-LLVM C-enhancement I-slow T-compiler

All 18 comments

Did you mean end < slice.len()?

@ecstatic-morse I tried including that as well, among other permutations of the assert condition, and it didn't change the in-loop bounds check.

Adding std::intrinsics::assume(i < slice.len()) inside the loop eliminates the check and vectorizes it.

Looks like LLVM is unable to propagate what it learned on an earlier conditional, into the later loop?

@joshtriplett Nevertheless, you should update your example. Without checking end < slice.len(), f(&[0], 0, 1337) will trigger an index-out-of-bounds panic without hitting the assertion.

@ecstatic-morse Done; edited the code and provided the new corresponding assembly. Doesn't affect the code of the loop, which still includes the bounds check.

This works and considering that an exclusive range is used it presumably also is the intended use of that function

#[no_mangle]
fn f(slice: &[u64], start: usize, end: usize) -> u64 {
    let mut total = 0;
    assert!(start < end && end <= slice.len());
    let end = std::cmp::min(slice.len(), end);
    for i in start..end {
        total += slice[i];
    }
    total
}

@the8472 I can confirm that that code eliminates the bounds check. And interestingly, if I reverse the two arguments to min, that does not eliminate the bounds check.

This works (no bounds check, vectorized):

#[no_mangle]
fn f(slice: &[u64], start: usize, end: usize) -> u64 {
    let mut total = 0;
    assert!(start < end && end <= slice.len());
    let end = std::cmp::min(slice.len(), end);
    for i in start..end {
        total += slice[i];
    }
    total
}

This doesn't work (includes the bounds check in the loop):

#[no_mangle]
fn f(slice: &[u64], start: usize, end: usize) -> u64 {
    let mut total = 0;
    assert!(start < end && end <= slice.len());
    let end = std::cmp::min(end, slice.len());
    for i in start..end {
        total += slice[i];
    }
    total
}

The only differences between the two argument orders to min are the direction of the comparison and which argument gets returned if they compare equal. And sure enough, I can confirm that open-coding the equivalent only works if we use slice.len() in place of end in the case where end == slice.len().

This works (no bounds check in the loop):

#[no_mangle]
fn f(slice: &[u64], start: usize, mut end: usize) -> u64 {
    let mut total = 0;
    assert!(start < end && start < slice.len() && end <= slice.len());
    if end >= slice.len() { end = slice.len(); }
    for i in start..end {
        total += slice[i];
    }
    total
}

This doesn't work (bounds check in the loop):

#[no_mangle]
fn f(slice: &[u64], start: usize, mut end: usize) -> u64 {
    let mut total = 0;
    assert!(start < end && start < slice.len() && end <= slice.len());
    if end > slice.len() { end = slice.len(); }
    for i in start..end {
        total += slice[i];
    }
    total
}

It looks like there are two separate bugs here:
1) rustc should optimize away the bounds check given just the assert!, without needing the redundant call to min or the equivalent open-coded change to end.
2) The optimization shouldn't require assigning end = slice.len() in the case where end == slice.len().

I produced a corresponding set of (naively translated) C++ test cases, and confirmed the same behavior from clang trunk.

The following C++ code does not optimize away the bounds check:

#include <cassert>
#include <cstdint>
#include <vector>
using namespace std;

uint64_t f(vector<uint64_t> slice, size_t start, size_t end)
{
    uint64_t total = 0;
    assert(start < end && start < slice.size() && end <= slice.size());
    for (size_t i = start; i < end; i++) {
        total += slice.at(i);
    }
    return total;
}

Nor does this:

#include <cassert>
#include <cstdint>
#include <vector>
using namespace std;

uint64_t f(vector<uint64_t> slice, size_t start, size_t end)
{
    uint64_t total = 0;
    assert(start < end && start < slice.size() && end <= slice.size());
    if (end > slice.size())
        end = slice.size();
    for (size_t i = start; i < end; i++) {
        total += slice.at(i);
    }
    return total;
}

But this does (note the >= in the if):

#include <cassert>
#include <cstdint>
#include <vector>
using namespace std;

uint64_t f(vector<uint64_t> slice, size_t start, size_t end)
{
    uint64_t total = 0;
    assert(start < end && start < slice.size() && end <= slice.size());
    if (end >= slice.size())
        end = slice.size();
    for (size_t i = start; i < end; i++) {
        total += slice.at(i);
    }
    return total;
}

A more intuitive way to achieve the desired result:

#[no_mangle]
fn f(slice: &[u64], start: usize, end: usize) -> u64 {
    let mut total = 0;
    assert!(start < end && start < slice.len() && end <= slice.len());
    for i in (start..end).take_while(|&i| i < slice.len()) {
        total += slice[i];
    }
    total
}

@the8472 As mentioned in the original comment, I understand that I could rewrite the loop in a way that Rust can figure out how to optimize. However, I'd like to see Rust optimizing the original code, which it has enough information to do; I expect that doing so will substantially improve quite a bit of existing Rust code.

Here's another example:

pub fn copy_t(dest: &mut [u8], src: &[u8]) {
    let len = std::cmp::min(dest.len(), src.len());
    for i in 0..len {
        dest[i] = src[i]
    }
}

which compiles to

example::copy_t:
        push    rax
        cmp     rsi, rcx
        mov     r8, rsi
        cmova   r8, rcx
        test    r8, r8
        je      .LBB2_5
        xor     r9d, r9d
.LBB2_2:
        cmp     rcx, r9
        je      .LBB2_6
        cmp     rsi, r9
        je      .LBB2_7
        movzx   eax, byte ptr [rdx + r9]
        mov     byte ptr [rdi + r9], al
        add     r9, 1
        cmp     r9, r8
        jb      .LBB2_2
.LBB2_5:
        pop     rax
        ret
.LBB2_6:
        lea     rdx, [rip + .L__unnamed_1]
        mov     rdi, rcx
        mov     rsi, rcx
        call    qword ptr [rip + core::panicking::panic_bounds_check@GOTPCREL]
        ud2
.LBB2_7:
        lea     rdx, [rip + .L__unnamed_2]
        mov     rdi, rsi
        call    qword ptr [rip + core::panicking::panic_bounds_check@GOTPCREL]
        ud2

I tried different combinations of assertions from this issue and none of them worked for this one. Using unchecked indexes optimizes the loop to a memcpy:

pub fn copy_s(dest: &mut [u8], src: &[u8]) {
    let len = std::cmp::min(dest.len(), src.len());
    for i in 0..src.len() {
        unsafe {
            *dest.get_unchecked_mut(i) = *src.get_unchecked(i);
        }
    }
}
example::copy_s:
        test    rcx, rcx
        je      .LBB1_2
        push    rax
        mov     rsi, rdx
        mov     rdx, rcx
        call    qword ptr [rip + memcpy@GOTPCREL]
        add     rsp, 8
.LBB1_2:
        ret

How about this? https://godbolt.org/z/TQSMyv

pub fn copy_c(dest: &mut [u8], src: &[u8]) {
    let len = std::cmp::min(dest.len(), src.len());
    let (left, _) = dest.split_at_mut(len);
    left.copy_from_slice(&src[..len]);
}

Yes, copy_from_slice works, but only if you know ahead of time that it's going to be a simple byte-for-byte copy.

What I'm personally interested in is trying to get generic code that (after inlining) compiles down to the equivalent of the above example to optimize.

@tmandry Here is what goes into induction variable simplication for your case, after a bit of cleanup:

define void @test([0 x i8]* nocapture nonnull align 1 %dest.0, i64 %dest.1, [0 x i8]* noalias nocapture nonnull readonly align 1 %src.0, i64 %src.1) {
start:
  %i = icmp ugt i64 %dest.1, %src.1
  %umin = select i1 %i, i64 %src.1, i64 %dest.1
  %i1 = icmp eq i64 %umin, 0
  br i1 %i1, label %bb7, label %bb9.preheader

bb9.preheader:                                    ; preds = %start
  br label %bb9

bb7.loopexit:                                     ; preds = %bb11
  br label %bb7

bb7:                                              ; preds = %bb7.loopexit, %start
  ret void

bb9:                                              ; preds = %bb11, %bb9.preheader
  %iv = phi i64 [ %iv.inc, %bb11 ], [ 0, %bb9.preheader ]
  %iv.inc = add nuw i64 %iv, 1
  %_23 = icmp ult i64 %iv, %src.1
  br i1 %_23, label %bb10, label %panic

bb10:                                             ; preds = %bb9
  %_26 = icmp ult i64 %iv, %dest.1
  br i1 %_26, label %bb11, label %panic1

bb11:                                             ; preds = %bb10
  %i3 = getelementptr inbounds [0 x i8], [0 x i8]* %src.0, i64 0, i64 %iv
  %_20 = load i8, i8* %i3, align 1
  %i4 = getelementptr inbounds [0 x i8], [0 x i8]* %dest.0, i64 0, i64 %iv
  store i8 %_20, i8* %i4, align 1
  %i5 = icmp ult i64 %iv.inc, %umin
  br i1 %i5, label %bb9, label %bb7.loopexit

panic:                                            ; preds = %bb9
  %iter.sroa.0.015.lcssa = phi i64 [ %iv, %bb9 ]
  tail call void @abort(i64 %iter.sroa.0.015.lcssa)
  unreachable

panic1:                                           ; preds = %bb10
  %iter.sroa.0.015.lcssa16 = phi i64 [ %iv, %bb10 ]
  tail call void @abort(i64 %iter.sroa.0.015.lcssa16)
  unreachable
}

declare void @abort(i64)

The thing to note is that %umin != 0 is checked on entry and the loop uses a postinc exit condition %iv+1 < %umin.

The relevant SCEV parts are:

  %iv = phi i64 [ %iv.inc, %bb11 ], [ 0, %bb9.preheader ]
  -->  {0,+,1}<nuw><%bb9> U: [0,-1) S: [0,-1)       Exits: ((-1 + (%dest.1 umin %src.1)) umin %dest.1 umin %src.1)      LoopDispositions: { %bb9: Computable }
  %iv.inc = add nuw i64 %iv, 1
  -->  {1,+,1}<nuw><%bb9> U: [1,0) S: [1,0)     Exits: (1 + ((-1 + (%dest.1 umin %src.1)) umin %dest.1 umin %src.1))        LoopDispositions: { %bb9: Computable }
...
  exit count for bb9: %src.1
  exit count for bb10: %dest.1
  exit count for bb11: (-1 + (%dest.1 umin %src.1))

The -1 is what obscures things here, because it could be overflowing. SCEV is not capable of retaining that a subtraction is NUW, because it canonicalizes to additions.

What I'm personally interested in is trying to get generic code that (after inlining) compiles down to the equivalent of the above example to optimize.

Is this sufficiently generic? It compiles to a memcpy

pub fn generic_copy<T: Clone>(dest: &mut [T], src: &[T]) {
    let len = std::cmp::min(dest.len(), src.len());
    let (dest, _) = dest.split_at_mut(len);
    let src = &src[..len];

    for i in 0..src.len() {
        dest[i] = src[i].clone()
    }
}

pub fn concrete(dest: &mut [u8], src: &[u8]) {
    generic_copy(dest, src)
}
Was this page helpful?
0 / 5 - 0 ratings