Skip to content

Commit 4b46fe6

Browse files
committed
[VPlan] Add start VPV to compute-reduction-result.
1 parent d71d1d1 commit 4b46fe6

File tree

6 files changed

+76
-24
lines changed

6 files changed

+76
-24
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7485,6 +7485,13 @@ static void addRuntimeUnrollDisableMetaData(Loop *L) {
74857485
}
74867486
}
74877487

7488+
static Value *getStartValueFromReductionResult(VPInstruction *RdxResult) {
7489+
using namespace VPlanPatternMatch;
7490+
VPValue *StartVPV = RdxResult->getOperand(1);
7491+
match(StartVPV, m_Freeze(m_VPValue(StartVPV)));
7492+
return StartVPV->getLiveInIRValue();
7493+
}
7494+
74887495
// If \p R is a ComputeReductionResult when vectorizing the epilog loop,
74897496
// fix the reduction's scalar PHI node by adding the incoming value from the
74907497
// main vector loop.
@@ -7493,7 +7500,8 @@ static void fixReductionScalarResumeWhenVectorizingEpilog(
74937500
BasicBlock *BypassBlock) {
74947501
auto *EpiRedResult = dyn_cast<VPInstruction>(R);
74957502
if (!EpiRedResult ||
7496-
(EpiRedResult->getOpcode() != VPInstruction::ComputeReductionResult &&
7503+
(EpiRedResult->getOpcode() != VPInstruction::ComputeAnyOfResult &&
7504+
EpiRedResult->getOpcode() != VPInstruction::ComputeReductionResult &&
74977505
EpiRedResult->getOpcode() != VPInstruction::ComputeFindLastIVResult))
74987506
return;
74997507

@@ -7505,15 +7513,19 @@ static void fixReductionScalarResumeWhenVectorizingEpilog(
75057513
EpiRedHeaderPhi->getStartValue()->getUnderlyingValue();
75067514
if (RecurrenceDescriptor::isAnyOfRecurrenceKind(
75077515
RdxDesc.getRecurrenceKind())) {
7516+
Value *StartV = EpiRedResult->getOperand(1)->getLiveInIRValue();
7517+
(void)StartV;
75087518
auto *Cmp = cast<ICmpInst>(MainResumeValue);
75097519
assert(Cmp->getPredicate() == CmpInst::ICMP_NE &&
75107520
"AnyOf expected to start with ICMP_NE");
7511-
assert(Cmp->getOperand(1) == RdxDesc.getRecurrenceStartValue() &&
7521+
assert(Cmp->getOperand(1) == StartV &&
75127522
"AnyOf expected to start by comparing main resume value to original "
75137523
"start value");
75147524
MainResumeValue = Cmp->getOperand(0);
75157525
} else if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(
75167526
RdxDesc.getRecurrenceKind())) {
7527+
Value *StartV = getStartValueFromReductionResult(EpiRedResult);
7528+
(void)StartV;
75177529
using namespace llvm::PatternMatch;
75187530
Value *Cmp, *OrigResumeV, *CmpOp;
75197531
bool IsExpectedPattern =
@@ -7522,10 +7534,7 @@ static void fixReductionScalarResumeWhenVectorizingEpilog(
75227534
m_Value(OrigResumeV))) &&
75237535
(match(Cmp, m_SpecificICmp(ICmpInst::ICMP_EQ, m_Specific(OrigResumeV),
75247536
m_Value(CmpOp))) &&
7525-
(match(CmpOp,
7526-
m_Freeze(m_Specific(RdxDesc.getRecurrenceStartValue()))) ||
7527-
(CmpOp == RdxDesc.getRecurrenceStartValue() &&
7528-
isGuaranteedNotToBeUndefOrPoison(CmpOp))));
7537+
((CmpOp == StartV && isGuaranteedNotToBeUndefOrPoison(CmpOp))));
75297538
assert(IsExpectedPattern && "Unexpected reduction resume pattern");
75307539
(void)IsExpectedPattern;
75317540
MainResumeValue = OrigResumeV;
@@ -9467,7 +9476,10 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
94679476
OrigExitingVPV->replaceUsesWithIf(NewExitingVPV, [](VPUser &U, unsigned) {
94689477
return isa<VPInstruction>(&U) &&
94699478
(cast<VPInstruction>(&U)->getOpcode() ==
9479+
VPInstruction::ComputeAnyOfResult ||
9480+
cast<VPInstruction>(&U)->getOpcode() ==
94709481
VPInstruction::ComputeReductionResult ||
9482+
94719483
cast<VPInstruction>(&U)->getOpcode() ==
94729484
VPInstruction::ComputeFindLastIVResult);
94739485
});
@@ -9497,6 +9509,12 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
94979509
FinalReductionResult =
94989510
Builder.createNaryOp(VPInstruction::ComputeFindLastIVResult,
94999511
{PhiR, Start, NewExitingVPV}, ExitDL);
9512+
} else if (RecurrenceDescriptor::isAnyOfRecurrenceKind(
9513+
RdxDesc.getRecurrenceKind())) {
9514+
VPValue *Start = PhiR->getStartValue();
9515+
FinalReductionResult =
9516+
Builder.createNaryOp(VPInstruction::ComputeAnyOfResult,
9517+
{PhiR, Start, NewExitingVPV}, ExitDL);
95009518
} else {
95019519
VPIRFlags Flags = RecurrenceDescriptor::isFloatingPointRecurrenceKind(
95029520
RdxDesc.getRecurrenceKind())
@@ -10050,23 +10068,36 @@ preparePlanForEpilogueVectorLoop(VPlan &Plan, Loop *L,
1005010068
Value *ResumeV = nullptr;
1005110069
// TODO: Move setting of resume values to prepareToExecute.
1005210070
if (auto *ReductionPhi = dyn_cast<VPReductionPHIRecipe>(&R)) {
10071+
auto *RdxResult =
10072+
cast<VPInstruction>(*find_if(ReductionPhi->users(), [](VPUser *U) {
10073+
auto *VPI = dyn_cast<VPInstruction>(U);
10074+
return VPI &&
10075+
(VPI->getOpcode() == VPInstruction::ComputeReductionResult ||
10076+
VPI->getOpcode() == VPInstruction::ComputeFindLastIVResult);
10077+
}));
1005310078
ResumeV = cast<PHINode>(ReductionPhi->getUnderlyingInstr())
1005410079
->getIncomingValueForBlock(L->getLoopPreheader());
1005510080
const RecurrenceDescriptor &RdxDesc =
1005610081
ReductionPhi->getRecurrenceDescriptor();
1005710082
RecurKind RK = RdxDesc.getRecurrenceKind();
1005810083
if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK)) {
10084+
Value *StartV = RdxResult->getOperand(1)->getLiveInIRValue();
10085+
assert(RdxDesc.getRecurrenceStartValue() == StartV &&
10086+
"start value from ComputeAnyOfResult must match");
10087+
1005910088
// VPReductionPHIRecipes for AnyOf reductions expect a boolean as
1006010089
// start value; compare the final value from the main vector loop
1006110090
// to the start value.
1006210091
BasicBlock *PBB = cast<Instruction>(ResumeV)->getParent();
1006310092
IRBuilder<> Builder(PBB, PBB->getFirstNonPHIIt());
10064-
ResumeV =
10065-
Builder.CreateICmpNE(ResumeV, RdxDesc.getRecurrenceStartValue());
10093+
ResumeV = Builder.CreateICmpNE(ResumeV, StartV);
1006610094
} else if (RecurrenceDescriptor::isFindLastIVRecurrenceKind(RK)) {
10067-
ToFrozen[RdxDesc.getRecurrenceStartValue()] =
10068-
cast<PHINode>(ResumeV)->getIncomingValueForBlock(
10069-
EPI.MainLoopIterationCountCheck);
10095+
Value *StartV = getStartValueFromReductionResult(RdxResult);
10096+
assert(RdxDesc.getRecurrenceStartValue() == StartV &&
10097+
"start value from ComputeFindLastIVResult must match");
10098+
10099+
ToFrozen[StartV] = cast<PHINode>(ResumeV)->getIncomingValueForBlock(
10100+
EPI.MainLoopIterationCountCheck);
1007010101

1007110102
// VPReductionPHIRecipe for FindLastIV reductions requires an adjustment
1007210103
// to the resume value. The resume value is adjusted to the sentinel
@@ -10076,8 +10107,7 @@ preparePlanForEpilogueVectorLoop(VPlan &Plan, Loop *L,
1007610107
// variable.
1007710108
BasicBlock *ResumeBB = cast<Instruction>(ResumeV)->getParent();
1007810109
IRBuilder<> Builder(ResumeBB, ResumeBB->getFirstNonPHIIt());
10079-
Value *Cmp = Builder.CreateICmpEQ(
10080-
ResumeV, ToFrozen[RdxDesc.getRecurrenceStartValue()]);
10110+
Value *Cmp = Builder.CreateICmpEQ(ResumeV, ToFrozen[StartV]);
1008110111
ResumeV =
1008210112
Builder.CreateSelect(Cmp, RdxDesc.getSentinelValue(), ResumeV);
1008310113
}

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -907,6 +907,7 @@ class VPInstruction : public VPRecipeWithIRFlags,
907907
BranchOnCount,
908908
BranchOnCond,
909909
Broadcast,
910+
ComputeAnyOfResult,
910911
ComputeFindLastIVResult,
911912
ComputeReductionResult,
912913
// Extracts the last lane from its operand if it is a vector, or the last

llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
8787
inferScalarType(R->getOperand(1)) &&
8888
"different types inferred for different operands");
8989
return IntegerType::get(Ctx, 1);
90+
case VPInstruction::ComputeAnyOfResult:
91+
return inferScalarType(R->getOperand(1));
9092
case VPInstruction::ComputeFindLastIVResult:
9193
case VPInstruction::ComputeReductionResult: {
9294
auto *PhiR = cast<VPReductionPHIRecipe>(R->getOperand(0));

llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,12 @@ m_VPInstruction(const Op0_t &Op0, const Op1_t &Op1, const Op2_t &Op2) {
318318
{Op0, Op1, Op2});
319319
}
320320

321+
template <typename Op0_t>
322+
inline UnaryVPInstruction_match<Op0_t, Instruction::Freeze>
323+
m_Freeze(const Op0_t &Op0) {
324+
return m_VPInstruction<Instruction::Freeze>(Op0);
325+
}
326+
321327
template <typename Op0_t>
322328
inline UnaryVPInstruction_match<Op0_t, VPInstruction::Not>
323329
m_Not(const Op0_t &Op0) {

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,20 @@ Value *VPInstruction::generate(VPTransformState &State) {
604604
return Builder.CreateVectorSplat(
605605
State.VF, State.get(getOperand(0), /*IsScalar*/ true), "broadcast");
606606
}
607+
case VPInstruction::ComputeAnyOfResult: {
608+
// FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary
609+
// and will be removed by breaking up the recipe further.
610+
auto *PhiR = cast<VPReductionPHIRecipe>(getOperand(0));
611+
auto *OrigPhi = cast<PHINode>(PhiR->getUnderlyingValue());
612+
Value *ReducedPartRdx = State.get(getOperand(2));
613+
for (unsigned Idx = 3; Idx < getNumOperands(); ++Idx)
614+
ReducedPartRdx = Builder.CreateBinOp(
615+
(Instruction::BinaryOps)RecurrenceDescriptor::getOpcode(
616+
RecurKind::AnyOf),
617+
State.get(getOperand(Idx)), ReducedPartRdx, "bin.rdx");
618+
return createAnyOfReduction(Builder, ReducedPartRdx,
619+
State.get(getOperand(1), VPLane(0)), OrigPhi);
620+
}
607621
case VPInstruction::ComputeFindLastIVResult: {
608622
// FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary
609623
// and will be removed by breaking up the recipe further.
@@ -670,19 +684,11 @@ Value *VPInstruction::generate(VPTransformState &State) {
670684

671685
// Create the reduction after the loop. Note that inloop reductions create
672686
// the target reduction in the loop using a Reduction recipe.
673-
if ((State.VF.isVector() ||
674-
RecurrenceDescriptor::isAnyOfRecurrenceKind(RK)) &&
675-
!PhiR->isInLoop()) {
687+
if (State.VF.isVector() && !PhiR->isInLoop()) {
676688
// TODO: Support in-order reductions based on the recurrence descriptor.
677689
// All ops in the reduction inherit fast-math-flags from the recurrence
678690
// descriptor.
679-
if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK)) {
680-
auto *OrigPhi = cast<PHINode>(PhiR->getUnderlyingValue());
681-
ReducedPartRdx =
682-
createAnyOfReduction(Builder, ReducedPartRdx,
683-
RdxDesc.getRecurrenceStartValue(), OrigPhi);
684-
} else
685-
ReducedPartRdx = createSimpleReduction(Builder, ReducedPartRdx, RK);
691+
ReducedPartRdx = createSimpleReduction(Builder, ReducedPartRdx, RK);
686692
}
687693

688694
return ReducedPartRdx;
@@ -813,6 +819,7 @@ bool VPInstruction::isVectorToScalar() const {
813819
getOpcode() == VPInstruction::ExtractPenultimateElement ||
814820
getOpcode() == Instruction::ExtractElement ||
815821
getOpcode() == VPInstruction::FirstActiveLane ||
822+
getOpcode() == VPInstruction::ComputeAnyOfResult ||
816823
getOpcode() == VPInstruction::ComputeFindLastIVResult ||
817824
getOpcode() == VPInstruction::ComputeReductionResult ||
818825
getOpcode() == VPInstruction::AnyOf;
@@ -908,6 +915,7 @@ bool VPInstruction::onlyFirstLaneUsed(const VPValue *Op) const {
908915
return true;
909916
case VPInstruction::PtrAdd:
910917
return Op == getOperand(0) || vputils::onlyFirstLaneUsed(this);
918+
case VPInstruction::ComputeAnyOfResult:
911919
case VPInstruction::ComputeFindLastIVResult:
912920
return Op == getOperand(1);
913921
};
@@ -988,6 +996,9 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
988996
case VPInstruction::ExtractPenultimateElement:
989997
O << "extract-penultimate-element";
990998
break;
999+
case VPInstruction::ComputeAnyOfResult:
1000+
O << "compute-anyof-result";
1001+
break;
9911002
case VPInstruction::ComputeFindLastIVResult:
9921003
O << "compute-find-last-iv-result";
9931004
break;

llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,9 @@ void UnrollState::unrollBlock(VPBlockBase *VPB) {
327327
// Add all VPValues for all parts to ComputeReductionResult which combines
328328
// the parts to compute the final reduction value.
329329
VPValue *Op1;
330-
if (match(&R, m_VPInstruction<VPInstruction::ComputeReductionResult>(
330+
if (match(&R, m_VPInstruction<VPInstruction::ComputeAnyOfResult>(
331+
m_VPValue(), m_VPValue(), m_VPValue(Op1))) ||
332+
match(&R, m_VPInstruction<VPInstruction::ComputeReductionResult>(
331333
m_VPValue(), m_VPValue(Op1))) ||
332334
match(&R, m_VPInstruction<VPInstruction::ComputeFindLastIVResult>(
333335
m_VPValue(), m_VPValue(), m_VPValue(Op1)))) {

0 commit comments

Comments
 (0)