Skip to content

Commit 62da12e

Browse files
committed
Clean up
1 parent eda53a4 commit 62da12e

File tree

2 files changed

+69
-18
lines changed

2 files changed

+69
-18
lines changed

src/transformers/integrations/executorch.py

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import torch
1717

1818
from ..cache_utils import DynamicCache, EncoderDecoderCache, HybridCache, StaticCache
19+
from ..configuration_utils import PretrainedConfig
1920
from ..generation.configuration_utils import GenerationConfig
2021
from ..masking_utils import (
2122
ALL_MASK_ATTENTION_FUNCTIONS,
@@ -47,7 +48,7 @@ def __init__(
4748
4849
Args:
4950
model (`PreTrainedModel`): The pretrained model to wrap.
50-
config (`PreTrainedConfig`): The pretrained text config for the decoder model.
51+
config (`PretrainedConfig`): The pretrained text config for the decoder model.
5152
generation_config (`GenerationConfig`): The generation config for the model.
5253
max_batch_size (int): Maximum batch size for the cache.
5354
max_cache_len (int): Maximum sequence length for the cache.
@@ -82,7 +83,7 @@ def forward(
8283
self,
8384
input_ids: Optional[torch.Tensor] = None,
8485
inputs_embeds: Optional[torch.Tensor] = None,
85-
cache_position: torch.Tensor,
86+
cache_position: Optional[torch.Tensor] = None,
8687
) -> torch.Tensor:
8788
"""
8889
Forward pass of the module, which is compatible with the ExecuTorch llm runner.
@@ -114,16 +115,50 @@ def export(
114115
115116
Args:
116117
input_ids (`Optional[torch.Tensor]`):
117-
Tensor representing current input token id to the module. If this and inputs_embeds are not provided, a default tensor will be used.
118+
Tensor representing current input token id to the module. Must specify either this or inputs_embeds.
118119
inputs_embeds (`Optional[torch.Tensor]`):
119-
Tensor representing current input embeddings to the module.
120+
Tensor representing current input embeddings to the module. Must specify either this or input_ids.
120121
cache_position (`Optional[torch.Tensor]`):
121122
Tensor representing current input position in the cache. If not provided, a default tensor will be used.
122123
dynamic_shapes (`Optional[dict]`):
123124
Dynamic shapes to use for export if specified.
124125
strict(`Optional[bool]`):
125126
Flag to instruct `torch.export` to use `torchdynamo`.
127+
128+
Returns:
129+
torch.export.ExportedProgram: The exported program that can be used for inference.
130+
131+
Examples:
132+
Export with input_ids:
133+
```python
134+
# Prepare inputs
135+
input_ids = torch.tensor([[1, 2, 3]], dtype=torch.long, device=model.device)
136+
cache_position = torch.arange(input_ids.shape[-1], dtype=torch.long, device=model.device)
137+
138+
# Export
139+
exported = exportable_module.export(
140+
input_ids=input_ids,
141+
cache_position=cache_position
142+
)
143+
```
144+
145+
Export with inputs_embeds:
146+
```python
147+
# Prepare embeddings
148+
inputs_embeds = torch.randn(1, 3, 768, device=model.device) # batch_size=1, seq_len=3, hidden_size=768
149+
cache_position = torch.arange(inputs_embeds.shape[1], dtype=torch.long, device=model.device)
150+
151+
# Export
152+
exported = exportable_module.export(
153+
inputs_embeds=inputs_embeds,
154+
cache_position=cache_position
155+
)
156+
```
126157
"""
158+
# Validate inputs early for fail-fast behavior
159+
if not input_ids ^ inputs_embeds:
160+
raise ValueError("Need to specify either input_ids or inputs_embeds.")
161+
127162
if hasattr(self.model, "base_model_prefix"):
128163
base = getattr(self.model, self.model.base_model_prefix, self.model)
129164
model_device = base.device
@@ -135,9 +170,6 @@ def export(
135170
"TorchExportableModuleForDecoderOnlyLM.export Can't infer device from the model. Set to CPU by default."
136171
)
137172

138-
if not input_ids ^ inputs_embeds:
139-
raise ValueError("Need to specify either input_ids or inputs_embeds.")
140-
141173
example_cache_position = (
142174
cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long, device=model_device)
143175
)
@@ -293,7 +325,7 @@ def __init__(
293325
Args:
294326
model (`PreTrainedModel`): The pretrained model to wrap. The model must have caching
295327
enabled and use a 'static' caching implementation.
296-
config (`PreTrainedConfig`): The pretrained text config for the model.
328+
config (`PretrainedConfig`): The pretrained text config for the model.
297329
generation_config (`GenerationConfig`): The generation config for the model.
298330
299331
Raises:
@@ -340,8 +372,8 @@ def __init__(
340372
def forward(
341373
self,
342374
input_ids: Optional[torch.LongTensor] = None,
343-
inputs_embeds: Optional[torch.FloatTensor] = None,
344-
cache_position: torch.Tensor = None,
375+
inputs_embeds: Optional[torch.Tensor] = None,
376+
cache_position: Optional[torch.Tensor] = None,
345377
):
346378
"""
347379
Forward pass of the module, which is compatible with the ExecuTorch runtime.
@@ -448,7 +480,7 @@ def __init__(
448480
449481
Args:
450482
model (`PreTrainedModel`): The pretrained model to wrap.
451-
config (`PreTrainedConfig`): The pretrained text config for the model.
483+
config (`PretrainedConfig`): The pretrained text config for the model.
452484
generation_config (`GenerationConfig`): The generation config for the model.
453485
max_batch_size (int): Maximum batch size for the cache.
454486
max_cache_len (int): Maximum sequence length for the cache.
@@ -482,8 +514,8 @@ def __init__(
482514
def forward(
483515
self,
484516
input_ids: Optional[torch.LongTensor] = None,
485-
inputs_embeds: Optional[torch.FloatTensor] = None,
486-
cache_position: torch.Tensor = None,
517+
inputs_embeds: Optional[torch.Tensor] = None,
518+
cache_position: Optional[torch.Tensor] = None,
487519
) -> torch.Tensor:
488520
"""
489521
Forward pass of the module, which is compatible with the ExecuTorch llm runner.
@@ -523,7 +555,7 @@ def forward(
523555

524556
def convert_and_export_with_cache(
525557
model: PreTrainedModel,
526-
config: PreTrainedConfig,
558+
config: PretrainedConfig,
527559
generation_config: GenerationConfig,
528560
example_input_ids: Optional[torch.Tensor] = None,
529561
example_cache_position: Optional[torch.Tensor] = None,
@@ -536,7 +568,7 @@ def convert_and_export_with_cache(
536568
537569
Args:
538570
model (`PreTrainedModel`): The pretrained model to be exported.
539-
config (`PreTrainedConfig`): The pretrained text config for the decoder model.
571+
config (`PretrainedConfig`): The pretrained text config for the decoder model.
540572
generation_config (`GenerationConfig`): The generation config for the model.
541573
example_input_ids (`Optional[torch.Tensor]`): Example input token id used by `torch.export`.
542574
example_cache_position (`Optional[torch.Tensor]`): Example current cache position used by `torch.export`.

tests/utils/test_cache_utils.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -813,7 +813,9 @@ def test_static_cache_exportability(self):
813813

814814
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
815815

816-
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
816+
exportable_module = TorchExportableModuleForDecoderOnlyLM(
817+
model, config=model.config, generation_config=model.generation_config
818+
)
817819
exported_program = exportable_module.export(
818820
input_ids=input_ids,
819821
cache_position=cache_position,
@@ -841,8 +843,25 @@ def test_hybrid_cache_exportability(self):
841843
model.eval()
842844
max_batch_size = 1
843845
max_cache_len = 23
844-
exportable_module = TorchExportableModuleForDecoderOnlyLM(model, max_batch_size, max_cache_len)
845-
exported_program = exportable_module.export()
846+
# Create generation config for the hybrid cache model
847+
from transformers.generation.configuration_utils import GenerationConfig
848+
generation_config = GenerationConfig(
849+
use_cache=True,
850+
cache_implementation="hybrid",
851+
max_length=max_cache_len,
852+
cache_config={
853+
"batch_size": max_batch_size,
854+
"max_cache_len": max_cache_len,
855+
"device": model.device,
856+
},
857+
)
858+
exportable_module = TorchExportableModuleForDecoderOnlyLM(
859+
model, config=model.config, generation_config=generation_config
860+
)
861+
exported_program = exportable_module.export(
862+
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
863+
cache_position=torch.tensor([0], dtype=torch.long, device=model.device)
864+
)
846865
n_g_key_caches = n_g_value_caches = 0
847866
for buffer_name, buffer in exported_program.named_buffers():
848867
if buffer_name.startswith("key_cache"):

0 commit comments

Comments
 (0)