5
5
from triton_kernels .target_info import get_cdna_version
6
6
import torch
7
7
from .opt_flags_details import opt_flags_amd , opt_flags_nvidia
8
+ from triton_kernels .tensor import bitwidth
8
9
9
10
10
11
@dataclass
@@ -80,15 +81,10 @@ def make_default_opt_flags_amd(
80
81
num_xcds = 8
81
82
xcd_swizzle = num_xcds
82
83
# block_nk:
84
+ # TODO: Does opt_flags_amd.compute_block_nk need to be refactored?
83
85
block_n , block_k = opt_flags_amd .compute_block_nk (
84
86
n , block_m , grid_m , num_xcds , lhs_dtype , rhs_dtype , precision_config
85
87
)
86
- # Replace block_k if provided in constraints.
87
- # TODO: Does opt_flags_amd.compute_block_nk need to be refactored?
88
- if constraints .get ("block_k" , None ) is not None :
89
- block_k = constraints ["block_k" ]
90
- if constraints .get ("block_n" , None ) is not None :
91
- block_n = constraints ["block_n" ]
92
88
is_persistent = constraints .get ("is_persistent" , False )
93
89
# split_k:
94
90
if constraints .get ("split_k" , None ) is not None :
@@ -109,10 +105,33 @@ def make_default_opt_flags_amd(
109
105
epilogue_subtile = constraints .get ('epilogue_subtile' , None )
110
106
if epilogue_subtile is None :
111
107
epilogue_subtile = 1
108
+
109
+ # specific configs for F16 x MXFP4 on CDNA4
110
+ # Note that these configs will exceed LDS usage with async copy enabled
111
+ if is_cdna4 and bitwidth (lhs_dtype ) == 16 and bitwidth (rhs_dtype ) == 4 and precision_config .weight_scale is not None :
112
+ split_k = 1
113
+ if m <= 1024 :
114
+ target_kernel_kwargs ["waves_per_eu" ] = 3
115
+ block_n = 128
116
+ block_k = 256
117
+ num_warps = 4
118
+ else :
119
+ target_kernel_kwargs ["waves_per_eu" ] = 0
120
+ block_m = 64
121
+ block_n = 512
122
+ block_k = 256
123
+ num_warps = 8
124
+
125
+ def replace_with_valid_constraint (k : str , v ):
126
+ if constraints .get (k , None ) is not None :
127
+ return constraints [k ]
128
+ else :
129
+ return v
130
+
112
131
ret = OptFlags (
113
- block_m = block_m ,
114
- block_n = block_n ,
115
- block_k = block_k ,
132
+ block_m = replace_with_valid_constraint ( ' block_m' , block_m ) ,
133
+ block_n = replace_with_valid_constraint ( ' block_n' , block_n ) ,
134
+ block_k = replace_with_valid_constraint ( ' block_k' , block_k ) ,
116
135
num_warps = num_warps ,
117
136
num_stages = num_stages ,
118
137
group_m = group_m ,
0 commit comments