Skip to content

Commit d846dcf

Browse files
allanrenucciGoogle-ML-Automation
authored andcommitted
[Pallas:MGPU] Add Pallas lowering for TMEM slices under WG semantic.
We lower TMEM slices to `memref.subview`. PiperOrigin-RevId: 831909484
1 parent 7774f56 commit d846dcf

File tree

8 files changed

+168
-66
lines changed

8 files changed

+168
-66
lines changed

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1449,7 +1449,10 @@ def _bubble_up(untransform_fn, data):
14491449
indices = _bubble_up(
14501450
lambda t, idxs: t.untransform_index(mlir_dtype, idxs), indices
14511451
)
1452-
if isinstance(transformed_ref, tcgen05.TMEMRef):
1452+
if (
1453+
isinstance(transformed_ref, tcgen05.TMEMRef)
1454+
and ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane
1455+
):
14531456
transformed_ref = transformed_ref.slice(*indices)
14541457
else:
14551458
transformed_ref = mgpu.memref_slice(transformed_ref, indices)

jax/experimental/mosaic/gpu/dialect_lowering.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1441,6 +1441,29 @@ def _memref_subview_op_lowering_rule(
14411441
) -> Sequence[ir.Value]:
14421442
del ctx
14431443

1444+
if any(s != 1 for s in op.static_strides):
1445+
raise NotImplementedError("SubViewOp only supports static strides of 1.")
1446+
if op.sizes:
1447+
raise NotImplementedError("SubViewOp only supports static sizes.")
1448+
src_ty = ir.MemRefType(op.source.type)
1449+
1450+
if utils.is_memref_transposed(src_ty):
1451+
raise NotImplementedError("SubViewOp does not support transposed memrefs.")
1452+
1453+
if utils.is_tmem_ref(src_ty):
1454+
[in_tmem_layout] = inference_utils.in_tmem_layouts(op)
1455+
[out_tmem_layout] = inference_utils.out_tmem_layouts(op)
1456+
assert in_tmem_layout == out_tmem_layout
1457+
ref = _tmem_ref_from_ir(op.source, in_tmem_layout)
1458+
indices = []
1459+
dynamic_offset_index = 0
1460+
for offset, size in zip(op.static_offsets, op.static_sizes, strict=True):
1461+
if ir.ShapedType.is_dynamic_size(offset):
1462+
offset = op.offsets[dynamic_offset_index]
1463+
dynamic_offset_index += 1
1464+
indices.append(utils.DynamicSlice(offset, size))
1465+
return [_tmem_ref_to_ir(ref.slice(*indices))]
1466+
14441467
in_transforms = inference_utils.in_transforms(op)[0]
14451468
out_transforms = inference_utils.out_transforms(op)[0]
14461469

@@ -1449,22 +1472,11 @@ def _memref_subview_op_lowering_rule(
14491472
"SubViewOp transforms for the input and output refs must be identical."
14501473
)
14511474

1452-
if any(s != 1 for s in op.static_strides):
1453-
raise NotImplementedError(
1454-
"SubViewOp only supports static strides of 1."
1455-
)
1456-
1457-
if utils.is_memref_transposed(op.source.type):
1458-
raise NotImplementedError(
1459-
"SubViewOp does not support transposed memrefs."
1460-
)
1461-
14621475
unwrapped_source_ref = unwrap_transformed_memref(op.source, in_transforms)
14631476
swizzle, transforms = swizzle_and_transforms_from_transforms_attr(out_transforms)
14641477
if swizzle != mgpu.SwizzlingMode.kNoSwizzle:
1465-
source_ty = ir.MemRefType(op.source.type)
1466-
swizzle_elems = swizzle * 8 // utils.bitwidth(source_ty.element_type)
1467-
source_strides, _ = source_ty.get_strides_and_offset()
1478+
swizzle_elems = swizzle * 8 // utils.bitwidth(src_ty.element_type)
1479+
source_strides, _ = src_ty.get_strides_and_offset()
14681480
for stride, offset, size in zip(
14691481
source_strides, op.static_offsets, op.static_sizes, strict=True
14701482
):
@@ -1774,6 +1786,14 @@ def _tmem_ref_from_ir(
17741786
return tcgen05.TMEMRef(tmem_addr, shape, el_ty, tmem_layout)
17751787

17761788

1789+
def _tmem_ref_to_ir(ref: tcgen05.TMEMRef) -> ir.Value:
1790+
"""Returns an IR value from a TMEMRef."""
1791+
type = ir.MemRefType.get(ref.shape, ref.dtype, memory_space=mgpu_utils.tmem())
1792+
cast = builtin.UnrealizedConversionCastOp([type], [ref.address])
1793+
cast.attributes["layout"] = layouts_lib.to_layout_attr(ref.layout)
1794+
return cast.result
1795+
1796+
17771797
@_register_lowering(mgpu.TcGen05MMAOp)
17781798
def _tcgen05_mma_op_lowering_rule(
17791799
ctx: LoweringContext, op: mgpu.TcGen05MMAOp
@@ -2155,6 +2175,7 @@ def _should_lower(op: ir.OpView) -> bool:
21552175
op.OPERATION_NAME.startswith("mosaic_gpu.") # pytype: disable=attribute-error
21562176
or inference_utils.should_have_layout(op)
21572177
or inference_utils.should_have_transforms(op)
2178+
or inference_utils.should_have_tmem_layout(op)
21582179
or any(bool(b) for r in op.regions for b in r) # Does it have subblocks?
21592180
)
21602181

jax/experimental/mosaic/gpu/equations.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,8 @@ def holds(self) -> bool | None:
529529
tiling = t
530530
case RegisterLayout(value=fa.TiledLayout() as layout):
531531
tiling = layout.base_tile_shape
532+
case TMEMLayout(value):
533+
tiling = value.base_tile_shape
532534
case _:
533535
return None
534536

jax/experimental/mosaic/gpu/layout_inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1450,7 +1450,7 @@ def _memref_subview_equation_system(
14501450
dest = ValueSite(op, VariableType.RESULT, 0)
14511451
source_dest_var = ctx.producer_ref(source)
14521452

1453-
if any(map(lambda s: s != 1, op.static_strides)):
1453+
if any(s != 1 for s in op.static_strides):
14541454
raise NotImplementedError(
14551455
f"Only unit strides are supported but got {op.static_strides}."
14561456
)
@@ -1473,7 +1473,7 @@ def _memref_subview_equation_system(
14731473
if ir.ShapedType.is_dynamic_size(size):
14741474
tiling_multiple = []
14751475
else:
1476-
src_type = ir.ShapedType(op.source.type)
1476+
src_type = ir.MemRefType(op.source.type)
14771477
divisibility_constraint = math.gcd(size, src_type.shape[i])
14781478
if isinstance(offset, int):
14791479
divisibility_constraint = math.gcd(divisibility_constraint, offset)

tests/mosaic/gpu_equations_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from jax.experimental.mosaic.gpu import equations
2222
from jax.experimental.mosaic.gpu import fragmented_array as fa
2323
from jax.experimental.mosaic.gpu import launch_context as lc
24+
from jax.experimental.mosaic.gpu import tcgen05
2425

2526
config.parse_flags_with_absl()
2627

@@ -469,6 +470,10 @@ def test_divides_constraints_are_satisfied_by_divisor_tiling(self):
469470
with self.subTest("RegisterLayout"):
470471
tiling = equations.RegisterLayout(fa.WGMMA_LAYOUT)
471472
self.assertTrue(equations.Divides(tiling, (0, 64)).holds())
473+
with self.subTest("TMEMLayout"):
474+
layout = tcgen05.tmem_default_layout(packing=1)
475+
tiling = equations.TMEMLayout(layout)
476+
self.assertTrue(equations.Divides(tiling, (0, 64)).holds())
472477

473478
def test_divides_constraints_are_not_satisfied_by_non_divisor_tiling(self):
474479
with self.subTest("SMEMTiling"):
@@ -477,6 +482,10 @@ def test_divides_constraints_are_not_satisfied_by_non_divisor_tiling(self):
477482
with self.subTest("RegisterLayout"):
478483
tiling = equations.RegisterLayout(fa.WGMMA_LAYOUT)
479484
self.assertFalse(equations.Divides(tiling, (3, 64)).holds())
485+
with self.subTest("TMEMLayout"):
486+
layout = tcgen05.tmem_default_layout(packing=1)
487+
tiling = equations.TMEMLayout(layout)
488+
self.assertFalse(equations.Divides(tiling, (3, 64)).holds())
480489

481490
def test_reduce_merges_divides_constraints_on_same_variable(self):
482491
v0, v1 = equations.Variable(0), equations.Variable(1)

tests/mosaic/gpu_layout_inference_test.py

Lines changed: 75 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1660,36 +1660,52 @@ def test_infer_transforms_for_memref_cast_op(self, annotate_producer):
16601660
def test_infer_transforms_for_subview_raises_on_slice_incompatible_with_tile(
16611661
self, annotate_input
16621662
):
1663-
shape = (2, 64, 64)
1664-
elt_ty = ir.BF16Type.get()
1665-
1666-
in_ref_ty = ir.MemRefType.get(shape, elt_ty, memory_space=mgpu.utils.smem())
1667-
out_ref_ty = ir.MemRefType.get((2, 64, 32), elt_ty, memory_space=mgpu.utils.smem())
1668-
16691663
with ir.InsertionPoint(self.module.body):
1664+
in_ref_ty = ir.MemRefType.get(
1665+
(2, 64, 64), ir.BF16Type.get(), memory_space=mgpu.utils.smem()
1666+
)
16701667
[in_ref] = undefs(in_ref_ty)
16711668

16721669
transforms = ir.ArrayAttr.get([
1673-
mgpu.dialect.TileTransformAttr.get((32, 16)),
1674-
mgpu.dialect.SwizzleTransformAttr.get(32),
1670+
mgpu.dialect.TileTransformAttr.get((32, 16)),
1671+
mgpu.dialect.SwizzleTransformAttr.get(32),
16751672
])
16761673

16771674
if annotate_input:
16781675
in_ref = mgpu.dialect.with_transforms(in_ref, transforms)
16791676

1680-
subview_op = memref.SubViewOp(
1681-
out_ref_ty,
1682-
in_ref,
1683-
[],
1684-
[],
1685-
[],
1686-
static_offsets = [1, 0, 0],
1687-
static_sizes = [2, 64, 8],
1688-
static_strides = [1, 1, 1]
1677+
out_ref = memref.subview(
1678+
in_ref, offsets=[1, 0, 0], sizes=[2, 64, 8], strides=[1, 1, 1]
1679+
)
1680+
1681+
if not annotate_input:
1682+
mgpu.dialect.with_transforms(out_ref, transforms)
1683+
1684+
with self.assertRaisesRegex(ValueError, "Failed to infer"):
1685+
mgpu.infer_layout(self.module)
1686+
1687+
@parameterized.parameters([False, True])
1688+
def test_infer_tmem_layouts_for_subview_raises_on_slice_incompatible_with_tile(
1689+
self, annotate_input
1690+
):
1691+
with ir.InsertionPoint(self.module.body):
1692+
in_ref_ty = ir.MemRefType.get(
1693+
(128, 64), ir.BF16Type.get(), memory_space=mgpu.utils.tmem()
1694+
)
1695+
[in_ref] = undefs(in_ref_ty)
1696+
1697+
layout = tcgen05.tmem_default_layout(packing=1)
1698+
layout_attr = layouts.to_layout_attr(layout)
1699+
1700+
if annotate_input:
1701+
in_ref = mgpu.dialect.tmem_layout_cast(in_ref, layout_attr)
1702+
1703+
out_ref = memref.subview(
1704+
in_ref, offsets=[1, 0], sizes=[2, 64], strides=[1, 1]
16891705
)
16901706

16911707
if not annotate_input:
1692-
mgpu.dialect.with_transforms(subview_op.result, transforms)
1708+
mgpu.dialect.tmem_layout_cast(out_ref, layout_attr)
16931709

16941710
with self.assertRaisesRegex(ValueError, "Failed to infer"):
16951711
mgpu.infer_layout(self.module)
@@ -1807,13 +1823,10 @@ def test_infer_transforms_for_sibling_subviews_and_distant_op(
18071823
def test_infer_transforms_for_subview_handles_dynamic_offsets(
18081824
self, annotate_input
18091825
):
1810-
shape = (32, 32, 32, 32)
1811-
elt_ty = ir.BF16Type.get()
1812-
1813-
in_ref_ty = ir.MemRefType.get(shape, elt_ty, memory_space=mgpu.utils.smem())
1814-
out_ref_ty = ir.MemRefType.get((16, 16, 32, 32), elt_ty, memory_space=mgpu.utils.smem())
1815-
18161826
with ir.InsertionPoint(self.module.body):
1827+
in_ref_ty = ir.MemRefType.get(
1828+
(32, 32, 32, 32), ir.BF16Type.get(), memory_space=mgpu.utils.smem()
1829+
)
18171830
[in_ref] = undefs(in_ref_ty)
18181831

18191832
transforms = ir.ArrayAttr.get([
@@ -1825,34 +1838,55 @@ def test_infer_transforms_for_subview_handles_dynamic_offsets(
18251838
in_ref = mgpu.dialect.with_transforms(in_ref, transforms)
18261839

18271840
c = lambda x: arith.constant(ir.IntegerType.get_signless(32), x)
1828-
subview_op = memref.SubViewOp(
1829-
out_ref_ty,
1841+
out_ref = memref.subview(
18301842
in_ref,
1831-
[c(16), c(4), arith.muli(c(8), c(3))],
1832-
[],
1833-
[],
1834-
static_offsets=[
1835-
ir.ShapedType.get_dynamic_size(),
1836-
ir.ShapedType.get_dynamic_size(),
1837-
ir.ShapedType.get_dynamic_size(),
1838-
0,
1839-
],
1840-
static_sizes=[16, 16, 32, 32],
1841-
static_strides=[1, 1, 1, 1],
1843+
offsets=[c(16), c(4), arith.muli(c(8), c(3)), 0],
1844+
sizes=[16, 16, 32, 32],
1845+
strides=[1, 1, 1, 1],
18421846
)
18431847

18441848
if not annotate_input:
1845-
mgpu.dialect.with_transforms(subview_op.result, transforms)
1849+
mgpu.dialect.with_transforms(out_ref, transforms)
18461850

18471851
mgpu.infer_layout(self.module)
1848-
18491852
self.assertSequenceEqual(
1850-
inference_utils.in_transforms(subview_op), [transforms]
1853+
inference_utils.in_transforms(out_ref.owner), [transforms]
18511854
)
18521855
self.assertSequenceEqual(
1853-
inference_utils.out_transforms(subview_op), [transforms]
1856+
inference_utils.out_transforms(out_ref.owner), [transforms]
18541857
)
18551858

1859+
@parameterized.parameters([False, True])
1860+
def test_infer_tmem_layouts_for_subview_handles_dynamic_offsets(
1861+
self, annotate_input
1862+
):
1863+
with ir.InsertionPoint(self.module.body):
1864+
in_ref_ty = ir.MemRefType.get(
1865+
(128, 256), ir.BF16Type.get(), memory_space=mgpu.utils.tmem()
1866+
)
1867+
[in_ref] = undefs(in_ref_ty)
1868+
1869+
layout = tcgen05.tmem_default_layout(packing=1)
1870+
layout_attr = layouts.to_layout_attr(layout)
1871+
1872+
if annotate_input:
1873+
in_ref = mgpu.dialect.tmem_layout_cast(in_ref, layout_attr)
1874+
1875+
c = lambda x: arith.constant(ir.IntegerType.get_signless(32), x)
1876+
out_ref = memref.subview(
1877+
in_ref,
1878+
offsets=[c(0), arith.muli(c(16), c(4))],
1879+
sizes=[128, 128],
1880+
strides=[1, 1],
1881+
)
1882+
1883+
if not annotate_input:
1884+
mgpu.dialect.tmem_layout_cast(out_ref, layout_attr)
1885+
1886+
mgpu.infer_layout(self.module)
1887+
self.checkInTmemLayouts(out_ref.owner, [layout])
1888+
self.checkOutTmemLayouts(out_ref.owner, [layout])
1889+
18561890
def test_custom_primitive_op_retains_transforms(self):
18571891
with ir.InsertionPoint(self.module.body):
18581892
transforms = ir.ArrayAttr.get([

tests/mosaic/gpu_test.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5570,6 +5570,46 @@ def body(ctx, x, y, x_out, y_out, tmem):
55705570
self.assertArraysEqual(x_out, x)
55715571
self.assertArraysEqual(y_out, y)
55725572

5573+
def test_tmem_subview(self):
5574+
def body(ctx, in_ref, out_ref, tmem):
5575+
del ctx
5576+
# GMEM -> Registers -> TMEM
5577+
in_reg = mgpu_dialect.vector_load(in_ref)
5578+
slice_in = memref.subview(
5579+
tmem, offsets=[0, 8], sizes=[128, 200], strides=[1, 1]
5580+
)
5581+
slice_in = memref.subview(
5582+
slice_in, offsets=[0, 0], sizes=[128, 128], strides=[1, 1]
5583+
)
5584+
mgpu_dialect.async_store_tmem(in_reg, slice_in)
5585+
tcgen05.commit_tmem()
5586+
5587+
def dynamic_idx(idx: int) -> ir.Value:
5588+
idx_type = ir.IndexType.get()
5589+
return arith.constant(idx_type, idx)
5590+
5591+
# TMEM -> Registers -> GMEM
5592+
slice_out = memref.subview(
5593+
tmem,
5594+
offsets=[dynamic_idx(0), dynamic_idx(8)],
5595+
sizes=[128, 128],
5596+
strides=[1, 1],
5597+
)
5598+
out_reg = mgpu_dialect.async_load_tmem(slice_out)
5599+
mgpu_dialect.vector_store(out_reg, out_ref)
5600+
5601+
kernel = mgpu.as_gpu_kernel(
5602+
body,
5603+
grid=(1, 1, 1),
5604+
block=(128, 1, 1),
5605+
in_shape=jax.ShapeDtypeStruct((128, 128), jnp.float32),
5606+
out_shape=jax.ShapeDtypeStruct((128, 128), jnp.float32),
5607+
smem_scratch_shape=mgpu.TMEM((128, 256), jnp.float32),
5608+
thread_semantics=mgpu.LoweringSemantics.Warpgroup,
5609+
)
5610+
x = self.prng.uniform(-100, 100, (128, 128)).astype(jnp.float32)
5611+
self.assertArraysEqual(kernel(x), x)
5612+
55735613

55745614
class UtilsTest(TestCase):
55755615
@parameterized.parameters(

tests/pallas/mosaic_gpu_test.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3204,13 +3204,8 @@ def test_print_layout_tmem(self):
32043204
)
32053205
def kernel(o_ref, tmem_ref):
32063206
del o_ref
3207-
if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane:
3208-
# Slicing TMEM to make sure we handle transforms correctly.
3209-
plgpu.print_layout("tmem: {}", tmem_ref.at[:, :128])
3210-
else:
3211-
# TODO(b/415721295): Remove this branch once TMEM slicing is supported
3212-
# for WG semantics.
3213-
plgpu.print_layout("tmem: {}", tmem_ref)
3207+
# Slicing TMEM to make sure we handle transforms correctly.
3208+
plgpu.print_layout("tmem: {}", tmem_ref.at[:, :128])
32143209

32153210
with self.capture_stdout() as output:
32163211
jax.block_until_ready(kernel())
@@ -3412,7 +3407,6 @@ def kernel(x_ref, y_ref, tmem_ref, smem_ref, barrier_ref):
34123407
np.testing.assert_array_equal(x_result, x + 1)
34133408

34143409
def test_tmem_column_slicing(self):
3415-
self.skip_if_wg_semantics()
34163410
transforms = self.default_transforms(dtype=jnp.float32)
34173411
@functools.partial(
34183412
self.kernel,
@@ -3806,7 +3800,6 @@ def kernel(a_gmem, b_gmem, out_gmem,
38063800
np.testing.assert_allclose(result, x @ y, rtol=1e-3)
38073801

38083802
def test_matmul_with_sliced_accumulator(self):
3809-
self.skip_if_wg_semantics() # Slicing TMEM is not supported.
38103803
dtype = jnp.bfloat16
38113804
shape = (128, 128)
38123805
tmem_shape = (128, 2 * 128)

0 commit comments

Comments
 (0)