Rust: slice::iter() does not preserve number of iterations information for optimizer causing unneeded bounds checks

Created on 26 Aug 2020  路  15Comments  路  Source: rust-lang/rust

Godbolt link to the code below: https://rust.godbolt.org/z/aKf3Wq

pub fn foo1(x: &[u32], y: &[u32]) -> u32 {
    let mut sum = 0;
    let chunk_size = y.len();
    for (c, y) in y.iter().enumerate() {
        for chunk in x.chunks_exact(chunk_size) {
            sum += chunk[c] + y;
        }
    }
    sum
}

This code has a bounds check for chunk[c] although c < chunk_size by construction.

The same code a bit more convoluted gets rid of the bounds check

pub fn foo2(x: &[u32], y: &[u32]) -> u32 {
    let mut sum = 0;
    let chunk_size = y.len();
    for c in 0..chunk_size {
        let y = y[c];
        for chunk in x.chunks_exact(chunk_size) {
            sum += chunk[c] + y;
        }
    }
    sum
}

It seems like the information that 0 <= c < y.len() gets lost for the optimizer when going via y.iter().enumerate(). So this is unrelated to chunks_exact() specifically but I can't come up with an equivalent example without it.

edit: As noticed in https://github.com/rust-lang/rust/issues/75935#issuecomment-680807329, this can be worked around by defining a custom slice iterator that does counting of elements instead of working with an end pointer.

The problem is that the slice::iter() works with an end pointer to know when the iteration can stop and keeps no information around for the optimizer that it's actually going to iterate exactly N times. Unclear to me how this information can be preserved without changing how the iterator works, which will probably have other negative effects.

A-iterators A-slice

All 15 comments

I should probably add that while this code is very contrived, it's based on real code that shows the same behaviour.

Maybe a duplicate of #74938: An upgrade to LLVM 12 or so is needed to fix the issue.

I'll try adding only that change to the LLVM version used by rustc. Let's see if that solves anything (or works at all :) ).

No, does not work at all. Gives a segfault in exactly that function that is changed inside LLVM.

Got it to work. It doesn't fix this issue here, but it fixes #75936 . This one here is still valid.

It seems like the information that 0 <= c < y.len() somehow gets lost for the optimizer when going via y.iter().enumerate().

My guess is that this is because the slice iterators don't go via a counter but instead via an end pointer, so it's not obvious anymore that it's just iterating exactly self.len() times.

Yes, going with a simple iterator that counts instead gets rid of the bounds check. Code:

pub fn foo1(x: &[u32], y: &[u32]) -> u32 {
    let mut sum = 0;
    let chunk_size = y.len();
    for (c, y) in Iter::new(y).enumerate() {
        for chunk in x.chunks_exact(chunk_size) {
            sum += chunk[c] + y;
        }
    }
    sum
}

struct Iter<'a> {
    ptr: *const u32,
    len: usize,
    phantom: std::marker::PhantomData<&'a [u32]>,
}

impl<'a> Iter<'a> {
    fn new(v: &'a [u32]) -> Iter<'a> {
        Iter {
            ptr: v.as_ptr(),
            len: v.len(),
            phantom: std::marker::PhantomData,
        }
    }
}

impl<'a> Iterator for Iter<'a> {
    type Item = &'a u32;

    fn next(&mut self) -> Option<&'a u32> {
        unsafe {
            if self.len == 0 {
                return None;
            }

            let item = &*self.ptr;
            self.ptr = self.ptr.add(1);
            self.len -= 1;
            Some(item)
        }
    }
}

You could update the issue description for new information.

You could update the issue description for new information.

Indeed, thanks. Done!

Is there any particular reason why the std-implementation (slice::Iter) is doing iteration through end pointer equality compared to the counting variant?

Is there any particular reason why the std-implementation (slice::Iter) is doing iteration through end pointer equality compared to the counting variant?

I don't know the history, but in theory one instruction less per iteration (one pointer addition vs. one pointer addition and one counter addition). And it might be taken advantage of in some specialized impls but I don't know.

Might be worth looking at what std::vector iterators in C++ are doing, I'd hope those are optimizing well with clang++.

Is there any particular reason why the std-implementation (slice::Iter) is doing iteration through end pointer equality compared to the counting variant?

The comment in code says that it's because of an optimization for ZST:
https://github.com/rust-lang/rust/blob/118860a7e76daaac3564c7655d46ac65a14fc612/library/core/src/slice/mod.rs#L4009-L4014

The comment in code says that it's because of an optimization for ZST:

If you had a counter that would work basically the same way, you'd just check idx==len or remainder==0 or similar. For ZST the current encoding seems just like a way to make it possible to not worry about ZST special cases in most other places of the code.

I can do an implementation of slice::iter() that does counting next week, but what would be the best way to check this doesn't cause any performance regressions elsewhere? How are such things usually checked (i.e. is there some extensive benchmark suite that I could run, ...)?

And another variant with specialization of the Enumerate iterator, which should also catch various other cases (e.g. the Chunks and ChunksExact iterators on slices). This yields the most optimal code so far: no bounds checks, unrolled and auto-vectorized nicely.

Check assembly here, the new one is foo2().

The std::intrinsics::assume() in the specialized impl is the part that makes it work nicely.

I'll create a PR for this later.

pub fn foo2(x: &[u32], y: &[u32]) -> u32 {
    let mut sum = 0;
    let chunk_size = y.len();
    for (c, y) in Enumerate::new(y.iter()) {
        for chunk in x.chunks_exact(chunk_size) {
            sum += chunk[c] + y;
        }
    }
    sum
}

struct Enumerate<I> {
    iter: I,
    count: usize,
    len: usize,
}

impl<I: Iterator> Enumerate<I> {
    fn new(iter: I) -> Self {
        EnumerateImpl::new(iter)
    }
}

impl<I> Iterator for Enumerate<I>
where
    I: Iterator,
{
    type Item = (usize, <I as Iterator>::Item);

    #[inline]
    fn next(&mut self) -> Option<(usize, <I as Iterator>::Item)> {
        EnumerateImpl::next(self)
    }
}

// Enumerate specialization trait
#[doc(hidden)]
trait EnumerateImpl<I> {
    type Item;
    fn new(iter: I) -> Self;
    fn next(&mut self) -> Option<(usize, Self::Item)>;
}

impl<I> EnumerateImpl<I> for Enumerate<I>
where
    I: Iterator,
{
    type Item = I::Item;

    default fn new(iter: I) -> Self {
        Enumerate {
            iter,
            count: 0,
            len: 0, // unused
        }
    }

    #[inline]
    default fn next(&mut self) -> Option<(usize, I::Item)> {
        let a = self.iter.next()?;
        let i = self.count;
        // Possible undefined overflow.
        self.count += 1;
        Some((i, a))
    }
}

impl<I> EnumerateImpl<I> for Enumerate<I>
where
    // FIXME: Should probably be TrustedRandomAccess because otherwise size_hint() might be expensive?
    I: std::iter::TrustedLen + ExactSizeIterator + Iterator,
{
    fn new(iter: I) -> Self {
        let len = iter.size_hint().0;

        Enumerate {
            iter,
            count: 0,
            len,
        }
    }

    #[inline]
    fn next(&mut self) -> Option<(usize, I::Item)> {
        let a = self.iter.next()?;
        unsafe { std::intrinsics::assume(self.count < self.len); }
        let i = self.count;
        // Possible undefined overflow.
        self.count += 1;
        Some((i, a))
    }
}
Was this page helpful?
0 / 5 - 0 ratings