Skip to content

Commit dd58234

Browse files
authored
[Gluon] Add warpgroup_mma_accumulator object (#7760)
1 parent ba9bd40 commit dd58234

File tree

4 files changed

+66
-9
lines changed

4 files changed

+66
-9
lines changed

python/test/gluon/test_consan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@ def multibuffered_loop_wgmma_kernel(input_desc, XBLOCK: ttgl.constexpr, FAILURE:
562562

563563
mma_layout: ttgl.constexpr = ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1],
564564
instr_shape=[16, 32, 16])
565-
acc = ttgl.zeros([XBLOCK, XBLOCK], ttgl.float32, mma_layout)
565+
acc = hopper.warpgroup_mma_init(ttgl.zeros([XBLOCK, XBLOCK], ttgl.float32, mma_layout))
566566

567567
smemA = ttgl.allocate_shared_memory(ttgl.float16, [num_buffers, XBLOCK, XBLOCK], input_desc.layout)
568568
smemB = ttgl.allocate_shared_memory(ttgl.float16, [num_buffers, XBLOCK, XBLOCK], input_desc.layout)

python/test/gluon/test_frontend.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,11 @@ def warpgroup_mma_kernel(nvmma_layout: ttgl.constexpr, acc_layout: ttgl.constexp
584584
a = ttgl.allocate_shared_memory(ttgl.float16, [128, 128], nvmma_layout)
585585
b = ttgl.allocate_shared_memory(ttgl.float16, [128, 128], nvmma_layout)
586586
acc = ttgl.full([128, 128], 0, dtype=ttgl.float16, layout=acc_layout)
587-
hopper.warpgroup_mma(a, b, acc)
587+
acc = hopper.warpgroup_mma(a, b, acc)
588+
ttgl.static_assert(isinstance(acc, ttgl.tensor))
589+
590+
acc = hopper.warpgroup_mma(a, b, acc, is_async=True)
591+
ttgl.static_assert(isinstance(acc, hopper.warpgroup_mma_accumulator))
588592

589593

590594
def test_warpgroup_mma():
@@ -608,6 +612,8 @@ def test_warpgroup_mma():
608612
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf16, #mma>
609613
%true = arith.constant true
610614
%2 = ttng.warp_group_dot %0, %1, %cst_0, %true {inputPrecision = 0 : i32} : !ttg.memdesc<128x128xf16, #shared, #smem, mutable> * !ttg.memdesc<128x128xf16, #shared, #smem, mutable> -> tensor<128x128xf16, #mma>
615+
%true_1 = arith.constant true
616+
%3 = ttng.warp_group_dot %0, %1, %2, %true_1 {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x128xf16, #shared, #smem, mutable> * !ttg.memdesc<128x128xf16, #shared, #smem, mutable> -> tensor<128x128xf16, #mma>
611617
tt.return
612618
}
613619
}
@@ -617,8 +623,9 @@ def test_warpgroup_mma():
617623
@gluon.jit
618624
def warpgroup_mma_wait_kernel():
619625
layout: ttgl.constexpr = ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], instr_shape=[16, 32, 16])
620-
acc = ttgl.full([128, 128], 0, dtype=ttgl.float16, layout=layout)
626+
acc = hopper.warpgroup_mma_init(ttgl.full([128, 128], 0, dtype=ttgl.float16, layout=layout))
621627
acc = hopper.warpgroup_mma_wait(num_outstanding=1, deps=[acc])
628+
_ = acc + acc
622629

623630

624631
def test_warpgroup_mma_wait():
@@ -631,6 +638,7 @@ def test_warpgroup_mma_wait():
631638
%cst = arith.constant 0.000000e+00 : f16
632639
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf16, #mma>
633640
%0 = ttng.warp_group_dot_wait %cst_0 {pendings = 1 : i32} : tensor<128x128xf16, #mma>
641+
%1 = arith.addf %0, %0 : tensor<128x128xf16, #mma>
634642
tt.return
635643
}
636644
}

python/triton/experimental/gluon/language/nvidia/hopper/__init__.py

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1+
from __future__ import annotations
12
from triton.compiler.code_generator import unflatten_ir_values
23
from ..ampere import async_copy
34
from . import mbarrier, tma
45
from ... import _core
56

7+
from typing import List, Tuple, TYPE_CHECKING
8+
if TYPE_CHECKING:
9+
from triton._C.libtriton import ir
10+
611
__all__ = ["async_copy", "fence_async_shared", "mbarrier", "tma", "warpgroup_mma", "warpgroup_mma_wait"]
712

813

@@ -18,6 +23,43 @@ def fence_async_shared(cluster=False, _semantic=None):
1823
_semantic.builder.create_fence_async_shared(cluster)
1924

2025

26+
class warpgroup_mma_accumulator_type(_core.base_type):
27+
tensor_type: _core.dtype
28+
29+
def __init__(self, tensor_type: _core.dtype):
30+
self.tensor_type = tensor_type
31+
32+
def __str__(self) -> str:
33+
return f"warpgroup_mma_accumulator<{self.tensor_type}>"
34+
35+
def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[warpgroup_mma_accumulator, int]:
36+
return warpgroup_mma_accumulator(handles[cursor], self.tensor_type), cursor + 1
37+
38+
def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
39+
self.tensor_type._flatten_ir_types(builder, out)
40+
41+
def mangle(self) -> str:
42+
return f"FT{self.tensor_type.mangle()}FT"
43+
44+
45+
class warpgroup_mma_accumulator(_core.base_value):
46+
handle: ir.value
47+
type: warpgroup_mma_accumulator_type
48+
49+
def __init__(self, handle, tensor_type: _core.dtype):
50+
self.handle = handle
51+
self.type = warpgroup_mma_accumulator_type(tensor_type)
52+
53+
def _flatten_ir(self, handles: List[ir.value]) -> None:
54+
handles.append(self.handle)
55+
56+
57+
@_core.builtin
58+
def warpgroup_mma_init(value, _semantic):
59+
assert isinstance(value, _core.tensor)
60+
return warpgroup_mma_accumulator(value.handle, value.type)
61+
62+
2163
@_core.builtin
2264
def warpgroup_mma(a, b, acc, *, use_acc=True, precision=None, max_num_imprecise_acc=None, is_async=False,
2365
_semantic=None):
@@ -35,7 +77,7 @@ def warpgroup_mma(a, b, acc, *, use_acc=True, precision=None, max_num_imprecise_
3577
is_async (bool): Whether operation is asynchronous. Defaults to False.
3678
3779
Returns:
38-
tensor: Result of warpgroup MMA operation.
80+
tensor or warpgroup_mma_accumulator: Returns the result if synchronous, or a token to load the value once computed if asynchronous.
3981
"""
4082
use_acc = _semantic.to_tensor(use_acc)
4183

@@ -59,7 +101,11 @@ def warpgroup_mma(a, b, acc, *, use_acc=True, precision=None, max_num_imprecise_
59101

60102
handle = _semantic.builder.create_warpgroup_mma(a.handle, b.handle, acc.handle, use_acc.handle, precision,
61103
max_num_imprecise_acc, is_async)
62-
return _core.tensor(handle, acc.type)
104+
tensor_ty = acc.type.tensor_type if isinstance(acc, warpgroup_mma_accumulator) else acc.type
105+
if is_async:
106+
return warpgroup_mma_accumulator(handle, tensor_ty)
107+
else:
108+
return _core.tensor(handle, tensor_ty)
63109

64110

65111
@_core.builtin
@@ -71,10 +117,13 @@ def warpgroup_mma_wait(num_outstanding=0, deps=None, _semantic=None):
71117
num_outstanding (int): Number of outstanding warpgroup MMA operations to wait for. Defaults to 0.
72118
deps (Sequence[tensor]): List of dependencies that need to be kept alive while the mma is unfinished.
73119
"""
120+
if deps is None:
121+
raise ValueError("warpgroup_mma_wait deps must be given")
74122
deps_handles = [x.handle for x in deps] if deps is not None else []
75123
num_outstanding = _core._unwrap_if_constexpr(num_outstanding)
76124
results = _semantic.builder.create_warpgroup_mma_wait(deps_handles, num_outstanding)
77-
results = tuple(unflatten_ir_values(results, [dep.type for dep in deps]))
78-
if len(results) == 1:
79-
return results[0]
125+
result_types = [dep.type.tensor_type if isinstance(dep, warpgroup_mma_accumulator) else dep.type for dep in deps]
126+
results = unflatten_ir_values(results, result_types)
127+
if len(deps) == 1:
128+
return next(results)
80129
return tuple(results)

python/triton/experimental/gluon/language/nvidia/hopper/tma.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
4242
self.strides_type._flatten_ir_types(builder, out)
4343

4444
def mangle(self) -> str:
45-
return f"TD{self.block_type.mangle}_{self.layout.mangle()}TD"
45+
return f"TD{self.block_type.mangle()}_{self.layout.mangle()}TD"
4646

4747

4848
class tensor_descriptor(base_value):

0 commit comments

Comments
 (0)