Skip to content

Commit d5f5603

Browse files
authored
Merge branch 'main' into sdk_dependency
Signed-off-by: Mamta Singh <[email protected]>
2 parents 2888a8b + d020b88 commit d5f5603

File tree

16 files changed

+522
-279
lines changed

16 files changed

+522
-279
lines changed

QEfficient/base/modeling_qeff.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#
66
# ----------------------------------------------------------------------------
77

8-
import hashlib
98
import inspect
109
import logging
1110
import shutil
@@ -22,8 +21,16 @@
2221
from QEfficient.base.pytorch_transforms import PytorchTransform
2322
from QEfficient.compile.qnn_compiler import compile as qnn_compile
2423
from QEfficient.generation.cloud_infer import QAICInferenceSession
25-
from QEfficient.utils import constants, create_json, dump_qconfig, generate_mdp_partition_config, load_json
26-
from QEfficient.utils.cache import QEFF_HOME, to_hashable
24+
from QEfficient.utils import (
25+
constants,
26+
create_json,
27+
create_model_params,
28+
dump_qconfig,
29+
export_wrapper,
30+
generate_mdp_partition_config,
31+
hash_dict_params,
32+
load_json,
33+
)
2734

2835
logger = logging.getLogger(__name__)
2936

@@ -45,12 +52,16 @@ class QEFFBaseModel(ABC):
4552
def _transform_names(cls) -> List[str]:
4653
return [x.__name__ for x in cls._pytorch_transforms + cls._onnx_transforms]
4754

48-
def __init__(self, model: torch.nn.Module) -> None:
55+
def __init__(self, model: torch.nn.Module, **kwargs) -> None:
4956
super().__init__()
5057
self.model = model
58+
self.hash_params = create_model_params(self, **kwargs)
5159
self.onnx_path: Optional[str] = None
5260
self.qpc_path: Optional[str] = None
5361
self.qpc_session: Optional[QAICInferenceSession] = None
62+
self.model_architecture = (
63+
(arch := getattr(self.model.config, "architectures", None)) and len(arch) > 0 and arch[0]
64+
) or None
5465

5566
# Apply the transformations
5667
any_transformed = False
@@ -67,10 +78,6 @@ def __init__(self, model: torch.nn.Module) -> None:
6778
@abstractmethod
6879
def model_name(self) -> str: ...
6980

70-
@property
71-
@abstractmethod
72-
def model_hash(self) -> str: ...
73-
7481
@abstractmethod
7582
def export(self, export_dir: Optional[str] = None) -> Path:
7683
"""
@@ -114,6 +121,7 @@ def compile(self, *args, **kwargs) -> Path:
114121
:str: Path of the compiled ``qpc`` package.
115122
"""
116123

124+
@export_wrapper
117125
def _export(
118126
self,
119127
example_inputs: Dict[str, torch.Tensor],
@@ -134,8 +142,6 @@ def _export(
134142
:onnx_transform_kwargs (dict): Additional arguments to be passed to `Transform.apply` for this class.
135143
:export_dir (str): Specify the export directory. The export_dir will be suffixed with a hash corresponding to current model.
136144
"""
137-
export_dir = Path(export_dir or (QEFF_HOME / self.model_name))
138-
export_dir = export_dir.with_name(export_dir.name + "-" + self.model_hash)
139145
onnx_path = export_dir / f"{self.model_name}.onnx"
140146
if onnx_path.is_file():
141147
self.onnx_path = onnx_path
@@ -299,23 +305,16 @@ def _compile(
299305
else:
300306
mdp_ts_json = None
301307

302-
compile_hash = hashlib.sha256(to_hashable(command))
303-
304-
if specializations is not None:
305-
compile_hash.update(to_hashable(specializations))
306-
307-
if custom_io is not None:
308-
compile_hash.update(to_hashable(custom_io))
309-
310-
if num_speculative_tokens:
311-
compile_hash.update(to_hashable({"num_speculative_tokens": num_speculative_tokens}))
312-
313-
# Hash the MDP partition config and the number of devices.
314-
compile_hash.update(to_hashable(mdp_ts_json))
315-
compile_hash.update(to_hashable({"mdp_ts_num_devices": mdp_ts_num_devices}))
308+
compile_hash_params = {
309+
"command": command,
310+
"specializations": specializations,
311+
"custom_io": custom_io,
312+
"mdp_ts_num_devices": mdp_ts_num_devices,
313+
"mdp_ts_json": mdp_ts_json,
314+
"num_speculative_tokens": num_speculative_tokens,
315+
}
316+
compile_hash = hash_dict_params(compile_hash_params)
316317

317-
# Check if already compiled
318-
compile_hash = compile_hash.hexdigest()[:16]
319318
compile_dir = qpc_path.with_name(qpc_path.name + "-" + compile_hash)
320319
qpc_path = compile_dir / "qpc"
321320
qpc_path.mkdir(parents=True, exist_ok=True)
@@ -366,6 +365,10 @@ def _compile(
366365
]
367366
)
368367
)
368+
# Dump JSON file with hashed parameters
369+
hashed_compile_params_path = compile_dir / "hashed_compile_params.json"
370+
create_json(hashed_compile_params_path, compile_hash_params)
371+
logger.info("Hashed parameters exported successfully.")
369372

370373
self.qpc_path = qpc_path
371374

QEfficient/compile/qnn_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
from typing import Dict, List, Optional
1313

1414
from QEfficient.utils._utils import create_json, execute_command, load_json
15-
from QEfficient.utils.cache import to_hashable
1615
from QEfficient.utils.constants import QnnConstants
1716
from QEfficient.utils.generate_qnn_network_specialization_config import (
1817
generate_data_format_config,
1918
generate_qnn_specialization,
2019
)
20+
from QEfficient.utils.hash_utils import to_hashable
2121
from QEfficient.utils.logging_utils import logger
2222

2323

QEfficient/peft/auto.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from QEfficient.transformers.models.pytorch_transforms import CustomOpsTransform, KVCacheTransform
2828
from QEfficient.utils import constants
2929
from QEfficient.utils._utils import get_padding_shape_from_config
30-
from QEfficient.utils.cache import to_hashable
30+
from QEfficient.utils.hash_utils import to_hashable
3131

3232
logger = logging.getLogger(__name__)
3333

QEfficient/peft/lora/auto.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from QEfficient import QEFFAutoModelForCausalLM
1919
from QEfficient.peft.lora.pytorch_transforms import LoraModelInputsTransform, TargetModulesTransform
2020
from QEfficient.utils import constants, get_padding_shape_from_config
21-
from QEfficient.utils.cache import to_hashable
21+
from QEfficient.utils.hash_utils import to_hashable
2222
from QEfficient.utils.logging_utils import logger
2323

2424

0 commit comments

Comments
 (0)