Skip to content

Commit 8ea7821

Browse files
committed
pr review
1 parent bcb30e9 commit 8ea7821

File tree

2 files changed

+137
-13
lines changed

2 files changed

+137
-13
lines changed

src/transformers/integrations/executorch.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,26 +38,29 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
3838
def __init__(
3939
self,
4040
model: PreTrainedModel,
41-
config: PretrainedConfig,
42-
generation_config: GenerationConfig,
43-
max_batch_size: int = 1,
44-
max_cache_len: int = 4096,
41+
config: Optional[PretrainedConfig] = None,
42+
generation_config: Optional[GenerationConfig] = None,
4543
):
4644
"""
4745
Initializes the exportable module with `HybridCache`.
4846
4947
Args:
5048
model (`PreTrainedModel`): The pretrained model to wrap.
5149
config (`PretrainedConfig`): The pretrained text config for the decoder model.
50+
If not specified will try to resolve with the model's config.
5251
generation_config (`GenerationConfig`): The generation config for the model.
53-
max_batch_size (int): Maximum batch size for the cache.
54-
max_cache_len (int): Maximum sequence length for the cache.
52+
If not specified will try to resolve with the model's generation config.
5553
5654
Raises:
5755
ValueError: If the model is configured with a unsupported cache implementation.
5856
"""
5957
super().__init__()
6058

59+
if not config:
60+
config = model.config
61+
if not generation_config:
62+
generation_config = model.generation_config
63+
6164
if not hasattr(config, "use_cache") or config.use_cache is False:
6265
raise ValueError("The model must have caching enabled to be performant.")
6366

@@ -167,13 +170,9 @@ def export(
167170
"TorchExportableModuleForDecoderOnlyLM.export Can't infer device from the model. Set to CPU by default."
168171
)
169172

170-
example_cache_position = (
171-
cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long, device=model_device)
172-
)
173-
174173
if input_ids is not None:
175174
if cache_position is None:
176-
cache_position = torch.arange(input_ids.shape[-1], dtype=torch.long)
175+
cache_position = torch.arange(input_ids.shape[-1], dtype=torch.long, model=model_device)
177176
exported_program = torch.export.export(
178177
self.model,
179178
args=(),
@@ -183,11 +182,11 @@ def export(
183182
)
184183
else: # inputs_embeds
185184
if cache_position is None:
186-
cache_position = torch.arange(inputs_embeds.shape[1], dtype=torch.long)
185+
cache_position = torch.arange(inputs_embeds.shape[1], dtype=torch.long, model=model_device)
187186
exported_program = torch.export.export(
188187
self.model,
189188
args=(),
190-
kwargs={"inputs_embeds": inputs_embeds, "cache_position": example_cache_position},
189+
kwargs={"inputs_embeds": inputs_embeds, "cache_position": cache_position},
191190
dynamic_shapes=dynamic_shapes,
192191
strict=strict if strict is not None else True,
193192
)

tests/test_executorch.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Copyright 2025 HuggingFace Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import torch
18+
19+
from transformers import AutoModelForCausalLM, set_seed
20+
from transformers.generation.configuration_utils import GenerationConfig
21+
from transformers.integrations.executorch import (
22+
TorchExportableModuleForDecoderOnlyLM,
23+
TorchExportableModuleWithHybridCache,
24+
TorchExportableModuleWithStaticCache,
25+
)
26+
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3
27+
from transformers.testing_utils import require_torch
28+
29+
30+
@require_torch
31+
class ExecutorchTest(unittest.TestCase):
32+
def setUp(self):
33+
if not is_torch_greater_or_equal_than_2_3:
34+
self.skipTest("torch >= 2.3 is required")
35+
36+
set_seed(0)
37+
self.model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
38+
self.model.eval()
39+
40+
# Create generation config with static cache for the model
41+
self.model.generation_config = GenerationConfig(
42+
use_cache=True,
43+
cache_implementation="static",
44+
cache_config={"batch_size": 1, "max_cache_len": 32, "device": "cpu"},
45+
)
46+
47+
self.input_ids = torch.tensor([[1, 2, 3]], dtype=torch.long)
48+
self.inputs_embeds = torch.randn(1, 3, self.model.config.hidden_size)
49+
self.cache_position = torch.arange(3, dtype=torch.long)
50+
51+
def test_static_cache_module_forward(self):
52+
"""Test TorchExportableModuleWithStaticCache forward with both input types"""
53+
generation_config = GenerationConfig(
54+
use_cache=True,
55+
cache_implementation="static",
56+
cache_config={"batch_size": 1, "max_cache_len": 32, "device": "cpu"},
57+
)
58+
59+
module = TorchExportableModuleWithStaticCache(self.model, self.model.config, generation_config)
60+
61+
# Test with input_ids
62+
eager_output_ids = self.model(input_ids=self.input_ids, use_cache=False).logits
63+
wrapped_output_ids = module.forward(input_ids=self.input_ids, cache_position=self.cache_position)
64+
torch.testing.assert_close(eager_output_ids, wrapped_output_ids, atol=1e-4, rtol=1e-4)
65+
66+
# Test with inputs_embeds
67+
eager_output_embeds = self.model(inputs_embeds=self.inputs_embeds, use_cache=False).logits
68+
wrapped_output_embeds = module.forward(inputs_embeds=self.inputs_embeds, cache_position=self.cache_position)
69+
torch.testing.assert_close(eager_output_embeds, wrapped_output_embeds, atol=1e-4, rtol=1e-4)
70+
71+
def test_hybrid_cache_module_forward(self):
72+
"""Test TorchExportableModuleWithHybridCache forward with both input types"""
73+
config = self.model.config
74+
config.sliding_window = 16
75+
config.layer_types = ["full_attention"] * config.num_hidden_layers
76+
77+
generation_config = GenerationConfig(
78+
use_cache=True,
79+
cache_implementation="hybrid",
80+
cache_config={"batch_size": 1, "max_cache_len": 32, "device": "cpu"},
81+
)
82+
83+
module = TorchExportableModuleWithHybridCache(self.model, config, generation_config)
84+
85+
# Test with input_ids
86+
eager_output_ids = self.model(input_ids=self.input_ids, use_cache=False).logits
87+
wrapped_output_ids = module.forward(input_ids=self.input_ids, cache_position=self.cache_position)
88+
torch.testing.assert_close(eager_output_ids, wrapped_output_ids, atol=1e-4, rtol=1e-4)
89+
90+
# Test with inputs_embeds
91+
eager_output_embeds = self.model(inputs_embeds=self.inputs_embeds, use_cache=False).logits
92+
wrapped_output_embeds = module.forward(inputs_embeds=self.inputs_embeds, cache_position=self.cache_position)
93+
torch.testing.assert_close(eager_output_embeds, wrapped_output_embeds, atol=1e-4, rtol=1e-4)
94+
95+
def test_decoder_only_lm_export_validation(self):
96+
"""Test TorchExportableModuleForDecoderOnlyLM export validation"""
97+
module = TorchExportableModuleForDecoderOnlyLM(self.model)
98+
99+
# Should fail with both input_ids and inputs_embeds
100+
with self.assertRaises(ValueError):
101+
module.export(input_ids=self.input_ids, inputs_embeds=self.inputs_embeds)
102+
103+
# Should fail with neither
104+
with self.assertRaises(ValueError):
105+
module.export()
106+
107+
def test_decoder_only_lm_export(self):
108+
"""Test TorchExportableModuleForDecoderOnlyLM export with both input types"""
109+
module = TorchExportableModuleForDecoderOnlyLM(self.model)
110+
111+
# Test export with input_ids
112+
exported_program_ids = module.export(input_ids=self.input_ids, cache_position=self.cache_position)
113+
eager_output_ids = self.model(input_ids=self.input_ids, use_cache=False).logits
114+
exported_output_ids = exported_program_ids.module()(
115+
input_ids=self.input_ids, cache_position=self.cache_position
116+
)
117+
torch.testing.assert_close(eager_output_ids, exported_output_ids, atol=1e-4, rtol=1e-4)
118+
119+
# Test export with inputs_embeds
120+
exported_program_embeds = module.export(inputs_embeds=self.inputs_embeds, cache_position=self.cache_position)
121+
eager_output_embeds = self.model(inputs_embeds=self.inputs_embeds, use_cache=False).logits
122+
exported_output_embeds = exported_program_embeds.module()(
123+
inputs_embeds=self.inputs_embeds, cache_position=self.cache_position
124+
)
125+
torch.testing.assert_close(eager_output_embeds, exported_output_embeds, atol=1e-4, rtol=1e-4)

0 commit comments

Comments
 (0)