-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Open
Labels
Description
Describe the bug
Currently, triton does not have a dtype for torch.float8_e4m3fnuz
:
triton/python/triton/language/core.py
Line 380 in aff4b7a
FP_TYPES = ['fp8e4b15', 'fp8e4nv', 'fp8e4b8', 'fp8e5', 'fp8e5b16', 'fp16', 'bf16', 'fp32', 'fp64'] |
The datatype is supported by the backend and a cast like some_tensor.to(target_pointer.dtype.element_ty)
is working.
However, a explicit type cast like some_tensor.to(tl.fp8e4fnuz)
is not possible.
This is a problem in situations where the kernel cannot access the correct datatype via a pytorch pointer (one example: vllm-project/vllm#24503 (comment)).
I know, this might be hard to keep it compatible cross-platform, but maybe still worth considering.
Environment details
Triton 3.4.0, MI300