|
38 | 38 | # - Multiple kernels with args: (tritonbench_module, [(helion_module, helion_func), ...], args_dict)
|
39 | 39 | KERNEL_MAPPINGS: dict[str, tuple[str, ...]] = { # pyright: ignore[reportAssignmentType]
|
40 | 40 | # <tritonbench_op_name>: (<tritonbench_module_path>, <helion_kernel_module_path>, <helion_kernel_function_name>)
|
41 |
| - "vector_add": ("tritonbench.operators.vector_add.operator", "examples.add", "add"), |
42 |
| - "embedding": ( |
43 |
| - "tritonbench.operators.embedding.operator", |
44 |
| - "examples.embedding", |
45 |
| - "embedding_tritonbench", |
46 |
| - ), |
47 |
| - "vector_exp": ( |
48 |
| - "tritonbench.operators.vector_exp.operator", |
49 |
| - "examples.exp", |
50 |
| - "exp_tritonbench", |
51 |
| - ), |
52 | 41 | "rms_norm": (
|
53 | 42 | "tritonbench.operators.rms_norm.operator",
|
54 | 43 | "examples.rms_norm",
|
|
57 | 46 | "num_inputs": 3
|
58 | 47 | }, # TODO(yf225): reduction dim size = 8192 currently throws error
|
59 | 48 | ),
|
60 |
| - "sum": ("tritonbench.operators.sum.operator", "examples.sum", "sum_tritonbench"), |
| 49 | + "layer_norm": ( |
| 50 | + "tritonbench.operators.layer_norm.operator", |
| 51 | + "examples.layer_norm", |
| 52 | + "layer_norm_fwd", |
| 53 | + ), |
61 | 54 | "softmax": (
|
62 | 55 | "tritonbench.operators.softmax.operator",
|
63 | 56 | "examples.softmax",
|
64 | 57 | "softmax",
|
65 | 58 | ),
|
66 |
| - "jagged_mean": ( |
67 |
| - "tritonbench.operators.jagged_mean.operator", |
68 |
| - "examples.jagged_mean", |
69 |
| - "jagged_mean_tritonbench", |
70 |
| - {"B": 32, "M": 8, "seqlen": 64} |
71 |
| - if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1" |
72 |
| - else {}, |
73 |
| - ), |
74 |
| - "fp8_gemm": ( |
75 |
| - "tritonbench.operators.fp8_gemm.fp8_gemm", |
76 |
| - "examples.fp8_gemm", |
77 |
| - "fp8_gemm_tritonbench", |
78 |
| - ), |
79 |
| - "flash_attention": ( |
80 |
| - "tritonbench.operators.flash_attention.operator", |
81 |
| - "examples.attention", |
82 |
| - "attention", |
83 |
| - { |
84 |
| - "d_head": 128 |
85 |
| - }, # Set default head dimension to 128 for TLX attention compatibility |
86 |
| - ), |
87 | 59 | "cross_entropy": (
|
88 | 60 | "tritonbench.operators.cross_entropy.operator",
|
89 | 61 | "examples.cross_entropy",
|
|
92 | 64 | if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1"
|
93 | 65 | else {},
|
94 | 66 | ),
|
95 |
| - "fp8_attention": ( |
96 |
| - "tritonbench.operators.fp8_attention.operator", |
97 |
| - "examples.fp8_attention", |
98 |
| - "fp8_attention_tritonbench", |
| 67 | + "sum": ("tritonbench.operators.sum.operator", "examples.sum", "sum_tritonbench"), |
| 68 | + "jagged_mean": ( |
| 69 | + "tritonbench.operators.jagged_mean.operator", |
| 70 | + "examples.jagged_mean", |
| 71 | + "jagged_mean_tritonbench", |
| 72 | + {"B": 32, "M": 8, "seqlen": 64} |
| 73 | + if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1" |
| 74 | + else {}, |
99 | 75 | ),
|
100 |
| - "layer_norm": ( |
101 |
| - "tritonbench.operators.layer_norm.operator", |
102 |
| - "examples.layer_norm", |
103 |
| - "layer_norm_fwd", |
| 76 | + "vector_add": ("tritonbench.operators.vector_add.operator", "examples.add", "add"), |
| 77 | + "embedding": ( |
| 78 | + "tritonbench.operators.embedding.operator", |
| 79 | + "examples.embedding", |
| 80 | + "embedding_tritonbench", |
104 | 81 | ),
|
105 |
| - # Multiple kernel variants: |
106 |
| - "gemm": ( |
107 |
| - "tritonbench.operators.gemm.operator", |
108 |
| - [ |
109 |
| - ("examples.matmul", "matmul_tritonbench"), |
110 |
| - ("examples.matmul_split_k", "matmul_split_k_tritonbench"), |
111 |
| - ], |
| 82 | + "vector_exp": ( |
| 83 | + "tritonbench.operators.vector_exp.operator", |
| 84 | + "examples.exp", |
| 85 | + "exp_tritonbench", |
112 | 86 | ),
|
| 87 | + # "fp8_gemm": ( |
| 88 | + # "tritonbench.operators.fp8_gemm.fp8_gemm", |
| 89 | + # "examples.fp8_gemm", |
| 90 | + # "fp8_gemm_tritonbench", |
| 91 | + # ), |
| 92 | + # "flash_attention": ( |
| 93 | + # "tritonbench.operators.flash_attention.operator", |
| 94 | + # "examples.attention", |
| 95 | + # "attention", |
| 96 | + # { |
| 97 | + # "d_head": 128 |
| 98 | + # }, # Set default head dimension to 128 for TLX attention compatibility |
| 99 | + # ), |
| 100 | + # "fp8_attention": ( |
| 101 | + # "tritonbench.operators.fp8_attention.operator", |
| 102 | + # "examples.fp8_attention", |
| 103 | + # "fp8_attention_tritonbench", |
| 104 | + # ), |
| 105 | + # # Multiple kernel variants: |
| 106 | + # "gemm": ( |
| 107 | + # "tritonbench.operators.gemm.operator", |
| 108 | + # [ |
| 109 | + # ("examples.matmul", "matmul_tritonbench"), |
| 110 | + # ("examples.matmul_split_k", "matmul_split_k_tritonbench"), |
| 111 | + # ], |
| 112 | + # ), |
113 | 113 | }
|
| 114 | +# KERNEL_MAPPINGS: dict[str, tuple[str, ...]] = { # pyright: ignore[reportAssignmentType] |
| 115 | +# "flash_attention": ( |
| 116 | +# "tritonbench.operators.flash_attention.operator", |
| 117 | +# "examples.attention", |
| 118 | +# "attention", |
| 119 | +# { |
| 120 | +# "d_head": 128 |
| 121 | +# }, # Set default head dimension to 128 for TLX attention compatibility |
| 122 | +# ), |
| 123 | +# # Multiple kernel variants: |
| 124 | +# "gemm": ( |
| 125 | +# "tritonbench.operators.gemm.operator", |
| 126 | +# [ |
| 127 | +# ("examples.matmul", "matmul_tritonbench"), |
| 128 | +# ("examples.matmul_split_k", "matmul_split_k_tritonbench"), |
| 129 | +# ], |
| 130 | +# ), |
| 131 | +# "fp8_gemm": ( |
| 132 | +# "tritonbench.operators.fp8_gemm.fp8_gemm", |
| 133 | +# "examples.fp8_gemm", |
| 134 | +# "fp8_gemm_tritonbench", |
| 135 | +# ), |
| 136 | +# "fp8_attention": ( |
| 137 | +# "tritonbench.operators.fp8_attention.operator", |
| 138 | +# "examples.fp8_attention", |
| 139 | +# "fp8_attention_tritonbench", |
| 140 | +# ), |
| 141 | +# } |
114 | 142 |
|
115 | 143 |
|
116 | 144 | def get_system_memory_gb() -> float:
|
@@ -298,6 +326,23 @@ def run_kernel_variants(
|
298 | 326 | ) -> None:
|
299 | 327 | """Run kernel variants in the same benchmark run."""
|
300 | 328 |
|
| 329 | + # Configure Helion to use fewer generations for faster autotuning during benchmarks |
| 330 | + import helion |
| 331 | + from helion.autotuner import DifferentialEvolutionSearch, LocalAutotuneCache |
| 332 | + from helion.runtime.kernel import BoundKernel |
| 333 | + from typing import Sequence |
| 334 | + |
| 335 | + def fast_autotuner_fn( |
| 336 | + bound_kernel: BoundKernel, args: Sequence[object], **kwargs: object |
| 337 | + ) -> LocalAutotuneCache: |
| 338 | + # Use only 1 generation instead of default 20 for faster benchmarking |
| 339 | + return LocalAutotuneCache( |
| 340 | + DifferentialEvolutionSearch(bound_kernel, args, num_generations=1, **kwargs) |
| 341 | + ) |
| 342 | + |
| 343 | + # Set the custom autotuner function |
| 344 | + helion.set_default_settings(helion.Settings(autotuner_fn=fast_autotuner_fn)) |
| 345 | + |
301 | 346 | # Import tritonbench components
|
302 | 347 | try:
|
303 | 348 | from tritonbench.utils.parser import ( # pyright: ignore[reportMissingImports]
|
|
0 commit comments