Skip to content
10 changes: 6 additions & 4 deletions python/triton_kernels/triton_kernels/matmul_ogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,8 @@ def matmul_ogs(x, w, bias,
w_scale = precision_config.weight_scale
w_has_mx = w_scale is not None
is_hopper_fp8 = is_cuda() and not target_info.cuda_capability_geq(10, 0) and bitwidth(w.dtype) == 8
if w_has_mx: assert w.stride(-2) == 1, "`w` must be column-major when it has data-type mxfp"
is_hopper_mx = is_cuda() and not target_info.cuda_capability_geq(10, 0) and w_scale is not None
if is_hopper_mx: assert w.stride(-2) == 1, "`w` must be column-major when it has data-type mxfp on capability < 10"
if is_hopper_fp8: assert w.stride(-2) == 1, "`w` must be column-major when it has data-type FP8 on capability < 10"
if not isinstance(w, Tensor):
# TODO: remove this code path; using uint8 for mxfp4 weight will bite us when we want to support uint8 for real
Expand Down Expand Up @@ -439,7 +440,7 @@ def matmul_ogs(x, w, bias,
has_scatter_tma = opt_flags.fused_scatter and target_info.has_tma_gather()
y = wrap_torch_tensor(out_matmul.view(math.prod(out_matmul.shape[:-1]), out_matmul.shape[-1]) if opt_flags.fused_scatter else out_matmul.view(math.prod(out_matmul.shape[:-2]), *out_matmul.shape[-2:]))
x_storage = _canonicalize_storage(x.storage, 2 if has_gather_tma else 3, flex.lhs_data)
w_storage = _canonicalize_storage(w.storage, 3, flex.rhs_data)
w_storage = _canonicalize_storage(w.storage, 5 if w.storage.layout.name == "BLACKWELL_VALUE" else 3, flex.rhs_data)
y_storage = _canonicalize_storage(y.storage, 2 if has_scatter_tma else 3, flex.out_data)
# create tma descriptor for x
x_has_tma = opt_flags.is_persistent and (has_gather_tma or not has_gather)
Expand All @@ -465,14 +466,15 @@ def matmul_ogs(x, w, bias,
x_scale_strides = (0, ) * (3 - len(x_scale_strides)) + x_scale_strides
w_scale_strides = w_scale.stride() if w_has_mx and not w_scale_has_tma else (None, None, None)
w_scale_strides = (0, ) * (3 - len(w_scale_strides)) + w_scale_strides
# w_strides = (None, None, None) if opt_flags.is_persistent else w_storage.data.stride()
w_strides = w_storage.data.stride()[-3:]
out_matmul_scale_strides = out_matmul_scale.stride() if out_matmul_has_mx else (None, None, None, None)
out_matmul_scale_strides = (0, ) * (4 - len(out_matmul_scale_strides)) + out_matmul_scale_strides
# launch kernel
kernels = get_kernels(epilogue.specs, matmul_fused_activation.specs)
# When stride(-2) == stride(-1) == 1, it's ambiguous whether W is transposed
# (i.e. col-wise). Since this matters when w_has_mx is True and w_transpose
# is True the fast code path, stride(-2) == 1 takes precedence, e.g., vs.
# w_transpose = w_storage.data.stride()[-1] != 1
w_transpose = w_storage.data.stride()[-2] == 1
(kernels._p_matmul_ogs if opt_flags.is_persistent else kernels._matmul_ogs)[(grid,)](
y_tensor_or_tma, y_storage.data, *out_matmul.stride(),
Expand All @@ -481,7 +483,7 @@ def matmul_ogs(x, w, bias,
x_tensor_or_tma, x_storage.data, *x_strides, x_transpose,
flex.lhs_data.scale,
None if x_scale is None else x_scale.data.view(torch.uint8), *x_scale_strides,
w_tensor_or_tma, w_storage.data, *w_storage.data.stride(), w_transpose,
w_tensor_or_tma, w_storage.data, *w_strides, w_transpose,
flex.rhs_data.scale,
w_scale_tensor_or_tma, *w_scale_strides,
bias, bias_stride,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from triton.tools.ragged_tma import load_ragged, store_ragged
from triton_kernels import target_info
from triton_kernels.tensor_details.layout_details.blackwell_scale import unswizzle_mx_scale_bw
from triton_kernels.tensor_details.layout_details.blackwell_value import unswizzle_mx_value_bw
from triton_kernels.numerics_details.flexpoint import (
float_to_flex,
load_scale,
Expand Down Expand Up @@ -113,7 +114,7 @@ def _p_matmul_ogs(
# optimization config
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, XCD_SWIZZLE: tl.constexpr,
# NYI: Must be None
# One of ["BLACKWELL", None]
SWIZZLE_MX_VALUE: tl.constexpr,
# One of ["BLACKWELL", None]
SWIZZLE_MX_SCALE: tl.constexpr,
Expand All @@ -127,7 +128,6 @@ def _p_matmul_ogs(
UPCAST_INDICES:tl.constexpr=False,
SWAP_XW: tl.constexpr = False,
IS_EPILOGUE_QUANT_MXFP8: tl.constexpr = False):
# tl.static_assert(SWIZZLE_MX_VALUE is None, "NYI. Value swizzling")

# why is this faster than using host-side tensor descriptor?!
if Y_TMA_MODE is not None:
Expand Down Expand Up @@ -310,10 +310,12 @@ def _p_matmul_ogs(
x = tl.load(XPtrs, mask=mask_k[None, :], other=0.0)

# --- load w ---
if W_TRANSPOSE:
w = tl.reshape(W.load([expt_id, off_n, off_k_w]), W.block_shape[1:]).T
if SWIZZLE_MX_VALUE == "BLACKWELL_VALUE":
w = unswizzle_mx_value_bw(tl.reshape(W.load([expt_id, off_n // 2, off_k_w // 64, 0, 0]), W.block_shape[1:]))
else:
w = tl.reshape(W.load([expt_id, off_k_w, off_n]), W.block_shape[1:])
w = tl.reshape(W.load([expt_id, off_n, off_k_w]), W.block_shape[1:])
if W_TRANSPOSE:
w = w.T

# --- load w_scale ---
if is_w_microscaled:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# isort: off
# fmt: off
from dataclasses import dataclass

import triton
from triton_kernels.target_info import get_cdna_version
from triton_kernels.tensor import FP4
import torch
from .opt_flags_details import opt_flags_amd, opt_flags_nvidia

Expand Down Expand Up @@ -165,7 +167,11 @@ def make_default_opt_flags_nvidia(
elif enforce_bitwise_invariance:
block_m = 128
else:
block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128))
if tokens_per_expt <= 64 and routing_data is not None and routing_data.expt_hist is not None:
# Ragged and likely memory bound; set the block size higher to minimize loading weights more than once.
block_m = max(16, min(triton.next_power_of_2(2 * tokens_per_expt), 64))
else:
block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128))
# block n
arch = None
block_n, block_n_tma = opt_flags_nvidia.compute_block_n(n, arch, precision_config)
Expand All @@ -188,6 +194,10 @@ def make_default_opt_flags_nvidia(
block_k = constraints["block_k"]
else:
block_k = opt_flags_nvidia.compute_block_k(m, k, is_persistent, lhs_dtype, rhs_dtype, precision_config)
# if block_n == 256 and block_k == 128 and block_m <= 64 and is_persistent and rhs_dtype == FP4 and k >= 4096 and tokens_per_expt > 1:
# # Swap block_n and block_k for mxfp4 weights so that block_k is a full cacheline, so long as K is sufficiently large.
# # TODO: swizzle the HBM layout of the weights instead
# block_n, block_k = block_k, block_n
# split_k
if constraints.get("split_k", None) is not None:
split_k = constraints["split_k"]
Expand Down
6 changes: 1 addition & 5 deletions python/triton_kernels/triton_kernels/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,7 @@ def make_dense_tma(self, block_shape, transpose=False):
shape = shape[:-2] + [shape[-1], shape[-2]]
strides = strides[:-2] + [strides[-1], strides[-2]]
if self.data.dtype == torch.uint8 and self.layout.name == "BLACKWELL_VALUE":
indx = strides.index(1)
block_shape[indx] = block_shape[indx] // 2
if shape[-1] % 128 != 0:
raise ValueError("inner shape need to be multiple of 128 for "
"mxfp4 (CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B) TMAs.")
block_shape[-1] = block_shape[-1] // 2
block_shape = self.layout.swizzle_block_shape(block_shape)
return TensorDescriptor(self.data, shape, strides, block_shape)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import triton
import triton.language as tl
import torch
from .base import Layout

Expand All @@ -10,24 +12,69 @@ def __init__(self, shape) -> None:
self.shape = shape

def swizzle_data(self, data):
# permutation needed to make `data` row major
to_row_major = sorted(range(data.ndim), key=lambda d: (data.stride(d), d))[::-1]
# permutation needed to retrieve original order
inv = [0] * data.ndim
for i, d in enumerate(to_row_major):
inv[d] = i
# leading dimension must be padded to be aligned to 128
align_dim = lambda x: (x + 128 - 1) // 128 * 128
assert data.shape == self.shape, "Mismatch between data and recorded shape"

major_dim = data.stride().index(1)
pad = align_dim(data.shape[major_dim]) - data.shape[major_dim]
data = torch.nn.functional.pad(data.permute(to_row_major), (0, pad)).permute(inv)
minor_dim = major_dim - 1 if major_dim == data.ndim - 1 else major_dim + 1

col_major = (major_dim, minor_dim) == (data.ndim - 2, data.ndim - 1)
row_major = (minor_dim, major_dim) == (data.ndim - 2, data.ndim - 1)
assert col_major or row_major

align_to = lambda x, alignment: (x + alignment - 1) // alignment * alignment

pad_major = align_to(data.shape[major_dim], 128) - data.shape[major_dim]
pad_minor = align_to(data.shape[minor_dim], 2) - data.shape[minor_dim]

padding = []
for dim in reversed(range(min(major_dim, minor_dim), data.ndim)):
if dim == major_dim:
padding.extend((0, pad_major))
elif dim == minor_dim:
padding.extend((0, pad_minor))
else:
padding.extend((0, 0))
data = torch.nn.functional.pad(data, tuple(padding))

*leading_shape, R, C = data.shape
leading_dims = range(data.ndim - 2)

if col_major:
data = data.reshape(*leading_shape, R // 64, 64, C // 2, 2)
data = data.permute(*leading_dims, -2, -4, -1, -3)
data = data.flatten(-2, -1)
data = data.reshape(*leading_shape, C // 2, R // 64, 2, 64)
data = data.transpose(-1, -2)
else:
data = data.reshape(*leading_shape, R // 2, 2, C // 64, 64)
data = data.transpose(-2, -3).flatten(-2, -1).reshape(*leading_shape, R // 2, C // 64, 2, 64)

return data

def unswizzle_data(self, data: torch.Tensor):
# Trim padding along all dims back to the original shape recorded at init.
assert data.ndim == len(self.shape), "Rank mismatch between data and recorded shape"
sizes = [min(data.size(i), self.shape[i]) for i in range(data.ndim)]
return data[tuple(slice(0, s) for s in sizes)]
assert data.ndim == len(self.shape) + 2, "Rank mismatch between data and recorded shape"
transpose = data.stride(-1) != 1

if transpose:
*leading_shape, C2, R64, a, b = data.shape
assert (a, b) == (64, 2)
data = data.transpose(-3, -4)
else:
*leading_shape, R2, C64, a, b = data.shape
assert (a, b) == (2, 64)
data = data.transpose(-2, -3)
data = data.flatten(-4, -3).flatten(-2, -1)

return data

def swizzle_block_shape(self, block_shape):
return block_shape
*leading_shape, BLOCK_N, BLOCK_K = block_shape
return (*leading_shape, BLOCK_N // 2, BLOCK_K // 64, 2, 64)

@triton.jit
def unswizzle_mx_value_bw(x):
shape_0: tl.constexpr = x.shape[0]
shape_1: tl.constexpr = x.shape[1]
tl.static_assert(x.shape[1] == 1, "unswizzle_mx_value_bw requires shape[1] == 1")
x = x.reshape(shape_0 * 2, shape_1 * 64)
return x