Skip to content

Commit e15cb57

Browse files
authored
[WS]: partition-scheduler annotates all ops + fixes (#8215)
* all ops, except parent ops (if/reduce), are annotated with partition sets after `partition-scheduler` * patch up downstream passes to handle this properly * partition set of a parent op is inferred in `partition-loop` pass; also infers partitions set for each result * extend `partition-loop` to partition if-stmt across different partition when loops are split * patch `lower-mma-specialization` to annotate all ops, and patch unit test * patch up remainig passes to play well with requirement that all ops must be annotated in ws-loops * there are few remaining issues before `aref-tmem-insertion` can be enabled and `load-mma-specialzation` disabled, which is subject of subsequent PR
1 parent e954c36 commit e15cb57

File tree

16 files changed

+633
-171
lines changed

16 files changed

+633
-171
lines changed

include/triton/Dialect/TritonGPU/Transforms/Partition.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class ForOp;
1717
} // namespace mlir
1818

1919
static constexpr char kPartitionAttrName[] = "ttg.partition";
20+
static constexpr char kPartitionOutputsAttrName[] = "ttg.partition.outputs";
2021
static constexpr char kPartitionStagesAttrName[] = "ttg.partition.stages";
2122
static constexpr char kWarpSpecializeTagAttrName[] = "ttg.warp_specialize.tag";
2223

include/triton/Dialect/TritonGPU/Transforms/PartitionBuilder.h

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,34 +2,47 @@
22
#define TRITON_TRITONGPU_TRANSFORMS_PARTITIONBUILDER_H
33

44
#include "mlir/IR/ImplicitLocOpBuilder.h"
5+
#include "llvm/ADT/SetVector.h"
56

67
namespace mlir::triton::gpu {
78

89
class Partition;
910

1011
using StageCluster = std::optional<std::pair<int, int>>;
1112

13+
// Get the stage and cluster for an operation, if it has one assigned.
14+
void setStageCluster(OpBuilder &b, Operation *op, StageCluster stageCluster);
15+
StageCluster getStageCluster(Operation *op);
16+
1217
struct PartitionBuilder : public ImplicitLocOpBuilder {
1318
using ImplicitLocOpBuilder::ImplicitLocOpBuilder;
1419

1520
Value intCst(int value, unsigned width = 32);
1621
Value boolCst(bool value);
1722

18-
void assignStage(Operation *op, StageCluster stageCluster);
1923
void assignPartition(Operation *op, Partition &partition);
2024

2125
template <typename OpT, typename... Args>
2226
auto createInto(Partition &partition, StageCluster stageCluster,
2327
Args &&...args) {
2428
auto op = create<OpT>(std::forward<Args>(args)...);
2529
assignPartition(op, partition);
26-
assignStage(op, stageCluster);
30+
setStageCluster(*this, op, stageCluster);
2731
return op;
2832
}
2933
};
3034

31-
// Get the stage and cluster for an operation, if it has one assigned.
32-
StageCluster getStageCluster(Operation *op);
35+
template <typename OpT, typename... Args>
36+
OpT createInto(OpBuilder &b, Location loc,
37+
std::optional<SetVector<int>> partitionSet,
38+
StageCluster stageCluster, Args &&...args) {
39+
auto op = b.create<OpT>(loc, std::forward<Args>(args)...);
40+
if (partitionSet) {
41+
setPartition(op, *partitionSet);
42+
setStageCluster(b, op, stageCluster);
43+
}
44+
return op;
45+
}
3346

3447
} // namespace mlir::triton::gpu
3548

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/AutomaticWarpSpecialization.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,11 @@ void AutomaticWarpSpecialization::runOnOperation() {
3636
OpPassManager pm;
3737
pm.addPass(createTritonGPUPartitionScheduling());
3838
pm.addPass(createNVWSInsertAref());
39+
#if 0
40+
pm.addPass(createNVWSInsertTmemAref());
41+
#else
3942
pm.addPass(createTritonGPULoadMMASpecialization({numStages}));
43+
#endif
4044
pm.addPass(createTritonGPURewritePartitionDependencies());
4145
// `int-range-optimizations` and SCCP are good at cleaning up loop arithmetic.
4246
// FIXME: Re-enable integer range analysis once it is fixed.

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/LoadMMASpecialization.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,7 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
573573
for (int i = 0; i < nodes.size(); ++i) {
574574
Node &cur = nodes[i];
575575
Node &next = nodes[(i + 1) % nodes.size()];
576-
if (!samePartition(inBody(cur.op), inBody(next.op))) {
576+
if (!samePartition(cur.op, next.op)) {
577577
cur.barNext = createBarrierAlloc(loop, numMmaStages);
578578
next.barPrev = cur.barNext;
579579
}
@@ -616,6 +616,7 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
616616
continue;
617617
b.setInsertionPoint(node.op);
618618
Value view = createSingleBufferView(b, allocOp, node.index);
619+
b.assignPartition(view.getDefiningOp(), *partitions.getPartition(node.op));
619620
if (auto storeOp = dyn_cast<ttng::TMEMStoreOp>(node.op)) {
620621
storeOp.getDstMutable().assign(view);
621622
storeOp.getDepMutable().clear();
@@ -671,7 +672,6 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
671672
Operation *defOp = operand.getDefiningOp();
672673
if (!defOp || loop.isDefinedOutsideOfLoop(operand))
673674
continue;
674-
defOp = inBody(defOp);
675675

676676
if (partitions.isInRootPartition(defOp)) {
677677
// If the MMA operand is coming from outside the loop, move the alloc out.
@@ -717,7 +717,7 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
717717
}
718718

719719
for (Node &node : nodes) {
720-
Partition *partition = partitions.getPartition(inBody(node.op));
720+
Partition *partition = partitions.getPartition(node.op);
721721
PartitionBuilder b(node.op->getLoc(), loop);
722722

723723
SmallVector<Operation *> defs;
@@ -746,11 +746,13 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
746746
domInfo.properlyDominates(mmaOp, userPred.getDefiningOp())) {
747747
b.restoreInsertionPoint(*incrementPt);
748748
Value bar = createSingleBufferView(b, node.barPrev, curIndex);
749+
b.assignPartition(bar.getDefiningOp(), *partition);
749750
b.createInto<ttng::WaitBarrierOp>(*partition, nodeStageCluster, bar,
750751
curPhase, userPred);
751752
} else {
752753
b.setInsertionPoint(domOp);
753754
Value bar = createSingleBufferView(b, node.barPrev, node.index);
755+
b.assignPartition(bar.getDefiningOp(), *partition);
754756
b.createInto<ttng::WaitBarrierOp>(*partition, nodeStageCluster, bar,
755757
node.phase, userPred);
756758
}
@@ -759,6 +761,7 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
759761
if (isa<scf::IfOp>(domOp->getParentOp()) && accIsMultiBuffered)
760762
b.setInsertionPointToStart(domOp->getBlock());
761763
Value bar = createSingleBufferView(b, node.barPrev, node.index);
764+
b.assignPartition(bar.getDefiningOp(), *partition);
762765
b.createInto<ttng::WaitBarrierOp>(*partition, nodeStageCluster, bar,
763766
node.phase);
764767
}
@@ -767,13 +770,15 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
767770
if (mmaOp == node.op) {
768771
b.setInsertionPoint(mmaOp);
769772
Value bar = createSingleBufferView(b, node.barNext, node.index);
773+
b.assignPartition(bar.getDefiningOp(), *partitions.getPartition(mmaOp));
770774
mmaOp.addCompletionBarrier(bar, userPred);
771775
mmaOp.setIsAsync(true);
772776
} else {
773777
b.setInsertionPointAfter(lastOp);
774778
if (isa<scf::IfOp>(lastOp->getParentOp()) && accIsMultiBuffered)
775779
b.setInsertionPoint(lastOp->getBlock()->getTerminator());
776780
Value bar = createSingleBufferView(b, node.barNext, node.index);
781+
b.assignPartition(bar.getDefiningOp(), *partition);
777782
b.createInto<ttng::ArriveBarrierOp>(*partition, nodeStageCluster, bar,
778783
1);
779784
}
@@ -799,20 +804,26 @@ static LogicalResult pipelineMMA(scf::ForOp &loop, PipelinedMMA &mma,
799804
StageCluster srcStageCluster = getStageCluster(domOp);
800805
b.setInsertionPoint(domOp);
801806
Value emptyView = createSingleBufferView(b, emptyBar, index);
807+
b.assignPartition(emptyView.getDefiningOp(), *partition);
802808
b.createInto<ttng::WaitBarrierOp>(*partition, srcStageCluster, emptyView,
803809
phase);
804810

805811
b.setInsertionPointAfter(lastOp);
806812
Value readyView = createSingleBufferView(b, readyBar, index);
813+
b.assignPartition(readyView.getDefiningOp(), *partition);
807814
b.createInto<ttng::ArriveBarrierOp>(*partition, srcStageCluster, readyView,
808815
1);
809816

810817
b.setInsertionPoint(mmaOp);
811818
Value readyView2 = createSingleBufferView(b, readyBar, index);
819+
b.assignPartition(readyView2.getDefiningOp(),
820+
*partitions.getPartition(mmaOp));
812821
b.createInto<ttng::WaitBarrierOp>(*partitions.getPartition(mmaOp),
813822
getStageCluster(mmaOp), readyView2,
814823
phase);
815824
Value emptyView2 = createSingleBufferView(b, emptyBar, index);
825+
b.assignPartition(emptyView2.getDefiningOp(),
826+
*partitions.getPartition(mmaOp));
816827
mmaOp.addCompletionBarrier(emptyView2, b.boolCst(true));
817828
mmaOp.setIsAsync(true);
818829
}

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionBuilder.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,6 @@ Value PartitionBuilder::boolCst(bool value) {
1414
return intCst(value, /*width=*/1);
1515
}
1616

17-
void PartitionBuilder::assignStage(Operation *op, StageCluster stageCluster) {
18-
if (stageCluster) {
19-
op->setAttr(kLoopStageAttrName, getI32IntegerAttr(stageCluster->first));
20-
op->setAttr(kLoopClusterAttrName, getI32IntegerAttr(stageCluster->second));
21-
}
22-
}
23-
2417
void PartitionBuilder::assignPartition(Operation *op, Partition &partition) {
2518
setPartition(op, &partition);
2619
}
@@ -32,3 +25,12 @@ StageCluster triton::gpu::getStageCluster(Operation *op) {
3225
return std::nullopt;
3326
return std::make_pair(stageAttr.getInt(), clusterAttr.getInt());
3427
}
28+
29+
void triton::gpu::setStageCluster(OpBuilder &b, Operation *op,
30+
StageCluster stageCluster) {
31+
if (stageCluster) {
32+
op->setAttr(kLoopStageAttrName, b.getI32IntegerAttr(stageCluster->first));
33+
op->setAttr(kLoopClusterAttrName,
34+
b.getI32IntegerAttr(stageCluster->second));
35+
}
36+
}

0 commit comments

Comments
 (0)