Skip to content

Commit f7b6501

Browse files
authored
[PGO] Add llvm.loop.estimated_trip_count metadata (#148758)
This patch implements the `llvm.loop.estimated_trip_count` metadata discussed in [[RFC] Fix Loop Transformations to Preserve Block Frequencies](https://discourse.llvm.org/t/rfc-fix-loop-transformations-to-preserve-block-frequencies/85785). As [suggested in the RFC comments](https://discourse.llvm.org/t/rfc-fix-loop-transformations-to-preserve-block-frequencies/85785/4), it adds the new metadata to all loops at the time of profile ingestion and estimates each trip count from the loop's `branch_weights` metadata. As [suggested in the PR #128785 review](#128785 (comment)), it does so via a new `PGOEstimateTripCountsPass` pass, which creates the new metadata for each loop but omits the value if it cannot estimate a trip count due to the loop's form. An important observation not previously discussed is that `PGOEstimateTripCountsPass` *often* cannot estimate a loop's trip count, but later passes can sometimes transform the loop in a way that makes it possible. Currently, such passes do not necessarily update the metadata, but eventually that should be fixed. Until then, if the new metadata has no value, `llvm::getLoopEstimatedTripCount` disregards it and tries again to estimate the trip count from the loop's current `branch_weights` metadata.
1 parent 3e579d9 commit f7b6501

36 files changed

+938
-198
lines changed

llvm/docs/LangRef.rst

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7960,6 +7960,67 @@ The attributes in this metadata is added to all followup loops of the
79607960
loop distribution pass. See
79617961
:ref:`Transformation Metadata <transformation-metadata>` for details.
79627962

7963+
'``llvm.loop.estimated_trip_count``' Metadata
7964+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
7965+
7966+
This metadata records an estimated trip count for the loop. The first operand
7967+
is the string ``llvm.loop.estimated_trip_count``. The second operand is an
7968+
integer constant of type ``i32`` or smaller specifying the count, which might be
7969+
omitted for the reasons described below. For example:
7970+
7971+
.. code-block:: llvm
7972+
7973+
!0 = !{!"llvm.loop.estimated_trip_count", i32 8}
7974+
!1 = !{!"llvm.loop.estimated_trip_count"}
7975+
7976+
Purpose
7977+
"""""""
7978+
7979+
A loop's estimated trip count is an estimate of the average number of loop
7980+
iterations (specifically, the number of times the loop's header executes) each
7981+
time execution reaches the loop. It is usually only an estimate based on, for
7982+
example, profile data. The actual number of iterations might vary widely.
7983+
7984+
The estimated trip count serves as a parameter for various loop transformations
7985+
and typically helps estimate transformation cost. For example, it can help
7986+
determine how many iterations to peel or how aggressively to unroll.
7987+
7988+
Initialization and Maintenance
7989+
""""""""""""""""""""""""""""""
7990+
7991+
The ``pgo-estimate-trip-counts`` pass typically runs immediately after profile
7992+
ingestion to add this metadata to all loops. It estimates each loop's trip
7993+
count from the loop's ``branch_weights`` metadata. This way of initially
7994+
estimating trip counts appears to be useful for the passes that consume them.
7995+
7996+
As passes transform existing loops and create new loops, they must be free to
7997+
update and create ``branch_weights`` metadata to maintain accurate block
7998+
frequencies. Trip counts estimated from this new ``branch_weights`` metadata
7999+
are not necessarily useful to the passes that consume them. In general, when
8000+
passes transform and create loops, they should separately estimate new trip
8001+
counts from previously estimated trip counts, and they should record them by
8002+
creating or updating this metadata. For this or any other work involving
8003+
estimated trip counts, passes should always call
8004+
``llvm::getLoopEstimatedTripCount`` and ``llvm::setLoopEstimatedTripCount``.
8005+
8006+
Missing Metadata and Values
8007+
"""""""""""""""""""""""""""
8008+
8009+
If the current implementation of ``pgo-estimate-trip-counts`` cannot estimate a
8010+
trip count from the loop's ``branch_weights`` metadata due to the loop's form or
8011+
due to missing profile data, it creates this metadata for the loop but omits the
8012+
value. This situation is currently common (e.g., the LLVM IR loop that Clang
8013+
emits for a simple C ``for`` loop). A later pass (e.g., ``loop-rotate``) might
8014+
modify the loop's form in a way that enables estimating its trip count even if
8015+
those modifications provably never impact the actual number of loop iterations.
8016+
That later pass should then add an appropriate value to the metadata.
8017+
8018+
However, not all such passes currently do so. Thus, if this metadata has no
8019+
value, ``llvm::getLoopEstimatedTripCount`` will disregard it and estimate the
8020+
trip count from the loop's ``branch_weights`` metadata. It does the same when
8021+
the metadata is missing altogether, perhaps because ``pgo-estimate-trip-counts``
8022+
was not specified in a minimal pass list to a tool like ``opt``.
8023+
79638024
'``llvm.licm.disable``' Metadata
79648025
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
79658026

llvm/include/llvm/Analysis/LoopInfo.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -637,9 +637,13 @@ LLVM_ABI std::optional<bool> getOptionalBoolLoopAttribute(const Loop *TheLoop,
637637
/// Returns true if Name is applied to TheLoop and enabled.
638638
LLVM_ABI bool getBooleanLoopAttribute(const Loop *TheLoop, StringRef Name);
639639

640-
/// Find named metadata for a loop with an integer value.
641-
LLVM_ABI std::optional<int> getOptionalIntLoopAttribute(const Loop *TheLoop,
642-
StringRef Name);
640+
/// Find named metadata for a loop with an integer value. Return
641+
/// \c std::nullopt if the metadata has no value or is missing altogether. If
642+
/// \p Missing, set \c *Missing to indicate whether the metadata is missing
643+
/// altogether.
644+
LLVM_ABI std::optional<int>
645+
getOptionalIntLoopAttribute(const Loop *TheLoop, StringRef Name,
646+
bool *Missing = nullptr);
643647

644648
/// Find named metadata for a loop with an integer value. Return \p Default if
645649
/// not set.

llvm/include/llvm/IR/Metadata.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -919,8 +919,8 @@ class MDOperand {
919919

920920
// Check if MDOperand is of type MDString and equals `Str`.
921921
bool equalsStr(StringRef Str) const {
922-
return isa<MDString>(this->get()) &&
923-
cast<MDString>(this->get())->getString() == Str;
922+
return isa_and_nonnull<MDString>(get()) &&
923+
cast<MDString>(get())->getString() == Str;
924924
}
925925

926926
~MDOperand() { untrack(); }
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
//===- PGOEstimateTripCounts.h ----------------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef LLVM_TRANSFORMS_INSTRUMENTATION_PGOESTIMATETRIPCOUNTS_H
10+
#define LLVM_TRANSFORMS_INSTRUMENTATION_PGOESTIMATETRIPCOUNTS_H
11+
12+
#include "llvm/IR/PassManager.h"
13+
14+
namespace llvm {
15+
16+
struct PGOEstimateTripCountsPass
17+
: public PassInfoMixin<PGOEstimateTripCountsPass> {
18+
PGOEstimateTripCountsPass() {}
19+
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
20+
};
21+
22+
} // namespace llvm
23+
24+
#endif // LLVM_TRANSFORMS_INSTRUMENTATION_PGOESTIMATETRIPCOUNTS_H

llvm/include/llvm/Transforms/Utils/LoopUtils.h

Lines changed: 75 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ typedef std::pair<const RuntimeCheckingPtrGroup *,
5252
template <typename T, unsigned N> class SmallSetVector;
5353
template <typename T, unsigned N> class SmallPriorityWorklist;
5454

55+
const char *const LLVMLoopEstimatedTripCount = "llvm.loop.estimated_trip_count";
56+
5557
LLVM_ABI BasicBlock *InsertPreheaderForLoop(Loop *L, DominatorTree *DT,
5658
LoopInfo *LI,
5759
MemorySSAUpdater *MSSAU,
@@ -316,28 +318,81 @@ LLVM_ABI TransformationMode hasDistributeTransformation(const Loop *L);
316318
LLVM_ABI TransformationMode hasLICMVersioningTransformation(const Loop *L);
317319
/// @}
318320

319-
/// Set input string into loop metadata by keeping other values intact.
320-
/// If the string is already in loop metadata update value if it is
321-
/// different.
322-
LLVM_ABI void addStringMetadataToLoop(Loop *TheLoop, const char *MDString,
323-
unsigned V = 0);
324-
325-
/// Returns a loop's estimated trip count based on branch weight metadata.
326-
/// In addition if \p EstimatedLoopInvocationWeight is not null it is
327-
/// initialized with weight of loop's latch leading to the exit.
328-
/// Returns a valid positive trip count, saturated at UINT_MAX, or std::nullopt
329-
/// when a meaningful estimate cannot be made.
321+
/// Set the string \p MDString into the loop metadata of \p TheLoop while
322+
/// keeping other loop metadata intact. Set \p *V as its value, or set it
323+
/// without a value if \p V is \c std::nullopt to indicate the value is unknown.
324+
/// If \p MDString is already in the loop metadata, update it if its value (or
325+
/// lack of value) is different. Return true if metadata was changed.
326+
LLVM_ABI bool addStringMetadataToLoop(Loop *TheLoop, const char *MDString,
327+
std::optional<unsigned> V = 0);
328+
329+
/// Return either:
330+
/// - The value of \c llvm.loop.estimated_trip_count from the loop metadata of
331+
/// \p L, if that metadata is present and has a value.
332+
/// - Else, a new estimate of the trip count from the latch branch weights of
333+
/// \p L, if the estimation's implementation is able to handle the loop form
334+
/// of \p L (e.g., \p L must have a latch block that controls the loop exit).
335+
/// - Else, \c std::nullopt.
336+
///
337+
/// An estimated trip count is always a valid positive trip count, saturated at
338+
/// \c UINT_MAX.
339+
///
340+
/// Via \c LLVM_DEBUG, emit diagnostics that include "WARNING" when the metadata
341+
/// is in an unexpected state as that indicates some transformation has
342+
/// corrupted it. If \p DbgForInit, expect the metadata to be missing.
343+
/// Otherwise, expect the metadata to be present, and expect it to have no value
344+
/// only if the trip count is currently inestimable from the latch branch
345+
/// weights.
346+
///
347+
/// In addition, if \p EstimatedLoopInvocationWeight, then either:
348+
/// - Set \p *EstimatedLoopInvocationWeight to the weight of the latch's branch
349+
/// to the loop exit.
350+
/// - Do not set it and return \c std::nullopt if the current implementation
351+
/// cannot compute that weight (e.g., if \p L does not have a latch block that
352+
/// controls the loop exit) or the weight is zero (because zero cannot be
353+
/// used to compute new branch weights that reflect the estimated trip count).
354+
///
355+
/// TODO: Eventually, once all passes have migrated away from setting branch
356+
/// weights to indicate estimated trip counts, this function will drop the
357+
/// \p EstimatedLoopInvocationWeight parameter.
358+
///
359+
/// TODO: There are also passes that currently do not consider estimated trip
360+
/// counts at all but that, for example, affect whether trip counts can be
361+
/// estimated from branch weights. Once all such passes have been adjusted to
362+
/// update this metadata, this function might stop estimating trip counts from
363+
/// branch weights and instead simply get the \c llvm.loop_estimated_trip_count
364+
/// metadata. See also the \c llvm.loop.estimated_trip_count entry in
365+
/// \c LangRef.rst.
330366
LLVM_ABI std::optional<unsigned>
331367
getLoopEstimatedTripCount(Loop *L,
332-
unsigned *EstimatedLoopInvocationWeight = nullptr);
333-
334-
/// Set a loop's branch weight metadata to reflect that loop has \p
335-
/// EstimatedTripCount iterations and \p EstimatedLoopInvocationWeight exits
336-
/// through latch. Returns true if metadata is successfully updated, false
337-
/// otherwise. Note that loop must have a latch block which controls loop exit
338-
/// in order to succeed.
339-
LLVM_ABI bool setLoopEstimatedTripCount(Loop *L, unsigned EstimatedTripCount,
340-
unsigned EstimatedLoopInvocationWeight);
368+
unsigned *EstimatedLoopInvocationWeight = nullptr,
369+
bool DbgForInit = false);
370+
371+
/// Set \c llvm.loop.estimated_trip_count with the value \c *EstimatedTripCount
372+
/// in the loop metadata of \p L, or set it without a value if
373+
/// \c !EstimatedTripCount to indicate that \c getLoopEstimatedTripCount cannot
374+
/// estimate the trip count from latch branch weights. If
375+
/// \c !EstimatedTripCount but \c getLoopEstimatedTripCount can estimate the
376+
/// trip counts, future calls to \c getLoopEstimatedTripCount will diagnose the
377+
/// metadata as corrupt.
378+
///
379+
/// In addition, if \p EstimatedLoopInvocationWeight, set the branch weight
380+
/// metadata of \p L to reflect that \p L has an estimated
381+
/// \c *EstimatedTripCount iterations and has \c *EstimatedLoopInvocationWeight
382+
/// exit weight through the loop's latch.
383+
///
384+
/// Return false if \c llvm.loop.estimated_trip_count was already set according
385+
/// to \p EstimatedTripCount and so was not updated. Return false if
386+
/// \p EstimatedLoopInvocationWeight and if branch weight metadata could not be
387+
/// successfully updated (e.g., if \p L does not have a latch block that
388+
/// controls the loop exit). Otherwise, return true.
389+
///
390+
/// TODO: Eventually, once all passes have migrated away from setting branch
391+
/// weights to indicate estimated trip counts, this function will drop the
392+
/// \p EstimatedLoopInvocationWeight parameter.
393+
LLVM_ABI bool setLoopEstimatedTripCount(
394+
Loop *L, std::optional<unsigned> EstimatedTripCount,
395+
std::optional<unsigned> EstimatedLoopInvocationWeight = std::nullopt);
341396

342397
/// Check inner loop (L) backedge count is known to be invariant on all
343398
/// iterations of its outer loop. If the loop has no parent, this is trivially

llvm/lib/Analysis/LoopInfo.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,9 +1111,13 @@ bool llvm::getBooleanLoopAttribute(const Loop *TheLoop, StringRef Name) {
11111111
}
11121112

11131113
std::optional<int> llvm::getOptionalIntLoopAttribute(const Loop *TheLoop,
1114-
StringRef Name) {
1115-
const MDOperand *AttrMD =
1116-
findStringMetadataForLoop(TheLoop, Name).value_or(nullptr);
1114+
StringRef Name,
1115+
bool *Missing) {
1116+
std::optional<const MDOperand *> AttrMDOpt =
1117+
findStringMetadataForLoop(TheLoop, Name);
1118+
if (Missing)
1119+
*Missing = !AttrMDOpt;
1120+
const MDOperand *AttrMD = AttrMDOpt.value_or(nullptr);
11171121
if (!AttrMD)
11181122
return std::nullopt;
11191123

llvm/lib/IR/Verifier.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@
121121
#include "llvm/Support/MathExtras.h"
122122
#include "llvm/Support/ModRef.h"
123123
#include "llvm/Support/raw_ostream.h"
124+
#include "llvm/Transforms/Utils/LoopUtils.h"
124125
#include <algorithm>
125126
#include <cassert>
126127
#include <cstdint>
@@ -1071,6 +1072,21 @@ void Verifier::visitMDNode(const MDNode &MD, AreDebugLocsAllowed AllowLocs) {
10711072
}
10721073
}
10731074

1075+
// Check llvm.loop.estimated_trip_count.
1076+
if (MD.getNumOperands() > 0 &&
1077+
MD.getOperand(0).equalsStr(LLVMLoopEstimatedTripCount)) {
1078+
Check(MD.getNumOperands() == 1 || MD.getNumOperands() == 2,
1079+
"Expected one or two operands", &MD);
1080+
if (MD.getNumOperands() == 2) {
1081+
auto *Count = dyn_cast_or_null<ConstantAsMetadata>(MD.getOperand(1));
1082+
Check(Count && Count->getType()->isIntegerTy() &&
1083+
cast<IntegerType>(Count->getType())->getBitWidth() <= 32,
1084+
"Expected optional second operand to be an integer constant of "
1085+
"type i32 or smaller",
1086+
&MD);
1087+
}
1088+
}
1089+
10741090
// Check these last, so we diagnose problems in operands first.
10751091
Check(!MD.isTemporary(), "Expected no forward declarations!", &MD);
10761092
Check(MD.isResolved(), "All nodes should be resolved!", &MD);

llvm/lib/Passes/PassBuilder.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@
252252
#include "llvm/Transforms/Instrumentation/NumericalStabilitySanitizer.h"
253253
#include "llvm/Transforms/Instrumentation/PGOCtxProfFlattening.h"
254254
#include "llvm/Transforms/Instrumentation/PGOCtxProfLowering.h"
255+
#include "llvm/Transforms/Instrumentation/PGOEstimateTripCounts.h"
255256
#include "llvm/Transforms/Instrumentation/PGOForceFunctionAttrs.h"
256257
#include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
257258
#include "llvm/Transforms/Instrumentation/RealtimeSanitizer.h"

llvm/lib/Passes/PassBuilderPipelines.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
#include "llvm/Transforms/Instrumentation/MemProfUse.h"
8181
#include "llvm/Transforms/Instrumentation/PGOCtxProfFlattening.h"
8282
#include "llvm/Transforms/Instrumentation/PGOCtxProfLowering.h"
83+
#include "llvm/Transforms/Instrumentation/PGOEstimateTripCounts.h"
8384
#include "llvm/Transforms/Instrumentation/PGOForceFunctionAttrs.h"
8485
#include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
8586
#include "llvm/Transforms/Scalar/ADCE.h"
@@ -1239,6 +1240,7 @@ PassBuilder::buildModuleSimplificationPipeline(OptimizationLevel Level,
12391240
MPM.addPass(AssignGUIDPass());
12401241
if (IsCtxProfUse) {
12411242
MPM.addPass(PGOCtxProfFlatteningPass(/*IsPreThinlink=*/true));
1243+
MPM.addPass(PGOEstimateTripCountsPass());
12421244
return MPM;
12431245
}
12441246
// Block further inlining in the instrumented ctxprof case. This avoids
@@ -1268,8 +1270,10 @@ PassBuilder::buildModuleSimplificationPipeline(OptimizationLevel Level,
12681270
MPM.addPass(MemProfUsePass(PGOOpt->MemoryProfile, PGOOpt->FS));
12691271

12701272
if (PGOOpt && (PGOOpt->Action == PGOOptions::IRUse ||
1271-
PGOOpt->Action == PGOOptions::SampleUse))
1273+
PGOOpt->Action == PGOOptions::SampleUse)) {
12721274
MPM.addPass(PGOForceFunctionAttrsPass(PGOOpt->ColdOptType));
1275+
}
1276+
MPM.addPass(PGOEstimateTripCountsPass());
12731277

12741278
MPM.addPass(AlwaysInlinerPass(/*InsertLifetimeIntrinsics=*/true));
12751279

@@ -2355,4 +2359,4 @@ AAManager PassBuilder::buildDefaultAAPipeline() {
23552359
bool PassBuilder::isInstrumentedPGOUse() const {
23562360
return (PGOOpt && PGOOpt->Action == PGOOptions::IRUse) ||
23572361
!UseCtxProfile.empty();
2358-
}
2362+
}

llvm/lib/Passes/PassRegistry.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ MODULE_PASS("openmp-opt", OpenMPOptPass())
124124
MODULE_PASS("openmp-opt-postlink",
125125
OpenMPOptPass(ThinOrFullLTOPhase::FullLTOPostLink))
126126
MODULE_PASS("partial-inliner", PartialInlinerPass())
127+
MODULE_PASS("pgo-estimate-trip-counts", PGOEstimateTripCountsPass())
127128
MODULE_PASS("pgo-icall-prom", PGOIndirectCallPromotion())
128129
MODULE_PASS("pgo-instr-gen", PGOInstrumentationGen())
129130
MODULE_PASS("pgo-instr-use", PGOInstrumentationUse())

0 commit comments

Comments
 (0)