From 82336b681fa6e917d4fe217a227614f481aa30ac Mon Sep 17 00:00:00 2001 From: Nimit Sachdeva Date: Mon, 28 Jul 2025 17:23:11 -0400 Subject: [PATCH 1/2] Optimize usub.sat fix for #79690 --- .../InstCombine/InstCombineSelect.cpp | 143 +++++++++++++++++- .../InstCombine/usub_sat_to_msb_mask.ll | 126 +++++++++++++++ 2 files changed, 263 insertions(+), 6 deletions(-) create mode 100644 llvm/test/Transforms/InstCombine/usub_sat_to_msb_mask.ll diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index eb4332fbc0959..74544009e6872 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -42,6 +42,8 @@ #include "llvm/Support/KnownBits.h" #include "llvm/Transforms/InstCombine/InstCombiner.h" #include +#include +#include #include #define DEBUG_TYPE "instcombine" @@ -50,7 +52,6 @@ using namespace llvm; using namespace PatternMatch; - /// Replace a select operand based on an equality comparison with the identity /// constant of a binop. static Instruction *foldSelectBinOpIdentity(SelectInst &Sel, @@ -1713,7 +1714,6 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp, if (Pred == CmpInst::ICMP_ULT && match(X, m_Add(m_Value(), m_Constant()))) return nullptr; - Value *SelVal0, *SelVal1; // We do not care which one is from where. match(&Sel, m_Select(m_Value(), m_Value(SelVal0), m_Value(SelVal1))); // At least one of these values we are selecting between must be a constant @@ -1993,6 +1993,135 @@ Value *InstCombinerImpl::foldSelectWithConstOpToBinOp(ICmpInst *Cmp, return BinOp; } +/// Folds: +/// %a_sub = call @llvm.usub.sat(x, IntConst1) +/// %b_sub = call @llvm.usub.sat(y, IntConst2) +/// %or = or %a_sub, %b_sub +/// %cmp = icmp eq %or, 0 +/// %sel = select %cmp, 0, MostSignificantBit +/// into: +/// %a_sub' = usub.sat(x, IntConst1 - MostSignificantBit) +/// %b_sub' = usub.sat(y, IntConst2 - MostSignificantBit) +/// %or = or %a_sub', %b_sub' +/// %and = and %or, MostSignificantBit +/// If the args are vectors +/// +static Instruction *foldICmpUSubSatWithAndForMostSignificantBitCmp( + SelectInst &SI, ICmpInst *ICI, InstCombiner::BuilderTy &Builder) { + auto *CI = dyn_cast(SI.getCondition()); + if (!CI) { + return nullptr; + } + + Value *CmpLHS = CI->getOperand(0); + Value *CmpRHS = CI->getOperand(1); + if (!match(CmpRHS, m_Zero())) { + return nullptr; + } + auto Pred = CI->getPredicate(); + auto *TrueVal = SI.getTrueValue(); + auto *FalseVal = SI.getFalseValue(); + + if (Pred != ICmpInst::ICMP_EQ) + return nullptr; + + // Match: icmp eq (or (usub.sat A, IntConst1), (usub.sat B, IntConst2)), 0 + Value *A, *B; + ConstantInt *IntConst1, *IntConst2, *PossibleMSBInt; + + if (match(CmpLHS, m_Or(m_Intrinsic( + m_Value(A), m_ConstantInt(IntConst1)), + m_Intrinsic( + m_Value(B), m_ConstantInt(IntConst2)))) && + match(TrueVal, m_Zero()) && + match(FalseVal, m_ConstantInt(PossibleMSBInt))) { + auto *Ty = A->getType(); + unsigned BW = Ty->getIntegerBitWidth(); + APInt MostSignificantBit = APInt::getOneBitSet(BW, BW - 1); + + if (PossibleMSBInt->getValue() != MostSignificantBit) + return nullptr; + // Ensure IntConst1 and IntConst2 are >= MostSignificantBit + if (IntConst1->getValue().ult(MostSignificantBit) || + IntConst2->getValue().ult(MostSignificantBit)) + return nullptr; + + // Rewrite: + Value *NewA = Builder.CreateBinaryIntrinsic( + Intrinsic::usub_sat, A, + ConstantInt::get(Ty, IntConst1->getValue() - MostSignificantBit + 1)); + Value *NewB = Builder.CreateBinaryIntrinsic( + Intrinsic::usub_sat, B, + ConstantInt::get(Ty, IntConst2->getValue() - MostSignificantBit + 1)); + Value *Or = Builder.CreateOr(NewA, NewB); + Value *And = + Builder.CreateAnd(Or, ConstantInt::get(Ty, MostSignificantBit)); + return cast(And); + } + Constant *Const1, *Const2, *PossibleMSB; + if (match(CmpLHS, m_Or(m_Intrinsic(m_Value(A), + m_Constant(Const1)), + m_Intrinsic( + m_Value(B), m_Constant(Const2)))) && + match(TrueVal, m_Zero()) && match(FalseVal, m_Constant(PossibleMSB))) { + auto *VecTy1 = dyn_cast(Const1->getType()); + auto *VecTy2 = dyn_cast(Const2->getType()); + auto *VecTyMSB = dyn_cast(PossibleMSB->getType()); + if (!VecTy1 || !VecTy2 || !VecTyMSB) { + return nullptr; + } + + unsigned NumElements = VecTy1->getNumElements(); + + if (NumElements != VecTy2->getNumElements() || + NumElements != VecTyMSB->getNumElements() || NumElements == 0) { + return nullptr; + } + auto *SplatMSB = + dyn_cast(PossibleMSB->getAggregateElement(0u)); + unsigned BW = SplatMSB->getValue().getBitWidth(); + APInt MostSignificantBit = APInt::getOneBitSet(BW, BW - 1); + if (!SplatMSB || SplatMSB->getValue() != MostSignificantBit) { + return nullptr; + } + for (unsigned int i = 1; i < NumElements; ++i) { + auto *Element = + dyn_cast(PossibleMSB->getAggregateElement(i)); + if (!Element || Element->getValue() != SplatMSB->getValue()) { + return nullptr; + } + } + SmallVector Arg1, Arg2; + for (unsigned int i = 0; i < NumElements; ++i) { + auto *E1 = dyn_cast(Const1->getAggregateElement(i)); + auto *E2 = dyn_cast(Const2->getAggregateElement(i)); + if (!E1 || !E2) { + return nullptr; + } + if (E1->getValue().ult(SplatMSB->getValue()) || + E2->getValue().ult(SplatMSB->getValue())) { + return nullptr; + } + Arg1.emplace_back( + ConstantInt::get(A->getType()->getScalarType(), + E1->getValue() - MostSignificantBit + 1)); + Arg2.emplace_back( + ConstantInt::get(B->getType()->getScalarType(), + E2->getValue() - MostSignificantBit + 1)); + } + Constant *ConstVec1 = ConstantVector::get(Arg1); + Constant *ConstVec2 = ConstantVector::get(Arg2); + Value *NewA = + Builder.CreateBinaryIntrinsic(Intrinsic::usub_sat, A, ConstVec1); + Value *NewB = + Builder.CreateBinaryIntrinsic(Intrinsic::usub_sat, B, ConstVec2); + Value *Or = Builder.CreateOr(NewA, NewB); + Value *And = Builder.CreateAnd(Or, PossibleMSB); + return cast(And); + } + return nullptr; +} + /// Visit a SelectInst that has an ICmpInst as its first operand. Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, ICmpInst *ICI) { @@ -2009,6 +2138,9 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, if (Instruction *NewSel = tryToReuseConstantFromSelectInComparison(SI, *ICI, *this)) return NewSel; + if (Instruction *Folded = + foldICmpUSubSatWithAndForMostSignificantBitCmp(SI, ICI, Builder)) + return replaceInstUsesWith(SI, Folded); // NOTE: if we wanted to, this is where to detect integer MIN/MAX bool Changed = false; @@ -4200,10 +4332,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { bool IsCastNeeded = LHS->getType() != SelType; Value *CmpLHS = cast(CondVal)->getOperand(0); Value *CmpRHS = cast(CondVal)->getOperand(1); - if (IsCastNeeded || - (LHS->getType()->isFPOrFPVectorTy() && - ((CmpLHS != LHS && CmpLHS != RHS) || - (CmpRHS != LHS && CmpRHS != RHS)))) { + if (IsCastNeeded || (LHS->getType()->isFPOrFPVectorTy() && + ((CmpLHS != LHS && CmpLHS != RHS) || + (CmpRHS != LHS && CmpRHS != RHS)))) { CmpInst::Predicate MinMaxPred = getMinMaxPred(SPF, SPR.Ordered); Value *Cmp; diff --git a/llvm/test/Transforms/InstCombine/usub_sat_to_msb_mask.ll b/llvm/test/Transforms/InstCombine/usub_sat_to_msb_mask.ll new file mode 100644 index 0000000000000..ffa77f4b42138 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/usub_sat_to_msb_mask.ll @@ -0,0 +1,126 @@ + +; RUN: opt -passes=instcombine -S < %s 2>&1 | FileCheck %s + +declare i8 @llvm.usub.sat.i8(i8, i8) +declare i16 @llvm.usub.sat.i16(i16, i16) +declare i32 @llvm.usub.sat.i32(i32, i32) +declare i64 @llvm.usub.sat.i64(i64, i64) + +define i8 @test_i8(i8 %a, i8 %b) { +; CHECK-LABEL: @test_i8( +; CHECK-NEXT: call i8 @llvm.usub.sat.i8(i8 %a, i8 96) +; CHECK-NEXT: call i8 @llvm.usub.sat.i8(i8 %b, i8 112) +; CHECK-NEXT: or i8 +; CHECK-NEXT: and i8 +; CHECK-NEXT: ret i8 + + %a_sub = call i8 @llvm.usub.sat.i8(i8 %a, i8 223) + %b_sub = call i8 @llvm.usub.sat.i8(i8 %b, i8 239) + %or = or i8 %a_sub, %b_sub + %cmp = icmp eq i8 %or, 0 + %res = select i1 %cmp, i8 0, i8 128 + ret i8 %res +} + +define i16 @test_i16(i16 %a, i16 %b) { +; CHECK-LABEL: @test_i16( +; CHECK-NEXT: call i16 @llvm.usub.sat.i16(i16 %a, i16 32642) +; CHECK-NEXT: call i16 @llvm.usub.sat.i16(i16 %b, i16 32656) +; CHECK-NEXT: or i16 +; CHECK-NEXT: and i16 +; CHECK-NEXT: ret i16 + + %a_sub = call i16 @llvm.usub.sat.i16(i16 %a, i16 65409) + %b_sub = call i16 @llvm.usub.sat.i16(i16 %b, i16 65423) + %or = or i16 %a_sub, %b_sub + %cmp = icmp eq i16 %or, 0 + %res = select i1 %cmp, i16 0, i16 32768 + ret i16 %res +} + +define i32 @test_i32(i32 %a, i32 %b) { +; CHECK-LABEL: @test_i32( +; CHECK-NEXT: call i32 @llvm.usub.sat.i32(i32 %a, i32 224) +; CHECK-NEXT: call i32 @llvm.usub.sat.i32(i32 %b, i32 240) +; CHECK-NEXT: or i32 +; CHECK-NEXT: and i32 +; CHECK-NEXT: ret i32 + + %a_sub = call i32 @llvm.usub.sat.i32(i32 %a, i32 2147483871) + %b_sub = call i32 @llvm.usub.sat.i32(i32 %b, i32 2147483887) + %or = or i32 %a_sub, %b_sub + %cmp = icmp eq i32 %or, 0 + %res = select i1 %cmp, i32 0, i32 2147483648 + ret i32 %res +} + +define i64 @test_i64(i64 %a, i64 %b) { +; CHECK-LABEL: @test_i64( +; CHECK-NEXT: call i64 @llvm.usub.sat.i64(i64 %a, i64 224) +; CHECK-NEXT: call i64 @llvm.usub.sat.i64(i64 %b, i64 240) +; CHECK-NEXT: or i64 +; CHECK-NEXT: and i64 +; CHECK-NEXT: ret i64 + + %a_sub = call i64 @llvm.usub.sat.i64(i64 %a, i64 9223372036854776031) + %b_sub = call i64 @llvm.usub.sat.i64(i64 %b, i64 9223372036854776047) + %or = or i64 %a_sub, %b_sub + %cmp = icmp eq i64 %or, 0 + %res = select i1 %cmp, i64 0, i64 9223372036854775808 + ret i64 %res +} + +define i32 @no_fold_due_to_small_K(i32 %a, i32 %b) { +; CHECK-LABEL: @no_fold_due_to_small_K( +; CHECK: call i32 @llvm.usub.sat.i32(i32 %a, i32 100) +; CHECK: call i32 @llvm.usub.sat.i32(i32 %b, i32 239) +; CHECK: or i32 +; CHECK: icmp eq i32 +; CHECK: select +; CHECK: ret i32 + + %a_sub = call i32 @llvm.usub.sat.i32(i32 %a, i32 100) + %b_sub = call i32 @llvm.usub.sat.i32(i32 %b, i32 239) + %or = or i32 %a_sub, %b_sub + %cmp = icmp eq i32 %or, 0 + %res = select i1 %cmp, i32 0, i32 2147483648 + ret i32 %res +} + +define i32 @commuted_test_neg(i32 %a, i32 %b) { +; CHECK-LABEL: @commuted_test_neg( +; CHECK-NEXT: call i32 @llvm.usub.sat.i32(i32 %b, i32 239) +; CHECK-NEXT: call i32 @llvm.usub.sat.i32(i32 %a, i32 223) +; CHECK-NEXT: or i32 +; CHECK-NEXT: icmp eq i32 +; CHECK-NEXT: select +; CHECK-NEXT: ret i32 + + %b_sub = call i32 @llvm.usub.sat.i32(i32 %b, i32 239) + %a_sub = call i32 @llvm.usub.sat.i32(i32 %a, i32 223) + %or = or i32 %b_sub, %a_sub + %cmp = icmp eq i32 %or, 0 + %res = select i1 %cmp, i32 0, i32 2147483648 + ret i32 %res +} +define <4 x i32> @vector_test(<4 x i32> %a, <4 x i32> %b) { +; CHECK-LABEL: @vector_test( +; CHECK-NEXT: call <4 x i32> @llvm.usub.sat.v4i32(<4 x i32> %a, <4 x i32> splat (i32 224)) +; CHECK-NEXT: call <4 x i32> @llvm.usub.sat.v4i32(<4 x i32> %b, <4 x i32> splat (i32 240)) +; CHECK-NEXT: or <4 x i32> +; CHECK-NEXT: and <4 x i32> +; CHECK-NEXT: ret <4 x i32> + + + %a_sub = call <4 x i32> @llvm.usub.sat.v4i32( + <4 x i32> %a, + <4 x i32> ) + %b_sub = call <4 x i32> @llvm.usub.sat.v4i32( + <4 x i32> %b, + <4 x i32> ) + %or = or <4 x i32> %a_sub, %b_sub + %cmp = icmp eq <4 x i32> %or, zeroinitializer + %res = select <4 x i1> %cmp, <4 x i32> zeroinitializer, + <4 x i32> + ret <4 x i32> %res +} From dbfd5989029598038bb9b3bef4cbba31619aa764 Mon Sep 17 00:00:00 2001 From: Nimit Sachdeva Date: Mon, 28 Jul 2025 17:56:10 -0400 Subject: [PATCH 2/2] refactorization --- .../Transforms/InstCombine/InstCombineSelect.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 74544009e6872..7e4eaa9745917 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -42,8 +42,6 @@ #include "llvm/Support/KnownBits.h" #include "llvm/Transforms/InstCombine/InstCombiner.h" #include -#include -#include #include #define DEBUG_TYPE "instcombine" @@ -52,6 +50,7 @@ using namespace llvm; using namespace PatternMatch; + /// Replace a select operand based on an equality comparison with the identity /// constant of a binop. static Instruction *foldSelectBinOpIdentity(SelectInst &Sel, @@ -1714,6 +1713,7 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp, if (Pred == CmpInst::ICMP_ULT && match(X, m_Add(m_Value(), m_Constant()))) return nullptr; + Value *SelVal0, *SelVal1; // We do not care which one is from where. match(&Sel, m_Select(m_Value(), m_Value(SelVal0), m_Value(SelVal1))); // At least one of these values we are selecting between must be a constant @@ -2004,8 +2004,7 @@ Value *InstCombinerImpl::foldSelectWithConstOpToBinOp(ICmpInst *Cmp, /// %b_sub' = usub.sat(y, IntConst2 - MostSignificantBit) /// %or = or %a_sub', %b_sub' /// %and = and %or, MostSignificantBit -/// If the args are vectors -/// +/// Likewise, for vector arguments as well. static Instruction *foldICmpUSubSatWithAndForMostSignificantBitCmp( SelectInst &SI, ICmpInst *ICI, InstCombiner::BuilderTy &Builder) { auto *CI = dyn_cast(SI.getCondition()); @@ -4332,9 +4331,10 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { bool IsCastNeeded = LHS->getType() != SelType; Value *CmpLHS = cast(CondVal)->getOperand(0); Value *CmpRHS = cast(CondVal)->getOperand(1); - if (IsCastNeeded || (LHS->getType()->isFPOrFPVectorTy() && - ((CmpLHS != LHS && CmpLHS != RHS) || - (CmpRHS != LHS && CmpRHS != RHS)))) { + if (IsCastNeeded || + (LHS->getType()->isFPOrFPVectorTy() && + ((CmpLHS != LHS && CmpLHS != RHS) || + (CmpRHS != LHS && CmpRHS != RHS)))) { CmpInst::Predicate MinMaxPred = getMinMaxPred(SPF, SPR.Ordered); Value *Cmp;