4040#include " llvm/Transforms/Utils/LoopUtils.h"
4141#include " llvm/Transforms/Utils/ScalarEvolutionExpander.h"
4242#include " llvm/Transforms/Utils/UnrollLoop.h"
43+ #include < cmath>
4344
4445using namespace llvm ;
4546
@@ -195,6 +196,21 @@ static void ConnectProlog(Loop *L, Value *BECount, unsigned Count,
195196 }
196197}
197198
199+ // / Assume, due to our position in the remainder loop or its guard, anywhere
200+ // / from 0 to \p N more iterations can possibly execute. Among such cases in
201+ // / the original loop (with loop probability \p OriginalLoopProb), what is the
202+ // / probability of executing at least one more iteration?
203+ static BranchProbability
204+ probOfNextInRemainder (BranchProbability OriginalLoopProb, unsigned N) {
205+ // Each of these variables holds the original loop's probability that the
206+ // number of iterations it will execute is some m in the specified range.
207+ BranchProbability ProbOne = OriginalLoopProb; // 1 <= m
208+ BranchProbability ProbTooMany = ProbOne.pow (N + 1 ); // N + 1 <= m
209+ BranchProbability ProbNotTooMany = ProbTooMany.getCompl (); // 0 <= m <= N
210+ BranchProbability ProbOneNotTooMany = ProbOne - ProbTooMany; // 1 <= m <= N
211+ return ProbOneNotTooMany / ProbNotTooMany;
212+ }
213+
198214// / Connect the unrolling epilog code to the original loop.
199215// / The unrolling epilog code contains code to execute the
200216// / 'extra' iterations if the run-time trip count modulo the
@@ -221,7 +237,8 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit,
221237 BasicBlock *EpilogPreHeader, BasicBlock *NewPreHeader,
222238 ValueToValueMapTy &VMap, DominatorTree *DT,
223239 LoopInfo *LI, bool PreserveLCSSA, ScalarEvolution &SE,
224- unsigned Count, AssumptionCache &AC) {
240+ unsigned Count, AssumptionCache &AC,
241+ BranchProbability OriginalLoopProb) {
225242 BasicBlock *Latch = L->getLoopLatch ();
226243 assert (Latch && " Loop must have a latch" );
227244 BasicBlock *EpilogLatch = cast<BasicBlock>(VMap[Latch]);
@@ -332,12 +349,19 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit,
332349 PreserveLCSSA);
333350 // Add the branch to the exit block (around the epilog loop)
334351 MDNode *BranchWeights = nullptr ;
335- if (hasBranchWeightMD (*Latch->getTerminator ())) {
352+ if (OriginalLoopProb.isUnknown () &&
353+ hasBranchWeightMD (*Latch->getTerminator ())) {
336354 // Assume equal distribution in interval [0, Count).
337355 MDBuilder MDB (B.getContext ());
338356 BranchWeights = MDB.createBranchWeights (1 , Count - 1 );
339357 }
340- B.CreateCondBr (BrLoopExit, EpilogPreHeader, Exit, BranchWeights);
358+ BranchInst *RemainderLoopGuard =
359+ B.CreateCondBr (BrLoopExit, EpilogPreHeader, Exit, BranchWeights);
360+ if (!OriginalLoopProb.isUnknown ()) {
361+ setBranchProbability (RemainderLoopGuard,
362+ probOfNextInRemainder (OriginalLoopProb, Count - 1 ),
363+ /* ForFirstTarget=*/ true );
364+ }
341365 InsertPt->eraseFromParent ();
342366 if (DT) {
343367 auto *NewDom = DT->findNearestCommonDominator (Exit, NewExit);
@@ -357,14 +381,15 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit,
357381// / The cloned blocks should be inserted between InsertTop and InsertBot.
358382// / InsertTop should be new preheader, InsertBot new loop exit.
359383// / Returns the new cloned loop that is created.
360- static Loop *
361- CloneLoopBlocks (Loop *L, Value *NewIter, const bool UseEpilogRemainder,
362- const bool UnrollRemainder,
363- BasicBlock *InsertTop,
364- BasicBlock *InsertBot, BasicBlock *Preheader,
384+ static Loop *CloneLoopBlocks (Loop *L, Value *NewIter,
385+ const bool UseEpilogRemainder,
386+ const bool UnrollRemainder, BasicBlock *InsertTop,
387+ BasicBlock *InsertBot, BasicBlock *Preheader,
365388 std::vector<BasicBlock *> &NewBlocks,
366389 LoopBlocksDFS &LoopBlocks, ValueToValueMapTy &VMap,
367- DominatorTree *DT, LoopInfo *LI, unsigned Count) {
390+ DominatorTree *DT, LoopInfo *LI, unsigned Count,
391+ std::optional<unsigned > OriginalTripCount,
392+ BranchProbability OriginalLoopProb) {
368393 StringRef suffix = UseEpilogRemainder ? " epil" : " prol" ;
369394 BasicBlock *Header = L->getHeader ();
370395 BasicBlock *Latch = L->getLoopLatch ();
@@ -419,7 +444,8 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool UseEpilogRemainder,
419444 Builder.CreateAdd (NewIdx, One, NewIdx->getName () + " .next" );
420445 Value *IdxCmp = Builder.CreateICmpNE (IdxNext, NewIter, NewIdx->getName () + " .cmp" );
421446 MDNode *BranchWeights = nullptr ;
422- if (hasBranchWeightMD (*LatchBR)) {
447+ if ((OriginalLoopProb.isUnknown () || !UseEpilogRemainder) &&
448+ hasBranchWeightMD (*LatchBR)) {
423449 uint32_t ExitWeight;
424450 uint32_t BackEdgeWeight;
425451 if (Count >= 3 ) {
@@ -437,7 +463,29 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool UseEpilogRemainder,
437463 MDBuilder MDB (Builder.getContext ());
438464 BranchWeights = MDB.createBranchWeights (BackEdgeWeight, ExitWeight);
439465 }
440- Builder.CreateCondBr (IdxCmp, FirstLoopBB, InsertBot, BranchWeights);
466+ BranchInst *RemainderLoopLatch =
467+ Builder.CreateCondBr (IdxCmp, FirstLoopBB, InsertBot, BranchWeights);
468+ if (!OriginalLoopProb.isUnknown () && UseEpilogRemainder) {
469+ // Compute the total frequency of the original loop body from the
470+ // remainder iterations. Once we've reached them, the first of them
471+ // always executes, so its frequency and probability are 1.
472+ double FreqRemIters = 1 ;
473+ if (Count > 2 ) {
474+ BranchProbability ProbReaching = BranchProbability::getOne ();
475+ for (unsigned N = Count - 2 ; N >= 1 ; --N) {
476+ ProbReaching *= probOfNextInRemainder (OriginalLoopProb, N);
477+ FreqRemIters += double (ProbReaching.getNumerator ()) /
478+ ProbReaching.getDenominator ();
479+ }
480+ }
481+ // Solve for the loop probability that would produce that frequency.
482+ // Sum(i=0..inf)(Prob^i) = 1/(1-Prob) = FreqRemIters.
483+ double ProbDouble = 1 - 1 / FreqRemIters;
484+ BranchProbability Prob = BranchProbability::getBranchProbability (
485+ std::round (ProbDouble * BranchProbability::getDenominator ()),
486+ BranchProbability::getDenominator ());
487+ setBranchProbability (RemainderLoopLatch, Prob, /* ForFirstTarget=*/ true );
488+ }
441489 NewIdx->addIncoming (Zero, InsertTop);
442490 NewIdx->addIncoming (IdxNext, NewBB);
443491 LatchBR->eraseFromParent ();
@@ -461,6 +509,9 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool UseEpilogRemainder,
461509 Loop *NewLoop = NewLoops[L];
462510 assert (NewLoop && " L should have been cloned" );
463511
512+ if (OriginalTripCount && UseEpilogRemainder)
513+ setLoopEstimatedTripCount (NewLoop, *OriginalTripCount % Count);
514+
464515 // Add unroll disable metadata to disable future unrolling for this loop.
465516 if (!UnrollRemainder)
466517 NewLoop->setLoopAlreadyUnrolled ();
@@ -588,7 +639,8 @@ bool llvm::UnrollRuntimeLoopRemainder(
588639 LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT, AssumptionCache *AC,
589640 const TargetTransformInfo *TTI, bool PreserveLCSSA,
590641 unsigned SCEVExpansionBudget, bool RuntimeUnrollMultiExit,
591- Loop **ResultLoop) {
642+ Loop **ResultLoop, std::optional<unsigned > OriginalTripCount,
643+ BranchProbability OriginalLoopProb) {
592644 LLVM_DEBUG (dbgs () << " Trying runtime unrolling on Loop: \n " );
593645 LLVM_DEBUG (L->dump ());
594646 LLVM_DEBUG (UseEpilogRemainder ? dbgs () << " Using epilog remainder.\n "
@@ -808,12 +860,23 @@ bool llvm::UnrollRuntimeLoopRemainder(
808860 BasicBlock *UnrollingLoop = UseEpilogRemainder ? NewPreHeader : PrologExit;
809861 // Branch to either remainder (extra iterations) loop or unrolling loop.
810862 MDNode *BranchWeights = nullptr ;
811- if (hasBranchWeightMD (*Latch->getTerminator ())) {
863+ if ((OriginalLoopProb.isUnknown () || !UseEpilogRemainder) &&
864+ hasBranchWeightMD (*Latch->getTerminator ())) {
812865 // Assume loop is nearly always entered.
813866 MDBuilder MDB (B.getContext ());
814867 BranchWeights = MDB.createBranchWeights (EpilogHeaderWeights);
815868 }
816- B.CreateCondBr (BranchVal, RemainderLoop, UnrollingLoop, BranchWeights);
869+ BranchInst *UnrollingLoopGuard =
870+ B.CreateCondBr (BranchVal, RemainderLoop, UnrollingLoop, BranchWeights);
871+ if (!OriginalLoopProb.isUnknown () && UseEpilogRemainder) {
872+ // The original loop's first iteration always happens. Compute the
873+ // probability of the original loop executing Count-1 iterations after that
874+ // to complete the first iteration of the unrolled loop.
875+ BranchProbability ProbOne = OriginalLoopProb;
876+ BranchProbability ProbRest = ProbOne.pow (Count - 1 );
877+ setBranchProbability (UnrollingLoopGuard, ProbRest,
878+ /* ForFirstTarget=*/ false );
879+ }
817880 PreHeaderBR->eraseFromParent ();
818881 if (DT) {
819882 if (UseEpilogRemainder)
@@ -840,9 +903,10 @@ bool llvm::UnrollRuntimeLoopRemainder(
840903 // iterations. This function adds the appropriate CFG connections.
841904 BasicBlock *InsertBot = UseEpilogRemainder ? LatchExit : PrologExit;
842905 BasicBlock *InsertTop = UseEpilogRemainder ? EpilogPreHeader : PrologPreHeader;
843- Loop *remainderLoop = CloneLoopBlocks (
844- L, ModVal, UseEpilogRemainder, UnrollRemainder, InsertTop, InsertBot,
845- NewPreHeader, NewBlocks, LoopBlocks, VMap, DT, LI, Count);
906+ Loop *remainderLoop =
907+ CloneLoopBlocks (L, ModVal, UseEpilogRemainder, UnrollRemainder, InsertTop,
908+ InsertBot, NewPreHeader, NewBlocks, LoopBlocks, VMap, DT,
909+ LI, Count, OriginalTripCount, OriginalLoopProb);
846910
847911 // Insert the cloned blocks into the function.
848912 F->splice (InsertBot->getIterator (), F, NewBlocks[0 ]->getIterator (), F->end ());
@@ -941,7 +1005,8 @@ bool llvm::UnrollRuntimeLoopRemainder(
9411005 // Connect the epilog code to the original loop and update the
9421006 // PHI functions.
9431007 ConnectEpilog (L, ModVal, NewExit, LatchExit, PreHeader, EpilogPreHeader,
944- NewPreHeader, VMap, DT, LI, PreserveLCSSA, *SE, Count, *AC);
1008+ NewPreHeader, VMap, DT, LI, PreserveLCSSA, *SE, Count, *AC,
1009+ OriginalLoopProb);
9451010
9461011 // Update counter in loop for unrolling.
9471012 // Use an incrementing IV. Pre-incr/post-incr is backedge/trip count.
0 commit comments