|
154 | 154 | import triton
|
155 | 155 | import triton.language as tl
|
156 | 156 |
|
| 157 | +from triton.language.target_info import is_hip_cdna4 |
| 158 | + |
157 | 159 | DEVICE = triton.runtime.driver.active.get_active_torch_device()
|
158 | 160 |
|
159 | 161 |
|
@@ -210,7 +212,7 @@ def get_hip_autotune_config():
|
210 | 212 | {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4},
|
211 | 213 | {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 6},
|
212 | 214 | ]
|
213 |
| - return [triton.Config(s | {'matrix_instr_nonkdim': 16}, num_warps=8, num_stages=2) for s in sizes] |
| 215 | + return [triton.Config(s | {'matrix_instr_nonkdim': 32}, num_warps=8, num_stages=2) for s in sizes] |
214 | 216 |
|
215 | 217 |
|
216 | 218 | def get_autotune_config():
|
@@ -372,7 +374,7 @@ def matmul(a, b, activation=""):
|
372 | 374 | print("❌ Triton and Torch differ")
|
373 | 375 |
|
374 | 376 | TORCH_HAS_FP8 = hasattr(torch, "float8_e5m2")
|
375 |
| -if TORCH_HAS_FP8 and is_cuda(): |
| 377 | +if TORCH_HAS_FP8 and (is_cuda() or is_hip_cdna4()): |
376 | 378 | torch.manual_seed(0)
|
377 | 379 | a = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
|
378 | 380 | b = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
|
@@ -403,7 +405,7 @@ def matmul(a, b, activation=""):
|
403 | 405 |
|
404 | 406 | configs = []
|
405 | 407 | for fp8_inputs in [False, True]:
|
406 |
| - if fp8_inputs and (not TORCH_HAS_FP8 or not is_cuda()): |
| 408 | + if fp8_inputs and (not TORCH_HAS_FP8 or (not is_cuda() and not is_hip_cdna4())): |
407 | 409 | continue
|
408 | 410 | configs.append(
|
409 | 411 | triton.testing.Benchmark(
|
|
0 commit comments