-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[mlir][amdgpu] Promote gpu.shuffle to amdgpu.dpp #155158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
✅ With the latest revision this PR passed the C/C++ code formatter. |
e9869a7
to
4b8e73c
Compare
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-backend-amdgpu Author: Tim Gymnich (tgymnich) Changes
Full diff: https://github.com/llvm/llvm-project/pull/155158.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp b/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp
index 67cef8af1e3b5..33655f27a8838 100644
--- a/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp
@@ -11,12 +11,13 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
-#include "mlir/Dialect/GPU/Transforms/Passes.h"
-
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/IR/PatternMatch.h"
#include <optional>
@@ -85,7 +86,7 @@ struct PromoteShuffleToPermlanePattern
int64_t offsetValue = *offset;
if (offsetValue != 16 && offsetValue != 32)
- return rewriter.notifyMatchFailure(op, "offset must be either 15 or 31");
+ return rewriter.notifyMatchFailure(op, "offset must be either 16 or 32");
Location loc = op.getLoc();
Value res = amdgpu::PermlaneSwapOp::create(
@@ -96,13 +97,153 @@ struct PromoteShuffleToPermlanePattern
}
};
+static Value getLaneId(RewriterBase &rewriter, Location loc) {
+ auto int32Type = IntegerType::get(rewriter.getContext(), 32);
+ Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32);
+ Value minus1 = arith::ConstantIntOp::create(rewriter, loc, -1, 32);
+ NamedAttribute noundef = rewriter.getNamedAttr(
+ LLVM::LLVMDialect::getNoUndefAttrName(), rewriter.getUnitAttr());
+ NamedAttribute lowRange = rewriter.getNamedAttr(
+ LLVM::LLVMDialect::getRangeAttrName(),
+ LLVM::ConstantRangeAttr::get(rewriter.getContext(), APInt::getZero(32),
+ APInt(32, 32)));
+ NamedAttribute highRange = rewriter.getNamedAttr(
+ LLVM::LLVMDialect::getRangeAttrName(),
+ LLVM::ConstantRangeAttr::get(rewriter.getContext(), APInt::getZero(32),
+ APInt(32, 64)));
+ Value mbcntLo = ROCDL::MbcntLoOp::create(
+ rewriter, loc, int32Type, minus1, zero, /*arg_attrs=*/{},
+ /*res_attrs=*/
+ rewriter.getArrayAttr(rewriter.getDictionaryAttr({noundef, lowRange})));
+ Value laneId = ROCDL::MbcntHiOp::create(
+ rewriter, loc, int32Type, minus1, mbcntLo, /*arg_attrs=*/{},
+ rewriter.getArrayAttr(rewriter.getDictionaryAttr({noundef, highRange})));
+ return laneId;
+}
+
+/// Try to promote `gpu.shuffle` to `amdgpu.dpp`, width must be 64
+/// and offset must be a constant integer in the set {16, 32}.
+struct PromoteShuffleToDPPPattern : public OpRewritePattern<gpu::ShuffleOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(gpu::ShuffleOp op,
+ PatternRewriter &rewriter) const override {
+ std::optional<int64_t> width = getConstantIntValue(op.getWidth());
+ if (!width)
+ return rewriter.notifyMatchFailure(op,
+ "width must be a constant integer");
+ int64_t widthValue = *width;
+ if (widthValue != 4 && widthValue != 8 && widthValue != 12 &&
+ widthValue != 16 && widthValue != 32 && widthValue != 48 &&
+ widthValue != 64)
+ return rewriter.notifyMatchFailure(
+ op, "width must be 4, 8, 12, 16, 32, 48 or 64");
+
+ std::optional<int64_t> offset = getConstantIntValue(op.getOffset());
+ if (!offset)
+ return rewriter.notifyMatchFailure(op,
+ "offset must be a constant integer");
+
+ int64_t offsetValue = *offset;
+ Location loc = op.getLoc();
+ auto int32Type = IntegerType::get(rewriter.getContext(), 32);
+
+ amdgpu::DPPPerm kind;
+ Attribute permAttr = rewriter.getUnitAttr();
+ Value srcLane;
+ Value dstLane;
+ switch (op.getMode()) {
+ case gpu::ShuffleMode::XOR: {
+ if (offsetValue != 1 && offsetValue != 2)
+ return rewriter.notifyMatchFailure(
+ op, "xor shuffle mode is only supported for offsets of 1 or 2");
+ kind = amdgpu::DPPPerm::quad_perm;
+ srcLane = getLaneId(rewriter, loc);
+ dstLane = LLVM::XOrOp::create(rewriter, loc, int32Type, srcLane,
+ op.getOffset());
+
+ if (offsetValue == 1)
+ permAttr = rewriter.getI32ArrayAttr({1, 0, 3, 2});
+ else if (offsetValue == 2)
+ permAttr = rewriter.getI32ArrayAttr({2, 3, 0, 1});
+ break;
+ }
+ case gpu::ShuffleMode::UP: {
+ if (offsetValue != 1)
+ return rewriter.notifyMatchFailure(
+ op, "up shuffle mode is only supported for offset 1");
+ kind = amdgpu::DPPPerm::wave_shr;
+ srcLane = getLaneId(rewriter, loc);
+ dstLane = LLVM::SubOp::create(rewriter, loc, int32Type, srcLane,
+ op.getOffset());
+ break;
+ }
+ case gpu::ShuffleMode::DOWN: {
+ if (offsetValue != 1)
+ return rewriter.notifyMatchFailure(
+ op, "down shuffle mode is only supported for offset 1");
+ kind = amdgpu::DPPPerm::wave_shl;
+ srcLane = getLaneId(rewriter, loc);
+ dstLane = LLVM::AddOp::create(rewriter, loc, int32Type, srcLane,
+ op.getOffset());
+ break;
+ }
+ case gpu::ShuffleMode::IDX:
+ return rewriter.notifyMatchFailure(op,
+ "idx shuffle mode is not supported");
+ }
+
+ unsigned bankMask = 0xF;
+ if (widthValue == 4)
+ bankMask = 0x1;
+ else if (widthValue == 8)
+ bankMask = 0x3;
+ else if (widthValue == 12)
+ bankMask = 0x7;
+
+ unsigned rowMask = 0xF;
+ if (widthValue == 16)
+ rowMask = 0x1;
+ else if (widthValue == 32)
+ rowMask = 0x3;
+ else if (widthValue == 48)
+ rowMask = 0x7;
+
+ constexpr bool boundCtrl = false;
+
+ Value negwidth =
+ arith::ConstantIntOp::create(rewriter, loc, int32Type, -widthValue);
+ Value add =
+ arith::AddIOp::create(rewriter, loc, int32Type, srcLane, op.getWidth());
+ Value widthOrZeroIfOutside =
+ arith::AndIOp::create(rewriter, loc, int32Type, add, negwidth);
+ Value isActiveSrcLane =
+ arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, dstLane,
+ widthOrZeroIfOutside);
+
+ Value dpp = amdgpu::DPPOp::create(rewriter, loc, op.getResult(0).getType(),
+ op.getValue(), op.getValue(), kind,
+ permAttr, rowMask, bankMask, boundCtrl);
+ Value poison =
+ LLVM::PoisonOp::create(rewriter, loc, op.getResult(0).getType());
+
+ Value selectResult =
+ arith::SelectOp::create(rewriter, loc, isActiveSrcLane, dpp, poison);
+
+ rewriter.replaceOp(op, {selectResult, isActiveSrcLane});
+ return success();
+ }
+};
+
} // namespace
void mlir::populateGpuPromoteShuffleToAMDGPUPatterns(
RewritePatternSet &patterns, std::optional<amdgpu::Chipset> maybeChipset) {
patterns.add<PromoteShuffleToSwizzlePattern>(patterns.getContext(),
/*benefit*/ 1);
+ patterns.add<PromoteShuffleToDPPPattern>(patterns.getContext(),
+ /*benefit*/ 2);
if (maybeChipset && *maybeChipset >= kGfx950)
patterns.add<PromoteShuffleToPermlanePattern>(patterns.getContext(),
- /*benefit*/ 2);
+ /*benefit*/ 3);
}
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
index 71c3e9974611e..6e8741b8e3efa 100644
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -735,13 +735,18 @@ gpu.module @test_module {
}
// CHECK-LABEL: func @gpu_shuffle_promote()
- func.func @gpu_shuffle_promote() -> (f32, f32, f32) {
+ func.func @gpu_shuffle_promote() -> (f32, f32, f32, f32, f32) {
+ // CHECK: %[[#POISON:]] = llvm.mlir.poison : f32
+ // CHECK: %[[#NEGWIDTH:]] = llvm.mlir.constant(-64 : i32) : i32
// CHECK: %[[#VALUE:]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
%arg0 = arith.constant 1.0 : f32
%arg1 = arith.constant 4 : i32
%arg2 = arith.constant 16 : i32
%arg3 = arith.constant 32 : i32
+ // CHECK: %[[#WIDTH:]] = llvm.mlir.constant(64 : i32) : i32
%arg4 = arith.constant 64 : i32
+ // CHECK: %[[#C1:]] = llvm.mlir.constant(1 : i32) : i32
+ %arg5 = arith.constant 1 : i32
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
// CHECK: %[[#MASK:]] = llvm.mlir.constant(4127 : i32) : i32
// CHECK: %[[#PERMUTE:]] = rocdl.ds_swizzle %[[#CAST_VALUE]], %[[#MASK]] : (i32, i32) -> i32
@@ -757,7 +762,78 @@ gpu.module @test_module {
// CHECK: %[[#EXTRACT:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#EXTRACT]] : i32 to f32
%shfl3, %pred3 = gpu.shuffle xor %arg0, %arg3, %arg4 : f32
- func.return %shfl1, %shfl2, %shfl3 : f32, f32, f32
+ // CHECK: %[[#LANE_ID:]] = rocdl.mbcnt.hi
+ // CHECK: %[[#SUB:]] = llvm.sub %[[#LANE_ID]], %[[#C1]] : i32
+ // CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[#WIDTH]] : i32
+ // CHECK: %[[#AND:]] = llvm.and %[[#ADD]], %[[#NEGWIDTH]] : i32
+ // CHECK: %[[#VALID:]] = llvm.icmp "slt" %[[#SUB]], %[[#AND]] : i32
+ // CHECK: %[[#PERMUTE:]] = rocdl.update.dpp %[[#VALUE]], %[[#VALUE]] with 312, 15, 15, false : f32
+ // CHECK: %[[#SELECT:]] = llvm.select %[[#VALID]], %[[#PERMUTE]], %[[#POISON]] : i1, f32
+ %shflu, %predu = gpu.shuffle up %arg0, %arg5, %arg4 : f32
+ // CHECK: %[[#LANE_ID:]] = rocdl.mbcnt.hi
+ // CHECK: %[[#OP:]] = llvm.add %[[#LANE_ID]], %[[#C1]] : i32
+ // CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[#WIDTH]] : i32
+ // CHECK: %[[#AND:]] = llvm.and %[[#ADD]], %[[#NEGWIDTH]] : i32
+ // CHECK: %[[#VALID:]] = llvm.icmp "slt" %[[#OP]], %[[#AND]] : i32
+ // CHECK: %[[#PERMUTE:]] = rocdl.update.dpp %[[#VALUE]], %[[#VALUE]] with 304, 15, 15, false : f32
+ // CHECK: %[[#SELECT:]] = llvm.select %[[#VALID]], %[[#PERMUTE]], %[[#POISON]] : i1, f32
+ %shfld, %predd = gpu.shuffle down %arg0, %arg5, %arg4 : f32
+ func.return %shfl1, %shfl2, %shfl3, %shflu, %shfld : f32, f32, f32, f32, f32
+ }
+
+ // CHECK-LABEL: func @gpu_butterfly_shuffle()
+ func.func @gpu_butterfly_shuffle() -> (f32, f32, f32, f32, f32, f32) {
+ // CHECK: %[[#POISON:]] = llvm.mlir.poison : f32
+ // CHECK: %[[#NEGWIDTH:]] = llvm.mlir.constant(-64 : i32) : i32
+ // CHECK: %[[#VALUE:]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
+ %arg0 = arith.constant 1.0 : f32
+ // CHECK: %[[#C1:]] = llvm.mlir.constant(1 : i32) : i32
+ %c1 = arith.constant 1 : i32
+ // CHECK: %[[#C2:]] = llvm.mlir.constant(2 : i32) : i32
+ %c2 = arith.constant 2 : i32
+ %c4 = arith.constant 4 : i32
+ %c8 = arith.constant 8 : i32
+ %c16 = arith.constant 16 : i32
+ %c32 = arith.constant 32 : i32
+ // CHECK: %[[#WIDTH:]] = llvm.mlir.constant(64 : i32) : i32
+ %c64 = arith.constant 64 : i32
+ // CHECK: %[[#LANE_ID:]] = rocdl.mbcnt.hi
+ // CHECK: %[[#XOR:]] = llvm.xor %[[#LANE_ID]], %[[#C1]] : i32
+ // CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[#WIDTH]] : i32
+ // CHECK: %[[#AND:]] = llvm.and %[[#ADD]], %[[#NEGWIDTH]] : i32
+ // CHECK: %[[#VALID:]] = llvm.icmp "slt" %[[#XOR]], %[[#AND]] : i32
+ // CHECK: %[[#PERMUTE:]] = rocdl.update.dpp %[[#VALUE]], %[[#VALUE]] with 177, 15, 15, false : f32
+ // CHECK: %[[#SELECT:]] = llvm.select %[[#VALID]], %[[#PERMUTE]], %[[#POISON]] : i1, f32
+ %shfl1, %pred1 = gpu.shuffle xor %arg0, %c1, %c64 : f32
+ // CHECK: %[[#LANE_ID:]] = rocdl.mbcnt.hi
+ // CHECK: %[[#XOR:]] = llvm.xor %[[#LANE_ID]], %[[#C2]] : i32
+ // CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[#WIDTH]] : i32
+ // CHECK: %[[#AND:]] = llvm.and %[[#ADD]], %[[#NEGWIDTH]] : i32
+ // CHECK: %[[#VALID:]] = llvm.icmp "slt" %[[#XOR]], %[[#AND]] : i32
+ // CHECK: %[[#PERMUTE:]] = rocdl.update.dpp %[[#VALUE]], %[[#VALUE]] with 78, 15, 15, false : f32
+ // CHECK: %[[#SELECT:]] = llvm.select %[[#VALID]], %[[#PERMUTE]], %[[#POISON]] : i1, f32
+ %shfl2, %pred2 = gpu.shuffle xor %arg0, %c2, %c64 : f32
+ // CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
+ // CHECK: %[[#MASK:]] = llvm.mlir.constant(4127 : i32) : i32
+ // CHECK: %[[#PERMUTE:]] = rocdl.ds_swizzle %[[#CAST_VALUE]], %[[#MASK]] : (i32, i32) -> i32
+ // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#PERMUTE]] : i32 to f32
+ %shfl3, %pred3 = gpu.shuffle xor %arg0, %c4, %c64 : f32
+ // CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
+ // CHECK: %[[#MASK:]] = llvm.mlir.constant(8223 : i32) : i32
+ // CHECK: %[[#PERMUTE:]] = rocdl.ds_swizzle %[[#CAST_VALUE]], %[[#MASK]] : (i32, i32) -> i32
+ // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#PERMUTE]] : i32 to f32
+ %shfl4, %pred4 = gpu.shuffle xor %arg0, %c8, %c64 : f32
+ // CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
+ // CHECK: %[[#PERMUTE:]] = rocdl.permlane16.swap %[[#CAST_VALUE]], %[[#CAST_VALUE]], false, false : (i32, i32) -> <(i32, i32)>
+ // CHECK: %[[#EXTRACT:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)>
+ // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#EXTRACT]] : i32 to f32
+ %shfl5, %pred5 = gpu.shuffle xor %arg0, %c16, %c64 : f32
+ // CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
+ // CHECK: %[[#PERMUTE:]] = rocdl.permlane32.swap %[[#CAST_VALUE]], %[[#CAST_VALUE]], false, false : (i32, i32) -> <(i32, i32)>
+ // CHECK: %[[#EXTRACT:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)>
+ // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#EXTRACT]] : i32 to f32
+ %shfl6, %pred6 = gpu.shuffle xor %arg0, %c32, %c64 : f32
+ func.return %shfl1, %shfl2, %shfl3, %shfl4, %shfl5, %shfl6 : f32, f32, f32, f32, f32, f32
}
// CHECK-LABEL: func @gpu_shuffle_vec
|
@llvm/pr-subscribers-mlir-gpu Author: Tim Gymnich (tgymnich) Changes
Full diff: https://github.com/llvm/llvm-project/pull/155158.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp b/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp
index 67cef8af1e3b5..33655f27a8838 100644
--- a/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp
@@ -11,12 +11,13 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
-#include "mlir/Dialect/GPU/Transforms/Passes.h"
-
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/IR/PatternMatch.h"
#include <optional>
@@ -85,7 +86,7 @@ struct PromoteShuffleToPermlanePattern
int64_t offsetValue = *offset;
if (offsetValue != 16 && offsetValue != 32)
- return rewriter.notifyMatchFailure(op, "offset must be either 15 or 31");
+ return rewriter.notifyMatchFailure(op, "offset must be either 16 or 32");
Location loc = op.getLoc();
Value res = amdgpu::PermlaneSwapOp::create(
@@ -96,13 +97,153 @@ struct PromoteShuffleToPermlanePattern
}
};
+static Value getLaneId(RewriterBase &rewriter, Location loc) {
+ auto int32Type = IntegerType::get(rewriter.getContext(), 32);
+ Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32);
+ Value minus1 = arith::ConstantIntOp::create(rewriter, loc, -1, 32);
+ NamedAttribute noundef = rewriter.getNamedAttr(
+ LLVM::LLVMDialect::getNoUndefAttrName(), rewriter.getUnitAttr());
+ NamedAttribute lowRange = rewriter.getNamedAttr(
+ LLVM::LLVMDialect::getRangeAttrName(),
+ LLVM::ConstantRangeAttr::get(rewriter.getContext(), APInt::getZero(32),
+ APInt(32, 32)));
+ NamedAttribute highRange = rewriter.getNamedAttr(
+ LLVM::LLVMDialect::getRangeAttrName(),
+ LLVM::ConstantRangeAttr::get(rewriter.getContext(), APInt::getZero(32),
+ APInt(32, 64)));
+ Value mbcntLo = ROCDL::MbcntLoOp::create(
+ rewriter, loc, int32Type, minus1, zero, /*arg_attrs=*/{},
+ /*res_attrs=*/
+ rewriter.getArrayAttr(rewriter.getDictionaryAttr({noundef, lowRange})));
+ Value laneId = ROCDL::MbcntHiOp::create(
+ rewriter, loc, int32Type, minus1, mbcntLo, /*arg_attrs=*/{},
+ rewriter.getArrayAttr(rewriter.getDictionaryAttr({noundef, highRange})));
+ return laneId;
+}
+
+/// Try to promote `gpu.shuffle` to `amdgpu.dpp`, width must be 64
+/// and offset must be a constant integer in the set {16, 32}.
+struct PromoteShuffleToDPPPattern : public OpRewritePattern<gpu::ShuffleOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(gpu::ShuffleOp op,
+ PatternRewriter &rewriter) const override {
+ std::optional<int64_t> width = getConstantIntValue(op.getWidth());
+ if (!width)
+ return rewriter.notifyMatchFailure(op,
+ "width must be a constant integer");
+ int64_t widthValue = *width;
+ if (widthValue != 4 && widthValue != 8 && widthValue != 12 &&
+ widthValue != 16 && widthValue != 32 && widthValue != 48 &&
+ widthValue != 64)
+ return rewriter.notifyMatchFailure(
+ op, "width must be 4, 8, 12, 16, 32, 48 or 64");
+
+ std::optional<int64_t> offset = getConstantIntValue(op.getOffset());
+ if (!offset)
+ return rewriter.notifyMatchFailure(op,
+ "offset must be a constant integer");
+
+ int64_t offsetValue = *offset;
+ Location loc = op.getLoc();
+ auto int32Type = IntegerType::get(rewriter.getContext(), 32);
+
+ amdgpu::DPPPerm kind;
+ Attribute permAttr = rewriter.getUnitAttr();
+ Value srcLane;
+ Value dstLane;
+ switch (op.getMode()) {
+ case gpu::ShuffleMode::XOR: {
+ if (offsetValue != 1 && offsetValue != 2)
+ return rewriter.notifyMatchFailure(
+ op, "xor shuffle mode is only supported for offsets of 1 or 2");
+ kind = amdgpu::DPPPerm::quad_perm;
+ srcLane = getLaneId(rewriter, loc);
+ dstLane = LLVM::XOrOp::create(rewriter, loc, int32Type, srcLane,
+ op.getOffset());
+
+ if (offsetValue == 1)
+ permAttr = rewriter.getI32ArrayAttr({1, 0, 3, 2});
+ else if (offsetValue == 2)
+ permAttr = rewriter.getI32ArrayAttr({2, 3, 0, 1});
+ break;
+ }
+ case gpu::ShuffleMode::UP: {
+ if (offsetValue != 1)
+ return rewriter.notifyMatchFailure(
+ op, "up shuffle mode is only supported for offset 1");
+ kind = amdgpu::DPPPerm::wave_shr;
+ srcLane = getLaneId(rewriter, loc);
+ dstLane = LLVM::SubOp::create(rewriter, loc, int32Type, srcLane,
+ op.getOffset());
+ break;
+ }
+ case gpu::ShuffleMode::DOWN: {
+ if (offsetValue != 1)
+ return rewriter.notifyMatchFailure(
+ op, "down shuffle mode is only supported for offset 1");
+ kind = amdgpu::DPPPerm::wave_shl;
+ srcLane = getLaneId(rewriter, loc);
+ dstLane = LLVM::AddOp::create(rewriter, loc, int32Type, srcLane,
+ op.getOffset());
+ break;
+ }
+ case gpu::ShuffleMode::IDX:
+ return rewriter.notifyMatchFailure(op,
+ "idx shuffle mode is not supported");
+ }
+
+ unsigned bankMask = 0xF;
+ if (widthValue == 4)
+ bankMask = 0x1;
+ else if (widthValue == 8)
+ bankMask = 0x3;
+ else if (widthValue == 12)
+ bankMask = 0x7;
+
+ unsigned rowMask = 0xF;
+ if (widthValue == 16)
+ rowMask = 0x1;
+ else if (widthValue == 32)
+ rowMask = 0x3;
+ else if (widthValue == 48)
+ rowMask = 0x7;
+
+ constexpr bool boundCtrl = false;
+
+ Value negwidth =
+ arith::ConstantIntOp::create(rewriter, loc, int32Type, -widthValue);
+ Value add =
+ arith::AddIOp::create(rewriter, loc, int32Type, srcLane, op.getWidth());
+ Value widthOrZeroIfOutside =
+ arith::AndIOp::create(rewriter, loc, int32Type, add, negwidth);
+ Value isActiveSrcLane =
+ arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, dstLane,
+ widthOrZeroIfOutside);
+
+ Value dpp = amdgpu::DPPOp::create(rewriter, loc, op.getResult(0).getType(),
+ op.getValue(), op.getValue(), kind,
+ permAttr, rowMask, bankMask, boundCtrl);
+ Value poison =
+ LLVM::PoisonOp::create(rewriter, loc, op.getResult(0).getType());
+
+ Value selectResult =
+ arith::SelectOp::create(rewriter, loc, isActiveSrcLane, dpp, poison);
+
+ rewriter.replaceOp(op, {selectResult, isActiveSrcLane});
+ return success();
+ }
+};
+
} // namespace
void mlir::populateGpuPromoteShuffleToAMDGPUPatterns(
RewritePatternSet &patterns, std::optional<amdgpu::Chipset> maybeChipset) {
patterns.add<PromoteShuffleToSwizzlePattern>(patterns.getContext(),
/*benefit*/ 1);
+ patterns.add<PromoteShuffleToDPPPattern>(patterns.getContext(),
+ /*benefit*/ 2);
if (maybeChipset && *maybeChipset >= kGfx950)
patterns.add<PromoteShuffleToPermlanePattern>(patterns.getContext(),
- /*benefit*/ 2);
+ /*benefit*/ 3);
}
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
index 71c3e9974611e..6e8741b8e3efa 100644
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -735,13 +735,18 @@ gpu.module @test_module {
}
// CHECK-LABEL: func @gpu_shuffle_promote()
- func.func @gpu_shuffle_promote() -> (f32, f32, f32) {
+ func.func @gpu_shuffle_promote() -> (f32, f32, f32, f32, f32) {
+ // CHECK: %[[#POISON:]] = llvm.mlir.poison : f32
+ // CHECK: %[[#NEGWIDTH:]] = llvm.mlir.constant(-64 : i32) : i32
// CHECK: %[[#VALUE:]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
%arg0 = arith.constant 1.0 : f32
%arg1 = arith.constant 4 : i32
%arg2 = arith.constant 16 : i32
%arg3 = arith.constant 32 : i32
+ // CHECK: %[[#WIDTH:]] = llvm.mlir.constant(64 : i32) : i32
%arg4 = arith.constant 64 : i32
+ // CHECK: %[[#C1:]] = llvm.mlir.constant(1 : i32) : i32
+ %arg5 = arith.constant 1 : i32
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
// CHECK: %[[#MASK:]] = llvm.mlir.constant(4127 : i32) : i32
// CHECK: %[[#PERMUTE:]] = rocdl.ds_swizzle %[[#CAST_VALUE]], %[[#MASK]] : (i32, i32) -> i32
@@ -757,7 +762,78 @@ gpu.module @test_module {
// CHECK: %[[#EXTRACT:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#EXTRACT]] : i32 to f32
%shfl3, %pred3 = gpu.shuffle xor %arg0, %arg3, %arg4 : f32
- func.return %shfl1, %shfl2, %shfl3 : f32, f32, f32
+ // CHECK: %[[#LANE_ID:]] = rocdl.mbcnt.hi
+ // CHECK: %[[#SUB:]] = llvm.sub %[[#LANE_ID]], %[[#C1]] : i32
+ // CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[#WIDTH]] : i32
+ // CHECK: %[[#AND:]] = llvm.and %[[#ADD]], %[[#NEGWIDTH]] : i32
+ // CHECK: %[[#VALID:]] = llvm.icmp "slt" %[[#SUB]], %[[#AND]] : i32
+ // CHECK: %[[#PERMUTE:]] = rocdl.update.dpp %[[#VALUE]], %[[#VALUE]] with 312, 15, 15, false : f32
+ // CHECK: %[[#SELECT:]] = llvm.select %[[#VALID]], %[[#PERMUTE]], %[[#POISON]] : i1, f32
+ %shflu, %predu = gpu.shuffle up %arg0, %arg5, %arg4 : f32
+ // CHECK: %[[#LANE_ID:]] = rocdl.mbcnt.hi
+ // CHECK: %[[#OP:]] = llvm.add %[[#LANE_ID]], %[[#C1]] : i32
+ // CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[#WIDTH]] : i32
+ // CHECK: %[[#AND:]] = llvm.and %[[#ADD]], %[[#NEGWIDTH]] : i32
+ // CHECK: %[[#VALID:]] = llvm.icmp "slt" %[[#OP]], %[[#AND]] : i32
+ // CHECK: %[[#PERMUTE:]] = rocdl.update.dpp %[[#VALUE]], %[[#VALUE]] with 304, 15, 15, false : f32
+ // CHECK: %[[#SELECT:]] = llvm.select %[[#VALID]], %[[#PERMUTE]], %[[#POISON]] : i1, f32
+ %shfld, %predd = gpu.shuffle down %arg0, %arg5, %arg4 : f32
+ func.return %shfl1, %shfl2, %shfl3, %shflu, %shfld : f32, f32, f32, f32, f32
+ }
+
+ // CHECK-LABEL: func @gpu_butterfly_shuffle()
+ func.func @gpu_butterfly_shuffle() -> (f32, f32, f32, f32, f32, f32) {
+ // CHECK: %[[#POISON:]] = llvm.mlir.poison : f32
+ // CHECK: %[[#NEGWIDTH:]] = llvm.mlir.constant(-64 : i32) : i32
+ // CHECK: %[[#VALUE:]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
+ %arg0 = arith.constant 1.0 : f32
+ // CHECK: %[[#C1:]] = llvm.mlir.constant(1 : i32) : i32
+ %c1 = arith.constant 1 : i32
+ // CHECK: %[[#C2:]] = llvm.mlir.constant(2 : i32) : i32
+ %c2 = arith.constant 2 : i32
+ %c4 = arith.constant 4 : i32
+ %c8 = arith.constant 8 : i32
+ %c16 = arith.constant 16 : i32
+ %c32 = arith.constant 32 : i32
+ // CHECK: %[[#WIDTH:]] = llvm.mlir.constant(64 : i32) : i32
+ %c64 = arith.constant 64 : i32
+ // CHECK: %[[#LANE_ID:]] = rocdl.mbcnt.hi
+ // CHECK: %[[#XOR:]] = llvm.xor %[[#LANE_ID]], %[[#C1]] : i32
+ // CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[#WIDTH]] : i32
+ // CHECK: %[[#AND:]] = llvm.and %[[#ADD]], %[[#NEGWIDTH]] : i32
+ // CHECK: %[[#VALID:]] = llvm.icmp "slt" %[[#XOR]], %[[#AND]] : i32
+ // CHECK: %[[#PERMUTE:]] = rocdl.update.dpp %[[#VALUE]], %[[#VALUE]] with 177, 15, 15, false : f32
+ // CHECK: %[[#SELECT:]] = llvm.select %[[#VALID]], %[[#PERMUTE]], %[[#POISON]] : i1, f32
+ %shfl1, %pred1 = gpu.shuffle xor %arg0, %c1, %c64 : f32
+ // CHECK: %[[#LANE_ID:]] = rocdl.mbcnt.hi
+ // CHECK: %[[#XOR:]] = llvm.xor %[[#LANE_ID]], %[[#C2]] : i32
+ // CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[#WIDTH]] : i32
+ // CHECK: %[[#AND:]] = llvm.and %[[#ADD]], %[[#NEGWIDTH]] : i32
+ // CHECK: %[[#VALID:]] = llvm.icmp "slt" %[[#XOR]], %[[#AND]] : i32
+ // CHECK: %[[#PERMUTE:]] = rocdl.update.dpp %[[#VALUE]], %[[#VALUE]] with 78, 15, 15, false : f32
+ // CHECK: %[[#SELECT:]] = llvm.select %[[#VALID]], %[[#PERMUTE]], %[[#POISON]] : i1, f32
+ %shfl2, %pred2 = gpu.shuffle xor %arg0, %c2, %c64 : f32
+ // CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
+ // CHECK: %[[#MASK:]] = llvm.mlir.constant(4127 : i32) : i32
+ // CHECK: %[[#PERMUTE:]] = rocdl.ds_swizzle %[[#CAST_VALUE]], %[[#MASK]] : (i32, i32) -> i32
+ // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#PERMUTE]] : i32 to f32
+ %shfl3, %pred3 = gpu.shuffle xor %arg0, %c4, %c64 : f32
+ // CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
+ // CHECK: %[[#MASK:]] = llvm.mlir.constant(8223 : i32) : i32
+ // CHECK: %[[#PERMUTE:]] = rocdl.ds_swizzle %[[#CAST_VALUE]], %[[#MASK]] : (i32, i32) -> i32
+ // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#PERMUTE]] : i32 to f32
+ %shfl4, %pred4 = gpu.shuffle xor %arg0, %c8, %c64 : f32
+ // CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
+ // CHECK: %[[#PERMUTE:]] = rocdl.permlane16.swap %[[#CAST_VALUE]], %[[#CAST_VALUE]], false, false : (i32, i32) -> <(i32, i32)>
+ // CHECK: %[[#EXTRACT:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)>
+ // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#EXTRACT]] : i32 to f32
+ %shfl5, %pred5 = gpu.shuffle xor %arg0, %c16, %c64 : f32
+ // CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
+ // CHECK: %[[#PERMUTE:]] = rocdl.permlane32.swap %[[#CAST_VALUE]], %[[#CAST_VALUE]], false, false : (i32, i32) -> <(i32, i32)>
+ // CHECK: %[[#EXTRACT:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)>
+ // CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#EXTRACT]] : i32 to f32
+ %shfl6, %pred6 = gpu.shuffle xor %arg0, %c32, %c64 : f32
+ func.return %shfl1, %shfl2, %shfl3, %shfl4, %shfl5, %shfl6 : f32, f32, f32, f32, f32, f32
}
// CHECK-LABEL: func @gpu_shuffle_vec
|
I haven't checked the patch in full, but, can you add a The issue is that having these transforms in |
4b8e73c
to
46b7fd3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks OK, I haven't reviewed the logic carefully but the overall structure seems fine
Later on, it would be nice to move this code out of gpu transforms to clean up build dependencies
promote gpu.shuffle %src xor {1,2} {4,8,12,16,32,48,64}
toamdgpu.dpp quad_perm %src, %mask
promote gpu.shuffle %src up 1 {4,8,12,16,32,48,64}
toamdgpu.dpp wave_shr %src, %mask
promote gpu.shuffle %src down 1 {4,8,12,16,32,48,64}
toamdgpu.dpp wave_shl %src, %mask