Skip to content

Commit f99c8e1

Browse files
committed
More error catching; run kernels from explicit list
1 parent 831db9c commit f99c8e1

File tree

2 files changed

+124
-64
lines changed

2 files changed

+124
-64
lines changed

benchmarks/run.py

Lines changed: 93 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,6 @@
3838
# - Multiple kernels with args: (tritonbench_module, [(helion_module, helion_func), ...], args_dict)
3939
KERNEL_MAPPINGS: dict[str, tuple[str, ...]] = { # pyright: ignore[reportAssignmentType]
4040
# <tritonbench_op_name>: (<tritonbench_module_path>, <helion_kernel_module_path>, <helion_kernel_function_name>)
41-
"vector_add": ("tritonbench.operators.vector_add.operator", "examples.add", "add"),
42-
"embedding": (
43-
"tritonbench.operators.embedding.operator",
44-
"examples.embedding",
45-
"embedding_tritonbench",
46-
),
47-
"vector_exp": (
48-
"tritonbench.operators.vector_exp.operator",
49-
"examples.exp",
50-
"exp_tritonbench",
51-
),
5241
"rms_norm": (
5342
"tritonbench.operators.rms_norm.operator",
5443
"examples.rms_norm",
@@ -57,33 +46,16 @@
5746
"num_inputs": 3
5847
}, # TODO(yf225): reduction dim size = 8192 currently throws error
5948
),
60-
"sum": ("tritonbench.operators.sum.operator", "examples.sum", "sum_tritonbench"),
49+
"layer_norm": (
50+
"tritonbench.operators.layer_norm.operator",
51+
"examples.layer_norm",
52+
"layer_norm_fwd",
53+
),
6154
"softmax": (
6255
"tritonbench.operators.softmax.operator",
6356
"examples.softmax",
6457
"softmax",
6558
),
66-
"jagged_mean": (
67-
"tritonbench.operators.jagged_mean.operator",
68-
"examples.jagged_mean",
69-
"jagged_mean_tritonbench",
70-
{"B": 32, "M": 8, "seqlen": 64}
71-
if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1"
72-
else {},
73-
),
74-
"fp8_gemm": (
75-
"tritonbench.operators.fp8_gemm.fp8_gemm",
76-
"examples.fp8_gemm",
77-
"fp8_gemm_tritonbench",
78-
),
79-
"flash_attention": (
80-
"tritonbench.operators.flash_attention.operator",
81-
"examples.attention",
82-
"attention",
83-
{
84-
"d_head": 128
85-
}, # Set default head dimension to 128 for TLX attention compatibility
86-
),
8759
"cross_entropy": (
8860
"tritonbench.operators.cross_entropy.operator",
8961
"examples.cross_entropy",
@@ -92,25 +64,81 @@
9264
if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1"
9365
else {},
9466
),
95-
"fp8_attention": (
96-
"tritonbench.operators.fp8_attention.operator",
97-
"examples.fp8_attention",
98-
"fp8_attention_tritonbench",
67+
"sum": ("tritonbench.operators.sum.operator", "examples.sum", "sum_tritonbench"),
68+
"jagged_mean": (
69+
"tritonbench.operators.jagged_mean.operator",
70+
"examples.jagged_mean",
71+
"jagged_mean_tritonbench",
72+
{"B": 32, "M": 8, "seqlen": 64}
73+
if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1"
74+
else {},
9975
),
100-
"layer_norm": (
101-
"tritonbench.operators.layer_norm.operator",
102-
"examples.layer_norm",
103-
"layer_norm_fwd",
76+
"vector_add": ("tritonbench.operators.vector_add.operator", "examples.add", "add"),
77+
"embedding": (
78+
"tritonbench.operators.embedding.operator",
79+
"examples.embedding",
80+
"embedding_tritonbench",
10481
),
105-
# Multiple kernel variants:
106-
"gemm": (
107-
"tritonbench.operators.gemm.operator",
108-
[
109-
("examples.matmul", "matmul_tritonbench"),
110-
("examples.matmul_split_k", "matmul_split_k_tritonbench"),
111-
],
82+
"vector_exp": (
83+
"tritonbench.operators.vector_exp.operator",
84+
"examples.exp",
85+
"exp_tritonbench",
11286
),
87+
# "fp8_gemm": (
88+
# "tritonbench.operators.fp8_gemm.fp8_gemm",
89+
# "examples.fp8_gemm",
90+
# "fp8_gemm_tritonbench",
91+
# ),
92+
# "flash_attention": (
93+
# "tritonbench.operators.flash_attention.operator",
94+
# "examples.attention",
95+
# "attention",
96+
# {
97+
# "d_head": 128
98+
# }, # Set default head dimension to 128 for TLX attention compatibility
99+
# ),
100+
# "fp8_attention": (
101+
# "tritonbench.operators.fp8_attention.operator",
102+
# "examples.fp8_attention",
103+
# "fp8_attention_tritonbench",
104+
# ),
105+
# # Multiple kernel variants:
106+
# "gemm": (
107+
# "tritonbench.operators.gemm.operator",
108+
# [
109+
# ("examples.matmul", "matmul_tritonbench"),
110+
# ("examples.matmul_split_k", "matmul_split_k_tritonbench"),
111+
# ],
112+
# ),
113113
}
114+
# KERNEL_MAPPINGS: dict[str, tuple[str, ...]] = { # pyright: ignore[reportAssignmentType]
115+
# "flash_attention": (
116+
# "tritonbench.operators.flash_attention.operator",
117+
# "examples.attention",
118+
# "attention",
119+
# {
120+
# "d_head": 128
121+
# }, # Set default head dimension to 128 for TLX attention compatibility
122+
# ),
123+
# # Multiple kernel variants:
124+
# "gemm": (
125+
# "tritonbench.operators.gemm.operator",
126+
# [
127+
# ("examples.matmul", "matmul_tritonbench"),
128+
# ("examples.matmul_split_k", "matmul_split_k_tritonbench"),
129+
# ],
130+
# ),
131+
# "fp8_gemm": (
132+
# "tritonbench.operators.fp8_gemm.fp8_gemm",
133+
# "examples.fp8_gemm",
134+
# "fp8_gemm_tritonbench",
135+
# ),
136+
# "fp8_attention": (
137+
# "tritonbench.operators.fp8_attention.operator",
138+
# "examples.fp8_attention",
139+
# "fp8_attention_tritonbench",
140+
# ),
141+
# }
114142

115143

116144
def get_system_memory_gb() -> float:
@@ -298,6 +326,23 @@ def run_kernel_variants(
298326
) -> None:
299327
"""Run kernel variants in the same benchmark run."""
300328

329+
# Configure Helion to use fewer generations for faster autotuning during benchmarks
330+
import helion
331+
from helion.autotuner import DifferentialEvolutionSearch, LocalAutotuneCache
332+
from helion.runtime.kernel import BoundKernel
333+
from typing import Sequence
334+
335+
def fast_autotuner_fn(
336+
bound_kernel: BoundKernel, args: Sequence[object], **kwargs: object
337+
) -> LocalAutotuneCache:
338+
# Use only 1 generation instead of default 20 for faster benchmarking
339+
return LocalAutotuneCache(
340+
DifferentialEvolutionSearch(bound_kernel, args, num_generations=1, **kwargs)
341+
)
342+
343+
# Set the custom autotuner function
344+
helion.set_default_settings(helion.Settings(autotuner_fn=fast_autotuner_fn))
345+
301346
# Import tritonbench components
302347
try:
303348
from tritonbench.utils.parser import ( # pyright: ignore[reportMissingImports]

benchmarks/run_input_shard.sh

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,27 +7,42 @@ TIMESTAMP=$(date +%s)
77
OUTPUT_FILE="benchmarks_autotune_${TIMESTAMP}_input_shard_${SHARD}_of_${WORLD_SIZE}.txt"
88
CSV_OUTPUT_DIR="benchmarks_autotune_${TIMESTAMP}_input_shard_${SHARD}_of_${WORLD_SIZE}_csv"
99

10+
KERNEL_NAME_LIST = (
11+
"rms_norm"
12+
"layer_norm"
13+
"softmax"
14+
"cross_entropy"
15+
"sum"
16+
"jagged_mean"
17+
"vector_add"
18+
"embedding"
19+
"vector_exp"
20+
)
21+
1022
# Retry until success
1123
attempt=0
12-
while true; do
13-
# while (( attempt < 10 )); do
14-
attempt=$((attempt + 1))
15-
echo "Attempt $attempt: Running benchmark for shard ${SHARD}/${WORLD_SIZE}..."
24+
for KERNEL_NAME in KERNEL_NAME_LIST; do
25+
while true; do
26+
# while (( attempt < 10 )); do
27+
attempt=$((attempt + 1))
28+
echo "Attempt $attempt: Running benchmark for shard ${SHARD}/${WORLD_SIZE}..."
1629

17-
# TIMESTAMP=$(date +%s)
18-
# OUTPUT_FILE="benchmarks_autotune_${TIMESTAMP}_input_shard_${SHARD}_of_${WORLD_SIZE}.txt"
30+
# TIMESTAMP=$(date +%s)
31+
# OUTPUT_FILE="benchmarks_autotune_${TIMESTAMP}_input_shard_${SHARD}_of_${WORLD_SIZE}.txt"
1932

20-
mkdir -p ${CSV_OUTPUT_DIR} || true
21-
CUDA_VISIBLE_DEVICES=$((RANK_OFFSET+SHARD-1)) python benchmarks/run.py --input-shard ${SHARD}/${WORLD_SIZE} --metrics accuracy,tflops,gbps,speedup --csv --output-dir ${CSV_OUTPUT_DIR} >"$OUTPUT_FILE" 2>&1
33+
mkdir -p ${CSV_OUTPUT_DIR} || true
34+
CUDA_VISIBLE_DEVICES=$((RANK_OFFSET+SHARD-1)) python benchmarks/run.py --input-shard ${SHARD}/${WORLD_SIZE} --kernel ${KERNEL_NAME} --metrics accuracy,tflops,gbps,speedup --csv --output-dir ${CSV_OUTPUT_DIR} >"$OUTPUT_FILE" 2>&1
2235

23-
exit_code=$?
24-
if [ $exit_code -eq 0 ]; then
25-
echo "Success! Benchmark completed for shard ${SHARD}/${WORLD_SIZE}"
26-
break
27-
else
28-
echo "Failed with exit code $exit_code. Retrying..."
29-
sleep 10 # wait a few seconds before retrying
30-
fi
36+
exit_code=$?
37+
# Check for success: exit code 0 AND no exception message in output
38+
if [ $exit_code -eq 0 ] && ! grep -q "Caught exception, terminating early with partial results" "$OUTPUT_FILE"; then
39+
echo "Success! Benchmark completed for shard $((SHARD+1))/${WORLD_SIZE}"
40+
break
41+
else
42+
echo "Failed with exit code $exit_code. Retrying..."
43+
sleep 10 # wait a few seconds before retrying
44+
fi
45+
done
3146
done
3247

3348
# Runs the 1st shard of input on GPU-0:

0 commit comments

Comments
 (0)