Skip to content

Commit b7f889a

Browse files
authored
[mlir][AMDGPU] Add canonicalizer for folding casts into gather_to_lds (#150503)
1 parent 42b101d commit b7f889a

File tree

3 files changed

+70
-0
lines changed

3 files changed

+70
-0
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -921,6 +921,7 @@ def AMDGPU_GatherToLDSOp :
921921
$src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` $transferType `,` type($src) `,` type($dst)
922922
}];
923923
let hasVerifier = 1;
924+
let hasCanonicalizer = 1;
924925
}
925926

926927
def AMDGPU_TransposeLoadOp :

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,10 @@ LogicalResult DPPOp::verify() {
510510
return success();
511511
}
512512

513+
//===----------------------------------------------------------------------===//
514+
// GatherToLDSOp
515+
//===----------------------------------------------------------------------===//
516+
513517
LogicalResult GatherToLDSOp::verify() {
514518
MemRefType srcType = cast<MemRefType>(getSrc().getType());
515519
MemRefType dstType = cast<MemRefType>(getDst().getType());
@@ -546,6 +550,42 @@ LogicalResult GatherToLDSOp::verify() {
546550
return success();
547551
}
548552

553+
namespace {
554+
/// If the source/target of a GatherToLDSOp is a CastOp that only removes static
555+
/// information or changes layout, the cast can be skipped.
556+
struct FoldGatherToLDSOfCast final : OpRewritePattern<GatherToLDSOp> {
557+
using OpRewritePattern::OpRewritePattern;
558+
559+
LogicalResult matchAndRewrite(GatherToLDSOp gatherOp,
560+
PatternRewriter &rewriter) const override {
561+
bool modified = false;
562+
auto foldCast = [&](OpOperand &operand) {
563+
if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
564+
if (memref::CastOp::canFoldIntoConsumerOp(castOp)) {
565+
rewriter.modifyOpInPlace(gatherOp,
566+
[&] { operand.assign(castOp.getSource()); });
567+
modified = true;
568+
}
569+
}
570+
};
571+
572+
foldCast(gatherOp.getSrcMutable());
573+
foldCast(gatherOp.getDstMutable());
574+
575+
return success(modified);
576+
}
577+
};
578+
} // namespace
579+
580+
void GatherToLDSOp::getCanonicalizationPatterns(RewritePatternSet &results,
581+
MLIRContext *context) {
582+
results.add<FoldGatherToLDSOfCast>(context);
583+
}
584+
585+
//===----------------------------------------------------------------------===//
586+
// TransposeLoadOp
587+
//===----------------------------------------------------------------------===//
588+
549589
LogicalResult TransposeLoadOp::verify() {
550590
MemRefType srcType = cast<MemRefType>(getSrc().getType());
551591

mlir/test/Dialect/AMDGPU/canonicalize.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,32 @@ func.func @dead_atomic_add(%arg0: memref<4xf32>, %arg1: f32) {
130130
amdgpu.raw_buffer_atomic_fadd {boundsCheck = true} %arg1 -> %arg0[%c4_i32] : f32 -> memref<4xf32>, i32
131131
func.return
132132
}
133+
134+
// -----
135+
136+
// CHECK-LABEL: func @fold_gather_to_lds_of_cast
137+
func.func @fold_gather_to_lds_of_cast(%global: memref<128x72xf32, 1>, %lds: memref<64x64xf32, 3>) {
138+
// CHECK-SAME: %[[GLOBAL:[A-Za-z0-9]+]]: memref<128x72xf32, 1>
139+
%c0 = arith.constant 0 : index
140+
%0 = memref.cast %global : memref<128x72xf32, 1> to memref<?x?xf32, 1>
141+
// CHECK: amdgpu.gather_to_lds %[[GLOBAL]]
142+
// CHECK-SAME: : f32, memref<128x72xf32, 1>
143+
amdgpu.gather_to_lds %0[%c0, %c0], %lds[%c0, %c0]
144+
: f32, memref<?x?xf32, 1>, memref<64x64xf32, 3>
145+
func.return
146+
}
147+
148+
// -----
149+
150+
// CHECK-LABEL: func @fold_gather_to_lds_of_cast_dest
151+
func.func @fold_gather_to_lds_of_cast_dest(%global: memref<128x72xf32, 1>, %lds: memref<64x64xf32, 3>) {
152+
// CHECK-SAME: %[[GLOBAL:[A-Za-z0-9]+]]: memref<128x72xf32, 1>
153+
// CHECK-SAME: %[[LDS:[A-Za-z0-9]+]]: memref<64x64xf32, 3>
154+
%c0 = arith.constant 0 : index
155+
%0 = memref.cast %lds : memref<64x64xf32, 3> to memref<?x?xf32, 3>
156+
// CHECK: amdgpu.gather_to_lds %[[GLOBAL]][{{.*}}], %[[LDS]]
157+
// CHECK-SAME: : f32, memref<128x72xf32, 1>, memref<64x64xf32, 3>
158+
amdgpu.gather_to_lds %global[%c0, %c0], %0[%c0, %c0]
159+
: f32, memref<128x72xf32, 1>, memref<?x?xf32, 3>
160+
func.return
161+
}

0 commit comments

Comments
 (0)