@@ -38,43 +38,33 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
3838 def __init__ (
3939 self ,
4040 model : PreTrainedModel ,
41- config : Optional [PretrainedConfig ] = None ,
42- generation_config : Optional [GenerationConfig ] = None ,
4341 ):
4442 """
4543 Initializes the exportable module with `HybridCache`.
4644
4745 Args:
4846 model (`PreTrainedModel`): The pretrained model to wrap.
49- config (`PretrainedConfig`): The pretrained text config for the decoder model.
50- If not specified will try to resolve with the model's config.
51- generation_config (`GenerationConfig`): The generation config for the model.
52- If not specified will try to resolve with the model's generation config.
5347
5448 Raises:
5549 ValueError: If the model is configured with a unsupported cache implementation.
5650 """
5751 super ().__init__ ()
5852
59- if not config :
60- config = model .config
61- if not generation_config :
62- generation_config = model .generation_config
53+ config = model .config .get_text_config ()
54+ generation_config = model .generation_config
6355
6456 if not hasattr (config , "use_cache" ) or config .use_cache is False :
6557 raise ValueError ("The model must have caching enabled to be performant." )
6658
6759 if hasattr (config , "layer_types" ) and getattr (config , "sliding_window" , None ) is not None :
68- self .model = TorchExportableModuleWithHybridCache (
69- model , config = config , generation_config = generation_config
70- )
60+ self .model = TorchExportableModuleWithHybridCache (model )
7161 else :
7262 # If `layer_types` is not specified explicitly in the config or `sliding_window` is null,
7363 # there is only 1 type of layers, so export will use `StaticCache` by default.
7464 logging .info (
7565 "Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config."
7666 )
77- self .model = TorchExportableModuleWithStaticCache (model , config , generation_config )
67+ self .model = TorchExportableModuleWithStaticCache (model )
7868 # This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
7969 ALL_MASK_ATTENTION_FUNCTIONS .register ("sdpa_without_vmap" , sdpa_mask_without_vmap )
8070 ALL_ATTENTION_FUNCTIONS .register ("sdpa_without_vmap" , ALL_ATTENTION_FUNCTIONS ["sdpa" ])
@@ -171,25 +161,23 @@ def export(
171161 )
172162
173163 if input_ids is not None :
174- if cache_position is None :
175- cache_position = torch .arange (input_ids .shape [- 1 ], dtype = torch .long , model = model_device )
176- exported_program = torch .export .export (
177- self .model ,
178- args = (),
179- kwargs = {"input_ids" : input_ids , "cache_position" : cache_position },
180- dynamic_shapes = dynamic_shapes ,
181- strict = strict if strict is not None else True ,
182- )
164+ input_kwargs = {
165+ "input_ids" : input_ids ,
166+ "cache_position" : cache_position if cache_position is not None else torch .arange (input_ids .shape [- 1 ], dtype = torch .long , model = model_device )
167+ }
183168 else : # inputs_embeds
184- if cache_position is None :
185- cache_position = torch .arange (inputs_embeds .shape [1 ], dtype = torch .long , model = model_device )
186- exported_program = torch .export .export (
187- self .model ,
188- args = (),
189- kwargs = {"inputs_embeds" : inputs_embeds , "cache_position" : cache_position },
190- dynamic_shapes = dynamic_shapes ,
191- strict = strict if strict is not None else True ,
192- )
169+ input_kwargs = {
170+ "inputs_embeds" : inputs_embeds ,
171+ "cache_position" : cache_position if cache_position is not None else torch .arange (inputs_embeds .shape [1 ], dtype = torch .long , model = model_device )
172+ }
173+
174+ exported_program = torch .export .export (
175+ self .model ,
176+ args = (),
177+ kwargs = input_kwargs ,
178+ dynamic_shapes = dynamic_shapes ,
179+ strict = strict if strict is not None else True ,
180+ )
193181
194182 return exported_program
195183
@@ -316,24 +304,23 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
316304 def __init__ (
317305 self ,
318306 model : PreTrainedModel ,
319- config : PretrainedConfig ,
320- generation_config : GenerationConfig ,
321307 ):
322308 """
323309 Initializes the wrapper module with the pretrained model.
324310
325311 Args:
326312 model (`PreTrainedModel`): The pretrained model to wrap. The model must have caching
327- enabled and use a 'static' caching implementation.
328- config (`PretrainedConfig`): The pretrained text config for the model.
329- generation_config (`GenerationConfig`): The generation config for the model.
313+ enabled and use a 'static' caching implementation.
330314
331315 Raises:
332316 AssertionError: If the pretrained model does not have caching enabled or if it does
333317 not use a 'static' caching implementation in `model.generation_config`.
334318 """
335319 super ().__init__ ()
336320
321+ config = model .config .get_text_config ()
322+ generation_config = model .generation_config
323+
337324 # Sanity checks
338325 if generation_config is None :
339326 raise AssertionError (
@@ -354,13 +341,11 @@ def __init__(
354341 )
355342
356343 self .model = model
357- self .config = config
358- self .generation_config = generation_config
359344 self .static_cache = StaticCache (
360345 config = config ,
361- max_batch_size = self . generation_config .cache_config .get ("batch_size" ),
362- max_cache_len = self . generation_config .cache_config .get ("max_cache_len" ),
363- device = self . generation_config .cache_config .get ("device" ),
346+ max_batch_size = generation_config .cache_config .get ("batch_size" ),
347+ max_cache_len = generation_config .cache_config .get ("max_cache_len" ),
348+ device = generation_config .cache_config .get ("device" ),
364349 dtype = self .model .dtype ,
365350 )
366351
@@ -471,26 +456,20 @@ class TorchExportableModuleWithHybridCache(torch.nn.Module):
471456 def __init__ (
472457 self ,
473458 model : PreTrainedModel ,
474- config : PretrainedConfig ,
475- generation_config : GenerationConfig ,
476459 ):
477460 """
478461 Initializes the exportable module with `HybridCache`.
479462
480463 Args:
481464 model (`PreTrainedModel`): The pretrained model to wrap.
482- config (`PretrainedConfig`): The pretrained text config for the model.
483- generation_config (`GenerationConfig`): The generation config for the model.
484- max_batch_size (int): Maximum batch size for the cache.
485- max_cache_len (int): Maximum sequence length for the cache.
486465
487466 Raises:
488467 AssertionError: If the model doesn't have the expected configuration for HybridCache.
489468 """
490469 super ().__init__ ()
491470 self .model = model
492- self . config = config
493- self . generation_config = generation_config
471+ config = model . config . get_text_config ()
472+ generation_config = model . generation_config
494473
495474 # Verify the model is configured for HybridCache
496475 if not config .use_cache :
@@ -499,9 +478,9 @@ def __init__(
499478 # Initialize the HybridCache
500479 self .cache = HybridCache (
501480 config = config ,
502- max_batch_size = self . generation_config .cache_config .get ("batch_size" ),
503- max_cache_len = self . generation_config .cache_config .get ("max_cache_len" ),
504- device = self . generation_config .cache_config .get ("device" ),
481+ max_batch_size = generation_config .cache_config .get ("batch_size" ),
482+ max_cache_len = generation_config .cache_config .get ("max_cache_len" ),
483+ device = generation_config .cache_config .get ("device" ),
505484 dtype = self .model .dtype ,
506485 )
507486
@@ -543,8 +522,6 @@ def forward(
543522
544523def convert_and_export_with_cache (
545524 model : PreTrainedModel ,
546- config : PretrainedConfig ,
547- generation_config : GenerationConfig ,
548525 example_input_ids : Optional [torch .Tensor ] = None ,
549526 example_cache_position : Optional [torch .Tensor ] = None ,
550527 dynamic_shapes : Optional [dict ] = None ,
@@ -556,8 +533,6 @@ def convert_and_export_with_cache(
556533
557534 Args:
558535 model (`PreTrainedModel`): The pretrained model to be exported.
559- config (`PretrainedConfig`): The pretrained text config for the decoder model.
560- generation_config (`GenerationConfig`): The generation config for the model.
561536 example_input_ids (`Optional[torch.Tensor]`): Example input token id used by `torch.export`.
562537 example_cache_position (`Optional[torch.Tensor]`): Example current cache position used by `torch.export`.
563538 dynamic_shapes(`Optional[dict]`): Dynamic shapes used by `torch.export`.
@@ -591,7 +566,7 @@ def convert_and_export_with_cache(
591566
592567 if is_torch_greater_or_equal ("2.6.0" ):
593568 exported_program = torch .export .export (
594- TorchExportableModuleWithStaticCache (model = model , config = config , generation_config = generation_config ),
569+ TorchExportableModuleWithStaticCache (model ),
595570 args = (),
596571 kwargs = {"input_ids" : example_input_ids , "cache_position" : example_cache_position },
597572 dynamic_shapes = dynamic_shapes ,
@@ -609,11 +584,7 @@ def convert_and_export_with_cache(
609584 # Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal
610585 # export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release.
611586 exported_program = torch .export ._trace ._export (
612- TorchExportableModuleWithStaticCache (
613- model = model ,
614- config = config ,
615- generation_config = generation_config ,
616- ),
587+ TorchExportableModuleWithStaticCache (model ),
617588 args = (),
618589 kwargs = {"input_ids" : example_input_ids , "cache_position" : example_cache_position },
619590 pre_dispatch = False ,
0 commit comments