Skip to content

Commit 7d92894

Browse files
authored
[Bench][AMD] Update Parameters for Bf16 x Mxfp4 MoE Kernel (#8176)
1 parent 70e69cb commit 7d92894

File tree

1 file changed

+28
-9
lines changed
  • python/triton_kernels/triton_kernels/matmul_ogs_details

1 file changed

+28
-9
lines changed

python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from triton_kernels.target_info import get_cdna_version
66
import torch
77
from .opt_flags_details import opt_flags_amd, opt_flags_nvidia
8+
from triton_kernels.tensor import bitwidth
89

910

1011
@dataclass
@@ -80,15 +81,10 @@ def make_default_opt_flags_amd(
8081
num_xcds = 8
8182
xcd_swizzle = num_xcds
8283
# block_nk:
84+
# TODO: Does opt_flags_amd.compute_block_nk need to be refactored?
8385
block_n, block_k = opt_flags_amd.compute_block_nk(
8486
n, block_m, grid_m, num_xcds, lhs_dtype, rhs_dtype, precision_config
8587
)
86-
# Replace block_k if provided in constraints.
87-
# TODO: Does opt_flags_amd.compute_block_nk need to be refactored?
88-
if constraints.get("block_k", None) is not None:
89-
block_k = constraints["block_k"]
90-
if constraints.get("block_n", None) is not None:
91-
block_n = constraints["block_n"]
9288
is_persistent = constraints.get("is_persistent", False)
9389
# split_k:
9490
if constraints.get("split_k", None) is not None:
@@ -109,10 +105,33 @@ def make_default_opt_flags_amd(
109105
epilogue_subtile = constraints.get('epilogue_subtile', None)
110106
if epilogue_subtile is None:
111107
epilogue_subtile = 1
108+
109+
# specific configs for F16 x MXFP4 on CDNA4
110+
# Note that these configs will exceed LDS usage with async copy enabled
111+
if is_cdna4 and bitwidth(lhs_dtype) == 16 and bitwidth(rhs_dtype) == 4 and precision_config.weight_scale is not None:
112+
split_k = 1
113+
if m <= 1024:
114+
target_kernel_kwargs["waves_per_eu"] = 3
115+
block_n = 128
116+
block_k = 256
117+
num_warps = 4
118+
else:
119+
target_kernel_kwargs["waves_per_eu"] = 0
120+
block_m = 64
121+
block_n = 512
122+
block_k = 256
123+
num_warps = 8
124+
125+
def replace_with_valid_constraint(k: str, v):
126+
if constraints.get(k, None) is not None:
127+
return constraints[k]
128+
else:
129+
return v
130+
112131
ret = OptFlags(
113-
block_m=block_m,
114-
block_n=block_n,
115-
block_k=block_k,
132+
block_m=replace_with_valid_constraint('block_m', block_m),
133+
block_n=replace_with_valid_constraint('block_n', block_n),
134+
block_k=replace_with_valid_constraint('block_k', block_k),
116135
num_warps=num_warps,
117136
num_stages=num_stages,
118137
group_m=group_m,

0 commit comments

Comments
 (0)