diff --git a/thunder/benchmarks/benchmark_inference.py b/thunder/benchmarks/benchmark_inference.py index 4639c5df4e..212f5f8e0d 100644 --- a/thunder/benchmarks/benchmark_inference.py +++ b/thunder/benchmarks/benchmark_inference.py @@ -22,7 +22,6 @@ import warnings from typing import Any from collections.abc import Callable -from looseversion import LooseVersion import torch import torch.distributed as dist @@ -30,9 +29,8 @@ from torch.distributed.device_mesh import init_device_mesh from torch.distributed.tensor.parallel import parallelize_module, RowwiseParallel, ColwiseParallel from tqdm import tqdm -import transformers from transformers import AutoConfig, AutoModelForCausalLM -from transformers.cache_utils import HybridChunkedCache, StaticCache +from transformers.cache_utils import StaticCache from transformers.models.llama4.modeling_llama4 import Llama4TextMoe from torch.distributed.tensor.placement_types import Shard from torch.distributed.tensor import DTensor @@ -335,36 +333,20 @@ def _load_model(self) -> torch.nn.Module: return model - def generate_batch(self) -> tuple[torch.Tensor, HybridChunkedCache]: + def generate_batch(self) -> tuple[torch.Tensor, StaticCache]: """Generate a batch of input tokens""" batch_size = self.config.batch_size input_length = self.config.input_length input_ids = torch.randint(0, self.vocab_size, (batch_size, input_length), device=DEVICE) - if LooseVersion(transformers.__version__) >= LooseVersion("4.55"): - # Transformers deprecated HybridChunkedCache in favour of static in 4.55.x - past_key_values = StaticCache( - config=self.hf_config, - max_batch_size=input_ids.shape[0], - max_cache_len=input_ids.shape[1] + self.config.output_length, - device=DEVICE, - dtype=torch.bfloat16, - ) - else: - past_key_values = HybridChunkedCache( - self.hf_config, input_ids.shape[0], input_ids.shape[1] + self.config.output_length - ) - for layer_idx in range(self.hf_config.num_hidden_layers): - # key_states.shape[1] is used to retrieve the number of key value heads, all other dimensions can be 1 and ignored - # https://github.com/huggingface/transformers/blob/9300728665aaeb0ebf4db99f9d9fbce916b4a183/src/transformers/cache_utils.py#L1822 - dummy_key_states = torch.empty(1, self.hf_config.num_key_value_heads // WORLD_SIZE, 1, 1, device=DEVICE) - past_key_values.initialise_cache_layer(layer_idx, dummy_key_states) + past_key_values = StaticCache( + config=self.hf_config, + max_cache_len=input_ids.shape[1] + self.config.output_length, + ) return input_ids, past_key_values - def get_next_token( - self, input_ids: torch.Tensor, past_key_values: HybridChunkedCache | StaticCache - ) -> torch.Tensor: + def get_next_token(self, input_ids: torch.Tensor, past_key_values: StaticCache) -> torch.Tensor: start_pos = past_key_values.get_seq_length() cache_position = start_pos + torch.arange(0, input_ids.shape[1], device=start_pos.device, dtype=start_pos.dtype) with torch.no_grad(): @@ -376,7 +358,7 @@ def get_next_token( next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) return next_token - def prefill(self, input_ids: torch.Tensor, past_key_values: HybridChunkedCache) -> torch.Tensor: + def prefill(self, input_ids: torch.Tensor, past_key_values: StaticCache) -> torch.Tensor: """ Prefill phase: Process the entire input prompt at once. Returns the next token. @@ -385,7 +367,7 @@ def prefill(self, input_ids: torch.Tensor, past_key_values: HybridChunkedCache) """ return self.get_next_token(input_ids, past_key_values) - def decode_one_token(self, input_ids: torch.Tensor, past_key_values: HybridChunkedCache) -> torch.Tensor: + def decode_one_token(self, input_ids: torch.Tensor, past_key_values: StaticCache) -> torch.Tensor: """ Decode phase: Generate a single token given the current sequence. Returns the next token. @@ -401,9 +383,7 @@ def decode_one_token(self, input_ids: torch.Tensor, past_key_values: HybridChunk # [rank1]: ~^^^^^ # [rank1]: RuntimeError: Cannot set version_counter for inference tensor # @torch.inference_mode() - def generate( - self, input_ids: torch.Tensor, max_new_tokens: int, past_key_values: HybridChunkedCache - ) -> dict[str, Any]: + def generate(self, input_ids: torch.Tensor, max_new_tokens: int, past_key_values: StaticCache) -> dict[str, Any]: """ Generate tokens using separate prefill and decode phases. Returns detailed metrics for both phases. @@ -431,7 +411,7 @@ def generate( } def measure_inference_step( - self, input_ids: torch.Tensor, past_key_values: HybridChunkedCache, max_new_tokens: int + self, input_ids: torch.Tensor, past_key_values: StaticCache, max_new_tokens: int ) -> dict[str, float]: """Measure a single inference step with detailed timing using separate prefill/decode""" # Generate tokens with separate prefill/decode tracking