Rust: Bad codegen with simple match statement

Created on 5 Feb 2020  路  14Comments  路  Source: rust-lang/rust

The following code:

type CSSFloat = f32;

pub enum ViewportPercentageLength {
    Vw(CSSFloat),
    Vh(CSSFloat),
    Vmin(CSSFloat),
    Vmax(CSSFloat),
}

impl ViewportPercentageLength {
    fn try_sum(&self, other: &Self) -> Result<Self, ()> {
        use self::ViewportPercentageLength::*;
        Ok(match (self, other) {
            (&Vw(one), &Vw(other)) => Vw(one + other),
            (&Vh(one), &Vh(other)) => Vh(one + other),
            (&Vmin(one), &Vmin(other)) => Vmin(one + other),
            (&Vmax(one), &Vmax(other)) => Vmax(one + other),
            _ => return Err(()),
        })
    }
}

#[no_mangle]
pub extern "C" fn sum_them(
    one: &ViewportPercentageLength,
    other: &ViewportPercentageLength,
    out: &mut ViewportPercentageLength,
) -> bool {
    match one.try_sum(other) {
        Ok(v) => {
            *out = v;
            true
        }
        Err(()) => false,
    }
}

Generates the following assembly on Rust Nightly when compiled with -C opt-level=3:

sum_them:
        mov     eax, dword ptr [rdi]
        movss   xmm0, dword ptr [rdi + 4]
        mov     ecx, dword ptr [rsi]
        movss   xmm1, dword ptr [rsi + 4]
        lea     rsi, [rip + .LJTI0_0]
        movsxd  rax, dword ptr [rsi + 4*rax]
        add     rax, rsi
        jmp     rax
.LBB0_1:
        xor     eax, eax
        test    ecx, ecx
        je      .LBB0_8
        ret
.LBB0_3:
        mov     eax, 2
        cmp     ecx, 2
        je      .LBB0_8
.LBB0_9:
        xor     eax, eax
        ret
.LBB0_5:
        mov     eax, 3
        cmp     ecx, 3
        jne     .LBB0_9
.LBB0_8:
        addss   xmm0, xmm1
        mov     dword ptr [rdx], eax
        movss   dword ptr [rdx + 4], xmm0
        mov     al, 1
        ret
.LBB0_7:
        mov     eax, 1
        cmp     ecx, 1
        jne     .LBB0_9
        jmp     .LBB0_8
.LJTI0_0:
        .long   .LBB0_1-.LJTI0_0
        .long   .LBB0_7-.LJTI0_0
        .long   .LBB0_3-.LJTI0_0
        .long   .LBB0_5-.LJTI0_0

Godbolt link: https://rust.godbolt.org/z/JfkEez

It seems to generate one branch for each case of the statement, when I would've expected it to look more like:

if one.enum_discriminant != other.enum_discriminant {
    jump to error case
}
write enum into outparam with tag = one.error_discriminant and value one.value != other.value

cc @michaelwoerister @heycam

A-codegen A-mir A-mir-opt I-slow T-compiler

Most helpful comment

After #75119 final asm appears much better.
I'll be using the following test file 68867.rs:

type CSSFloat = f32;

pub enum ViewportPercentageLength {
    Vw(CSSFloat),
    Vh(CSSFloat),
    Vmin(CSSFloat),
    Vmax(CSSFloat),
}

impl ViewportPercentageLength {
    fn try_sum(&self, other: &Self) -> Result<Self, ()> {
        use self::ViewportPercentageLength::*;
        Ok(match (self, other) {
            (&Vw(one), &Vw(other)) => Vw(one + other),
            (&Vh(one), &Vh(other)) => Vh(one + other),
            (&Vmin(one), &Vmin(other)) => Vmin(one + other),
            (&Vmax(one), &Vmax(other)) => Vmax(one + other),
            _ => return Err(()),
        })
    }
}

#[no_mangle]
pub extern "C" fn sum_them(
    one: &ViewportPercentageLength,
    other: &ViewportPercentageLength,
    out: &mut ViewportPercentageLength,
) -> bool {
    match one.try_sum(other) {
        Ok(v) => {
            *out = v;
            true
        }
        Err(()) => false,
    }
}

fn main() {}

before: rustup run nightly-2020-09-20 rustc 68867.rs --emit=asm -O -C opt-level=3 -o 2020-09-20.s

    .section    .text.sum_them,"ax",@progbits
    .globl  sum_them
    .p2align    4, 0x90
    .type   sum_them,@function
sum_them:
    .cfi_startproc
    movl    (%rdi), %eax
    movss   4(%rdi), %xmm0
    movl    (%rsi), %ecx
    movss   4(%rsi), %xmm1
    leaq    .LJTI5_0(%rip), %rsi
    movslq  (%rsi,%rax,4), %rax
    addq    %rsi, %rax
    jmpq    *%rax
.LBB5_1:
    xorl    %eax, %eax
    testl   %ecx, %ecx
    je  .LBB5_8
    retq
.LBB5_3:
    movl    $2, %eax
    cmpl    $2, %ecx
    je  .LBB5_8
.LBB5_9:
    xorl    %eax, %eax
    retq
.LBB5_5:
    movl    $3, %eax
    cmpl    $3, %ecx
    jne .LBB5_9
.LBB5_8:
    addss   %xmm1, %xmm0
    movl    %eax, (%rdx)
    movss   %xmm0, 4(%rdx)
    movb    $1, %al
    retq
.LBB5_7:
    movl    $1, %eax
    cmpl    $1, %ecx
    jne .LBB5_9
    jmp .LBB5_8
.Lfunc_end5:
    .size   sum_them, .Lfunc_end5-sum_them
    .cfi_endproc
    .section    .rodata.sum_them,"a",@progbits
    .p2align    2
.LJTI5_0:
    .long   .LBB5_1-.LJTI5_0
    .long   .LBB5_7-.LJTI5_0
    .long   .LBB5_3-.LJTI5_0
    .long   .LBB5_5-.LJTI5_0

after: rustup run nightly-2020-09-21 rustc 68867.rs --emit=asm -O -C opt-level=3 -o 2020-09-21.s

    .section    .text.sum_them,"ax",@progbits
    .globl  sum_them
    .p2align    4, 0x90
    .type   sum_them,@function
sum_them:
    .cfi_startproc
    movl    (%rdi), %eax
    cmpl    %eax, (%rsi)
    jne .LBB5_1
    movss   4(%rdi), %xmm0
    addss   4(%rsi), %xmm0
    movl    %eax, (%rdx)
    movss   %xmm0, 4(%rdx)
    movb    $1, %al
    retq
.LBB5_1:
    xorl    %eax, %eax
    retq
.Lfunc_end5:
    .size   sum_them, .Lfunc_end5-sum_them
    .cfi_endproc

All 14 comments

Also cc @nox / @SimonSapin as servo uses this kind of pattern quite a lot too

I wonder if this affects the built-in #[derive(PartialEq)] and so on.

Seems like PartialEq is not affected and manages to do this (I used i32 instead of f32 to avoid floating point instructions): https://rust.godbolt.org/z/DXIbja

However a manually implemented version of that is: https://rust.godbolt.org/z/fmYvsb

It'd be good to know what kind of code does rustc generate for this case... Does it exploit internals to poke at the representation directly?

The code that the derive expands to is: https://rust.godbolt.org/z/P3ptVh

Might be a good candidate for a MIR opt. I believe LLVM has historically shied away from doing this kind of thing out of compile time concerns.

The code that the derive expands to is: https://rust.godbolt.org/z/P3ptVh

Cool, TIL!

So the unreachable is key there, and same for putting the discriminant_value check outside the match statement... Is there a way to get the discriminant value in stable without forcing #[repr(u8)] or such to be on the enum?

std::mem::discriminant is stable and returns something that implements Eq

Ah, true! And std::hint::unreachable_unchecked as well... That'd work for me, I think (though improvements in the compiler itself to detect this would be awesome).

I've always wanted to write such a peephole optimisation for match expressions over (T, T) where all the arms but the last one are (T::Foo(..), T::Foo(..)) but I didn't have the time to do it yet.

That's a low-hanging fruit and I suspect will bring improvements all across the board but days are still only 24 hours long.

After #75119 final asm appears much better.
I'll be using the following test file 68867.rs:

type CSSFloat = f32;

pub enum ViewportPercentageLength {
    Vw(CSSFloat),
    Vh(CSSFloat),
    Vmin(CSSFloat),
    Vmax(CSSFloat),
}

impl ViewportPercentageLength {
    fn try_sum(&self, other: &Self) -> Result<Self, ()> {
        use self::ViewportPercentageLength::*;
        Ok(match (self, other) {
            (&Vw(one), &Vw(other)) => Vw(one + other),
            (&Vh(one), &Vh(other)) => Vh(one + other),
            (&Vmin(one), &Vmin(other)) => Vmin(one + other),
            (&Vmax(one), &Vmax(other)) => Vmax(one + other),
            _ => return Err(()),
        })
    }
}

#[no_mangle]
pub extern "C" fn sum_them(
    one: &ViewportPercentageLength,
    other: &ViewportPercentageLength,
    out: &mut ViewportPercentageLength,
) -> bool {
    match one.try_sum(other) {
        Ok(v) => {
            *out = v;
            true
        }
        Err(()) => false,
    }
}

fn main() {}

before: rustup run nightly-2020-09-20 rustc 68867.rs --emit=asm -O -C opt-level=3 -o 2020-09-20.s

    .section    .text.sum_them,"ax",@progbits
    .globl  sum_them
    .p2align    4, 0x90
    .type   sum_them,@function
sum_them:
    .cfi_startproc
    movl    (%rdi), %eax
    movss   4(%rdi), %xmm0
    movl    (%rsi), %ecx
    movss   4(%rsi), %xmm1
    leaq    .LJTI5_0(%rip), %rsi
    movslq  (%rsi,%rax,4), %rax
    addq    %rsi, %rax
    jmpq    *%rax
.LBB5_1:
    xorl    %eax, %eax
    testl   %ecx, %ecx
    je  .LBB5_8
    retq
.LBB5_3:
    movl    $2, %eax
    cmpl    $2, %ecx
    je  .LBB5_8
.LBB5_9:
    xorl    %eax, %eax
    retq
.LBB5_5:
    movl    $3, %eax
    cmpl    $3, %ecx
    jne .LBB5_9
.LBB5_8:
    addss   %xmm1, %xmm0
    movl    %eax, (%rdx)
    movss   %xmm0, 4(%rdx)
    movb    $1, %al
    retq
.LBB5_7:
    movl    $1, %eax
    cmpl    $1, %ecx
    jne .LBB5_9
    jmp .LBB5_8
.Lfunc_end5:
    .size   sum_them, .Lfunc_end5-sum_them
    .cfi_endproc
    .section    .rodata.sum_them,"a",@progbits
    .p2align    2
.LJTI5_0:
    .long   .LBB5_1-.LJTI5_0
    .long   .LBB5_7-.LJTI5_0
    .long   .LBB5_3-.LJTI5_0
    .long   .LBB5_5-.LJTI5_0

after: rustup run nightly-2020-09-21 rustc 68867.rs --emit=asm -O -C opt-level=3 -o 2020-09-21.s

    .section    .text.sum_them,"ax",@progbits
    .globl  sum_them
    .p2align    4, 0x90
    .type   sum_them,@function
sum_them:
    .cfi_startproc
    movl    (%rdi), %eax
    cmpl    %eax, (%rsi)
    jne .LBB5_1
    movss   4(%rdi), %xmm0
    addss   4(%rsi), %xmm0
    movl    %eax, (%rdx)
    movss   %xmm0, 4(%rdx)
    movb    $1, %al
    retq
.LBB5_1:
    xorl    %eax, %eax
    retq
.Lfunc_end5:
    .size   sum_them, .Lfunc_end5-sum_them
    .cfi_endproc

Yeah indeed, thanks for fixing this!

Was this page helpful?
0 / 5 - 0 ratings