Skip to content

Commit 3ef7fde

Browse files
authored
[WS] enable e2e aref (#8262)
* enable e2e aref flow * remove `load-mma-specialization` * update `partition-scheduling` to place unannotated ops into partition of their users * remove canonicalization passes from `automatic-warp-specialization` pass * improve `assign-stage-phase` to track whether `aref.buffer` is from `put` or `get`, and do not thread stage/phase through control flow where unnecessary * `aref-tmem-insertion` doesn't require annotated on `tmem_alloc` w/o source, and if `tmem_alloc` with src has same partition as its user, we do not use arefs for this. ~pending issue with `schedule-loops` and this test https://github.com/triton-lang/triton/blob/main/test/TritonGPU/automatic-warp-specialization.mlir#L118 when there is if-stmt (being discussed). Mark PR draft for now.~
1 parent c6683c8 commit 3ef7fde

File tree

8 files changed

+209
-2625
lines changed

8 files changed

+209
-2625
lines changed

include/triton/Dialect/TritonGPU/Transforms/Passes.td

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -176,24 +176,6 @@ def TritonGPUPartitionScheduling : Pass<"tritongpu-partition-scheduling", "mlir:
176176
}];
177177
}
178178

179-
def TritonGPULoadMMASpecialization : Pass<"tritongpu-load-mma-specialization", "mlir::ModuleOp"> {
180-
let summary = "load MMA specialization";
181-
182-
let description = [{
183-
The `tritongpu-load-mma-specialization` pass looks for matmul loops in the
184-
module and attempts to create a partition schedule, separating async loads
185-
and async MMAs into separate partitions.
186-
}];
187-
188-
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
189-
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"];
190-
191-
let options = [
192-
Option<"numStages", "num-stages", "int32_t", /*default*/"3",
193-
"number of pipeline stages">
194-
];
195-
}
196-
197179
def TritonGPUF32DotTC : Pass<"tritongpu-F32DotTC", "mlir::ModuleOp"> {
198180
let summary = "3xTF32 trick";
199181

lib/Dialect/TritonGPU/Transforms/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ add_triton_library(TritonGPUTransforms
2929
Utility.cpp
3030
LayoutPropagationUtility.cpp
3131
WarpSpecialization/AutomaticWarpSpecialization.cpp
32-
WarpSpecialization/LoadMMASpecialization.cpp
3332
WarpSpecialization/Partition.cpp
3433
WarpSpecialization/OptimizePartitionWarps.cpp
3534
WarpSpecialization/PartitionBuilder.cpp

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/AutomaticWarpSpecialization.cpp

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,7 @@ void AutomaticWarpSpecialization::runOnOperation() {
3636
OpPassManager pm;
3737
pm.addPass(createTritonGPUPartitionScheduling());
3838
pm.addPass(createNVWSInsertAref());
39-
#if 0
4039
pm.addPass(createNVWSInsertTmemAref());
41-
#else
42-
pm.addPass(createTritonGPULoadMMASpecialization({numStages}));
43-
#endif
4440
pm.addPass(createTritonGPURewritePartitionDependencies());
4541
// `int-range-optimizations` and SCCP are good at cleaning up loop arithmetic.
4642
// FIXME: Re-enable integer range analysis once it is fixed.
@@ -50,19 +46,6 @@ void AutomaticWarpSpecialization::runOnOperation() {
5046
pm.addPass(createNVWSLowerAref({numStages}));
5147
pm.addPass(createTritonGPUPartitionLoops());
5248
pm.addPass(createNVWSLowerWarpGroup());
53-
if (failed(runPipeline(pm, getOperation())))
54-
return signalPassFailure();
55-
56-
// Cleanup code generated by warp specialization.
57-
RewritePatternSet patterns(&getContext());
58-
populateForOpDeadArgumentElimination(patterns);
59-
scf::ForOp::getCanonicalizationPatterns(patterns, &getContext());
60-
scf::IfOp::getCanonicalizationPatterns(patterns, &getContext());
61-
WarpSpecializeOp::getCanonicalizationPatterns(patterns, &getContext());
62-
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
63-
return signalPassFailure();
64-
65-
pm.clear();
6649
pm.addPass(createTritonGPUOptimizePartitionWarps());
6750
pm.addPass(createTritonGPUScheduleLoops());
6851
if (failed(runPipeline(pm, getOperation())))

0 commit comments

Comments
 (0)