Skip to content

Commit 645ee06

Browse files
authored
[Rebase]: ttg.ViewOp->ttg.ReshapeOp in commit: 607190f… (#131)
* [Rebase]: ttg.ViewOp->ttg.ReshapeOp in commit: 607190f * Update TODO failed cases: test_convertmma2mma
1 parent 055bcce commit 645ee06

File tree

6 files changed

+70
-11
lines changed

6 files changed

+70
-11
lines changed

.github/tests/triton_todo_failure_tests.log

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,43 @@ test_subprocess.py::test_print[device_print_large-int32]
155155
test_subprocess.py::test_print[device_print_multiple_args-int32]
156156
test_subprocess.py::test_print[print_multiple_args-int32]
157157
test_subprocess.py::test_print[print_no_arg-int32]
158+
test_core.py::test_convertmma2mma[mma_pair0-float16-64-1]
159+
test_core.py::test_convertmma2mma[mma_pair0-float16-1-64]
160+
test_core.py::test_convertmma2mma[mma_pair0-float16-64-64]
161+
test_core.py::test_convertmma2mma[mma_pair0-float16-128-128]
162+
test_core.py::test_convertmma2mma[mma_pair0-float16-256-256]
163+
test_core.py::test_convertmma2mma[mma_pair1-float16-64-1]
164+
test_core.py::test_convertmma2mma[mma_pair1-float16-1-64]
165+
test_core.py::test_convertmma2mma[mma_pair1-float16-64-64]
166+
test_core.py::test_convertmma2mma[mma_pair1-float16-128-128]
167+
test_core.py::test_convertmma2mma[mma_pair1-float16-256-256]
168+
test_core.py::test_convertmma2mma[mma_pair2-float16-64-1]
169+
test_core.py::test_convertmma2mma[mma_pair2-float16-1-64]
170+
test_core.py::test_convertmma2mma[mma_pair2-float16-64-64]
171+
test_core.py::test_convertmma2mma[mma_pair2-float16-128-128]
172+
test_core.py::test_convertmma2mma[mma_pair2-float16-256-256]
173+
test_core.py::test_convertmma2mma[mma_pair3-float16-64-1]
174+
test_core.py::test_convertmma2mma[mma_pair3-float16-1-64]
175+
test_core.py::test_convertmma2mma[mma_pair3-float16-64-64]
176+
test_core.py::test_convertmma2mma[mma_pair3-float16-128-128]
177+
test_core.py::test_convertmma2mma[mma_pair3-float16-256-256]
178+
test_core.py::test_convertmma2mma[mma_pair0-float16-64-1]
179+
test_core.py::test_convertmma2mma[mma_pair0-float16-1-64]
180+
test_core.py::test_convertmma2mma[mma_pair0-float16-64-64]
181+
test_core.py::test_convertmma2mma[mma_pair0-float16-128-128]
182+
test_core.py::test_convertmma2mma[mma_pair0-float16-256-256]
183+
test_core.py::test_convertmma2mma[mma_pair1-float16-64-1]
184+
test_core.py::test_convertmma2mma[mma_pair1-float16-1-64]
185+
test_core.py::test_convertmma2mma[mma_pair1-float16-64-64]
186+
test_core.py::test_convertmma2mma[mma_pair1-float16-128-128]
187+
test_core.py::test_convertmma2mma[mma_pair1-float16-256-256]
188+
test_core.py::test_convertmma2mma[mma_pair2-float16-64-1]
189+
test_core.py::test_convertmma2mma[mma_pair2-float16-1-64]
190+
test_core.py::test_convertmma2mma[mma_pair2-float16-64-64]
191+
test_core.py::test_convertmma2mma[mma_pair2-float16-128-128]
192+
test_core.py::test_convertmma2mma[mma_pair2-float16-256-256]
193+
test_core.py::test_convertmma2mma[mma_pair3-float16-64-1]
194+
test_core.py::test_convertmma2mma[mma_pair3-float16-1-64]
195+
test_core.py::test_convertmma2mma[mma_pair3-float16-64-64]
196+
test_core.py::test_convertmma2mma[mma_pair3-float16-128-128]
197+
test_core.py::test_convertmma2mma[mma_pair3-float16-256-256]

lib/Conversion/TritonGPUToSPIRV/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
add_mlir_conversion_library(TritonGPUToSPIRV
22
TritonGPUToSPIRV.cpp
3-
ViewOpToSPIRV.cpp
3+
ReshapeOpToSPIRV.cpp
44
ElementwiseOpToSPIRV.cpp
55
TritonGPUToSPIRVPass.cpp
66
ConvertLayoutOpToSPIRV/SharedToDotOperandFMA.cpp

lib/Conversion/TritonGPUToSPIRV/ViewOpToSPIRV.cpp renamed to lib/Conversion/TritonGPUToSPIRV/ReshapeOpToSPIRV.cpp

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include "ViewOpToSPIRV.h"
1+
#include "ReshapeOpToSPIRV.h"
22

33
using namespace mlir;
44
using namespace mlir::triton;
@@ -145,16 +145,35 @@ struct CatOpSPIRVConversion
145145
}
146146
};
147147

148-
struct ViewOpSPIRVConversion : public ConvertTritonGPUOpToSPIRVPattern<ViewOp> {
149-
using OpAdaptor = typename ViewOp::Adaptor;
148+
struct ReshapeOpSPIRVConversion
149+
: public ConvertTritonGPUOpToSPIRVPattern<ReshapeOp> {
150+
using OpAdaptor = typename ReshapeOp::Adaptor;
150151
using ConvertTritonGPUOpToSPIRVPattern<
151-
ViewOp>::ConvertTritonGPUOpToSPIRVPattern;
152+
ReshapeOp>::ConvertTritonGPUOpToSPIRVPattern;
152153

153154
LogicalResult
154-
matchAndRewrite(ViewOp op, OpAdaptor adaptor,
155+
matchAndRewrite(ReshapeOp op, OpAdaptor adaptor,
155156
ConversionPatternRewriter &rewriter) const override {
156157
Location loc = op->getLoc();
157158
auto resultTy = op.getType().template cast<RankedTensorType>();
159+
auto srcTy = op.getSrc().getType().template cast<RankedTensorType>();
160+
if (!op.getAllowReorder()) {
161+
// Only support trivial block layouts for now.
162+
auto mod = op->getParentOfType<ModuleOp>();
163+
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
164+
int threadsPerWarp =
165+
triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
166+
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod);
167+
assert(resultTy.getEncoding() == triton::gpu::getDefaultBlockedEncoding(
168+
op.getContext(), resultTy.getShape(),
169+
numWarps, threadsPerWarp, numCTAs) &&
170+
"ReshapeOp lowering only support block encoding right now.");
171+
assert(srcTy.getEncoding() == triton::gpu::getDefaultBlockedEncoding(
172+
op.getContext(), srcTy.getShape(),
173+
numWarps, threadsPerWarp, numCTAs) &&
174+
"ReshapeOp lowering only support block encoding right now.");
175+
}
176+
158177
auto vals = this->getTypeConverter()->unpackLLElements(
159178
loc, adaptor.getSrc(), rewriter, op.getOperand().getType());
160179
Value ret =
@@ -235,7 +254,7 @@ void populateViewOpToSPIRVPatterns(
235254
mlir::ModuleAllocation *allocation, mlir::Value smem,
236255
mlir::PatternBenefit benefit,
237256
std::map<std::string, int> &computeCapability) {
238-
patterns.add<ViewOpSPIRVConversion>(typeConverter, context, benefit);
257+
patterns.add<ReshapeOpSPIRVConversion>(typeConverter, context, benefit);
239258
patterns.add<ExpandDimsOpSPIRVConversion>(typeConverter, context, benefit);
240259
patterns.add<SplatOpSPIRVConversion>(typeConverter, context, benefit);
241260
patterns.add<ArithConstantSplatOpSPIRVConversion>(

lib/Conversion/TritonGPUToSPIRV/ViewOpToSPIRV.h renamed to lib/Conversion/TritonGPUToSPIRV/ReshapeOpToSPIRV.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
#ifndef TRITON_VIEWOPTOSPIRV_H
2-
#define TRITON_VIEWOPTOSPIRV_H
1+
#ifndef TRITON_RESHAPEOPTOSPIRV_H
2+
#define TRITON_RESHAPEOPTOSPIRV_H
33

44
#include "TritonGPUToSPIRVBase.h"
55
#include "triton/Analysis/Membar.h"

lib/Conversion/TritonGPUToSPIRV/TritonGPUToSPIRVPass.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222
#include "ElementwiseOpToSPIRV.h"
2323
#include "LoadStoreOpToSPIRV.h"
2424
#include "ReduceOpToSPIRV.h"
25+
#include "ReshapeOpToSPIRV.h"
2526
#include "ScanOpToSPIRV.h"
2627
#include "TritonGPUToSPIRV.h"
2728
#include "TypeConverter.h"
28-
#include "ViewOpToSPIRV.h"
2929

3030
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
3131

triton_hash.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
f168b148ecdd067205c6066bc3e6939fd67ab893
1+
607190fda19daa9ecbccd7b10ed49d24b7916fd2

0 commit comments

Comments
 (0)