Skip to content

Commit 70e69cb

Browse files
authored
[BACKEND] Generic tcgen05.cp lowering (#8225)
We also fix a ton of issues here and there that we found while working on this. - We add full support for `memdesc_trans` and `memdesc_reshape` using the newly minted `SharedLinearLayout`. - We fix a few issues we left out in `SharedLinearLayout`'s initial implementation. - We now make `tcgen05.cp` take the correct layout, and we fix the OptimizeDotOperands pass to use `memdesc_trans/reshape` to reflect this. - We fix a number of previously broken tests We still need to tighten the memdesc_copy verifier to make it a bit more user-friendly tho.
1 parent e15cb57 commit 70e69cb

File tree

16 files changed

+282
-282
lines changed

16 files changed

+282
-282
lines changed

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
44
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
55
#include "triton/Dialect/TritonGPU/IR/Types.h"
6+
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
67
#include "triton/Tools/LayoutUtils.h"
78

89
using namespace mlir;

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1664,7 +1664,7 @@ void SharedLinearEncodingAttr::print(AsmPrinter &printer) const {
16641664
layout.sublayout({kOffset}, llvm::to_vector(layout.getOutDimNames()));
16651665
}
16661666
printLinearLayout(printer, layout);
1667-
printer << "}, alignment = " << getAlignment() << "}>";
1667+
printer << "}, alignment = " << getAlignment() << ">";
16681668
}
16691669

16701670
Attribute SharedLinearEncodingAttr::parse(AsmParser &parser, Type type) {
@@ -2644,19 +2644,17 @@ struct TritonGPUInferLayoutInterface
26442644
}
26452645

26462646
if (auto enc = dyn_cast<NVMMASharedEncodingAttr>(operandEncoding)) {
2647-
if (failed(checkRank(enc.getRank())))
2648-
return failure();
2649-
if (order != ArrayRef<int32_t>({1, 0})) {
2650-
return emitOptionalError(
2651-
loc, "NVMMSharedEncoding can only be transposed in 2D");
2652-
}
2647+
if (order == ArrayRef<int32_t>({1, 0})) {
2648+
if (failed(checkRank(enc.getRank())))
2649+
return failure();
26532650

2654-
CTALayoutAttr ctaLayout =
2655-
permuteCTALayout(ctx, enc.getCTALayout(), order);
2656-
resultEncoding = NVMMASharedEncodingAttr::get(
2657-
ctx, enc.getSwizzlingByteWidth(), !enc.getTransposed(),
2658-
enc.getElementBitWidth(), enc.getFp4Padded(), ctaLayout);
2659-
return success();
2651+
CTALayoutAttr ctaLayout =
2652+
permuteCTALayout(ctx, enc.getCTALayout(), order);
2653+
resultEncoding = NVMMASharedEncodingAttr::get(
2654+
ctx, enc.getSwizzlingByteWidth(), !enc.getTransposed(),
2655+
enc.getElementBitWidth(), enc.getFp4Padded(), ctaLayout);
2656+
return success();
2657+
}
26602658
}
26612659

26622660
if (auto enc = dyn_cast<BlockedEncodingAttr>(operandEncoding)) {
@@ -2672,20 +2670,25 @@ struct TritonGPUInferLayoutInterface
26722670
applyPermutation(invOrderUnsigned, enc.getOrder()), ctaLayout);
26732671
return success();
26742672
}
2673+
// Generic case
2674+
auto padded = dyn_cast<PaddedSharedEncodingAttr>(operandEncoding);
26752675

2676-
if (auto enc = dyn_cast<PaddedSharedEncodingAttr>(operandEncoding)) {
2677-
if (failed(checkRank(enc.getRank())))
2678-
return failure();
2679-
const auto &transLL =
2680-
transposeLinearLayout(enc.getLinearComponent(), order);
2681-
resultEncoding = PaddedSharedEncodingAttr::get(
2682-
ctx, enc.getIntervals(), enc.getPaddings(), transLL);
2683-
return success();
2684-
}
2685-
2686-
auto ll = toLinearLayout(shape, operandEncoding);
2676+
auto ll = padded ? padded.getLinearComponent()
2677+
: toLinearLayout(shape, operandEncoding);
2678+
if (failed(checkRank(ll.getNumOutDims())))
2679+
return failure();
26872680
auto transposedLl = transposeLinearLayout(ll, order);
2688-
resultEncoding = LinearEncodingAttr::get(ctx, std::move(transposedLl));
2681+
if (isa<DistributedEncodingTrait>(operandEncoding)) {
2682+
resultEncoding = LinearEncodingAttr::get(ctx, std::move(transposedLl));
2683+
} else if (padded) {
2684+
resultEncoding = PaddedSharedEncodingAttr::get(ctx, padded.getIntervals(),
2685+
padded.getPaddings(),
2686+
std::move(transposedLl));
2687+
} else {
2688+
auto shared = cast<SharedEncodingTrait>(operandEncoding);
2689+
resultEncoding = SharedLinearEncodingAttr::get(
2690+
ctx, std::move(transposedLl), shared.getAlignment());
2691+
}
26892692
return success();
26902693
}
26912694

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -529,40 +529,44 @@ static LogicalResult inferMemDescReshapeOpEncoding(ArrayRef<int64_t> srcShape,
529529
Attribute srcEnc,
530530
ArrayRef<int64_t> dstShape,
531531
Attribute &dstEnc) {
532+
// TODO Delete this once SharedLinearEncodingAttr is more widely supported.
532533
if (auto mmaEncoding = dyn_cast<NVMMASharedEncodingAttr>(srcEnc)) {
533-
// TODO: supporting reshape of CTA layouts is non-trivial.
534-
if (getNumCTAs(mmaEncoding) > 1)
535-
return failure();
536-
int innerDimDst =
537-
mmaEncoding.getTransposed() ? dstShape.front() : dstShape.back();
538-
int innerDimSrc =
539-
mmaEncoding.getTransposed() ? srcShape.front() : srcShape.back();
540-
// For now disallow reshape of the inner dimension.
541-
if (innerDimDst != innerDimSrc)
542-
return failure();
543534
auto *ctx = srcEnc.getContext();
544-
545-
// CTALayout can be all 1's because we bailed on multi-CTA layouts above.
546-
auto CTALayout = CTALayoutAttr::get(
547-
ctx,
548-
/*CTAsPerCGA=*/SmallVector<unsigned>(dstShape.size(), 1),
549-
/*CTASplitNum=*/SmallVector<unsigned>(dstShape.size(), 1),
550-
/*CTAOrder=*/llvm::to_vector(llvm::seq<unsigned>(dstShape.size())));
551-
dstEnc = NVMMASharedEncodingAttr::get(
552-
ctx, mmaEncoding.getSwizzlingByteWidth(), mmaEncoding.getTransposed(),
553-
mmaEncoding.getElementBitWidth(), mmaEncoding.getFp4Padded(),
554-
CTALayout);
555-
// Big guns, check linear layouts are equivalent
556-
// We disallow reshaping memdesc_subslice in the verifier
557-
// so allocShape == shape
558-
auto srcLL = toLinearLayout(srcShape, srcEnc);
559-
auto dstLL = toLinearLayout(dstShape, dstEnc);
560-
if (reshapeLayout(ctx, srcLL, dstShape) != dstLL) {
561-
return failure();
535+
if (getNumCTAs(mmaEncoding) == 1) {
536+
int innerDimDst =
537+
mmaEncoding.getTransposed() ? dstShape.front() : dstShape.back();
538+
int innerDimSrc =
539+
mmaEncoding.getTransposed() ? srcShape.front() : srcShape.back();
540+
// We can keep an NVMMAShared encoding only if the innermost dimension is
541+
// preserved. Otherwise fall back to the generic shared-linear encoding
542+
// logic below.
543+
if (innerDimDst == innerDimSrc) {
544+
auto CTALayout = CTALayoutAttr::get(
545+
ctx,
546+
/*CTAsPerCGA=*/SmallVector<unsigned>(dstShape.size(), 1),
547+
/*CTASplitNum=*/SmallVector<unsigned>(dstShape.size(), 1),
548+
/*CTAOrder=*/llvm::to_vector(llvm::seq<unsigned>(dstShape.size())));
549+
auto candidateEncoding = NVMMASharedEncodingAttr::get(
550+
ctx, mmaEncoding.getSwizzlingByteWidth(),
551+
mmaEncoding.getTransposed(), mmaEncoding.getElementBitWidth(),
552+
mmaEncoding.getFp4Padded(), CTALayout);
553+
auto srcLL = toLinearLayout(srcShape, srcEnc);
554+
auto dstLL = toLinearLayout(dstShape, candidateEncoding);
555+
if (reshapeLayout(ctx, srcLL, dstShape) == dstLL) {
556+
dstEnc = candidateEncoding;
557+
return success();
558+
}
559+
}
562560
}
563-
return success();
564561
}
565-
return failure();
562+
563+
// Generic LL case
564+
auto sharedEnc = cast<SharedEncodingTrait>(srcEnc);
565+
auto *ctx = srcEnc.getContext();
566+
auto srcLL = toLinearLayout(srcShape, srcEnc);
567+
auto dstLL = reshapeLayout(ctx, srcLL, dstShape);
568+
dstEnc = SharedLinearEncodingAttr::get(ctx, dstLL, sharedEnc.getAlignment());
569+
return success();
566570
}
567571

568572
LogicalResult MemDescReshapeOp::inferReturnTypes(

lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,24 @@ class UseShmemForScales
257257
if (!isTmemCopyCompatible(localLoad.getSrc().getType(), usesTMAload))
258258
return failure();
259259

260-
opOperand.assign(localLoad.getSrc());
260+
PatternRewriter::InsertionGuard guard(rewriter);
261+
rewriter.setInsertionPoint(tmemAlloc);
262+
263+
Value shared = localLoad.getSrc();
264+
265+
Value reshaped5D = rewriter.create<MemDescReshapeOp>(
266+
reshapeOp5D.getLoc(), shared, reshape5DShape);
267+
SmallVector<int32_t> transposeOrder32(transposeOrder.begin(),
268+
transposeOrder.end());
269+
Value transposed = rewriter.create<MemDescTransOp>(
270+
transOp.getLoc(), reshaped5D, transposeOrder32);
271+
SmallVector<int64_t> scale2DShapeVec(scale2DShape.begin(),
272+
scale2DShape.end());
273+
Value reshaped2D = rewriter.create<MemDescReshapeOp>(
274+
reshapeOp2D.getLoc(), transposed, scale2DShapeVec);
275+
276+
opOperand.assign(reshaped2D);
277+
rewriter.eraseOp(tmemAlloc);
261278
return success();
262279
}
263280

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -680,14 +680,20 @@ LogicalResult TMEMCopyOp::verify() {
680680
getSrc().getType().getMemorySpace()))
681681
return emitOpError("The source must be a shared memory buffer");
682682

683+
auto srcTy = cast<triton::gpu::MemDescType>(getSrc().getType());
684+
auto dstTy = cast<triton::gpu::MemDescType>(getDst().getType());
685+
if (srcTy.getShape() != dstTy.getShape())
686+
return emitOpError("source shape ")
687+
<< srcTy.getShape() << " must match destination shape "
688+
<< dstTy.getShape();
689+
683690
if (getBarrier() && !isa<triton::gpu::SharedMemorySpaceAttr>(
684691
getBarrier().getType().getMemorySpace())) {
685692
return emitOpError("The optional barrier should be a shared memory buffer");
686693
}
687694
if (!getDst().getType().getMutableMemory()) {
688695
return emitOpError("Cannot copy into an immutable alloc");
689696
}
690-
auto srcTy = cast<triton::gpu::MemDescType>(getSrc().getType());
691697
auto sharedEnc =
692698
dyn_cast<triton::gpu::SharedEncodingTrait>(srcTy.getEncoding());
693699
if (sharedEnc.getAlignment() < 16) {
@@ -700,21 +706,16 @@ LogicalResult TMEMCopyOp::verify() {
700706
if (numCTAs != 1)
701707
return emitOpError("NYI: Only one CTA is supported for now.");
702708

709+
// Fp4 we could lift if we needed
703710
auto nvmmaEnc =
704711
dyn_cast<triton::gpu::NVMMASharedEncodingAttr>(srcTy.getEncoding());
705-
if (!nvmmaEnc) {
706-
return emitOpError("Source must have nvmma layout.");
707-
}
708-
// Fp4 we could lift if we needed
709-
if (nvmmaEnc.getTransposed() || nvmmaEnc.getFp4Padded())
712+
if (nvmmaEnc && (nvmmaEnc.getTransposed() || nvmmaEnc.getFp4Padded())) {
710713
return emitOpError("The source should not be transposed or padded");
714+
}
711715
if (isa<TensorMemoryScalesEncodingAttr>(getDst().getType().getEncoding())) {
712-
if (nvmmaEnc.getSwizzlingByteWidth() != 0) {
716+
if (nvmmaEnc && nvmmaEnc.getSwizzlingByteWidth() != 0) {
713717
return emitOpError("The source should not be swizzled for now");
714718
}
715-
if (!triton::gpu::isInnermostContiguous(srcTy, 512)) {
716-
return emitOpError("The source must be in a row-major order.");
717-
}
718719
} else {
719720
if (getSrc().getType().getShape() != getDst().getType().getShape()) {
720721
return emitOpError(
@@ -728,7 +729,7 @@ LogicalResult TMEMCopyOp::verify() {
728729
if (tmemEnc.getBlockM() != 128) {
729730
return emitOpError("Tmem layout ahouls have M=128.");
730731
}
731-
if (nvmmaEnc.getSwizzlingByteWidth() == 0) {
732+
if (nvmmaEnc && nvmmaEnc.getSwizzlingByteWidth() == 0) {
732733
return emitOpError("Source layout should be swizzled.");
733734
}
734735
// When we lift this, we should make sure we handle unpacked cleanly

python/src/gluon_ir.cc

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ struct GluonLayouts {
9898
py::handle NVMMADistributedLayout;
9999
py::handle NVMMASharedLayout;
100100
py::handle SwizzledSharedLayout;
101+
py::handle SharedLinearLayout;
101102
py::handle AMDMFMALayout;
102103
py::handle AMDWMMALayout;
103104
py::handle PaddedSharedLayout;
@@ -119,6 +120,8 @@ struct GluonLayouts {
119120
NVMMASharedLayout = py::object(layouts.attr("NVMMASharedLayout")).release();
120121
SwizzledSharedLayout =
121122
py::object(layouts.attr("SwizzledSharedLayout")).release();
123+
SharedLinearLayout =
124+
py::object(layouts.attr("SharedLinearLayout")).release();
122125
AMDMFMALayout = py::object(amdLayouts.attr("AMDMFMALayout")).release();
123126
AMDWMMALayout = py::object(amdLayouts.attr("AMDWMMALayout")).release();
124127
PaddedSharedLayout =
@@ -203,6 +206,14 @@ py::object layoutToGluon(Attribute layout) {
203206
toStdVector(ctaLayout.getCTAsPerCGA()),
204207
toStdVector(ctaLayout.getCTASplitNum()),
205208
toStdVector(ctaLayout.getCTAOrder()));
209+
} else if (auto sharedLl = dyn_cast<ttg::SharedLinearEncodingAttr>(layout)) {
210+
const auto &ll = sharedLl.getLinearLayout();
211+
auto ctx = layout.getContext();
212+
auto kOffset = mlir::StringAttr::get(ctx, "offset");
213+
auto kBlock = mlir::StringAttr::get(ctx, "block");
214+
return layouts.SharedLinearLayout(
215+
toStdVector(ll.getBases().lookup(kOffset)),
216+
toStdVector(ll.getBases().lookup(kBlock)), sharedLl.getAlignment());
206217
} else if (auto autoEnc = dyn_cast<gluon::AutoEncodingAttr>(layout)) {
207218
return layouts.AutoLayout();
208219
} else if (auto amdMfma = dyn_cast<ttg::AMDMfmaEncodingAttr>(layout)) {
@@ -410,14 +421,13 @@ void init_gluon_ir(py::module &&m) {
410421
.def("get_shared_linear_layout",
411422
[](GluonOpBuilder &self, std::vector<std::vector<int>> &offsetBases,
412423
std::vector<std::vector<int>> &blockBases,
413-
std::vector<int64_t> &shape, unsigned alignment) -> Attribute {
424+
unsigned alignment) -> Attribute {
414425
auto ctx = self.getContext();
415426
auto kOffset = mlir::StringAttr::get(ctx, "offset");
416427
auto kBlock = mlir::StringAttr::get(ctx, "block");
428+
auto outDims = tt::standardOutDimNames(ctx, offsetBases[0].size());
417429
auto ll = tt::LinearLayout(
418-
{{kOffset, offsetBases}, {kBlock, blockBases}},
419-
tt::standardOutDimPairs(ctx, shape),
420-
/*requireSurjective=*/true);
430+
{{kOffset, offsetBases}, {kBlock, blockBases}}, outDims);
421431
return self.getChecked<ttg::SharedLinearEncodingAttr>(ctx, ll,
422432
alignment);
423433
})

python/test/gluon/test_core.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -514,13 +514,12 @@ def fast_dividef_kernel(x_ptr, y_ptr, z_ptr, warp_size: ttgl.constexpr, num_warp
514514
torch.testing.assert_close(z, torch.div(x, y), atol=1e-5, rtol=1e-4)
515515

516516

517-
@pytest.mark.xfail(reason="copy to tmem with scale layout is currently broken in Gluon.")
518517
@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell")
519518
def test_tmem_copy_2d():
520519
device = "cuda"
521520

522-
smem_h = 256
523-
smem_w = 4
521+
smem_h = 64
522+
smem_w = 16
524523
num_rows = 128
525524
num_cols = smem_h * smem_w // 32
526525

@@ -530,13 +529,14 @@ def kernel(in_ptr, out_ptr, smem_h: ttgl.constexpr, smem_w: ttgl.constexpr, num_
530529
in_ptrs = in_ptr + ttgl.arange(0, smem_h)[:, None] * smem_w + ttgl.arange(0, smem_w)[None, :]
531530
out_ptrs = out_ptr + ttgl.arange(0, num_rows)[:, None] * num_cols + ttgl.arange(0, num_cols)[None, :]
532531

533-
blocked: ttgl.constexpr = ttgl.BlockedLayout([1, 4], [32, 1], [4, 1], [0, 1])
532+
blocked: ttgl.constexpr = ttgl.BlockedLayout([1, 4], [32, 1], [4, 1], [1, 0])
534533
value = ttgl.load(ttgl.set_auto_layout(in_ptrs, blocked))
535534

536-
smem_layout: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=0, element_bitwidth=8, rank=2)
535+
smem_layout: ttgl.constexpr = ttgl.SharedLinearLayout(
536+
offset_bases=[[0, 1], [0, 2], [32, 0], [0, 4], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 8]])
537537
tmem_layout: ttgl.constexpr = TensorMemoryScalesLayout()
538538
smem = ttgl.allocate_shared_memory(ttgl.int8, (smem_h, smem_w), layout=smem_layout)
539-
tmem = allocate_tensor_memory(ttgl.int8, (num_rows, num_cols), layout=tmem_layout)
539+
tmem = allocate_tensor_memory(ttgl.int8, (smem_h, smem_w), layout=tmem_layout)
540540

541541
barrier = ttgl.allocate_shared_memory(ttgl.int64, [1], ttgl.constexpr(mbarrier.MBarrierLayout()))
542542
mbarrier.init(barrier, count=1)
@@ -546,22 +546,30 @@ def kernel(in_ptr, out_ptr, smem_h: ttgl.constexpr, smem_w: ttgl.constexpr, num_
546546
tcgen05_copy(smem, tmem)
547547
tcgen05_commit(barrier)
548548
mbarrier.wait(barrier, phase=0)
549-
tmem_alias: ttgl.constexpr = TensorMemoryLayout((128, 32), col_stride=1)
549+
tmem_alias: ttgl.constexpr = TensorMemoryLayout((num_rows, num_cols), col_stride=1)
550550
tmem = tmem._reinterpret(ttgl.int8, (num_rows, num_cols), tmem_alias)
551551
value = tmem.load(blocked)
552+
ttgl.static_print(ttgl.to_linear_layout(blocked, (smem_h, smem_w)))
553+
ttgl.static_print(ttgl.to_linear_layout(blocked, (num_rows, num_cols)))
552554
ttgl.store(ttgl.set_auto_layout(out_ptrs, blocked), value)
553555

556+
torch.manual_seed(0)
554557
x = torch.randint(size=(smem_h, smem_w), low=-100, high=100, dtype=torch.int8).to(device)
558+
#x = torch.arange(smem_h * smem_w, dtype=torch.int8, device=device).reshape(smem_h, smem_w)
555559
z_tri = torch.zeros(size=(num_rows, num_cols), dtype=torch.int8).to(device)
556560
kernel[(1, )](x, z_tri, smem_h, smem_w, num_rows, num_cols)
557561

558-
num_rep_m = smem_h // 32
559-
560-
for m in range(num_rep_m):
561-
col_offset = m * 4
562-
for i in range(4):
563-
# Copied values are duplicated across warps
564-
assert torch.equal(x[m * 32:(m + 1) * 32], z_tri[32 * i:32 * (i + 1), col_offset:(col_offset + 4)])
562+
# offset_bases=[[0, 1], [0, 2], [32, 0], [0, 4], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 8]],
563+
# Split into contiguous shmem chunks
564+
x_res = x.reshape(2, 32, 2, 2, 4)
565+
# Put tmem cols first then rows
566+
x_res = x_res.permute(1, 2, 3, 0, 4)
567+
# Reshape as 32xnum_cols
568+
x_res = x_res.reshape(num_rows // 4, num_cols)
569+
570+
warps = torch.chunk(z_tri, chunks=4, dim=0)
571+
for warp in warps:
572+
torch.testing.assert_close(x_res, warp)
565573

566574

567575
@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell")

python/test/gluon/test_frontend.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from triton.experimental.gluon import language as ttgl
88
from triton.experimental.gluon.language.nvidia import blackwell
99
from triton.experimental.gluon.language.nvidia import hopper
10-
from triton.experimental.gluon.language.nvidia.blackwell import mbarrier, tma, TensorMemoryLayout, TensorMemoryScalesLayout, async_copy
10+
from triton.experimental.gluon.language.nvidia.blackwell import mbarrier, tma, TensorMemoryLayout, async_copy
1111
from triton.experimental.gluon.nvidia.hopper import TensorDescriptor
1212
from triton.experimental.gluon.language.amd import _layouts as amd_layouts
1313
from triton.experimental.gluon.language.amd.cdna4 import async_copy as cdna4_async_copy
@@ -613,23 +613,6 @@ def test_tcgen05_mma_mbar():
613613
""")
614614

615615

616-
@filecheck_test
617-
@gluon.jit
618-
def test_tcgen05_copy():
619-
# CHECK-LABEL: test_tcgen05_copy
620-
smem_h: ttgl.constexpr = 256
621-
num_cols: ttgl.constexpr = smem_h * 4 // 32
622-
623-
shared_layout: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=0, element_bitwidth=8, rank=2)
624-
tmem_layout: ttgl.constexpr = TensorMemoryScalesLayout()
625-
# CHECK: [[SRC:%.*]] = ttg.local_alloc
626-
src = ttgl.allocate_shared_memory(ttgl.int8, [smem_h, 4], shared_layout)
627-
# CHECK: [[DST:%.*]] = ttng.tmem_alloc
628-
dst = blackwell.allocate_tensor_memory(ttgl.int8, [128, num_cols], tmem_layout)
629-
# CHECK: ttng.tmem_copy [[SRC]], [[DST]]
630-
blackwell.tcgen05_copy(src, dst)
631-
632-
633616
@filecheck_test
634617
@gluon.jit
635618
def test_tcgen05_commit():

0 commit comments

Comments
 (0)