Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/triton_kernels/tests/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,11 @@ def _build_test_op_cases():
Case(*even_shape, "ragged", "bfloat16", "bfloat16", epilogue_subtile=val, swiglu_opts=(1.1, 1.4))
for val in (1, 2, 4)
])
# swiglu together with mxfp8 downcastepilogue
test_cases.extend([
Case(*shape, mode, "mxfloat8_e4m3fn", "mxfloat4_e2m1", hbm_swizzling=True, split_k=split_k, swiglu_opts=(1.1, 7))
for shape in [odd_shape2, even_shape] for mode in ["ragged", "batched"] for split_k in [1, 5]
])

return test_cases

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@ def make_default_opt_flags_nvidia(
block_m = max(16, min(triton.next_power_of_2(8 * slice_size), 128))
else:
block_m = max(16, min(triton.next_power_of_2(2 * slice_size), 64))
if block_m == 64 and precision_config.c_mx_scale is not None and rhs_dtype == FP4 and torch.cuda.get_device_capability()[0] >= 10:
# when having both fused_activation and mxfp8 downcast in epilogue, block_m=64 causing shared memory overflow
block_m = 128
else:
block_m = max(16, min(triton.next_power_of_2(slice_size), 128))
# block n
Expand Down