Skip to content

Commit 1440ad1

Browse files
committed
!fixup dont transform reductions with intermediate stores.
1 parent 3cebf8e commit 1440ad1

File tree

3 files changed

+62
-4
lines changed

3 files changed

+62
-4
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,8 @@ class TargetTransformInfo {
638638
/// Fall back to the generic logic to determine whether multi-exit unrolling
639639
/// is profitable if set to false.
640640
bool RuntimeUnrollMultiExit;
641+
642+
DenseMap<PHINode *, RecurrenceDescriptor> ParallelizeReductions;
641643
};
642644

643645
/// Get target-customized preferences for the generic loop unrolling

llvm/lib/Transforms/Utils/LoopUnroll.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1210,12 +1210,12 @@ MDNode *llvm::GetUnrollMetadata(MDNode *LoopID, StringRef Name) {
12101210
std::optional<RecurrenceDescriptor>
12111211
llvm::canParallelizeReductionWhenUnrolling(PHINode &Phi, Loop *L,
12121212
ScalarEvolution *SE) {
1213-
RecurrenceDescriptor RedDes;
1214-
if (!RecurrenceDescriptor::isReductionPHI(&Phi, L, RedDes,
1213+
RecurrenceDescriptor RdxDesc;
1214+
if (!RecurrenceDescriptor::isReductionPHI(&Phi, L, RdxDesc,
12151215
/*DemandedBits=*/nullptr,
12161216
/*AC=*/nullptr, /*DT=*/nullptr, SE))
12171217
return std::nullopt;
1218-
RecurKind RK = RedDes.getRecurrenceKind();
1218+
RecurKind RK = RdxDesc.getRecurrenceKind();
12191219
// Skip unsupported reductions.
12201220
// TODO: Handle additional reductions, including FP and min-max
12211221
// reductions.
@@ -1225,6 +1225,9 @@ llvm::canParallelizeReductionWhenUnrolling(PHINode &Phi, Loop *L,
12251225
RecurrenceDescriptor::isMinMaxRecurrenceKind(RK))
12261226
return std::nullopt;
12271227

1228+
if (RdxDesc.IntermediateStore)
1229+
return std::nullopt;
1230+
12281231
// Don't unroll reductions with constant ops; those can be folded to a
12291232
// single induction update.
12301233
if (any_of(cast<Instruction>(Phi.getIncomingValueForBlock(L->getLoopLatch()))
@@ -1239,5 +1242,5 @@ llvm::canParallelizeReductionWhenUnrolling(PHINode &Phi, Loop *L,
12391242
&Phi))
12401243
return std::nullopt;
12411244

1242-
return RedDes;
1245+
return RdxDesc;
12431246
}

llvm/test/Transforms/LoopUnroll/partial-unroll-reductions.ll

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,3 +456,56 @@ loop:
456456
exit:
457457
ret i64 %rdx.next
458458
}
459+
460+
define void @reduction_with_intermediate_store(ptr %src, ptr %sum) {
461+
; CHECK-LABEL: define void @reduction_with_intermediate_store(
462+
; CHECK-SAME: ptr [[SRC:%.*]], ptr [[SUM:%.*]]) {
463+
; CHECK-NEXT: [[ENTRY:.*]]:
464+
; CHECK-NEXT: [[SUM_PROMOTED:%.*]] = load i32, ptr [[SUM]], align 4
465+
; CHECK-NEXT: br label %[[LOOP:.*]]
466+
; CHECK: [[LOOP]]:
467+
; CHECK-NEXT: [[IV:%.*]] = phi i64 [ 0, %[[ENTRY]] ], [ [[IV_NEXT_3:%.*]], %[[LOOP]] ]
468+
; CHECK-NEXT: [[RED:%.*]] = phi i32 [ [[SUM_PROMOTED]], %[[ENTRY]] ], [ [[RED_NEXT_3:%.*]], %[[LOOP]] ]
469+
; CHECK-NEXT: [[GEP_SRC:%.*]] = getelementptr inbounds nuw i32, ptr [[SRC]], i64 [[IV]]
470+
; CHECK-NEXT: [[L:%.*]] = load i32, ptr [[GEP_SRC]], align 4
471+
; CHECK-NEXT: [[RED_NEXT:%.*]] = add nsw i32 [[RED]], [[L]]
472+
; CHECK-NEXT: store i32 [[RED_NEXT]], ptr [[SUM]], align 4
473+
; CHECK-NEXT: [[IV_NEXT:%.*]] = add nuw nsw i64 [[IV]], 1
474+
; CHECK-NEXT: [[GEP_SRC_1:%.*]] = getelementptr inbounds nuw i32, ptr [[SRC]], i64 [[IV_NEXT]]
475+
; CHECK-NEXT: [[L_1:%.*]] = load i32, ptr [[GEP_SRC_1]], align 4
476+
; CHECK-NEXT: [[RED_NEXT_1:%.*]] = add nsw i32 [[RED_NEXT]], [[L_1]]
477+
; CHECK-NEXT: store i32 [[RED_NEXT_1]], ptr [[SUM]], align 4
478+
; CHECK-NEXT: [[IV_NEXT_1:%.*]] = add nuw nsw i64 [[IV]], 2
479+
; CHECK-NEXT: [[GEP_SRC_2:%.*]] = getelementptr inbounds nuw i32, ptr [[SRC]], i64 [[IV_NEXT_1]]
480+
; CHECK-NEXT: [[L_2:%.*]] = load i32, ptr [[GEP_SRC_2]], align 4
481+
; CHECK-NEXT: [[RED_NEXT_2:%.*]] = add nsw i32 [[RED_NEXT_1]], [[L_2]]
482+
; CHECK-NEXT: store i32 [[RED_NEXT_2]], ptr [[SUM]], align 4
483+
; CHECK-NEXT: [[IV_NEXT_2:%.*]] = add nuw nsw i64 [[IV]], 3
484+
; CHECK-NEXT: [[GEP_SRC_3:%.*]] = getelementptr inbounds nuw i32, ptr [[SRC]], i64 [[IV_NEXT_2]]
485+
; CHECK-NEXT: [[L_3:%.*]] = load i32, ptr [[GEP_SRC_3]], align 4
486+
; CHECK-NEXT: [[RED_NEXT_3]] = add nsw i32 [[RED_NEXT_2]], [[L_3]]
487+
; CHECK-NEXT: store i32 [[RED_NEXT_3]], ptr [[SUM]], align 4
488+
; CHECK-NEXT: [[IV_NEXT_3]] = add nuw nsw i64 [[IV]], 4
489+
; CHECK-NEXT: [[EC_3:%.*]] = icmp eq i64 [[IV_NEXT_3]], 10000
490+
; CHECK-NEXT: br i1 [[EC_3]], label %[[EXIT:.*]], label %[[LOOP]]
491+
; CHECK: [[EXIT]]:
492+
; CHECK-NEXT: ret void
493+
;
494+
entry:
495+
%sum.promoted = load i32, ptr %sum, align 4
496+
br label %loop
497+
498+
loop:
499+
%iv = phi i64 [ 0, %entry ], [ %iv.next, %loop ]
500+
%red = phi i32 [ %sum.promoted, %entry ], [ %red.next, %loop ]
501+
%gep.src = getelementptr inbounds nuw i32, ptr %src, i64 %iv
502+
%l = load i32, ptr %gep.src, align 4
503+
%red.next = add nsw i32 %red, %l
504+
store i32 %red.next, ptr %sum, align 4
505+
%iv.next = add nuw nsw i64 %iv, 1
506+
%ec = icmp eq i64 %iv.next, 10000
507+
br i1 %ec, label %exit, label %loop
508+
509+
exit:
510+
ret void
511+
}

0 commit comments

Comments
 (0)