Skip to content

Commit 8b06186

Browse files
committed
Revert config/generation_config changes
1 parent 8ea7821 commit 8b06186

File tree

15 files changed

+57
-112
lines changed

15 files changed

+57
-112
lines changed

src/transformers/integrations/executorch.py

Lines changed: 34 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -38,43 +38,33 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
3838
def __init__(
3939
self,
4040
model: PreTrainedModel,
41-
config: Optional[PretrainedConfig] = None,
42-
generation_config: Optional[GenerationConfig] = None,
4341
):
4442
"""
4543
Initializes the exportable module with `HybridCache`.
4644
4745
Args:
4846
model (`PreTrainedModel`): The pretrained model to wrap.
49-
config (`PretrainedConfig`): The pretrained text config for the decoder model.
50-
If not specified will try to resolve with the model's config.
51-
generation_config (`GenerationConfig`): The generation config for the model.
52-
If not specified will try to resolve with the model's generation config.
5347
5448
Raises:
5549
ValueError: If the model is configured with a unsupported cache implementation.
5650
"""
5751
super().__init__()
5852

59-
if not config:
60-
config = model.config
61-
if not generation_config:
62-
generation_config = model.generation_config
53+
config = model.config.get_text_config()
54+
generation_config = model.generation_config
6355

6456
if not hasattr(config, "use_cache") or config.use_cache is False:
6557
raise ValueError("The model must have caching enabled to be performant.")
6658

6759
if hasattr(config, "layer_types") and getattr(config, "sliding_window", None) is not None:
68-
self.model = TorchExportableModuleWithHybridCache(
69-
model, config=config, generation_config=generation_config
70-
)
60+
self.model = TorchExportableModuleWithHybridCache(model)
7161
else:
7262
# If `layer_types` is not specified explicitly in the config or `sliding_window` is null,
7363
# there is only 1 type of layers, so export will use `StaticCache` by default.
7464
logging.info(
7565
"Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config."
7666
)
77-
self.model = TorchExportableModuleWithStaticCache(model, config, generation_config)
67+
self.model = TorchExportableModuleWithStaticCache(model)
7868
# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
7969
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
8070
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
@@ -171,25 +161,23 @@ def export(
171161
)
172162

173163
if input_ids is not None:
174-
if cache_position is None:
175-
cache_position = torch.arange(input_ids.shape[-1], dtype=torch.long, model=model_device)
176-
exported_program = torch.export.export(
177-
self.model,
178-
args=(),
179-
kwargs={"input_ids": input_ids, "cache_position": cache_position},
180-
dynamic_shapes=dynamic_shapes,
181-
strict=strict if strict is not None else True,
182-
)
164+
input_kwargs = {
165+
"input_ids": input_ids,
166+
"cache_position": cache_position if cache_position is not None else torch.arange(input_ids.shape[-1], dtype=torch.long, model=model_device)
167+
}
183168
else: # inputs_embeds
184-
if cache_position is None:
185-
cache_position = torch.arange(inputs_embeds.shape[1], dtype=torch.long, model=model_device)
186-
exported_program = torch.export.export(
187-
self.model,
188-
args=(),
189-
kwargs={"inputs_embeds": inputs_embeds, "cache_position": cache_position},
190-
dynamic_shapes=dynamic_shapes,
191-
strict=strict if strict is not None else True,
192-
)
169+
input_kwargs = {
170+
"inputs_embeds": inputs_embeds,
171+
"cache_position": cache_position if cache_position is not None else torch.arange(inputs_embeds.shape[1], dtype=torch.long, model=model_device)
172+
}
173+
174+
exported_program = torch.export.export(
175+
self.model,
176+
args=(),
177+
kwargs=input_kwargs,
178+
dynamic_shapes=dynamic_shapes,
179+
strict=strict if strict is not None else True,
180+
)
193181

194182
return exported_program
195183

@@ -316,24 +304,23 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
316304
def __init__(
317305
self,
318306
model: PreTrainedModel,
319-
config: PretrainedConfig,
320-
generation_config: GenerationConfig,
321307
):
322308
"""
323309
Initializes the wrapper module with the pretrained model.
324310
325311
Args:
326312
model (`PreTrainedModel`): The pretrained model to wrap. The model must have caching
327-
enabled and use a 'static' caching implementation.
328-
config (`PretrainedConfig`): The pretrained text config for the model.
329-
generation_config (`GenerationConfig`): The generation config for the model.
313+
enabled and use a 'static' caching implementation.
330314
331315
Raises:
332316
AssertionError: If the pretrained model does not have caching enabled or if it does
333317
not use a 'static' caching implementation in `model.generation_config`.
334318
"""
335319
super().__init__()
336320

321+
config = model.config.get_text_config()
322+
generation_config = model.generation_config
323+
337324
# Sanity checks
338325
if generation_config is None:
339326
raise AssertionError(
@@ -354,13 +341,11 @@ def __init__(
354341
)
355342

356343
self.model = model
357-
self.config = config
358-
self.generation_config = generation_config
359344
self.static_cache = StaticCache(
360345
config=config,
361-
max_batch_size=self.generation_config.cache_config.get("batch_size"),
362-
max_cache_len=self.generation_config.cache_config.get("max_cache_len"),
363-
device=self.generation_config.cache_config.get("device"),
346+
max_batch_size=generation_config.cache_config.get("batch_size"),
347+
max_cache_len=generation_config.cache_config.get("max_cache_len"),
348+
device=generation_config.cache_config.get("device"),
364349
dtype=self.model.dtype,
365350
)
366351

@@ -471,26 +456,20 @@ class TorchExportableModuleWithHybridCache(torch.nn.Module):
471456
def __init__(
472457
self,
473458
model: PreTrainedModel,
474-
config: PretrainedConfig,
475-
generation_config: GenerationConfig,
476459
):
477460
"""
478461
Initializes the exportable module with `HybridCache`.
479462
480463
Args:
481464
model (`PreTrainedModel`): The pretrained model to wrap.
482-
config (`PretrainedConfig`): The pretrained text config for the model.
483-
generation_config (`GenerationConfig`): The generation config for the model.
484-
max_batch_size (int): Maximum batch size for the cache.
485-
max_cache_len (int): Maximum sequence length for the cache.
486465
487466
Raises:
488467
AssertionError: If the model doesn't have the expected configuration for HybridCache.
489468
"""
490469
super().__init__()
491470
self.model = model
492-
self.config = config
493-
self.generation_config = generation_config
471+
config = model.config.get_text_config()
472+
generation_config = model.generation_config
494473

495474
# Verify the model is configured for HybridCache
496475
if not config.use_cache:
@@ -499,9 +478,9 @@ def __init__(
499478
# Initialize the HybridCache
500479
self.cache = HybridCache(
501480
config=config,
502-
max_batch_size=self.generation_config.cache_config.get("batch_size"),
503-
max_cache_len=self.generation_config.cache_config.get("max_cache_len"),
504-
device=self.generation_config.cache_config.get("device"),
481+
max_batch_size=generation_config.cache_config.get("batch_size"),
482+
max_cache_len=generation_config.cache_config.get("max_cache_len"),
483+
device=generation_config.cache_config.get("device"),
505484
dtype=self.model.dtype,
506485
)
507486

@@ -543,8 +522,6 @@ def forward(
543522

544523
def convert_and_export_with_cache(
545524
model: PreTrainedModel,
546-
config: PretrainedConfig,
547-
generation_config: GenerationConfig,
548525
example_input_ids: Optional[torch.Tensor] = None,
549526
example_cache_position: Optional[torch.Tensor] = None,
550527
dynamic_shapes: Optional[dict] = None,
@@ -556,8 +533,6 @@ def convert_and_export_with_cache(
556533
557534
Args:
558535
model (`PreTrainedModel`): The pretrained model to be exported.
559-
config (`PretrainedConfig`): The pretrained text config for the decoder model.
560-
generation_config (`GenerationConfig`): The generation config for the model.
561536
example_input_ids (`Optional[torch.Tensor]`): Example input token id used by `torch.export`.
562537
example_cache_position (`Optional[torch.Tensor]`): Example current cache position used by `torch.export`.
563538
dynamic_shapes(`Optional[dict]`): Dynamic shapes used by `torch.export`.
@@ -591,7 +566,7 @@ def convert_and_export_with_cache(
591566

592567
if is_torch_greater_or_equal("2.6.0"):
593568
exported_program = torch.export.export(
594-
TorchExportableModuleWithStaticCache(model=model, config=config, generation_config=generation_config),
569+
TorchExportableModuleWithStaticCache(model),
595570
args=(),
596571
kwargs={"input_ids": example_input_ids, "cache_position": example_cache_position},
597572
dynamic_shapes=dynamic_shapes,
@@ -609,11 +584,7 @@ def convert_and_export_with_cache(
609584
# Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal
610585
# export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release.
611586
exported_program = torch.export._trace._export(
612-
TorchExportableModuleWithStaticCache(
613-
model=model,
614-
config=config,
615-
generation_config=generation_config,
616-
),
587+
TorchExportableModuleWithStaticCache(model),
617588
args=(),
618589
kwargs={"input_ids": example_input_ids, "cache_position": example_cache_position},
619590
pre_dispatch=False,

tests/models/cohere2/test_modeling_cohere2.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -275,9 +275,7 @@ def test_export_static_cache(self):
275275
max_new_tokens = 30 - prompt_token_ids.shape[-1]
276276

277277
# Static Cache + export
278-
exported_program = convert_and_export_with_cache(
279-
model, config=model.config, generation_config=model.generation_config
280-
)
278+
exported_program = convert_and_export_with_cache(model)
281279
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
282280
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
283281
)

tests/models/exaone4/test_modeling_exaone4.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -400,9 +400,7 @@ def test_export_static_cache(self):
400400
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]
401401

402402
# Static Cache + export
403-
exported_program = convert_and_export_with_cache(
404-
model, config=model.config, generation_config=model.generation_config
405-
)
403+
exported_program = convert_and_export_with_cache(model)
406404
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
407405
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
408406
)

tests/models/gemma/test_modeling_gemma.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -459,9 +459,7 @@ def test_export_static_cache(self):
459459
# Static Cache + export
460460
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
461461

462-
exportable_module = TorchExportableModuleForDecoderOnlyLM(
463-
model, config=model.config, generation_config=model.generation_config
464-
)
462+
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
465463
exported_program = exportable_module.export(
466464
input_ids=prompt_token_ids,
467465
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),

tests/models/gemma2/test_modeling_gemma2.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -364,9 +364,7 @@ def test_export_static_cache(self):
364364
# Static Cache + export
365365
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
366366

367-
exportable_module = TorchExportableModuleForDecoderOnlyLM(
368-
model, config=model.config, generation_config=model.generation_config
369-
)
367+
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
370368
exported_program = exportable_module.export(
371369
input_ids=prompt_token_ids,
372370
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),
@@ -393,9 +391,7 @@ def test_export_hybrid_cache(self):
393391

394392
# Export + HybridCache
395393
model.eval()
396-
exportable_module = TorchExportableModuleForDecoderOnlyLM(
397-
model, config=model.config, generation_config=model.generation_config
398-
)
394+
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
399395
exported_program = exportable_module.export(
400396
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
401397
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),

tests/models/gemma3/test_modeling_gemma3.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -808,9 +808,7 @@ def test_export_text_only_with_hybrid_cache(self):
808808

809809
# Export + HybridCache
810810
model.eval()
811-
exportable_module = TorchExportableModuleForDecoderOnlyLM(
812-
model, config=model.config, generation_config=model.generation_config
813-
)
811+
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
814812
exported_program = exportable_module.export(
815813
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
816814
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),

tests/models/llama/test_modeling_llama.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -352,9 +352,7 @@ def test_export_static_cache(self):
352352
# Static Cache + export
353353
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
354354

355-
exportable_module = TorchExportableModuleForDecoderOnlyLM(
356-
model, config=model.config, generation_config=model.generation_config
357-
)
355+
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
358356
exported_program = exportable_module.export(
359357
input_ids=prompt_token_ids,
360358
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),

tests/models/olmo/test_modeling_olmo.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -383,9 +383,7 @@ def test_export_static_cache(self):
383383
# Static Cache + export
384384
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
385385

386-
exportable_module = TorchExportableModuleForDecoderOnlyLM(
387-
model, config=model.config, generation_config=model.generation_config
388-
)
386+
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
389387
exported_program = exportable_module.export(
390388
input_ids=prompt_token_ids,
391389
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),

tests/models/olmo2/test_modeling_olmo2.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -383,9 +383,7 @@ def test_export_static_cache(self):
383383
self.assertEqual(EXPECTED_TEXT_COMPLETION, eager_generated_text)
384384

385385
# Static Cache + export
386-
exported_program = convert_and_export_with_cache(
387-
model, config=model.config, generation_config=model.generation_config
388-
)
386+
exported_program = convert_and_export_with_cache(model)
389387
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
390388
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
391389
)

tests/models/phi3/test_modeling_phi3.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -416,9 +416,7 @@ def test_export_static_cache(self):
416416
# Static Cache + export
417417
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
418418

419-
exportable_module = TorchExportableModuleForDecoderOnlyLM(
420-
model, config=model.config, generation_config=model.generation_config
421-
)
419+
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
422420
exported_program = exportable_module.export(
423421
input_ids=prompt_token_ids,
424422
cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device),

0 commit comments

Comments
 (0)