Skip to content

Add more metrics to dashboard #2667

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 72 additions & 10 deletions benchmarks/dashboard/ci_microbenchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,43 +120,105 @@ def run_ci_benchmarks(config_path: str) -> List[Dict[str, Any]]:
result = run_inference(config)

if result is not None:
# Create benchmark result in OSS format
speedup_result = create_benchmark_result(
## Create benchmark result in OSS format

# Compile mode speedup
compile_speedup_result = create_benchmark_result(
benchmark_name="TorchAO Quantization Benchmark",
shape=[config.m, config.k, config.n],
metric_name="Fwd Speedup (x)",
metric_values=[result.speedup],
metric_values=[result.compile_speedup_on_baseline],
quant_type=config.quantization,
device=config.device,
torch_compile_mode=config.torch_compile_mode,
)
results.append(speedup_result)
baseline_time_result = create_benchmark_result(
results.append(compile_speedup_result)

# Compile mode baseline
compile_baseline_time_result = create_benchmark_result(
benchmark_name="TorchAO Quantization Benchmark",
shape=[config.m, config.k, config.n],
metric_name="Bfloat16 Fwd Time (ms)",
metric_values=[result.baseline_inference_time_in_ms],
metric_values=[result.compile_baseline_inference_time_in_ms],
quant_type=config.quantization,
device=config.device,
torch_compile_mode=config.torch_compile_mode,
metric_extra_info={
"unit": "ms",
},
)
results.append(baseline_time_result)
quantize_time_result = create_benchmark_result(
results.append(compile_baseline_time_result)

# Compile mode quantized
compile_quantize_time_result = create_benchmark_result(
benchmark_name="TorchAO Quantization Benchmark",
shape=[config.m, config.k, config.n],
metric_name="Quantized Fwd Time (ms)",
metric_values=[result.model_inference_time_in_ms],
metric_values=[result.compile_model_inference_time_in_ms],
quant_type=config.quantization,
device=config.device,
torch_compile_mode=config.torch_compile_mode,
metric_extra_info={
"unit": "ms",
},
)
results.append(compile_quantize_time_result)

# Eager mode speedup
eager_speedup_result = create_benchmark_result(
benchmark_name="TorchAO Quantization Benchmark",
shape=[config.m, config.k, config.n],
metric_name="Fwd Speedup w/ Eager (x)",
metric_values=[result.eager_speedup_on_baseline],
quant_type=config.quantization,
device=config.device,
torch_compile_mode=config.torch_compile_mode,
)
results.append(eager_speedup_result)

# Eager mode baseline
eager_baseline_time_result = create_benchmark_result(
benchmark_name="TorchAO Quantization Benchmark",
shape=[config.m, config.k, config.n],
metric_name="Bfloat16 Fwd Time w/ Eager (ms)",
metric_values=[result.eager_baseline_inference_time_in_ms],
quant_type=config.quantization,
device=config.device,
torch_compile_mode=config.torch_compile_mode,
metric_extra_info={
"unit": "ms",
},
)
results.append(quantize_time_result)
results.append(eager_baseline_time_result)

# Eager mode quantized
eager_quantize_time_result = create_benchmark_result(
benchmark_name="TorchAO Quantization Benchmark",
shape=[config.m, config.k, config.n],
metric_name="Quantized Fwd Time w/ Eager (ms)",
metric_values=[result.eager_model_inference_time_in_ms],
quant_type=config.quantization,
device=config.device,
torch_compile_mode=config.torch_compile_mode,
metric_extra_info={
"unit": "ms",
},
)
results.append(eager_quantize_time_result)

## Compile vs eager results
compile_eager_speedup_result = create_benchmark_result(
benchmark_name="TorchAO Quantization Benchmark",
shape=[config.m, config.k, config.n],
metric_name="Eager vs Compile Fwd Speedup (x)",
metric_values=[result.compile_speedup_on_eager],
quant_type=config.quantization,
device=config.device,
torch_compile_mode=config.torch_compile_mode,
)
results.append(compile_eager_speedup_result)

## Memory results
allocated_memory_result = create_benchmark_result(
benchmark_name="TorchAO Quantization Benchmark",
shape=[config.m, config.k, config.n],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ model_params:
min_power: 10
max_power: 15
high_precision_dtype: "torch.bfloat16"
use_torch_compile: true
torch_compile_mode: "max-autotune"
device: "cuda"
model_type: "linear"
Expand Down
147 changes: 118 additions & 29 deletions benchmarks/microbenchmarks/benchmark_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import os
from copy import deepcopy
from pathlib import Path
from typing import Dict, Tuple

import torch

Expand All @@ -34,15 +35,70 @@
create_model_and_input_data,
)

# -----------------------------------------------------------------------------
# Baseline caching
#
# ``_BASELINE_CACHE`` maps a unique key to a tuple
# ``(eager_baseline_time, compile_baseline_time)``. See ``_make_cache_key`` for the key
# construction. Users should not access this cache directly; it is
# internal to this module. The cache intentionally holds the
# uncompiled base model so that quantized versions can be derived
# without mutating the cached copy.

_BASELINE_CACHE: Dict[Tuple, Tuple[float, float]] = {}


def _make_cache_key(config: BenchmarkConfig) -> Tuple:
"""Create a key for caching based on benchmark configuration.

Parameters that affect baseline performance are included:

* model type (e.g. ``linear`` or ``transformer_block``)
* shape dimensions (m, k, n)
* high precision dtype (bf16, fp16, etc.)
* device (cuda, cpu, mps)
* compile settings (whether compile is enabled and compile mode)

Sparsity and quantization settings are deliberately excluded
because the baseline (non‑quantized, non‑sparse) performance is
independent of those attributes.
"""
return (
config.model_type,
config.m,
config.k,
config.n,
config.high_precision_dtype,
config.device,
config.torch_compile_mode,
)


def run(config: BenchmarkConfig) -> BenchmarkResult:
"""Run inference benchmarks"""
"""
Run inference benchmarks.

The function first checks if a baseline for the given configuration
already exists in the internal cache. If not, it measures the baseline
inference time and stores the result. When the baseline is cached,
the function reuses the cached baselines to calculate speedup metrics.

Args:
config (BenchmarkConfig): Benchmark configuration.

Returns:
BenchmarkResult: Result of the benchmark.
"""
try:
clean_caches() # Clean caches

# Create output directory if it doesn't exist
Path(config.output_dir).mkdir(parents=True, exist_ok=True)

# Prepare result container
result = BenchmarkResult(config=config)

# Create model and input data
base_model, input_data = create_model_and_input_data(
config.model_type,
config.m,
Expand All @@ -51,28 +107,46 @@ def run(config: BenchmarkConfig) -> BenchmarkResult:
high_precision_dtype=config.high_precision_dtype,
device=config.device,
)
# Copy base model for quantizing
m_copy = deepcopy(base_model)

# Run benchmarks
result = BenchmarkResult(config=config)
# Generate a cache key for the current configuration
cache_key = _make_cache_key(config)

# Store result in model for memory profiling
base_model._benchmark_result = result
# Check if the baseline for this configuration has been computed
if cache_key not in _BASELINE_CACHE:
# Switch model to eval and move to device
base_model = base_model.eval().to(config.device)
print("Benchmarking eager baseline inference.....")
eager_baseline_time = model_inference_time_in_ms(
model=base_model, input_data=input_data
)

# Run baseline benchmarking
base_model = base_model.eval().to(config.device)
if config.use_torch_compile:
print("Compiling baseline model....")
print("Benchmarking compile baseline inference.....")
base_model = torch.compile(
base_model, mode=config.torch_compile_mode, fullgraph=True
)
# Benchmark time to run an inference call for baseline model
print("Benchmarking baseline inference.....")
result.baseline_inference_time_in_ms = model_inference_time_in_ms(
model=base_model, input_data=input_data
)
compile_baseline_time = model_inference_time_in_ms(
model=base_model, input_data=input_data
)

# Store uncompiled model, input and baseline time
_BASELINE_CACHE[cache_key] = (eager_baseline_time, compile_baseline_time)

result.eager_baseline_inference_time_in_ms = eager_baseline_time
result.compile_baseline_inference_time_in_ms = compile_baseline_time
else:
# Retrieve cached values
cached_eager_time, cached_compile_time = _BASELINE_CACHE[cache_key]
result.eager_baseline_inference_time_in_ms = cached_eager_time
result.compile_baseline_inference_time_in_ms = cached_compile_time

# At this point, ``base_model`` is an uncompiled model ready for quantization,
# and ``input_data`` is the corresponding input tensor. The baseline time
# has been stored in ``result.baseline_inference_time_in_ms``.

# Copy base model for quantizing/sparsifying
m_copy = deepcopy(base_model)

# Determine quantization/sparsity configuration
ao_base_config = string_to_config(
config.quantization,
config.sparsity,
Expand Down Expand Up @@ -101,24 +175,39 @@ def run(config: BenchmarkConfig) -> BenchmarkResult:
m_copy = m_copy.eval().to(config.device)
quantize_(m_copy, ao_base_config)

if config.use_torch_compile:
print("Compiling quantized model....")
m_copy = torch.compile(
m_copy, mode=config.torch_compile_mode, fullgraph=True
)

# Store result in model for memory profiling
m_copy._benchmark_result = result

# Benchmark time to run an inference call for quantized model
# Measure inference time for quantized model
print("Benchmarking eager quantized model.....")
result.eager_model_inference_time_in_ms = model_inference_time_in_ms(
model=m_copy, input_data=input_data
)

# Measure inference time for compiled quantized model
print("Benchmarking quantized model.....")
result.model_inference_time_in_ms = model_inference_time_in_ms(
m_copy = torch.compile(m_copy, mode=config.torch_compile_mode, fullgraph=True)
result.compile_model_inference_time_in_ms = model_inference_time_in_ms(
model=m_copy, input_data=input_data
)

# Calculate speedup w.r.t. baseline
result.speedup = round(
result.baseline_inference_time_in_ms / result.model_inference_time_in_ms, 2
# Compute eager speedup relative to baseline
result.eager_speedup_on_baseline = round(
result.eager_baseline_inference_time_in_ms
/ result.eager_model_inference_time_in_ms,
2,
)
# Compute compile speedup relative to baseline
result.compile_speedup_on_baseline = round(
result.compile_baseline_inference_time_in_ms
/ result.compile_model_inference_time_in_ms,
2,
)
# Compute compile speedup for quantized model relative to eager quantized model
result.compile_speedup_on_eager = round(
result.eager_model_inference_time_in_ms
/ result.compile_model_inference_time_in_ms,
2,
)

# Run profiler if enabled
Expand Down Expand Up @@ -165,9 +254,9 @@ def run(config: BenchmarkConfig) -> BenchmarkResult:
result.memory_profile_path
)
except ValueError as e:
if "not enough values to unpack" in e:
if "not enough values to unpack" in str(e):
print(
"Failed due to existing bugs, re-run the code to generate memory profile. Please raise an issue if it persists."
"Failed due to existing bugs, rerun the code to generate memory profile. Please raise an issue if it persists."
)
except Exception as e:
print(f"Error running memory profiler: {e}")
Expand Down
3 changes: 0 additions & 3 deletions benchmarks/microbenchmarks/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,6 @@ def get_quantization_sparsity_recipes(
"""
config_recipes = set()

# Always include baseline without sparsity
config_recipes.add(("baseline", None))

# Add all quantization techniques without sparsity
for quant_config in quantization_recipes:
config_recipes.add((quant_config, None))
Expand Down
4 changes: 0 additions & 4 deletions benchmarks/microbenchmarks/test/benchmark_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ model_params:
min_power: 14
max_power: 16
high_precision_dtype: "torch.bfloat16"
use_torch_compile: true
torch_compile_mode: "max-autotune"
device: "cuda"
model_type: "linear"
Expand All @@ -27,7 +26,6 @@ model_params:
[2048, 4096, 1024],
]
high_precision_dtype: "torch.bfloat16"
use_torch_compile: true
torch_compile_mode: "max-autotune"
device: "cuda"
model_type: "ln_linear_sigmoid"
Expand All @@ -41,7 +39,6 @@ model_params:
[2048, 4096, 1024], # For transformer_block, k is the hidden dimension
]
high_precision_dtype: "torch.bfloat16"
use_torch_compile: true
torch_compile_mode: "max-autotune"
device: "cuda"
model_type: "transformer_block" # TODO: Add a custom model (Figure out how to do this, maybe pass a .py file with model definition)
Expand All @@ -58,7 +55,6 @@ model_params:
min_power: 10 # 1024
max_power: 11 # 2048
high_precision_dtype: "torch.bfloat16"
use_torch_compile: true
torch_compile_mode: "max-autotune"
device: "cuda"
model_type: "linear"
Expand Down
Loading
Loading