Skip to content

Commit 492a6e5

Browse files
jongsoo-openaiThomasRaoux
authored andcommitted
[mxfp] remove col-major assert for mx weight (#8249)
Follow-up triton-lang/triton#7795 Now transposed weight is supported, remove unnecessary assertion that mx weight should be col-major <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [x] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.) --------- Co-authored-by: Thomas Raoux <[email protected]>
1 parent b526748 commit 492a6e5

File tree

2 files changed

+73
-11
lines changed

2 files changed

+73
-11
lines changed

python/triton_kernels/tests/test_matmul.py

Lines changed: 70 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ class Case:
194194
x_transpose: bool = False
195195
w_transpose: bool = False
196196
y_transpose: bool = False
197+
colmajor_mxfp_weight: bool = True
197198

198199

199200
@pytest.mark.parametrize(
@@ -266,6 +267,7 @@ class Case:
266267
Case(1000, 704, 800, "batched", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 2, 1),
267268
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9),
268269
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True),
270+
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9, colmajor_mxfp_weight=False),
269271
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2),
270272
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True),
271273
Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 8, 4),
@@ -312,7 +314,7 @@ class Case:
312314
@pytest.mark.parametrize("has_y_gammas", [False, True])
313315
@pytest.mark.parametrize("is_persistent", [False, True])
314316
def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot,
315-
n_expts_act, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, epilogue_subtile,
317+
n_expts_act, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile,
316318
x_transpose, w_transpose, y_transpose,
317319
device, opt_flags_scope):
318320
# TODO: remove when Triton FP8 supports proper RTNE
@@ -460,14 +462,72 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
460462
w_scale_layout, w_scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout(
461463
mx_axis=mx_axis, num_warps=8)
462464
# downcast to mxfp
463-
w_tri, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=mx_axis)
464-
w_ref = upcast_from_mxfp(w_tri, w_scale_tri, torch.bfloat16, axis=mx_axis)
465-
w_tri_dtype = FP4 if "float4" in weight_dtype_str else weight_dtype
466-
w_tri = wrap_torch_tensor(w_tri, w_tri_dtype)
467-
w_scale_tri = wrap_torch_tensor(w_scale_tri)
468-
# convert layouts
469-
w_tri = convert_layout(w_tri, w_layout, **w_layout_opts)
470-
w_scale_tri = convert_layout(w_scale_tri, w_scale_layout, **w_scale_layout_opts)
465+
w_tri_orig = w_tri
466+
if colmajor_mxfp_weight:
467+
w_tri, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=mx_axis)
468+
w_ref = upcast_from_mxfp(w_tri, w_scale_tri, torch.bfloat16, axis=mx_axis)
469+
w_tri_dtype = FP4 if "float4" in weight_dtype_str else weight_dtype
470+
w_tri = wrap_torch_tensor(w_tri, w_tri_dtype)
471+
w_scale_tri = wrap_torch_tensor(w_scale_tri)
472+
# convert layouts
473+
w_tri = convert_layout(w_tri, w_layout, **w_layout_opts)
474+
w_scale_tri = convert_layout(w_scale_tri, w_scale_layout, **w_scale_layout_opts)
475+
else:
476+
if torch.cuda.get_device_capability()[0] < 10:
477+
pytest.skip("transposed mxfp weight not supported with cuda capability < 10")
478+
if block_m == 16:
479+
pytest.skip("PassManager::run failed from Triton compiler")
480+
# TODO: swizzling for rowmajor
481+
482+
# A typical use case is we already quantized col-major weight,
483+
# and we want matmul with its transposed row-major weight w/o
484+
# requantization.
485+
486+
# put abs_max of each 32x32 block to diagonal so scales of transposed agree
487+
w_ndim = w_tri.ndim
488+
if w_ndim == 2:
489+
w_tri = w_tri.unsqueeze(0)
490+
BLOCK_SIZE = int(MXFP_BLOCK_SIZE)
491+
for e, i, j in itertools.product(range(w_tri.shape[0]), range(0, w_tri.shape[1], BLOCK_SIZE), range(0, w_tri.shape[2], BLOCK_SIZE)):
492+
i_end = min(i+BLOCK_SIZE, w_tri.shape[1])
493+
j_end = min(j+BLOCK_SIZE, w_tri.shape[2])
494+
block = w_tri[e, i:i_end, j:j_end]
495+
m_abs = block.abs().max()
496+
i_len = i_end - i
497+
j_len = j_end - j
498+
min_len = min(i_len, j_len)
499+
signs = torch.randint(0, 2, (max(i_len, j_len),), device=w_tri.device) * 2 - 1
500+
block.diagonal(dim1=-2, dim2=-1)[:] = signs[:min_len] * m_abs
501+
if j_len > i_len:
502+
block[i_len - 1, i_len:] = signs[min_len:] * m_abs
503+
elif i_len > j_len:
504+
block[j_len:, j_len - 1] = signs[min_len:] * m_abs
505+
if w_ndim == 2:
506+
w_tri = w_tri.squeeze(0)
507+
508+
# matmul with rowmajor weight expects scale is separately
509+
# constructed (not much additional memory needed).
510+
_, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=mx_axis)
511+
# reuse quantized value from colmajor
512+
w_tri_rowmajor, w_scale_tri_rowmajor = downcast_to_mxfp(w_tri.mT.contiguous(), weight_dtype, axis=mx_axis)
513+
w_ref = upcast_from_mxfp(w_tri_rowmajor, w_scale_tri_rowmajor, torch.bfloat16, axis=mx_axis).mT.contiguous()
514+
w_tri = w_tri_rowmajor.data.mT
515+
516+
def _pad_and_block(x: torch.Tensor) -> torch.Tensor:
517+
x = torch.nn.functional.pad(x, (0, x.shape[-1] % BLOCK_SIZE), mode="replicate")
518+
return x.view(*x.shape[:-1], x.shape[-1] // BLOCK_SIZE, BLOCK_SIZE)
519+
520+
# check if generated scale is transpose-invariant as intended construction
521+
# [cdiv(K, 32), N] -> dedup to [cdiv(K, 32), cdiv(N, 32)]
522+
w_scale_tri_blocked = _pad_and_block(w_scale_tri)
523+
w_scale_tri_sampled = w_scale_tri_blocked[..., 0:1]
524+
# [cdiv(N, 32), K] -> dedup to [cdiv(N, 32), cdiv(K, 32)]
525+
w_scale_tri_rowmajor_blocked = _pad_and_block(w_scale_tri_rowmajor)
526+
w_scale_tri_rowmajor_sampled = w_scale_tri_rowmajor_blocked[..., 0:1]
527+
assert torch.equal(w_scale_tri_sampled.expand_as(w_scale_tri_blocked), w_scale_tri_blocked)
528+
assert torch.equal(w_scale_tri_rowmajor_sampled.expand_as(w_scale_tri_rowmajor_blocked), w_scale_tri_rowmajor_blocked)
529+
assert torch.equal(w_scale_tri_sampled.squeeze(-1), w_scale_tri_rowmajor_sampled.squeeze(-1).mT)
530+
471531
precision_opt.weight_scale = w_scale_tri
472532
epilogue = None
473533
if act_mxfp8:
@@ -476,7 +536,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
476536
is_input_batched = x_tri.ndim == 3
477537
y_shape = x_tri.shape if is_input_batched else (1,) + x_tri.shape
478538
n_rows = y_shape[1] if gindx is None or mode == "batched" else gindx.dst_indx.shape[0]
479-
y_shape = (y_shape[0], n_rows, w_tri.shape[-1])
539+
y_shape = (y_shape[0], n_rows, w_tri_orig.shape[-1])
480540
if sindx is None or mode == "batched":
481541
if not is_input_batched:
482542
y_shape = (y_shape[1], y_shape[2])

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .matmul_ogs_details._p_matmul_ogs import _p_matmul_ogs, get_per_device_per_stream_alloc_fn
1717
from .matmul_ogs_details._reduce_grouped import _reduce_grouped
1818
from .numerics_details.mxfp import MXFP_BLOCK_SIZE
19+
from .tensor_details.layout_details.strided import StridedLayout
1920
from .matmul_ogs_details.opt_flags import make_opt_flags, update_opt_flags_constraints, InapplicableConstraint
2021
from .specialize import specialize
2122
from .tensor import Storage, Tensor, FP4, bitwidth, wrap_torch_tensor, RaggedTensorMetadata
@@ -480,12 +481,13 @@ def matmul_ogs(x, w, bias,
480481
w_scale = precision_config.weight_scale
481482
w_has_mx = w_scale is not None
482483
is_hopper_fp8 = is_cuda() and not target_info.cuda_capability_geq(10, 0) and bitwidth(w.dtype) == 8
483-
if w_has_mx: assert w.stride(-2) == 1, "`w` must be column-major when it has data-type mxfp"
484484
if is_hopper_fp8: assert w.stride(-2) == 1, "`w` must be column-major when it has data-type FP8 on capability < 10"
485485
if not isinstance(w, Tensor):
486486
# TODO: remove this code path; using uint8 for mxfp4 weight will bite us when we want to support uint8 for real
487487
dtype = FP4 if w.dtype == torch.uint8 else w.dtype
488488
w = wrap_torch_tensor(w, dtype=dtype)
489+
if w_has_mx and (torch.cuda.get_device_capability()[0] < 10 or w.storage.layout is not None and not isinstance(w.storage.layout, StridedLayout)):
490+
assert w.stride(-2) == 1, "`w` must be column-major when it has data-type mxfp and (swizzled or not on >=Blackwell)"
489491
if w_scale is not None and not isinstance(w_scale, Tensor):
490492
w_scale = Tensor(w_scale)
491493
if w_scale is not None:

0 commit comments

Comments
 (0)