Skip to content

[WIP] 20250722 benchmark sweep #347

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
277 changes: 220 additions & 57 deletions benchmarks/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import sys
from typing import Any
from typing import Callable
import time

# Maps tritonbench op names to Helion kernel examples
# Can map to a single kernel or a list of kernel variants
Expand All @@ -37,17 +38,6 @@
# - Multiple kernels with args: (tritonbench_module, [(helion_module, helion_func), ...], args_dict)
KERNEL_MAPPINGS: dict[str, tuple[str, ...]] = { # pyright: ignore[reportAssignmentType]
# <tritonbench_op_name>: (<tritonbench_module_path>, <helion_kernel_module_path>, <helion_kernel_function_name>)
"vector_add": ("tritonbench.operators.vector_add.operator", "examples.add", "add"),
"embedding": (
"tritonbench.operators.embedding.operator",
"examples.embedding",
"embedding_tritonbench",
),
"vector_exp": (
"tritonbench.operators.vector_exp.operator",
"examples.exp",
"exp_tritonbench",
),
"rms_norm": (
"tritonbench.operators.rms_norm.operator",
"examples.rms_norm",
Expand All @@ -56,32 +46,43 @@
"num_inputs": 3
}, # TODO(yf225): reduction dim size = 8192 currently throws error
),
"sum": ("tritonbench.operators.sum.operator", "examples.sum", "sum_tritonbench"),
"layer_norm": (
"tritonbench.operators.layer_norm.operator",
"examples.layer_norm",
"layer_norm_fwd",
),
"softmax": (
"tritonbench.operators.softmax.operator",
"examples.softmax",
"softmax",
),
"jagged_mean": (
"tritonbench.operators.jagged_mean.operator",
"examples.jagged_mean",
"jagged_mean_tritonbench",
{"B": 32, "M": 8, "seqlen": 64}
if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1"
else {},
),
"fp8_gemm": (
"tritonbench.operators.fp8_gemm.fp8_gemm",
"examples.fp8_gemm",
"fp8_gemm_tritonbench",
),
"flash_attention": (
"tritonbench.operators.flash_attention.operator",
"examples.attention",
"attention",
{
"d_head": 128
}, # Set default head dimension to 128 for TLX attention compatibility
"only_shapes": [
[4096, 6912 * 2 - 4096],
[4096, 6976 * 2 - 4096],
[4096, 7040 * 2 - 4096],
[4096, 7104 * 2 - 4096],
[4096, 7168 * 2 - 4096],
[4096, 7232 * 2 - 4096],
[4096, 7296 * 2 - 4096],
[4096, 7360 * 2 - 4096],
[4096, 7424 * 2 - 4096],
[4096, 7488 * 2 - 4096],
[4096, 7552 * 2 - 4096],
[4096, 7616 * 2 - 4096],
[4096, 7680 * 2 - 4096],
[4096, 7744 * 2 - 4096],
[4096, 7808 * 2 - 4096],
[4096, 7872 * 2 - 4096],
[4096, 7936 * 2 - 4096],
[4096, 8000 * 2 - 4096],
[4096, 8064 * 2 - 4096],
[4096, 8128 * 2 - 4096],
[4096, 8192 * 2 - 4096],
[4096, 8256 * 2 - 4096],
[4096, 8320 * 2 - 4096],
[4096, 8384 * 2 - 4096],
]
},
),
"cross_entropy": (
"tritonbench.operators.cross_entropy.operator",
Expand All @@ -91,25 +92,105 @@
if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1"
else {},
),
"fp8_attention": (
"tritonbench.operators.fp8_attention.operator",
"examples.fp8_attention",
"fp8_attention_tritonbench",
"sum": ("tritonbench.operators.sum.operator", "examples.sum", "sum_tritonbench"),
"jagged_mean": (
"tritonbench.operators.jagged_mean.operator",
"examples.jagged_mean",
"jagged_mean_tritonbench",
{"B": 32, "M": 8, "seqlen": 64}
if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1"
else {"B": 512, "M": 64},
),
"layer_norm": (
"tritonbench.operators.layer_norm.operator",
"examples.layer_norm",
"layer_norm_fwd",
"vector_add": ("tritonbench.operators.vector_add.operator", "examples.add", "add"),
"embedding": (
"tritonbench.operators.embedding.operator",
"examples.embedding",
"embedding_tritonbench",
{
"only_shapes": [
(8, 2048, 4096, 16384),
(8, 2048, 4096, 32768),
(8, 2048, 4096, 65536),
(8, 2048, 4096, 131072),
]
},
),
# Multiple kernel variants:
"gemm": (
"tritonbench.operators.gemm.operator",
[
("examples.matmul", "matmul_tritonbench"),
("examples.matmul_split_k", "matmul_split_k_tritonbench"),
],
"vector_exp": (
"tritonbench.operators.vector_exp.operator",
"examples.exp",
"exp_tritonbench",
{
"only_shapes": [
65536,
131072,
262144,
524288,
1048576,
2097152,
4194304,
8388608,
16777216,
33554432,
67108864,
134217728,
]
},
),
# "fp8_gemm": (
# "tritonbench.operators.fp8_gemm.fp8_gemm",
# "examples.fp8_gemm",
# "fp8_gemm_tritonbench",
# ),
# "flash_attention": (
# "tritonbench.operators.flash_attention.operator",
# "examples.attention",
# "attention",
# {
# "d_head": 128
# }, # Set default head dimension to 128 for TLX attention compatibility
# ),
# "fp8_attention": (
# "tritonbench.operators.fp8_attention.operator",
# "examples.fp8_attention",
# "fp8_attention_tritonbench",
# ),
# # Multiple kernel variants:
# "gemm": (
# "tritonbench.operators.gemm.operator",
# [
# ("examples.matmul", "matmul_tritonbench"),
# ("examples.matmul_split_k", "matmul_split_k_tritonbench"),
# ],
# ),
}
# KERNEL_MAPPINGS: dict[str, tuple[str, ...]] = { # pyright: ignore[reportAssignmentType]
# "flash_attention": (
# "tritonbench.operators.flash_attention.operator",
# "examples.attention",
# "attention",
# {
# "d_head": 128
# }, # Set default head dimension to 128 for TLX attention compatibility
# ),
# # Multiple kernel variants:
# "gemm": (
# "tritonbench.operators.gemm.operator",
# [
# ("examples.matmul", "matmul_tritonbench"),
# ("examples.matmul_split_k", "matmul_split_k_tritonbench"),
# ],
# ),
# "fp8_gemm": (
# "tritonbench.operators.fp8_gemm.fp8_gemm",
# "examples.fp8_gemm",
# "fp8_gemm_tritonbench",
# ),
# "fp8_attention": (
# "tritonbench.operators.fp8_attention.operator",
# "examples.fp8_attention",
# "fp8_attention_tritonbench",
# ),
# }


def get_system_memory_gb() -> float:
Expand Down Expand Up @@ -252,6 +333,7 @@ def run_kernel(

# Extract operator args if present
operator_args = {}
only_shapes = None

# Normalize to list of variants format
if isinstance(mapping[1], list):
Expand All @@ -260,15 +342,21 @@ def run_kernel(
variants = mapping[1]
# Check if last element is args dict
if len(mapping) > 2 and isinstance(mapping[2], dict):
operator_args = mapping[2]
operator_args = mapping[2].copy()
# Extract only_shapes if present
if "only_shapes" in operator_args:
only_shapes = operator_args.pop("only_shapes")
else:
# Single kernel format
if len(mapping) == 4 and isinstance(mapping[3], dict):
# With args
tritonbench_module = mapping[0]
module_path = mapping[1]
func_name = mapping[2]
operator_args = mapping[3] # pyright: ignore[reportGeneralTypeIssues]
operator_args = mapping[3].copy() # pyright: ignore[reportGeneralTypeIssues]
# Extract only_shapes if present
if "only_shapes" in operator_args:
only_shapes = operator_args.pop("only_shapes")
variants = [(module_path, func_name)]
else:
# Without args
Expand All @@ -284,6 +372,7 @@ def run_kernel(
tritonbench_args,
input_shard_info,
operator_args,
only_shapes,
)


Expand All @@ -294,6 +383,7 @@ def run_kernel_variants(
tritonbench_args: list[str],
input_shard_info: tuple[int, int] | None = None,
operator_args: dict[str, Any] | None = None,
only_shapes: list[str] | None = None,
) -> None:
"""Run kernel variants in the same benchmark run."""

Expand All @@ -320,10 +410,23 @@ def run_kernel_variants(

# Add operator-specific default args if provided
if operator_args:
print(
f"Applying custom args for {operator_name}: {operator_args}",
file=sys.stderr,
)
# First, remove any existing occurrences of these args
for arg_name, arg_value in operator_args.items():
arg_flag = f"--{arg_name.replace('_', '-')}"
if arg_flag not in tritonbench_args:
tritonbench_args.extend([arg_flag, str(arg_value)])
# Remove existing arg if present
while arg_flag in tritonbench_args:
idx = tritonbench_args.index(arg_flag)
tritonbench_args.pop(idx) # Remove flag
if idx < len(tritonbench_args) and not tritonbench_args[idx].startswith(
"--"
):
tritonbench_args.pop(idx) # Remove value
# Add the custom arg
tritonbench_args.extend([arg_flag, str(arg_value)])

# Parse known args and collect unknown ones for operator
tb_args, unknown_args = tb_parser.parse_known_args(tritonbench_args)
Expand All @@ -345,6 +448,69 @@ def run_kernel_variants(
from tritonbench.utils.triton_op import ( # pyright: ignore[reportMissingImports]
register_benchmark,
)

# Inject only_shapes filter if provided
if only_shapes:
print(f"Using only_shapes for {kernel_name}: {only_shapes}", file=sys.stderr)

# Override the get_input_iter method for the operator class
original_get_input_iter = Operator.get_input_iter
original_get_x_val = Operator.get_x_val if hasattr(Operator, 'get_x_val') else None

# Create a list to store filtered inputs and their shapes
filtered_inputs = []

# First, collect all inputs that match the shape filter
temp_operator = Operator(tb_args=tb_args, extra_args=unknown_args)
for inputs in original_get_input_iter(temp_operator):
# Get the shape value for this input
shape_value = None

if original_get_x_val:
# Use the operator's get_x_val method to get shape representation
shape_value = original_get_x_val(temp_operator, inputs)
else:
# Fallback: try to get shape from the inputs directly
if isinstance(inputs, tuple) and len(inputs) > 0:
if hasattr(inputs[0], 'shape'):
shape_value = list(inputs[0].shape)
elif isinstance(inputs[0], (int, float)):
shape_value = inputs[0]
else:
# For complex inputs, try to extract meaningful shape info
shape_value = inputs

# Check if this shape matches any in our filter using direct comparison
match_found = False
for expected_shape in only_shapes:
if shape_value == expected_shape:
match_found = True
break
# Also check if shape_value is a tuple/list that matches
elif isinstance(shape_value, (tuple, list)) and isinstance(expected_shape, (tuple, list)):
if len(shape_value) == len(expected_shape) and all(a == b for a, b in zip(shape_value, expected_shape)):
match_found = True
break

if match_found:
filtered_inputs.append(inputs)
print(f" Including shape: {shape_value}", file=sys.stderr)

del temp_operator # Clean up temporary operator

if not filtered_inputs:
print(f"Warning: No shapes matched the filter for {kernel_name}", file=sys.stderr)

def filtered_get_input_iter(self):
"""Custom input iterator that only yields filtered shapes."""
for inputs in filtered_inputs:
yield inputs

# Monkey-patch the operator class
Operator.get_input_iter = filtered_get_input_iter

# Also override _available_num_inputs for proper sharding support
Operator._available_num_inputs = len(filtered_inputs)

# Register all variants as separate methods
for module_path, func_name in variants:
Expand Down Expand Up @@ -389,6 +555,7 @@ def helion_method(
# This ensures we run autotuning even if the kernel has pre-specified configs
if os.environ.get("HELION_USE_DEFAULT_CONFIG", "0") != "1":
attr.settings.force_autotune = True
attr.settings.static_shapes = True

def _inner() -> Callable[..., Any] | object:
# BENCHMARK HOT PATH, do not add any new logic here
Expand Down Expand Up @@ -429,8 +596,6 @@ def _inner() -> Callable[..., Any] | object:
file=sys.stderr,
)

from tritonbench.run import _run

# Handle input sharding if requested
if input_shard_info:
shard_idx, total_shards = input_shard_info
Expand Down Expand Up @@ -464,11 +629,9 @@ def _inner() -> Callable[..., Any] | object:
["--input-id", str(start_idx), "--num-inputs", str(shard_size)]
)

# Re-parse args with the new input range
tb_args, unknown_args = tb_parser.parse_known_args(tritonbench_args)
from tritonbench.run import run as tritonbench_run

# Use tritonbench's _run function which handles arg processing
_run(tb_args, unknown_args)
tritonbench_run(tritonbench_args)

# Force garbage collection multiple times to ensure memory is freed
for _ in range(3):
Expand Down
Loading