Skip to content

Commit 42d1194

Browse files
committed
[LV] Vectorize maxnum/minnum w/o fast-math flags.
Update LV to vectorize maxnum/minnum reductions without fast-math flags, by adding an extra check in the loop if any inputs to maxnum/minnum are NaN. If any input is NaN, *exit the vector loop, *compute the reduction result up to the vector iteration that contained NaN inputs and * resume in the scalar loop New recurrence kinds are added for reductions using maxnum/minnum without fast-math flags. The new recurrence kinds are not supported in the code to generate IR to perform the reductions to prevent accidential mis-use. Users need to add the required checks ensuring no NaN inputs, and convert to regular FMin/FMax recurrence kinds.
1 parent 2cdcc4f commit 42d1194

16 files changed

+492
-62
lines changed

llvm/include/llvm/Analysis/IVDescriptors.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ enum class RecurKind {
4747
FMul, ///< Product of floats.
4848
FMin, ///< FP min implemented in terms of select(cmp()).
4949
FMax, ///< FP max implemented in terms of select(cmp()).
50+
FMinNumNoFMFs, ///< FP min with llvm.minnum semantics and no fast-math flags.
51+
FMaxNumNoFMFs, ///< FP max with llvm.maxnumsemantics and no fast-math flags.
5052
FMinimum, ///< FP min with llvm.minimum semantics
5153
FMaximum, ///< FP max with llvm.maximum semantics
5254
FMinimumNum, ///< FP min with llvm.minimumnum semantics

llvm/lib/Analysis/IVDescriptors.cpp

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -941,10 +941,27 @@ RecurrenceDescriptor::InstDesc RecurrenceDescriptor::isRecurrenceInstr(
941941
m_Intrinsic<Intrinsic::minimumnum>(m_Value(), m_Value())) ||
942942
match(I, m_Intrinsic<Intrinsic::maximumnum>(m_Value(), m_Value()));
943943
};
944-
if (isIntMinMaxRecurrenceKind(Kind) ||
945-
(HasRequiredFMF() && isFPMinMaxRecurrenceKind(Kind)))
944+
if (isIntMinMaxRecurrenceKind(Kind))
946945
return isMinMaxPattern(I, Kind, Prev);
947-
else if (isFMulAddIntrinsic(I))
946+
if (isFPMinMaxRecurrenceKind(Kind)) {
947+
if (HasRequiredFMF())
948+
return isMinMaxPattern(I, Kind, Prev);
949+
// We may be able to vectorize FMax/FMin reductions using maxnum/minnum
950+
// intrinsics with extra checks ensuring the inputs are not NaN.
951+
auto *StartV = dyn_cast<ConstantFP>(
952+
OrigPhi->getIncomingValueForBlock(L->getLoopPredecessor()));
953+
if (StartV && !StartV->getValue().isNaN() &&
954+
isMinMaxPattern(I, Kind, Prev).isRecurrence()) {
955+
if (((Kind == RecurKind::FMax &&
956+
match(I, m_Intrinsic<Intrinsic::maxnum>(m_Value(), m_Value()))) ||
957+
Kind == RecurKind::FMaxNumNoFMFs))
958+
return InstDesc(I, RecurKind::FMaxNumNoFMFs);
959+
if (((Kind == RecurKind::FMin &&
960+
match(I, m_Intrinsic<Intrinsic::minnum>(m_Value(), m_Value()))) ||
961+
Kind == RecurKind::FMinNumNoFMFs))
962+
return InstDesc(I, RecurKind::FMinNumNoFMFs);
963+
}
964+
} else if (isFMulAddIntrinsic(I))
948965
return InstDesc(Kind == RecurKind::FMulAdd, I,
949966
I->hasAllowReassoc() ? nullptr : I);
950967
return InstDesc(false, I);

llvm/lib/Transforms/Utils/LoopUtils.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -938,8 +938,10 @@ constexpr Intrinsic::ID llvm::getReductionIntrinsicID(RecurKind RK) {
938938
case RecurKind::UMin:
939939
return Intrinsic::vector_reduce_umin;
940940
case RecurKind::FMax:
941+
case RecurKind::FMaxNumNoFMFs:
941942
return Intrinsic::vector_reduce_fmax;
942943
case RecurKind::FMin:
944+
case RecurKind::FMinNumNoFMFs:
943945
return Intrinsic::vector_reduce_fmin;
944946
case RecurKind::FMaximum:
945947
return Intrinsic::vector_reduce_fmaximum;

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4347,8 +4347,15 @@ bool LoopVectorizationPlanner::isCandidateForEpilogueVectorization(
43474347
ElementCount VF) const {
43484348
// Cross iteration phis such as reductions need special handling and are
43494349
// currently unsupported.
4350-
if (any_of(OrigLoop->getHeader()->phis(),
4351-
[&](PHINode &Phi) { return Legal->isFixedOrderRecurrence(&Phi); }))
4350+
if (any_of(OrigLoop->getHeader()->phis(), [&](PHINode &Phi) {
4351+
if (Legal->isReductionVariable(&Phi)) {
4352+
RecurKind RK =
4353+
Legal->getRecurrenceDescriptor(&Phi).getRecurrenceKind();
4354+
return RK == RecurKind::FMinNumNoFMFs ||
4355+
RK == RecurKind::FMaxNumNoFMFs;
4356+
}
4357+
return Legal->isFixedOrderRecurrence(&Phi);
4358+
}))
43524359
return false;
43534360

43544361
// Phis with uses outside of the loop require special handling and are
@@ -8769,6 +8776,9 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
87698776

87708777
// Adjust the recipes for any inloop reductions.
87718778
adjustRecipesForReductions(Plan, RecipeBuilder, Range.Start);
8779+
if (!VPlanTransforms::runPass(
8780+
VPlanTransforms::handleMaxMinNumReductionsWithoutFastMath, *Plan))
8781+
return nullptr;
87728782

87738783
// Transform recipes to abstract recipes if it is legal and beneficial and
87748784
// clamp the range for better cost estimation.

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23193,6 +23193,8 @@ class HorizontalReduction {
2319323193
case RecurKind::FindFirstIVUMin:
2319423194
case RecurKind::FindLastIVSMax:
2319523195
case RecurKind::FindLastIVUMax:
23196+
case RecurKind::FMaxNumNoFMFs:
23197+
case RecurKind::FMinNumNoFMFs:
2319623198
case RecurKind::FMaximumNum:
2319723199
case RecurKind::FMinimumNum:
2319823200
case RecurKind::None:
@@ -23330,6 +23332,8 @@ class HorizontalReduction {
2333023332
case RecurKind::FindFirstIVUMin:
2333123333
case RecurKind::FindLastIVSMax:
2333223334
case RecurKind::FindLastIVUMax:
23335+
case RecurKind::FMaxNumNoFMFs:
23336+
case RecurKind::FMinNumNoFMFs:
2333323337
case RecurKind::FMaximumNum:
2333423338
case RecurKind::FMinimumNum:
2333523339
case RecurKind::None:
@@ -23432,6 +23436,8 @@ class HorizontalReduction {
2343223436
case RecurKind::FindFirstIVUMin:
2343323437
case RecurKind::FindLastIVSMax:
2343423438
case RecurKind::FindLastIVUMax:
23439+
case RecurKind::FMaxNumNoFMFs:
23440+
case RecurKind::FMinNumNoFMFs:
2343523441
case RecurKind::FMaximumNum:
2343623442
case RecurKind::FMinimumNum:
2343723443
case RecurKind::None:

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1356,16 +1356,16 @@ class LLVM_ABI_FOR_TEST VPWidenRecipe : public VPRecipeWithIRFlags,
13561356
unsigned Opcode;
13571357

13581358
public:
1359+
VPWidenRecipe(Instruction &I, ArrayRef<VPValue *> Operands)
1360+
: VPRecipeWithIRFlags(VPDef::VPWidenSC, Operands, I), VPIRMetadata(I),
1361+
Opcode(I.getOpcode()) {}
1362+
13591363
VPWidenRecipe(unsigned Opcode, ArrayRef<VPValue *> Operands,
13601364
const VPIRFlags &Flags, const VPIRMetadata &Metadata,
13611365
DebugLoc DL)
13621366
: VPRecipeWithIRFlags(VPDef::VPWidenSC, Operands, Flags, DL),
13631367
VPIRMetadata(Metadata), Opcode(Opcode) {}
13641368

1365-
VPWidenRecipe(Instruction &I, ArrayRef<VPValue *> Operands)
1366-
: VPRecipeWithIRFlags(VPDef::VPWidenSC, Operands, I), VPIRMetadata(I),
1367-
Opcode(I.getOpcode()) {}
1368-
13691369
~VPWidenRecipe() override = default;
13701370

13711371
VPWidenRecipe *clone() override {

llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
8484
return ResTy;
8585
}
8686
case Instruction::ICmp:
87+
case Instruction::FCmp:
8788
case VPInstruction::ActiveLaneMask:
8889
assert(inferScalarType(R->getOperand(0)) ==
8990
inferScalarType(R->getOperand(1)) &&

llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,3 +652,140 @@ void VPlanTransforms::attachCheckBlock(VPlan &Plan, Value *Cond,
652652
Term->addMetadata(LLVMContext::MD_prof, BranchWeights);
653653
}
654654
}
655+
656+
static VPValue *getMinMaxCompareValue(VPSingleDefRecipe *MinMaxOp,
657+
VPReductionPHIRecipe *RedPhi) {
658+
auto *RepR = dyn_cast<VPReplicateRecipe>(MinMaxOp);
659+
if (!isa<VPWidenIntrinsicRecipe>(MinMaxOp) &&
660+
!(RepR && (isa<IntrinsicInst>(RepR->getUnderlyingInstr()))))
661+
return nullptr;
662+
663+
if (MinMaxOp->getOperand(0) == RedPhi)
664+
return MinMaxOp->getOperand(1);
665+
return MinMaxOp->getOperand(0);
666+
}
667+
668+
/// Returns true if there VPlan is read-only and execution can be resumed at the
669+
/// beginning of the last vector iteration in the scalar loop
670+
static bool canResumeInScalarLoopFromVectorLoop(VPlan &Plan) {
671+
for (VPBlockBase *VPB : vp_depth_first_shallow(
672+
Plan.getVectorLoopRegion()->getEntryBasicBlock())) {
673+
auto *VPBB = dyn_cast<VPBasicBlock>(VPB);
674+
if (!VPBB)
675+
return false;
676+
for (auto &R : *VPBB) {
677+
if (match(&R, m_BranchOnCount(m_VPValue(), m_VPValue())))
678+
continue;
679+
if (R.mayWriteToMemory())
680+
return false;
681+
}
682+
}
683+
return true;
684+
}
685+
686+
bool VPlanTransforms::handleMaxMinNumReductionsWithoutFastMath(VPlan &Plan) {
687+
VPRegionBlock *LoopRegion = Plan.getVectorLoopRegion();
688+
VPValue *AnyNaN = nullptr;
689+
VPReductionPHIRecipe *RedPhiR = nullptr;
690+
VPRecipeWithIRFlags *MinMaxOp = nullptr;
691+
bool HasUnsupportedPhi = false;
692+
for (auto &R : LoopRegion->getEntryBasicBlock()->phis()) {
693+
HasUnsupportedPhi |=
694+
!isa<VPCanonicalIVPHIRecipe, VPWidenIntOrFpInductionRecipe,
695+
VPReductionPHIRecipe>(&R);
696+
auto *Cur = dyn_cast<VPReductionPHIRecipe>(&R);
697+
if (!Cur)
698+
continue;
699+
if (RedPhiR)
700+
return false;
701+
if (Cur->getRecurrenceKind() != RecurKind::FMaxNumNoFMFs &&
702+
Cur->getRecurrenceKind() != RecurKind::FMinNumNoFMFs)
703+
continue;
704+
705+
RedPhiR = Cur;
706+
MinMaxOp = dyn_cast<VPRecipeWithIRFlags>(
707+
RedPhiR->getBackedgeValue()->getDefiningRecipe());
708+
if (!MinMaxOp)
709+
return false;
710+
VPValue *In = getMinMaxCompareValue(MinMaxOp, RedPhiR);
711+
if (!In)
712+
return false;
713+
714+
auto *IsNaN =
715+
new VPInstruction(Instruction::FCmp, {In, In}, {CmpInst::FCMP_UNO}, {});
716+
IsNaN->insertBefore(MinMaxOp);
717+
AnyNaN = new VPInstruction(VPInstruction::AnyOf, {IsNaN});
718+
AnyNaN->getDefiningRecipe()->insertAfter(IsNaN);
719+
}
720+
721+
if (!AnyNaN)
722+
return true;
723+
724+
if (HasUnsupportedPhi || !canResumeInScalarLoopFromVectorLoop(Plan))
725+
return false;
726+
727+
auto *MiddleVPBB = Plan.getMiddleBlock();
728+
auto *RdxResult = dyn_cast<VPInstruction>(&MiddleVPBB->front());
729+
if (!RdxResult ||
730+
RdxResult->getOpcode() != VPInstruction::ComputeReductionResult ||
731+
RdxResult->getOperand(0) != RedPhiR)
732+
return false;
733+
734+
auto *ScalarPH = Plan.getScalarPreheader();
735+
// Update the resume phis in the scalar preheader. They all must either resume
736+
// from the reduction result or the canonical induction. Bail out if there are
737+
// other resume phis.
738+
for (auto &R : ScalarPH->phis()) {
739+
auto *ResumeR = cast<VPPhi>(&R);
740+
VPValue *VecV = ResumeR->getOperand(0);
741+
VPValue *BypassV = ResumeR->getOperand(ResumeR->getNumOperands() - 1);
742+
if (VecV != RdxResult && VecV != &Plan.getVectorTripCount())
743+
return false;
744+
ResumeR->setOperand(
745+
1, VecV == &Plan.getVectorTripCount() ? Plan.getCanonicalIV() : VecV);
746+
ResumeR->addOperand(BypassV);
747+
}
748+
749+
// Create a new reduction phi recipe with either FMin/FMax, replacing
750+
// FMinNumNoFMFs/FMaxNumNoFMFs.
751+
RecurKind NewRK = RedPhiR->getRecurrenceKind() != RecurKind::FMinNumNoFMFs
752+
? RecurKind::FMin
753+
: RecurKind::FMax;
754+
auto *NewRedPhiR = new VPReductionPHIRecipe(
755+
cast<PHINode>(RedPhiR->getUnderlyingValue()), NewRK,
756+
*RedPhiR->getStartValue(), RedPhiR->isInLoop(), RedPhiR->isOrdered());
757+
NewRedPhiR->addOperand(RedPhiR->getOperand(1));
758+
NewRedPhiR->insertBefore(RedPhiR);
759+
RedPhiR->replaceAllUsesWith(NewRedPhiR);
760+
RedPhiR->eraseFromParent();
761+
762+
// Update the loop exit condition to exit if either any of the inputs is NaN
763+
// or the vector trip count is reached.
764+
VPBasicBlock *LatchVPBB = LoopRegion->getExitingBasicBlock();
765+
VPBuilder Builder(LatchVPBB->getTerminator());
766+
auto *LatchExitingBranch = cast<VPInstruction>(LatchVPBB->getTerminator());
767+
assert(LatchExitingBranch->getOpcode() == VPInstruction::BranchOnCount &&
768+
"Unexpected terminator");
769+
auto *IsLatchExitTaken =
770+
Builder.createICmp(CmpInst::ICMP_EQ, LatchExitingBranch->getOperand(0),
771+
LatchExitingBranch->getOperand(1));
772+
auto *AnyExitTaken =
773+
Builder.createNaryOp(Instruction::Or, {AnyNaN, IsLatchExitTaken});
774+
Builder.createNaryOp(VPInstruction::BranchOnCond, AnyExitTaken);
775+
LatchExitingBranch->eraseFromParent();
776+
777+
// Split the middle block and introduce a new block, branching to the scalar
778+
// preheader to resume iteration in the scalar loop if any NaNs have been
779+
// encountered.
780+
MiddleVPBB->splitAt(std::prev(MiddleVPBB->end()));
781+
Builder.setInsertPoint(MiddleVPBB, MiddleVPBB->begin());
782+
auto *NewSel =
783+
Builder.createSelect(AnyNaN, NewRedPhiR, RdxResult->getOperand(1));
784+
RdxResult->setOperand(1, NewSel);
785+
Builder.setInsertPoint(MiddleVPBB);
786+
Builder.createNaryOp(VPInstruction::BranchOnCond, AnyNaN);
787+
VPBlockUtils::connectBlocks(MiddleVPBB, ScalarPH);
788+
MiddleVPBB->swapSuccessors();
789+
std::swap(ScalarPH->getPredecessors()[1], ScalarPH->getPredecessors().back());
790+
return true;
791+
}

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,7 @@ Value *VPInstruction::generate(VPTransformState &State) {
587587
Value *Op = State.get(getOperand(0), vputils::onlyFirstLaneUsed(this));
588588
return Builder.CreateFreeze(Op, Name);
589589
}
590+
case Instruction::FCmp:
590591
case Instruction::ICmp: {
591592
bool OnlyFirstLaneUsed = vputils::onlyFirstLaneUsed(this);
592593
Value *A = State.get(getOperand(0), OnlyFirstLaneUsed);
@@ -860,7 +861,7 @@ Value *VPInstruction::generate(VPTransformState &State) {
860861
Value *Res = State.get(getOperand(0));
861862
for (VPValue *Op : drop_begin(operands()))
862863
Res = Builder.CreateOr(Res, State.get(Op));
863-
return Builder.CreateOrReduce(Res);
864+
return State.VF.isScalar() ? Res : Builder.CreateOrReduce(Res);
864865
}
865866
case VPInstruction::FirstActiveLane: {
866867
if (getNumOperands() == 1) {
@@ -1033,6 +1034,7 @@ bool VPInstruction::opcodeMayReadOrWriteFromMemory() const {
10331034
switch (getOpcode()) {
10341035
case Instruction::ExtractElement:
10351036
case Instruction::Freeze:
1037+
case Instruction::FCmp:
10361038
case Instruction::ICmp:
10371039
case Instruction::Select:
10381040
case VPInstruction::AnyOf:
@@ -1068,6 +1070,7 @@ bool VPInstruction::onlyFirstLaneUsed(const VPValue *Op) const {
10681070
return Op == getOperand(1);
10691071
case Instruction::PHI:
10701072
return true;
1073+
case Instruction::FCmp:
10711074
case Instruction::ICmp:
10721075
case Instruction::Select:
10731076
case Instruction::Or:
@@ -1100,6 +1103,7 @@ bool VPInstruction::onlyFirstPartUsed(const VPValue *Op) const {
11001103
switch (getOpcode()) {
11011104
default:
11021105
return false;
1106+
case Instruction::FCmp:
11031107
case Instruction::ICmp:
11041108
case Instruction::Select:
11051109
return vputils::onlyFirstPartUsed(this);
@@ -1786,7 +1790,7 @@ bool VPIRFlags::flagsValidForOpcode(unsigned Opcode) const {
17861790
return Opcode == Instruction::ZExt;
17871791
break;
17881792
case OperationType::Cmp:
1789-
return Opcode == Instruction::ICmp;
1793+
return Opcode == Instruction::FCmp || Opcode == Instruction::ICmp;
17901794
case OperationType::Other:
17911795
return true;
17921796
}

llvm/lib/Transforms/Vectorize/VPlanTransforms.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,11 @@ struct VPlanTransforms {
103103
/// not valid.
104104
static bool adjustFixedOrderRecurrences(VPlan &Plan, VPBuilder &Builder);
105105

106+
/// Check if \p Plan contains any FMaxNumNoFMFs or FMinNumNoFMFs reductions.
107+
/// If they do, try to update the vector loop to exit early if any input is
108+
/// NaN and resume executing in the scalar loop to handle the NaNs there.
109+
static bool handleMaxMinNumReductionsWithoutFastMath(VPlan &Plan);
110+
106111
/// Clear NSW/NUW flags from reduction instructions if necessary.
107112
static void clearReductionWrapFlags(VPlan &Plan);
108113

0 commit comments

Comments
 (0)