Skip to content

Commit 9d28282

Browse files
committed
[LV] Vectorize FMax w/o fast-math flags.
Add a new recurrence kind for FMax reductions without fast-math flags and a corresponding VPlan transform tries to vectorize without fast-math flags. To do so, a new FindFirstIV reduction is added that tracks the first indices that contain the maximum values. This serves as tie breaker if the partial reduction vector contains NaNs or signed zeros. After the loop, the first index is used to retrieve the final max value after vectorization from the vector containing the partial maximum values
1 parent 4c27279 commit 9d28282

17 files changed

+569
-90
lines changed

llvm/include/llvm/Analysis/IVDescriptors.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ enum class RecurKind {
4747
FMul, ///< Product of floats.
4848
FMin, ///< FP min implemented in terms of select(cmp()).
4949
FMax, ///< FP max implemented in terms of select(cmp()).
50+
FMaxNoFMFs, ///< FP max implemented in terms of select(cmp()), but without
51+
///any fast-math flags. Users need to handle NaNs and signed zeros when generating code.
5052
FMinimum, ///< FP min with llvm.minimum semantics
5153
FMaximum, ///< FP max with llvm.maximum semantics
5254
FMinimumNum, ///< FP min with llvm.minimumnum semantics
@@ -250,8 +252,9 @@ class RecurrenceDescriptor {
250252
/// Returns true if the recurrence kind is a floating-point min/max kind.
251253
static bool isFPMinMaxRecurrenceKind(RecurKind Kind) {
252254
return Kind == RecurKind::FMin || Kind == RecurKind::FMax ||
253-
Kind == RecurKind::FMinimum || Kind == RecurKind::FMaximum ||
254-
Kind == RecurKind::FMinimumNum || Kind == RecurKind::FMaximumNum;
255+
Kind == RecurKind::FMaxNoFMFs || Kind == RecurKind::FMinimum ||
256+
Kind == RecurKind::FMaximum || Kind == RecurKind::FMinimumNum ||
257+
Kind == RecurKind::FMaximumNum;
255258
}
256259

257260
/// Returns true if the recurrence kind is any min/max kind.

llvm/lib/Analysis/IVDescriptors.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -819,7 +819,8 @@ RecurrenceDescriptor::isMinMaxPattern(Instruction *I, RecurKind Kind,
819819
if (match(I, m_OrdOrUnordFMin(m_Value(), m_Value())))
820820
return InstDesc(Kind == RecurKind::FMin, I);
821821
if (match(I, m_OrdOrUnordFMax(m_Value(), m_Value())))
822-
return InstDesc(Kind == RecurKind::FMax, I);
822+
return InstDesc(Kind == RecurKind::FMax || Kind == RecurKind::FMaxNoFMFs,
823+
I);
823824
if (match(I, m_FMinNum(m_Value(), m_Value())))
824825
return InstDesc(Kind == RecurKind::FMin, I);
825826
if (match(I, m_FMaxNum(m_Value(), m_Value())))
@@ -941,10 +942,15 @@ RecurrenceDescriptor::InstDesc RecurrenceDescriptor::isRecurrenceInstr(
941942
m_Intrinsic<Intrinsic::minimumnum>(m_Value(), m_Value())) ||
942943
match(I, m_Intrinsic<Intrinsic::maximumnum>(m_Value(), m_Value()));
943944
};
944-
if (isIntMinMaxRecurrenceKind(Kind) ||
945-
(HasRequiredFMF() && isFPMinMaxRecurrenceKind(Kind)))
945+
if (isIntMinMaxRecurrenceKind(Kind))
946946
return isMinMaxPattern(I, Kind, Prev);
947-
else if (isFMulAddIntrinsic(I))
947+
if (isFPMinMaxRecurrenceKind(Kind)) {
948+
if (HasRequiredFMF())
949+
return isMinMaxPattern(I, Kind, Prev);
950+
if ((Kind == RecurKind::FMax || Kind == RecurKind::FMaxNoFMFs) &&
951+
isMinMaxPattern(I, Kind, Prev).isRecurrence())
952+
return InstDesc(I, RecurKind::FMaxNoFMFs);
953+
} else if (isFMulAddIntrinsic(I))
948954
return InstDesc(Kind == RecurKind::FMulAdd, I,
949955
I->hasAllowReassoc() ? nullptr : I);
950956
return InstDesc(false, I);
@@ -1207,6 +1213,7 @@ unsigned RecurrenceDescriptor::getOpcode(RecurKind Kind) {
12071213
case RecurKind::UMin:
12081214
return Instruction::ICmp;
12091215
case RecurKind::FMax:
1216+
case RecurKind::FMaxNoFMFs:
12101217
case RecurKind::FMin:
12111218
case RecurKind::FMaximum:
12121219
case RecurKind::FMinimum:

llvm/lib/Transforms/Utils/LoopUtils.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -937,6 +937,7 @@ constexpr Intrinsic::ID llvm::getReductionIntrinsicID(RecurKind RK) {
937937
return Intrinsic::vector_reduce_umax;
938938
case RecurKind::UMin:
939939
return Intrinsic::vector_reduce_umin;
940+
case RecurKind::FMaxNoFMFs:
940941
case RecurKind::FMax:
941942
return Intrinsic::vector_reduce_fmax;
942943
case RecurKind::FMin:
@@ -1085,6 +1086,7 @@ CmpInst::Predicate llvm::getMinMaxReductionPredicate(RecurKind RK) {
10851086
case RecurKind::FMin:
10861087
return CmpInst::FCMP_OLT;
10871088
case RecurKind::FMax:
1089+
case RecurKind::FMaxNoFMFs:
10881090
return CmpInst::FCMP_OGT;
10891091
// We do not add FMinimum/FMaximum recurrence kind here since there is no
10901092
// equivalent predicate which compares signed zeroes according to the
@@ -1307,6 +1309,7 @@ Value *llvm::createSimpleReduction(IRBuilderBase &Builder, Value *Src,
13071309
case RecurKind::UMax:
13081310
case RecurKind::UMin:
13091311
case RecurKind::FMax:
1312+
case RecurKind::FMaxNoFMFs:
13101313
case RecurKind::FMin:
13111314
case RecurKind::FMinimum:
13121315
case RecurKind::FMaximum:

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4346,8 +4346,13 @@ bool LoopVectorizationPlanner::isCandidateForEpilogueVectorization(
43464346
ElementCount VF) const {
43474347
// Cross iteration phis such as reductions need special handling and are
43484348
// currently unsupported.
4349-
if (any_of(OrigLoop->getHeader()->phis(),
4350-
[&](PHINode &Phi) { return Legal->isFixedOrderRecurrence(&Phi); }))
4349+
if (any_of(OrigLoop->getHeader()->phis(), [&](PHINode &Phi) {
4350+
return Legal->isFixedOrderRecurrence(&Phi) ||
4351+
(Legal->isReductionVariable(&Phi) &&
4352+
Legal->getReductionVars()
4353+
.find(&Phi)
4354+
->second.getRecurrenceKind() == RecurKind::FMaxNoFMFs);
4355+
}))
43514356
return false;
43524357

43534358
// Phis with uses outside of the loop require special handling and are
@@ -8808,6 +8813,9 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
88088813

88098814
// Adjust the recipes for any inloop reductions.
88108815
adjustRecipesForReductions(Plan, RecipeBuilder, Range.Start);
8816+
if (!VPlanTransforms::runPass(
8817+
VPlanTransforms::handleFMaxReductionsWithoutFastMath, *Plan))
8818+
return nullptr;
88118819

88128820
// Transform recipes to abstract recipes if it is legal and beneficial and
88138821
// clamp the range for better cost estimation.

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23375,6 +23375,7 @@ class HorizontalReduction {
2337523375
case RecurKind::FindFirstIVUMin:
2337623376
case RecurKind::FindLastIVSMax:
2337723377
case RecurKind::FindLastIVUMax:
23378+
case RecurKind::FMaxNoFMFs:
2337823379
case RecurKind::FMaximumNum:
2337923380
case RecurKind::FMinimumNum:
2338023381
case RecurKind::None:

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -980,7 +980,10 @@ class VPInstruction : public VPRecipeWithIRFlags,
980980
ReductionStartVector,
981981
// Creates a step vector starting from 0 to VF with a step of 1.
982982
StepVector,
983-
983+
/// Extracts a single lane (first operand) from a set of vector operands.
984+
/// The lane specifies an index into a vector formed by combining all vector
985+
/// operands (all operands after the first one).
986+
ExtractLane,
984987
};
985988

986989
private:

llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
8585
return ResTy;
8686
}
8787
case Instruction::ICmp:
88+
case Instruction::FCmp:
8889
case VPInstruction::ActiveLaneMask:
8990
assert(inferScalarType(R->getOperand(0)) ==
9091
inferScalarType(R->getOperand(1)) &&
@@ -110,6 +111,8 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
110111
case VPInstruction::BuildStructVector:
111112
case VPInstruction::BuildVector:
112113
return SetResultTyFromOp();
114+
case VPInstruction::ExtractLane:
115+
return inferScalarType(R->getOperand(1));
113116
case VPInstruction::FirstActiveLane:
114117
return Type::getIntNTy(Ctx, 64);
115118
case VPInstruction::ExtractLastElement:

llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#define DEBUG_TYPE "vplan"
2626

2727
using namespace llvm;
28+
using namespace VPlanPatternMatch;
2829

2930
namespace {
3031
// Class that is used to build the plain CFG for the incoming IR.
@@ -427,7 +428,6 @@ static void createLoopRegion(VPlan &Plan, VPBlockBase *HeaderVPB) {
427428
static void addCanonicalIVRecipes(VPlan &Plan, VPBasicBlock *HeaderVPBB,
428429
VPBasicBlock *LatchVPBB, Type *IdxTy,
429430
DebugLoc DL) {
430-
using namespace VPlanPatternMatch;
431431
Value *StartIdx = ConstantInt::get(IdxTy, 0);
432432
auto *StartV = Plan.getOrAddLiveIn(StartIdx);
433433

@@ -628,3 +628,114 @@ void VPlanTransforms::attachCheckBlock(VPlan &Plan, Value *Cond,
628628
Term->addMetadata(LLVMContext::MD_prof, BranchWeights);
629629
}
630630
}
631+
632+
bool VPlanTransforms::handleFMaxReductionsWithoutFastMath(VPlan &Plan) {
633+
VPRegionBlock *LoopRegion = Plan.getVectorLoopRegion();
634+
VPReductionPHIRecipe *RedPhiR = nullptr;
635+
VPRecipeWithIRFlags *MinMaxOp = nullptr;
636+
VPWidenIntOrFpInductionRecipe *WideIV = nullptr;
637+
638+
// Check if there are any FMaxNoFMFs reductions using wide selects that we can
639+
// fix up. To do so, we also need a wide canonical IV to keep track of the
640+
// indices of the max values.
641+
for (auto &R : LoopRegion->getEntryBasicBlock()->phis()) {
642+
// We need a wide canonical IV
643+
if (auto *CurIV = dyn_cast<VPWidenIntOrFpInductionRecipe>(&R)) {
644+
if (!CurIV->isCanonical())
645+
continue;
646+
WideIV = CurIV;
647+
continue;
648+
}
649+
650+
// And a single FMaxNoFMFs reduction phi.
651+
// TODO: Support FMin reductions as well.
652+
auto *CurRedPhiR = dyn_cast<VPReductionPHIRecipe>(&R);
653+
if (!CurRedPhiR)
654+
continue;
655+
if (RedPhiR)
656+
return false;
657+
if (CurRedPhiR->getRecurrenceKind() != RecurKind::FMaxNoFMFs ||
658+
CurRedPhiR->isInLoop() || CurRedPhiR->isOrdered())
659+
continue;
660+
RedPhiR = CurRedPhiR;
661+
662+
// MaxOp feeding the reduction phi must be a select (either wide or a
663+
// replicate recipe), where the phi is the last operand, and the compare
664+
// predicate is strict. This ensures NaNs won't get propagated unless the
665+
// initial value is NaN
666+
VPRecipeBase *Inc = RedPhiR->getBackedgeValue()->getDefiningRecipe();
667+
auto *RepR = dyn_cast<VPReplicateRecipe>(Inc);
668+
if (!isa<VPWidenSelectRecipe>(Inc) &&
669+
!(RepR && (isa<SelectInst>(RepR->getUnderlyingInstr()))))
670+
return false;
671+
672+
MinMaxOp = cast<VPRecipeWithIRFlags>(Inc);
673+
auto *Cmp = cast<VPRecipeWithIRFlags>(MinMaxOp->getOperand(0));
674+
if (MinMaxOp->getOperand(1) == RedPhiR ||
675+
!CmpInst::isStrictPredicate(Cmp->getPredicate()))
676+
return false;
677+
}
678+
679+
// Nothing to do.
680+
if (!RedPhiR)
681+
return true;
682+
683+
// A wide canonical IV is currently required.
684+
// TODO: Create an induction if no suitable existing one is available.
685+
if (!WideIV)
686+
return false;
687+
688+
// Create a reduction that tracks the first indices where the latest maximum
689+
// value has been selected. This is later used to select the max value from
690+
// the partial reductions in a way that correctly handles signed zeros and
691+
// NaNs in the input.
692+
// Note that we do not need to check if the induction may hit the sentinel
693+
// value. If the sentinel value gets hit, the final reduction value is at the
694+
// last index or the maximum was never set and all lanes contain the start
695+
// value. In either case, the correct value is selected.
696+
unsigned IVWidth =
697+
VPTypeAnalysis(Plan).inferScalarType(WideIV)->getScalarSizeInBits();
698+
LLVMContext &Ctx = Plan.getScalarHeader()->getIRBasicBlock()->getContext();
699+
VPValue *UMinSentinel =
700+
Plan.getOrAddLiveIn(ConstantInt::get(Ctx, APInt::getMaxValue(IVWidth)));
701+
auto *IdxPhi = new VPReductionPHIRecipe(nullptr, RecurKind::FindFirstIVUMin,
702+
*UMinSentinel, false, false, 1);
703+
IdxPhi->insertBefore(RedPhiR);
704+
auto *MinIdxSel = new VPInstruction(
705+
Instruction::Select, {MinMaxOp->getOperand(0), WideIV, IdxPhi});
706+
MinIdxSel->insertAfter(MinMaxOp);
707+
IdxPhi->addOperand(MinIdxSel);
708+
709+
// Find the first index of with the maximum value. This is used to extract the
710+
// lane with the final max value and is needed to handle signed zeros and NaNs
711+
// in the input.
712+
auto *MiddleVPBB = Plan.getMiddleBlock();
713+
auto *OrigRdxResult = cast<VPSingleDefRecipe>(&MiddleVPBB->front());
714+
VPBuilder Builder(OrigRdxResult->getParent(),
715+
std::next(OrigRdxResult->getIterator()));
716+
717+
// Create mask for lanes that have the max value and use it to mask out
718+
// indices that don't contain maximum values.
719+
auto *MaskFinalMaxValue = Builder.createNaryOp(
720+
Instruction::FCmp, {OrigRdxResult->getOperand(1), OrigRdxResult},
721+
VPIRFlags(CmpInst::FCMP_OEQ));
722+
auto *IndicesWithMaxValue = Builder.createNaryOp(
723+
Instruction::Select, {MaskFinalMaxValue, MinIdxSel, UMinSentinel});
724+
auto *FirstMaxIdx = Builder.createNaryOp(
725+
VPInstruction::ComputeFindIVResult,
726+
{IdxPhi, WideIV->getStartValue(), UMinSentinel, IndicesWithMaxValue});
727+
// Convert the index of the first max value to an index in the vector lanes of
728+
// the partial reduction results. This ensures we select the first max value
729+
// and acts as a tie-breaker if the partial reductions contain signed zeros.
730+
auto *FirstMaxLane =
731+
Builder.createNaryOp(Instruction::URem, {FirstMaxIdx, &Plan.getVFxUF()});
732+
733+
// Extract the final max value and update the users.
734+
auto *Res = Builder.createNaryOp(
735+
VPInstruction::ExtractLane, {FirstMaxLane, OrigRdxResult->getOperand(1)});
736+
OrigRdxResult->replaceUsesWithIf(Res,
737+
[MaskFinalMaxValue](VPUser &U, unsigned) {
738+
return &U != MaskFinalMaxValue;
739+
});
740+
return true;
741+
}

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,7 @@ Value *VPInstruction::generate(VPTransformState &State) {
585585
Value *Op = State.get(getOperand(0), vputils::onlyFirstLaneUsed(this));
586586
return Builder.CreateFreeze(Op, Name);
587587
}
588+
case Instruction::FCmp:
588589
case Instruction::ICmp: {
589590
bool OnlyFirstLaneUsed = vputils::onlyFirstLaneUsed(this);
590591
Value *A = State.get(getOperand(0), OnlyFirstLaneUsed);
@@ -595,7 +596,8 @@ Value *VPInstruction::generate(VPTransformState &State) {
595596
llvm_unreachable("should be handled by VPPhi::execute");
596597
}
597598
case Instruction::Select: {
598-
bool OnlyFirstLaneUsed = vputils::onlyFirstLaneUsed(this);
599+
bool OnlyFirstLaneUsed =
600+
State.VF.isScalar() || vputils::onlyFirstLaneUsed(this);
599601
Value *Cond = State.get(getOperand(0), OnlyFirstLaneUsed);
600602
Value *Op1 = State.get(getOperand(1), OnlyFirstLaneUsed);
601603
Value *Op2 = State.get(getOperand(2), OnlyFirstLaneUsed);
@@ -858,7 +860,30 @@ Value *VPInstruction::generate(VPTransformState &State) {
858860
Value *Res = State.get(getOperand(0));
859861
for (VPValue *Op : drop_begin(operands()))
860862
Res = Builder.CreateOr(Res, State.get(Op));
861-
return Builder.CreateOrReduce(Res);
863+
return Res->getType()->isIntegerTy(1) ? Res : Builder.CreateOrReduce(Res);
864+
}
865+
case VPInstruction::ExtractLane: {
866+
Value *LaneToExtract = State.get(getOperand(0), true);
867+
Type *IdxTy = State.TypeAnalysis.inferScalarType(getOperand(0));
868+
Value *Res = nullptr;
869+
Value *RuntimeVF = getRuntimeVF(State.Builder, IdxTy, State.VF);
870+
871+
for (unsigned Idx = 1; Idx != getNumOperands(); ++Idx) {
872+
Value *VectorStart =
873+
Builder.CreateMul(RuntimeVF, ConstantInt::get(IdxTy, Idx - 1));
874+
Value *VectorIdx = Builder.CreateSub(LaneToExtract, VectorStart);
875+
Value *Ext = State.VF.isScalar()
876+
? State.get(getOperand(Idx))
877+
: Builder.CreateExtractElement(
878+
State.get(getOperand(Idx)), VectorIdx);
879+
if (Res) {
880+
Value *Cmp = Builder.CreateICmpUGE(LaneToExtract, VectorStart);
881+
Res = Builder.CreateSelect(Cmp, Ext, Res);
882+
} else {
883+
Res = Ext;
884+
}
885+
}
886+
return Res;
862887
}
863888
case VPInstruction::FirstActiveLane: {
864889
if (getNumOperands() == 1) {
@@ -984,7 +1009,8 @@ bool VPInstruction::isVectorToScalar() const {
9841009
getOpcode() == VPInstruction::ComputeAnyOfResult ||
9851010
getOpcode() == VPInstruction::ComputeFindIVResult ||
9861011
getOpcode() == VPInstruction::ComputeReductionResult ||
987-
getOpcode() == VPInstruction::AnyOf;
1012+
getOpcode() == VPInstruction::AnyOf ||
1013+
getOpcode() == VPInstruction::ExtractLane;
9881014
}
9891015

9901016
bool VPInstruction::isSingleScalar() const {
@@ -1031,6 +1057,7 @@ bool VPInstruction::opcodeMayReadOrWriteFromMemory() const {
10311057
switch (getOpcode()) {
10321058
case Instruction::ExtractElement:
10331059
case Instruction::Freeze:
1060+
case Instruction::FCmp:
10341061
case Instruction::ICmp:
10351062
case Instruction::Select:
10361063
case VPInstruction::AnyOf:
@@ -1066,6 +1093,7 @@ bool VPInstruction::onlyFirstLaneUsed(const VPValue *Op) const {
10661093
return Op == getOperand(1);
10671094
case Instruction::PHI:
10681095
return true;
1096+
case Instruction::FCmp:
10691097
case Instruction::ICmp:
10701098
case Instruction::Select:
10711099
case Instruction::Or:
@@ -1098,6 +1126,7 @@ bool VPInstruction::onlyFirstPartUsed(const VPValue *Op) const {
10981126
switch (getOpcode()) {
10991127
default:
11001128
return false;
1129+
case Instruction::FCmp:
11011130
case Instruction::ICmp:
11021131
case Instruction::Select:
11031132
return vputils::onlyFirstPartUsed(this);
@@ -1782,7 +1811,7 @@ bool VPIRFlags::flagsValidForOpcode(unsigned Opcode) const {
17821811
return Opcode == Instruction::ZExt;
17831812
break;
17841813
case OperationType::Cmp:
1785-
return Opcode == Instruction::ICmp;
1814+
return Opcode == Instruction::FCmp || Opcode == Instruction::ICmp;
17861815
case OperationType::Other:
17871816
return true;
17881817
}

llvm/lib/Transforms/Vectorize/VPlanTransforms.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,8 @@ struct VPlanTransforms {
196196
VPBasicBlock *LatchVPBB,
197197
VFRange &Range);
198198

199+
static bool handleFMaxReductionsWithoutFastMath(VPlan &Plan);
200+
199201
/// Replace loop regions with explicit CFG.
200202
static void dissolveLoopRegions(VPlan &Plan);
201203

0 commit comments

Comments
 (0)