Skip to content

Simplify discriminant codegen for niche-encoded variants which don't wrap across an integer boundary #143784

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 19, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion compiler/rustc_abi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ use std::fmt;
#[cfg(feature = "nightly")]
use std::iter::Step;
use std::num::{NonZeroUsize, ParseIntError};
use std::ops::{Add, AddAssign, Deref, Mul, RangeInclusive, Sub};
use std::ops::{Add, AddAssign, Deref, Mul, RangeFull, RangeInclusive, Sub};
use std::str::FromStr;

use bitflags::bitflags;
Expand Down Expand Up @@ -1391,12 +1391,45 @@ impl WrappingRange {
}

/// Returns `true` if `size` completely fills the range.
///
/// Note that this is *not* the same as `self == WrappingRange::full(size)`.
/// Niche calculations can produce full ranges which are not the canonical one;
/// for example `Option<NonZero<u16>>` gets `valid_range: (..=0) | (1..)`.
#[inline]
fn is_full_for(&self, size: Size) -> bool {
let max_value = size.unsigned_int_max();
debug_assert!(self.start <= max_value && self.end <= max_value);
self.start == (self.end.wrapping_add(1) & max_value)
}

/// Checks whether this range is considered non-wrapping when the values are
/// interpreted as *unsigned* numbers of width `size`.
///
/// Returns `Ok(true)` if there's no wrap-around, `Ok(false)` if there is,
/// and `Err(..)` if the range is full so it depends how you think about it.
#[inline]
pub fn no_unsigned_wraparound(&self, size: Size) -> Result<bool, RangeFull> {
if self.is_full_for(size) { Err(..) } else { Ok(self.start <= self.end) }
}

/// Checks whether this range is considered non-wrapping when the values are
/// interpreted as *signed* numbers of width `size`.
///
/// This is heavily dependent on the `size`, as `100..=200` does wrap when
/// interpreted as `i8`, but doesn't when interpreted as `i16`.
///
/// Returns `Ok(true)` if there's no wrap-around, `Ok(false)` if there is,
/// and `Err(..)` if the range is full so it depends how you think about it.
#[inline]
pub fn no_signed_wraparound(&self, size: Size) -> Result<bool, RangeFull> {
if self.is_full_for(size) {
Err(..)
} else {
let start: i128 = size.sign_extend(self.start);
let end: i128 = size.sign_extend(self.end);
Ok(start <= end)
}
}
}

impl fmt::Debug for WrappingRange {
Expand Down
77 changes: 50 additions & 27 deletions compiler/rustc_codegen_ssa/src/mir/operand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> {
// value and the variant index match, since that's all `Niche` can encode.

let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32();
let niche_start_const = bx.cx().const_uint_big(tag_llty, niche_start);

// We have a subrange `niche_start..=niche_end` inside `range`.
// If the value of the tag is inside this subrange, it's a
Expand All @@ -511,35 +512,44 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> {
// } else {
// untagged_variant
// }
let niche_start = bx.cx().const_uint_big(tag_llty, niche_start);
let is_niche = bx.icmp(IntPredicate::IntEQ, tag, niche_start);
let is_niche = bx.icmp(IntPredicate::IntEQ, tag, niche_start_const);
let tagged_discr =
bx.cx().const_uint(cast_to, niche_variants.start().as_u32() as u64);
(is_niche, tagged_discr, 0)
} else {
// The special cases don't apply, so we'll have to go with
// the general algorithm.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the comment be repositioned?

let relative_discr = bx.sub(tag, bx.cx().const_uint_big(tag_llty, niche_start));

let tag_range = tag_scalar.valid_range(&dl);
let tag_size = tag_scalar.size(&dl);
let niche_end = u128::from(relative_max).wrapping_add(niche_start);
let niche_end = tag_size.truncate(niche_end);

let relative_discr = bx.sub(tag, niche_start_const);
let cast_tag = bx.intcast(relative_discr, cast_to, false);
let is_niche = bx.icmp(
IntPredicate::IntULE,
relative_discr,
bx.cx().const_uint(tag_llty, relative_max as u64),
);

// Thanks to parameter attributes and load metadata, LLVM already knows
// the general valid range of the tag. It's possible, though, for there
// to be an impossible value *in the middle*, which those ranges don't
// communicate, so it's worth an `assume` to let the optimizer know.
if niche_variants.contains(&untagged_variant)
&& bx.cx().sess().opts.optimize != OptLevel::No
{
let impossible =
u64::from(untagged_variant.as_u32() - niche_variants.start().as_u32());
let impossible = bx.cx().const_uint(tag_llty, impossible);
let ne = bx.icmp(IntPredicate::IntNE, relative_discr, impossible);
bx.assume(ne);
}
let is_niche = if tag_range.no_unsigned_wraparound(tag_size) == Ok(true) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some description of the special case here? Does a diagram like this make sense?

//         niche_start                       niche_end           
//              |                                |               
//              v                                v               
// 0u8----------+--------------------------------+----------255u8
//         ^    |            is niche            |               
//         |    +--------------------------------+               
//         |                                     |               
// tag_range.start                        tag_range.end          

if niche_start == tag_range.start {
let niche_end_const = bx.cx().const_uint_big(tag_llty, niche_end);
bx.icmp(IntPredicate::IntULE, tag, niche_end_const)
} else {
assert_eq!(niche_end, tag_range.end);
bx.icmp(IntPredicate::IntUGE, tag, niche_start_const)
}
} else if tag_range.no_signed_wraparound(tag_size) == Ok(true) {
if niche_start == tag_range.start {
let niche_end_const = bx.cx().const_uint_big(tag_llty, niche_end);
bx.icmp(IntPredicate::IntSLE, tag, niche_end_const)
} else {
assert_eq!(niche_end, tag_range.end);
bx.icmp(IntPredicate::IntSGE, tag, niche_start_const)
}
} else {
bx.icmp(
IntPredicate::IntULE,
relative_discr,
bx.cx().const_uint(tag_llty, relative_max as u64),
)
};

(is_niche, cast_tag, niche_variants.start().as_u32() as u128)
};
Expand All @@ -550,11 +560,24 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> {
bx.add(tagged_discr, bx.cx().const_uint_big(cast_to, delta))
};

let discr = bx.select(
is_niche,
tagged_discr,
bx.cx().const_uint(cast_to, untagged_variant.as_u32() as u64),
);
let untagged_variant_const =
bx.cx().const_uint(cast_to, u64::from(untagged_variant.as_u32()));

// Thanks to parameter attributes and load metadata, LLVM already knows
// the general valid range of the tag. It's possible, though, for there
// to be an impossible value *in the middle*, which those ranges don't
// communicate, so it's worth an `assume` to let the optimizer know.
// Most importantly, this means when optimizing a variant test like
// `SELECT(is_niche, complex, CONST) == CONST` it's ok to simplify that
// to `!is_niche` because the `complex` part can't possibly match.
if niche_variants.contains(&untagged_variant)
&& bx.cx().sess().opts.optimize != OptLevel::No
{
let ne = bx.icmp(IntPredicate::IntNE, tagged_discr, untagged_variant_const);
bx.assume(ne);
}

let discr = bx.select(is_niche, tagged_discr, untagged_variant_const);

// In principle we could insert assumes on the possible range of `discr`, but
// currently in LLVM this isn't worth it because the original `tag` will
Expand Down
223 changes: 223 additions & 0 deletions tests/codegen/enum/enum-discriminant-eq.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
//@ compile-flags: -Copt-level=3 -Zmerge-functions=disabled
//@ min-llvm-version: 20
//@ only-64bit

// The `derive(PartialEq)` on enums with field-less variants compares discriminants,
// so make sure we emit that in some reasonable way.

#![crate_type = "lib"]
#![feature(ascii_char)]
#![feature(core_intrinsics)]
#![feature(repr128)]

use std::ascii::Char as AC;
use std::cmp::Ordering;
use std::intrinsics::discriminant_value;
use std::num::NonZero;

// A type that's bigger than `isize`, unlike the usual cases that have small tags.
#[repr(u128)]
pub enum Giant {
Two = 2,
Three = 3,
Four = 4,
}

#[unsafe(no_mangle)]
pub fn opt_bool_eq_discr(a: Option<bool>, b: Option<bool>) -> bool {
// CHECK-LABEL: @opt_bool_eq_discr(
// CHECK: %[[A:.+]] = icmp ne i8 %a, 2
// CHECK: %[[B:.+]] = icmp eq i8 %b, 2
// CHECK: %[[R:.+]] = xor i1 %[[A]], %[[B]]
// CHECK: ret i1 %[[R]]

discriminant_value(&a) == discriminant_value(&b)
}

#[unsafe(no_mangle)]
pub fn opt_ord_eq_discr(a: Option<Ordering>, b: Option<Ordering>) -> bool {
// CHECK-LABEL: @opt_ord_eq_discr(
// CHECK: %[[A:.+]] = icmp ne i8 %a, 2
// CHECK: %[[B:.+]] = icmp eq i8 %b, 2
// CHECK: %[[R:.+]] = xor i1 %[[A]], %[[B]]
// CHECK: ret i1 %[[R]]

discriminant_value(&a) == discriminant_value(&b)
}

#[unsafe(no_mangle)]
pub fn opt_nz32_eq_discr(a: Option<NonZero<u32>>, b: Option<NonZero<u32>>) -> bool {
// CHECK-LABEL: @opt_nz32_eq_discr(
// CHECK: %[[A:.+]] = icmp ne i32 %a, 0
// CHECK: %[[B:.+]] = icmp eq i32 %b, 0
// CHECK: %[[R:.+]] = xor i1 %[[A]], %[[B]]
// CHECK: ret i1 %[[R]]

discriminant_value(&a) == discriminant_value(&b)
}

#[unsafe(no_mangle)]
pub fn opt_ac_eq_discr(a: Option<AC>, b: Option<AC>) -> bool {
// CHECK-LABEL: @opt_ac_eq_discr(
// CHECK: %[[A:.+]] = icmp ne i8 %a, -128
// CHECK: %[[B:.+]] = icmp eq i8 %b, -128
// CHECK: %[[R:.+]] = xor i1 %[[A]], %[[B]]
// CHECK: ret i1 %[[R]]

discriminant_value(&a) == discriminant_value(&b)
}

#[unsafe(no_mangle)]
pub fn opt_giant_eq_discr(a: Option<Giant>, b: Option<Giant>) -> bool {
// CHECK-LABEL: @opt_giant_eq_discr(
// CHECK: %[[A:.+]] = icmp ne i128 %a, 1
// CHECK: %[[B:.+]] = icmp eq i128 %b, 1
// CHECK: %[[R:.+]] = xor i1 %[[A]], %[[B]]
// CHECK: ret i1 %[[R]]

discriminant_value(&a) == discriminant_value(&b)
}

pub enum Mid<T> {
Before,
Thing(T),
After,
}

#[unsafe(no_mangle)]
pub fn mid_bool_eq_discr(a: Mid<bool>, b: Mid<bool>) -> bool {
// CHECK-LABEL: @mid_bool_eq_discr(

// CHECK: %[[A_REL_DISCR:.+]] = add nsw i8 %a, -2
// CHECK: %[[A_IS_NICHE:.+]] = icmp samesign ugt i8 %a, 1
// CHECK: %[[A_NOT_HOLE:.+]] = icmp ne i8 %[[A_REL_DISCR]], 1
// CHECK: tail call void @llvm.assume(i1 %[[A_NOT_HOLE]])
// CHECK: %[[A_DISCR:.+]] = select i1 %[[A_IS_NICHE]], i8 %[[A_REL_DISCR]], i8 1

// CHECK: %[[B_REL_DISCR:.+]] = add nsw i8 %b, -2
// CHECK: %[[B_IS_NICHE:.+]] = icmp samesign ugt i8 %b, 1
// CHECK: %[[B_NOT_HOLE:.+]] = icmp ne i8 %[[B_REL_DISCR]], 1
// CHECK: tail call void @llvm.assume(i1 %[[B_NOT_HOLE]])
// CHECK: %[[B_DISCR:.+]] = select i1 %[[B_IS_NICHE]], i8 %[[B_REL_DISCR]], i8 1

// CHECK: ret i1 %[[R]]
discriminant_value(&a) == discriminant_value(&b)
}

#[unsafe(no_mangle)]
pub fn mid_ord_eq_discr(a: Mid<Ordering>, b: Mid<Ordering>) -> bool {
// CHECK-LABEL: @mid_ord_eq_discr(

// CHECK: %[[A_REL_DISCR:.+]] = add nsw i8 %a, -2
// CHECK: %[[A_IS_NICHE:.+]] = icmp sgt i8 %a, 1
// CHECK: %[[A_NOT_HOLE:.+]] = icmp ne i8 %[[A_REL_DISCR]], 1
// CHECK: tail call void @llvm.assume(i1 %[[A_NOT_HOLE]])
// CHECK: %[[A_DISCR:.+]] = select i1 %[[A_IS_NICHE]], i8 %[[A_REL_DISCR]], i8 1

// CHECK: %[[B_REL_DISCR:.+]] = add nsw i8 %b, -2
// CHECK: %[[B_IS_NICHE:.+]] = icmp sgt i8 %b, 1
// CHECK: %[[B_NOT_HOLE:.+]] = icmp ne i8 %[[B_REL_DISCR]], 1
// CHECK: tail call void @llvm.assume(i1 %[[B_NOT_HOLE]])
// CHECK: %[[B_DISCR:.+]] = select i1 %[[B_IS_NICHE]], i8 %[[B_REL_DISCR]], i8 1

// CHECK: %[[R:.+]] = icmp eq i8 %[[A_DISCR]], %[[B_DISCR]]
// CHECK: ret i1 %[[R]]
discriminant_value(&a) == discriminant_value(&b)
}

#[unsafe(no_mangle)]
pub fn mid_nz32_eq_discr(a: Mid<NonZero<u32>>, b: Mid<NonZero<u32>>) -> bool {
// CHECK-LABEL: @mid_nz32_eq_discr(
// CHECK: %[[R:.+]] = icmp eq i32 %a.0, %b.0
// CHECK: ret i1 %[[R]]
discriminant_value(&a) == discriminant_value(&b)
}

#[unsafe(no_mangle)]
pub fn mid_ac_eq_discr(a: Mid<AC>, b: Mid<AC>) -> bool {
// CHECK-LABEL: @mid_ac_eq_discr(

// CHECK: %[[A_REL_DISCR:.+]] = xor i8 %a, -128
// CHECK: %[[A_IS_NICHE:.+]] = icmp slt i8 %a, 0
// CHECK: %[[A_NOT_HOLE:.+]] = icmp ne i8 %a, -127
// CHECK: tail call void @llvm.assume(i1 %[[A_NOT_HOLE]])
// CHECK: %[[A_DISCR:.+]] = select i1 %[[A_IS_NICHE]], i8 %[[A_REL_DISCR]], i8 1

// CHECK: %[[B_REL_DISCR:.+]] = xor i8 %b, -128
// CHECK: %[[B_IS_NICHE:.+]] = icmp slt i8 %b, 0
// CHECK: %[[B_NOT_HOLE:.+]] = icmp ne i8 %b, -127
// CHECK: tail call void @llvm.assume(i1 %[[B_NOT_HOLE]])
// CHECK: %[[B_DISCR:.+]] = select i1 %[[B_IS_NICHE]], i8 %[[B_REL_DISCR]], i8 1

// CHECK: %[[R:.+]] = icmp eq i8 %[[A_DISCR]], %[[B_DISCR]]
// CHECK: ret i1 %[[R]]
discriminant_value(&a) == discriminant_value(&b)
}

// FIXME: This should be improved once our LLVM fork picks up the fix for
// <https://github.com/llvm/llvm-project/issues/134024>
#[unsafe(no_mangle)]
pub fn mid_giant_eq_discr(a: Mid<Giant>, b: Mid<Giant>) -> bool {
// CHECK-LABEL: @mid_giant_eq_discr(

// CHECK: %[[A_TRUNC:.+]] = trunc nuw nsw i128 %a to i64
// CHECK: %[[A_REL_DISCR:.+]] = add nsw i64 %[[A_TRUNC]], -5
// CHECK: %[[A_IS_NICHE:.+]] = icmp samesign ugt i128 %a, 4
// CHECK: %[[A_NOT_HOLE:.+]] = icmp ne i64 %[[A_REL_DISCR]], 1
// CHECK: tail call void @llvm.assume(i1 %[[A_NOT_HOLE]])
// CHECK: %[[A_DISCR:.+]] = select i1 %[[A_IS_NICHE]], i64 %[[A_REL_DISCR]], i64 1

// CHECK: %[[B_TRUNC:.+]] = trunc nuw nsw i128 %b to i64
// CHECK: %[[B_REL_DISCR:.+]] = add nsw i64 %[[B_TRUNC]], -5
// CHECK: %[[B_IS_NICHE:.+]] = icmp samesign ugt i128 %b, 4
// CHECK: %[[B_NOT_HOLE:.+]] = icmp ne i64 %[[B_REL_DISCR]], 1
// CHECK: tail call void @llvm.assume(i1 %[[B_NOT_HOLE]])
// CHECK: %[[B_DISCR:.+]] = select i1 %[[B_IS_NICHE]], i64 %[[B_REL_DISCR]], i64 1

// CHECK: %[[R:.+]] = icmp eq i64 %[[A_DISCR]], %[[B_DISCR]]
// CHECK: ret i1 %[[R]]
discriminant_value(&a) == discriminant_value(&b)
}

// In niche-encoded enums, testing for the untagged variant should optimize to a
// straight-forward comparison looking for the natural range of the payload value.

#[unsafe(no_mangle)]
pub fn mid_bool_is_thing(a: Mid<bool>) -> bool {
// CHECK-LABEL: @mid_bool_is_thing(
// CHECK: %[[R:.+]] = icmp samesign ult i8 %a, 2
// CHECK: ret i1 %[[R]]
discriminant_value(&a) == 1
}

#[unsafe(no_mangle)]
pub fn mid_ord_is_thing(a: Mid<Ordering>) -> bool {
// CHECK-LABEL: @mid_ord_is_thing(
// CHECK: %[[R:.+]] = icmp slt i8 %a, 2
// CHECK: ret i1 %[[R]]
discriminant_value(&a) == 1
}

#[unsafe(no_mangle)]
pub fn mid_nz32_is_thing(a: Mid<NonZero<u32>>) -> bool {
// CHECK-LABEL: @mid_nz32_is_thing(
// CHECK: %[[R:.+]] = icmp eq i32 %a.0, 1
// CHECK: ret i1 %[[R]]
discriminant_value(&a) == 1
}

#[unsafe(no_mangle)]
pub fn mid_ac_is_thing(a: Mid<AC>) -> bool {
// CHECK-LABEL: @mid_ac_is_thing(
// CHECK: %[[R:.+]] = icmp sgt i8 %a, -1
// CHECK: ret i1 %[[R]]
discriminant_value(&a) == 1
}

#[unsafe(no_mangle)]
pub fn mid_giant_is_thing(a: Mid<Giant>) -> bool {
// CHECK-LABEL: @mid_giant_is_thing(
// CHECK: %[[R:.+]] = icmp samesign ult i128 %a, 5
// CHECK: ret i1 %[[R]]
discriminant_value(&a) == 1
}
Loading
Loading