@@ -653,102 +653,85 @@ void VPlanTransforms::attachCheckBlock(VPlan &Plan, Value *Cond,
653
653
}
654
654
}
655
655
656
- static VPValue *getMinMaxCompareValue (VPSingleDefRecipe *MinMaxOp,
657
- VPReductionPHIRecipe *RedPhi) {
658
- auto *RepR = dyn_cast<VPReplicateRecipe>(MinMaxOp);
659
- if (!isa<VPWidenIntrinsicRecipe>(MinMaxOp) &&
660
- !(RepR && (isa<IntrinsicInst>(RepR->getUnderlyingInstr ()))))
661
- return nullptr ;
662
-
663
- if (MinMaxOp->getOperand (0 ) == RedPhi)
664
- return MinMaxOp->getOperand (1 );
665
- return MinMaxOp->getOperand (0 );
666
- }
667
-
668
- // / Returns true if there VPlan is read-only and execution can be resumed at the
669
- // / beginning of the last vector iteration in the scalar loop
670
- static bool canResumeInScalarLoopFromVectorLoop (VPlan &Plan) {
671
- for (VPBlockBase *VPB : vp_depth_first_shallow (
672
- Plan.getVectorLoopRegion ()->getEntryBasicBlock ())) {
673
- auto *VPBB = dyn_cast<VPBasicBlock>(VPB);
674
- if (!VPBB)
675
- return false ;
676
- for (auto &R : *VPBB) {
677
- if (match (&R, m_BranchOnCount (m_VPValue (), m_VPValue ())))
678
- continue ;
679
- if (R.mayWriteToMemory ())
680
- return false ;
681
- }
682
- }
683
- return true ;
684
- }
685
-
686
656
bool VPlanTransforms::handleMaxMinNumReductionsWithoutFastMath (VPlan &Plan) {
687
657
VPRegionBlock *LoopRegion = Plan.getVectorLoopRegion ();
688
- VPValue *AnyNaN = nullptr ;
689
658
VPReductionPHIRecipe *RedPhiR = nullptr ;
690
- VPRecipeWithIRFlags *MinMaxOp = nullptr ;
659
+ VPValue *MinMaxOp = nullptr ;
691
660
bool HasUnsupportedPhi = false ;
661
+
662
+ auto GetMinMaxCompareValue = [](VPSingleDefRecipe *MinMaxOp,
663
+ VPReductionPHIRecipe *RedPhi) -> VPValue * {
664
+ auto *RepR = dyn_cast<VPReplicateRecipe>(MinMaxOp);
665
+ if (!isa<VPWidenIntrinsicRecipe>(MinMaxOp) &&
666
+ !(RepR && (isa<IntrinsicInst>(RepR->getUnderlyingInstr ()))))
667
+ return nullptr ;
668
+
669
+ if (MinMaxOp->getOperand (0 ) == RedPhi)
670
+ return MinMaxOp->getOperand (1 );
671
+ assert (MinMaxOp->getOperand (1 ) == RedPhi &&
672
+ " Reduction phi operand expected" );
673
+ return MinMaxOp->getOperand (0 );
674
+ };
675
+
692
676
for (auto &R : LoopRegion->getEntryBasicBlock ()->phis ()) {
677
+ // TODO: Also support first-order recurrence phis.
693
678
HasUnsupportedPhi |=
694
679
!isa<VPCanonicalIVPHIRecipe, VPWidenIntOrFpInductionRecipe,
695
680
VPReductionPHIRecipe>(&R);
696
681
auto *Cur = dyn_cast<VPReductionPHIRecipe>(&R);
697
682
if (!Cur)
698
683
continue ;
684
+ // For now, only a single reduction is supported.
685
+ // TODO: Support multiple MaxNum/MinNum reductions and other reductions.
699
686
if (RedPhiR)
700
687
return false ;
701
- if (Cur->getRecurrenceKind () != RecurKind::FMaxNumNoFMFs &&
702
- Cur->getRecurrenceKind () != RecurKind::FMinNumNoFMFs )
688
+ if (Cur->getRecurrenceKind () != RecurKind::FMaxNum &&
689
+ Cur->getRecurrenceKind () != RecurKind::FMinNum )
703
690
continue ;
704
691
705
692
RedPhiR = Cur;
706
- MinMaxOp = dyn_cast<VPRecipeWithIRFlags>(
693
+ auto *MinMaxR = dyn_cast<VPRecipeWithIRFlags>(
707
694
RedPhiR->getBackedgeValue ()->getDefiningRecipe ());
708
- if (!MinMaxOp )
695
+ if (!MinMaxR )
709
696
return false ;
710
- VPValue *In = getMinMaxCompareValue (MinMaxOp , RedPhiR);
711
- if (!In )
697
+ MinMaxOp = GetMinMaxCompareValue (MinMaxR , RedPhiR);
698
+ if (!MinMaxOp )
712
699
return false ;
713
-
714
- auto *IsNaN =
715
- new VPInstruction (Instruction::FCmp, {In, In}, {CmpInst::FCMP_UNO}, {});
716
- IsNaN->insertBefore (MinMaxOp);
717
- AnyNaN = new VPInstruction (VPInstruction::AnyOf, {IsNaN});
718
- AnyNaN->getDefiningRecipe ()->insertAfter (IsNaN);
719
700
}
720
701
721
- if (!AnyNaN )
702
+ if (!RedPhiR )
722
703
return true ;
723
704
724
- if (HasUnsupportedPhi || !canResumeInScalarLoopFromVectorLoop ( Plan))
705
+ if (HasUnsupportedPhi || !Plan. hasScalarTail ( ))
725
706
return false ;
726
707
708
+ // / Check if the vector loop of \p Plan can early exit and restart
709
+ // / execution of last vector iteration in the scalar loop. This requires all
710
+ // / recipes up to early exit point be side-effect free as they are
711
+ // / re-executed. Currently we check that the loop is free of any recipe that
712
+ // / may write to memory. Expected to operate on an early VPlan w/o nested
713
+ // / regions.
714
+ for (VPBlockBase *VPB : vp_depth_first_shallow (
715
+ Plan.getVectorLoopRegion ()->getEntryBasicBlock ())) {
716
+ auto *VPBB = cast<VPBasicBlock>(VPB);
717
+ for (auto &R : *VPBB) {
718
+ if (match (&R, m_BranchOnCount (m_VPValue (), m_VPValue ())))
719
+ continue ;
720
+ if (R.mayWriteToMemory ())
721
+ return false ;
722
+ }
723
+ }
724
+
727
725
auto *MiddleVPBB = Plan.getMiddleBlock ();
728
726
auto *RdxResult = dyn_cast<VPInstruction>(&MiddleVPBB->front ());
729
727
if (!RdxResult ||
730
728
RdxResult->getOpcode () != VPInstruction::ComputeReductionResult ||
731
729
RdxResult->getOperand (0 ) != RedPhiR)
732
730
return false ;
733
731
734
- auto *ScalarPH = Plan.getScalarPreheader ();
735
- // Update the resume phis in the scalar preheader. They all must either resume
736
- // from the reduction result or the canonical induction. Bail out if there are
737
- // other resume phis.
738
- for (auto &R : ScalarPH->phis ()) {
739
- auto *ResumeR = cast<VPPhi>(&R);
740
- VPValue *VecV = ResumeR->getOperand (0 );
741
- VPValue *BypassV = ResumeR->getOperand (ResumeR->getNumOperands () - 1 );
742
- if (VecV != RdxResult && VecV != &Plan.getVectorTripCount ())
743
- return false ;
744
- ResumeR->setOperand (
745
- 1 , VecV == &Plan.getVectorTripCount () ? Plan.getCanonicalIV () : VecV);
746
- ResumeR->addOperand (BypassV);
747
- }
748
-
749
732
// Create a new reduction phi recipe with either FMin/FMax, replacing
750
- // FMinNumNoFMFs/FMaxNumNoFMFs .
751
- RecurKind NewRK = RedPhiR->getRecurrenceKind () != RecurKind::FMinNumNoFMFs
733
+ // FMinNum/FMaxNum .
734
+ RecurKind NewRK = RedPhiR->getRecurrenceKind () == RecurKind::FMinNum
752
735
? RecurKind::FMin
753
736
: RecurKind::FMax;
754
737
auto *NewRedPhiR = new VPReductionPHIRecipe (
@@ -769,23 +752,40 @@ bool VPlanTransforms::handleMaxMinNumReductionsWithoutFastMath(VPlan &Plan) {
769
752
auto *IsLatchExitTaken =
770
753
Builder.createICmp (CmpInst::ICMP_EQ, LatchExitingBranch->getOperand (0 ),
771
754
LatchExitingBranch->getOperand (1 ));
755
+
756
+ VPValue *IsNaN = Builder.createFCmp (CmpInst::FCMP_UNO, MinMaxOp, MinMaxOp);
757
+ VPValue *AnyNaN = Builder.createNaryOp (VPInstruction::AnyOf, {IsNaN});
772
758
auto *AnyExitTaken =
773
759
Builder.createNaryOp (Instruction::Or, {AnyNaN, IsLatchExitTaken});
774
760
Builder.createNaryOp (VPInstruction::BranchOnCond, AnyExitTaken);
775
761
LatchExitingBranch->eraseFromParent ();
776
762
777
- // Split the middle block and introduce a new block, branching to the scalar
778
- // preheader to resume iteration in the scalar loop if any NaNs have been
779
- // encountered.
780
- MiddleVPBB->splitAt (std::prev (MiddleVPBB->end ()));
763
+ // If we exit early due to NaNs, compute the final reduction result based on
764
+ // the reduction phi at the beginning of the last vector iteration.
781
765
Builder.setInsertPoint (MiddleVPBB, MiddleVPBB->begin ());
782
766
auto *NewSel =
783
767
Builder.createSelect (AnyNaN, NewRedPhiR, RdxResult->getOperand (1 ));
784
768
RdxResult->setOperand (1 , NewSel);
785
- Builder.setInsertPoint (MiddleVPBB);
786
- Builder.createNaryOp (VPInstruction::BranchOnCond, AnyNaN);
787
- VPBlockUtils::connectBlocks (MiddleVPBB, ScalarPH);
788
- MiddleVPBB->swapSuccessors ();
789
- std::swap (ScalarPH->getPredecessors ()[1 ], ScalarPH->getPredecessors ().back ());
769
+
770
+ auto *ScalarPH = Plan.getScalarPreheader ();
771
+ // Update the resume phis for inductions in the scalar preheader. If AnyNaN is
772
+ // true, the resume from the start of the last vector iteration via the
773
+ // canonical IV, otherwise from the original value.
774
+ for (auto &R : ScalarPH->phis ()) {
775
+ auto *ResumeR = cast<VPPhi>(&R);
776
+ VPValue *VecV = ResumeR->getOperand (0 );
777
+ if (VecV == RdxResult)
778
+ continue ;
779
+ if (VecV != &Plan.getVectorTripCount ())
780
+ return false ;
781
+ auto *NewSel = Builder.createSelect (AnyNaN, Plan.getCanonicalIV (), VecV);
782
+ ResumeR->setOperand (0 , NewSel);
783
+ }
784
+
785
+ auto *MiddleTerm = MiddleVPBB->getTerminator ();
786
+ Builder.setInsertPoint (MiddleTerm);
787
+ VPValue *MiddleCond = MiddleTerm->getOperand (0 );
788
+ VPValue *NewCond = Builder.createAnd (MiddleCond, Builder.createNot (AnyNaN));
789
+ MiddleTerm->setOperand (0 , NewCond);
790
790
return true ;
791
791
}
0 commit comments