Skip to content

Add triton dtype for torch.float8_e4m3fnuz to allow explicit casts #8164

@bringlein

Description

@bringlein

Describe the bug

Currently, triton does not have a dtype for torch.float8_e4m3fnuz:

FP_TYPES = ['fp8e4b15', 'fp8e4nv', 'fp8e4b8', 'fp8e5', 'fp8e5b16', 'fp16', 'bf16', 'fp32', 'fp64']

The datatype is supported by the backend and a cast like some_tensor.to(target_pointer.dtype.element_ty) is working.
However, a explicit type cast like some_tensor.to(tl.fp8e4fnuz) is not possible.
This is a problem in situations where the kernel cannot access the correct datatype via a pytorch pointer (one example: vllm-project/vllm#24503 (comment)).

I know, this might be hard to keep it compatible cross-platform, but maybe still worth considering.

Environment details

Triton 3.4.0, MI300

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions