Skip to content

Commit b1dc287

Browse files
add test
1 parent 1e5b330 commit b1dc287

File tree

2 files changed

+80
-10
lines changed

2 files changed

+80
-10
lines changed

python/triton_kernels/tests/test_matmul.py

Lines changed: 77 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ def alloc_rand(shape, device, dtype, requires_grad=True):
3434
return tmp.to(dtype).requires_grad_(requires_grad)
3535
return torch.randn(shape, device=device, dtype=dtype, requires_grad=requires_grad)
3636

37+
# def alloc_ones(shape, device, dtype, requires_grad=True):
38+
# return torch.ones(shape, device=device, dtype=dtype, requires_grad=requires_grad)
39+
40+
# def alloc_zeros(shape, device, dtype, requires_grad=True):
41+
# return torch.zeros(shape, device=device, dtype=dtype, requires_grad=requires_grad)
42+
3743

3844
def alloc_rand_like(x):
3945
return alloc_rand(x.shape, x.device, x.dtype, x.requires_grad)
@@ -162,6 +168,7 @@ class Case:
162168
x_transpose: bool = False
163169
w_transpose: bool = False
164170
y_transpose: bool = False
171+
colmajor_mxfp_weight: bool = False
165172

166173

167174
@pytest.mark.parametrize(
@@ -236,6 +243,7 @@ class Case:
236243
Case(1000, 704, 800, "batched", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 2, 1),
237244
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9),
238245
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True),
246+
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9, colmajor_mxfp_weight=False),
239247
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2),
240248
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True),
241249
Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 8, 4),
@@ -277,7 +285,8 @@ class Case:
277285
@pytest.mark.parametrize("has_y_gammas", [False, True])
278286
@pytest.mark.parametrize("is_persistent", [False, True])
279287
def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas, is_persistent, n_expts_tot,
280-
n_expts_act, n_expt_shards, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, epilogue_subtile,
288+
n_expts_act, n_expt_shards, mode, act_dtype_str, weight_dtype_str, block_m,
289+
hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile,
281290
x_transpose, w_transpose, y_transpose,
282291
device, opt_flags_scope, fresh_knobs):
283292
# TODO: remove when Triton FP8 supports proper RTNE
@@ -409,14 +418,72 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
409418
w_scale_layout, w_scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout(
410419
mx_axis=mx_axis, num_warps=8)
411420
# downcast to mxfp
412-
w_tri, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=mx_axis)
413-
w_ref = upcast_from_mxfp(w_tri, w_scale_tri, torch.bfloat16, axis=mx_axis)
414-
w_tri_dtype = FP4 if "float4" in weight_dtype_str else weight_dtype
415-
w_tri = wrap_torch_tensor(w_tri, w_tri_dtype)
416-
w_scale_tri = wrap_torch_tensor(w_scale_tri)
417-
# convert layouts
418-
w_tri = convert_layout(w_tri, w_layout, **w_layout_opts)
419-
w_scale_tri = convert_layout(w_scale_tri, w_scale_layout, **w_scale_layout_opts)
421+
w_tri_orig = w_tri
422+
if colmajor_mxfp_weight:
423+
w_tri, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=mx_axis)
424+
w_ref = upcast_from_mxfp(w_tri, w_scale_tri, torch.bfloat16, axis=mx_axis)
425+
w_tri_dtype = FP4 if "float4" in weight_dtype_str else weight_dtype
426+
w_tri = wrap_torch_tensor(w_tri, w_tri_dtype)
427+
w_scale_tri = wrap_torch_tensor(w_scale_tri)
428+
# convert layouts
429+
w_tri = convert_layout(w_tri, w_layout, **w_layout_opts)
430+
w_scale_tri = convert_layout(w_scale_tri, w_scale_layout, **w_scale_layout_opts)
431+
else:
432+
if torch.cuda.get_device_capability()[0] < 10:
433+
pytest.skip("transposed mxfp weight not supported with cuda capability < 10")
434+
if block_m == 16:
435+
pytest.skip("PassManager::run failed from Triton compiler")
436+
# TODO: swizzling for rowmajor
437+
438+
# A typical use case is we already quantized col-major weight,
439+
# and we want matmul with its transposed row-major weight w/o
440+
# requantization.
441+
442+
# put abs_max of each 32x32 block to diagonal so scales of transposed agree
443+
w_ndim = w_tri.ndim
444+
if w_ndim == 2:
445+
w_tri = w_tri.unsqueeze(0)
446+
BLOCK_SIZE = int(MXFP_BLOCK_SIZE)
447+
for e, i, j in itertools.product(range(w_tri.shape[0]), range(0, w_tri.shape[1], BLOCK_SIZE), range(0, w_tri.shape[2], BLOCK_SIZE)):
448+
i_end = min(i+BLOCK_SIZE, w_tri.shape[1])
449+
j_end = min(j+BLOCK_SIZE, w_tri.shape[2])
450+
block = w_tri[e, i:i_end, j:j_end]
451+
m_abs = block.abs().max()
452+
i_len = i_end - i
453+
j_len = j_end - j
454+
min_len = min(i_len, j_len)
455+
signs = torch.randint(0, 2, (max(i_len, j_len),), device=w_tri.device) * 2 - 1
456+
block.diagonal(dim1=-2, dim2=-1)[:] = signs[:min_len] * m_abs
457+
if j_len > i_len:
458+
block[i_len - 1, i_len:] = signs[min_len:] * m_abs
459+
elif i_len > j_len:
460+
block[j_len:, j_len - 1] = signs[min_len:] * m_abs
461+
if w_ndim == 2:
462+
w_tri = w_tri.squeeze(0)
463+
464+
# matmul with rowmajor weight expects scale is separately
465+
# constructed (not much additional memory needed).
466+
_, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=mx_axis)
467+
# reuse quantized value from colmajor
468+
w_tri_rowmajor, w_scale_tri_rowmajor = downcast_to_mxfp(w_tri.mT.contiguous(), weight_dtype, axis=mx_axis)
469+
w_ref = upcast_from_mxfp(w_tri_rowmajor, w_scale_tri_rowmajor, torch.bfloat16, axis=mx_axis).mT.contiguous()
470+
w_tri = w_tri_rowmajor.data.mT
471+
472+
def _pad_and_block(x: torch.Tensor) -> torch.Tensor:
473+
x = torch.nn.functional.pad(x, (0, x.shape[-1] % BLOCK_SIZE), mode="replicate")
474+
return x.view(*x.shape[:-1], x.shape[-1] // BLOCK_SIZE, BLOCK_SIZE)
475+
476+
# check if generated scale is transpose-invariant as intended construction
477+
# [cdiv(K, 32), N] -> dedup to [cdiv(K, 32), cdiv(N, 32)]
478+
w_scale_tri_blocked = _pad_and_block(w_scale_tri)
479+
w_scale_tri_sampled = w_scale_tri_blocked[..., 0:1]
480+
# [cdiv(N, 32), K] -> dedup to [cdiv(N, 32), cdiv(K, 32)]
481+
w_scale_tri_rowmajor_blocked = _pad_and_block(w_scale_tri_rowmajor)
482+
w_scale_tri_rowmajor_sampled = w_scale_tri_rowmajor_blocked[..., 0:1]
483+
assert torch.equal(w_scale_tri_sampled.expand_as(w_scale_tri_blocked), w_scale_tri_blocked)
484+
assert torch.equal(w_scale_tri_rowmajor_sampled.expand_as(w_scale_tri_rowmajor_blocked), w_scale_tri_rowmajor_blocked)
485+
assert torch.equal(w_scale_tri_sampled.squeeze(-1), w_scale_tri_rowmajor_sampled.squeeze(-1).mT)
486+
420487
precision_opt.weight_scale = w_scale_tri
421488
epilogue = None
422489
if act_mxfp8:
@@ -425,7 +492,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
425492
is_input_batched = x_tri.ndim == 3
426493
y_shape = x_tri.shape if is_input_batched else (1,) + x_tri.shape
427494
n_rows = y_shape[1] if gindx is None or mode == "batched" else gindx.dst_indx.shape[0]
428-
y_shape = (y_shape[0], n_rows, w_tri.shape[-1])
495+
y_shape = (y_shape[0], n_rows, w_tri_orig.shape[-1])
429496
if sindx is None or mode == "batched":
430497
if not is_input_batched:
431498
y_shape = (y_shape[1], y_shape[2])

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .matmul_ogs_details._p_matmul_ogs import _p_matmul_ogs, get_per_device_per_stream_alloc_fn
1818
from .matmul_ogs_details._reduce_grouped import _reduce_grouped
1919
from .numerics_details.mxfp import MXFP_BLOCK_SIZE
20+
from .tensor_details.layout_details.strided import StridedLayout
2021
from .matmul_ogs_details.opt_flags import make_opt_flags, update_opt_flags_constraints, InapplicableConstraint
2122
from .specialize import specialize
2223
from .tensor import Storage, Tensor, FP4, bitwidth, wrap_torch_tensor
@@ -339,6 +340,8 @@ def matmul_ogs(x, w, bias,
339340
# TODO: remove this code path; using uint8 for mxfp4 weight will bite us when we want to support uint8 for real
340341
dtype = FP4 if w.dtype == torch.uint8 else w.dtype
341342
w = wrap_torch_tensor(w, dtype=dtype)
343+
if w_has_mx and (torch.cuda.get_device_capability()[0] < 10 or w.storage.layout is not None and not isinstance(w.storage.layout, StridedLayout)):
344+
assert w.stride(-2) == 1, "`w` must be column-major when it has data-type mxfp and (swizzled or not on >=Blackwell)"
342345
if w_scale is not None and not isinstance(w_scale, Tensor):
343346
w_scale = Tensor(w_scale)
344347
if w_scale is not None:

0 commit comments

Comments
 (0)