Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 70 additions & 10 deletions python/triton_kernels/tests/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ class Case:
x_transpose: bool = False
w_transpose: bool = False
y_transpose: bool = False
colmajor_mxfp_weight: bool = True


@pytest.mark.parametrize(
Expand Down Expand Up @@ -269,6 +270,7 @@ class Case:
Case(1000, 704, 800, "batched", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 2, 1),
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9),
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True),
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9, colmajor_mxfp_weight=False),
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2),
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True),
Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 8, 4),
Expand Down Expand Up @@ -315,7 +317,7 @@ class Case:
@pytest.mark.parametrize("has_y_gammas", [False, True])
@pytest.mark.parametrize("is_persistent", [False, True])
def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot,
n_expts_act, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, epilogue_subtile,
n_expts_act, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile,
x_transpose, w_transpose, y_transpose,
device, opt_flags_scope):
# TODO: remove when Triton FP8 supports proper RTNE
Expand Down Expand Up @@ -462,14 +464,72 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
w_scale_layout, w_scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout(
mx_axis=mx_axis, num_warps=8)
# downcast to mxfp
w_tri, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=mx_axis)
w_ref = upcast_from_mxfp(w_tri, w_scale_tri, torch.bfloat16, axis=mx_axis)
w_tri_dtype = FP4 if "float4" in weight_dtype_str else weight_dtype
w_tri = wrap_torch_tensor(w_tri, w_tri_dtype)
w_scale_tri = wrap_torch_tensor(w_scale_tri)
# convert layouts
w_tri = convert_layout(w_tri, w_layout, **w_layout_opts)
w_scale_tri = convert_layout(w_scale_tri, w_scale_layout, **w_scale_layout_opts)
w_tri_orig = w_tri
if colmajor_mxfp_weight:
w_tri, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=mx_axis)
w_ref = upcast_from_mxfp(w_tri, w_scale_tri, torch.bfloat16, axis=mx_axis)
w_tri_dtype = FP4 if "float4" in weight_dtype_str else weight_dtype
w_tri = wrap_torch_tensor(w_tri, w_tri_dtype)
w_scale_tri = wrap_torch_tensor(w_scale_tri)
# convert layouts
w_tri = convert_layout(w_tri, w_layout, **w_layout_opts)
w_scale_tri = convert_layout(w_scale_tri, w_scale_layout, **w_scale_layout_opts)
else:
if torch.cuda.get_device_capability()[0] < 10:
pytest.skip("transposed mxfp weight not supported with cuda capability < 10")
if block_m == 16:
pytest.skip("PassManager::run failed from Triton compiler")
# TODO: swizzling for rowmajor

# A typical use case is we already quantized col-major weight,
# and we want matmul with its transposed row-major weight w/o
# requantization.

# put abs_max of each 32x32 block to diagonal so scales of transposed agree
w_ndim = w_tri.ndim
if w_ndim == 2:
w_tri = w_tri.unsqueeze(0)
BLOCK_SIZE = int(MXFP_BLOCK_SIZE)
for e, i, j in itertools.product(range(w_tri.shape[0]), range(0, w_tri.shape[1], BLOCK_SIZE), range(0, w_tri.shape[2], BLOCK_SIZE)):
i_end = min(i+BLOCK_SIZE, w_tri.shape[1])
j_end = min(j+BLOCK_SIZE, w_tri.shape[2])
block = w_tri[e, i:i_end, j:j_end]
m_abs = block.abs().max()
i_len = i_end - i
j_len = j_end - j
min_len = min(i_len, j_len)
signs = torch.randint(0, 2, (max(i_len, j_len),), device=w_tri.device) * 2 - 1
block.diagonal(dim1=-2, dim2=-1)[:] = signs[:min_len] * m_abs
if j_len > i_len:
block[i_len - 1, i_len:] = signs[min_len:] * m_abs
elif i_len > j_len:
block[j_len:, j_len - 1] = signs[min_len:] * m_abs
if w_ndim == 2:
w_tri = w_tri.squeeze(0)

# matmul with rowmajor weight expects scale is separately
# constructed (not much additional memory needed).
_, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=mx_axis)
# reuse quantized value from colmajor
w_tri_rowmajor, w_scale_tri_rowmajor = downcast_to_mxfp(w_tri.mT.contiguous(), weight_dtype, axis=mx_axis)
w_ref = upcast_from_mxfp(w_tri_rowmajor, w_scale_tri_rowmajor, torch.bfloat16, axis=mx_axis).mT.contiguous()
w_tri = w_tri_rowmajor.data.mT

def _pad_and_block(x: torch.Tensor) -> torch.Tensor:
x = torch.nn.functional.pad(x, (0, x.shape[-1] % BLOCK_SIZE), mode="replicate")
return x.view(*x.shape[:-1], x.shape[-1] // BLOCK_SIZE, BLOCK_SIZE)

# check if generated scale is transpose-invariant as intended construction
# [cdiv(K, 32), N] -> dedup to [cdiv(K, 32), cdiv(N, 32)]
w_scale_tri_blocked = _pad_and_block(w_scale_tri)
w_scale_tri_sampled = w_scale_tri_blocked[..., 0:1]
# [cdiv(N, 32), K] -> dedup to [cdiv(N, 32), cdiv(K, 32)]
w_scale_tri_rowmajor_blocked = _pad_and_block(w_scale_tri_rowmajor)
w_scale_tri_rowmajor_sampled = w_scale_tri_rowmajor_blocked[..., 0:1]
assert torch.equal(w_scale_tri_sampled.expand_as(w_scale_tri_blocked), w_scale_tri_blocked)
assert torch.equal(w_scale_tri_rowmajor_sampled.expand_as(w_scale_tri_rowmajor_blocked), w_scale_tri_rowmajor_blocked)
assert torch.equal(w_scale_tri_sampled.squeeze(-1), w_scale_tri_rowmajor_sampled.squeeze(-1).mT)

precision_opt.weight_scale = w_scale_tri
epilogue = None
if act_mxfp8:
Expand All @@ -478,7 +538,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
is_input_batched = x_tri.ndim == 3
y_shape = x_tri.shape if is_input_batched else (1,) + x_tri.shape
n_rows = y_shape[1] if gindx is None or mode == "batched" else gindx.dst_indx.shape[0]
y_shape = (y_shape[0], n_rows, w_tri.shape[-1])
y_shape = (y_shape[0], n_rows, w_tri_orig.shape[-1])
if sindx is None or mode == "batched":
if not is_input_batched:
y_shape = (y_shape[1], y_shape[2])
Expand Down
4 changes: 3 additions & 1 deletion python/triton_kernels/triton_kernels/matmul_ogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .matmul_ogs_details._p_matmul_ogs import _p_matmul_ogs, get_per_device_per_stream_alloc_fn
from .matmul_ogs_details._reduce_grouped import _reduce_grouped
from .numerics_details.mxfp import MXFP_BLOCK_SIZE
from .tensor_details.layout_details.strided import StridedLayout
from .matmul_ogs_details.opt_flags import make_opt_flags, update_opt_flags_constraints, InapplicableConstraint
from .specialize import specialize
from .tensor import Storage, Tensor, FP4, bitwidth, wrap_torch_tensor
Expand Down Expand Up @@ -441,12 +442,13 @@ def matmul_ogs(x, w, bias,
w_scale = precision_config.weight_scale
w_has_mx = w_scale is not None
is_hopper_fp8 = is_cuda() and not target_info.cuda_capability_geq(10, 0) and bitwidth(w.dtype) == 8
if w_has_mx: assert w.stride(-2) == 1, "`w` must be column-major when it has data-type mxfp"
if is_hopper_fp8: assert w.stride(-2) == 1, "`w` must be column-major when it has data-type FP8 on capability < 10"
if not isinstance(w, Tensor):
# TODO: remove this code path; using uint8 for mxfp4 weight will bite us when we want to support uint8 for real
dtype = FP4 if w.dtype == torch.uint8 else w.dtype
w = wrap_torch_tensor(w, dtype=dtype)
if w_has_mx and (torch.cuda.get_device_capability()[0] < 10 or w.storage.layout is not None and not isinstance(w.storage.layout, StridedLayout)):
assert w.stride(-2) == 1, "`w` must be column-major when it has data-type mxfp and (swizzled or not on >=Blackwell)"
if w_scale is not None and not isinstance(w_scale, Tensor):
w_scale = Tensor(w_scale)
if w_scale is not None:
Expand Down
Loading