1+ from __future__ import annotations
12from triton .compiler .code_generator import unflatten_ir_values
23from ..ampere import async_copy
34from . import mbarrier , tma
45from ... import _core
56
7+ from typing import List , Tuple , TYPE_CHECKING
8+ if TYPE_CHECKING :
9+ from triton ._C .libtriton import ir
10+
611__all__ = ["async_copy" , "fence_async_shared" , "mbarrier" , "tma" , "warpgroup_mma" , "warpgroup_mma_wait" ]
712
813
@@ -18,6 +23,43 @@ def fence_async_shared(cluster=False, _semantic=None):
1823 _semantic .builder .create_fence_async_shared (cluster )
1924
2025
26+ class warpgroup_mma_accumulator_type (_core .base_type ):
27+ tensor_type : _core .dtype
28+
29+ def __init__ (self , tensor_type : _core .dtype ):
30+ self .tensor_type = tensor_type
31+
32+ def __str__ (self ) -> str :
33+ return f"warpgroup_mma_accumulator<{ self .tensor_type } >"
34+
35+ def _unflatten_ir (self , handles : List [ir .value ], cursor : int ) -> Tuple [warpgroup_mma_accumulator , int ]:
36+ return warpgroup_mma_accumulator (handles [cursor ], self .tensor_type ), cursor + 1
37+
38+ def _flatten_ir_types (self , builder : ir .builder , out : List [ir .type ]) -> None :
39+ self .tensor_type ._flatten_ir_types (builder , out )
40+
41+ def mangle (self ) -> str :
42+ return f"FT{ self .tensor_type .mangle ()} FT"
43+
44+
45+ class warpgroup_mma_accumulator (_core .base_value ):
46+ handle : ir .value
47+ type : warpgroup_mma_accumulator_type
48+
49+ def __init__ (self , handle , tensor_type : _core .dtype ):
50+ self .handle = handle
51+ self .type = warpgroup_mma_accumulator_type (tensor_type )
52+
53+ def _flatten_ir (self , handles : List [ir .value ]) -> None :
54+ handles .append (self .handle )
55+
56+
57+ @_core .builtin
58+ def warpgroup_mma_init (value , _semantic ):
59+ assert isinstance (value , _core .tensor )
60+ return warpgroup_mma_accumulator (value .handle , value .type )
61+
62+
2163@_core .builtin
2264def warpgroup_mma (a , b , acc , * , use_acc = True , precision = None , max_num_imprecise_acc = None , is_async = False ,
2365 _semantic = None ):
@@ -35,7 +77,7 @@ def warpgroup_mma(a, b, acc, *, use_acc=True, precision=None, max_num_imprecise_
3577 is_async (bool): Whether operation is asynchronous. Defaults to False.
3678
3779 Returns:
38- tensor: Result of warpgroup MMA operation .
80+ tensor or warpgroup_mma_accumulator: Returns the result if synchronous, or a token to load the value once computed if asynchronous .
3981 """
4082 use_acc = _semantic .to_tensor (use_acc )
4183
@@ -59,7 +101,11 @@ def warpgroup_mma(a, b, acc, *, use_acc=True, precision=None, max_num_imprecise_
59101
60102 handle = _semantic .builder .create_warpgroup_mma (a .handle , b .handle , acc .handle , use_acc .handle , precision ,
61103 max_num_imprecise_acc , is_async )
62- return _core .tensor (handle , acc .type )
104+ tensor_ty = acc .type .tensor_type if isinstance (acc , warpgroup_mma_accumulator ) else acc .type
105+ if is_async :
106+ return warpgroup_mma_accumulator (handle , tensor_ty )
107+ else :
108+ return _core .tensor (handle , tensor_ty )
63109
64110
65111@_core .builtin
@@ -71,10 +117,13 @@ def warpgroup_mma_wait(num_outstanding=0, deps=None, _semantic=None):
71117 num_outstanding (int): Number of outstanding warpgroup MMA operations to wait for. Defaults to 0.
72118 deps (Sequence[tensor]): List of dependencies that need to be kept alive while the mma is unfinished.
73119 """
120+ if deps is None :
121+ raise ValueError ("warpgroup_mma_wait deps must be given" )
74122 deps_handles = [x .handle for x in deps ] if deps is not None else []
75123 num_outstanding = _core ._unwrap_if_constexpr (num_outstanding )
76124 results = _semantic .builder .create_warpgroup_mma_wait (deps_handles , num_outstanding )
77- results = tuple (unflatten_ir_values (results , [dep .type for dep in deps ]))
78- if len (results ) == 1 :
79- return results [0 ]
125+ result_types = [dep .type .tensor_type if isinstance (dep , warpgroup_mma_accumulator ) else dep .type for dep in deps ]
126+ results = unflatten_ir_values (results , result_types )
127+ if len (deps ) == 1 :
128+ return next (results )
80129 return tuple (results )
0 commit comments