Skip to content

Commit 376b9b9

Browse files
authored
[AMD][Gluon] Expose buffer_load and buffer_store (#7738)
Expose AMD buffer_load and buffer_store Gluon. Example usage looks like: ``` def buffer_ldst_kernel(x, y): layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 1], threads_per_warp=[1, 64], warps_per_cta=[4, 1], order=[1, 0]) offsets = ttgl.arange(0, 64 * 64, layout=layout) a = ttgl.amd.cdna3.buffer_load(ptr=x, offsets=offsets) ttgl.amd.cdna3.buffer_store(stored_value=a, ptr=y, offsets=offsets) ```
1 parent 2bb88ff commit 376b9b9

File tree

4 files changed

+210
-10
lines changed

4 files changed

+210
-10
lines changed

python/src/gluon_ir.cc

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ namespace tt = triton;
2020
namespace ttg = triton::gpu;
2121
namespace ttng = triton::nvidia_gpu;
2222
namespace gluon = mlir::triton::gluon;
23+
namespace ttag = mlir::triton::amdgpu;
2324

2425
// Helper to check if an MLIR type or attribute has a verifier method.
2526
template <typename AttrOrType>
@@ -246,7 +247,8 @@ void init_gluon_ir(py::module &&m) {
246247
auto ctx = self.getContext();
247248
return self.getChecked<ttg::MemDescType>(
248249
shape, elementType, layout,
249-
ttg::SharedMemorySpaceAttr::get(ctx), /*mutableMemory=*/true,
250+
ttg::SharedMemorySpaceAttr::get(ctx),
251+
/*mutableMemory=*/true,
250252
/*allocShape=*/allocShape);
251253
})
252254
.def("get_tensor_mem_desc_ty",
@@ -256,7 +258,8 @@ void init_gluon_ir(py::module &&m) {
256258
auto ctx = self.getContext();
257259
return self.getChecked<ttg::MemDescType>(
258260
shape, elementType, layout,
259-
ttng::TensorMemorySpaceAttr::get(ctx), /*mutableMemory=*/true,
261+
ttng::TensorMemorySpaceAttr::get(ctx),
262+
/*mutableMemory=*/true,
260263
/*allocShape=*/allocShape);
261264
})
262265
.def("get_blocked_layout",
@@ -404,8 +407,8 @@ void init_gluon_ir(py::module &&m) {
404407
tt::CacheModifier cacheModifier,
405408
tt::EvictionPolicy evictionPolicy, bool isVolatile) {
406409
self.create<ttg::AsyncCopyGlobalToLocalOp>(
407-
pointer, smem, mask, /*other*/ Value{}, cacheModifier,
408-
evictionPolicy, isVolatile);
410+
pointer, smem, mask,
411+
/*other*/ Value{}, cacheModifier, evictionPolicy, isVolatile);
409412
})
410413
.def("create_async_copy_mbarrier_arrive",
411414
[](GluonOpBuilder &self, Value mbarrier, bool incrementCount) {
@@ -622,11 +625,24 @@ void init_gluon_ir(py::module &&m) {
622625
return self.create<ttg::WarpSpecializeOp>(
623626
resultTypes, explicitCaptures, partitionNumWarps);
624627
})
628+
.def("create_buffer_load",
629+
[](GluonOpBuilder &self, Type resultType, Value ptr, Value offsets,
630+
Value mask, Value other, tt::CacheModifier cache) -> Value {
631+
return self.create<ttag::BufferLoadOp>(resultType, ptr, offsets,
632+
Value() /*stride*/, cache,
633+
mask, other);
634+
})
635+
.def("create_buffer_store",
636+
[](GluonOpBuilder &self, Value storedValue, Value ptr, Value offsets,
637+
Value mask, tt::CacheModifier cache) {
638+
self.create<ttag::BufferStoreOp>(storedValue, ptr, offsets,
639+
Value() /*stride*/, cache, mask);
640+
})
625641
.def("create_buffer_load_to_local",
626642
[](GluonOpBuilder &self, Value dest, Value ptr, Value offsets,
627643
Value mask, Value other, Value stride,
628644
tt::CacheModifier cacheModifier) {
629-
self.create<triton::amdgpu::BufferLoadToLocalOp>(
645+
self.create<ttag::BufferLoadToLocalOp>(
630646
dest, ptr, offsets, mask, other, stride, cacheModifier);
631647
});
632648

python/test/gluon/test_frontend.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1605,3 +1605,109 @@ def kernel(ptr):
16051605
}
16061606
}
16071607
""")
1608+
1609+
1610+
@gluon.jit
1611+
def buffer_load_store_kernel(x, y):
1612+
layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 1], threads_per_warp=[1, 64], warps_per_cta=[4, 1],
1613+
order=[1, 0])
1614+
1615+
offsets = ttgl.arange(0, 64 * 64).reshape(64, 64)
1616+
offsets = ttgl.convert_layout(offsets, layout=layout)
1617+
mask = ttgl.full((64, 64), 1, tl.int1, layout=layout)
1618+
other = ttgl.full((64, 64), 1.0, tl.float32, layout=layout)
1619+
a = ttgl.amd.cdna3.buffer_load(ptr=x, offsets=offsets, mask=mask, other=other, cache='.ca')
1620+
ttgl.amd.cdna3.buffer_store(stored_value=a, ptr=y, offsets=offsets, mask=mask, cache='.ca')
1621+
1622+
a = ttgl.amd.cdna4.buffer_load(ptr=x, offsets=offsets, mask=mask, other=other, cache='.ca')
1623+
ttgl.amd.cdna4.buffer_store(stored_value=a, ptr=y, offsets=offsets, mask=mask, cache='.ca')
1624+
1625+
1626+
@pytest.mark.parametrize("target", [HIP_TARGET_CDNA3, HIP_TARGET_CDNA4])
1627+
def test_buffer_load_store(target):
1628+
x = MockTensor(ttgl.float32)
1629+
y = MockTensor(ttgl.float32)
1630+
module = run_parser(buffer_load_store_kernel, *make_args(x, y), target=target)
1631+
1632+
expecttest.assert_expected_inline(
1633+
anonymize_ir(module.str_nodebug()), """\
1634+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}>
1635+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} {
1636+
tt.func public @buffer_load_store_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
1637+
%0 = tt.make_range {end = 4096 : i32, start = 0 : i32} : tensor<4096xi32, #gluon.auto_encoding>
1638+
%1 = tt.reshape %0 : tensor<4096xi32, #gluon.auto_encoding> -> tensor<64x64xi32, #gluon.auto_encoding>
1639+
%2 = ttg.convert_layout %1 : tensor<64x64xi32, #gluon.auto_encoding> -> tensor<64x64xi32, #blocked>
1640+
%true = arith.constant true
1641+
%cst = arith.constant dense<true> : tensor<64x64xi1, #blocked>
1642+
%cst_0 = arith.constant 1.000000e+00 : f32
1643+
%cst_1 = arith.constant dense<1.000000e+00> : tensor<64x64xf32, #blocked>
1644+
%3 = amdgpu.buffer_load %arg0[%2], %cst, %cst_1 cacheModifier = ca : tensor<64x64xf32, #blocked>
1645+
amdgpu.buffer_store %3, %arg1[%2], %cst cacheModifier = ca : tensor<64x64xf32, #blocked>
1646+
%4 = amdgpu.buffer_load %arg0[%2], %cst, %cst_1 cacheModifier = ca : tensor<64x64xf32, #blocked>
1647+
amdgpu.buffer_store %4, %arg1[%2], %cst cacheModifier = ca : tensor<64x64xf32, #blocked>
1648+
tt.return
1649+
}
1650+
}
1651+
""")
1652+
1653+
1654+
@gluon.jit
1655+
def buffer_load_store_with_broadcast_kernel(x, y):
1656+
layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 1], threads_per_warp=[1, 64], warps_per_cta=[4, 1],
1657+
order=[1, 0])
1658+
1659+
offsets = ttgl.arange(0, 64 * 64).reshape(64, 64)
1660+
offsets = ttgl.convert_layout(offsets, layout=layout)
1661+
other = ttgl.full((64, 64), 1.0, tl.float32, layout=layout)
1662+
1663+
mask = ttgl.full((64, 1), 1, tl.int1, layout=layout)
1664+
a = ttgl.amd.cdna3.buffer_load(ptr=x, offsets=offsets, mask=mask, other=other, cache='.ca')
1665+
ttgl.amd.cdna3.buffer_store(stored_value=a, ptr=y, offsets=offsets, mask=mask, cache='.ca')
1666+
1667+
mask = ttgl.full((1, 64), 1, tl.int1, layout=layout)
1668+
a = ttgl.amd.cdna3.buffer_load(ptr=x, offsets=offsets, mask=mask, other=other, cache='.ca')
1669+
ttgl.amd.cdna3.buffer_store(stored_value=a, ptr=y, offsets=offsets, mask=mask, cache='.ca')
1670+
1671+
other = 1.0
1672+
a = ttgl.amd.cdna3.buffer_load(ptr=x, offsets=offsets, mask=mask, other=other, cache='.ca')
1673+
ttgl.amd.cdna3.buffer_store(stored_value=a, ptr=y, offsets=offsets, mask=mask, cache='.ca')
1674+
1675+
1676+
@pytest.mark.parametrize("target", [HIP_TARGET_CDNA3, HIP_TARGET_CDNA4])
1677+
def test_buffer_load_store_with_broadcast(target):
1678+
x = MockTensor(ttgl.float32)
1679+
y = MockTensor(ttgl.float32)
1680+
module = run_parser(buffer_load_store_with_broadcast_kernel, *make_args(x, y), target=target)
1681+
1682+
expecttest.assert_expected_inline(
1683+
anonymize_ir(module.str_nodebug()), """\
1684+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}>
1685+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} {
1686+
tt.func public @buffer_load_store_with_broadcast_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
1687+
%0 = tt.make_range {end = 4096 : i32, start = 0 : i32} : tensor<4096xi32, #gluon.auto_encoding>
1688+
%1 = tt.reshape %0 : tensor<4096xi32, #gluon.auto_encoding> -> tensor<64x64xi32, #gluon.auto_encoding>
1689+
%2 = ttg.convert_layout %1 : tensor<64x64xi32, #gluon.auto_encoding> -> tensor<64x64xi32, #blocked>
1690+
%cst = arith.constant 1.000000e+00 : f32
1691+
%cst_0 = arith.constant dense<1.000000e+00> : tensor<64x64xf32, #blocked>
1692+
%true = arith.constant true
1693+
%cst_1 = arith.constant dense<true> : tensor<64x1xi1, #blocked>
1694+
%3 = tt.broadcast %cst_1 : tensor<64x1xi1, #blocked> -> tensor<64x64xi1, #blocked>
1695+
%4 = amdgpu.buffer_load %arg0[%2], %3, %cst_0 cacheModifier = ca : tensor<64x64xf32, #blocked>
1696+
%5 = tt.broadcast %cst_1 : tensor<64x1xi1, #blocked> -> tensor<64x64xi1, #blocked>
1697+
amdgpu.buffer_store %4, %arg1[%2], %5 cacheModifier = ca : tensor<64x64xf32, #blocked>
1698+
%true_2 = arith.constant true
1699+
%cst_3 = arith.constant dense<true> : tensor<1x64xi1, #blocked>
1700+
%6 = tt.broadcast %cst_3 : tensor<1x64xi1, #blocked> -> tensor<64x64xi1, #blocked>
1701+
%7 = amdgpu.buffer_load %arg0[%2], %6, %cst_0 cacheModifier = ca : tensor<64x64xf32, #blocked>
1702+
%8 = tt.broadcast %cst_3 : tensor<1x64xi1, #blocked> -> tensor<64x64xi1, #blocked>
1703+
amdgpu.buffer_store %7, %arg1[%2], %8 cacheModifier = ca : tensor<64x64xf32, #blocked>
1704+
%cst_4 = arith.constant 1.000000e+00 : f32
1705+
%9 = tt.broadcast %cst_3 : tensor<1x64xi1, #blocked> -> tensor<64x64xi1, #blocked>
1706+
%cst_5 = arith.constant dense<1.000000e+00> : tensor<64x64xf32, #blocked>
1707+
%10 = amdgpu.buffer_load %arg0[%2], %9, %cst_5 cacheModifier = ca : tensor<64x64xf32, #blocked>
1708+
%11 = tt.broadcast %cst_3 : tensor<1x64xi1, #blocked> -> tensor<64x64xi1, #blocked>
1709+
amdgpu.buffer_store %10, %arg1[%2], %11 cacheModifier = ca : tensor<64x64xf32, #blocked>
1710+
tt.return
1711+
}
1712+
}
1713+
""")

python/triton/experimental/gluon/language/amd/cdna3/__init__.py

Lines changed: 81 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,29 @@
1-
from ..._core import builtin, int32, uint32
2-
from ..._semantic import _check
1+
from __future__ import annotations
2+
from typing import TYPE_CHECKING
3+
4+
from triton.experimental.gluon.language import _core as ttgl
35
from triton._C.libtriton import ir
6+
from ..._core import builtin, int32, uint32, _unwrap_if_constexpr
7+
from ..._semantic import _check
8+
9+
if TYPE_CHECKING:
10+
from ..._semantic import GluonSemantic
11+
12+
__all__ = ["buffer_load_to_shared", "buffer_load", "buffer_store"]
413

5-
__all__ = ["buffer_load_to_shared"]
14+
15+
def _verify_buffer_load_store(ptr, offsets, mask, other=None):
16+
assert ptr.type.is_ptr(), "ptr must be a scalar pointer type"
17+
18+
assert isinstance(offsets.type, ttgl.distributed_type), "expected offsets type to be a distributed_type"
19+
assert offsets.dtype.is_int32() or offsets.dtype.is_uint32(), "offsets element type must be int32 or uint32"
20+
21+
element_type = ptr.type.scalar.element_ty
22+
23+
if other is not None:
24+
assert mask is not None, "when other is not None, mask should not be None"
25+
assert other.shape == offsets.shape, "other shape must match the offsets shape"
26+
assert other.dtype == element_type, "other must have the same data type as ptr scalar type"
627

728

829
@builtin
@@ -32,3 +53,60 @@ def buffer_load_to_shared(dest, ptr, offsets, mask=None, other=None, cache_modif
3253
cache_modifier = _semantic._str_to_load_cache_modifier(cache_modifier)
3354

3455
builder.create_buffer_load_to_local(dest.handle, ptr.handle, offsets.handle, mask, other, stride, cache_modifier)
56+
57+
58+
@builtin
59+
def buffer_load(ptr, offsets, mask=None, other=None, cache=None, _semantic=None):
60+
"""
61+
AMD buffer load from global memory via a scalar base pointer and a tensor of
62+
offsets instead of a tensor of pointers. This operation will load data
63+
directly into registers.
64+
65+
Args:
66+
ptr (pointer to scalar): Global memory scalar base pointer to load from.
67+
offsets (tensor): Offsets tensor for the load operation.
68+
mask (tensor, optional): Mask tensor for predicated loads. Defaults to None.
69+
other (tensor, optional): Tensor providing default values for masked elements. Defaults to None.
70+
cache_modifier (str): Cache modifier specifier. Defaults to "".
71+
"""
72+
mask = _unwrap_if_constexpr(mask)
73+
if mask is not None:
74+
offsets, mask = _semantic.broadcast_impl_value(offsets, mask)
75+
76+
other = _unwrap_if_constexpr(other)
77+
if other is not None:
78+
offsets, other = _semantic.broadcast_impl_value(offsets, other)
79+
80+
_verify_buffer_load_store(ptr, offsets, mask, other)
81+
82+
other = other.handle if other is not None else ir.value()
83+
mask = mask.handle if mask is not None else ir.value()
84+
cache_modifier = _semantic._str_to_load_cache_modifier(cache) if cache is not None else ir.CACHE_MODIFIER.NONE
85+
86+
ret_ty = offsets.type.with_element_ty(ptr.type.scalar.element_ty)
87+
builder = _semantic.builder
88+
handle = builder.create_buffer_load(ret_ty.to_ir(builder), ptr.handle, offsets.handle, mask, other, cache_modifier)
89+
return ttgl.tensor(handle, ret_ty)
90+
91+
92+
@builtin
93+
def buffer_store(stored_value, ptr, offsets, mask, cache=None, _semantic: GluonSemantic = None):
94+
"""
95+
AMD buffer store a tensor directly to global memory via a scalar base pointer and a tensor of
96+
offsets instead of a tensor of pointers.
97+
Args:
98+
stored_value (tensor to be stored): The tensor to be stored to global memory.
99+
ptr (pointer to scalar): Global memory scalar base pointer to store to.
100+
offsets (tensor): Offsets tensor for the store operation.
101+
mask (tensor, optional): Mask tensor for predicated store. Defaults to None.
102+
cache_modifier (str): Cache modifier specifier. Defaults to "".
103+
"""
104+
if mask is not None:
105+
offsets, mask = _semantic.broadcast_impl_value(offsets, mask)
106+
107+
_verify_buffer_load_store(ptr, offsets, mask)
108+
109+
mask = mask.handle if mask is not None else ir.value()
110+
cache_modifier = _semantic._str_to_load_cache_modifier(cache) if cache is not None else ir.CACHE_MODIFIER.NONE
111+
112+
_semantic.builder.create_buffer_store(stored_value.handle, ptr.handle, offsets.handle, mask, cache_modifier)
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from ..cdna3 import buffer_load_to_shared
1+
from ..cdna3 import buffer_load_to_shared, buffer_load, buffer_store
22

3-
__all__ = ["buffer_load_to_shared"]
3+
__all__ = ["buffer_load_to_shared", "buffer_load", "buffer_store"]

0 commit comments

Comments
 (0)