Skip to content

Commit 40335eb

Browse files
authored
[LAYOUTS] Implement toLinearLayout for TensorMemoryEncodingAttr (#7748)
We do so by modelling M/N as describing elements and not the hardware 32bit registers. This allows us to avoid the issue of having two elements pointing to the same register when `unpacked=False`. We also tighten the `MemDescType` verifier and the `TensorMemoryEncodingAttr` verifier to be consistent with the definition we are using. Doing this makes us having to update a ton of lit tests that were silently wrong...
1 parent 09d9bd9 commit 40335eb

28 files changed

+419
-193
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,6 @@ def TTG_MemDescReinterpretOp : TTG_Op<"memdesc_reinterpret", [Pure, MemDescViewT
333333
$src attr-dict `:` qualified(type($src)) `->` qualified(type($result))
334334
}];
335335

336-
let hasVerifier = 1;
337336
let hasFolder = 1;
338337
}
339338

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,21 @@ def TTG_TensorMemorySpace : AttrDef<TritonNvidiaGPU_Dialect, "TensorMemorySpace"
1313
across TMEM 128 rows.
1414

1515
Blocks are distributed along M dimension first and then N dimension. This is an arbitrary
16-
convention that need to be followed operations reading/writing to TMEM.
16+
convention that needs to be followed by operations reading/writing to TMEM.
1717

18-
a tensor <128x128xf32> with blockM = 64 and blockN = 64 will be distributed as follows:
18+
a tensor <128x128xf32> with blockM = 64 and blockN = 32 will be distributed as follows:
1919

20-
\ col 0 1 31 32 64 96 127
21-
rows: 0 ( 0, 0) ( 0, 1) ... ( 0, 31) (64, 0) ... (0, 64) ... (64, 64) ... (64, 96)
20+
\ col 0 1 31 32 64 96 127
21+
rows: 0 ( 0, 0) ( 0, 1) ... ( 0, 31) ( 0, 32) ... ( 0, 64) ... ( 0, 96) ... ( 0, 127)
2222
1
2323
...
24-
15 (15, 0) (15, 1) ... (15, 31) (79, 0) ... (15, 64) ... (79, 64) ... (79, 96)
25-
16 ( 0, 32) ( 0, 33) ... ( 0, 63) (64, 32) ... ( 0, 96) ... (64, 96) ... (64, 127)
24+
15 (15, 0) (15, 1) ... (15, 31) (15, 32) ... (15, 64) ... (15, 96) ... (15, 127)
25+
16 (64, 0) (64, 1) ... (64, 31) (64, 32) ... (64, 64) ... (64, 96) ... (64, 127)
2626
...
27-
31 (15, 32) (15, 33) ... (15, 63) (79, 32) ... (15, 96) ... (79, 96) ... (79, 127)
28-
32 (16, 0) (16, 1) ... (16, 31) (80, 0) ... (16, 64) ... (80, 64) ... (80, 96)
29-
...
30-
127 (63, 32) (63, 33) ... (63, 63) (127, 32) ... (63, 96) ... (127, 96)... (127, 127)
27+
31 (79, 0) (79, 1) ... (79, 31) (79, 32) ... (79, 64) ... (79, 96) ... (79, 127)
28+
32 (16, 0) (16, 1) ... (16, 31) (16, 32) ... (16, 64) ... (16, 96) ... (16, 127)
29+
..
30+
127 (127, 0) (127, 1) ... (127, 31) (127, 32) ... (127, 64) ... (127, 96) ... (127, 127)
3131
}];
3232
}
3333

@@ -47,6 +47,7 @@ def TTG_TensorMemoryEncodingAttr : AttrDef<TritonNvidiaGPU_Dialect, "TensorMemor
4747
DefaultValuedParameter<"unsigned", "1">:$CTASplitM,
4848
DefaultValuedParameter<"unsigned", "1">:$CTASplitN
4949
);
50+
let genVerifyDecl = 1;
5051
let assemblyFormat = "`<` struct(params) `>`";
5152
}
5253

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
55
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
66
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
7+
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
78
#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h"
89
#include "triton/Tools/LayoutUtils.h"
910
#include "triton/Tools/LinearLayout.h"
@@ -13,6 +14,8 @@
1314
#include "llvm/Support/ErrorHandling.h"
1415
#include "llvm/Support/MathExtras.h"
1516

17+
using mlir::triton::nvidia_gpu::TensorMemoryEncodingAttr;
18+
1619
namespace mlir::triton::gpu {
1720
namespace {
1821

@@ -1185,6 +1188,72 @@ LinearLayout SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
11851188
llvm::to_vector(sliceLL.getOutDimNames()));
11861189
}
11871190

1191+
LinearLayout tensorMemoryToLinearLayout(ArrayRef<int64_t> shape,
1192+
TensorMemoryEncodingAttr encoding) {
1193+
// We model packed layouts as having the rows/cols dimensions of bitwidth=16
1194+
// This means that a layout with unpacked=True is the same as one with
1195+
// unpacked=False
1196+
assert(shape.size() == 2);
1197+
auto *ctx = encoding.getContext();
1198+
auto kRow = S("row");
1199+
auto kCol = S("col");
1200+
auto dims = standardOutDimNames(ctx, 2);
1201+
// The CTAOrder = [0, 1] so se start by N so that it ends up as
1202+
// ((tile * splitM) * splitN)
1203+
if (encoding.getCTASplitN() > 1) {
1204+
auto split =
1205+
LinearLayout::identity1D(encoding.getCTASplitN(), kCol, dims[1]);
1206+
auto newEncoding = TensorMemoryEncodingAttr::get(
1207+
ctx, encoding.getBlockM(), encoding.getBlockN(), encoding.getUnpacked(),
1208+
encoding.getCTASplitM(), 1);
1209+
return tensorMemoryToLinearLayout(
1210+
{shape[0], shape[1] / encoding.getCTASplitN()}, newEncoding) *
1211+
split;
1212+
}
1213+
if (encoding.getCTASplitM() > 1) {
1214+
auto split =
1215+
LinearLayout::identity1D(encoding.getCTASplitM(), kCol, dims[0]);
1216+
auto newEncoding = TensorMemoryEncodingAttr::get(
1217+
ctx, encoding.getBlockM(), encoding.getBlockN(), encoding.getUnpacked(),
1218+
1, encoding.getCTASplitN());
1219+
return tensorMemoryToLinearLayout(
1220+
{shape[0] / encoding.getCTASplitM(), shape[1]}, newEncoding) *
1221+
split;
1222+
}
1223+
assert(encoding.getCTASplitM() == 1 && encoding.getCTASplitN() == 1);
1224+
1225+
auto blockM = encoding.getBlockM();
1226+
auto blockN = encoding.getBlockN();
1227+
assert(blockM == 64 || blockM == 128);
1228+
LinearLayout tile;
1229+
if (blockM == 64) {
1230+
tile = LinearLayout::identity1D(16, kRow, dims[0]) *
1231+
LinearLayout::identity1D(blockN, kCol, dims[1]);
1232+
auto bases = tile.getBases();
1233+
if (shape[0] > blockM) {
1234+
bases[kRow].push_back({64, 0});
1235+
} else if (shape[1] > blockN) {
1236+
bases[kRow].push_back({0, static_cast<int32_t>(blockN)});
1237+
} else {
1238+
// Empty. This is modelled as broadcasting, same as for TMA(fp4)
1239+
bases[kRow].push_back({0, 0});
1240+
}
1241+
bases[kRow].push_back({16, 0});
1242+
bases[kRow].push_back({32, 0});
1243+
tile = LinearLayout(bases, dims);
1244+
} else {
1245+
tile = LinearLayout::identity1D(blockM, kRow, dims[0]) *
1246+
LinearLayout::identity1D(blockN, kCol, dims[1]);
1247+
}
1248+
auto repsM = shape[0] / tile.getOutDimSize(dims[0]);
1249+
auto repsN = shape[1] / tile.getOutDimSize(dims[1]);
1250+
assert(repsM >= 1 && repsN >= 1);
1251+
// Broadcast the remaining dimensions in order [0, 1]
1252+
tile = tile * LinearLayout::identity1D(repsM, kCol, dims[0]) *
1253+
LinearLayout::identity1D(repsN, kCol, dims[1]);
1254+
return tile;
1255+
}
1256+
11881257
LinearLayout
11891258
TritonGPUDialect::toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
11901259
ArrayRef<int64_t> allocationShape) {
@@ -1204,7 +1273,8 @@ TritonGPUDialect::toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
12041273
result = distributed.toLinearLayout(shape);
12051274
} else {
12061275
assert(!allocationShape.empty() &&
1207-
"allocationShape not supported for shared layout");
1276+
"allocationShape must be given for SharedMemory and TensorMemory "
1277+
"encodings");
12081278
allocationShape = allocationShape.take_back(shape.size());
12091279
assert(llvm::all_of(allocationShape,
12101280
[](int64_t dim) {
@@ -1216,13 +1286,16 @@ TritonGPUDialect::toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
12161286
return std::get<0>(dims) >= std::get<1>(dims);
12171287
}) &&
12181288
"allocationShape must be at least as large as shape");
1219-
12201289
if (auto shared = dyn_cast<SwizzledSharedEncodingAttr>(layout)) {
12211290
result = swizzledSharedToLinearLayout(allocationShape, shared);
12221291
} else if (auto shared = dyn_cast<NVMMASharedEncodingAttr>(layout)) {
12231292
result = nvmmaSharedToLinearLayout(allocationShape, shared);
12241293
} else if (auto sbl = dyn_cast<AMDRotatingSharedEncodingAttr>(layout)) {
12251294
result = sharedToLinearLayoutAMDRotating(allocationShape, sbl);
1295+
} else if (auto tensorMemoryEncoding =
1296+
dyn_cast<TensorMemoryEncodingAttr>(layout)) {
1297+
result =
1298+
tensorMemoryToLinearLayout(allocationShape, tensorMemoryEncoding);
12261299
} else {
12271300
assert(0 && "unknown layout");
12281301
}

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -581,13 +581,6 @@ LogicalResult MemDescReshapeOp::inferReturnTypes(
581581
return success();
582582
}
583583

584-
// MemDescReinterpretOp
585-
LogicalResult MemDescReinterpretOp::verify() {
586-
if (getSrc().getType().getMemorySpace() != getType().getMemorySpace())
587-
return emitError("source and destination memory space must match");
588-
return success();
589-
}
590-
591584
OpFoldResult MemDescReinterpretOp::fold(FoldAdaptor adaptor) {
592585
if (getType() == getSrc().getType())
593586
return getSrc();

lib/Dialect/TritonGPU/IR/Types.cpp

Lines changed: 72 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include "triton/Dialect/TritonGPU/IR/Types.h"
22
#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc`
33
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
4+
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
5+
#include "triton/Tools/LayoutUtils.h"
46
#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc`
57

68
using namespace mlir;
@@ -52,12 +54,15 @@ Type MemDescType::parse(AsmParser &parser) {
5254
if (parser.parseGreater())
5355
return Type();
5456

55-
if (allocShape.size() > 0)
56-
return MemDescType::get(parser.getContext(), dimensions, elementType,
57-
encoding, memorySpace, mutableMemory, allocShape);
57+
Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
58+
if (!allocShape.empty())
59+
return MemDescType::getChecked(loc, parser.getContext(), dimensions,
60+
elementType, encoding, memorySpace,
61+
mutableMemory, allocShape);
5862

59-
return MemDescType::get(parser.getContext(), dimensions, elementType,
60-
encoding, memorySpace, mutableMemory, dimensions);
63+
return MemDescType::getChecked(loc, parser.getContext(), dimensions,
64+
elementType, encoding, memorySpace,
65+
mutableMemory, dimensions);
6166
}
6267

6368
void MemDescType::print(AsmPrinter &printer) const {
@@ -87,8 +92,69 @@ LogicalResult MemDescType::verify(function_ref<InFlightDiagnostic()> emitError,
8792
Attribute encoding, Attribute memorySpace,
8893
bool mutableMemory,
8994
ArrayRef<int64_t> allocShape) {
95+
// Every dimension but the first (to allow for pipelining) must be a power of
96+
// 2
97+
if (!isa<PaddedSharedEncodingAttr>(encoding) &&
98+
llvm::any_of(shape.drop_front(1),
99+
[](int64_t dim) { return !llvm::isPowerOf2_64(dim); }))
100+
return emitError() << "shape must have power-of-2 dimensions; got "
101+
<< shape;
90102
if (allocShape.size() < shape.size())
91-
emitError() << "alloc shape must have at least as many dimensions as shape";
103+
return emitError()
104+
<< "alloc shape must have at least as many dimensions as shape";
105+
if (llvm::any_of(
106+
llvm::zip(shape, allocShape.take_back(shape.size())),
107+
[](auto pair) { return std::get<0>(pair) > std::get<1>(pair); }))
108+
return emitError() << "shape must be less than or equal to allocShape. "
109+
<< "shape = " << shape
110+
<< ", allocShape = " << allocShape;
111+
auto ctx = encoding.getContext();
112+
if (auto enc = dyn_cast<nvidia_gpu::TensorMemoryEncodingAttr>(encoding)) {
113+
if (memorySpace != nvidia_gpu::TensorMemorySpaceAttr::get(ctx)) {
114+
return emitError() << "memorySpace must be TensorMemorySpace";
115+
}
116+
if (shape.size() != 2 && shape.size() != 3) {
117+
return emitError() << "rank must be 2 or 3";
118+
}
119+
auto bitwidth = elementType.getIntOrFloatBitWidth();
120+
if (!enc.getUnpacked() && bitwidth != 16) {
121+
return emitError() << "bitwidth must be 16 for packed tensor memory";
122+
}
123+
if (bitwidth != 16 && bitwidth != 32) {
124+
return emitError() << "bitwidth must be 16 or 32";
125+
}
126+
shape = shape.take_back(2);
127+
allocShape = allocShape.take_back(2);
128+
if (allocShape[0] < enc.getBlockM() * enc.getCTASplitM() ||
129+
allocShape[1] < enc.getBlockN() * enc.getCTASplitN()) {
130+
return emitError() << "the allocation shape must be at least "
131+
<< enc.getBlockM() * enc.getCTASplitM() << "x"
132+
<< enc.getBlockN() * enc.getCTASplitN() << ". Got "
133+
<< allocShape;
134+
}
135+
auto ll = toLinearLayout(shape, enc, allocShape);
136+
auto dims = standardOutDimNames(ctx, 2);
137+
if (ll.getOutDimSize(dims[0]) != allocShape[0] ||
138+
ll.getOutDimSize(dims[1]) != allocShape[1]) {
139+
return emitError() << "allocation shape must be equal to "
140+
<< ll.getOutDimSize(dims[0]) << "x"
141+
<< ll.getOutDimSize(dims[1]);
142+
}
143+
} else if (auto enc = dyn_cast<SharedEncodingTrait>(encoding)) {
144+
if (memorySpace != SharedMemorySpaceAttr::get(ctx)) {
145+
return emitError()
146+
<< "memorySpace must be SharedMemorySpace for shared encoding. "
147+
<< "Got " << memorySpace;
148+
}
149+
} else if (auto enc = dyn_cast<nvidia_gpu::TensorMemoryScalesEncodingAttr>(
150+
encoding)) {
151+
if (memorySpace != nvidia_gpu::TensorMemorySpaceAttr::get(ctx)) {
152+
return emitError() << "memorySpace must be TensorMemorySpace";
153+
}
154+
// TODO Add rest of verifier
155+
} else {
156+
return emitError() << encoding << " is not a valid encoding";
157+
}
92158
return success();
93159
}
94160

lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,24 @@ bool isDistributedLayoutTMemCompatible(Operation *op,
247247
});
248248
}
249249

250+
LogicalResult TensorMemoryEncodingAttr::verify(
251+
function_ref<InFlightDiagnostic()> emitError, unsigned blockM,
252+
unsigned blockN, bool unpacked, unsigned CTASplitM, unsigned CTASplitN) {
253+
if (CTASplitM < 1 || CTASplitN < 1) {
254+
return emitError() << "CTASplitM and CTASplitN must be greater than 0";
255+
}
256+
if (blockM != 64 && blockM != 128) {
257+
return emitError() << "blockM must be 64 or 128";
258+
}
259+
if (!llvm::isPowerOf2_32(blockN) || blockN > 512) {
260+
return emitError() << "blockN must be a power of 2 and less than 512";
261+
}
262+
if (!unpacked && blockN < 2) {
263+
return emitError() << "blockN must be at least 2 for packed tensor memory";
264+
}
265+
return success();
266+
}
267+
250268
LogicalResult impl::verifyMMAv5Op(Operation *op) {
251269
auto isInterleaved = [](MemDescType memdesc) {
252270
auto enc = dyn_cast<TensorMemoryEncodingAttr>(memdesc.getEncoding());

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -528,9 +528,6 @@ static LogicalResult verifyTMEMOperand(Operation *op, RankedTensorType type,
528528
}
529529

530530
LogicalResult TMEMStoreOp::verify() {
531-
if (!isa<triton::nvidia_gpu::TensorMemorySpaceAttr>(
532-
getDst().getType().getMemorySpace()))
533-
return emitOpError("destination must be a tensor memory buffer.");
534531
if (!isa<triton::nvidia_gpu::TensorMemoryEncodingAttr,
535532
TensorMemoryScalesEncodingAttr>(getDst().getType().getEncoding()))
536533
return emitOpError("should use tensor memory encoding.");
@@ -559,8 +556,6 @@ LogicalResult TMEMLoadOp::verify() {
559556

560557
// -- TMEMAllocOp --
561558
LogicalResult TMEMAllocOp::verify() {
562-
if (!isa<TensorMemorySpaceAttr>(getType().getMemorySpace()))
563-
return emitOpError("should create a buffer of tensor memory");
564559
if (!isa<TensorMemoryEncodingAttr, TensorMemoryScalesEncodingAttr>(
565560
getType().getEncoding()))
566561
return emitOpError("should use tensor memory encoding");
@@ -662,7 +657,7 @@ void TMEMSubSliceOp::build(OpBuilder &builder, OperationState &state,
662657
encoding.getUnpacked(), encoding.getCTASplitM(), encoding.getCTASplitN());
663658
auto subsliceType = gpu::MemDescType::get(
664659
shape, allocTy.getElementType(), newEncoding, allocTy.getMemorySpace(),
665-
allocTy.getMutableMemory());
660+
allocTy.getMutableMemory(), allocTy.getAllocShape());
666661
build(builder, state, subsliceType, alloc, offset);
667662
}
668663

lib/Dialect/TritonNvidiaGPU/Transforms/PromoteLHSToTMem.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,18 @@ template <class MMAOpTy> class LHSToTMem : public OpRewritePattern<MMAOpTy> {
5555
tcGen5MMAOp.getD().getType().getEncoding());
5656
ArrayRef<unsigned> CTASplitNum =
5757
triton::gpu::getCTALayout(srcLayout).getCTASplitNum();
58-
// TMem encoding for A operand is the same as for D (Acc), but packed.
58+
// TMem encoding for A operand is the same as for D (Acc), but packed for
59+
// bitwidth=16
60+
unsigned elemBitWidth =
61+
lhs.getType().getElementType().getIntOrFloatBitWidth();
62+
// We don't currently support fp8 (not sure if we can)
63+
if (elemBitWidth != 16 && elemBitWidth != 32) {
64+
return failure();
65+
}
66+
bool unpacked = elemBitWidth != 16;
5967
auto aTMemEncoding = TensorMemoryEncodingAttr::get(
6068
context, accTMemEncoding.getBlockM(), lhs.getType().getShape()[1],
61-
/*unpacked=*/false, CTASplitNum[0], CTASplitNum[1]);
69+
/*unpacked=*/unpacked, CTASplitNum[0], CTASplitNum[1]);
6270
Attribute tensorMemorySpace =
6371
triton::nvidia_gpu::TensorMemorySpaceAttr::get(context);
6472
ttg::MemDescType lhsMemDescType = ttg::MemDescType::get(

python/test/unit/blackwell/test_tmem.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def test_tmem_copy_2d(tmp_path: pathlib.Path):
7575
#blocked = #ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[32, 1], warpsPerCTA=[4, 1], order=[0, 1]}>
7676
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8}>
7777
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
78-
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 32, unpacked = false>
78+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 32, unpacked = true>
7979
module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} {
8080
tt.func public @kernel_0d1d(%arg0: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
8181
""" + ir_body + """

python/tutorials/gluon/01-attention-forward.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -479,15 +479,15 @@ def _borrow_s_as_p(config, s_tmem):
479479
@gluon.jit
480480
def _borrow_s_as_alpha(config, s_tmem):
481481
alpha_tmem = s_tmem.slice(config.BLOCK_N // 2, 1)
482-
alpha_layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M, 1], unpacked=False)
482+
alpha_layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M, 1], unpacked=True)
483483
return alpha_tmem._reinterpret(gl.float32, [config.SPLIT_M, 1], alpha_layout)
484484

485485

486486
@gluon.jit
487487
def _borrow_s_for_epilogue(config, s_tmem):
488488
m_i_tmem = s_tmem.slice(config.BLOCK_N // 2 + 1, 1)
489489
l_i_tmem = s_tmem.slice(config.BLOCK_N // 2 + 2, 1)
490-
layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M, 1], unpacked=False)
490+
layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M, 1], unpacked=True)
491491
m_i_tmem = m_i_tmem._reinterpret(gl.float32, [config.SPLIT_M, 1], layout)
492492
l_i_tmem = l_i_tmem._reinterpret(gl.float32, [config.SPLIT_M, 1], layout)
493493
return m_i_tmem, l_i_tmem

0 commit comments

Comments
 (0)