1616import torch
1717
1818from ..cache_utils import DynamicCache , EncoderDecoderCache , HybridCache , StaticCache
19+ from ..configuration_utils import PretrainedConfig
1920from ..generation .configuration_utils import GenerationConfig
2021from ..masking_utils import (
2122 ALL_MASK_ATTENTION_FUNCTIONS ,
@@ -47,7 +48,7 @@ def __init__(
4748
4849 Args:
4950 model (`PreTrainedModel`): The pretrained model to wrap.
50- config (`PreTrainedConfig `): The pretrained text config for the decoder model.
51+ config (`PretrainedConfig `): The pretrained text config for the decoder model.
5152 generation_config (`GenerationConfig`): The generation config for the model.
5253 max_batch_size (int): Maximum batch size for the cache.
5354 max_cache_len (int): Maximum sequence length for the cache.
@@ -82,7 +83,7 @@ def forward(
8283 self ,
8384 input_ids : Optional [torch .Tensor ] = None ,
8485 inputs_embeds : Optional [torch .Tensor ] = None ,
85- cache_position : torch .Tensor ,
86+ cache_position : Optional [ torch .Tensor ] = None ,
8687 ) -> torch .Tensor :
8788 """
8889 Forward pass of the module, which is compatible with the ExecuTorch llm runner.
@@ -114,16 +115,50 @@ def export(
114115
115116 Args:
116117 input_ids (`Optional[torch.Tensor]`):
117- Tensor representing current input token id to the module. If this and inputs_embeds are not provided, a default tensor will be used .
118+ Tensor representing current input token id to the module. Must specify either this or inputs_embeds .
118119 inputs_embeds (`Optional[torch.Tensor]`):
119- Tensor representing current input embeddings to the module.
120+ Tensor representing current input embeddings to the module. Must specify either this or input_ids.
120121 cache_position (`Optional[torch.Tensor]`):
121122 Tensor representing current input position in the cache. If not provided, a default tensor will be used.
122123 dynamic_shapes (`Optional[dict]`):
123124 Dynamic shapes to use for export if specified.
124125 strict(`Optional[bool]`):
125126 Flag to instruct `torch.export` to use `torchdynamo`.
127+
128+ Returns:
129+ torch.export.ExportedProgram: The exported program that can be used for inference.
130+
131+ Examples:
132+ Export with input_ids:
133+ ```python
134+ # Prepare inputs
135+ input_ids = torch.tensor([[1, 2, 3]], dtype=torch.long, device=model.device)
136+ cache_position = torch.arange(input_ids.shape[-1], dtype=torch.long, device=model.device)
137+
138+ # Export
139+ exported = exportable_module.export(
140+ input_ids=input_ids,
141+ cache_position=cache_position
142+ )
143+ ```
144+
145+ Export with inputs_embeds:
146+ ```python
147+ # Prepare embeddings
148+ inputs_embeds = torch.randn(1, 3, 768, device=model.device) # batch_size=1, seq_len=3, hidden_size=768
149+ cache_position = torch.arange(inputs_embeds.shape[1], dtype=torch.long, device=model.device)
150+
151+ # Export
152+ exported = exportable_module.export(
153+ inputs_embeds=inputs_embeds,
154+ cache_position=cache_position
155+ )
156+ ```
126157 """
158+ # Validate inputs early for fail-fast behavior
159+ if not input_ids ^ inputs_embeds :
160+ raise ValueError ("Need to specify either input_ids or inputs_embeds." )
161+
127162 if hasattr (self .model , "base_model_prefix" ):
128163 base = getattr (self .model , self .model .base_model_prefix , self .model )
129164 model_device = base .device
@@ -135,9 +170,6 @@ def export(
135170 "TorchExportableModuleForDecoderOnlyLM.export Can't infer device from the model. Set to CPU by default."
136171 )
137172
138- if not input_ids ^ inputs_embeds :
139- raise ValueError ("Need to specify either input_ids or inputs_embeds." )
140-
141173 example_cache_position = (
142174 cache_position if cache_position is not None else torch .tensor ([0 ], dtype = torch .long , device = model_device )
143175 )
@@ -293,7 +325,7 @@ def __init__(
293325 Args:
294326 model (`PreTrainedModel`): The pretrained model to wrap. The model must have caching
295327 enabled and use a 'static' caching implementation.
296- config (`PreTrainedConfig `): The pretrained text config for the model.
328+ config (`PretrainedConfig `): The pretrained text config for the model.
297329 generation_config (`GenerationConfig`): The generation config for the model.
298330
299331 Raises:
@@ -340,8 +372,8 @@ def __init__(
340372 def forward (
341373 self ,
342374 input_ids : Optional [torch .LongTensor ] = None ,
343- inputs_embeds : Optional [torch .FloatTensor ] = None ,
344- cache_position : torch .Tensor = None ,
375+ inputs_embeds : Optional [torch .Tensor ] = None ,
376+ cache_position : Optional [ torch .Tensor ] = None ,
345377 ):
346378 """
347379 Forward pass of the module, which is compatible with the ExecuTorch runtime.
@@ -448,7 +480,7 @@ def __init__(
448480
449481 Args:
450482 model (`PreTrainedModel`): The pretrained model to wrap.
451- config (`PreTrainedConfig `): The pretrained text config for the model.
483+ config (`PretrainedConfig `): The pretrained text config for the model.
452484 generation_config (`GenerationConfig`): The generation config for the model.
453485 max_batch_size (int): Maximum batch size for the cache.
454486 max_cache_len (int): Maximum sequence length for the cache.
@@ -482,8 +514,8 @@ def __init__(
482514 def forward (
483515 self ,
484516 input_ids : Optional [torch .LongTensor ] = None ,
485- inputs_embeds : Optional [torch .FloatTensor ] = None ,
486- cache_position : torch .Tensor = None ,
517+ inputs_embeds : Optional [torch .Tensor ] = None ,
518+ cache_position : Optional [ torch .Tensor ] = None ,
487519 ) -> torch .Tensor :
488520 """
489521 Forward pass of the module, which is compatible with the ExecuTorch llm runner.
@@ -523,7 +555,7 @@ def forward(
523555
524556def convert_and_export_with_cache (
525557 model : PreTrainedModel ,
526- config : PreTrainedConfig ,
558+ config : PretrainedConfig ,
527559 generation_config : GenerationConfig ,
528560 example_input_ids : Optional [torch .Tensor ] = None ,
529561 example_cache_position : Optional [torch .Tensor ] = None ,
@@ -536,7 +568,7 @@ def convert_and_export_with_cache(
536568
537569 Args:
538570 model (`PreTrainedModel`): The pretrained model to be exported.
539- config (`PreTrainedConfig `): The pretrained text config for the decoder model.
571+ config (`PretrainedConfig `): The pretrained text config for the decoder model.
540572 generation_config (`GenerationConfig`): The generation config for the model.
541573 example_input_ids (`Optional[torch.Tensor]`): Example input token id used by `torch.export`.
542574 example_cache_position (`Optional[torch.Tensor]`): Example current cache position used by `torch.export`.
0 commit comments