-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Add support for fp8e4m3fnuz dtype in Triton #8231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
- Register in with mantissa/exponent details - Update canonicalization + bitwidth mapping in - Expose new dtype in public API - Add unit tests: * Verify dtype existence and repr * Round-trip float32 → fp8e4m3fnuz → float32 (skipped if CUDA/torch support is missing) - Removed incorrect placeholder mapping of Closes triton-lang#8164
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required for float8 tests") | ||
@pytest.mark.skipif(not hasattr(torch, "float8_e4m3fnuz"), reason="PyTorch build does not expose float8_e4m3fnuz") | ||
def test_float8e4m3fnuz_roundtrip(): | ||
# Create random data | ||
x = torch.randn(32, device="cuda", dtype=torch.float32) | ||
# Cast to fp8e4m3fnuz | ||
y = x.to(torch.float8_e4m3fnuz) | ||
# Cast back to fp32 | ||
z = y.to(torch.float32) | ||
|
||
# Shapes must match | ||
assert z.shape == x.shape | ||
# Result must still be a tensor | ||
assert torch.is_tensor(z) | ||
# Values should be approximately equal (fp8 is lossy) | ||
assert torch.allclose(x, z, atol=1e-1, rtol=1e-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that doesn't test triton at all
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the callout—updated. The tests now JIT and run Triton kernels that cast inside Triton:
y = x.to(tl.float8e4m3fnuz).to(tl.float32)
We then compare against a PyTorch reference (torch.float8_e4m3fnuz).
The suite includes:
- a minimal round-trip kernel,
- a param sweep over (BLOCK, num_warps),
- a cast-only variant (no load/store dtype kwargs, for API compatibility).
Tests are backend-aware as they run where the backing FP8 is supported and skip cleanly otherwise. This ensures we’re actually testing Triton lowering/JIT, not just PyTorch
"float8_e4m3fn": "fp8e4nv", | ||
"float8e4b8": "fp8e4b8", | ||
"float8_e4m3fnuz": "fp8e4b8", | ||
"float8_e4m3fnuz": "fp8e4m3fnuz", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fp8e4b8
maps to e4m3_fnuz already, what does this fix?
Line 943 in 7d92894
.def("get_fp8e4b8_ty", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed—no new backend dtype is introduced here. This PR only exposes a user-facing alias tl.float8e4m3fnuz and keeps canonicalization:
_utils.py
"float8_e4m3fnuz": "fp8e4b8"
The goal is ergonomic parity with PyTorch’s torch.float8_e4m3fnuz so kernels can write:
x.to(tl.float8e4m3fnuz)
without needing to know the backend internal name. All IR paths would still go through the existing fp8e4b8 plumbing so behavior is unchanged.
…mapping; tidy utils
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressing comments and sharing revision 2. Thanks for comments!
"float8_e4m3fn": "fp8e4nv", | ||
"float8e4b8": "fp8e4b8", | ||
"float8_e4m3fnuz": "fp8e4b8", | ||
"float8_e4m3fnuz": "fp8e4m3fnuz", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed—no new backend dtype is introduced here. This PR only exposes a user-facing alias tl.float8e4m3fnuz and keeps canonicalization:
_utils.py
"float8_e4m3fnuz": "fp8e4b8"
The goal is ergonomic parity with PyTorch’s torch.float8_e4m3fnuz so kernels can write:
x.to(tl.float8e4m3fnuz)
without needing to know the backend internal name. All IR paths would still go through the existing fp8e4b8 plumbing so behavior is unchanged.
fp8e4m3fnuz
indtype
with mantissa/exponent details_utils.py
triton.language
public APIfloat8_e4m3fnuz -> fp8e4b8
Closes #8164
New contributor declaration
I am not making a trivial change, such as fixing a typo in a comment.
I have written a PR description following these
rules.
I have run
pre-commit run --from-ref origin/main --to-ref HEAD
.Select one of the following.
/test
forlit
tests/unittest
for C++ tests/python/test
for end-to-end testsN/A
.Select one of the following.
lit
tests.lit
tests I have added follow these best practices,including the "tests should be minimal" section. (Usually running Python code
and using the instructions it generates is not minimal.)