Skip to content

Commit 2f9b83b

Browse files
authored
Refactor BlockwiseMatrixParamsAttr (#2059)
* Refactor to use BlockwiseMatrixParamsAttr
1 parent a78bfd0 commit 2f9b83b

29 files changed

+452
-646
lines changed

mlir/include/mlir/Dialect/Rock/IR/AccelEmitter.h

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,8 @@ struct AccelEmitter {
100100
/// is dependent on the type of accelerator we are targeting
101101
virtual Value
102102
wrapLDSBufferForLoad(OpBuilder &b, Location loc, Value buffer,
103-
int64_t blockSize, int64_t dInCopyPerThread,
104-
StringRef dName, bool rotateDWithK, bool directToLds,
105-
bool ldsLayoutDxK,
106-
bool doSplitKAcrossThreadsFirst = false) const = 0;
103+
const BlockwiseMatrixParamsAttr &matrixParams,
104+
int64_t blockSize, StringRef dName) const = 0;
107105

108106
/// This functions creates the subtile views that is :
109107
/// 1) gridSubTileView :
@@ -177,12 +175,9 @@ struct MfmaEmitter : public AccelEmitter {
177175
void emitThreadwiseLoop(OpBuilder &b, Location loc, Value argA, Value argB,
178176
Value bufferC, ValueRange regCOffset) override;
179177

180-
Value
181-
wrapLDSBufferForLoad(OpBuilder &b, Location loc, Value buffer,
182-
int64_t blockSize, int64_t dInCopyPerThread,
183-
StringRef dName, bool rotateDWithK, bool directToLds,
184-
bool ldsLayoutDxK,
185-
bool doSplitKAcrossThreadsFirst = false) const override;
178+
Value wrapLDSBufferForLoad(OpBuilder &b, Location loc, Value buffer,
179+
const BlockwiseMatrixParamsAttr &matrixParams,
180+
int64_t blockSize, StringRef dName) const override;
186181

187182
FailureOr<RegsAsMatrixSubTiles> createAccelGemmOperandTransforms(
188183
OpBuilder &b, Location loc, int64_t kIters,
@@ -225,12 +220,9 @@ struct WmmaEmitter : public AccelEmitter {
225220
void emitThreadwiseLoop(OpBuilder &b, Location loc, Value argA, Value argB,
226221
Value bufferC, ValueRange regCOffset) override;
227222

228-
Value
229-
wrapLDSBufferForLoad(OpBuilder &b, Location loc, Value buffer,
230-
int64_t blockSize, int64_t dInCopyPerThread,
231-
StringRef dName, bool rotateDWithK, bool directToLds,
232-
bool ldsLayoutDxK,
233-
bool doSplitKAcrossThreadsFirst = false) const override;
223+
Value wrapLDSBufferForLoad(OpBuilder &b, Location loc, Value buffer,
224+
const BlockwiseMatrixParamsAttr &matrixParams,
225+
int64_t blockSize, StringRef dName) const override;
234226

235227
FailureOr<RegsAsMatrixSubTiles> createAccelGemmOperandTransforms(
236228
OpBuilder &b, Location loc, int64_t kIters,

mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,4 +556,29 @@ def Rock_PrefillAttr : Rock_Attr<"Prefill"> {
556556
let assemblyFormat = "`<` params `>`";
557557
}
558558

559+
def Rock_BlockwiseMatrixParamsAttr : Rock_Attr<"BlockwiseMatrixParams", []> {
560+
let mnemonic = "blockwise_matrix_params";
561+
let description = [{
562+
Encapsulates rock.blockwise_load_tile and rock.blockwise_gemm_accel parameters.
563+
- elementType: Element type of the matrix operation.
564+
- elementTypeLoad: Element type of that was actually loaded from memory (before any input fusion).
565+
- rotateDWithK: Trick to reduce LDS bank conflicts (see more info here: https://github.com/ROCm/rocMLIR/pull/1209)
566+
- swapThreadIterSubDims: Trick to reduce LDS bank conflicts (see more info here: https://github.com/ROCm/rocMLIR/pull/1209)
567+
- LDSLayoutDxK: Wheter the layout in LDS is DxK
568+
- directToLDS: Wheter direct to LDS is enabled
569+
- splitKAcrossThreadsFirst: Used for attention, when bypassing LDS for the result of the first GEMM, explanation here: https://github.com/ROCm/rocMLIR-internal/issues/1201#issuecomment-1898925539
570+
- g: gemm parameter G
571+
- d: gemm parameter D (could be M or N)
572+
- inDPerThread: How many elements of D (M or N) each thread is going to load from memory.
573+
}];
574+
let parameters = (ins "Type":$elementType, "Type":$elementTypeLoad,
575+
"bool":$rotateDWithK, "bool":$swapThreadIterSubDims, "bool":$LDSLayoutDxK,
576+
"bool":$directToLDS, "bool":$splitKAcrossThreadsFirst, "int64_t":$g,
577+
"int64_t":$d, "int64_t":$inDPerThread);
578+
579+
let assemblyFormat = [{
580+
`<` struct(params) `>`
581+
}];
582+
}
583+
559584
#endif

mlir/include/mlir/Dialect/Rock/IR/RockOps.td

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1365,26 +1365,22 @@ def Rock_BlockwiseGemmAccelOp
13651365
[AttrSizedOperandSegments,
13661366
DeclareOpInterfaceMethods<RockGemmFeaturesInterface>,
13671367
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]>,
1368-
Arguments<(ins Optional<MemRefOf<LdsBufferTypes>>:$matrixA,
1368+
Arguments<(ins MemRefOf<LdsBufferTypes>:$bufferA,
1369+
MemRefOf<LdsBufferTypes>:$bufferB, MemRefOf<AccelResTypes>:$matrixC,
1370+
Rock_BlockwiseMatrixParamsAttr:$matrixParamsA,
1371+
Rock_BlockwiseMatrixParamsAttr:$matrixParamsB,
1372+
Optional<MemRefOf<LdsBufferTypes>>:$matrixA,
13691373
Optional<MemRefOf<LdsBufferTypes>>:$matrixB,
13701374
Optional<MemRefOf<LdsBufferTypes>>:$scaleA,
1371-
Optional<MemRefOf<LdsBufferTypes>>:$scaleB, I32Attr:$inMPerThread,
1372-
I32Attr:$inNPerThread, UnitAttr:$rotateMWithK, UnitAttr:$rotateNWithK,
1373-
UnitAttr:$loadAfromLDS, UnitAttr:$loadBfromLDS,
1374-
UnitAttr:$splitKAcrossThreadsFirstA,
1375-
UnitAttr:$splitKAcrossThreadsFirstB, UnitAttr:$directToLDS,
1376-
UnitAttr:$ldsLayoutMxK, UnitAttr:$ldsLayoutNxK,
1377-
MemRefOf<LdsBufferTypes>:$bufferA, MemRefOf<LdsBufferTypes>:$bufferB,
1378-
MemRefOf<AccelResTypes>:$matrixC,
1375+
Optional<MemRefOf<LdsBufferTypes>>:$scaleB,
13791376
Optional<MemRefOf<LdsBufferTypes>>:$bufferScaleA,
13801377
Optional<MemRefOf<LdsBufferTypes>>:$bufferScaleB,
1381-
TypeAttr:$elementTypeA, TypeAttr:$elementTypeB,
13821378
OptionalAttr<Rock_GemmFeaturesAttr>:$features, I32Attr:$blockSize,
13831379
RockAccelTuningParamAttrInterface:$params)> {
13841380
let summary = "Blockwise GEMM accelerated version";
13851381
let description = [{
13861382
The `rock.blockwise_gemm_accel` op does GEMM at workgroup (block) level.
1387-
- Matrix A and Matrix B shall reside on LDS or registers (depending on loadAfromLDS and loadBfromLDS).
1383+
- Matrix A and Matrix B shall reside on registers (if matrixA or matrixB are passed, we load them from LDS).
13881384
- Matrix C shall be vectors.
13891385

13901386
The elements of matrices A and B should be vectors of length kpack, or
@@ -1410,11 +1406,11 @@ def Rock_BlockwiseLoadTileOp
14101406
Arg<Optional<MemRefOf<LdsBufferTypes>>, "destination LDS">:$destLDS,
14111407
Arg<Optional<MemRefOf<NativeMemoryOpTypes>>,
14121408
"destination registers">:$destRegisters,
1413-
Rock_GemmLoadTileTypeAttr:$loadType, UnitAttr:$isA,
1414-
TypeAttr:$elementTypeA, TypeAttr:$elementTypeB, TypeAttr:$elementType,
1415-
TypeAttr:$elementLoadType, UnitAttr:$rotateWithK,
1416-
UnitAttr:$swapThreadIterSubDims, UnitAttr:$LDSLayoutDxK,
1417-
Variadic<Index>:$sourceIndices, I64Attr:$G, I64Attr:$M, I64Attr:$N,
1409+
Rock_GemmLoadTileTypeAttr:$loadType, TypeAttr:$elementType,
1410+
TypeAttr:$elementLoadType,
1411+
Rock_BlockwiseMatrixParamsAttr:$matrixParamsA,
1412+
Rock_BlockwiseMatrixParamsAttr:$matrixParamsB, UnitAttr:$isA,
1413+
Variadic<Index>:$sourceIndices,
14181414
OptionalAttr<Rock_GemmFeaturesAttr>:$features, I32Attr:$blockSize,
14191415
RockAccelTuningParamAttrInterface:$params)> {
14201416
let summary =
@@ -1427,6 +1423,8 @@ def Rock_BlockwiseLoadTileOp
14271423
- Default: Creates two stages, (1) load from memory, (2) write to LDS.
14281424
- BypassLDS: Bypasses LDS and loads from device memory to registers directly (only one stage).
14291425
- DoubleBuffer: Creates three stages, (1) load from memory, (2) write to LDS, (3) load to registers.
1426+
- DirectToLDSDefault: Same as Default, but a single stage loads from memory and writes to LDS.
1427+
- DirectToLDSDoubleBuffer: Same as DoubleBuffer, but a single stage loads from memory and writes to LDS.
14301428

14311429
`isA` determines if we are loading an A matrix or B matrix. `G`, `M` and `N` are the GEMM sizes.
14321430
`elementTypeA` and `elementTypeB` are used to construct AccelEmitter. They are data types for the Matrix A & B of the GEMMs.
@@ -1464,9 +1462,9 @@ def Rock_ThreadwiseGemmOp
14641462
let hasVerifier = 1;
14651463
}
14661464

1467-
// threadwise_accel_gemm
1468-
def Rock_ThreadwiseAccelGemmOp
1469-
: Rock_Op<"threadwise_accel_gemm",
1465+
// threadwise_gemm_accel
1466+
def Rock_ThreadwiseGemmAccelOp
1467+
: Rock_Op<"threadwise_gemm_accel",
14701468
[DeclareOpInterfaceMethods<RockGemmFeaturesInterface>,
14711469
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
14721470
AttrSizedOperandSegments]>,
@@ -1481,7 +1479,7 @@ def Rock_ThreadwiseAccelGemmOp
14811479
RockAccelTuningParamAttrInterface:$params)> {
14821480
let summary = "Accelerated GEMM";
14831481
let description = [{
1484-
The `rock.accel_gemm` op is an abstraction of doing GEMM based on an accelerator.
1482+
The `rock.threadwise_gemm_accel` op is an abstraction of doing GEMM based on an accelerator.
14851483
It would employ a series of accelerator (e.g., mfma or wmma) operations.
14861484

14871485
Matrices A and B reside in LDS, the buffers live in registers, C is a vector

mlir/lib/Dialect/Rock/IR/RockDialect.cpp

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2214,6 +2214,14 @@ LogicalResult BlockwiseLoadTileOp::verify() {
22142214
GemmLoadTileType loadType = getLoadType();
22152215
bool singleBuffer = loadType == GemmLoadTileType::Default ||
22162216
loadType == GemmLoadTileType::DirectToLDSDefault;
2217+
bool directToLDS = loadType == GemmLoadTileType::DirectToLDSDefault ||
2218+
loadType == GemmLoadTileType::DirectToLDSDoubleBuffer;
2219+
2220+
bool paramsDirectToLDS = getIsA() ? getMatrixParamsA().getDirectToLDS()
2221+
: getMatrixParamsB().getDirectToLDS();
2222+
2223+
if (paramsDirectToLDS != directToLDS)
2224+
return emitOpError("Inconsistency between params and load type");
22172225

22182226
if (!destLDS && loadType != GemmLoadTileType::BypassLDS)
22192227
return emitOpError("destLDS must be set unless loadType is BypassLDS");
@@ -2261,20 +2269,17 @@ void BlockwiseGemmOp::getEffects(
22612269
//===----------------------------------------------------------------------===//
22622270

22632271
LogicalResult BlockwiseGemmAccelOp::verify() {
2264-
bool loadAFromLDS = getLoadAfromLDS();
2265-
bool loadBFromLDS = getLoadBfromLDS();
22662272
bool hasA = getMatrixA() != nullptr;
22672273
bool hasB = getMatrixB() != nullptr;
2274+
bool directToLDS = getMatrixParamsA().getDirectToLDS() ||
2275+
getMatrixParamsB().getDirectToLDS();
22682276

2269-
if (loadAFromLDS && !hasA)
2270-
return emitOpError("If loadAFromLDS is enabled, matrixA must be non-null.");
2271-
if (loadBFromLDS && !hasB)
2272-
return emitOpError("If loadBFromLDS is enabled, matrixB must be non-null.");
2273-
2274-
if (hasA && getElementTypeOrSelfRecursive(getMatrixA()) != getElementTypeA())
2277+
if (hasA && getElementTypeOrSelfRecursive(getMatrixA()) !=
2278+
getMatrixParamsA().getElementType())
22752279
return emitOpError("ElementTypeA and matrixA element type don't match");
22762280

2277-
if (hasB && getElementTypeOrSelfRecursive(getMatrixB()) != getElementTypeB())
2281+
if (hasB && getElementTypeOrSelfRecursive(getMatrixB()) !=
2282+
getMatrixParamsB().getElementType())
22782283
return emitOpError("ElementTypeA and matrixA element type don't match");
22792284

22802285
bool hasScaleABuffer = getBufferScaleA() != nullptr;
@@ -2289,9 +2294,9 @@ LogicalResult BlockwiseGemmAccelOp::verify() {
22892294
StringAttr archAttr = rock::getArch(*this).value_or(
22902295
StringAttr::get(this->getContext(), "gfx00"));
22912296

2292-
if (loadAFromLDS && loadBFromLDS)
2297+
if (hasA && hasB)
22932298
if (failed(verifyGemmTypes(*this, rock::getFeatures(*this), archAttr, aType,
2294-
bType, cType)))
2299+
bType, directToLDS ? nullptr : cType)))
22952300
return failure();
22962301
auto verifyMatrixAndScale = [&](bool loadFromLds, Value matrix, Value lds,
22972302
Value bufferScale, ShapedType bufferType,
@@ -2363,12 +2368,12 @@ LogicalResult BlockwiseGemmAccelOp::verify() {
23632368
};
23642369

23652370
// Verify matrix A and its scales
2366-
if (failed(verifyMatrixAndScale(loadAFromLDS, getMatrixA(), getScaleA(),
2371+
if (failed(verifyMatrixAndScale(hasA, getMatrixA(), getScaleA(),
23672372
getBufferScaleA(), aBufferType, "A")))
23682373
return failure();
23692374

23702375
// Verify matrix B and its scales
2371-
if (failed(verifyMatrixAndScale(loadBFromLDS, getMatrixB(), getScaleB(),
2376+
if (failed(verifyMatrixAndScale(hasB, getMatrixB(), getScaleB(),
23722377
getBufferScaleB(), bBufferType, "B")))
23732378
return failure();
23742379

@@ -2380,7 +2385,7 @@ LogicalResult BlockwiseGemmAccelOp::verify() {
23802385
}
23812386

23822387
SmallVector<mlir::Type> BlockwiseGemmAccelOp::getTypesForFeature() {
2383-
return {getMatrixA().getType()};
2388+
return {getMatrixParamsA().getElementType()};
23842389
}
23852390

23862391
void BlockwiseGemmAccelOp::getEffects(
@@ -2398,17 +2403,15 @@ void BlockwiseGemmAccelOp::getEffects(
23982403
effects.emplace_back(read, &getBufferScaleBMutable()[0]);
23992404
}
24002405
// if we load from LDS, we need to write to registers
2401-
if (getLoadAfromLDS()) {
2402-
assert(getMatrixA() != nullptr);
2406+
if (getMatrixA() != nullptr) {
24032407
effects.emplace_back(read, &getMatrixAMutable()[0]);
24042408
effects.emplace_back(write, &getBufferAMutable());
24052409
if (getScaleA()) {
24062410
effects.emplace_back(read, &getScaleAMutable()[0]);
24072411
effects.emplace_back(write, &getBufferScaleAMutable()[0]);
24082412
}
24092413
}
2410-
if (getLoadBfromLDS()) {
2411-
assert(getMatrixB() != nullptr);
2414+
if (getMatrixB() != nullptr) {
24122415
effects.emplace_back(read, &getMatrixBMutable()[0]);
24132416
effects.emplace_back(write, &getBufferBMutable());
24142417
if (getScaleB()) {
@@ -2443,13 +2446,13 @@ void ThreadwiseGemmOp::getEffects(
24432446
}
24442447

24452448
//===----------------------------------------------------------------------===//
2446-
// ThreadwiseAccelGemmOp
2449+
// ThreadwiseGemmAccelOp
24472450
//===----------------------------------------------------------------------===//
2448-
SmallVector<mlir::Type> ThreadwiseAccelGemmOp::getTypesForFeature() {
2451+
SmallVector<mlir::Type> ThreadwiseGemmAccelOp::getTypesForFeature() {
24492452
return {getMatrixA().getType()};
24502453
}
24512454

2452-
LogicalResult ThreadwiseAccelGemmOp::verify() {
2455+
LogicalResult ThreadwiseGemmAccelOp::verify() {
24532456
ShapedType aType = cast<ShapedType>(getMatrixA().getType());
24542457
ShapedType bType = cast<ShapedType>(getMatrixB().getType());
24552458

@@ -2489,7 +2492,7 @@ LogicalResult ThreadwiseAccelGemmOp::verify() {
24892492
return success();
24902493
}
24912494

2492-
void ThreadwiseAccelGemmOp::getEffects(
2495+
void ThreadwiseGemmAccelOp::getEffects(
24932496
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
24942497
if (getScaleA()) {
24952498
auto *read = MemoryEffects::Read::get();

mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -411,11 +411,13 @@ struct BlockwiseGemmAccelRewritePattern
411411
int64_t kpackPerBlock = tuningParams.getKpackPerBlock();
412412
int64_t mPerWave = tuningParams.getMPerWave();
413413
int64_t nPerWave = tuningParams.getNPerWave();
414-
bool loadAFromLDS = adaptor.getLoadAfromLDS();
415-
bool loadBFromLDS = adaptor.getLoadBfromLDS();
414+
bool loadAFromLDS = adaptor.getMatrixA() != nullptr;
415+
bool loadBFromLDS = adaptor.getMatrixB() != nullptr;
416+
BlockwiseMatrixParamsAttr matrixParamsA = op.getMatrixParamsA();
417+
BlockwiseMatrixParamsAttr matrixParamsB = op.getMatrixParamsB();
416418

417-
Type dataTypeA = adaptor.getElementTypeA();
418-
Type dataTypeB = adaptor.getElementTypeB();
419+
Type dataTypeA = matrixParamsA.getElementType();
420+
Type dataTypeB = matrixParamsB.getElementType();
419421

420422
auto features = rock::getFeatures(op);
421423
auto accelEmitterPtr = rock::accel::AccelEmitter::select(
@@ -447,7 +449,8 @@ struct BlockwiseGemmAccelRewritePattern
447449
<< "kpackPerBlock: " << kpackPerBlock << "\n"
448450
<< "loadAFromLDS: " << loadAFromLDS << "\n"
449451
<< "loadBFromLDS: " << loadBFromLDS << "\n"
450-
<< "rotateMWithK: " << op.getRotateMWithK() << "\n"
452+
<< "rotateMWithK: " << matrixParamsA.getRotateDWithK() << "\n"
453+
<< "rotateNWithK: " << matrixParamsB.getRotateDWithK() << "\n"
451454
<< "bufferA type: " << adaptor.getBufferA().getType() << "\n"
452455
<< "bufferB type: " << adaptor.getBufferB().getType() << "\n");
453456

@@ -466,24 +469,20 @@ struct BlockwiseGemmAccelRewritePattern
466469
// considered a temporary hack until we have a proper way of "searching"
467470
// through different schedules (either heuristically or automatically)
468471

469-
bool directToLDS = op.getDirectToLDS();
470472
Value wrappedLDSBufferForLoadA, wrappedLDSBufferForLoadB;
471473
if (loadAFromLDS) {
472474
wrappedLDSBufferForLoadA = accelEmitterPtr->wrapLDSBufferForLoad(
473-
b, loc, op.getMatrixA(), op.getBlockSize(), op.getInMPerThread(), "m",
474-
op.getRotateMWithK(), directToLDS, op.getLdsLayoutMxK(),
475-
op.getSplitKAcrossThreadsFirstA());
475+
b, loc, op.getMatrixA(), matrixParamsA, op.getBlockSize(), "m");
476476
}
477477
if (loadBFromLDS) {
478478
wrappedLDSBufferForLoadB = accelEmitterPtr->wrapLDSBufferForLoad(
479-
b, loc, op.getMatrixB(), op.getBlockSize(), op.getInNPerThread(), "n",
480-
op.getRotateNWithK(), directToLDS, op.getLdsLayoutNxK(),
481-
op.getSplitKAcrossThreadsFirstB());
479+
b, loc, op.getMatrixB(), matrixParamsB, op.getBlockSize(), "n");
482480
}
483481

484482
auto loadBuffer = [&](Value buffer, Value wrappedLDSBufferForLoad,
485483
Value loopVar, Type argType, int64_t repeats,
486-
bool loadFromLDS, bool isA) -> Value {
484+
bool loadFromLDS, bool directToLDS,
485+
bool isA) -> Value {
487486
Value inputBuffer = buffer;
488487
SmallVector<int64_t> shape;
489488
if (directToLDS) {
@@ -544,8 +543,9 @@ struct BlockwiseGemmAccelRewritePattern
544543
Value i = mLoop.getInductionVar();
545544

546545
Value bufferA = adaptor.getBufferA();
547-
bufferA = loadBuffer(bufferA, wrappedLDSBufferForLoadA, i, argTypeA,
548-
mRepeats, loadAFromLDS, true);
546+
bufferA =
547+
loadBuffer(bufferA, wrappedLDSBufferForLoadA, i, argTypeA, mRepeats,
548+
loadAFromLDS, matrixParamsA.getDirectToLDS(), true);
549549
Value viewA =
550550
accelEmitterPtr->generateThreadwiseViewBufferA(b, loc, bufferA);
551551

@@ -556,8 +556,9 @@ struct BlockwiseGemmAccelRewritePattern
556556
Value j = nLoop.getInductionVar();
557557

558558
Value bufferB = adaptor.getBufferB();
559-
bufferB = loadBuffer(bufferB, wrappedLDSBufferForLoadB, j, argTypeB,
560-
nRepeats, loadBFromLDS, false);
559+
bufferB =
560+
loadBuffer(bufferB, wrappedLDSBufferForLoadB, j, argTypeB, nRepeats,
561+
loadBFromLDS, matrixParamsB.getDirectToLDS(), false);
561562
Value viewB =
562563
accelEmitterPtr->generateThreadwiseViewBufferB(b, loc, bufferB);
563564

@@ -569,8 +570,8 @@ struct BlockwiseGemmAccelRewritePattern
569570
Value viewC = accelEmitterPtr->generateThreadwiseViewBufferC(
570571
b, loc, adaptor.getMatrixC());
571572
Value k = kLoop.getInductionVar();
572-
ThreadwiseAccelGemmOp::create(b, loc, viewA, viewB, viewC,
573-
/*aScale=*/nullptr, /*bScale=*/nullptr,
573+
ThreadwiseGemmAccelOp::create(b, loc, viewA, viewB, viewC,
574+
/*scaleA=*/nullptr, /*scaleB=*/nullptr,
574575
ValueRange{i, j, k},
575576
op.getFeaturesAttr(), tuningParams);
576577
}

0 commit comments

Comments
 (0)