|
| 1 | +import re |
| 2 | +import numpy as np |
| 3 | +from numpy.random import RandomState |
| 4 | +import pytest |
| 5 | +import torch |
| 6 | +import pathlib |
| 7 | + |
| 8 | +import triton |
| 9 | +from triton._internal_testing import numpy_random, to_numpy |
| 10 | + |
| 11 | +MIN_GROUP_SIZE = torch.xpu.get_device_capability()['sub_group_sizes'][0] |
| 12 | + |
| 13 | + |
| 14 | +class DpasLayout: |
| 15 | + |
| 16 | + def __init__(self, repeatCount, systolic_depth, execution_size, ops_per_chan, threads_per_warp, warps_per_cta, |
| 17 | + rep_cluster): |
| 18 | + self.repeatCount = repeatCount |
| 19 | + self.systolic_depth = systolic_depth |
| 20 | + self.execution_size = execution_size |
| 21 | + self.ops_per_chan = ops_per_chan |
| 22 | + self.threads_per_warp = threads_per_warp |
| 23 | + self.warps_per_cta = warps_per_cta |
| 24 | + self.rep_cluster = rep_cluster |
| 25 | + |
| 26 | + def __str__(self): |
| 27 | + return f"#triton_intel_gpu.dpas<{{repeatCount={self.repeatCount}, systolicDepth={self.systolic_depth}, executionSize = {self.execution_size}, opsPerChan = {self.ops_per_chan}, threadsPerWarp = {self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, repCluster={self.rep_cluster}}}>" |
| 28 | + |
| 29 | + |
| 30 | +class BlockedLayout: |
| 31 | + |
| 32 | + def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga=[1, 1], |
| 33 | + cta_split_num=[1, 1], cta_order=[0, 1]): |
| 34 | + self.sz_per_thread = size_per_thread |
| 35 | + self.threads_per_warp = threads_per_warp |
| 36 | + self.warps_per_cta = warps_per_cta |
| 37 | + self.order = order |
| 38 | + self.ctas_per_cga = ctas_per_cga |
| 39 | + self.cta_split_num = cta_split_num |
| 40 | + self.cta_order = cta_order |
| 41 | + |
| 42 | + def __str__(self): |
| 43 | + return f"#{GPU_DIALECT}.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" |
| 44 | + |
| 45 | + |
| 46 | +layouts = [ |
| 47 | + DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=2, threads_per_warp=16, |
| 48 | + warps_per_cta=[2, 2], rep_cluster=[1, 1]), |
| 49 | + BlockedLayout([1, 1], [1, 16], [2, 2], [1, 0]) |
| 50 | +] |
| 51 | + |
| 52 | +if MIN_GROUP_SIZE == 16: |
| 53 | + # Add threads_per_warp=32 cases. |
| 54 | + layouts + [ |
| 55 | + DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=2, threads_per_warp=16, |
| 56 | + warps_per_cta=[2, 2], rep_cluster=[1, 1]), |
| 57 | + BlockedLayout([1, 1], [1, 16], [2, 2], [1, 0]) |
| 58 | + ] |
| 59 | + |
| 60 | + |
| 61 | +def warps_per_cta(layout, shape): |
| 62 | + return layout.warps_per_cta |
| 63 | + |
| 64 | + |
| 65 | +GPU_DIALECT = "ttg" |
| 66 | + |
| 67 | + |
| 68 | +@pytest.mark.parametrize("M, N", [[128, 16], [128, 128], [32, 128], [32, 32], [64, 32], [16, 16]]) |
| 69 | +@pytest.mark.parametrize("src_layout", layouts) |
| 70 | +@pytest.mark.parametrize("dtype_str", ["int32", "float32", "float16"]) |
| 71 | +@pytest.mark.parametrize("reduce_op", ["sum", "max"]) |
| 72 | +def test_horizontal_simd_reduce(M, N, src_layout, dtype_str, reduce_op, device, tmp_path: pathlib.Path): |
| 73 | + ty = {"int32": "i32", "float32": "f32", "float16": "f16"}[dtype_str] |
| 74 | + arith_op = { |
| 75 | + "max": {"int32": "arith.maxsi", "float32": "arith.maximumf", "float16": "arith.maximumf"}, # |
| 76 | + "sum": {"int32": "arith.addi", "float32": "arith.addf", "float16": "arith.addf"} |
| 77 | + }[reduce_op][dtype_str] |
| 78 | + numpy_op = {"max": np.max, "sum": np.sum}[reduce_op] |
| 79 | + rdims_1d = f"{M}" |
| 80 | + rdims_2d = f"{M}x1" |
| 81 | + store_range = "%1" |
| 82 | + warps = src_layout.warps_per_cta |
| 83 | + threads_per_warp = int(np.prod(src_layout.threads_per_warp)) |
| 84 | + num_warps = int(np.prod(warps)) |
| 85 | + blocked = BlockedLayout([1, 1], [16, threads_per_warp // 16], [4, num_warps // 4], [0, 1], [1, 1], [1, 1], [0, 1]) |
| 86 | + one_d_layout = BlockedLayout([1], [threads_per_warp], [num_warps], [0], [1], [1], [0]) |
| 87 | + |
| 88 | + ir = f""" |
| 89 | + #blocked = {blocked} |
| 90 | + #src = {src_layout} |
| 91 | + #one_d_layout = {one_d_layout} |
| 92 | + module attributes {{"ttg.num-warps" = {num_warps} : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = {threads_per_warp} : i32, "triton_intel_gpu.min_sg_size" = {MIN_GROUP_SIZE} }} {{ |
| 93 | + tt.func public @kernel_0d1d2c3d4c(%arg0: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %arg1: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}) {{ |
| 94 | + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>> |
| 95 | + %1 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked> |
| 96 | + %2 = tt.splat %arg1 : i32 -> tensor<{M}x1xi32, #blocked> |
| 97 | + %3 = arith.muli %1, %2 : tensor<{M}x1xi32, #blocked> |
| 98 | + %4 = tt.splat %arg0 : !tt.ptr<{ty}> -> tensor<{M}x1x!tt.ptr<{ty}>, #blocked> |
| 99 | + %5 = tt.addptr %4, %3 : tensor<{M}x1x!tt.ptr<{ty}>, #blocked>, tensor<{M}x1xi32, #blocked> |
| 100 | + %6 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>> |
| 101 | + %7 = tt.expand_dims %6 {{axis = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{N}xi32, #blocked> |
| 102 | + %8 = tt.broadcast %5 : tensor<{M}x1x!tt.ptr<{ty}>, #blocked> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked> |
| 103 | + %9 = tt.broadcast %7 : tensor<1x{N}xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> |
| 104 | + %10 = tt.addptr %8, %9 : tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked>, tensor<{M}x{N}xi32, #blocked> |
| 105 | + %11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked> |
| 106 | + %12 = {GPU_DIALECT}.convert_layout %11 : tensor<{M}x{N}x{ty}, #blocked> -> tensor<{M}x{N}x{ty}, #src> |
| 107 | + %13 = "tt.reduce"(%12) ({{ |
| 108 | + ^bb0(%arg3: {ty}, %arg4: {ty}): |
| 109 | + %17 = {arith_op} %arg3, %arg4 : {ty} |
| 110 | + tt.reduce.return %17 : {ty} |
| 111 | + }}) {{axis = 1 : i32}} : (tensor<{M}x{N}x{ty}, #src>) -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> |
| 112 | + %14 = tt.splat %arg2 : !tt.ptr<{ty}> -> tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked> |
| 113 | + %15 = tt.addptr %14, {store_range} : tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked>, tensor<{rdims_2d}xi32, #blocked> |
| 114 | + %16 = {GPU_DIALECT}.convert_layout %13 : tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>> |
| 115 | + %17 = tt.expand_dims %16 {{axis = 1 : i32}} : tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{rdims_2d}x{ty}, #blocked> |
| 116 | + tt.store %15, %17 : tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked> |
| 117 | + tt.return |
| 118 | + }} |
| 119 | + }} |
| 120 | + """ |
| 121 | + |
| 122 | + temp_file = tmp_path / "test_reduce_layouts.ttgir" |
| 123 | + print("johnlu ttgir:", ir) |
| 124 | + temp_file.write_text(ir) |
| 125 | + kernel = triton.compile(str(temp_file)) |
| 126 | + |
| 127 | + rs = RandomState(17) |
| 128 | + x = numpy_random((M, N), dtype_str=dtype_str, rs=rs, low=0, high=10) |
| 129 | + z_shape = (M, 1) |
| 130 | + z = np.zeros(z_shape).astype(dtype_str) |
| 131 | + |
| 132 | + x_tri = torch.tensor(x, device=device) |
| 133 | + z_tri = torch.tensor(z, device=device) |
| 134 | + |
| 135 | + kernel[(1, 1, 1)](x_tri, x_tri.stride(0), z_tri) |
| 136 | + z_ref = numpy_op(x, axis=1, keepdims=True) |
| 137 | + |
| 138 | + llir = kernel.asm['llir'] |
| 139 | + assert re.search(r'call .* asm', llir), 'no inline visa in llir' # inline visa is used |
| 140 | + |
| 141 | + if dtype_str == 'float16': |
| 142 | + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-2) |
| 143 | + else: |
| 144 | + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3) |
0 commit comments