-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[VPlan] Support multiple F(Max|Min)Num reductions. #161735
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
27e362a
c0c5161
e96fe0a
766d150
b65d362
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -826,7 +826,7 @@ bool VPlanTransforms::handleMaxMinNumReductions(VPlan &Plan) { | |
}; | ||
|
||
VPRegionBlock *LoopRegion = Plan.getVectorLoopRegion(); | ||
VPReductionPHIRecipe *RedPhiR = nullptr; | ||
SmallVector<VPReductionPHIRecipe *> ReductionsToConvert; | ||
bool HasUnsupportedPhi = false; | ||
for (auto &R : LoopRegion->getEntryBasicBlock()->phis()) { | ||
if (isa<VPCanonicalIVPHIRecipe, VPWidenIntOrFpInductionRecipe>(&R)) | ||
|
@@ -837,19 +837,15 @@ bool VPlanTransforms::handleMaxMinNumReductions(VPlan &Plan) { | |
HasUnsupportedPhi = true; | ||
continue; | ||
} | ||
// For now, only a single reduction is supported. | ||
// TODO: Support multiple MaxNum/MinNum reductions and other reductions. | ||
if (RedPhiR) | ||
return false; | ||
if (Cur->getRecurrenceKind() != RecurKind::FMaxNum && | ||
Cur->getRecurrenceKind() != RecurKind::FMinNum) { | ||
HasUnsupportedPhi = true; | ||
continue; | ||
} | ||
RedPhiR = Cur; | ||
ReductionsToConvert.push_back(Cur); | ||
} | ||
|
||
if (!RedPhiR) | ||
if (ReductionsToConvert.empty()) | ||
return true; | ||
|
||
// We won't be able to resume execution in the scalar tail, if there are | ||
|
@@ -858,15 +854,6 @@ bool VPlanTransforms::handleMaxMinNumReductions(VPlan &Plan) { | |
if (HasUnsupportedPhi || !Plan.hasScalarTail()) | ||
return false; | ||
|
||
VPValue *MinMaxOp = GetMinMaxCompareValue(RedPhiR); | ||
if (!MinMaxOp) | ||
return false; | ||
|
||
RecurKind RedPhiRK = RedPhiR->getRecurrenceKind(); | ||
assert((RedPhiRK == RecurKind::FMaxNum || RedPhiRK == RecurKind::FMinNum) && | ||
"unsupported reduction"); | ||
(void)RedPhiRK; | ||
|
||
/// Check if the vector loop of \p Plan can early exit and restart | ||
/// execution of last vector iteration in the scalar loop. This requires all | ||
/// recipes up to early exit point be side-effect free as they are | ||
|
@@ -884,52 +871,69 @@ bool VPlanTransforms::handleMaxMinNumReductions(VPlan &Plan) { | |
} | ||
|
||
VPBasicBlock *LatchVPBB = LoopRegion->getExitingBasicBlock(); | ||
VPBasicBlock *MiddleVPBB = Plan.getMiddleBlock(); | ||
VPBuilder MiddleBuilder(MiddleVPBB, MiddleVPBB->begin()); | ||
VPBuilder Builder(LatchVPBB->getTerminator()); | ||
auto *LatchExitingBranch = cast<VPInstruction>(LatchVPBB->getTerminator()); | ||
assert(LatchExitingBranch->getOpcode() == VPInstruction::BranchOnCount && | ||
VPValue *AnyNaN = nullptr; | ||
SmallPtrSet<VPValue *, 2> RdxResults; | ||
for (VPReductionPHIRecipe *RedPhiR : ReductionsToConvert) { | ||
VPValue *MinMaxOp = GetMinMaxCompareValue(RedPhiR); | ||
if (!MinMaxOp) | ||
return false; | ||
|
||
RecurKind RedPhiRK = RedPhiR->getRecurrenceKind(); | ||
assert((RedPhiRK == RecurKind::FMaxNum || RedPhiRK == RecurKind::FMinNum) && | ||
"unsupported reduction"); | ||
(void)RedPhiRK; | ||
|
||
|
||
VPValue *IsNaN = Builder.createFCmp(CmpInst::FCMP_UNO, MinMaxOp, MinMaxOp); | ||
VPValue *HasNaN = Builder.createNaryOp(VPInstruction::AnyOf, {IsNaN}); | ||
if (AnyNaN) | ||
AnyNaN = Builder.createOr(AnyNaN, HasNaN); | ||
else | ||
AnyNaN = HasNaN; | ||
|
||
// If we exit early due to NaNs, compute the final reduction result based | ||
// on the reduction phi at the beginning of the last vector iteration. | ||
auto *RdxResult = find_singleton<VPSingleDefRecipe>( | ||
RedPhiR->users(), [](VPUser *U, bool) -> VPSingleDefRecipe * { | ||
|
||
auto *VPI = dyn_cast<VPInstruction>(U); | ||
if (VPI && VPI->getOpcode() == VPInstruction::ComputeReductionResult) | ||
return VPI; | ||
return nullptr; | ||
}); | ||
|
||
auto *NewSel = | ||
MiddleBuilder.createSelect(HasNaN, RedPhiR, RdxResult->getOperand(1)); | ||
RdxResult->setOperand(1, NewSel); | ||
RdxResults.insert(RdxResult); | ||
} | ||
|
||
auto *LatchExitingBranch = LatchVPBB->getTerminator(); | ||
assert(match(LatchExitingBranch, m_BranchOnCount(m_VPValue(), m_VPValue())) && | ||
"Unexpected terminator"); | ||
auto *IsLatchExitTaken = | ||
Builder.createICmp(CmpInst::ICMP_EQ, LatchExitingBranch->getOperand(0), | ||
LatchExitingBranch->getOperand(1)); | ||
|
||
VPValue *IsNaN = Builder.createFCmp(CmpInst::FCMP_UNO, MinMaxOp, MinMaxOp); | ||
VPValue *AnyNaN = Builder.createNaryOp(VPInstruction::AnyOf, {IsNaN}); | ||
auto *AnyExitTaken = | ||
Builder.createNaryOp(Instruction::Or, {AnyNaN, IsLatchExitTaken}); | ||
Builder.createNaryOp(VPInstruction::BranchOnCond, AnyExitTaken); | ||
LatchExitingBranch->eraseFromParent(); | ||
|
||
// If we exit early due to NaNs, compute the final reduction result based on | ||
// the reduction phi at the beginning of the last vector iteration. | ||
auto *RdxResult = find_singleton<VPSingleDefRecipe>( | ||
RedPhiR->users(), [](VPUser *U, bool) -> VPSingleDefRecipe * { | ||
auto *VPI = dyn_cast<VPInstruction>(U); | ||
if (VPI && VPI->getOpcode() == VPInstruction::ComputeReductionResult) | ||
return VPI; | ||
return nullptr; | ||
}); | ||
|
||
auto *MiddleVPBB = Plan.getMiddleBlock(); | ||
Builder.setInsertPoint(MiddleVPBB, MiddleVPBB->begin()); | ||
auto *NewSel = | ||
Builder.createSelect(AnyNaN, RedPhiR, RdxResult->getOperand(1)); | ||
RdxResult->setOperand(1, NewSel); | ||
|
||
auto *ScalarPH = Plan.getScalarPreheader(); | ||
// Update resume phis for inductions in the scalar preheader. If AnyNaN is | ||
// true, the resume from the start of the last vector iteration via the | ||
// canonical IV, otherwise from the original value. | ||
for (auto &R : ScalarPH->phis()) { | ||
for (auto &R : Plan.getScalarPreheader()->phis()) { | ||
auto *ResumeR = cast<VPPhi>(&R); | ||
VPValue *VecV = ResumeR->getOperand(0); | ||
if (VecV == RdxResult) | ||
if (RdxResults.contains(VecV)) | ||
continue; | ||
if (auto *DerivedIV = dyn_cast<VPDerivedIVRecipe>(VecV)) { | ||
if (DerivedIV->getNumUsers() == 1 && | ||
DerivedIV->getOperand(1) == &Plan.getVectorTripCount()) { | ||
auto *NewSel = Builder.createSelect(AnyNaN, Plan.getCanonicalIV(), | ||
&Plan.getVectorTripCount()); | ||
DerivedIV->moveAfter(&*Builder.getInsertPoint()); | ||
auto *NewSel = MiddleBuilder.createSelect(AnyNaN, Plan.getCanonicalIV(), | ||
&Plan.getVectorTripCount()); | ||
DerivedIV->moveAfter(&*MiddleBuilder.getInsertPoint()); | ||
DerivedIV->setOperand(1, NewSel); | ||
continue; | ||
} | ||
|
@@ -941,7 +945,8 @@ bool VPlanTransforms::handleMaxMinNumReductions(VPlan &Plan) { | |
"FMaxNum/FMinNum reduction.\n"); | ||
return false; | ||
} | ||
auto *NewSel = Builder.createSelect(AnyNaN, Plan.getCanonicalIV(), VecV); | ||
auto *NewSel = | ||
MiddleBuilder.createSelect(AnyNaN, Plan.getCanonicalIV(), VecV); | ||
ResumeR->setOperand(0, NewSel); | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the recurrence chain for RecurKind::FMinNum and RecurKind::FMaxNum limited to a single min/max operation (similar to FindLast, which has only one select)?
Seems like GetMinMaxCompareValue only returns the last min/max operation.
If multiple min/max operations can appear in the recurrence chain, should all of them be handled? Or only handle the last one is correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll check separately.