Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 43 additions & 16 deletions test/TritonGPU/amd/accelerate-amd-matmul-mfma.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
// RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul="arch-generation-name=gfx942 matrix-instruction-size=0" --verify-diagnostics | FileCheck %s --check-prefixes MFMA0,CHECK
// RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul="arch-generation-name=gfx942 matrix-instruction-size=16" --verify-diagnostics | FileCheck %s --check-prefixes MFMA16,CHECK
// RUN: split-file %s %t
// RUN: cat %t/common.mlir %t/mfma0.mlir > %t/run-mfma0.mlir
// RUN: triton-opt %t/run-mfma0.mlir -split-input-file --tritonamdgpu-accelerate-matmul="arch-generation-name=gfx942 matrix-instruction-size=0" --verify-diagnostics | FileCheck %t/run-mfma0.mlir --check-prefixes=MFMA0,CHECK
// RUN: cat %t/common.mlir %t/mfma16.mlir > %t/run-mfma16.mlir
// RUN: triton-opt %t/run-mfma16.mlir -split-input-file --tritonamdgpu-accelerate-matmul="arch-generation-name=gfx942 matrix-instruction-size=16" --verify-diagnostics | FileCheck %t/run-mfma16.mlir --check-prefixes=MFMA16,CHECK

//--- common.mlir

#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 4], order = [1, 0]}>
// CHECK-LABEL: mfma_dot_fp8e5m2_fp8e4m3fn
Expand Down Expand Up @@ -64,6 +69,28 @@ module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.n

// -----

// MFMA0-NOT: amd_mfma
// MFMA16-NOT: amd_mfma
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 4], order = [1, 0]}>
// CHECK-LABEL: mfma_dot_small_k
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
tt.func public @mfma_dot_small_k(
%arg0: tensor<128x4xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>,
%arg1: tensor<4x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>,
%arg2: tensor<128x256x!tt.ptr<f32>, #blocked> ) {
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked>
// expected-remark @+2 {{Unable to select MFMA intrinsic}}
// expected-remark @+1 {{Attempting to map dot operation to FMA intrinsic.}}
%1 = tt.dot %arg0, %arg1, %cst : tensor<128x4xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<4x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked>
tt.store %arg2, %1 : tensor<128x256x!tt.ptr<f32>, #blocked>
tt.return
}
}

// -----

//--- mfma0.mlir

// MFMA0-NOT: amd_mfma
// MFMA16: #mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [1, 2], instrShape = [16, 16, 16], isTransposed = true}>
// CHECK-LABEL: small_m_size_fma
Expand All @@ -74,26 +101,26 @@ module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.n
%b: tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>)
-> tensor<1x128xf32, #blocked> {
%zero_f32 = arith.constant dense<0.000000e+00> : tensor<1x128xf32, #blocked>
// expected-remark @+2 {{Unable to select MFMA intrinsic}}
// expected-remark @+1 {{Attempting to map dot operation to FMA intrinsic.}}
%result = tt.dot %a, %b, %zero_f32 : tensor<1x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<1x128xf32, #blocked>
tt.return %result : tensor<1x128xf32, #blocked>
}
}


// -----
//--- mfma16.mlir

// MFMA0-NOT: amd_mfma
// MFMA16-NOT: amd_mfma
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 4], order = [1, 0]}>
// CHECK-LABEL: mfma_dot_small_k
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
tt.func public @mfma_dot_small_k(
%arg0: tensor<128x4xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>,
%arg1: tensor<4x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>,
%arg2: tensor<128x256x!tt.ptr<f32>, #blocked> ) {
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked>
%1 = tt.dot %arg0, %arg1, %cst : tensor<128x4xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<4x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked>
tt.store %arg2, %1 : tensor<128x256x!tt.ptr<f32>, #blocked>
tt.return
// MFMA16: #mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [1, 2], instrShape = [16, 16, 16], isTransposed = true}>
// CHECK-LABEL: small_m_size_fma
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 64], warpsPerCTA = [1, 2], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, "ttg.threads-per-warp" = 64 : i32} {
tt.func @small_m_size_fma(
%a: tensor<1x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>,
%b: tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>)
-> tensor<1x128xf32, #blocked> {
%zero_f32 = arith.constant dense<0.000000e+00> : tensor<1x128xf32, #blocked>
%result = tt.dot %a, %b, %zero_f32 : tensor<1x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<1x128xf32, #blocked>
tt.return %result : tensor<1x128xf32, #blocked>
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -199,16 +199,37 @@ chooseMfmaInstruction(Location loc, int mfmaVersion, RankedTensorType cType,
aElemType, bElemType, withScale, allowXF32);

// Fallback to FMA if the M/N dim is not supported by MFMA.
if (failed(maybeMfmaIntrinsic))
if (failed(maybeMfmaIntrinsic)) {
mlir::emitRemark(loc) << "Unable to select MFMA intrinsic for the request: "
<< "version=" << mfmaVersion << ", result-shape=("
<< M << "x" << N << "), selected-tiles=(" << mDim
<< "x" << nDim << "), inputKSize=" << inputKSize
<< ", aElemType=" << aElemType
<< ", bElemType=" << bElemType
<< ", withScale=" << (withScale ? "true" : "false")
<< ", allowXF32=" << (allowXF32 ? "true" : "false")
<< (enforcedNonKDim != 0
? (llvm::Twine(", enforcedNonKDim=") +
llvm::Twine(enforcedNonKDim))
.str()
: "");
return failure();
}

kDim = maybeMfmaIntrinsic->kDim;
assert(kDim != 0);
assert(enforcedNonKDim != 0 || (M % mDim == 0 && N % nDim == 0));
// If inputKSize % kDim != 0 (including the case where inputKSize < kDim),
// this layout will introduce data duplication.
if (inputKSize % kDim != 0)
if (inputKSize % kDim != 0) {
mlir::emitRemark(loc)
<< "Unable to select MFMA intrinsic '" << maybeMfmaIntrinsic->name
<< "' as MFMA intrinsic k-dimension size kDim=" << kDim
<< ", which is not a multiple of tile k-dimension size inputKSize="
<< inputKSize
<< ". Using this intrinsic would introduce data duplication.";
return failure();
}
return maybeMfmaIntrinsic;
}

Expand Down Expand Up @@ -548,11 +569,15 @@ class BlockedToMFMA : public OpRewritePattern<tt::DotOp> {
chooseMfmaInstruction(dotOp, mfmaVersion, nonKDim, withScale);
if (failed(mfmaInstr)) {
if (!withScale) {
return failure();
return rewriter.notifyMatchFailure(
dotOp,
"Unable to choose preferable MFMA intrinsic for dot operation.");
}
mfmaInstr = chooseMfmaInstruction(dotOp, mfmaVersion, nonKDim, false);
if (failed(mfmaInstr))
return failure();
if (failed(mfmaInstr)) {
return rewriter.notifyMatchFailure(
dotOp, "Unable to choose MFMA intrinsic for dot operation.");
}

withScale = false;
}
Expand Down Expand Up @@ -769,7 +794,8 @@ class ScaledBlockedToMFMA final : public OpRewritePattern<triton::DotScaledOp> {
FailureOr<MfmaIntrinsic> mfmaInstr =
chooseMfmaInstruction(dotOp, mfmaVersion, nonKDim, useFp16);
if (failed(mfmaInstr))
return rewriter.notifyMatchFailure(dotOp, "cannot choose mfma intrinsic");
return rewriter.notifyMatchFailure(
dotOp, "Unable to choose MFMA intrinsic for scaled dot operation.");

if (useFp16) {
dotOp.emitRemark(
Expand Down Expand Up @@ -895,6 +921,13 @@ class DecomposeAMDScaledBlocked final : public ttg::DecomposeScaledBlocked {
: ttg::DecomposeScaledBlocked(context, benefit) {}
using TensorValue = TypedValue<RankedTensorType>;

LogicalResult matchAndRewrite(tt::DotScaledOp dotOp,
PatternRewriter &rewriter) const override {
dotOp.emitRemark() << "Decomposing scaled dot operation into regular dot "
"operation with explicit scaling.";
return ttg::DecomposeScaledBlocked::matchAndRewrite(dotOp, rewriter);
}

RankedTensorType getScaleType(RankedTensorType vType, int32_t kDim,
bool isFp4) const {
if (!isFp4)
Expand Down Expand Up @@ -1018,9 +1051,11 @@ class ScaledBlockedToScaledMFMAF8F6F4 final
// Choose a suitable Scaled MFMA instruction for this scaled dot op.
FailureOr<MfmaIntrinsic> mfmaInstr =
chooseMfmaInstruction(dotOp, mfmaVersion, nonKDim);
if (failed(mfmaInstr))
if (failed(mfmaInstr)) {
return rewriter.notifyMatchFailure(dotOp,
"cannot choose scaled mfma intrinsic");
"Unable to choose preferable MFMA "
"intrinsic for scaled dot operation.");
}

auto mDim = mfmaInstr->mDim;
auto nDim = mfmaInstr->nDim;
Expand Down Expand Up @@ -1474,7 +1509,8 @@ class BlockedToWMMA : public OpRewritePattern<tt::DotOp> {
FailureOr<WmmaIntrinsic> wmmaInstr =
chooseWmmaInstruction(dotOp, operandTypes, wmmaVersion, nonKDim);
if (failed(wmmaInstr)) {
return failure();
return rewriter.notifyMatchFailure(
dotOp, "Unable to choose WMMA intrinsic for dot operation.");
}

auto mDim = wmmaInstr->mDim;
Expand Down Expand Up @@ -1625,7 +1661,8 @@ class AccelerateBlocked : public OpRewritePattern<DotOp> {
LogicalResult tryAccelerateF16WithVDot(DotOp dotOp, PatternRewriter &rewriter,
const DotElTypes &dotTypes) const {
if (!AMD::supportsVDot(arch))
return failure();
return rewriter.notifyMatchFailure(
dotOp, "Target architecture does not support V_DOT instruction.");

// If this is fp16 x fp16 ->fp16 case prioritize using v_dot.
auto aOpType = dotOp.getA().getType();
Expand All @@ -1641,7 +1678,8 @@ class AccelerateBlocked : public OpRewritePattern<DotOp> {
rewriter.replaceOp(dotOp, newD);
return success();
}
return failure();
return rewriter.notifyMatchFailure(
dotOp, "Unable to choose V_DOT instruction for dot operation.");
}

LogicalResult tryLegalizeFMA(DotOp dotOp, PatternRewriter &rewriter,
Expand Down Expand Up @@ -1687,7 +1725,10 @@ class AccelerateBlocked : public OpRewritePattern<DotOp> {
LogicalResult matchAndRewrite(DotOp dotOp,
PatternRewriter &rewriter) const override {
if (!isa<BlockedEncodingAttr>(dotOp.getD().getType().getEncoding()))
return failure();
return rewriter.notifyMatchFailure(
dotOp, "expected blocked encoding result tensor");

dotOp.emitRemark() << "Attempting to map dot operation to FMA intrinsic.";

DotElTypes dotTypes;
dotTypes.a = dotOp.getA().getType().getElementType();
Expand All @@ -1697,7 +1738,8 @@ class AccelerateBlocked : public OpRewritePattern<DotOp> {

// Check that dot is not legalized already
if (isLegalFMAForm(dotOp, dotTypes)) {
return failure();
return rewriter.notifyMatchFailure(
dotOp, "Dot operation is already in FMA form.");
}

// TODO: enable this condition, when fp32 -> fp16 cast works correctly
Expand Down
Loading