diff --git a/python/tutorials/09-persistent-matmul.py b/python/tutorials/09-persistent-matmul.py index 6c4a093ebd1d..1ab2bafa64b2 100644 --- a/python/tutorials/09-persistent-matmul.py +++ b/python/tutorials/09-persistent-matmul.py @@ -31,16 +31,29 @@ from typing import Optional -if torch.cuda.is_available(): + +def is_cuda(): + return triton.runtime.driver.active.get_current_target().backend == "cuda" + + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +if is_cuda(): from triton._C.libtriton import nvidia - cublas_workspace = torch.empty(32 * 1024 * 1024, device="cuda", dtype=torch.uint8) - cublas = nvidia.cublas.CublasLt(cublas_workspace) + device_workspace = torch.empty(32 * 1024 * 1024, device="cuda", dtype=torch.uint8) + device_blas = nvidia.cublas.CublasLt(device_workspace) +elif is_hip(): + from triton._C.libtriton import amd + device_workspace = torch.empty(32 * 1024 * 1024, device="cuda", dtype=torch.uint8) + device_blas = amd.hipblas.HipblasLt(device_workspace) else: - cublas = None + device_blas = None -def is_cuda(): - return triton.runtime.driver.active.get_current_target().backend == "cuda" +def device_blas_name(): + return 'cuBLAS' if is_cuda() else 'hipBLAS' def supports_tma(): @@ -592,7 +605,7 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]): return c -def cublas_matmul(a, b): +def device_blas_matmul(a, b): # Check constraints. assert a.shape[1] == b.shape[1], "Incompatible dimensions" # b is transposed M, K = a.shape @@ -601,9 +614,10 @@ def cublas_matmul(a, b): c = torch.empty((M, N), device=a.device, dtype=dtype) bytes_per_elem = a.element_size() flops_str = f"flops{bytes_per_elem * 8}" - with proton.scope(f"cublas [M={M}, N={N}, K={K}]", + blas_name = device_blas_name() + with proton.scope(f"{blas_name} [M={M}, N={N}, K={K}]", {"bytes": bytes_per_elem * (M * K + N * K + M * N), flops_str: 2. * M * N * K}): - cublas.matmul(a, b, c) + device_blas.matmul(a, b, c) return c @@ -645,8 +659,9 @@ def bench(K, dtype, reps=10000, warmup_reps=10000): b = b.T.contiguous() - if cublas is not None: - bench_fn("cublas", reps, warmup_reps, cublas_matmul, a, b) + if device_blas is not None: + blas_name = device_blas_name() + bench_fn(blas_name, reps, warmup_reps, device_blas_matmul, a, b) if dtype == torch.float16: bench_fn("torch", reps, warmup_reps, torch_matmul, a, b) bench_fn("naive", reps, warmup_reps, matmul, a, b.T) @@ -682,7 +697,7 @@ def validate(M, N, K, dtype): naive_result = matmul(a, b.T).to(torch.float16) run_test(naive_result, torch_matmul, a, b, "Torch", enabled=dtype == torch.float16) - run_test(naive_result, cublas_matmul, a, b, "cuBLAS", enabled=cublas is not None) + run_test(naive_result, device_blas_matmul, a, b, device_blas_name(), enabled=device_blas is not None) run_test(naive_result, matmul_persistent, a, b.T, "Persistent") kernels = [ @@ -722,7 +737,7 @@ def show_profile(precision, profile_name): args = parser.parse_args() if args.prec == 'fp8' and (not hasattr(torch, "float8_e4m3fn") or not is_cuda()): - print("This example requires CUDA with fp8 support.") + print("This example requires CUDA/HIP with fp8 support.") else: dtype = torch.float8_e4m3fn if args.prec == 'fp8' else torch.float16