Skip to content

Commit bc94c4c

Browse files
committed
Refactor test_mm_fp4.py
1 parent 14cf404 commit bc94c4c

File tree

1 file changed

+25
-16
lines changed

1 file changed

+25
-16
lines changed

tests/gemm/test_mm_fp4.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,7 @@
1212
from flashinfer.gemm import CUDNN_FP4_MXFP4_SM120_CUDNN_VERSION_ERROR
1313

1414

15-
# TODO: Consdier splitting this function up for the various backends
16-
@pytest.mark.parametrize("m", [1, 48, 128, 256, 512])
17-
@pytest.mark.parametrize("n", [128, 256, 512])
18-
@pytest.mark.parametrize("k", [128, 256, 512])
19-
@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16])
20-
@pytest.mark.parametrize("backend", ["trtllm", "cudnn", "cutlass"])
21-
@pytest.mark.parametrize("use_128x4_sf_layout", [False, True])
22-
@pytest.mark.parametrize("auto_tuning", [False, True])
23-
@pytest.mark.parametrize("fp4_type", ["nvfp4", "mxfp4", "mxfp4_alpha"])
24-
def test_mm_fp4(
15+
def _test_mm_fp4(
2516
m, n, k, res_dtype, backend, use_128x4_sf_layout, auto_tuning, fp4_type
2617
):
2718
use_nvfp4 = fp4_type == "nvfp4"
@@ -40,8 +31,8 @@ def test_mm_fp4(
4031
pytest.skip("trtllm gemm does not support SM110/SM120/SM121 GPUs.")
4132
if not use_128x4_sf_layout and backend != "trtllm":
4233
pytest.skip("Skipping test for non-trtllm fp4 with use_128x4_sf_layout=False")
43-
if not use_nvfp4 and backend != "cudnn":
44-
pytest.skip("mx_fp4 is only supported for cudnn backend")
34+
if not use_nvfp4 and backend not in ["cudnn", "auto"]:
35+
pytest.skip("mx_fp4 is only supported for cudnn and auto backends")
4536

4637
input = torch.randn([m, k], device="cuda", dtype=torch.bfloat16)
4738
mat2 = torch.randn([n, k], device="cuda", dtype=torch.bfloat16)
@@ -103,19 +94,37 @@ def test_mm_fp4(
10394
pytest.fail(str(e))
10495

10596

97+
# TODO: Consdier splitting this function up for the various backends
98+
@pytest.mark.parametrize("m", [1, 48, 128, 256, 512])
99+
@pytest.mark.parametrize("n", [128, 256, 512])
100+
@pytest.mark.parametrize("k", [128, 256, 512])
101+
@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16])
102+
@pytest.mark.parametrize("backend", ["trtllm", "cudnn", "cutlass"])
103+
@pytest.mark.parametrize("use_128x4_sf_layout", [False, True])
104+
@pytest.mark.parametrize("auto_tuning", [False, True])
105+
@pytest.mark.parametrize("fp4_type", ["nvfp4", "mxfp4", "mxfp4_alpha"])
106+
def test_mm_fp4(
107+
m, n, k, res_dtype, backend, use_128x4_sf_layout, auto_tuning, fp4_type
108+
):
109+
# Non-auto backends
110+
_test_mm_fp4(
111+
m, n, k, res_dtype, backend, use_128x4_sf_layout, auto_tuning, fp4_type
112+
)
113+
114+
106115
# Split tests for checking auto functionality
107116
@pytest.mark.parametrize("m", [1, 48, 256, 512])
108117
@pytest.mark.parametrize("n", [256, 512])
109118
@pytest.mark.parametrize("k", [256, 512])
110119
@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16])
111-
@pytest.mark.parametrize("backend", ["auto"])
112-
@pytest.mark.parametrize("use_128x4_sf_layout", [False, True])
120+
@pytest.mark.parametrize("use_128x4_sf_layout", [True])
113121
@pytest.mark.parametrize("auto_tuning", [False, True])
114122
@pytest.mark.parametrize("fp4_type", ["nvfp4", "mxfp4", "mxfp4_alpha"])
115123
def test_mm_fp4_backend_auto(
116-
m, n, k, res_dtype, backend, use_128x4_sf_layout, auto_tuning, fp4_type
124+
m, n, k, res_dtype, use_128x4_sf_layout, auto_tuning, fp4_type
117125
):
118-
test_mm_fp4(m, n, k, res_dtype, "auto", use_128x4_sf_layout, auto_tuning, fp4_type)
126+
# Some test cases for auto backend.
127+
_test_mm_fp4(m, n, k, res_dtype, "auto", use_128x4_sf_layout, auto_tuning, fp4_type)
119128

120129

121130
if __name__ == "__main__":

0 commit comments

Comments
 (0)