Skip to content

Commit a2c29fb

Browse files
committed
Revert config/generation_config changes
1 parent 8ea7821 commit a2c29fb

File tree

15 files changed

+41
-94
lines changed

15 files changed

+41
-94
lines changed

src/transformers/integrations/executorch.py

Lines changed: 18 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -38,43 +38,33 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
3838
def __init__(
3939
self,
4040
model: PreTrainedModel,
41-
config: Optional[PretrainedConfig] = None,
42-
generation_config: Optional[GenerationConfig] = None,
4341
):
4442
"""
4543
Initializes the exportable module with `HybridCache`.
4644
4745
Args:
4846
model (`PreTrainedModel`): The pretrained model to wrap.
49-
config (`PretrainedConfig`): The pretrained text config for the decoder model.
50-
If not specified will try to resolve with the model's config.
51-
generation_config (`GenerationConfig`): The generation config for the model.
52-
If not specified will try to resolve with the model's generation config.
5347
5448
Raises:
5549
ValueError: If the model is configured with a unsupported cache implementation.
5650
"""
5751
super().__init__()
5852

59-
if not config:
60-
config = model.config
61-
if not generation_config:
62-
generation_config = model.generation_config
53+
config = model.config.text_config()
54+
generation_config = model.generation_config
6355

6456
if not hasattr(config, "use_cache") or config.use_cache is False:
6557
raise ValueError("The model must have caching enabled to be performant.")
6658

6759
if hasattr(config, "layer_types") and getattr(config, "sliding_window", None) is not None:
68-
self.model = TorchExportableModuleWithHybridCache(
69-
model, config=config, generation_config=generation_config
70-
)
60+
self.model = TorchExportableModuleWithHybridCache(model)
7161
else:
7262
# If `layer_types` is not specified explicitly in the config or `sliding_window` is null,
7363
# there is only 1 type of layers, so export will use `StaticCache` by default.
7464
logging.info(
7565
"Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config."
7666
)
77-
self.model = TorchExportableModuleWithStaticCache(model, config, generation_config)
67+
self.model = TorchExportableModuleWithStaticCache(model)
7868
# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
7969
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
8070
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
@@ -316,24 +306,23 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
316306
def __init__(
317307
self,
318308
model: PreTrainedModel,
319-
config: PretrainedConfig,
320-
generation_config: GenerationConfig,
321309
):
322310
"""
323311
Initializes the wrapper module with the pretrained model.
324312
325313
Args:
326314
model (`PreTrainedModel`): The pretrained model to wrap. The model must have caching
327-
enabled and use a 'static' caching implementation.
328-
config (`PretrainedConfig`): The pretrained text config for the model.
329-
generation_config (`GenerationConfig`): The generation config for the model.
315+
enabled and use a 'static' caching implementation.
330316
331317
Raises:
332318
AssertionError: If the pretrained model does not have caching enabled or if it does
333319
not use a 'static' caching implementation in `model.generation_config`.
334320
"""
335321
super().__init__()
336322

323+
config = model.config.text_config()
324+
generation_config = model.generation_config
325+
337326
# Sanity checks
338327
if generation_config is None:
339328
raise AssertionError(
@@ -354,13 +343,11 @@ def __init__(
354343
)
355344

356345
self.model = model
357-
self.config = config
358-
self.generation_config = generation_config
359346
self.static_cache = StaticCache(
360347
config=config,
361-
max_batch_size=self.generation_config.cache_config.get("batch_size"),
362-
max_cache_len=self.generation_config.cache_config.get("max_cache_len"),
363-
device=self.generation_config.cache_config.get("device"),
348+
max_batch_size=generation_config.cache_config.get("batch_size"),
349+
max_cache_len=generation_config.cache_config.get("max_cache_len"),
350+
device=generation_config.cache_config.get("device"),
364351
dtype=self.model.dtype,
365352
)
366353

@@ -471,26 +458,20 @@ class TorchExportableModuleWithHybridCache(torch.nn.Module):
471458
def __init__(
472459
self,
473460
model: PreTrainedModel,
474-
config: PretrainedConfig,
475-
generation_config: GenerationConfig,
476461
):
477462
"""
478463
Initializes the exportable module with `HybridCache`.
479464
480465
Args:
481466
model (`PreTrainedModel`): The pretrained model to wrap.
482-
config (`PretrainedConfig`): The pretrained text config for the model.
483-
generation_config (`GenerationConfig`): The generation config for the model.
484-
max_batch_size (int): Maximum batch size for the cache.
485-
max_cache_len (int): Maximum sequence length for the cache.
486467
487468
Raises:
488469
AssertionError: If the model doesn't have the expected configuration for HybridCache.
489470
"""
490471
super().__init__()
491472
self.model = model
492-
self.config = config
493-
self.generation_config = generation_config
473+
config = model.config.text_config()
474+
generation_config = model.generation_config
494475

495476
# Verify the model is configured for HybridCache
496477
if not config.use_cache:
@@ -499,9 +480,9 @@ def __init__(
499480
# Initialize the HybridCache
500481
self.cache = HybridCache(
501482
config=config,
502-
max_batch_size=self.generation_config.cache_config.get("batch_size"),
503-
max_cache_len=self.generation_config.cache_config.get("max_cache_len"),
504-
device=self.generation_config.cache_config.get("device"),
483+
max_batch_size=generation_config.cache_config.get("batch_size"),
484+
max_cache_len=generation_config.cache_config.get("max_cache_len"),
485+
device=generation_config.cache_config.get("device"),
505486
dtype=self.model.dtype,
506487
)
507488

@@ -543,8 +524,6 @@ def forward(
543524

544525
def convert_and_export_with_cache(
545526
model: PreTrainedModel,
546-
config: PretrainedConfig,
547-
generation_config: GenerationConfig,
548527
example_input_ids: Optional[torch.Tensor] = None,
549528
example_cache_position: Optional[torch.Tensor] = None,
550529
dynamic_shapes: Optional[dict] = None,
@@ -556,8 +535,6 @@ def convert_and_export_with_cache(
556535
557536
Args:
558537
model (`PreTrainedModel`): The pretrained model to be exported.
559-
config (`PretrainedConfig`): The pretrained text config for the decoder model.
560-
generation_config (`GenerationConfig`): The generation config for the model.
561538
example_input_ids (`Optional[torch.Tensor]`): Example input token id used by `torch.export`.
562539
example_cache_position (`Optional[torch.Tensor]`): Example current cache position used by `torch.export`.
563540
dynamic_shapes(`Optional[dict]`): Dynamic shapes used by `torch.export`.
@@ -591,7 +568,7 @@ def convert_and_export_with_cache(
591568

592569
if is_torch_greater_or_equal("2.6.0"):
593570
exported_program = torch.export.export(
594-
TorchExportableModuleWithStaticCache(model=model, config=config, generation_config=generation_config),
571+
TorchExportableModuleWithStaticCache(model),
595572
args=(),
596573
kwargs={"input_ids": example_input_ids, "cache_position": example_cache_position},
597574
dynamic_shapes=dynamic_shapes,
@@ -609,11 +586,7 @@ def convert_and_export_with_cache(
609586
# Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal
610587
# export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release.
611588
exported_program = torch.export._trace._export(
612-
TorchExportableModuleWithStaticCache(
613-
model=model,
614-
config=config,
615-
generation_config=generation_config,
616-
),
589+
TorchExportableModuleWithStaticCache(model),
617590
args=(),
618591
kwargs={"input_ids": example_input_ids, "cache_position": example_cache_position},
619592
pre_dispatch=False,

tests/models/cohere2/test_modeling_cohere2.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -275,9 +275,7 @@ def test_export_static_cache(self):
275275
max_new_tokens = 30 - prompt_token_ids.shape[-1]
276276

277277
# Static Cache + export
278-
exported_program = convert_and_export_with_cache(
279-
model, config=model.config, generation_config=model.generation_config
280-
)
278+
exported_program = convert_and_export_with_cache(model)
281279
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
282280
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
283281
)

tests/models/exaone4/test_modeling_exaone4.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -400,9 +400,7 @@ def test_export_static_cache(self):
400400
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]
401401

402402
# Static Cache + export
403-
exported_program = convert_and_export_with_cache(
404-
model, config=model.config, generation_config=model.generation_config
405-
)
403+
exported_program = convert_and_export_with_cache(model)
406404
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
407405
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
408406
)

tests/models/gemma/test_modeling_gemma.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -459,9 +459,7 @@ def test_export_static_cache(self):
459459
# Static Cache + export
460460
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
461461

462-
exportable_module = TorchExportableModuleForDecoderOnlyLM(
463-
model, config=model.config, generation_config=model.generation_config
464-
)
462+
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
465463
exported_program = exportable_module.export(
466464
input_ids=prompt_token_ids,
467465
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),

tests/models/gemma2/test_modeling_gemma2.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -364,9 +364,7 @@ def test_export_static_cache(self):
364364
# Static Cache + export
365365
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
366366

367-
exportable_module = TorchExportableModuleForDecoderOnlyLM(
368-
model, config=model.config, generation_config=model.generation_config
369-
)
367+
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
370368
exported_program = exportable_module.export(
371369
input_ids=prompt_token_ids,
372370
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
@@ -393,9 +391,7 @@ def test_export_hybrid_cache(self):
393391

394392
# Export + HybridCache
395393
model.eval()
396-
exportable_module = TorchExportableModuleForDecoderOnlyLM(
397-
model, config=model.config, generation_config=model.generation_config
398-
)
394+
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
399395
exported_program = exportable_module.export(
400396
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
401397
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),

tests/models/gemma3/test_modeling_gemma3.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -808,9 +808,7 @@ def test_export_text_only_with_hybrid_cache(self):
808808

809809
# Export + HybridCache
810810
model.eval()
811-
exportable_module = TorchExportableModuleForDecoderOnlyLM(
812-
model, config=model.config, generation_config=model.generation_config
813-
)
811+
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
814812
exported_program = exportable_module.export(
815813
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
816814
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),

tests/models/llama/test_modeling_llama.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -352,9 +352,7 @@ def test_export_static_cache(self):
352352
# Static Cache + export
353353
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
354354

355-
exportable_module = TorchExportableModuleForDecoderOnlyLM(
356-
model, config=model.config, generation_config=model.generation_config
357-
)
355+
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
358356
exported_program = exportable_module.export(
359357
input_ids=prompt_token_ids,
360358
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),

tests/models/olmo/test_modeling_olmo.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -383,9 +383,7 @@ def test_export_static_cache(self):
383383
# Static Cache + export
384384
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
385385

386-
exportable_module = TorchExportableModuleForDecoderOnlyLM(
387-
model, config=model.config, generation_config=model.generation_config
388-
)
386+
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
389387
exported_program = exportable_module.export(
390388
input_ids=prompt_token_ids,
391389
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),

tests/models/olmo2/test_modeling_olmo2.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -383,9 +383,7 @@ def test_export_static_cache(self):
383383
self.assertEqual(EXPECTED_TEXT_COMPLETION, eager_generated_text)
384384

385385
# Static Cache + export
386-
exported_program = convert_and_export_with_cache(
387-
model, config=model.config, generation_config=model.generation_config
388-
)
386+
exported_program = convert_and_export_with_cache(model)
389387
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
390388
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
391389
)

tests/models/phi3/test_modeling_phi3.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -416,9 +416,7 @@ def test_export_static_cache(self):
416416
# Static Cache + export
417417
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
418418

419-
exportable_module = TorchExportableModuleForDecoderOnlyLM(
420-
model, config=model.config, generation_config=model.generation_config
421-
)
419+
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
422420
exported_program = exportable_module.export(
423421
input_ids=prompt_token_ids,
424422
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),

0 commit comments

Comments
 (0)