From 1e5b33017627562d62a9f725337b15e34bd83a74 Mon Sep 17 00:00:00 2001 From: jongsoo-openai Date: Sat, 20 Sep 2025 22:31:26 -0700 Subject: [PATCH 1/3] [mxfp] remove col-major check for mx weight --- python/triton_kernels/triton_kernels/matmul_ogs.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/triton_kernels/triton_kernels/matmul_ogs.py b/python/triton_kernels/triton_kernels/matmul_ogs.py index cc374c76d94f..852d2c3ca64e 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs.py @@ -334,7 +334,6 @@ def matmul_ogs(x, w, bias, w_scale = precision_config.weight_scale w_has_mx = w_scale is not None is_hopper_fp8 = is_cuda() and not target_info.cuda_capability_geq(10, 0) and bitwidth(w.dtype) == 8 - if w_has_mx: assert w.stride(-2) == 1, "`w` must be column-major when it has data-type mxfp" if is_hopper_fp8: assert w.stride(-2) == 1, "`w` must be column-major when it has data-type FP8 on capability < 10" if not isinstance(w, Tensor): # TODO: remove this code path; using uint8 for mxfp4 weight will bite us when we want to support uint8 for real From b1dc287b676172425082f08cdaec5d0bfca11dd1 Mon Sep 17 00:00:00 2001 From: jongsoo-openai Date: Sun, 21 Sep 2025 10:55:45 -0700 Subject: [PATCH 2/3] add test --- python/triton_kernels/tests/test_matmul.py | 87 ++++++++++++++++--- .../triton_kernels/matmul_ogs.py | 3 + 2 files changed, 80 insertions(+), 10 deletions(-) diff --git a/python/triton_kernels/tests/test_matmul.py b/python/triton_kernels/tests/test_matmul.py index afb8befd8f59..17315a387a22 100644 --- a/python/triton_kernels/tests/test_matmul.py +++ b/python/triton_kernels/tests/test_matmul.py @@ -34,6 +34,12 @@ def alloc_rand(shape, device, dtype, requires_grad=True): return tmp.to(dtype).requires_grad_(requires_grad) return torch.randn(shape, device=device, dtype=dtype, requires_grad=requires_grad) +# def alloc_ones(shape, device, dtype, requires_grad=True): +# return torch.ones(shape, device=device, dtype=dtype, requires_grad=requires_grad) + +# def alloc_zeros(shape, device, dtype, requires_grad=True): +# return torch.zeros(shape, device=device, dtype=dtype, requires_grad=requires_grad) + def alloc_rand_like(x): return alloc_rand(x.shape, x.device, x.dtype, x.requires_grad) @@ -162,6 +168,7 @@ class Case: x_transpose: bool = False w_transpose: bool = False y_transpose: bool = False + colmajor_mxfp_weight: bool = False @pytest.mark.parametrize( @@ -236,6 +243,7 @@ class Case: Case(1000, 704, 800, "batched", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 2, 1), Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9), Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True), + Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9, colmajor_mxfp_weight=False), Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2), Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True), Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 8, 4), @@ -277,7 +285,8 @@ class Case: @pytest.mark.parametrize("has_y_gammas", [False, True]) @pytest.mark.parametrize("is_persistent", [False, True]) def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas, is_persistent, n_expts_tot, - n_expts_act, n_expt_shards, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, epilogue_subtile, + n_expts_act, n_expt_shards, mode, act_dtype_str, weight_dtype_str, block_m, + hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile, x_transpose, w_transpose, y_transpose, device, opt_flags_scope, fresh_knobs): # TODO: remove when Triton FP8 supports proper RTNE @@ -409,14 +418,72 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas w_scale_layout, w_scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout( mx_axis=mx_axis, num_warps=8) # downcast to mxfp - w_tri, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=mx_axis) - w_ref = upcast_from_mxfp(w_tri, w_scale_tri, torch.bfloat16, axis=mx_axis) - w_tri_dtype = FP4 if "float4" in weight_dtype_str else weight_dtype - w_tri = wrap_torch_tensor(w_tri, w_tri_dtype) - w_scale_tri = wrap_torch_tensor(w_scale_tri) - # convert layouts - w_tri = convert_layout(w_tri, w_layout, **w_layout_opts) - w_scale_tri = convert_layout(w_scale_tri, w_scale_layout, **w_scale_layout_opts) + w_tri_orig = w_tri + if colmajor_mxfp_weight: + w_tri, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=mx_axis) + w_ref = upcast_from_mxfp(w_tri, w_scale_tri, torch.bfloat16, axis=mx_axis) + w_tri_dtype = FP4 if "float4" in weight_dtype_str else weight_dtype + w_tri = wrap_torch_tensor(w_tri, w_tri_dtype) + w_scale_tri = wrap_torch_tensor(w_scale_tri) + # convert layouts + w_tri = convert_layout(w_tri, w_layout, **w_layout_opts) + w_scale_tri = convert_layout(w_scale_tri, w_scale_layout, **w_scale_layout_opts) + else: + if torch.cuda.get_device_capability()[0] < 10: + pytest.skip("transposed mxfp weight not supported with cuda capability < 10") + if block_m == 16: + pytest.skip("PassManager::run failed from Triton compiler") + # TODO: swizzling for rowmajor + + # A typical use case is we already quantized col-major weight, + # and we want matmul with its transposed row-major weight w/o + # requantization. + + # put abs_max of each 32x32 block to diagonal so scales of transposed agree + w_ndim = w_tri.ndim + if w_ndim == 2: + w_tri = w_tri.unsqueeze(0) + BLOCK_SIZE = int(MXFP_BLOCK_SIZE) + for e, i, j in itertools.product(range(w_tri.shape[0]), range(0, w_tri.shape[1], BLOCK_SIZE), range(0, w_tri.shape[2], BLOCK_SIZE)): + i_end = min(i+BLOCK_SIZE, w_tri.shape[1]) + j_end = min(j+BLOCK_SIZE, w_tri.shape[2]) + block = w_tri[e, i:i_end, j:j_end] + m_abs = block.abs().max() + i_len = i_end - i + j_len = j_end - j + min_len = min(i_len, j_len) + signs = torch.randint(0, 2, (max(i_len, j_len),), device=w_tri.device) * 2 - 1 + block.diagonal(dim1=-2, dim2=-1)[:] = signs[:min_len] * m_abs + if j_len > i_len: + block[i_len - 1, i_len:] = signs[min_len:] * m_abs + elif i_len > j_len: + block[j_len:, j_len - 1] = signs[min_len:] * m_abs + if w_ndim == 2: + w_tri = w_tri.squeeze(0) + + # matmul with rowmajor weight expects scale is separately + # constructed (not much additional memory needed). + _, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=mx_axis) + # reuse quantized value from colmajor + w_tri_rowmajor, w_scale_tri_rowmajor = downcast_to_mxfp(w_tri.mT.contiguous(), weight_dtype, axis=mx_axis) + w_ref = upcast_from_mxfp(w_tri_rowmajor, w_scale_tri_rowmajor, torch.bfloat16, axis=mx_axis).mT.contiguous() + w_tri = w_tri_rowmajor.data.mT + + def _pad_and_block(x: torch.Tensor) -> torch.Tensor: + x = torch.nn.functional.pad(x, (0, x.shape[-1] % BLOCK_SIZE), mode="replicate") + return x.view(*x.shape[:-1], x.shape[-1] // BLOCK_SIZE, BLOCK_SIZE) + + # check if generated scale is transpose-invariant as intended construction + # [cdiv(K, 32), N] -> dedup to [cdiv(K, 32), cdiv(N, 32)] + w_scale_tri_blocked = _pad_and_block(w_scale_tri) + w_scale_tri_sampled = w_scale_tri_blocked[..., 0:1] + # [cdiv(N, 32), K] -> dedup to [cdiv(N, 32), cdiv(K, 32)] + w_scale_tri_rowmajor_blocked = _pad_and_block(w_scale_tri_rowmajor) + w_scale_tri_rowmajor_sampled = w_scale_tri_rowmajor_blocked[..., 0:1] + assert torch.equal(w_scale_tri_sampled.expand_as(w_scale_tri_blocked), w_scale_tri_blocked) + assert torch.equal(w_scale_tri_rowmajor_sampled.expand_as(w_scale_tri_rowmajor_blocked), w_scale_tri_rowmajor_blocked) + assert torch.equal(w_scale_tri_sampled.squeeze(-1), w_scale_tri_rowmajor_sampled.squeeze(-1).mT) + precision_opt.weight_scale = w_scale_tri epilogue = None if act_mxfp8: @@ -425,7 +492,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas is_input_batched = x_tri.ndim == 3 y_shape = x_tri.shape if is_input_batched else (1,) + x_tri.shape n_rows = y_shape[1] if gindx is None or mode == "batched" else gindx.dst_indx.shape[0] - y_shape = (y_shape[0], n_rows, w_tri.shape[-1]) + y_shape = (y_shape[0], n_rows, w_tri_orig.shape[-1]) if sindx is None or mode == "batched": if not is_input_batched: y_shape = (y_shape[1], y_shape[2]) diff --git a/python/triton_kernels/triton_kernels/matmul_ogs.py b/python/triton_kernels/triton_kernels/matmul_ogs.py index 852d2c3ca64e..967efedc6e86 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs.py @@ -17,6 +17,7 @@ from .matmul_ogs_details._p_matmul_ogs import _p_matmul_ogs, get_per_device_per_stream_alloc_fn from .matmul_ogs_details._reduce_grouped import _reduce_grouped from .numerics_details.mxfp import MXFP_BLOCK_SIZE +from .tensor_details.layout_details.strided import StridedLayout from .matmul_ogs_details.opt_flags import make_opt_flags, update_opt_flags_constraints, InapplicableConstraint from .specialize import specialize from .tensor import Storage, Tensor, FP4, bitwidth, wrap_torch_tensor @@ -339,6 +340,8 @@ def matmul_ogs(x, w, bias, # TODO: remove this code path; using uint8 for mxfp4 weight will bite us when we want to support uint8 for real dtype = FP4 if w.dtype == torch.uint8 else w.dtype w = wrap_torch_tensor(w, dtype=dtype) + if w_has_mx and (torch.cuda.get_device_capability()[0] < 10 or w.storage.layout is not None and not isinstance(w.storage.layout, StridedLayout)): + assert w.stride(-2) == 1, "`w` must be column-major when it has data-type mxfp and (swizzled or not on >=Blackwell)" if w_scale is not None and not isinstance(w_scale, Tensor): w_scale = Tensor(w_scale) if w_scale is not None: From 6778e93bf437cd8a0bbaeb492570e1e3194b87e1 Mon Sep 17 00:00:00 2001 From: jongsoo-openai Date: Sun, 21 Sep 2025 10:57:25 -0700 Subject: [PATCH 3/3] remove temp changes --- python/triton_kernels/tests/test_matmul.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/python/triton_kernels/tests/test_matmul.py b/python/triton_kernels/tests/test_matmul.py index 17315a387a22..31e8c1652b25 100644 --- a/python/triton_kernels/tests/test_matmul.py +++ b/python/triton_kernels/tests/test_matmul.py @@ -34,12 +34,6 @@ def alloc_rand(shape, device, dtype, requires_grad=True): return tmp.to(dtype).requires_grad_(requires_grad) return torch.randn(shape, device=device, dtype=dtype, requires_grad=requires_grad) -# def alloc_ones(shape, device, dtype, requires_grad=True): -# return torch.ones(shape, device=device, dtype=dtype, requires_grad=requires_grad) - -# def alloc_zeros(shape, device, dtype, requires_grad=True): -# return torch.zeros(shape, device=device, dtype=dtype, requires_grad=requires_grad) - def alloc_rand_like(x): return alloc_rand(x.shape, x.device, x.dtype, x.requires_grad) @@ -168,7 +162,7 @@ class Case: x_transpose: bool = False w_transpose: bool = False y_transpose: bool = False - colmajor_mxfp_weight: bool = False + colmajor_mxfp_weight: bool = True @pytest.mark.parametrize(