5
5
#
6
6
# ----------------------------------------------------------------------------
7
7
8
- import hashlib
9
8
import inspect
10
9
import logging
11
10
import shutil
22
21
from QEfficient .base .pytorch_transforms import PytorchTransform
23
22
from QEfficient .compile .qnn_compiler import compile as qnn_compile
24
23
from QEfficient .generation .cloud_infer import QAICInferenceSession
25
- from QEfficient .utils import constants , create_json , dump_qconfig , generate_mdp_partition_config , load_json
26
- from QEfficient .utils .cache import QEFF_HOME , to_hashable
24
+ from QEfficient .utils import (
25
+ constants ,
26
+ create_json ,
27
+ create_model_params ,
28
+ dump_qconfig ,
29
+ export_wrapper ,
30
+ generate_mdp_partition_config ,
31
+ hash_dict_params ,
32
+ load_json ,
33
+ )
27
34
28
35
logger = logging .getLogger (__name__ )
29
36
@@ -45,12 +52,16 @@ class QEFFBaseModel(ABC):
45
52
def _transform_names (cls ) -> List [str ]:
46
53
return [x .__name__ for x in cls ._pytorch_transforms + cls ._onnx_transforms ]
47
54
48
- def __init__ (self , model : torch .nn .Module ) -> None :
55
+ def __init__ (self , model : torch .nn .Module , ** kwargs ) -> None :
49
56
super ().__init__ ()
50
57
self .model = model
58
+ self .hash_params = create_model_params (self , ** kwargs )
51
59
self .onnx_path : Optional [str ] = None
52
60
self .qpc_path : Optional [str ] = None
53
61
self .qpc_session : Optional [QAICInferenceSession ] = None
62
+ self .model_architecture = (
63
+ (arch := getattr (self .model .config , "architectures" , None )) and len (arch ) > 0 and arch [0 ]
64
+ ) or None
54
65
55
66
# Apply the transformations
56
67
any_transformed = False
@@ -67,10 +78,6 @@ def __init__(self, model: torch.nn.Module) -> None:
67
78
@abstractmethod
68
79
def model_name (self ) -> str : ...
69
80
70
- @property
71
- @abstractmethod
72
- def model_hash (self ) -> str : ...
73
-
74
81
@abstractmethod
75
82
def export (self , export_dir : Optional [str ] = None ) -> Path :
76
83
"""
@@ -114,6 +121,7 @@ def compile(self, *args, **kwargs) -> Path:
114
121
:str: Path of the compiled ``qpc`` package.
115
122
"""
116
123
124
+ @export_wrapper
117
125
def _export (
118
126
self ,
119
127
example_inputs : Dict [str , torch .Tensor ],
@@ -134,8 +142,6 @@ def _export(
134
142
:onnx_transform_kwargs (dict): Additional arguments to be passed to `Transform.apply` for this class.
135
143
:export_dir (str): Specify the export directory. The export_dir will be suffixed with a hash corresponding to current model.
136
144
"""
137
- export_dir = Path (export_dir or (QEFF_HOME / self .model_name ))
138
- export_dir = export_dir .with_name (export_dir .name + "-" + self .model_hash )
139
145
onnx_path = export_dir / f"{ self .model_name } .onnx"
140
146
if onnx_path .is_file ():
141
147
self .onnx_path = onnx_path
@@ -304,23 +310,16 @@ def _compile(
304
310
else :
305
311
mdp_ts_json = None
306
312
307
- compile_hash = hashlib .sha256 (to_hashable (command ))
308
-
309
- if specializations is not None :
310
- compile_hash .update (to_hashable (specializations ))
311
-
312
- if custom_io is not None :
313
- compile_hash .update (to_hashable (custom_io ))
314
-
315
- if num_speculative_tokens :
316
- compile_hash .update (to_hashable ({"num_speculative_tokens" : num_speculative_tokens }))
317
-
318
- # Hash the MDP partition config and the number of devices.
319
- compile_hash .update (to_hashable (mdp_ts_json ))
320
- compile_hash .update (to_hashable ({"mdp_ts_num_devices" : mdp_ts_num_devices }))
313
+ compile_hash_params = {
314
+ "command" : command ,
315
+ "specializations" : specializations ,
316
+ "custom_io" : custom_io ,
317
+ "mdp_ts_num_devices" : mdp_ts_num_devices ,
318
+ "mdp_ts_json" : mdp_ts_json ,
319
+ "num_speculative_tokens" : num_speculative_tokens ,
320
+ }
321
+ compile_hash = hash_dict_params (compile_hash_params )
321
322
322
- # Check if already compiled
323
- compile_hash = compile_hash .hexdigest ()[:16 ]
324
323
compile_dir = qpc_path .with_name (qpc_path .name + "-" + compile_hash )
325
324
qpc_path = compile_dir / "qpc"
326
325
qpc_path .mkdir (parents = True , exist_ok = True )
@@ -371,6 +370,10 @@ def _compile(
371
370
]
372
371
)
373
372
)
373
+ # Dump JSON file with hashed parameters
374
+ hashed_compile_params_path = compile_dir / "hashed_compile_params.json"
375
+ create_json (hashed_compile_params_path , compile_hash_params )
376
+ logger .info ("Hashed parameters exported successfully." )
374
377
375
378
self .qpc_path = qpc_path
376
379
0 commit comments