Skip to content

Commit 4bcbf68

Browse files
committed
Merge branch 'main' into dualQpcParamFix
2 parents c0f9514 + f214e43 commit 4bcbf68

File tree

6 files changed

+248
-14
lines changed

6 files changed

+248
-14
lines changed

QEfficient/base/modeling_qeff.py

Lines changed: 69 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#
66
# ----------------------------------------------------------------------------
77

8+
import gc
89
import inspect
910
import logging
1011
import shutil
@@ -63,6 +64,9 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None:
6364
(arch := getattr(self.model.config, "architectures", None)) and len(arch) > 0 and arch[0]
6465
) or None
6566

67+
# Flag for checking if weights are offloaded
68+
self._is_weights_offloaded: bool = False
69+
6670
# Apply the transformations
6771
any_transformed = False
6872
for transform in self._pytorch_transforms:
@@ -74,6 +78,44 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None:
7478
else:
7579
logger.info(f"Pytorch transforms applied to model: {self.model_name}")
7680

81+
def _offload_model_weights(self, offload_pt_weights) -> bool:
82+
"""
83+
Clear PyTorch weights after export if offload_pt_weights is set to True
84+
85+
Returns:
86+
bool: True if weights were successfully offloaded, False otherwise
87+
"""
88+
# Check if offloading is enabled and weights are not already offloaded
89+
if offload_pt_weights and not self._is_weights_offloaded:
90+
try:
91+
self.model = self.model.to_empty(device="meta")
92+
self._is_weights_offloaded = True
93+
logger.info("Model weights offloaded to meta device")
94+
95+
gc.collect()
96+
logger.info("PyTorch weights cleared after export")
97+
return True
98+
99+
except Exception as e:
100+
logger.error(f"Failed to offload model weights: {e}")
101+
return False
102+
return False
103+
104+
def _model_offloaded_check(self) -> None:
105+
"""
106+
Check if the model is in meta state or weights are offloaded.
107+
108+
Raises:
109+
RuntimeError: If model is in meta state or if weights are offloaded
110+
"""
111+
if self._is_weights_offloaded or any(param.is_meta for param in self.model.parameters()):
112+
error_msg = (
113+
"Cannot re-export model: weights have been offloaded to save memory. "
114+
"To re-export, please create a new model instance using from_pretrained() method."
115+
)
116+
logger.error(error_msg)
117+
raise RuntimeError(error_msg)
118+
77119
@property
78120
@abstractmethod
79121
def model_name(self) -> str: ...
@@ -130,9 +172,15 @@ def _export(
130172
export_kwargs: Optional[Dict[str, any]] = None,
131173
onnx_transform_kwargs: Optional[Dict[str, any]] = None,
132174
export_dir: Optional[str] = None,
175+
offload_pt_weights: bool = True,
133176
) -> str:
134177
"""
135-
Export the Pytorch model to ONNX.
178+
Export the PyTorch model to ONNX and apply ONNX transforms
179+
180+
This method:
181+
1. Exports PyTorch model to ONNX using torch.onnx.export
182+
2. Clears PyTorch weights after export
183+
3. Applies ONNX transforms with reduced memory footprint
136184
137185
Args:
138186
:example_inputs (dict): Sample inputs to trace the model.
@@ -141,18 +189,30 @@ def _export(
141189
:export_kwargs (dict): Additional arguments to be passed to `torch.onnx.export`.
142190
:onnx_transform_kwargs (dict): Additional arguments to be passed to `Transform.apply` for this class.
143191
:export_dir (str): Specify the export directory. The export_dir will be suffixed with a hash corresponding to current model.
192+
:offload_pt_weights (bool): If True, offload PyTorch model weights to meta device
193+
after successful export to reduce memory usage. Set to False if you need to
194+
keep weights for further operations. Defaults to True.
195+
Note:
196+
Once weights are offloaded, the model cannot be re-exported. Create a new
197+
instance using from_pretrained() for re-export.
198+
144199
"""
145200
onnx_path = export_dir / f"{self.model_name}.onnx"
201+
202+
# Return early if ONNX already exists
146203
if onnx_path.is_file():
147204
self.onnx_path = onnx_path
148205
return onnx_path
149206

207+
# check if the model is in meta state or weights are offloaded
208+
self._model_offloaded_check()
209+
210+
# Setup temporary paths
150211
tmp_onnx_dir = export_dir / "onnx_tmp"
151212
tmp_onnx_path = tmp_onnx_dir / f"{self.model_name}.onnx"
152213
tmp_onnx_dir.mkdir(parents=True, exist_ok=True)
153214

154215
# Create input_names from example_inputs
155-
156216
input_names = []
157217
for param in inspect.signature(self.model.forward).parameters:
158218
if param in example_inputs:
@@ -188,7 +248,9 @@ def _export(
188248
opset_version=constants.ONNX_EXPORT_OPSET,
189249
**export_kwargs,
190250
)
191-
logger.info("Pytorch export successful")
251+
logger.info("PyTorch export successful")
252+
253+
_ = self._offload_model_weights(offload_pt_weights)
192254

193255
model = onnx.load(tmp_onnx_path, load_external_data=False)
194256
transform_kwargs = {
@@ -200,17 +262,17 @@ def _export(
200262

201263
for transform in self._onnx_transforms:
202264
model, transformed = transform.apply(model, **transform_kwargs)
265+
203266
model.metadata_props.append(
204267
onnx.StringStringEntryProto(key="qeff_transforms", value=",".join(self._transform_names()))
205268
)
206269
logger.info("ONNX transforms applied")
207270

208271
onnx.save(model, onnx_path)
209-
logger.info("Transformed onnx saved")
272+
logger.info("Transformed ONNX saved")
210273

211274
except Exception as e:
212-
logger.error(f"ONNX export (or) ONNXTransforms failed: {e}")
213-
275+
logger.error(f"ONNX export or transforms failed: {e}")
214276
raise e
215277

216278
finally:
@@ -230,7 +292,7 @@ def _compile(
230292
custom_io: Optional[Dict[str, str]] = None,
231293
mdp_ts_num_devices: int = 1,
232294
num_speculative_tokens: Optional[int] = None,
233-
mxfp6_matmul: bool = constants.DEFAULT_AIC_MXPF6_MATMUL,
295+
mxfp6_matmul: bool = constants.DEFAULT_AIC_MXFP6_MATMUL,
234296
enable_qnn: Optional[bool] = False,
235297
qnn_config: Optional[str] = None,
236298
**compiler_options,

QEfficient/peft/auto.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def generate(
287287

288288
generation_config = generation_config or self.model.generation_config
289289
generation_config, model_kwargs = self.model._prepare_generation_config(generation_config, **kwargs)
290-
self.model._prepare_special_tokens(generation_config)
290+
self.model._prepare_special_tokens(generation_config, device="cpu")
291291
if generation_config.do_sample:
292292
raise NotImplementedError("do_sample=True not supported currently")
293293
if generation_config.num_beams > 1:

QEfficient/transformers/models/modeling_auto.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -433,8 +433,10 @@ def __init__(self, model: nn.modules, **kwargs):
433433
self.model = model.get_qeff_vision_encoder()
434434
self.hash_params["qeff_auto_class"] = self.__class__.__name__
435435

436-
def export(self, inputs, output_names, dynamic_axes, export_dir=None):
437-
return self._export(inputs, output_names, dynamic_axes, export_dir=export_dir)
436+
def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True):
437+
return self._export(
438+
inputs, output_names, dynamic_axes, export_dir=export_dir, offload_pt_weights=offload_pt_weights
439+
)
438440

439441
def compile(
440442
self,
@@ -488,8 +490,10 @@ def __init__(self, model, **kwargs):
488490
self.model = model.get_qeff_language_decoder()
489491
self.hash_params["qeff_auto_class"] = self.__class__.__name__
490492

491-
def export(self, inputs, output_names, dynamic_axes, export_dir=None):
492-
return self._export(inputs, output_names, dynamic_axes, export_dir=export_dir)
493+
def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True):
494+
return self._export(
495+
inputs, output_names, dynamic_axes, export_dir=export_dir, offload_pt_weights=offload_pt_weights
496+
)
493497

494498
def compile(
495499
self,
@@ -583,14 +587,18 @@ def export(
583587
inputs = self.model.get_dummy_inputs(kv_offload=True)
584588
dynamic_axes = self.model.get_onnx_dynamic_axes(kv_offload=True)
585589
output_names = self.model.get_output_names(kv_offload=True)
590+
586591
self.vision_model.export(
587592
inputs["vision"],
588593
output_names["vision"],
589594
dynamic_axes["vision"],
590595
export_dir=export_dir,
596+
offload_pt_weights=False,
597+
)
598+
self.lang_model.export(
599+
inputs["lang"], output_names["lang"], dynamic_axes["lang"], export_dir=export_dir, offload_pt_weights=True
591600
)
592601

593-
self.lang_model.export(inputs["lang"], output_names["lang"], dynamic_axes["lang"], export_dir=export_dir)
594602
return self.onnx_path
595603

596604
def compile(

QEfficient/utils/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
# Compiler defaults
2929
DEFAULT_AIC_NUM_CORES = 16
30-
DEFAULT_AIC_MXPF6_MATMUL = False
30+
DEFAULT_AIC_MXFP6_MATMUL = False
3131
# Hashing defaults
3232
HASH_HEXDIGEST_STR_LEN = 16
3333
KWARGS_INCLUSION_LIST = [
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
7+
8+
import pytest
9+
from transformers import AutoConfig, AutoModelForCausalLM
10+
11+
from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM
12+
13+
# Simple test config for memory reduction testing
14+
test_config = AutoConfig.for_model(
15+
"gpt2",
16+
max_position_embeddings=256,
17+
num_hidden_layers=2,
18+
num_attention_heads=4,
19+
hidden_size=128,
20+
intermediate_size=512,
21+
vocab_size=127,
22+
num_key_value_heads=2,
23+
)
24+
25+
model_kwargs = {"attn_implementation": "eager"}
26+
27+
28+
@pytest.fixture
29+
def tmp_cache(tmp_path, monkeypatch):
30+
monkeypatch.setattr("QEfficient.utils._utils.QEFF_HOME", tmp_path)
31+
yield tmp_path
32+
33+
34+
def test_offload_weights_method():
35+
"""Test the _offload_model_weights method with both True and False values."""
36+
model = AutoModelForCausalLM.from_config(test_config, **model_kwargs)
37+
qeff_model = QEFFAutoModelForCausalLM(model, continuous_batching=False)
38+
39+
# Initially weights should not be offloaded
40+
assert not qeff_model._is_weights_offloaded
41+
assert not any(param.is_meta for param in qeff_model.model.parameters())
42+
43+
# Test with offload_pt_weights=True
44+
success = qeff_model._offload_model_weights(offload_pt_weights=True)
45+
assert success
46+
assert qeff_model._is_weights_offloaded
47+
assert all(param.is_meta for param in qeff_model.model.parameters())
48+
49+
# Reset for next test
50+
model2 = AutoModelForCausalLM.from_config(test_config, **model_kwargs)
51+
qeff_model2 = QEFFAutoModelForCausalLM(model2, continuous_batching=False)
52+
53+
# Test with offload_pt_weights=False
54+
success = qeff_model2._offload_model_weights(offload_pt_weights=False)
55+
assert not success
56+
assert not qeff_model2._is_weights_offloaded
57+
assert not any(param.is_meta for param in qeff_model2.model.parameters())
58+
59+
60+
def test_re_export_behavior_with_offloaded_weights(tmp_cache):
61+
"""Test that re-export fails when weights are offloaded."""
62+
model = AutoModelForCausalLM.from_config(test_config, **model_kwargs)
63+
qeff_model = QEFFAutoModelForCausalLM(model, continuous_batching=False)
64+
65+
# First export should succeed
66+
_ = qeff_model.export()
67+
assert qeff_model.onnx_path is not None
68+
69+
# Manually offload weights
70+
qeff_model._offload_model_weights(offload_pt_weights=True)
71+
assert qeff_model._is_weights_offloaded
72+
73+
# Force a new export by removing the file
74+
import os
75+
76+
os.remove(qeff_model.onnx_path)
77+
qeff_model.onnx_path = None
78+
79+
# Re-export should fail with RuntimeError due to offloaded weights
80+
with pytest.raises(RuntimeError, match="weights have been offloaded"):
81+
qeff_model.export()
82+
83+
84+
def test_vlm_dual_qpc_memory_offload_behavior():
85+
"""Test asymmetric memory offload behavior for VLM dual QPC models."""
86+
87+
# Mock vision model (should NOT offload weights)
88+
class MockVisionModel:
89+
def __init__(self):
90+
self._is_weights_offloaded = False
91+
92+
def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True):
93+
if offload_pt_weights:
94+
self._is_weights_offloaded = True
95+
return "vision_export_path"
96+
97+
# Mock language model (should offload weights)
98+
class MockLangModel:
99+
def __init__(self):
100+
self._is_weights_offloaded = False
101+
102+
def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True):
103+
if offload_pt_weights:
104+
self._is_weights_offloaded = True
105+
return "lang_export_path"
106+
107+
# Test dual QPC behavior
108+
vision_model = MockVisionModel()
109+
lang_model = MockLangModel()
110+
111+
# Simulate dual QPC export behavior
112+
vision_model.export({}, [], {}, offload_pt_weights=False) # Vision model doesn't offload
113+
lang_model.export({}, [], {}, offload_pt_weights=True) # Language model offloads
114+
115+
# Verify asymmetric behavior
116+
assert not vision_model._is_weights_offloaded # Vision model should NOT be offloaded
117+
assert lang_model._is_weights_offloaded # Language model should be offloaded
118+
119+
120+
def test_vlm_single_qpc_memory_offload_behavior():
121+
"""Test memory offload behavior for VLM single QPC models with both True and False."""
122+
123+
class MockParam:
124+
def __init__(self, is_meta=False):
125+
self.is_meta = is_meta
126+
127+
class MockModel:
128+
def __init__(self):
129+
self._params = [MockParam(is_meta=False)]
130+
131+
def parameters(self):
132+
return self._params
133+
134+
class MockSingleQPCModel:
135+
def __init__(self):
136+
self._is_weights_offloaded = False
137+
self.model = MockModel()
138+
139+
def _offload_model_weights(self):
140+
self._is_weights_offloaded = True
141+
for param in self.model.parameters():
142+
param.is_meta = True
143+
return True
144+
145+
def export(self, export_dir=None, offload_pt_weights=True):
146+
if offload_pt_weights:
147+
self._offload_model_weights()
148+
return "single_qpc_export_path"
149+
150+
# Test with offload_pt_weights=True
151+
qeff_model = MockSingleQPCModel()
152+
qeff_model.export(offload_pt_weights=True)
153+
assert qeff_model._is_weights_offloaded
154+
assert all(param.is_meta for param in qeff_model.model.parameters())
155+
156+
# Test with offload_pt_weights=False
157+
qeff_model2 = MockSingleQPCModel()
158+
qeff_model2.export(offload_pt_weights=False)
159+
assert not qeff_model2._is_weights_offloaded
160+
assert not any(param.is_meta for param in qeff_model2.model.parameters())

tests/transformers/models/test_causal_lm_models.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,10 @@ def test_causal_lm_export_with_deprecated_api(model_name):
282282
tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_name)
283283
qeff_model = QEFFAutoModelForCausalLM(model, model_name=model_name, pretrained_model_name_or_path=model_name)
284284
new_api_onnx_model_path = qeff_model.export()
285+
286+
# Again loading model since the export moves model to meta device
287+
model, _ = load_causal_lm_model(model_name, n_layer=1)
288+
qeff_model = QEFFAutoModelForCausalLM(model, model_name=model_name, pretrained_model_name_or_path=model_name)
285289
_, old_api_onnx_model_path = qualcomm_efficient_converter(
286290
model_name=model_name, model_kv=qeff_model, tokenizer=tokenizer
287291
)

0 commit comments

Comments
 (0)