|
| 1 | +import triton |
| 2 | +import triton.language as tl |
| 3 | +from triton.tools.tensor_descriptor import TensorDescriptor |
| 4 | + |
| 5 | +# fmt: off |
| 6 | + |
| 7 | +def create_ragged_descriptor(T, block_shape): |
| 8 | + """ |
| 9 | + Given a 2- or 3-dimensional tensor T, this creates a 'ragged descriptor' |
| 10 | + which behaves like a concatenation (along the first axis) of subarrays |
| 11 | + of potentially unequal size. |
| 12 | +
|
| 13 | + The load_ragged and store_ragged device functions can be used to read |
| 14 | + and write from subarrays T[batch_offset : batch_offset + batch_size] |
| 15 | + with hardware bounds-checking preventing any sort of leakage outside |
| 16 | + the subarray. |
| 17 | + """ |
| 18 | + |
| 19 | + block_shape = list(block_shape) |
| 20 | + tensor_shape = list(T.shape) |
| 21 | + |
| 22 | + assert 2 <= len(tensor_shape) <= 3, "ragged tensors must have dimension 2 or 3" |
| 23 | + assert len(tensor_shape) == len(block_shape), "block shape must match tensor shape" |
| 24 | + |
| 25 | + max_int = 0x7fff0000 |
| 26 | + billion = 0x40000000 # == 2**30 |
| 27 | + |
| 28 | + assert tensor_shape[0] <= billion, "number of rows may not exceed 2**30" |
| 29 | + |
| 30 | + # we prepend an extra two dimensions and rely on the fact that pointers |
| 31 | + # have 64-bit wraparound semantics: |
| 32 | + tma_stride = [2**34 - T.stride(0), T.stride(0)] + [T.stride(i) for i in range(len(tensor_shape))] |
| 33 | + tma_shape = [max_int, max_int, billion] + tensor_shape[1:] |
| 34 | + box_shape = [1, 1] + block_shape |
| 35 | + |
| 36 | + return TensorDescriptor(T, tma_shape, tma_stride, box_shape) |
| 37 | + |
| 38 | + |
| 39 | +@triton.jit |
| 40 | +def to_ragged_indices(batch_offset, batch_size, row): |
| 41 | + """ |
| 42 | + Helper function for load_ragged and store_ragged. |
| 43 | + """ |
| 44 | + |
| 45 | + billion = 0x40000000 # == 2**30 |
| 46 | + x = billion - batch_size + row |
| 47 | + y = batch_offset + batch_size |
| 48 | + |
| 49 | + return billion, y, x |
| 50 | + |
| 51 | + |
| 52 | +@triton.jit |
| 53 | +def load_ragged(TMA, batch_offset, batch_size, coords): |
| 54 | + """ |
| 55 | + Read from a subarray T[batch_offset : batch_offset + batch_size] with |
| 56 | + hardware bounds-checking, where reading outside the subarray gives zeros. |
| 57 | +
|
| 58 | + Coords should be an appropriately-sized list of integers, just like in |
| 59 | + TMA.load(). |
| 60 | + """ |
| 61 | + |
| 62 | + c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[0]) |
| 63 | + data = TMA.load([c0, c1, c2] + coords[1:]) |
| 64 | + data = tl.reshape(data, data.shape[2:]) |
| 65 | + return data |
| 66 | + |
| 67 | + |
| 68 | +@triton.jit |
| 69 | +def store_ragged(TMA, batch_offset, batch_size, coords, data): |
| 70 | + """ |
| 71 | + Write to a subarray T[batch_offset : batch_offset + batch_size] with |
| 72 | + hardware bounds-checking, where writes outside the subarray are masked |
| 73 | + correctly. |
| 74 | +
|
| 75 | + Coords should be an appropriately-sized list of integers, just like in |
| 76 | + TMA.store(). |
| 77 | + """ |
| 78 | + |
| 79 | + c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[0]) |
| 80 | + data = tl.reshape(data, [1, 1] + data.shape) |
| 81 | + TMA.store([c0, c1, c2] + coords[1:], data) |
0 commit comments