Skip to content

Commit 416d2a4

Browse files
committed
[LV] Also clamp MaxVF by trip count when maximizing vector bandwidth.
Also clamp the max VF when maximizing vector bandwidth by the maximum trip count. Otherwise we may end up choosing a VF for which the vector loop never executes.
1 parent 5c7c855 commit 416d2a4

File tree

3 files changed

+188
-136
lines changed

3 files changed

+188
-136
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1505,6 +1505,11 @@ class LoopVectorizationCostModel {
15051505
ElementCount UserVF,
15061506
bool FoldTailByMasking);
15071507

1508+
/// If \p VF > MaxTripcount, clamps it to the next lower VF that is <=
1509+
/// MaxTripCount.
1510+
ElementCount clampVFByMaxTripCount(ElementCount VF, unsigned MaxTripCount,
1511+
bool FoldTailByMasking) const;
1512+
15081513
/// \return the maximized element count based on the targets vector
15091514
/// registers and the loop trip-count, but limited to a maximum safe VF.
15101515
/// This is a helper function of computeFeasibleMaxVF.
@@ -3854,6 +3859,38 @@ bool LoopVectorizationCostModel::useMaxBandwidth(
38543859
Legal->hasVectorCallVariants())));
38553860
}
38563861

3862+
ElementCount LoopVectorizationCostModel::clampVFByMaxTripCount(
3863+
ElementCount VF, unsigned MaxTripCount, bool FoldTailByMasking) const {
3864+
unsigned EstimatedVF = VF.getKnownMinValue();
3865+
if (VF.isScalable() && TheFunction->hasFnAttribute(Attribute::VScaleRange)) {
3866+
auto Attr = TheFunction->getFnAttribute(Attribute::VScaleRange);
3867+
auto Min = Attr.getVScaleRangeMin();
3868+
EstimatedVF *= Min;
3869+
}
3870+
3871+
// When a scalar epilogue is required, at least one iteration of the scalar
3872+
// loop has to execute. Adjust MaxTripCount accordingly to avoid picking a
3873+
// max VF that results in a dead vector loop.
3874+
if (MaxTripCount > 0 && requiresScalarEpilogue(true))
3875+
MaxTripCount -= 1;
3876+
3877+
if (MaxTripCount && MaxTripCount <= EstimatedVF &&
3878+
(!FoldTailByMasking || isPowerOf2_32(MaxTripCount))) {
3879+
// If upper bound loop trip count (TC) is known at compile time there is no
3880+
// point in choosing VF greater than TC (as done in the loop below). Select
3881+
// maximum power of two which doesn't exceed TC. If VF is
3882+
// scalable, we only fall back on a fixed VF when the TC is less than or
3883+
// equal to the known number of lanes.
3884+
auto ClampedUpperTripCount = llvm::bit_floor(MaxTripCount);
3885+
LLVM_DEBUG(dbgs() << "LV: Clamping the MaxVF to maximum power of two not "
3886+
"exceeding the constant trip count: "
3887+
<< ClampedUpperTripCount << "\n");
3888+
return ElementCount::get(ClampedUpperTripCount,
3889+
FoldTailByMasking ? VF.isScalable() : false);
3890+
}
3891+
return VF;
3892+
}
3893+
38573894
ElementCount LoopVectorizationCostModel::getMaximizedVFForTarget(
38583895
unsigned MaxTripCount, unsigned SmallestType, unsigned WidestType,
38593896
ElementCount MaxSafeVF, bool FoldTailByMasking) {
@@ -3885,40 +3922,14 @@ ElementCount LoopVectorizationCostModel::getMaximizedVFForTarget(
38853922
return ElementCount::getFixed(1);
38863923
}
38873924

3888-
unsigned WidestRegisterMinEC = MaxVectorElementCount.getKnownMinValue();
3889-
if (MaxVectorElementCount.isScalable() &&
3890-
TheFunction->hasFnAttribute(Attribute::VScaleRange)) {
3891-
auto Attr = TheFunction->getFnAttribute(Attribute::VScaleRange);
3892-
auto Min = Attr.getVScaleRangeMin();
3893-
WidestRegisterMinEC *= Min;
3894-
}
3895-
3896-
// When a scalar epilogue is required, at least one iteration of the scalar
3897-
// loop has to execute. Adjust MaxTripCount accordingly to avoid picking a
3898-
// max VF that results in a dead vector loop.
3899-
if (MaxTripCount > 0 && requiresScalarEpilogue(true))
3900-
MaxTripCount -= 1;
3901-
3902-
if (MaxTripCount && MaxTripCount <= WidestRegisterMinEC &&
3903-
(!FoldTailByMasking || isPowerOf2_32(MaxTripCount))) {
3904-
// If upper bound loop trip count (TC) is known at compile time there is no
3905-
// point in choosing VF greater than TC (as done in the loop below). Select
3906-
// maximum power of two which doesn't exceed TC. If MaxVectorElementCount is
3907-
// scalable, we only fall back on a fixed VF when the TC is less than or
3908-
// equal to the known number of lanes.
3909-
auto ClampedUpperTripCount = llvm::bit_floor(MaxTripCount);
3910-
LLVM_DEBUG(dbgs() << "LV: Clamping the MaxVF to maximum power of two not "
3911-
"exceeding the constant trip count: "
3912-
<< ClampedUpperTripCount << "\n");
3913-
return ElementCount::get(
3914-
ClampedUpperTripCount,
3915-
FoldTailByMasking ? MaxVectorElementCount.isScalable() : false);
3916-
}
3925+
ElementCount MaxVF = clampVFByMaxTripCount(MaxVectorElementCount,
3926+
MaxTripCount, FoldTailByMasking);
3927+
if (MaxVF != MaxVectorElementCount)
3928+
return MaxVF;
39173929

39183930
TargetTransformInfo::RegisterKind RegKind =
39193931
ComputeScalableMaxVF ? TargetTransformInfo::RGK_ScalableVector
39203932
: TargetTransformInfo::RGK_FixedWidthVector;
3921-
ElementCount MaxVF = MaxVectorElementCount;
39223933

39233934
if (MaxVF.isScalable())
39243935
MaxPermissibleVFWithoutMaxBW.ScalableVF = MaxVF;
@@ -3940,6 +3951,8 @@ ElementCount LoopVectorizationCostModel::getMaximizedVFForTarget(
39403951
}
39413952
}
39423953

3954+
MaxVF = clampVFByMaxTripCount(MaxVF, MaxTripCount, FoldTailByMasking);
3955+
39433956
// Invalidate any widening decisions we might have made, in case the loop
39443957
// requires prediction (decided later), but we have already made some
39453958
// load/store widening decisions.

llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product.ll

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1904,39 +1904,28 @@ define i64 @dotp_cost_disagreement(ptr %a, ptr %b) #0 {
19041904
; CHECK-MAXBW-LABEL: define i64 @dotp_cost_disagreement(
19051905
; CHECK-MAXBW-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) #[[ATTR0]] {
19061906
; CHECK-MAXBW-NEXT: entry:
1907-
; CHECK-MAXBW-NEXT: [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
1908-
; CHECK-MAXBW-NEXT: [[TMP1:%.*]] = mul nuw i64 [[TMP0]], 8
1909-
; CHECK-MAXBW-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 16, [[TMP1]]
1910-
; CHECK-MAXBW-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
1907+
; CHECK-MAXBW-NEXT: br i1 false, label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
19111908
; CHECK-MAXBW: vector.ph:
1912-
; CHECK-MAXBW-NEXT: [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
1913-
; CHECK-MAXBW-NEXT: [[TMP3:%.*]] = mul nuw i64 [[TMP2]], 8
1914-
; CHECK-MAXBW-NEXT: [[N_MOD_VF:%.*]] = urem i64 16, [[TMP3]]
1915-
; CHECK-MAXBW-NEXT: [[N_VEC:%.*]] = sub i64 16, [[N_MOD_VF]]
1916-
; CHECK-MAXBW-NEXT: [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
1917-
; CHECK-MAXBW-NEXT: [[TMP5:%.*]] = mul nuw i64 [[TMP4]], 8
19181909
; CHECK-MAXBW-NEXT: br label [[VECTOR_BODY:%.*]]
19191910
; CHECK-MAXBW: vector.body:
19201911
; CHECK-MAXBW-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
1921-
; CHECK-MAXBW-NEXT: [[VEC_PHI:%.*]] = phi <vscale x 8 x i64> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP19:%.*]], [[VECTOR_BODY]] ]
1912+
; CHECK-MAXBW-NEXT: [[VEC_PHI:%.*]] = phi <16 x i64> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP13:%.*]], [[VECTOR_BODY]] ]
19221913
; CHECK-MAXBW-NEXT: [[TMP7:%.*]] = getelementptr inbounds nuw i8, ptr [[A]], i64 [[INDEX]]
19231914
; CHECK-MAXBW-NEXT: [[TMP8:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP7]], i32 0
1924-
; CHECK-MAXBW-NEXT: [[WIDE_LOAD:%.*]] = load <vscale x 8 x i8>, ptr [[TMP8]], align 1
1925-
; CHECK-MAXBW-NEXT: [[TMP9:%.*]] = zext <vscale x 8 x i8> [[WIDE_LOAD]] to <vscale x 8 x i64>
1915+
; CHECK-MAXBW-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP8]], align 1
1916+
; CHECK-MAXBW-NEXT: [[TMP2:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i64>
19261917
; CHECK-MAXBW-NEXT: [[TMP10:%.*]] = add nuw nsw i64 [[INDEX]], 1
19271918
; CHECK-MAXBW-NEXT: [[TMP11:%.*]] = getelementptr inbounds nuw i8, ptr [[B]], i64 [[TMP10]]
19281919
; CHECK-MAXBW-NEXT: [[TMP12:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP11]], i32 0
1929-
; CHECK-MAXBW-NEXT: [[WIDE_LOAD1:%.*]] = load <vscale x 8 x i8>, ptr [[TMP12]], align 1
1930-
; CHECK-MAXBW-NEXT: [[TMP13:%.*]] = zext <vscale x 8 x i8> [[WIDE_LOAD1]] to <vscale x 8 x i64>
1931-
; CHECK-MAXBW-NEXT: [[TMP14:%.*]] = mul nuw nsw <vscale x 8 x i64> [[TMP13]], [[TMP9]]
1932-
; CHECK-MAXBW-NEXT: [[TMP19]] = add <vscale x 8 x i64> [[VEC_PHI]], [[TMP14]]
1933-
; CHECK-MAXBW-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
1934-
; CHECK-MAXBW-NEXT: [[TMP15:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
1935-
; CHECK-MAXBW-NEXT: br i1 [[TMP15]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP21:![0-9]+]]
1920+
; CHECK-MAXBW-NEXT: [[WIDE_LOAD1:%.*]] = load <16 x i8>, ptr [[TMP12]], align 1
1921+
; CHECK-MAXBW-NEXT: [[TMP6:%.*]] = zext <16 x i8> [[WIDE_LOAD1]] to <16 x i64>
1922+
; CHECK-MAXBW-NEXT: [[TMP14:%.*]] = mul nuw nsw <16 x i64> [[TMP6]], [[TMP2]]
1923+
; CHECK-MAXBW-NEXT: [[TMP13]] = add <16 x i64> [[VEC_PHI]], [[TMP14]]
1924+
; CHECK-MAXBW-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
1925+
; CHECK-MAXBW-NEXT: br i1 true, label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP21:![0-9]+]]
19361926
; CHECK-MAXBW: middle.block:
1937-
; CHECK-MAXBW-NEXT: [[TMP16:%.*]] = call i64 @llvm.vector.reduce.add.nxv8i64(<vscale x 8 x i64> [[TMP19]])
1938-
; CHECK-MAXBW-NEXT: [[CMP_N:%.*]] = icmp eq i64 16, [[N_VEC]]
1939-
; CHECK-MAXBW-NEXT: br i1 [[CMP_N]], label [[EXIT:%.*]], label [[SCALAR_PH]]
1927+
; CHECK-MAXBW-NEXT: [[TMP9:%.*]] = call i64 @llvm.vector.reduce.add.v16i64(<16 x i64> [[TMP13]])
1928+
; CHECK-MAXBW-NEXT: br i1 true, label [[EXIT:%.*]], label [[SCALAR_PH]]
19401929
; CHECK-MAXBW: scalar.ph:
19411930
;
19421931
entry:

0 commit comments

Comments
 (0)