@@ -38,43 +38,33 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
3838 def __init__ (
3939 self ,
4040 model : PreTrainedModel ,
41- config : Optional [PretrainedConfig ] = None ,
42- generation_config : Optional [GenerationConfig ] = None ,
4341 ):
4442 """
4543 Initializes the exportable module with `HybridCache`.
4644
4745 Args:
4846 model (`PreTrainedModel`): The pretrained model to wrap.
49- config (`PretrainedConfig`): The pretrained text config for the decoder model.
50- If not specified will try to resolve with the model's config.
51- generation_config (`GenerationConfig`): The generation config for the model.
52- If not specified will try to resolve with the model's generation config.
5347
5448 Raises:
5549 ValueError: If the model is configured with a unsupported cache implementation.
5650 """
5751 super ().__init__ ()
5852
59- if not config :
60- config = model .config
61- if not generation_config :
62- generation_config = model .generation_config
53+ config = model .config .text_config ()
54+ generation_config = model .generation_config
6355
6456 if not hasattr (config , "use_cache" ) or config .use_cache is False :
6557 raise ValueError ("The model must have caching enabled to be performant." )
6658
6759 if hasattr (config , "layer_types" ) and getattr (config , "sliding_window" , None ) is not None :
68- self .model = TorchExportableModuleWithHybridCache (
69- model , config = config , generation_config = generation_config
70- )
60+ self .model = TorchExportableModuleWithHybridCache (model )
7161 else :
7262 # If `layer_types` is not specified explicitly in the config or `sliding_window` is null,
7363 # there is only 1 type of layers, so export will use `StaticCache` by default.
7464 logging .info (
7565 "Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config."
7666 )
77- self .model = TorchExportableModuleWithStaticCache (model , config , generation_config )
67+ self .model = TorchExportableModuleWithStaticCache (model )
7868 # This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
7969 ALL_MASK_ATTENTION_FUNCTIONS .register ("sdpa_without_vmap" , sdpa_mask_without_vmap )
8070 ALL_ATTENTION_FUNCTIONS .register ("sdpa_without_vmap" , ALL_ATTENTION_FUNCTIONS ["sdpa" ])
@@ -316,24 +306,23 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
316306 def __init__ (
317307 self ,
318308 model : PreTrainedModel ,
319- config : PretrainedConfig ,
320- generation_config : GenerationConfig ,
321309 ):
322310 """
323311 Initializes the wrapper module with the pretrained model.
324312
325313 Args:
326314 model (`PreTrainedModel`): The pretrained model to wrap. The model must have caching
327- enabled and use a 'static' caching implementation.
328- config (`PretrainedConfig`): The pretrained text config for the model.
329- generation_config (`GenerationConfig`): The generation config for the model.
315+ enabled and use a 'static' caching implementation.
330316
331317 Raises:
332318 AssertionError: If the pretrained model does not have caching enabled or if it does
333319 not use a 'static' caching implementation in `model.generation_config`.
334320 """
335321 super ().__init__ ()
336322
323+ config = model .config .text_config ()
324+ generation_config = model .generation_config
325+
337326 # Sanity checks
338327 if generation_config is None :
339328 raise AssertionError (
@@ -354,13 +343,11 @@ def __init__(
354343 )
355344
356345 self .model = model
357- self .config = config
358- self .generation_config = generation_config
359346 self .static_cache = StaticCache (
360347 config = config ,
361- max_batch_size = self . generation_config .cache_config .get ("batch_size" ),
362- max_cache_len = self . generation_config .cache_config .get ("max_cache_len" ),
363- device = self . generation_config .cache_config .get ("device" ),
348+ max_batch_size = generation_config .cache_config .get ("batch_size" ),
349+ max_cache_len = generation_config .cache_config .get ("max_cache_len" ),
350+ device = generation_config .cache_config .get ("device" ),
364351 dtype = self .model .dtype ,
365352 )
366353
@@ -471,26 +458,20 @@ class TorchExportableModuleWithHybridCache(torch.nn.Module):
471458 def __init__ (
472459 self ,
473460 model : PreTrainedModel ,
474- config : PretrainedConfig ,
475- generation_config : GenerationConfig ,
476461 ):
477462 """
478463 Initializes the exportable module with `HybridCache`.
479464
480465 Args:
481466 model (`PreTrainedModel`): The pretrained model to wrap.
482- config (`PretrainedConfig`): The pretrained text config for the model.
483- generation_config (`GenerationConfig`): The generation config for the model.
484- max_batch_size (int): Maximum batch size for the cache.
485- max_cache_len (int): Maximum sequence length for the cache.
486467
487468 Raises:
488469 AssertionError: If the model doesn't have the expected configuration for HybridCache.
489470 """
490471 super ().__init__ ()
491472 self .model = model
492- self . config = config
493- self . generation_config = generation_config
473+ config = model . config . text_config ()
474+ generation_config = model . generation_config
494475
495476 # Verify the model is configured for HybridCache
496477 if not config .use_cache :
@@ -499,9 +480,9 @@ def __init__(
499480 # Initialize the HybridCache
500481 self .cache = HybridCache (
501482 config = config ,
502- max_batch_size = self . generation_config .cache_config .get ("batch_size" ),
503- max_cache_len = self . generation_config .cache_config .get ("max_cache_len" ),
504- device = self . generation_config .cache_config .get ("device" ),
483+ max_batch_size = generation_config .cache_config .get ("batch_size" ),
484+ max_cache_len = generation_config .cache_config .get ("max_cache_len" ),
485+ device = generation_config .cache_config .get ("device" ),
505486 dtype = self .model .dtype ,
506487 )
507488
@@ -543,8 +524,6 @@ def forward(
543524
544525def convert_and_export_with_cache (
545526 model : PreTrainedModel ,
546- config : PretrainedConfig ,
547- generation_config : GenerationConfig ,
548527 example_input_ids : Optional [torch .Tensor ] = None ,
549528 example_cache_position : Optional [torch .Tensor ] = None ,
550529 dynamic_shapes : Optional [dict ] = None ,
@@ -556,8 +535,6 @@ def convert_and_export_with_cache(
556535
557536 Args:
558537 model (`PreTrainedModel`): The pretrained model to be exported.
559- config (`PretrainedConfig`): The pretrained text config for the decoder model.
560- generation_config (`GenerationConfig`): The generation config for the model.
561538 example_input_ids (`Optional[torch.Tensor]`): Example input token id used by `torch.export`.
562539 example_cache_position (`Optional[torch.Tensor]`): Example current cache position used by `torch.export`.
563540 dynamic_shapes(`Optional[dict]`): Dynamic shapes used by `torch.export`.
@@ -591,7 +568,7 @@ def convert_and_export_with_cache(
591568
592569 if is_torch_greater_or_equal ("2.6.0" ):
593570 exported_program = torch .export .export (
594- TorchExportableModuleWithStaticCache (model = model , config = config , generation_config = generation_config ),
571+ TorchExportableModuleWithStaticCache (model ),
595572 args = (),
596573 kwargs = {"input_ids" : example_input_ids , "cache_position" : example_cache_position },
597574 dynamic_shapes = dynamic_shapes ,
@@ -609,11 +586,7 @@ def convert_and_export_with_cache(
609586 # Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal
610587 # export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release.
611588 exported_program = torch .export ._trace ._export (
612- TorchExportableModuleWithStaticCache (
613- model = model ,
614- config = config ,
615- generation_config = generation_config ,
616- ),
589+ TorchExportableModuleWithStaticCache (model ),
617590 args = (),
618591 kwargs = {"input_ids" : example_input_ids , "cache_position" : example_cache_position },
619592 pre_dispatch = False ,
0 commit comments