Skip to content

Commit 68f21b8

Browse files
committed
AI changes the rest of the call sites
1 parent 35bb9a4 commit 68f21b8

File tree

12 files changed

+68
-22
lines changed

12 files changed

+68
-22
lines changed

tests/models/cohere2/test_modeling_cohere2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,9 @@ def test_export_static_cache(self):
275275
max_new_tokens = 30 - prompt_token_ids.shape[-1]
276276

277277
# Static Cache + export
278-
exported_program = convert_and_export_with_cache(model)
278+
exported_program = convert_and_export_with_cache(
279+
model, config=model.config, generation_config=model.generation_config
280+
)
279281
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
280282
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
281283
)

tests/models/exaone4/test_modeling_exaone4.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,9 @@ def test_export_static_cache(self):
400400
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]
401401

402402
# Static Cache + export
403-
exported_program = convert_and_export_with_cache(model)
403+
exported_program = convert_and_export_with_cache(
404+
model, config=model.config, generation_config=model.generation_config
405+
)
404406
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
405407
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
406408
)

tests/models/gemma/test_modeling_gemma.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -459,8 +459,12 @@ def test_export_static_cache(self):
459459
# Static Cache + export
460460
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
461461

462-
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
463-
exported_program = exportable_module.export()
462+
exportable_module = TorchExportableModuleForDecoderOnlyLM(
463+
model, config=model.config, generation_config=model.generation_config
464+
)
465+
exported_program = exportable_module.export(
466+
input_ids=prompt_token_ids, cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device)
467+
)
464468
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
465469
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
466470
)

tests/models/gemma2/test_modeling_gemma2.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -364,8 +364,12 @@ def test_export_static_cache(self):
364364
# Static Cache + export
365365
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
366366

367-
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
368-
exported_program = exportable_module.export()
367+
exportable_module = TorchExportableModuleForDecoderOnlyLM(
368+
model, config=model.config, generation_config=model.generation_config
369+
)
370+
exported_program = exportable_module.export(
371+
input_ids=prompt_token_ids, cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device)
372+
)
369373
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
370374
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
371375
)
@@ -388,8 +392,13 @@ def test_export_hybrid_cache(self):
388392

389393
# Export + HybridCache
390394
model.eval()
391-
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
392-
exported_program = exportable_module.export()
395+
exportable_module = TorchExportableModuleForDecoderOnlyLM(
396+
model, config=model.config, generation_config=model.generation_config
397+
)
398+
exported_program = exportable_module.export(
399+
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
400+
cache_position=torch.tensor([0], dtype=torch.long, device=model.device)
401+
)
393402

394403
# Test generation with the exported model
395404
prompt = "What is the capital of France?"

tests/models/gemma3/test_modeling_gemma3.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -808,8 +808,13 @@ def test_export_text_only_with_hybrid_cache(self):
808808

809809
# Export + HybridCache
810810
model.eval()
811-
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
812-
exported_program = exportable_module.export()
811+
exportable_module = TorchExportableModuleForDecoderOnlyLM(
812+
model, config=model.config, generation_config=model.generation_config
813+
)
814+
exported_program = exportable_module.export(
815+
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
816+
cache_position=torch.tensor([0], dtype=torch.long, device=model.device)
817+
)
813818
logging.info(f"\nExported program: {exported_program}")
814819

815820
# Test generation with the exported model

tests/models/llama/test_modeling_llama.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,8 +352,12 @@ def test_export_static_cache(self):
352352
# Static Cache + export
353353
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
354354

355-
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
356-
exported_program = exportable_module.export()
355+
exportable_module = TorchExportableModuleForDecoderOnlyLM(
356+
model, config=model.config, generation_config=model.generation_config
357+
)
358+
exported_program = exportable_module.export(
359+
input_ids=prompt_token_ids, cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device)
360+
)
357361
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
358362
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
359363
)

tests/models/olmo/test_modeling_olmo.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,8 +383,12 @@ def test_export_static_cache(self):
383383
# Static Cache + export
384384
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
385385

386-
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
387-
exported_program = exportable_module.export()
386+
exportable_module = TorchExportableModuleForDecoderOnlyLM(
387+
model, config=model.config, generation_config=model.generation_config
388+
)
389+
exported_program = exportable_module.export(
390+
input_ids=prompt_token_ids, cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device)
391+
)
388392
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
389393
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
390394
)

tests/models/olmo2/test_modeling_olmo2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,9 @@ def test_export_static_cache(self):
383383
self.assertEqual(EXPECTED_TEXT_COMPLETION, eager_generated_text)
384384

385385
# Static Cache + export
386-
exported_program = convert_and_export_with_cache(model)
386+
exported_program = convert_and_export_with_cache(
387+
model, config=model.config, generation_config=model.generation_config
388+
)
387389
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
388390
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
389391
)

tests/models/phi3/test_modeling_phi3.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -416,8 +416,12 @@ def test_export_static_cache(self):
416416
# Static Cache + export
417417
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
418418

419-
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
420-
exported_program = exportable_module.export()
419+
exportable_module = TorchExportableModuleForDecoderOnlyLM(
420+
model, config=model.config, generation_config=model.generation_config
421+
)
422+
exported_program = exportable_module.export(
423+
input_ids=prompt_token_ids, cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device)
424+
)
421425
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
422426
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
423427
)

tests/models/qwen2/test_modeling_qwen2.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,11 +299,15 @@ def test_export_static_cache(self):
299299
# Static Cache + export
300300
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
301301

302-
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
302+
exportable_module = TorchExportableModuleForDecoderOnlyLM(
303+
model, config=model.config, generation_config=model.generation_config
304+
)
303305
strict = version.parse(torch.__version__) != version.parse(
304306
"2.7.0"
305307
) # Due to https://github.com/pytorch/pytorch/issues/150994
306-
exported_program = exportable_module.export(strict=strict)
308+
exported_program = exportable_module.export(
309+
input_ids=prompt_token_ids, cache_position=torch.arange(prompt_token_ids.shape[-1], dtype=torch.long, device=model.device), strict=strict
310+
)
307311
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
308312
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
309313
)

0 commit comments

Comments
 (0)