From 79e095a62d7f7ee5b44855f2eaee61bb64e72269 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Thu, 31 Jul 2025 18:51:57 -0700 Subject: [PATCH 01/11] Support input_embeds in torch exportable decoders --- src/transformers/integrations/executorch.py | 108 ++++++++++++++------ 1 file changed, 78 insertions(+), 30 deletions(-) diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 6fa0e6348d66..36f150dd8095 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -37,6 +37,8 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module): def __init__( self, model: PreTrainedModel, + config: PretrainedConfig, + generation_config: GenerationConfig, max_batch_size: int = 1, max_cache_len: int = 4096, ): @@ -45,6 +47,8 @@ def __init__( Args: model (`PreTrainedModel`): The pretrained model to wrap. + config (`PreTrainedConfig`): The pretrained text config for the decoder model. + generation_config (`GenerationConfig`): The generation config for the model. max_batch_size (int): Maximum batch size for the cache. max_cache_len (int): Maximum sequence length for the cache. @@ -53,10 +57,10 @@ def __init__( """ super().__init__() - if not hasattr(model.config, "use_cache") or model.config.use_cache is False: + if not hasattr(config, "use_cache") or config.use_cache is False: raise ValueError("The model must have caching enabled to be performant.") - if hasattr(model.config, "layer_types") and getattr(model.config, "sliding_window", None) is not None: + if hasattr(config, "layer_types") and getattr(config, "sliding_window", None) is not None: self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len) else: # If `layer_types` is not specified explicitly in the config or `sliding_window` is null, @@ -64,7 +68,7 @@ def __init__( logging.info( "Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config." ) - self.model = TorchExportableModuleWithStaticCache(model) + self.model = TorchExportableModuleWithStaticCache(model, config, generation_config) # This is the same as sdpa, but mask creation does not use `vmap` which is not exportable ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap) ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"]) @@ -72,7 +76,8 @@ def __init__( def forward( self, - input_ids: torch.Tensor, + input_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, cache_position: torch.Tensor, ) -> torch.Tensor: """ @@ -80,16 +85,22 @@ def forward( Args: input_ids (`torch.Tensor`): Tensor representing current input token id to the module. + inputs_embeds (`torch.Tensor`): Tensor representing current input embeddings to the module. cache_position (`torch.Tensor`): Tensor representing current input position in the cache. Returns: torch.Tensor: Logits output from the model. """ - return self.model.forward(input_ids, cache_position) + return self.model.forward( + cache_position=cache_position, + input_ids=input_ids, + inputs_embeds=inputs_embeds, + ) def export( self, input_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, cache_position: Optional[torch.Tensor] = None, dynamic_shapes: Optional[dict] = None, strict: Optional[bool] = None, @@ -99,7 +110,9 @@ def export( Args: input_ids (`Optional[torch.Tensor]`): - Tensor representing current input token id to the module. If not provided, a default tensor will be used. + Tensor representing current input token id to the module. If this and inputs_embeds are not provided, a default tensor will be used. + inputs_embeds (`Optional[torch.Tensor]`): + Tensor representing current input embeddings to the module. cache_position (`Optional[torch.Tensor]`): Tensor representing current input position in the cache. If not provided, a default tensor will be used. dynamic_shapes (`Optional[dict]`): @@ -118,20 +131,40 @@ def export( "TorchExportableModuleForDecoderOnlyLM.export Can't infer device from the model. Set to CPU by default." ) - example_input_ids = ( - input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long, device=model_device) - ) + if input_ids is not None and inputs_embeds is not None: + raise ValueError("Can't specify both input_ids and inputs_embeds.") + example_cache_position = ( cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long, device=model_device) ) - exported_program = torch.export.export( - self.model, - args=(example_input_ids, example_cache_position), - kwargs={}, - dynamic_shapes=dynamic_shapes, - strict=strict if strict is not None else True, - ) + if input_ids: + exported_program = torch.export.export( + self.model, + args=(), + kwargs={"input_ids": input_ids, "cache_position": example_cache_position}, + dynamic_shapes=dynamic_shapes, + strict=strict if strict is not None else True, + ) + elif input_emebds: + exported_program = torch.export.export( + self.model, + args=(), + kwargs={"inputs_embeds": inputs_embeds, "cache_position": example_cache_position}, + dynamic_shapes=dynamic_shapes, + strict=strict if strict is not None else True, + ) + else: + # No inputs specified, assume we are exporting with input_ids for legacy reasons. + example_input_ids = torch.tensor([[1]], dtype=torch.long, device=model_device) + exported_program = torch.export.export( + self.model, + args=(), + kwargs={"input_ids": example_input_ids, "cache_position": example_cache_position}, + dynamic_shapes=dynamic_shapes, + strict=strict if strict is not None else True, + ) + return exported_program @staticmethod @@ -180,7 +213,7 @@ def generate( curr_cache_position = torch.tensor([curr_position], dtype=torch.long, device=device) # Forward pass - _ = exported_module(curr_input_ids, curr_cache_position) + _ = exported_module(input_ids=curr_input_ids, cache_position=curr_cache_position) curr_position += 1 # Generate new tokens @@ -190,7 +223,7 @@ def generate( curr_cache_position = torch.tensor([curr_position], dtype=torch.long, device=device) # Forward pass to get next token logits - outputs = exported_module(curr_input_ids, curr_cache_position) + outputs = exported_module(input_ids=curr_input_ids, cache_position=curr_cache_position) # Get the next token ID if do_sample: @@ -254,13 +287,20 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module): in a way that ensures the model can be further lowered and run efficiently in `ExecuTorch`. """ - def __init__(self, model: PreTrainedModel): + def __init__( + self, + model: PreTrainedModel, + config: PretrainedConfig, + generation_config: GenerationConfig, + ): """ Initializes the wrapper module with the pretrained model. Args: model (`PreTrainedModel`): The pretrained model to wrap. The model must have caching enabled and use a 'static' caching implementation. + config (`PreTrainedConfig`): The pretrained text config for the decoder model. + generation_config (`GenerationConfig`): The generation config for the model. Raises: AssertionError: If the pretrained model does not have caching enabled or if it does @@ -269,42 +309,52 @@ def __init__(self, model: PreTrainedModel): super().__init__() # Sanity checks - if model.generation_config is None: + if generation_config is None: raise AssertionError( "The model must have a generation config to be exported with static caching. " "Please set `generation_config`." ) - if not model.generation_config.use_cache: + if not generation_config.use_cache: raise AssertionError( "The model must have caching enabled to be exported with static caching. " "Please set `generation_config.use_cache=True`." ) - if model.generation_config.cache_implementation != "static": + if generation_config.cache_implementation != "static": raise AssertionError( "The model must use a 'static' caching implementation to be exported with static caching. " "Please set `generation_config.cache_implementation='static'`." ) self.model = model + self.config = config + self.generation_config = generation_config self.static_cache = StaticCache( - config=self.model.config, - max_batch_size=self.model.generation_config.cache_config.get("batch_size"), - max_cache_len=self.model.generation_config.cache_config.get("max_cache_len"), - device=self.model.generation_config.cache_config.get("device"), + config=config, + max_batch_size=self.generation_config.cache_config.get("batch_size"), + max_cache_len=self.generation_config.cache_config.get("max_cache_len"), + device=self.generation_config.cache_config.get("device"), dtype=self.model.dtype, ) + for i in range(len(self.static_cache)): self.register_buffer(f"key_cache_{i}", self.static_cache.layers[i].keys, persistent=False) self.register_buffer(f"value_cache_{i}", self.static_cache.layers[i].values, persistent=False) - def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor): + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_position: torch.Tensor = None, + ): """ Forward pass of the module, which is compatible with the ExecuTorch runtime. Args: input_ids (`torch.Tensor`): Tensor representing current input token id to the module. + inputs_embeds (`torch.Tensor`): Tensor representing current input embeddings to the module. cache_position (`torch.Tensor`): Tensor representing current input position in the cache. Returns: @@ -320,14 +370,12 @@ def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor): The adapter matches the model's forward signature with that in `executorch/extension/llm/runner`, ensuring that the exported model can be executed in `ExecuTorch` out-of-the-box. """ - _, seqlen = input_ids.shape - position_ids = cache_position.unsqueeze(0) past_key_values = self.static_cache outs = self.model( input_ids=input_ids, + inputs_embeds=inputs_embeds, attention_mask=None, - position_ids=position_ids, cache_position=cache_position, past_key_values=past_key_values, use_cache=True, From 1ad0627574ba6f59b545c730f699b048bf792b57 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Thu, 31 Jul 2025 21:47:49 -0700 Subject: [PATCH 02/11] Hybrid cache update --- src/transformers/integrations/executorch.py | 50 ++++++++++++++------- 1 file changed, 35 insertions(+), 15 deletions(-) diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 36f150dd8095..dafe6c1a4766 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -61,7 +61,11 @@ def __init__( raise ValueError("The model must have caching enabled to be performant.") if hasattr(config, "layer_types") and getattr(config, "sliding_window", None) is not None: - self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len) + self.model = TorchExportableModuleWithHybridCache( + model, + config=config, + generation_config=generation_config + ) else: # If `layer_types` is not specified explicitly in the config or `sliding_window` is null, # there is only 1 type of layers, so export will use `StaticCache` by default. @@ -146,7 +150,7 @@ def export( dynamic_shapes=dynamic_shapes, strict=strict if strict is not None else True, ) - elif input_emebds: + elif inputs_embeds: exported_program = torch.export.export( self.model, args=(), @@ -299,7 +303,7 @@ def __init__( Args: model (`PreTrainedModel`): The pretrained model to wrap. The model must have caching enabled and use a 'static' caching implementation. - config (`PreTrainedConfig`): The pretrained text config for the decoder model. + config (`PreTrainedConfig`): The pretrained text config for the model. generation_config (`GenerationConfig`): The generation config for the model. Raises: @@ -446,14 +450,16 @@ class TorchExportableModuleWithHybridCache(torch.nn.Module): def __init__( self, model: PreTrainedModel, - max_batch_size: int = 1, - max_cache_len: int = 4096, + config: PretrainedConfig, + generation_config: GenerationConfig, ): """ Initializes the exportable module with `HybridCache`. Args: model (`PreTrainedModel`): The pretrained model to wrap. + config (`PreTrainedConfig`): The pretrained text config for the model. + generation_config (`GenerationConfig`): The generation config for the model. max_batch_size (int): Maximum batch size for the cache. max_cache_len (int): Maximum sequence length for the cache. @@ -462,17 +468,19 @@ def __init__( """ super().__init__() self.model = model + self.config = config + self.generation_config = generation_config # Verify the model is configured for HybridCache - if not self.model.config.use_cache: - raise AssertionError("Model must have caching enabled") + if not config.use_cache: + raise AssertionError("Model must have caching enabled.") # Initialize the HybridCache self.cache = HybridCache( - config=self.model.config, - max_batch_size=max_batch_size, - max_cache_len=max_cache_len, - device=self.model.device, + config=config, + max_batch_size=self.generation_config.cache_config.get("batch_size"), + max_cache_len=self.generation_config.cache_config.get("max_cache_len"), + device=self.generation_config.cache_config.get("device"), dtype=self.model.dtype, ) @@ -483,23 +491,31 @@ def __init__( def forward( self, - input_ids: torch.Tensor, - cache_position: torch.Tensor, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_position: torch.Tensor = None, ) -> torch.Tensor: """ Forward pass of the module, which is compatible with the ExecuTorch llm runner. Args: input_ids (`torch.Tensor`): Tensor representing current input token id to the module. + inputs_embeds (`Optional[torch.Tensor]`): Tensor representing current input embeddings to the module. cache_position (`torch.Tensor`): Tensor representing current input position in the cache. Returns: torch.Tensor: Logits output from the model. """ - batch_size = input_ids.shape[0] + batch_size = None + if input_ids: + batch_size = input_ids.shape[0] + elif inputs_embeds: + batch_size = inputs_embeds.shape[0] # Generate position_ids from cache_position - position_ids = cache_position.unsqueeze(0).expand(batch_size, -1) + position_ids = None + if batch_size: + position_ids = cache_position.unsqueeze(0).expand(batch_size, -1) # Forward pass with the model outputs = self.model( @@ -517,6 +533,8 @@ def forward( def convert_and_export_with_cache( model: PreTrainedModel, + config: PreTrainedConfig, + generation_config: GenerationConfig, example_input_ids: Optional[torch.Tensor] = None, example_cache_position: Optional[torch.Tensor] = None, dynamic_shapes: Optional[dict] = None, @@ -528,6 +546,8 @@ def convert_and_export_with_cache( Args: model (`PreTrainedModel`): The pretrained model to be exported. + config (`PreTrainedConfig`): The pretrained text config for the decoder model. + generation_config (`GenerationConfig`): The generation config for the model. example_input_ids (`Optional[torch.Tensor]`): Example input token id used by `torch.export`. example_cache_position (`Optional[torch.Tensor]`): Example current cache position used by `torch.export`. dynamic_shapes(`Optional[dict]`): Dynamic shapes used by `torch.export`. From 35bb9a4c7536ef28a24fe965b65796eadfcf5150 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Thu, 31 Jul 2025 21:52:45 -0700 Subject: [PATCH 03/11] Manually change some callsites --- src/transformers/integrations/executorch.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index dafe6c1a4766..7b8a88b0c5dc 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -581,9 +581,13 @@ def convert_and_export_with_cache( if is_torch_greater_or_equal("2.6.0"): exported_program = torch.export.export( - TorchExportableModuleWithStaticCache(model), - args=(example_input_ids, example_cache_position), - kwargs={}, + TorchExportableModuleWithStaticCache( + model=model, + config=config, + generation_config=generation_config + ), + args=(), + kwargs={"input_ids": example_input_ids, "cache_position": example_cache_position}, dynamic_shapes=dynamic_shapes, strict=strict if strict is not None else True, ) @@ -599,9 +603,13 @@ def convert_and_export_with_cache( # Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal # export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release. exported_program = torch.export._trace._export( - TorchExportableModuleWithStaticCache(model), - args=(example_input_ids,), - kwargs={"cache_position": example_cache_position}, + TorchExportableModuleWithStaticCache( + model=model, + config=config, + generation_config=generation_config, + ), + args=(), + kwargs={"input_ids": example_input_ids, "cache_position": example_cache_position}, pre_dispatch=False, strict=True, ) From 68f21b8db62bd596abc4bd16fc2c8efa9bc1a4f3 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Thu, 31 Jul 2025 22:11:11 -0700 Subject: [PATCH 04/11] AI changes the rest of the call sites --- tests/models/cohere2/test_modeling_cohere2.py | 4 +++- tests/models/exaone4/test_modeling_exaone4.py | 4 +++- tests/models/gemma/test_modeling_gemma.py | 8 ++++++-- tests/models/gemma2/test_modeling_gemma2.py | 17 +++++++++++++---- tests/models/gemma3/test_modeling_gemma3.py | 9 +++++++-- tests/models/llama/test_modeling_llama.py | 8 ++++++-- tests/models/olmo/test_modeling_olmo.py | 8 ++++++-- tests/models/olmo2/test_modeling_olmo2.py | 4 +++- tests/models/phi3/test_modeling_phi3.py | 8 ++++++-- tests/models/qwen2/test_modeling_qwen2.py | 8 ++++++-- tests/models/qwen3/test_modeling_qwen3.py | 8 ++++++-- tests/models/smollm3/test_modeling_smollm3.py | 4 +++- 12 files changed, 68 insertions(+), 22 deletions(-) diff --git a/tests/models/cohere2/test_modeling_cohere2.py b/tests/models/cohere2/test_modeling_cohere2.py index 71335c37075e..b21cda7d1535 100644 --- a/tests/models/cohere2/test_modeling_cohere2.py +++ b/tests/models/cohere2/test_modeling_cohere2.py @@ -275,7 +275,9 @@ def test_export_static_cache(self): max_new_tokens = 30 - prompt_token_ids.shape[-1] # Static Cache + export - exported_program = convert_and_export_with_cache(model) + exported_program = convert_and_export_with_cache( + model, config=model.config, generation_config=model.generation_config + ) ep_generated_ids = TorchExportableModuleWithStaticCache.generate( exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens ) diff --git a/tests/models/exaone4/test_modeling_exaone4.py b/tests/models/exaone4/test_modeling_exaone4.py index 4ac87ce900b5..f8c525e5e1c0 100644 --- a/tests/models/exaone4/test_modeling_exaone4.py +++ b/tests/models/exaone4/test_modeling_exaone4.py @@ -400,7 +400,9 @@ def test_export_static_cache(self): max_new_tokens = max_generation_length - prompt_token_ids.shape[-1] # Static Cache + export - exported_program = convert_and_export_with_cache(model) + exported_program = convert_and_export_with_cache( + model, config=model.config, generation_config=model.generation_config + ) ep_generated_ids = TorchExportableModuleWithStaticCache.generate( exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens ) diff --git a/tests/models/gemma/test_modeling_gemma.py b/tests/models/gemma/test_modeling_gemma.py index 9276fb12b328..4ef26c9bc76f 100644 --- a/tests/models/gemma/test_modeling_gemma.py +++ b/tests/models/gemma/test_modeling_gemma.py @@ -459,8 +459,12 @@ def test_export_static_cache(self): # Static Cache + export from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM - exportable_module = TorchExportableModuleForDecoderOnlyLM(model) - exported_program = exportable_module.export() + exportable_module = TorchExportableModuleForDecoderOnlyLM( + model, config=model.config, generation_config=model.generation_config + ) + exported_program = exportable_module.export( + input_ids=prompt_token_ids, cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device) + ) ep_generated_ids = TorchExportableModuleWithStaticCache.generate( exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens ) diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index 589e08dd1d98..53d03025a670 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -364,8 +364,12 @@ def test_export_static_cache(self): # Static Cache + export from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM - exportable_module = TorchExportableModuleForDecoderOnlyLM(model) - exported_program = exportable_module.export() + exportable_module = TorchExportableModuleForDecoderOnlyLM( + model, config=model.config, generation_config=model.generation_config + ) + exported_program = exportable_module.export( + input_ids=prompt_token_ids, cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device) + ) ep_generated_ids = TorchExportableModuleWithStaticCache.generate( exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens ) @@ -388,8 +392,13 @@ def test_export_hybrid_cache(self): # Export + HybridCache model.eval() - exportable_module = TorchExportableModuleForDecoderOnlyLM(model) - exported_program = exportable_module.export() + exportable_module = TorchExportableModuleForDecoderOnlyLM( + model, config=model.config, generation_config=model.generation_config + ) + exported_program = exportable_module.export( + input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device), + cache_position=torch.tensor([0], dtype=torch.long, device=model.device) + ) # Test generation with the exported model prompt = "What is the capital of France?" diff --git a/tests/models/gemma3/test_modeling_gemma3.py b/tests/models/gemma3/test_modeling_gemma3.py index 43ac57dbb566..e0ce6cd34fe4 100644 --- a/tests/models/gemma3/test_modeling_gemma3.py +++ b/tests/models/gemma3/test_modeling_gemma3.py @@ -808,8 +808,13 @@ def test_export_text_only_with_hybrid_cache(self): # Export + HybridCache model.eval() - exportable_module = TorchExportableModuleForDecoderOnlyLM(model) - exported_program = exportable_module.export() + exportable_module = TorchExportableModuleForDecoderOnlyLM( + model, config=model.config, generation_config=model.generation_config + ) + exported_program = exportable_module.export( + input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device), + cache_position=torch.tensor([0], dtype=torch.long, device=model.device) + ) logging.info(f"\nExported program: {exported_program}") # Test generation with the exported model diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 136f76f48c9a..fc3343c3f74d 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -352,8 +352,12 @@ def test_export_static_cache(self): # Static Cache + export from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM - exportable_module = TorchExportableModuleForDecoderOnlyLM(model) - exported_program = exportable_module.export() + exportable_module = TorchExportableModuleForDecoderOnlyLM( + model, config=model.config, generation_config=model.generation_config + ) + exported_program = exportable_module.export( + input_ids=prompt_token_ids, cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device) + ) ep_generated_ids = TorchExportableModuleWithStaticCache.generate( exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens ) diff --git a/tests/models/olmo/test_modeling_olmo.py b/tests/models/olmo/test_modeling_olmo.py index 86913f254fbb..5e77bf14b8ae 100644 --- a/tests/models/olmo/test_modeling_olmo.py +++ b/tests/models/olmo/test_modeling_olmo.py @@ -383,8 +383,12 @@ def test_export_static_cache(self): # Static Cache + export from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM - exportable_module = TorchExportableModuleForDecoderOnlyLM(model) - exported_program = exportable_module.export() + exportable_module = TorchExportableModuleForDecoderOnlyLM( + model, config=model.config, generation_config=model.generation_config + ) + exported_program = exportable_module.export( + input_ids=prompt_token_ids, cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device) + ) ep_generated_ids = TorchExportableModuleWithStaticCache.generate( exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens ) diff --git a/tests/models/olmo2/test_modeling_olmo2.py b/tests/models/olmo2/test_modeling_olmo2.py index 20b0c49d3f0b..53971d3e6f0c 100644 --- a/tests/models/olmo2/test_modeling_olmo2.py +++ b/tests/models/olmo2/test_modeling_olmo2.py @@ -383,7 +383,9 @@ def test_export_static_cache(self): self.assertEqual(EXPECTED_TEXT_COMPLETION, eager_generated_text) # Static Cache + export - exported_program = convert_and_export_with_cache(model) + exported_program = convert_and_export_with_cache( + model, config=model.config, generation_config=model.generation_config + ) ep_generated_ids = TorchExportableModuleWithStaticCache.generate( exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens ) diff --git a/tests/models/phi3/test_modeling_phi3.py b/tests/models/phi3/test_modeling_phi3.py index 387eb6c4df79..f80ef773d47d 100644 --- a/tests/models/phi3/test_modeling_phi3.py +++ b/tests/models/phi3/test_modeling_phi3.py @@ -416,8 +416,12 @@ def test_export_static_cache(self): # Static Cache + export from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM - exportable_module = TorchExportableModuleForDecoderOnlyLM(model) - exported_program = exportable_module.export() + exportable_module = TorchExportableModuleForDecoderOnlyLM( + model, config=model.config, generation_config=model.generation_config + ) + exported_program = exportable_module.export( + input_ids=prompt_token_ids, cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device) + ) ep_generated_ids = TorchExportableModuleWithStaticCache.generate( exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens ) diff --git a/tests/models/qwen2/test_modeling_qwen2.py b/tests/models/qwen2/test_modeling_qwen2.py index d48226394c33..3508959732cf 100644 --- a/tests/models/qwen2/test_modeling_qwen2.py +++ b/tests/models/qwen2/test_modeling_qwen2.py @@ -299,11 +299,15 @@ def test_export_static_cache(self): # Static Cache + export from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM - exportable_module = TorchExportableModuleForDecoderOnlyLM(model) + exportable_module = TorchExportableModuleForDecoderOnlyLM( + model, config=model.config, generation_config=model.generation_config + ) strict = version.parse(torch.__version__) != version.parse( "2.7.0" ) # Due to https://github.com/pytorch/pytorch/issues/150994 - exported_program = exportable_module.export(strict=strict) + exported_program = exportable_module.export( + input_ids=prompt_token_ids, cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device), strict=strict + ) ep_generated_ids = TorchExportableModuleWithStaticCache.generate( exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens ) diff --git a/tests/models/qwen3/test_modeling_qwen3.py b/tests/models/qwen3/test_modeling_qwen3.py index a37df40ed4a8..830ebead5a88 100644 --- a/tests/models/qwen3/test_modeling_qwen3.py +++ b/tests/models/qwen3/test_modeling_qwen3.py @@ -292,8 +292,12 @@ def test_export_static_cache(self): # Static Cache + export from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM - exportable_module = TorchExportableModuleForDecoderOnlyLM(model) - exported_program = exportable_module.export(strict=strict) + exportable_module = TorchExportableModuleForDecoderOnlyLM( + model, config=model.config, generation_config=model.generation_config + ) + exported_program = exportable_module.export( + input_ids=prompt_token_ids, cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device), strict=strict + ) ep_generated_ids = TorchExportableModuleWithStaticCache.generate( exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens ) diff --git a/tests/models/smollm3/test_modeling_smollm3.py b/tests/models/smollm3/test_modeling_smollm3.py index f855e0b36a5f..80baaf9fd15b 100644 --- a/tests/models/smollm3/test_modeling_smollm3.py +++ b/tests/models/smollm3/test_modeling_smollm3.py @@ -219,7 +219,9 @@ def test_export_static_cache(self): # Static Cache + export strict = is_torch_greater_or_equal("2.7.0") # Due to https://github.com/pytorch/pytorch/issues/150994 - exported_program = convert_and_export_with_cache(model, strict=strict) + exported_program = convert_and_export_with_cache( + model, config=model.config, generation_config=model.generation_config, strict=strict + ) ep_generated_ids = TorchExportableModuleWithStaticCache.generate( exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens ) From eda53a49622c7be28e26c9199d315e5ac8fdecc2 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Thu, 31 Jul 2025 22:14:47 -0700 Subject: [PATCH 05/11] Make either input_ids/inputs_embeds mandatory --- src/transformers/integrations/executorch.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 7b8a88b0c5dc..7d4fe0a68952 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -135,8 +135,8 @@ def export( "TorchExportableModuleForDecoderOnlyLM.export Can't infer device from the model. Set to CPU by default." ) - if input_ids is not None and inputs_embeds is not None: - raise ValueError("Can't specify both input_ids and inputs_embeds.") + if not input_ids ^ inputs_embeds: + raise ValueError("Need to specify either input_ids or inputs_embeds.") example_cache_position = ( cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long, device=model_device) @@ -150,7 +150,7 @@ def export( dynamic_shapes=dynamic_shapes, strict=strict if strict is not None else True, ) - elif inputs_embeds: + else: # inputs_embeds exported_program = torch.export.export( self.model, args=(), @@ -158,16 +158,6 @@ def export( dynamic_shapes=dynamic_shapes, strict=strict if strict is not None else True, ) - else: - # No inputs specified, assume we are exporting with input_ids for legacy reasons. - example_input_ids = torch.tensor([[1]], dtype=torch.long, device=model_device) - exported_program = torch.export.export( - self.model, - args=(), - kwargs={"input_ids": example_input_ids, "cache_position": example_cache_position}, - dynamic_shapes=dynamic_shapes, - strict=strict if strict is not None else True, - ) return exported_program From 62da12e5f4ca14d2cde05527628015b079d962ef Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Thu, 31 Jul 2025 22:25:03 -0700 Subject: [PATCH 06/11] Clean up --- src/transformers/integrations/executorch.py | 62 ++++++++++++++++----- tests/utils/test_cache_utils.py | 25 ++++++++- 2 files changed, 69 insertions(+), 18 deletions(-) diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 7d4fe0a68952..447c47747a4f 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -16,6 +16,7 @@ import torch from ..cache_utils import DynamicCache, EncoderDecoderCache, HybridCache, StaticCache +from ..configuration_utils import PretrainedConfig from ..generation.configuration_utils import GenerationConfig from ..masking_utils import ( ALL_MASK_ATTENTION_FUNCTIONS, @@ -47,7 +48,7 @@ def __init__( Args: model (`PreTrainedModel`): The pretrained model to wrap. - config (`PreTrainedConfig`): The pretrained text config for the decoder model. + config (`PretrainedConfig`): The pretrained text config for the decoder model. generation_config (`GenerationConfig`): The generation config for the model. max_batch_size (int): Maximum batch size for the cache. max_cache_len (int): Maximum sequence length for the cache. @@ -82,7 +83,7 @@ def forward( self, input_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, - cache_position: torch.Tensor, + cache_position: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Forward pass of the module, which is compatible with the ExecuTorch llm runner. @@ -114,16 +115,50 @@ def export( Args: input_ids (`Optional[torch.Tensor]`): - Tensor representing current input token id to the module. If this and inputs_embeds are not provided, a default tensor will be used. + Tensor representing current input token id to the module. Must specify either this or inputs_embeds. inputs_embeds (`Optional[torch.Tensor]`): - Tensor representing current input embeddings to the module. + Tensor representing current input embeddings to the module. Must specify either this or input_ids. cache_position (`Optional[torch.Tensor]`): Tensor representing current input position in the cache. If not provided, a default tensor will be used. dynamic_shapes (`Optional[dict]`): Dynamic shapes to use for export if specified. strict(`Optional[bool]`): Flag to instruct `torch.export` to use `torchdynamo`. + + Returns: + torch.export.ExportedProgram: The exported program that can be used for inference. + + Examples: + Export with input_ids: + ```python + # Prepare inputs + input_ids = torch.tensor([[1, 2, 3]], dtype=torch.long, device=model.device) + cache_position = torch.arange(input_ids.shape[-1], dtype=torch.long, device=model.device) + + # Export + exported = exportable_module.export( + input_ids=input_ids, + cache_position=cache_position + ) + ``` + + Export with inputs_embeds: + ```python + # Prepare embeddings + inputs_embeds = torch.randn(1, 3, 768, device=model.device) # batch_size=1, seq_len=3, hidden_size=768 + cache_position = torch.arange(inputs_embeds.shape[1], dtype=torch.long, device=model.device) + + # Export + exported = exportable_module.export( + inputs_embeds=inputs_embeds, + cache_position=cache_position + ) + ``` """ + # Validate inputs early for fail-fast behavior + if not input_ids ^ inputs_embeds: + raise ValueError("Need to specify either input_ids or inputs_embeds.") + if hasattr(self.model, "base_model_prefix"): base = getattr(self.model, self.model.base_model_prefix, self.model) model_device = base.device @@ -135,9 +170,6 @@ def export( "TorchExportableModuleForDecoderOnlyLM.export Can't infer device from the model. Set to CPU by default." ) - if not input_ids ^ inputs_embeds: - raise ValueError("Need to specify either input_ids or inputs_embeds.") - example_cache_position = ( cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long, device=model_device) ) @@ -293,7 +325,7 @@ def __init__( Args: model (`PreTrainedModel`): The pretrained model to wrap. The model must have caching enabled and use a 'static' caching implementation. - config (`PreTrainedConfig`): The pretrained text config for the model. + config (`PretrainedConfig`): The pretrained text config for the model. generation_config (`GenerationConfig`): The generation config for the model. Raises: @@ -340,8 +372,8 @@ def __init__( def forward( self, input_ids: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - cache_position: torch.Tensor = None, + inputs_embeds: Optional[torch.Tensor] = None, + cache_position: Optional[torch.Tensor] = None, ): """ Forward pass of the module, which is compatible with the ExecuTorch runtime. @@ -448,7 +480,7 @@ def __init__( Args: model (`PreTrainedModel`): The pretrained model to wrap. - config (`PreTrainedConfig`): The pretrained text config for the model. + config (`PretrainedConfig`): The pretrained text config for the model. generation_config (`GenerationConfig`): The generation config for the model. max_batch_size (int): Maximum batch size for the cache. max_cache_len (int): Maximum sequence length for the cache. @@ -482,8 +514,8 @@ def __init__( def forward( self, input_ids: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - cache_position: torch.Tensor = None, + inputs_embeds: Optional[torch.Tensor] = None, + cache_position: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Forward pass of the module, which is compatible with the ExecuTorch llm runner. @@ -523,7 +555,7 @@ def forward( def convert_and_export_with_cache( model: PreTrainedModel, - config: PreTrainedConfig, + config: PretrainedConfig, generation_config: GenerationConfig, example_input_ids: Optional[torch.Tensor] = None, example_cache_position: Optional[torch.Tensor] = None, @@ -536,7 +568,7 @@ def convert_and_export_with_cache( Args: model (`PreTrainedModel`): The pretrained model to be exported. - config (`PreTrainedConfig`): The pretrained text config for the decoder model. + config (`PretrainedConfig`): The pretrained text config for the decoder model. generation_config (`GenerationConfig`): The generation config for the model. example_input_ids (`Optional[torch.Tensor]`): Example input token id used by `torch.export`. example_cache_position (`Optional[torch.Tensor]`): Example current cache position used by `torch.export`. diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 14b29344f190..02b89afe2ae6 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -813,7 +813,9 @@ def test_static_cache_exportability(self): from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM - exportable_module = TorchExportableModuleForDecoderOnlyLM(model) + exportable_module = TorchExportableModuleForDecoderOnlyLM( + model, config=model.config, generation_config=model.generation_config + ) exported_program = exportable_module.export( input_ids=input_ids, cache_position=cache_position, @@ -841,8 +843,25 @@ def test_hybrid_cache_exportability(self): model.eval() max_batch_size = 1 max_cache_len = 23 - exportable_module = TorchExportableModuleForDecoderOnlyLM(model, max_batch_size, max_cache_len) - exported_program = exportable_module.export() + # Create generation config for the hybrid cache model + from transformers.generation.configuration_utils import GenerationConfig + generation_config = GenerationConfig( + use_cache=True, + cache_implementation="hybrid", + max_length=max_cache_len, + cache_config={ + "batch_size": max_batch_size, + "max_cache_len": max_cache_len, + "device": model.device, + }, + ) + exportable_module = TorchExportableModuleForDecoderOnlyLM( + model, config=model.config, generation_config=generation_config + ) + exported_program = exportable_module.export( + input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device), + cache_position=torch.tensor([0], dtype=torch.long, device=model.device) + ) n_g_key_caches = n_g_value_caches = 0 for buffer_name, buffer in exported_program.named_buffers(): if buffer_name.startswith("key_cache"): From e03b3e0e63f9ea977b048588d113d7a6dcae0a90 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Thu, 31 Jul 2025 22:44:35 -0700 Subject: [PATCH 07/11] Ruff check --fix --- src/transformers/integrations/executorch.py | 21 +++++++-------------- tests/models/gemma/test_modeling_gemma.py | 3 ++- tests/models/gemma2/test_modeling_gemma2.py | 7 ++++--- tests/models/gemma3/test_modeling_gemma3.py | 4 ++-- tests/models/llama/test_modeling_llama.py | 3 ++- tests/models/olmo/test_modeling_olmo.py | 3 ++- tests/models/phi3/test_modeling_phi3.py | 3 ++- tests/models/qwen2/test_modeling_qwen2.py | 4 +++- tests/models/qwen3/test_modeling_qwen3.py | 4 +++- tests/utils/test_cache_utils.py | 3 ++- 10 files changed, 29 insertions(+), 26 deletions(-) diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 447c47747a4f..ecd4cfba4d73 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -63,9 +63,7 @@ def __init__( if hasattr(config, "layer_types") and getattr(config, "sliding_window", None) is not None: self.model = TorchExportableModuleWithHybridCache( - model, - config=config, - generation_config=generation_config + model, config=config, generation_config=generation_config ) else: # If `layer_types` is not specified explicitly in the config or `sliding_window` is null, @@ -134,21 +132,21 @@ def export( # Prepare inputs input_ids = torch.tensor([[1, 2, 3]], dtype=torch.long, device=model.device) cache_position = torch.arange(input_ids.shape[-1], dtype=torch.long, device=model.device) - + # Export exported = exportable_module.export( - input_ids=input_ids, + input_ids=input_ids, cache_position=cache_position ) ``` - + Export with inputs_embeds: ```python # Prepare embeddings inputs_embeds = torch.randn(1, 3, 768, device=model.device) # batch_size=1, seq_len=3, hidden_size=768 cache_position = torch.arange(inputs_embeds.shape[1], dtype=torch.long, device=model.device) - - # Export + + # Export exported = exportable_module.export( inputs_embeds=inputs_embeds, cache_position=cache_position @@ -368,7 +366,6 @@ def __init__( self.register_buffer(f"key_cache_{i}", self.static_cache.layers[i].keys, persistent=False) self.register_buffer(f"value_cache_{i}", self.static_cache.layers[i].values, persistent=False) - def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -603,11 +600,7 @@ def convert_and_export_with_cache( if is_torch_greater_or_equal("2.6.0"): exported_program = torch.export.export( - TorchExportableModuleWithStaticCache( - model=model, - config=config, - generation_config=generation_config - ), + TorchExportableModuleWithStaticCache(model=model, config=config, generation_config=generation_config), args=(), kwargs={"input_ids": example_input_ids, "cache_position": example_cache_position}, dynamic_shapes=dynamic_shapes, diff --git a/tests/models/gemma/test_modeling_gemma.py b/tests/models/gemma/test_modeling_gemma.py index 4ef26c9bc76f..718a8c0f541d 100644 --- a/tests/models/gemma/test_modeling_gemma.py +++ b/tests/models/gemma/test_modeling_gemma.py @@ -463,7 +463,8 @@ def test_export_static_cache(self): model, config=model.config, generation_config=model.generation_config ) exported_program = exportable_module.export( - input_ids=prompt_token_ids, cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device) + input_ids=prompt_token_ids, + cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device), ) ep_generated_ids = TorchExportableModuleWithStaticCache.generate( exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index 53d03025a670..e00637f235eb 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -368,7 +368,8 @@ def test_export_static_cache(self): model, config=model.config, generation_config=model.generation_config ) exported_program = exportable_module.export( - input_ids=prompt_token_ids, cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device) + input_ids=prompt_token_ids, + cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device), ) ep_generated_ids = TorchExportableModuleWithStaticCache.generate( exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens @@ -396,8 +397,8 @@ def test_export_hybrid_cache(self): model, config=model.config, generation_config=model.generation_config ) exported_program = exportable_module.export( - input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device), - cache_position=torch.tensor([0], dtype=torch.long, device=model.device) + input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device), + cache_position=torch.tensor([0], dtype=torch.long, device=model.device), ) # Test generation with the exported model diff --git a/tests/models/gemma3/test_modeling_gemma3.py b/tests/models/gemma3/test_modeling_gemma3.py index e0ce6cd34fe4..f256ed9e3a6e 100644 --- a/tests/models/gemma3/test_modeling_gemma3.py +++ b/tests/models/gemma3/test_modeling_gemma3.py @@ -812,8 +812,8 @@ def test_export_text_only_with_hybrid_cache(self): model, config=model.config, generation_config=model.generation_config ) exported_program = exportable_module.export( - input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device), - cache_position=torch.tensor([0], dtype=torch.long, device=model.device) + input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device), + cache_position=torch.tensor([0], dtype=torch.long, device=model.device), ) logging.info(f"\nExported program: {exported_program}") diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index fc3343c3f74d..14cb614b08d0 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -356,7 +356,8 @@ def test_export_static_cache(self): model, config=model.config, generation_config=model.generation_config ) exported_program = exportable_module.export( - input_ids=prompt_token_ids, cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device) + input_ids=prompt_token_ids, + cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device), ) ep_generated_ids = TorchExportableModuleWithStaticCache.generate( exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens diff --git a/tests/models/olmo/test_modeling_olmo.py b/tests/models/olmo/test_modeling_olmo.py index 5e77bf14b8ae..67fc3e900b6a 100644 --- a/tests/models/olmo/test_modeling_olmo.py +++ b/tests/models/olmo/test_modeling_olmo.py @@ -387,7 +387,8 @@ def test_export_static_cache(self): model, config=model.config, generation_config=model.generation_config ) exported_program = exportable_module.export( - input_ids=prompt_token_ids, cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device) + input_ids=prompt_token_ids, + cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device), ) ep_generated_ids = TorchExportableModuleWithStaticCache.generate( exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens diff --git a/tests/models/phi3/test_modeling_phi3.py b/tests/models/phi3/test_modeling_phi3.py index f80ef773d47d..55997aa26f7f 100644 --- a/tests/models/phi3/test_modeling_phi3.py +++ b/tests/models/phi3/test_modeling_phi3.py @@ -420,7 +420,8 @@ def test_export_static_cache(self): model, config=model.config, generation_config=model.generation_config ) exported_program = exportable_module.export( - input_ids=prompt_token_ids, cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device) + input_ids=prompt_token_ids, + cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device), ) ep_generated_ids = TorchExportableModuleWithStaticCache.generate( exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens diff --git a/tests/models/qwen2/test_modeling_qwen2.py b/tests/models/qwen2/test_modeling_qwen2.py index 3508959732cf..0e0f1d005f35 100644 --- a/tests/models/qwen2/test_modeling_qwen2.py +++ b/tests/models/qwen2/test_modeling_qwen2.py @@ -306,7 +306,9 @@ def test_export_static_cache(self): "2.7.0" ) # Due to https://github.com/pytorch/pytorch/issues/150994 exported_program = exportable_module.export( - input_ids=prompt_token_ids, cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device), strict=strict + input_ids=prompt_token_ids, + cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device), + strict=strict, ) ep_generated_ids = TorchExportableModuleWithStaticCache.generate( exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens diff --git a/tests/models/qwen3/test_modeling_qwen3.py b/tests/models/qwen3/test_modeling_qwen3.py index 830ebead5a88..7fa18310ffc1 100644 --- a/tests/models/qwen3/test_modeling_qwen3.py +++ b/tests/models/qwen3/test_modeling_qwen3.py @@ -296,7 +296,9 @@ def test_export_static_cache(self): model, config=model.config, generation_config=model.generation_config ) exported_program = exportable_module.export( - input_ids=prompt_token_ids, cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device), strict=strict + input_ids=prompt_token_ids, + cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device), + strict=strict, ) ep_generated_ids = TorchExportableModuleWithStaticCache.generate( exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 02b89afe2ae6..ca46cd7d788c 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -845,6 +845,7 @@ def test_hybrid_cache_exportability(self): max_cache_len = 23 # Create generation config for the hybrid cache model from transformers.generation.configuration_utils import GenerationConfig + generation_config = GenerationConfig( use_cache=True, cache_implementation="hybrid", @@ -860,7 +861,7 @@ def test_hybrid_cache_exportability(self): ) exported_program = exportable_module.export( input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device), - cache_position=torch.tensor([0], dtype=torch.long, device=model.device) + cache_position=torch.tensor([0], dtype=torch.long, device=model.device), ) n_g_key_caches = n_g_value_caches = 0 for buffer_name, buffer in exported_program.named_buffers(): From bcb30e97c83bdcce29796e5841e002ea1f85d87d Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Fri, 1 Aug 2025 10:31:21 -0700 Subject: [PATCH 08/11] Fix test --- src/transformers/integrations/executorch.py | 30 ++++++++------------- 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index ecd4cfba4d73..22416a4f1759 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -95,9 +95,9 @@ def forward( torch.Tensor: Logits output from the model. """ return self.model.forward( - cache_position=cache_position, input_ids=input_ids, inputs_embeds=inputs_embeds, + cache_position=cache_position, ) def export( @@ -153,8 +153,7 @@ def export( ) ``` """ - # Validate inputs early for fail-fast behavior - if not input_ids ^ inputs_embeds: + if not (input_ids is None) ^ (inputs_embeds is None): raise ValueError("Need to specify either input_ids or inputs_embeds.") if hasattr(self.model, "base_model_prefix"): @@ -172,15 +171,19 @@ def export( cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long, device=model_device) ) - if input_ids: + if input_ids is not None: + if cache_position is None: + cache_position = torch.arange(input_ids.shape[-1], dtype=torch.long) exported_program = torch.export.export( self.model, args=(), - kwargs={"input_ids": input_ids, "cache_position": example_cache_position}, + kwargs={"input_ids": input_ids, "cache_position": cache_position}, dynamic_shapes=dynamic_shapes, strict=strict if strict is not None else True, ) else: # inputs_embeds + if cache_position is None: + cache_position = torch.arange(inputs_embeds.shape[1], dtype=torch.long) exported_program = torch.export.export( self.model, args=(), @@ -398,8 +401,8 @@ def forward( outs = self.model( input_ids=input_ids, inputs_embeds=inputs_embeds, - attention_mask=None, cache_position=cache_position, + attention_mask=None, past_key_values=past_key_values, use_cache=True, ) @@ -525,25 +528,14 @@ def forward( Returns: torch.Tensor: Logits output from the model. """ - batch_size = None - if input_ids: - batch_size = input_ids.shape[0] - elif inputs_embeds: - batch_size = inputs_embeds.shape[0] - - # Generate position_ids from cache_position - position_ids = None - if batch_size: - position_ids = cache_position.unsqueeze(0).expand(batch_size, -1) - # Forward pass with the model outputs = self.model( input_ids=input_ids, + inputs_embeds=inputs_embeds, + cache_position=cache_position, attention_mask=None, - position_ids=position_ids, past_key_values=self.cache, use_cache=True, - cache_position=cache_position, ) # Return only the logits to simplify the export From 8ea782104134f3ddf50c949c8e537f9d62ed8a40 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Fri, 1 Aug 2025 11:22:44 -0700 Subject: [PATCH 09/11] pr review --- src/transformers/integrations/executorch.py | 25 ++-- tests/test_executorch.py | 125 ++++++++++++++++++++ 2 files changed, 137 insertions(+), 13 deletions(-) create mode 100644 tests/test_executorch.py diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 22416a4f1759..b7f1dfea09e6 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -38,10 +38,8 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module): def __init__( self, model: PreTrainedModel, - config: PretrainedConfig, - generation_config: GenerationConfig, - max_batch_size: int = 1, - max_cache_len: int = 4096, + config: Optional[PretrainedConfig] = None, + generation_config: Optional[GenerationConfig] = None, ): """ Initializes the exportable module with `HybridCache`. @@ -49,15 +47,20 @@ def __init__( Args: model (`PreTrainedModel`): The pretrained model to wrap. config (`PretrainedConfig`): The pretrained text config for the decoder model. + If not specified will try to resolve with the model's config. generation_config (`GenerationConfig`): The generation config for the model. - max_batch_size (int): Maximum batch size for the cache. - max_cache_len (int): Maximum sequence length for the cache. + If not specified will try to resolve with the model's generation config. Raises: ValueError: If the model is configured with a unsupported cache implementation. """ super().__init__() + if not config: + config = model.config + if not generation_config: + generation_config = model.generation_config + if not hasattr(config, "use_cache") or config.use_cache is False: raise ValueError("The model must have caching enabled to be performant.") @@ -167,13 +170,9 @@ def export( "TorchExportableModuleForDecoderOnlyLM.export Can't infer device from the model. Set to CPU by default." ) - example_cache_position = ( - cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long, device=model_device) - ) - if input_ids is not None: if cache_position is None: - cache_position = torch.arange(input_ids.shape[-1], dtype=torch.long) + cache_position = torch.arange(input_ids.shape[-1], dtype=torch.long, model=model_device) exported_program = torch.export.export( self.model, args=(), @@ -183,11 +182,11 @@ def export( ) else: # inputs_embeds if cache_position is None: - cache_position = torch.arange(inputs_embeds.shape[1], dtype=torch.long) + cache_position = torch.arange(inputs_embeds.shape[1], dtype=torch.long, model=model_device) exported_program = torch.export.export( self.model, args=(), - kwargs={"inputs_embeds": inputs_embeds, "cache_position": example_cache_position}, + kwargs={"inputs_embeds": inputs_embeds, "cache_position": cache_position}, dynamic_shapes=dynamic_shapes, strict=strict if strict is not None else True, ) diff --git a/tests/test_executorch.py b/tests/test_executorch.py new file mode 100644 index 000000000000..f36b0fc739e5 --- /dev/null +++ b/tests/test_executorch.py @@ -0,0 +1,125 @@ +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from transformers import AutoModelForCausalLM, set_seed +from transformers.generation.configuration_utils import GenerationConfig +from transformers.integrations.executorch import ( + TorchExportableModuleForDecoderOnlyLM, + TorchExportableModuleWithHybridCache, + TorchExportableModuleWithStaticCache, +) +from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3 +from transformers.testing_utils import require_torch + + +@require_torch +class ExecutorchTest(unittest.TestCase): + def setUp(self): + if not is_torch_greater_or_equal_than_2_3: + self.skipTest("torch >= 2.3 is required") + + set_seed(0) + self.model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM") + self.model.eval() + + # Create generation config with static cache for the model + self.model.generation_config = GenerationConfig( + use_cache=True, + cache_implementation="static", + cache_config={"batch_size": 1, "max_cache_len": 32, "device": "cpu"}, + ) + + self.input_ids = torch.tensor([[1, 2, 3]], dtype=torch.long) + self.inputs_embeds = torch.randn(1, 3, self.model.config.hidden_size) + self.cache_position = torch.arange(3, dtype=torch.long) + + def test_static_cache_module_forward(self): + """Test TorchExportableModuleWithStaticCache forward with both input types""" + generation_config = GenerationConfig( + use_cache=True, + cache_implementation="static", + cache_config={"batch_size": 1, "max_cache_len": 32, "device": "cpu"}, + ) + + module = TorchExportableModuleWithStaticCache(self.model, self.model.config, generation_config) + + # Test with input_ids + eager_output_ids = self.model(input_ids=self.input_ids, use_cache=False).logits + wrapped_output_ids = module.forward(input_ids=self.input_ids, cache_position=self.cache_position) + torch.testing.assert_close(eager_output_ids, wrapped_output_ids, atol=1e-4, rtol=1e-4) + + # Test with inputs_embeds + eager_output_embeds = self.model(inputs_embeds=self.inputs_embeds, use_cache=False).logits + wrapped_output_embeds = module.forward(inputs_embeds=self.inputs_embeds, cache_position=self.cache_position) + torch.testing.assert_close(eager_output_embeds, wrapped_output_embeds, atol=1e-4, rtol=1e-4) + + def test_hybrid_cache_module_forward(self): + """Test TorchExportableModuleWithHybridCache forward with both input types""" + config = self.model.config + config.sliding_window = 16 + config.layer_types = ["full_attention"] * config.num_hidden_layers + + generation_config = GenerationConfig( + use_cache=True, + cache_implementation="hybrid", + cache_config={"batch_size": 1, "max_cache_len": 32, "device": "cpu"}, + ) + + module = TorchExportableModuleWithHybridCache(self.model, config, generation_config) + + # Test with input_ids + eager_output_ids = self.model(input_ids=self.input_ids, use_cache=False).logits + wrapped_output_ids = module.forward(input_ids=self.input_ids, cache_position=self.cache_position) + torch.testing.assert_close(eager_output_ids, wrapped_output_ids, atol=1e-4, rtol=1e-4) + + # Test with inputs_embeds + eager_output_embeds = self.model(inputs_embeds=self.inputs_embeds, use_cache=False).logits + wrapped_output_embeds = module.forward(inputs_embeds=self.inputs_embeds, cache_position=self.cache_position) + torch.testing.assert_close(eager_output_embeds, wrapped_output_embeds, atol=1e-4, rtol=1e-4) + + def test_decoder_only_lm_export_validation(self): + """Test TorchExportableModuleForDecoderOnlyLM export validation""" + module = TorchExportableModuleForDecoderOnlyLM(self.model) + + # Should fail with both input_ids and inputs_embeds + with self.assertRaises(ValueError): + module.export(input_ids=self.input_ids, inputs_embeds=self.inputs_embeds) + + # Should fail with neither + with self.assertRaises(ValueError): + module.export() + + def test_decoder_only_lm_export(self): + """Test TorchExportableModuleForDecoderOnlyLM export with both input types""" + module = TorchExportableModuleForDecoderOnlyLM(self.model) + + # Test export with input_ids + exported_program_ids = module.export(input_ids=self.input_ids, cache_position=self.cache_position) + eager_output_ids = self.model(input_ids=self.input_ids, use_cache=False).logits + exported_output_ids = exported_program_ids.module()( + input_ids=self.input_ids, cache_position=self.cache_position + ) + torch.testing.assert_close(eager_output_ids, exported_output_ids, atol=1e-4, rtol=1e-4) + + # Test export with inputs_embeds + exported_program_embeds = module.export(inputs_embeds=self.inputs_embeds, cache_position=self.cache_position) + eager_output_embeds = self.model(inputs_embeds=self.inputs_embeds, use_cache=False).logits + exported_output_embeds = exported_program_embeds.module()( + inputs_embeds=self.inputs_embeds, cache_position=self.cache_position + ) + torch.testing.assert_close(eager_output_embeds, exported_output_embeds, atol=1e-4, rtol=1e-4) From 14610ed8215ac3d6c4310a31700dab834e9c6a1a Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Mon, 4 Aug 2025 12:59:08 -0700 Subject: [PATCH 10/11] Revert config/generation_config changes --- src/transformers/integrations/executorch.py | 102 +++++++----------- tests/models/cohere2/test_modeling_cohere2.py | 4 +- tests/models/exaone4/test_modeling_exaone4.py | 4 +- tests/models/gemma/test_modeling_gemma.py | 4 +- tests/models/gemma2/test_modeling_gemma2.py | 8 +- tests/models/gemma3/test_modeling_gemma3.py | 4 +- tests/models/llama/test_modeling_llama.py | 4 +- tests/models/olmo/test_modeling_olmo.py | 4 +- tests/models/olmo2/test_modeling_olmo2.py | 4 +- tests/models/phi3/test_modeling_phi3.py | 4 +- tests/models/qwen2/test_modeling_qwen2.py | 4 +- tests/models/qwen3/test_modeling_qwen3.py | 4 +- tests/models/smollm3/test_modeling_smollm3.py | 4 +- tests/test_executorch.py | 8 +- tests/utils/test_cache_utils.py | 12 +-- 15 files changed, 61 insertions(+), 113 deletions(-) diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index b7f1dfea09e6..536cb5994c8b 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -16,7 +16,6 @@ import torch from ..cache_utils import DynamicCache, EncoderDecoderCache, HybridCache, StaticCache -from ..configuration_utils import PretrainedConfig from ..generation.configuration_utils import GenerationConfig from ..masking_utils import ( ALL_MASK_ATTENTION_FUNCTIONS, @@ -38,43 +37,33 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module): def __init__( self, model: PreTrainedModel, - config: Optional[PretrainedConfig] = None, - generation_config: Optional[GenerationConfig] = None, ): """ Initializes the exportable module with `HybridCache`. Args: model (`PreTrainedModel`): The pretrained model to wrap. - config (`PretrainedConfig`): The pretrained text config for the decoder model. - If not specified will try to resolve with the model's config. - generation_config (`GenerationConfig`): The generation config for the model. - If not specified will try to resolve with the model's generation config. Raises: ValueError: If the model is configured with a unsupported cache implementation. """ super().__init__() - if not config: - config = model.config - if not generation_config: - generation_config = model.generation_config + config = model.config.get_text_config() + generation_config = model.generation_config if not hasattr(config, "use_cache") or config.use_cache is False: raise ValueError("The model must have caching enabled to be performant.") if hasattr(config, "layer_types") and getattr(config, "sliding_window", None) is not None: - self.model = TorchExportableModuleWithHybridCache( - model, config=config, generation_config=generation_config - ) + self.model = TorchExportableModuleWithHybridCache(model) else: # If `layer_types` is not specified explicitly in the config or `sliding_window` is null, # there is only 1 type of layers, so export will use `StaticCache` by default. logging.info( "Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config." ) - self.model = TorchExportableModuleWithStaticCache(model, config, generation_config) + self.model = TorchExportableModuleWithStaticCache(model) # This is the same as sdpa, but mask creation does not use `vmap` which is not exportable ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap) ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"]) @@ -171,25 +160,27 @@ def export( ) if input_ids is not None: - if cache_position is None: - cache_position = torch.arange(input_ids.shape[-1], dtype=torch.long, model=model_device) - exported_program = torch.export.export( - self.model, - args=(), - kwargs={"input_ids": input_ids, "cache_position": cache_position}, - dynamic_shapes=dynamic_shapes, - strict=strict if strict is not None else True, - ) + input_kwargs = { + "input_ids": input_ids, + "cache_position": cache_position + if cache_position is not None + else torch.arange(input_ids.shape[-1], dtype=torch.long, model=model_device), + } else: # inputs_embeds - if cache_position is None: - cache_position = torch.arange(inputs_embeds.shape[1], dtype=torch.long, model=model_device) - exported_program = torch.export.export( - self.model, - args=(), - kwargs={"inputs_embeds": inputs_embeds, "cache_position": cache_position}, - dynamic_shapes=dynamic_shapes, - strict=strict if strict is not None else True, - ) + input_kwargs = { + "inputs_embeds": inputs_embeds, + "cache_position": cache_position + if cache_position is not None + else torch.arange(inputs_embeds.shape[1], dtype=torch.long, model=model_device), + } + + exported_program = torch.export.export( + self.model, + args=(), + kwargs=input_kwargs, + dynamic_shapes=dynamic_shapes, + strict=strict if strict is not None else True, + ) return exported_program @@ -316,17 +307,13 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module): def __init__( self, model: PreTrainedModel, - config: PretrainedConfig, - generation_config: GenerationConfig, ): """ Initializes the wrapper module with the pretrained model. Args: model (`PreTrainedModel`): The pretrained model to wrap. The model must have caching - enabled and use a 'static' caching implementation. - config (`PretrainedConfig`): The pretrained text config for the model. - generation_config (`GenerationConfig`): The generation config for the model. + enabled and use a 'static' caching implementation. Raises: AssertionError: If the pretrained model does not have caching enabled or if it does @@ -334,6 +321,9 @@ def __init__( """ super().__init__() + config = model.config.get_text_config() + generation_config = model.generation_config + # Sanity checks if generation_config is None: raise AssertionError( @@ -354,13 +344,11 @@ def __init__( ) self.model = model - self.config = config - self.generation_config = generation_config self.static_cache = StaticCache( config=config, - max_batch_size=self.generation_config.cache_config.get("batch_size"), - max_cache_len=self.generation_config.cache_config.get("max_cache_len"), - device=self.generation_config.cache_config.get("device"), + max_batch_size=generation_config.cache_config.get("batch_size"), + max_cache_len=generation_config.cache_config.get("max_cache_len"), + device=generation_config.cache_config.get("device"), dtype=self.model.dtype, ) @@ -471,26 +459,20 @@ class TorchExportableModuleWithHybridCache(torch.nn.Module): def __init__( self, model: PreTrainedModel, - config: PretrainedConfig, - generation_config: GenerationConfig, ): """ Initializes the exportable module with `HybridCache`. Args: model (`PreTrainedModel`): The pretrained model to wrap. - config (`PretrainedConfig`): The pretrained text config for the model. - generation_config (`GenerationConfig`): The generation config for the model. - max_batch_size (int): Maximum batch size for the cache. - max_cache_len (int): Maximum sequence length for the cache. Raises: AssertionError: If the model doesn't have the expected configuration for HybridCache. """ super().__init__() self.model = model - self.config = config - self.generation_config = generation_config + config = model.config.get_text_config() + generation_config = model.generation_config # Verify the model is configured for HybridCache if not config.use_cache: @@ -499,9 +481,9 @@ def __init__( # Initialize the HybridCache self.cache = HybridCache( config=config, - max_batch_size=self.generation_config.cache_config.get("batch_size"), - max_cache_len=self.generation_config.cache_config.get("max_cache_len"), - device=self.generation_config.cache_config.get("device"), + max_batch_size=generation_config.cache_config.get("batch_size"), + max_cache_len=generation_config.cache_config.get("max_cache_len"), + device=generation_config.cache_config.get("device"), dtype=self.model.dtype, ) @@ -543,8 +525,6 @@ def forward( def convert_and_export_with_cache( model: PreTrainedModel, - config: PretrainedConfig, - generation_config: GenerationConfig, example_input_ids: Optional[torch.Tensor] = None, example_cache_position: Optional[torch.Tensor] = None, dynamic_shapes: Optional[dict] = None, @@ -556,8 +536,6 @@ def convert_and_export_with_cache( Args: model (`PreTrainedModel`): The pretrained model to be exported. - config (`PretrainedConfig`): The pretrained text config for the decoder model. - generation_config (`GenerationConfig`): The generation config for the model. example_input_ids (`Optional[torch.Tensor]`): Example input token id used by `torch.export`. example_cache_position (`Optional[torch.Tensor]`): Example current cache position used by `torch.export`. dynamic_shapes(`Optional[dict]`): Dynamic shapes used by `torch.export`. @@ -591,7 +569,7 @@ def convert_and_export_with_cache( if is_torch_greater_or_equal("2.6.0"): exported_program = torch.export.export( - TorchExportableModuleWithStaticCache(model=model, config=config, generation_config=generation_config), + TorchExportableModuleWithStaticCache(model), args=(), kwargs={"input_ids": example_input_ids, "cache_position": example_cache_position}, dynamic_shapes=dynamic_shapes, @@ -609,11 +587,7 @@ def convert_and_export_with_cache( # Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal # export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release. exported_program = torch.export._trace._export( - TorchExportableModuleWithStaticCache( - model=model, - config=config, - generation_config=generation_config, - ), + TorchExportableModuleWithStaticCache(model), args=(), kwargs={"input_ids": example_input_ids, "cache_position": example_cache_position}, pre_dispatch=False, diff --git a/tests/models/cohere2/test_modeling_cohere2.py b/tests/models/cohere2/test_modeling_cohere2.py index b21cda7d1535..71335c37075e 100644 --- a/tests/models/cohere2/test_modeling_cohere2.py +++ b/tests/models/cohere2/test_modeling_cohere2.py @@ -275,9 +275,7 @@ def test_export_static_cache(self): max_new_tokens = 30 - prompt_token_ids.shape[-1] # Static Cache + export - exported_program = convert_and_export_with_cache( - model, config=model.config, generation_config=model.generation_config - ) + exported_program = convert_and_export_with_cache(model) ep_generated_ids = TorchExportableModuleWithStaticCache.generate( exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens ) diff --git a/tests/models/exaone4/test_modeling_exaone4.py b/tests/models/exaone4/test_modeling_exaone4.py index f8c525e5e1c0..4ac87ce900b5 100644 --- a/tests/models/exaone4/test_modeling_exaone4.py +++ b/tests/models/exaone4/test_modeling_exaone4.py @@ -400,9 +400,7 @@ def test_export_static_cache(self): max_new_tokens = max_generation_length - prompt_token_ids.shape[-1] # Static Cache + export - exported_program = convert_and_export_with_cache( - model, config=model.config, generation_config=model.generation_config - ) + exported_program = convert_and_export_with_cache(model) ep_generated_ids = TorchExportableModuleWithStaticCache.generate( exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens ) diff --git a/tests/models/gemma/test_modeling_gemma.py b/tests/models/gemma/test_modeling_gemma.py index 718a8c0f541d..8a1e2ea9eb7f 100644 --- a/tests/models/gemma/test_modeling_gemma.py +++ b/tests/models/gemma/test_modeling_gemma.py @@ -459,9 +459,7 @@ def test_export_static_cache(self): # Static Cache + export from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM - exportable_module = TorchExportableModuleForDecoderOnlyLM( - model, config=model.config, generation_config=model.generation_config - ) + exportable_module = TorchExportableModuleForDecoderOnlyLM(model) exported_program = exportable_module.export( input_ids=prompt_token_ids, cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device), diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index e00637f235eb..4a6c326f780c 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -364,9 +364,7 @@ def test_export_static_cache(self): # Static Cache + export from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM - exportable_module = TorchExportableModuleForDecoderOnlyLM( - model, config=model.config, generation_config=model.generation_config - ) + exportable_module = TorchExportableModuleForDecoderOnlyLM(model) exported_program = exportable_module.export( input_ids=prompt_token_ids, cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device), @@ -393,9 +391,7 @@ def test_export_hybrid_cache(self): # Export + HybridCache model.eval() - exportable_module = TorchExportableModuleForDecoderOnlyLM( - model, config=model.config, generation_config=model.generation_config - ) + exportable_module = TorchExportableModuleForDecoderOnlyLM(model) exported_program = exportable_module.export( input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device), cache_position=torch.tensor([0], dtype=torch.long, device=model.device), diff --git a/tests/models/gemma3/test_modeling_gemma3.py b/tests/models/gemma3/test_modeling_gemma3.py index f256ed9e3a6e..e1b444e2c546 100644 --- a/tests/models/gemma3/test_modeling_gemma3.py +++ b/tests/models/gemma3/test_modeling_gemma3.py @@ -808,9 +808,7 @@ def test_export_text_only_with_hybrid_cache(self): # Export + HybridCache model.eval() - exportable_module = TorchExportableModuleForDecoderOnlyLM( - model, config=model.config, generation_config=model.generation_config - ) + exportable_module = TorchExportableModuleForDecoderOnlyLM(model) exported_program = exportable_module.export( input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device), cache_position=torch.tensor([0], dtype=torch.long, device=model.device), diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 14cb614b08d0..a6c2c3eee2b6 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -352,9 +352,7 @@ def test_export_static_cache(self): # Static Cache + export from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM - exportable_module = TorchExportableModuleForDecoderOnlyLM( - model, config=model.config, generation_config=model.generation_config - ) + exportable_module = TorchExportableModuleForDecoderOnlyLM(model) exported_program = exportable_module.export( input_ids=prompt_token_ids, cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device), diff --git a/tests/models/olmo/test_modeling_olmo.py b/tests/models/olmo/test_modeling_olmo.py index 67fc3e900b6a..ea23f4e96fda 100644 --- a/tests/models/olmo/test_modeling_olmo.py +++ b/tests/models/olmo/test_modeling_olmo.py @@ -383,9 +383,7 @@ def test_export_static_cache(self): # Static Cache + export from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM - exportable_module = TorchExportableModuleForDecoderOnlyLM( - model, config=model.config, generation_config=model.generation_config - ) + exportable_module = TorchExportableModuleForDecoderOnlyLM(model) exported_program = exportable_module.export( input_ids=prompt_token_ids, cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device), diff --git a/tests/models/olmo2/test_modeling_olmo2.py b/tests/models/olmo2/test_modeling_olmo2.py index 53971d3e6f0c..20b0c49d3f0b 100644 --- a/tests/models/olmo2/test_modeling_olmo2.py +++ b/tests/models/olmo2/test_modeling_olmo2.py @@ -383,9 +383,7 @@ def test_export_static_cache(self): self.assertEqual(EXPECTED_TEXT_COMPLETION, eager_generated_text) # Static Cache + export - exported_program = convert_and_export_with_cache( - model, config=model.config, generation_config=model.generation_config - ) + exported_program = convert_and_export_with_cache(model) ep_generated_ids = TorchExportableModuleWithStaticCache.generate( exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens ) diff --git a/tests/models/phi3/test_modeling_phi3.py b/tests/models/phi3/test_modeling_phi3.py index 55997aa26f7f..6887c0c6cd64 100644 --- a/tests/models/phi3/test_modeling_phi3.py +++ b/tests/models/phi3/test_modeling_phi3.py @@ -416,9 +416,7 @@ def test_export_static_cache(self): # Static Cache + export from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM - exportable_module = TorchExportableModuleForDecoderOnlyLM( - model, config=model.config, generation_config=model.generation_config - ) + exportable_module = TorchExportableModuleForDecoderOnlyLM(model) exported_program = exportable_module.export( input_ids=prompt_token_ids, cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device), diff --git a/tests/models/qwen2/test_modeling_qwen2.py b/tests/models/qwen2/test_modeling_qwen2.py index 0e0f1d005f35..51bd943cf916 100644 --- a/tests/models/qwen2/test_modeling_qwen2.py +++ b/tests/models/qwen2/test_modeling_qwen2.py @@ -299,9 +299,7 @@ def test_export_static_cache(self): # Static Cache + export from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM - exportable_module = TorchExportableModuleForDecoderOnlyLM( - model, config=model.config, generation_config=model.generation_config - ) + exportable_module = TorchExportableModuleForDecoderOnlyLM(model) strict = version.parse(torch.__version__) != version.parse( "2.7.0" ) # Due to https://github.com/pytorch/pytorch/issues/150994 diff --git a/tests/models/qwen3/test_modeling_qwen3.py b/tests/models/qwen3/test_modeling_qwen3.py index 7fa18310ffc1..205228073e19 100644 --- a/tests/models/qwen3/test_modeling_qwen3.py +++ b/tests/models/qwen3/test_modeling_qwen3.py @@ -292,9 +292,7 @@ def test_export_static_cache(self): # Static Cache + export from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM - exportable_module = TorchExportableModuleForDecoderOnlyLM( - model, config=model.config, generation_config=model.generation_config - ) + exportable_module = TorchExportableModuleForDecoderOnlyLM(model) exported_program = exportable_module.export( input_ids=prompt_token_ids, cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device), diff --git a/tests/models/smollm3/test_modeling_smollm3.py b/tests/models/smollm3/test_modeling_smollm3.py index 80baaf9fd15b..f855e0b36a5f 100644 --- a/tests/models/smollm3/test_modeling_smollm3.py +++ b/tests/models/smollm3/test_modeling_smollm3.py @@ -219,9 +219,7 @@ def test_export_static_cache(self): # Static Cache + export strict = is_torch_greater_or_equal("2.7.0") # Due to https://github.com/pytorch/pytorch/issues/150994 - exported_program = convert_and_export_with_cache( - model, config=model.config, generation_config=model.generation_config, strict=strict - ) + exported_program = convert_and_export_with_cache(model, strict=strict) ep_generated_ids = TorchExportableModuleWithStaticCache.generate( exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens ) diff --git a/tests/test_executorch.py b/tests/test_executorch.py index f36b0fc739e5..0e33253c08f1 100644 --- a/tests/test_executorch.py +++ b/tests/test_executorch.py @@ -56,7 +56,9 @@ def test_static_cache_module_forward(self): cache_config={"batch_size": 1, "max_cache_len": 32, "device": "cpu"}, ) - module = TorchExportableModuleWithStaticCache(self.model, self.model.config, generation_config) + # Set generation config on model + self.model.generation_config = generation_config + module = TorchExportableModuleWithStaticCache(self.model) # Test with input_ids eager_output_ids = self.model(input_ids=self.input_ids, use_cache=False).logits @@ -80,7 +82,9 @@ def test_hybrid_cache_module_forward(self): cache_config={"batch_size": 1, "max_cache_len": 32, "device": "cpu"}, ) - module = TorchExportableModuleWithHybridCache(self.model, config, generation_config) + # Set generation config on model + self.model.generation_config = generation_config + module = TorchExportableModuleWithHybridCache(self.model) # Test with input_ids eager_output_ids = self.model(input_ids=self.input_ids, use_cache=False).logits diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index ca46cd7d788c..74b19395a67f 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -813,9 +813,7 @@ def test_static_cache_exportability(self): from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM - exportable_module = TorchExportableModuleForDecoderOnlyLM( - model, config=model.config, generation_config=model.generation_config - ) + exportable_module = TorchExportableModuleForDecoderOnlyLM(model) exported_program = exportable_module.export( input_ids=input_ids, cache_position=cache_position, @@ -843,10 +841,10 @@ def test_hybrid_cache_exportability(self): model.eval() max_batch_size = 1 max_cache_len = 23 - # Create generation config for the hybrid cache model + # Set generation config on the model for the hybrid cache model from transformers.generation.configuration_utils import GenerationConfig - generation_config = GenerationConfig( + model.generation_config = GenerationConfig( use_cache=True, cache_implementation="hybrid", max_length=max_cache_len, @@ -856,9 +854,7 @@ def test_hybrid_cache_exportability(self): "device": model.device, }, ) - exportable_module = TorchExportableModuleForDecoderOnlyLM( - model, config=model.config, generation_config=generation_config - ) + exportable_module = TorchExportableModuleForDecoderOnlyLM(model) exported_program = exportable_module.export( input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device), cache_position=torch.tensor([0], dtype=torch.long, device=model.device), From e21cf4309f2e33eb567b38eae328040027310367 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Tue, 5 Aug 2025 09:38:45 -0700 Subject: [PATCH 11/11] Ruff check --- src/transformers/integrations/executorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 536cb5994c8b..7b7742e29386 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -50,7 +50,7 @@ def __init__( super().__init__() config = model.config.get_text_config() - generation_config = model.generation_config + _generation_config = model.generation_config if not hasattr(config, "use_cache") or config.use_cache is False: raise ValueError("The model must have caching enabled to be performant.")