Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions python/triton_kernels/tests/test_mxfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,22 @@ def test_mxfp4_rounding_cases(dst_dtype, device):
assert_equal(dequant_torch, dequant)


@pytest.mark.parametrize("src_dtype", ["float4_e2m1", "float8_e5m2", "float8_e4m3fn"])
@pytest.mark.parametrize("dst_dtype", ["float16", "bfloat16", "float32"])
def test_mxfp_extreme_values(src_dtype, dst_dtype, device):
if "float8" in src_dtype and (is_cuda() and torch.cuda.get_device_capability()[0] < 9):
pytest.skip("Float8 not tested on A100")
src_dtype = dtype_str_to_torch(src_dtype)
dst_dtype = dtype_str_to_torch(dst_dtype)
BIG_VALUE = 65470 if dst_dtype == torch.float16 else 3.3895e38
x = torch.tensor([BIG_VALUE, BIG_VALUE], dtype=dst_dtype, device=device)
xq_value, xq_scale = downcast_to_mxfp(x, src_dtype, axis=-1)
xdq = upcast_from_mxfp(xq_value, xq_scale, dst_dtype, axis=-1)
xdq_ref = upcast_from_mxfp_torch(xq_value, xq_scale, dst_dtype, axis=-1)
assert_equal(xdq_ref, xdq)
assert not xdq.isinf().any()


@pytest.mark.parametrize("src_dtype", ["float4_e2m1", "float8_e5m2", "float8_e4m3fn"])
@pytest.mark.parametrize("dst_dtype", ["float16", "bfloat16", "float32"])
def test_mxfp_quant_dequant(src_dtype, dst_dtype, device):
Expand Down
11 changes: 11 additions & 0 deletions python/triton_kernels/triton_kernels/numerics_details/mxfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,17 @@ def upcast_from_mxfp_torch(tensor: torch.Tensor, scale: torch.Tensor, target_dty
padded_tensor = padded_tensor.view(*new_shape)
dq_scale_padded = dq_scale.unsqueeze(-1) # shape: [..., ceil(axis_shape/32), 1]
out_padded = padded_tensor * dq_scale_padded
# Need to clamp since due to rounding, we can have overflow that was within
# the range before quantization.
# e.g., 3.3895e+38 -> log2(3.3895e+38 / max_fp8e4m3=448) ~= 119.17 -> round
# up to 120 + exp_bias=127 -> scale=247
# 3.3895e+38 / 2**120 ~= 254.9976 -> round to 256 in fp8e4m3fn
# Dequantization: 256 * 2**120 > 3.4e38 overflowing 3.38953139e38
finfo = torch.finfo(target_dtype)
out_padded = (padded_tensor * dq_scale_padded).clamp(finfo.min, finfo.max)
if tensor.dtype == torch.float8_e5m2:
# fp8e5m2 can have inf and we want to preserve so separately handle
out_padded = out_padded.where(~padded_tensor.isinf(), padded_tensor.to(target_dtype))

# Flatten back and remove the padded tail
out_padded = out_padded.view(*fp32_tensor.shape[:-1], new_axis_shape)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,16 @@ def _upcast_from_mxfp(out_ptr, stride_o_outer, stride_o_quant: tl.constexpr, mx_
scale = scale.reshape(dst_scale.shape)

out_tensor = dst_tensor * dst_scale
if dst_dtype == tl.float32:
max_fin = 3.4028234663852886e+38
elif dst_dtype == tl.bfloat16:
max_fin = 3.3895313892515355e+38
else:
tl.static_assert(dst_dtype == tl.float16)
max_fin = 65504
# TODO: handle infinity same as upcast_from_mxfp_torch together with the
# above FIXME
out_tensor = tl.clamp(out_tensor, min=-max_fin, max=max_fin)
# Correct any NaNs encoded via the scale.
out_tensor = tl.where(scale == 0xFF, float("nan"), out_tensor)
out_tensor = out_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM])
Expand Down
Loading