@@ -364,8 +364,12 @@ def test_export_static_cache(self):
364364 # Static Cache + export
365365 from transformers .integrations .executorch import TorchExportableModuleForDecoderOnlyLM
366366
367- exportable_module = TorchExportableModuleForDecoderOnlyLM (model )
368- exported_program = exportable_module .export ()
367+ exportable_module = TorchExportableModuleForDecoderOnlyLM (
368+ model , config = model .config , generation_config = model .generation_config
369+ )
370+ exported_program = exportable_module .export (
371+ input_ids = prompt_token_ids , cache_position = torch .arange (prompt_token_ids .shape [- 1 ], dtype = torch .long , device = model .device )
372+ )
369373 ep_generated_ids = TorchExportableModuleWithStaticCache .generate (
370374 exported_program = exported_program , prompt_token_ids = prompt_token_ids , max_new_tokens = max_new_tokens
371375 )
@@ -388,8 +392,13 @@ def test_export_hybrid_cache(self):
388392
389393 # Export + HybridCache
390394 model .eval ()
391- exportable_module = TorchExportableModuleForDecoderOnlyLM (model )
392- exported_program = exportable_module .export ()
395+ exportable_module = TorchExportableModuleForDecoderOnlyLM (
396+ model , config = model .config , generation_config = model .generation_config
397+ )
398+ exported_program = exportable_module .export (
399+ input_ids = torch .tensor ([[1 ]], dtype = torch .long , device = model .device ),
400+ cache_position = torch .tensor ([0 ], dtype = torch .long , device = model .device )
401+ )
393402
394403 # Test generation with the exported model
395404 prompt = "What is the capital of France?"
0 commit comments