Skip to content

Commit ba9bd40

Browse files
authored
[FRONTEND] Support for ragged TMAs (#7783)
This allows us to use higher-dimensional TMA descriptors to emulate ragged-batching support with automatic bounds checking.
1 parent 376b9b9 commit ba9bd40

File tree

2 files changed

+126
-0
lines changed

2 files changed

+126
-0
lines changed

python/test/unit/cuda/test_tma_descriptor.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from contextlib import nullcontext
22
import pytest
33
import torch
4+
import triton
5+
from triton.tools.ragged_tma import create_ragged_descriptor, load_ragged, store_ragged
46
from triton.tools.tensor_descriptor import TensorDescriptor
57

68

@@ -44,3 +46,46 @@ def test_2d_tma_descriptor_exception(M, N, BLOCK_M, BLOCK_N, expect_error_n, exp
4446
ctx = pytest.raises(exc_type, match=match) if expect_error else nullcontext()
4547
with ctx:
4648
_ = TensorDescriptor.from_tensor(A, [BLOCK_M, BLOCK_N])
49+
50+
51+
@triton.jit
52+
def example_load_store_kernel(X, Y, x_off, y_off, x_size, y_size):
53+
54+
data = load_ragged(X, x_off, x_size, [0, 0])
55+
store_ragged(Y, y_off, y_size, [0, 0], data)
56+
57+
58+
@pytest.mark.parametrize("dtype", ["float16", "float32", "float64"])
59+
def test_ragged_tma(dtype):
60+
61+
if not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 9:
62+
pytest.skip("Test requires Hopper or Blackwell target.")
63+
return
64+
65+
dtype = getattr(torch, dtype)
66+
67+
src = torch.randn((1024, 80), dtype=torch.float32, device="cuda").to(dtype)
68+
ref = torch.randn((1024, 80), dtype=torch.float32, device="cuda").to(dtype)
69+
dst = 1.0 * ref
70+
71+
X = create_ragged_descriptor(src, [32, 128])
72+
Y = create_ragged_descriptor(dst, [32, 128])
73+
74+
x_off = 42
75+
y_off = 51
76+
x_size = 17
77+
y_size = 24
78+
79+
example_load_store_kernel[(1, )](X, Y, x_off, y_off, x_size, y_size)
80+
81+
# the initial and final segments are unchanged:
82+
res0 = torch.equal(dst[:y_off], ref[:y_off])
83+
res1 = torch.equal(dst[y_off + y_size:], ref[y_off + y_size:])
84+
85+
# this segment will be copied verbatim from src:
86+
res2 = torch.equal(dst[y_off:y_off + x_size], src[x_off:x_off + x_size])
87+
88+
# this segment will have read OOB zeroes and written them here:
89+
res3 = torch.all(dst[y_off + x_size:y_off + y_size] == 0.0).item()
90+
91+
assert [res0, res1, res2, res3] == [True, True, True, True]

python/triton/tools/ragged_tma.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import triton
2+
import triton.language as tl
3+
from triton.tools.tensor_descriptor import TensorDescriptor
4+
5+
# fmt: off
6+
7+
def create_ragged_descriptor(T, block_shape):
8+
"""
9+
Given a 2- or 3-dimensional tensor T, this creates a 'ragged descriptor'
10+
which behaves like a concatenation (along the first axis) of subarrays
11+
of potentially unequal size.
12+
13+
The load_ragged and store_ragged device functions can be used to read
14+
and write from subarrays T[batch_offset : batch_offset + batch_size]
15+
with hardware bounds-checking preventing any sort of leakage outside
16+
the subarray.
17+
"""
18+
19+
block_shape = list(block_shape)
20+
tensor_shape = list(T.shape)
21+
22+
assert 2 <= len(tensor_shape) <= 3, "ragged tensors must have dimension 2 or 3"
23+
assert len(tensor_shape) == len(block_shape), "block shape must match tensor shape"
24+
25+
max_int = 0x7fff0000
26+
billion = 0x40000000 # == 2**30
27+
28+
assert tensor_shape[0] <= billion, "number of rows may not exceed 2**30"
29+
30+
# we prepend an extra two dimensions and rely on the fact that pointers
31+
# have 64-bit wraparound semantics:
32+
tma_stride = [2**34 - T.stride(0), T.stride(0)] + [T.stride(i) for i in range(len(tensor_shape))]
33+
tma_shape = [max_int, max_int, billion] + tensor_shape[1:]
34+
box_shape = [1, 1] + block_shape
35+
36+
return TensorDescriptor(T, tma_shape, tma_stride, box_shape)
37+
38+
39+
@triton.jit
40+
def to_ragged_indices(batch_offset, batch_size, row):
41+
"""
42+
Helper function for load_ragged and store_ragged.
43+
"""
44+
45+
billion = 0x40000000 # == 2**30
46+
x = billion - batch_size + row
47+
y = batch_offset + batch_size
48+
49+
return billion, y, x
50+
51+
52+
@triton.jit
53+
def load_ragged(TMA, batch_offset, batch_size, coords):
54+
"""
55+
Read from a subarray T[batch_offset : batch_offset + batch_size] with
56+
hardware bounds-checking, where reading outside the subarray gives zeros.
57+
58+
Coords should be an appropriately-sized list of integers, just like in
59+
TMA.load().
60+
"""
61+
62+
c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[0])
63+
data = TMA.load([c0, c1, c2] + coords[1:])
64+
data = tl.reshape(data, data.shape[2:])
65+
return data
66+
67+
68+
@triton.jit
69+
def store_ragged(TMA, batch_offset, batch_size, coords, data):
70+
"""
71+
Write to a subarray T[batch_offset : batch_offset + batch_size] with
72+
hardware bounds-checking, where writes outside the subarray are masked
73+
correctly.
74+
75+
Coords should be an appropriately-sized list of integers, just like in
76+
TMA.store().
77+
"""
78+
79+
c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[0])
80+
data = tl.reshape(data, [1, 1] + data.shape)
81+
TMA.store([c0, c1, c2] + coords[1:], data)

0 commit comments

Comments
 (0)