Skip to content

Commit 124fffa

Browse files
[AMD] Support float8_e5m2 in 03-matrix-multiplication.py for gfx950
This enables the float8_e5m2 part of the test for AMD's gfx950 devices, which has gained support for this data type (see https://rocm.docs.amd.com/en/latest/reference/precision-support.html). This test was already enabled for CUDA devices.
1 parent 2263431 commit 124fffa

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

python/tutorials/03-matrix-multiplication.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@
154154
import triton
155155
import triton.language as tl
156156

157+
from triton.language.target_info import is_hip_cdna4
158+
157159
DEVICE = triton.runtime.driver.active.get_active_torch_device()
158160

159161

@@ -210,7 +212,7 @@ def get_hip_autotune_config():
210212
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4},
211213
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 6},
212214
]
213-
return [triton.Config(s | {'matrix_instr_nonkdim': 16}, num_warps=8, num_stages=2) for s in sizes]
215+
return [triton.Config(s | {'matrix_instr_nonkdim': 32}, num_warps=8, num_stages=2) for s in sizes]
214216

215217

216218
def get_autotune_config():
@@ -372,7 +374,7 @@ def matmul(a, b, activation=""):
372374
print("❌ Triton and Torch differ")
373375

374376
TORCH_HAS_FP8 = hasattr(torch, "float8_e5m2")
375-
if TORCH_HAS_FP8 and is_cuda():
377+
if TORCH_HAS_FP8 and (is_cuda() or is_hip_cdna4()):
376378
torch.manual_seed(0)
377379
a = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
378380
b = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
@@ -403,7 +405,7 @@ def matmul(a, b, activation=""):
403405

404406
configs = []
405407
for fp8_inputs in [False, True]:
406-
if fp8_inputs and (not TORCH_HAS_FP8 or not is_cuda()):
408+
if fp8_inputs and (not TORCH_HAS_FP8 or (not is_cuda() and not is_hip_cdna4())):
407409
continue
408410
configs.append(
409411
triton.testing.Benchmark(

0 commit comments

Comments
 (0)