5
5
#
6
6
# ----------------------------------------------------------------------------
7
7
8
- import hashlib
9
8
import inspect
10
9
import logging
11
10
import shutil
22
21
from QEfficient .base .pytorch_transforms import PytorchTransform
23
22
from QEfficient .compile .qnn_compiler import compile as qnn_compile
24
23
from QEfficient .generation .cloud_infer import QAICInferenceSession
25
- from QEfficient .utils import constants , create_json , dump_qconfig , generate_mdp_partition_config , load_json
26
- from QEfficient .utils .cache import QEFF_HOME , to_hashable
24
+ from QEfficient .utils import (
25
+ constants ,
26
+ create_json ,
27
+ create_model_params ,
28
+ dump_qconfig ,
29
+ export_wrapper ,
30
+ generate_mdp_partition_config ,
31
+ hash_dict_params ,
32
+ load_json ,
33
+ )
27
34
28
35
logger = logging .getLogger (__name__ )
29
36
@@ -45,12 +52,16 @@ class QEFFBaseModel(ABC):
45
52
def _transform_names (cls ) -> List [str ]:
46
53
return [x .__name__ for x in cls ._pytorch_transforms + cls ._onnx_transforms ]
47
54
48
- def __init__ (self , model : torch .nn .Module ) -> None :
55
+ def __init__ (self , model : torch .nn .Module , ** kwargs ) -> None :
49
56
super ().__init__ ()
50
57
self .model = model
58
+ self .hash_params = create_model_params (self , ** kwargs )
51
59
self .onnx_path : Optional [str ] = None
52
60
self .qpc_path : Optional [str ] = None
53
61
self .qpc_session : Optional [QAICInferenceSession ] = None
62
+ self .model_architecture = (
63
+ (arch := getattr (self .model .config , "architectures" , None )) and len (arch ) > 0 and arch [0 ]
64
+ ) or None
54
65
55
66
# Apply the transformations
56
67
any_transformed = False
@@ -67,10 +78,6 @@ def __init__(self, model: torch.nn.Module) -> None:
67
78
@abstractmethod
68
79
def model_name (self ) -> str : ...
69
80
70
- @property
71
- @abstractmethod
72
- def model_hash (self ) -> str : ...
73
-
74
81
@abstractmethod
75
82
def export (self , export_dir : Optional [str ] = None ) -> Path :
76
83
"""
@@ -114,6 +121,7 @@ def compile(self, *args, **kwargs) -> Path:
114
121
:str: Path of the compiled ``qpc`` package.
115
122
"""
116
123
124
+ @export_wrapper
117
125
def _export (
118
126
self ,
119
127
example_inputs : Dict [str , torch .Tensor ],
@@ -134,8 +142,6 @@ def _export(
134
142
:onnx_transform_kwargs (dict): Additional arguments to be passed to `Transform.apply` for this class.
135
143
:export_dir (str): Specify the export directory. The export_dir will be suffixed with a hash corresponding to current model.
136
144
"""
137
- export_dir = Path (export_dir or (QEFF_HOME / self .model_name ))
138
- export_dir = export_dir .with_name (export_dir .name + "-" + self .model_hash )
139
145
onnx_path = export_dir / f"{ self .model_name } .onnx"
140
146
if onnx_path .is_file ():
141
147
self .onnx_path = onnx_path
@@ -299,23 +305,16 @@ def _compile(
299
305
else :
300
306
mdp_ts_json = None
301
307
302
- compile_hash = hashlib .sha256 (to_hashable (command ))
303
-
304
- if specializations is not None :
305
- compile_hash .update (to_hashable (specializations ))
306
-
307
- if custom_io is not None :
308
- compile_hash .update (to_hashable (custom_io ))
309
-
310
- if num_speculative_tokens :
311
- compile_hash .update (to_hashable ({"num_speculative_tokens" : num_speculative_tokens }))
312
-
313
- # Hash the MDP partition config and the number of devices.
314
- compile_hash .update (to_hashable (mdp_ts_json ))
315
- compile_hash .update (to_hashable ({"mdp_ts_num_devices" : mdp_ts_num_devices }))
308
+ compile_hash_params = {
309
+ "command" : command ,
310
+ "specializations" : specializations ,
311
+ "custom_io" : custom_io ,
312
+ "mdp_ts_num_devices" : mdp_ts_num_devices ,
313
+ "mdp_ts_json" : mdp_ts_json ,
314
+ "num_speculative_tokens" : num_speculative_tokens ,
315
+ }
316
+ compile_hash = hash_dict_params (compile_hash_params )
316
317
317
- # Check if already compiled
318
- compile_hash = compile_hash .hexdigest ()[:16 ]
319
318
compile_dir = qpc_path .with_name (qpc_path .name + "-" + compile_hash )
320
319
qpc_path = compile_dir / "qpc"
321
320
qpc_path .mkdir (parents = True , exist_ok = True )
@@ -366,6 +365,10 @@ def _compile(
366
365
]
367
366
)
368
367
)
368
+ # Dump JSON file with hashed parameters
369
+ hashed_compile_params_path = compile_dir / "hashed_compile_params.json"
370
+ create_json (hashed_compile_params_path , compile_hash_params )
371
+ logger .info ("Hashed parameters exported successfully." )
369
372
370
373
self .qpc_path = qpc_path
371
374
0 commit comments