Skip to content

Conversation

arnav-garg1
Copy link

@arnav-garg1 arnav-garg1 commented Sep 19, 2025

  • Register fp8e4m3fnuz in dtype with mantissa/exponent details
  • Update canonicalization + bitwidth mapping in _utils.py
  • Expose new dtype in triton.language public API
  • Add unit tests:
    • Verify dtype existence and repr
    • Round-trip float32 → fp8e4m3fnuz → float32 (skipped if CUDA/torch support is missing)
  • Removed incorrect placeholder mapping of float8_e4m3fnuz -> fp8e4b8

Closes #8164

New contributor declaration

  • I am not making a trivial change, such as fixing a typo in a comment.

  • I have written a PR description following these
    rules.

  • I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • Select one of the following.

    • I have added tests.
      • /test for lit tests
      • /unittest for C++ tests
      • /python/test for end-to-end tests
    • This PR does not need a test because N/A.
  • Select one of the following.

    • I have not added any lit tests.
    • The lit tests I have added follow these best practices,
      including the "tests should be minimal" section. (Usually running Python code
      and using the instructions it generates is not minimal.)

- Register  in  with mantissa/exponent details
- Update canonicalization + bitwidth mapping in
- Expose new dtype in  public API
- Add unit tests:
  * Verify dtype existence and repr
  * Round-trip float32 → fp8e4m3fnuz → float32 (skipped if CUDA/torch support is missing)
- Removed incorrect placeholder mapping of

Closes triton-lang#8164
Comment on lines +103 to +118
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required for float8 tests")
@pytest.mark.skipif(not hasattr(torch, "float8_e4m3fnuz"), reason="PyTorch build does not expose float8_e4m3fnuz")
def test_float8e4m3fnuz_roundtrip():
# Create random data
x = torch.randn(32, device="cuda", dtype=torch.float32)
# Cast to fp8e4m3fnuz
y = x.to(torch.float8_e4m3fnuz)
# Cast back to fp32
z = y.to(torch.float32)

# Shapes must match
assert z.shape == x.shape
# Result must still be a tensor
assert torch.is_tensor(z)
# Values should be approximately equal (fp8 is lossy)
assert torch.allclose(x, z, atol=1e-1, rtol=1e-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that doesn't test triton at all

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the callout—updated. The tests now JIT and run Triton kernels that cast inside Triton:

y = x.to(tl.float8e4m3fnuz).to(tl.float32)

We then compare against a PyTorch reference (torch.float8_e4m3fnuz).
The suite includes:

  • a minimal round-trip kernel,
  • a param sweep over (BLOCK, num_warps),
  • a cast-only variant (no load/store dtype kwargs, for API compatibility).

Tests are backend-aware as they run where the backing FP8 is supported and skip cleanly otherwise. This ensures we’re actually testing Triton lowering/JIT, not just PyTorch

"float8_e4m3fn": "fp8e4nv",
"float8e4b8": "fp8e4b8",
"float8_e4m3fnuz": "fp8e4b8",
"float8_e4m3fnuz": "fp8e4m3fnuz",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fp8e4b8 maps to e4m3_fnuz already, what does this fix?

.def("get_fp8e4b8_ty",

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed—no new backend dtype is introduced here. This PR only exposes a user-facing alias tl.float8e4m3fnuz and keeps canonicalization:

_utils.py

"float8_e4m3fnuz": "fp8e4b8"

The goal is ergonomic parity with PyTorch’s torch.float8_e4m3fnuz so kernels can write:

x.to(tl.float8e4m3fnuz)

without needing to know the backend internal name. All IR paths would still go through the existing fp8e4b8 plumbing so behavior is unchanged.

Copy link
Author

@arnav-garg1 arnav-garg1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressing comments and sharing revision 2. Thanks for comments!

"float8_e4m3fn": "fp8e4nv",
"float8e4b8": "fp8e4b8",
"float8_e4m3fnuz": "fp8e4b8",
"float8_e4m3fnuz": "fp8e4m3fnuz",
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed—no new backend dtype is introduced here. This PR only exposes a user-facing alias tl.float8e4m3fnuz and keeps canonicalization:

_utils.py

"float8_e4m3fnuz": "fp8e4b8"

The goal is ergonomic parity with PyTorch’s torch.float8_e4m3fnuz so kernels can write:

x.to(tl.float8e4m3fnuz)

without needing to know the backend internal name. All IR paths would still go through the existing fp8e4b8 plumbing so behavior is unchanged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add triton dtype for torch.float8_e4m3fnuz to allow explicit casts
2 participants