diff --git a/compute/accelerator/benchmarks/mamf-finder.py b/compute/accelerator/benchmarks/mamf-finder.py index 9200182..dc96341 100755 --- a/compute/accelerator/benchmarks/mamf-finder.py +++ b/compute/accelerator/benchmarks/mamf-finder.py @@ -156,6 +156,41 @@ def flush(self): if self.verbose: self.stdout.flush() +# from https://github.com/pytorch/pytorch/blob/b432443cf2fdcd2575c6e3363d4f86448a5d6650/test/test_matmul_cuda.py#L924 +# XXX: hopefully pytorch will have a core function for that +def ceil_div(a, b): return (a + b - 1) // b +def to_blocked(input_matrix) -> torch.Tensor: + """ + Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern. + + See: + https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + + Args: + input_matrix: Input tensor of shape (H, W) + + Returns: + Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4)) + """ + rows, cols = input_matrix.shape + n_row_blocks = ceil_div(rows, 128) + n_col_blocks = ceil_div(cols, 4) + + # Calculate the padded shape + padded_rows = n_row_blocks * 128 + padded_cols = n_col_blocks * 4 + + padded = input_matrix + # Ideally we would use torch.nn.pad but it doesn't support float8_e8m0fnu for now + if (rows, cols) != (padded_rows, padded_cols): + padded = torch.zeros((padded_rows, padded_cols), device=input_matrix.device, dtype=input_matrix.dtype) + padded[:rows, :cols] = input_matrix + + # Rearrange the blocks + blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) + rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) + + return rearranged.flatten() def print_benchmark_header(dtype, device, notes="None"): @@ -228,7 +263,7 @@ def func_wrapper(*args, **kwargs): # fp8 requires special handling depending on the vendor: # float8_e4m3fn for nvidia, float8_e4m3fnuz for amd - fp8_dtypes = [torch.float8_e4m3fn, torch.float8_e4m3fnuz] + fp8_dtypes = [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e8m0fnu] if dtype in fp8_dtypes: # torch._scaled_mm is different before pt-2.5 if version.parse(torch.__version__) < version.parse("2.5"): @@ -238,14 +273,33 @@ def func_wrapper(*args, **kwargs): A = torch.randn(m, k, dtype=torch.float32, device=device).contiguous() B = torch.randn(n, k, dtype=torch.float32, device=device).contiguous().t() - scale = torch.tensor([1.0]).to(device) - A = A.to(dtype) - B = B.to(dtype) + + if dtype == torch.float8_e8m0fnu: + # mxfp8 + BLOCK_SIZE = 32 + # from https://github.com/pytorch/pytorch/blob/b432443cf2fdcd2575c6e3363d4f86448a5d6650/test/test_matmul_cuda.py#L950-L961 + + A = A.to(torch.float8_e4m3fn) + B = B.to(torch.float8_e4m3fn) + + scale_a = torch.full((m, ceil_div(k, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu) + scale_b = torch.full((n, ceil_div(k, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu) + # convert to swizzled format + scale_a = to_blocked(scale_a) + scale_b = to_blocked(scale_b) + + out_dtype = torch.torch.bfloat16 + else: + scale_a = torch.tensor([1.0]).to(device) + scale_b = scale_a + A = A.to(dtype) + B = B.to(dtype) + out_dtype = dtype # Simplified call for PyTorch 2.5+ @time_it(total_iterations) def time_iterations(): - C = torch._scaled_mm(A, B, scale, scale) + C = torch._scaled_mm(A, B, scale_a, scale_b, out_dtype=out_dtype) else: A = torch.randn(m, k, dtype=dtype, device=device).contiguous()