diff --git a/benchmarks/dashboard/ci_microbenchmark_runner.py b/benchmarks/dashboard/ci_microbenchmark_runner.py index a8b7ae048d..d492712d85 100644 --- a/benchmarks/dashboard/ci_microbenchmark_runner.py +++ b/benchmarks/dashboard/ci_microbenchmark_runner.py @@ -120,22 +120,26 @@ def run_ci_benchmarks(config_path: str) -> List[Dict[str, Any]]: result = run_inference(config) if result is not None: - # Create benchmark result in OSS format - speedup_result = create_benchmark_result( + ## Create benchmark result in OSS format + + # Compile mode speedup + compile_speedup_result = create_benchmark_result( benchmark_name="TorchAO Quantization Benchmark", shape=[config.m, config.k, config.n], metric_name="Fwd Speedup (x)", - metric_values=[result.speedup], + metric_values=[result.compile_speedup_on_baseline], quant_type=config.quantization, device=config.device, torch_compile_mode=config.torch_compile_mode, ) - results.append(speedup_result) - baseline_time_result = create_benchmark_result( + results.append(compile_speedup_result) + + # Compile mode baseline + compile_baseline_time_result = create_benchmark_result( benchmark_name="TorchAO Quantization Benchmark", shape=[config.m, config.k, config.n], metric_name="Bfloat16 Fwd Time (ms)", - metric_values=[result.baseline_inference_time_in_ms], + metric_values=[result.compile_baseline_inference_time_in_ms], quant_type=config.quantization, device=config.device, torch_compile_mode=config.torch_compile_mode, @@ -143,12 +147,41 @@ def run_ci_benchmarks(config_path: str) -> List[Dict[str, Any]]: "unit": "ms", }, ) - results.append(baseline_time_result) - quantize_time_result = create_benchmark_result( + results.append(compile_baseline_time_result) + + # Compile mode quantized + compile_quantize_time_result = create_benchmark_result( benchmark_name="TorchAO Quantization Benchmark", shape=[config.m, config.k, config.n], metric_name="Quantized Fwd Time (ms)", - metric_values=[result.model_inference_time_in_ms], + metric_values=[result.compile_model_inference_time_in_ms], + quant_type=config.quantization, + device=config.device, + torch_compile_mode=config.torch_compile_mode, + metric_extra_info={ + "unit": "ms", + }, + ) + results.append(compile_quantize_time_result) + + # Eager mode speedup + eager_speedup_result = create_benchmark_result( + benchmark_name="TorchAO Quantization Benchmark", + shape=[config.m, config.k, config.n], + metric_name="Fwd Speedup w/ Eager (x)", + metric_values=[result.eager_speedup_on_baseline], + quant_type=config.quantization, + device=config.device, + torch_compile_mode=config.torch_compile_mode, + ) + results.append(eager_speedup_result) + + # Eager mode baseline + eager_baseline_time_result = create_benchmark_result( + benchmark_name="TorchAO Quantization Benchmark", + shape=[config.m, config.k, config.n], + metric_name="Bfloat16 Fwd Time w/ Eager (ms)", + metric_values=[result.eager_baseline_inference_time_in_ms], quant_type=config.quantization, device=config.device, torch_compile_mode=config.torch_compile_mode, @@ -156,7 +189,36 @@ def run_ci_benchmarks(config_path: str) -> List[Dict[str, Any]]: "unit": "ms", }, ) - results.append(quantize_time_result) + results.append(eager_baseline_time_result) + + # Eager mode quantized + eager_quantize_time_result = create_benchmark_result( + benchmark_name="TorchAO Quantization Benchmark", + shape=[config.m, config.k, config.n], + metric_name="Quantized Fwd Time w/ Eager (ms)", + metric_values=[result.eager_model_inference_time_in_ms], + quant_type=config.quantization, + device=config.device, + torch_compile_mode=config.torch_compile_mode, + metric_extra_info={ + "unit": "ms", + }, + ) + results.append(eager_quantize_time_result) + + ## Compile vs eager results + compile_eager_speedup_result = create_benchmark_result( + benchmark_name="TorchAO Quantization Benchmark", + shape=[config.m, config.k, config.n], + metric_name="Eager vs Compile Fwd Speedup (x)", + metric_values=[result.compile_speedup_on_eager], + quant_type=config.quantization, + device=config.device, + torch_compile_mode=config.torch_compile_mode, + ) + results.append(compile_eager_speedup_result) + + ## Memory results allocated_memory_result = create_benchmark_result( benchmark_name="TorchAO Quantization Benchmark", shape=[config.m, config.k, config.n], diff --git a/benchmarks/dashboard/microbenchmark_quantization_config.yml b/benchmarks/dashboard/microbenchmark_quantization_config.yml index 774237d54c..8156422668 100644 --- a/benchmarks/dashboard/microbenchmark_quantization_config.yml +++ b/benchmarks/dashboard/microbenchmark_quantization_config.yml @@ -14,7 +14,6 @@ model_params: min_power: 10 max_power: 15 high_precision_dtype: "torch.bfloat16" - use_torch_compile: true torch_compile_mode: "max-autotune" device: "cuda" model_type: "linear" diff --git a/benchmarks/microbenchmarks/benchmark_inference.py b/benchmarks/microbenchmarks/benchmark_inference.py index 77ae7080ef..b0d617f32d 100644 --- a/benchmarks/microbenchmarks/benchmark_inference.py +++ b/benchmarks/microbenchmarks/benchmark_inference.py @@ -13,6 +13,7 @@ import os from copy import deepcopy from pathlib import Path +from typing import Dict, Tuple import torch @@ -34,15 +35,70 @@ create_model_and_input_data, ) +# ----------------------------------------------------------------------------- +# Baseline caching +# +# ``_BASELINE_CACHE`` maps a unique key to a tuple +# ``(eager_baseline_time, compile_baseline_time)``. See ``_make_cache_key`` for the key +# construction. Users should not access this cache directly; it is +# internal to this module. The cache intentionally holds the +# uncompiled base model so that quantized versions can be derived +# without mutating the cached copy. + +_BASELINE_CACHE: Dict[Tuple, Tuple[float, float]] = {} + + +def _make_cache_key(config: BenchmarkConfig) -> Tuple: + """Create a key for caching based on benchmark configuration. + + Parameters that affect baseline performance are included: + + * model type (e.g. ``linear`` or ``transformer_block``) + * shape dimensions (m, k, n) + * high precision dtype (bf16, fp16, etc.) + * device (cuda, cpu, mps) + * compile settings (whether compile is enabled and compile mode) + + Sparsity and quantization settings are deliberately excluded + because the baseline (non‑quantized, non‑sparse) performance is + independent of those attributes. + """ + return ( + config.model_type, + config.m, + config.k, + config.n, + config.high_precision_dtype, + config.device, + config.torch_compile_mode, + ) + def run(config: BenchmarkConfig) -> BenchmarkResult: - """Run inference benchmarks""" + """ + Run inference benchmarks. + + The function first checks if a baseline for the given configuration + already exists in the internal cache. If not, it measures the baseline + inference time and stores the result. When the baseline is cached, + the function reuses the cached baselines to calculate speedup metrics. + + Args: + config (BenchmarkConfig): Benchmark configuration. + + Returns: + BenchmarkResult: Result of the benchmark. + """ try: clean_caches() # Clean caches # Create output directory if it doesn't exist Path(config.output_dir).mkdir(parents=True, exist_ok=True) + # Prepare result container + result = BenchmarkResult(config=config) + + # Create model and input data base_model, input_data = create_model_and_input_data( config.model_type, config.m, @@ -51,28 +107,46 @@ def run(config: BenchmarkConfig) -> BenchmarkResult: high_precision_dtype=config.high_precision_dtype, device=config.device, ) - # Copy base model for quantizing - m_copy = deepcopy(base_model) - # Run benchmarks - result = BenchmarkResult(config=config) + # Generate a cache key for the current configuration + cache_key = _make_cache_key(config) - # Store result in model for memory profiling - base_model._benchmark_result = result + # Check if the baseline for this configuration has been computed + if cache_key not in _BASELINE_CACHE: + # Switch model to eval and move to device + base_model = base_model.eval().to(config.device) + print("Benchmarking eager baseline inference.....") + eager_baseline_time = model_inference_time_in_ms( + model=base_model, input_data=input_data + ) - # Run baseline benchmarking - base_model = base_model.eval().to(config.device) - if config.use_torch_compile: - print("Compiling baseline model....") + print("Benchmarking compile baseline inference.....") base_model = torch.compile( base_model, mode=config.torch_compile_mode, fullgraph=True ) - # Benchmark time to run an inference call for baseline model - print("Benchmarking baseline inference.....") - result.baseline_inference_time_in_ms = model_inference_time_in_ms( - model=base_model, input_data=input_data - ) + compile_baseline_time = model_inference_time_in_ms( + model=base_model, input_data=input_data + ) + + # Store uncompiled model, input and baseline time + _BASELINE_CACHE[cache_key] = (eager_baseline_time, compile_baseline_time) + + result.eager_baseline_inference_time_in_ms = eager_baseline_time + result.compile_baseline_inference_time_in_ms = compile_baseline_time + else: + # Retrieve cached values + cached_eager_time, cached_compile_time = _BASELINE_CACHE[cache_key] + result.eager_baseline_inference_time_in_ms = cached_eager_time + result.compile_baseline_inference_time_in_ms = cached_compile_time + + # At this point, ``base_model`` is an uncompiled model ready for quantization, + # and ``input_data`` is the corresponding input tensor. The baseline time + # has been stored in ``result.baseline_inference_time_in_ms``. + + # Copy base model for quantizing/sparsifying + m_copy = deepcopy(base_model) + # Determine quantization/sparsity configuration ao_base_config = string_to_config( config.quantization, config.sparsity, @@ -101,24 +175,39 @@ def run(config: BenchmarkConfig) -> BenchmarkResult: m_copy = m_copy.eval().to(config.device) quantize_(m_copy, ao_base_config) - if config.use_torch_compile: - print("Compiling quantized model....") - m_copy = torch.compile( - m_copy, mode=config.torch_compile_mode, fullgraph=True - ) - # Store result in model for memory profiling m_copy._benchmark_result = result - # Benchmark time to run an inference call for quantized model + # Measure inference time for quantized model + print("Benchmarking eager quantized model.....") + result.eager_model_inference_time_in_ms = model_inference_time_in_ms( + model=m_copy, input_data=input_data + ) + + # Measure inference time for compiled quantized model print("Benchmarking quantized model.....") - result.model_inference_time_in_ms = model_inference_time_in_ms( + m_copy = torch.compile(m_copy, mode=config.torch_compile_mode, fullgraph=True) + result.compile_model_inference_time_in_ms = model_inference_time_in_ms( model=m_copy, input_data=input_data ) - # Calculate speedup w.r.t. baseline - result.speedup = round( - result.baseline_inference_time_in_ms / result.model_inference_time_in_ms, 2 + # Compute eager speedup relative to baseline + result.eager_speedup_on_baseline = round( + result.eager_baseline_inference_time_in_ms + / result.eager_model_inference_time_in_ms, + 2, + ) + # Compute compile speedup relative to baseline + result.compile_speedup_on_baseline = round( + result.compile_baseline_inference_time_in_ms + / result.compile_model_inference_time_in_ms, + 2, + ) + # Compute compile speedup for quantized model relative to eager quantized model + result.compile_speedup_on_eager = round( + result.eager_model_inference_time_in_ms + / result.compile_model_inference_time_in_ms, + 2, ) # Run profiler if enabled @@ -165,9 +254,9 @@ def run(config: BenchmarkConfig) -> BenchmarkResult: result.memory_profile_path ) except ValueError as e: - if "not enough values to unpack" in e: + if "not enough values to unpack" in str(e): print( - "Failed due to existing bugs, re-run the code to generate memory profile. Please raise an issue if it persists." + "Failed due to existing bugs, re‑run the code to generate memory profile. Please raise an issue if it persists." ) except Exception as e: print(f"Error running memory profiler: {e}") diff --git a/benchmarks/microbenchmarks/benchmark_runner.py b/benchmarks/microbenchmarks/benchmark_runner.py index 8066b71714..45a0534ee0 100644 --- a/benchmarks/microbenchmarks/benchmark_runner.py +++ b/benchmarks/microbenchmarks/benchmark_runner.py @@ -139,9 +139,6 @@ def get_quantization_sparsity_recipes( """ config_recipes = set() - # Always include baseline without sparsity - config_recipes.add(("baseline", None)) - # Add all quantization techniques without sparsity for quant_config in quantization_recipes: config_recipes.add((quant_config, None)) diff --git a/benchmarks/microbenchmarks/test/benchmark_config.yml b/benchmarks/microbenchmarks/test/benchmark_config.yml index 4fd5eb2018..40db49e223 100644 --- a/benchmarks/microbenchmarks/test/benchmark_config.yml +++ b/benchmarks/microbenchmarks/test/benchmark_config.yml @@ -13,7 +13,6 @@ model_params: min_power: 14 max_power: 16 high_precision_dtype: "torch.bfloat16" - use_torch_compile: true torch_compile_mode: "max-autotune" device: "cuda" model_type: "linear" @@ -27,7 +26,6 @@ model_params: [2048, 4096, 1024], ] high_precision_dtype: "torch.bfloat16" - use_torch_compile: true torch_compile_mode: "max-autotune" device: "cuda" model_type: "ln_linear_sigmoid" @@ -41,7 +39,6 @@ model_params: [2048, 4096, 1024], # For transformer_block, k is the hidden dimension ] high_precision_dtype: "torch.bfloat16" - use_torch_compile: true torch_compile_mode: "max-autotune" device: "cuda" model_type: "transformer_block" # TODO: Add a custom model (Figure out how to do this, maybe pass a .py file with model definition) @@ -58,7 +55,6 @@ model_params: min_power: 10 # 1024 max_power: 11 # 2048 high_precision_dtype: "torch.bfloat16" - use_torch_compile: true torch_compile_mode: "max-autotune" device: "cuda" model_type: "linear" diff --git a/benchmarks/microbenchmarks/test/test_benchmark_inference.py b/benchmarks/microbenchmarks/test/test_benchmark_inference.py index 22863dcbcf..38ffcc5a6c 100644 --- a/benchmarks/microbenchmarks/test/test_benchmark_inference.py +++ b/benchmarks/microbenchmarks/test/test_benchmark_inference.py @@ -21,7 +21,6 @@ def setUp(self): sparsity="semi-sparse", params={ "high_precision_dtype": "torch.float32", - "use_torch_compile": False, "device": "cpu", "model_type": "linear", }, @@ -46,7 +45,7 @@ def test_run_inference(self, mock_string_to_config): result = run(self.config) self.assertIsInstance(result, BenchmarkResult) - self.assertTrue(hasattr(result, "model_inference_time_in_ms")) + self.assertTrue(hasattr(result, "compile_model_inference_time_in_ms")) @patch("benchmarks.microbenchmarks.benchmark_inference.string_to_config") def test_run_inference_with_semi_sparse_marlin(self, mock_string_to_config): @@ -64,7 +63,6 @@ def test_run_inference_with_semi_sparse_marlin(self, mock_string_to_config): sparsity="semi-sparse", params={ "high_precision_dtype": "torch.float32", - "use_torch_compile": False, "device": "cpu", "model_type": "linear", }, @@ -75,7 +73,7 @@ def test_run_inference_with_semi_sparse_marlin(self, mock_string_to_config): ) result = run(config) self.assertIsInstance(result, BenchmarkResult) - self.assertTrue(hasattr(result, "model_inference_time_in_ms")) + self.assertTrue(hasattr(result, "compile_model_inference_time_in_ms")) @patch("benchmarks.microbenchmarks.benchmark_inference.string_to_config") def test_run_inference_with_block_sparsity(self, mock_string_to_config): @@ -92,7 +90,6 @@ def test_run_inference_with_block_sparsity(self, mock_string_to_config): sparsity="block", params={ "high_precision_dtype": "torch.float32", - "use_torch_compile": False, "device": "cpu", "model_type": "linear", }, @@ -103,7 +100,7 @@ def test_run_inference_with_block_sparsity(self, mock_string_to_config): ) result = run(config) self.assertIsInstance(result, BenchmarkResult) - self.assertTrue(hasattr(result, "model_inference_time_in_ms")) + self.assertTrue(hasattr(result, "compile_model_inference_time_in_ms")) if __name__ == "__main__": diff --git a/benchmarks/microbenchmarks/test/test_benchmark_profiler.py b/benchmarks/microbenchmarks/test/test_benchmark_profiler.py index 92689c4802..d0c36d8cfe 100644 --- a/benchmarks/microbenchmarks/test/test_benchmark_profiler.py +++ b/benchmarks/microbenchmarks/test/test_benchmark_profiler.py @@ -270,13 +270,12 @@ def test_memory_profiler_cuda_unavailable(self): f"{config.name}_{self.m}_{self.k}_{self.n}_memory_profile.json", ) - # Generate memory profile - result, memory_stats = generate_memory_profile( - self.model, self.input_data, memory_profile_path - ) - # Should return None when CUDA is unavailable - self.assertIsNone(result) + self.assertIsNone( + generate_memory_profile( + self.model, self.input_data, memory_profile_path + ) + ) # Should not create file when CUDA is unavailable self.assertFalse(os.path.exists(memory_profile_path)) diff --git a/benchmarks/microbenchmarks/test/test_benchmark_runner.py b/benchmarks/microbenchmarks/test/test_benchmark_runner.py index 2f7e5ba541..f7e54e4bec 100644 --- a/benchmarks/microbenchmarks/test/test_benchmark_runner.py +++ b/benchmarks/microbenchmarks/test/test_benchmark_runner.py @@ -39,7 +39,6 @@ def setUp(self): } ], "high_precision_dtype": "torch.bfloat16", - "use_torch_compile": True, "torch_compile_mode": "max-autotune", "device": "cpu", "model_type": "linear", @@ -130,7 +129,6 @@ def test_get_param_combinations(self): self.assertEqual(len(shapes), 1) self.assertEqual(shapes[0], ("custom", [1024, 1024, 1024])) self.assertEqual(params["high_precision_dtype"], "torch.bfloat16") - self.assertEqual(params["use_torch_compile"], True) @patch("argparse.Namespace") def test_load_benchmark_configs(self, mock_args): diff --git a/benchmarks/microbenchmarks/test/test_utils.py b/benchmarks/microbenchmarks/test/test_utils.py index 06f557a8f4..864c521251 100644 --- a/benchmarks/microbenchmarks/test/test_utils.py +++ b/benchmarks/microbenchmarks/test/test_utils.py @@ -33,7 +33,6 @@ def setUp(self): self.test_params = { "name": "test_model", "high_precision_dtype": "torch.bfloat16", - "use_torch_compile": True, "torch_compile_mode": "max-autotune", "device": "cpu", "model_type": "linear", @@ -57,7 +56,6 @@ def test_benchmark_config(self): self.assertEqual(config.k, 1024) self.assertEqual(config.n, 1024) self.assertEqual(config.high_precision_dtype, torch.bfloat16) - self.assertEqual(config.use_torch_compile, True) self.assertEqual(config.torch_compile_mode, "max-autotune") self.assertEqual(config.device, "cpu") self.assertEqual(config.model_type, "linear") @@ -76,7 +74,7 @@ def test_benchmark_result(self): result = BenchmarkResult(config=config) self.assertEqual(result.config, config) - self.assertEqual(result.model_inference_time_in_ms, 0.0) + self.assertEqual(result.compile_model_inference_time_in_ms, 0.0) def test_get_default_device(self): # Test CPU fallback diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py index 40bce5c33d..94c6f19b81 100644 --- a/benchmarks/microbenchmarks/utils.py +++ b/benchmarks/microbenchmarks/utils.py @@ -73,18 +73,13 @@ def __init__( self.high_precision_dtype = self._parse_precision( params.get("high_precision_dtype", "torch.bfloat16") ) - self.use_torch_compile = bool(params.get("use_torch_compile", False)) - self.torch_compile_mode = ( - params.get("torch_compile_mode", "default") - if self.use_torch_compile - else None - ) + self.torch_compile_mode = params.get("torch_compile_mode", "default") self.device = get_default_device(params.get("device", None)) self.model_type = params.get("model_type", "linear") self.output_dir = f"{output_dir}/{self.benchmark_mode}" self.name = params.get( "name", - f"benchmark_{self.quantization}_{self.model_type}_m{self.m}_k{self.k}_n{self.n}{'_compile' if self.use_torch_compile else ''}", + f"benchmark_{self.quantization}_{self.model_type}_m{self.m}_k{self.k}_n{self.n}{'_compile'}", ) self.enable_profiler = bool(params.get("enable_profiler", False)) self.enable_memory_profiler = bool(params.get("enable_memory_profiler", False)) @@ -108,7 +103,6 @@ def to_dict(self) -> Dict[str, Any]: "k": self.k, "n": self.n, "high_precision_dtype": self.high_precision_dtype, - "use_torch_compile": self.use_torch_compile, "torch_compile_mode": self.torch_compile_mode, "device": self.device, "model_type": self.model_type, @@ -125,9 +119,13 @@ def __init__( ): self.config = config self.output_dir = config.output_dir - self.baseline_inference_time_in_ms = 0.0 - self.model_inference_time_in_ms = 0.0 - self.speedup = 0.0 + self.eager_baseline_inference_time_in_ms = 0.0 + self.eager_model_inference_time_in_ms = 0.0 + self.compile_baseline_inference_time_in_ms = 0.0 + self.compile_model_inference_time_in_ms = 0.0 + self.eager_speedup_on_baseline = 0.0 + self.compile_speedup_on_baseline = 0.0 + self.compile_speedup_on_eager = 0.0 self.profiler_json_path: Optional[str] = None self.memory_profile_path: Optional[str] = None self.memory_visualization_path: Optional[str] = None @@ -137,9 +135,13 @@ def to_dict(self) -> Dict[str, Any]: """Convert result to dictionary for main function""" result_dict = { **self.config.to_dict(), - "baseline_inference_time_in_ms": self.baseline_inference_time_in_ms, - "model_inference_time_in_ms": self.model_inference_time_in_ms, - "speedup": self.speedup, + "eager_baseline_inference_time_in_ms": self.eager_baseline_inference_time_in_ms, + "eager_model_inference_time_in_ms": self.eager_model_inference_time_in_ms, + "compile_baseline_inference_time_in_ms": self.compile_baseline_inference_time_in_ms, + "compile_model_inference_time_in_ms": self.compile_model_inference_time_in_ms, + "eager speedup on baseline": self.eager_speedup_on_baseline, + "compile speedup on baseline": self.compile_speedup_on_baseline, + "eager vs compile speedup": self.compile_speedup_on_eager, "profiler_json_path": self.profiler_json_path, "memory_profile_path": self.memory_profile_path, "memory_visualization_path": self.memory_visualization_path, @@ -408,9 +410,13 @@ def print_results(results: List[BenchmarkResult]): result.config.quantization or "baseline", result.config.sparsity or "none", f"{result.config.shape_name} ({result.config.m}, {result.config.k}, {result.config.n})", - f"{result.baseline_inference_time_in_ms:.2f}", - f"{result.model_inference_time_in_ms:.2f}", - f"{result.speedup:.2f}x", + f"{result.eager_baseline_inference_time_in_ms:.2f}", + f"{result.eager_model_inference_time_in_ms:.2f}", + f"{result.eager_speedup_on_baseline:.2f}x", + f"{result.compile_baseline_inference_time_in_ms:.2f}", + f"{result.compile_model_inference_time_in_ms:.2f}", + f"{result.compile_speedup_on_baseline:.2f}x", + f"{result.compile_speedup_on_eager:.2f}x", str(result.config.enable_profiler), ] @@ -422,9 +428,13 @@ def print_results(results: List[BenchmarkResult]): "Quantization", "Sparsity", "Shape", - "Baseline Inference Time (ms)", - "Inference Time (ms)", - "Speedup", + "Eager Baseline Inference Time (ms)", + "Eager Model Inference Time (ms)", + "Eager Speedup", + "Compile Baseline Inference Time (ms)", + "Compile Model Inference Time (ms)", + "Compile Speedup", + "Eager vs Compile Speedup", "Profiler Enabled", ] diff --git a/docs/source/benchmarking_api_guide.md b/docs/source/benchmarking_api_guide.md index b07a0e14ff..bd81a7f65f 100644 --- a/docs/source/benchmarking_api_guide.md +++ b/docs/source/benchmarking_api_guide.md @@ -122,7 +122,6 @@ model_params: min_power: 10 max_power: 15 high_precision_dtype: "torch.bfloat16" - use_torch_compile: true torch_compile_mode: "max-autotune" device: "cuda" model_type: "linear" @@ -199,9 +198,8 @@ python -m unittest discover benchmarks/microbenchmarks/test ### Common Issues 1. **CUDA Out of Memory**: Reduce batch size or matrix dimensions -2. **Compilation Errors**: Set `use_torch_compile: false` for debugging -3. **Missing Quantization Methods**: Ensure TorchAO is properly installed -4. **Device Not Available**: Check device availability and drivers +2. **Missing Quantization Methods**: Ensure TorchAO is properly installed +3. **Device Not Available**: Check device availability and drivers ### Best Practices diff --git a/torchao/testing/model_architectures.py b/torchao/testing/model_architectures.py index f59a1271b1..9f02bbebc5 100644 --- a/torchao/testing/model_architectures.py +++ b/torchao/testing/model_architectures.py @@ -156,6 +156,7 @@ def create_model_and_input_data( high_precision_dtype (torch.dtype): data type of the model m, k, n (int): dimensions of the model and input data """ + torch.manual_seed(42) if model_type == "linear": model = ToyLinearModel(k, n, high_precision_dtype).to(device) input_data = torch.randn(m, k, device=device, dtype=high_precision_dtype)