12
12
import numpy as np
13
13
from onnx import ModelProto , external_data_helper , numpy_helper
14
14
15
- from QEfficient .utils .constants import ONNX_TRANSFROM_MEMORY_CLEANUP_INTERVAL
15
+ from QEfficient .utils .constants import ONNX_TRANSFORM_MEMORY_CLEANUP_INTERVAL
16
16
17
17
logger = logging .getLogger (__name__ )
18
18
@@ -22,6 +22,8 @@ class OnnxTransform:
22
22
OnnxTransform is the base class for graph modifications on exported onnx.
23
23
"""
24
24
25
+ _external_data_loaded_cache = {} # Dict[int, bool]
26
+
25
27
def __init__ (self ):
26
28
raise TypeError ("Transform classes are not to be instantiated. Directly use the `apply` method." )
27
29
@@ -45,12 +47,54 @@ def _check_external_data_loaded(cls, model: ModelProto) -> bool:
45
47
:param model: The ONNX model to check
46
48
:returns: True if external data is already loaded, False otherwise
47
49
"""
50
+ # Use object ID as key instead of the object itself
51
+ model_id = id (model )
52
+ # Return cached result if available
53
+ if model_id in cls ._external_data_loaded_cache :
54
+ return cls ._external_data_loaded_cache [model_id ]
55
+
56
+ # Load the model if not already loaded
48
57
for tensor in external_data_helper ._get_all_tensors (model ):
49
58
# Check if tensor has external data but no raw data loaded
50
59
if len (tensor .external_data ) > 0 and not tensor .HasField ("raw_data" ):
60
+ cls ._external_data_loaded_cache [model_id ] = False
51
61
return False
62
+
63
+ cls ._external_data_loaded_cache [model_id ] = True
52
64
return True
53
65
66
+ @classmethod
67
+ def _load_external_data (cls , model : ModelProto , onnx_base_dir : Optional [str ] = None ):
68
+ """
69
+ Performs a bulk load of external data if it's not already loaded.
70
+ Updates the cache upon successful load.
71
+ """
72
+ model_id = id (model )
73
+ if not cls ._check_external_data_loaded (model ):
74
+ logger .info ("External data not loaded. Performing bulk load." )
75
+ external_data_helper .load_external_data_for_model (model , onnx_base_dir )
76
+ cls ._external_data_loaded_cache [model_id ] = True
77
+ else :
78
+ logger .info ("External data already loaded (or cached). Skipping bulk load." )
79
+
80
+
81
+ @classmethod
82
+ def _cleanup_external_data_and_cache (cls , model : ModelProto ):
83
+ """
84
+ Combines clearing external data from the model and its cache entry.
85
+ """
86
+ # Remove the loaded raw data from tensors
87
+ for tensor in external_data_helper ._get_all_tensors (model ):
88
+ if tensor .HasField ("raw_data" ):
89
+ tensor .ClearField ("raw_data" )
90
+
91
+ # Clear the cache entry for this model using its ID
92
+ model_id = id (model )
93
+ if model_id in cls ._external_data_loaded_cache :
94
+ del cls ._external_data_loaded_cache [model_id ]
95
+
96
+ logger .info ("External data and cache cleaned up." )
97
+
54
98
@classmethod
55
99
def _cleanup_memory (cls ):
56
100
"""
@@ -69,36 +113,42 @@ def apply(cls, model: ModelProto, *, onnx_base_dir: Optional[str] = None, **kwar
69
113
"""
70
114
:param onnx_base_dir: Base directory to load tensors
71
115
"""
72
- finfo = np .finfo (np .float16 )
73
- fp16_max = finfo .max
74
- fp16_min = finfo .min
75
- transformed = False
116
+ try :
117
+ # --- FIX: Ensure external data is loaded efficiently BEFORE processing ---
118
+ cls ._load_external_data (model , onnx_base_dir )
76
119
77
- processed_count = 0
78
- for tensor in external_data_helper ._get_all_tensors (model ):
79
- nptensor = numpy_helper .to_array (tensor , onnx_base_dir )
80
- if nptensor .dtype == np .float32 and (np .any (nptensor > fp16_max ) or np .any (nptensor < fp16_min )):
81
- neg_inf_mask = np .isinf (nptensor ) & (nptensor < 0 )
82
- clipped_tensor = np .clip (nptensor , fp16_min , fp16_max )
120
+ finfo = np .finfo (np .float16 )
121
+ fp16_max = finfo .max
122
+ fp16_min = finfo .min
123
+ transformed = False
124
+
125
+ processed_count = 0
126
+ for tensor in external_data_helper ._get_all_tensors (model ):
127
+ nptensor = numpy_helper .to_array (tensor ) # Removed onnx_base_dir as data is already loaded
128
+ if nptensor .dtype == np .float32 and (np .any (nptensor > fp16_max ) or np .any (nptensor < fp16_min )):
129
+ neg_inf_mask = np .isinf (nptensor ) & (nptensor < 0 )
130
+ clipped_tensor = np .clip (nptensor , fp16_min , fp16_max )
83
131
84
- # Restore -inf values
85
- if neg_inf_mask .any ():
86
- clipped_tensor = np .where (neg_inf_mask , np .float32 ("-inf" ), clipped_tensor )
132
+ # Restore -inf values
133
+ if neg_inf_mask .any ():
134
+ clipped_tensor = np .where (neg_inf_mask , np .float32 ("-inf" ), clipped_tensor )
87
135
88
- new_tensor = numpy_helper .from_array (clipped_tensor , tensor .name )
89
- tensor .CopyFrom (new_tensor )
90
- transformed = True
136
+ new_tensor = numpy_helper .from_array (clipped_tensor , tensor .name )
137
+ tensor .CopyFrom (new_tensor )
138
+ transformed = True
91
139
92
- del neg_inf_mask , clipped_tensor , new_tensor
140
+ del neg_inf_mask , clipped_tensor , new_tensor
93
141
94
- del nptensor
95
- processed_count += 1
142
+ del nptensor
143
+ processed_count += 1
96
144
97
- if processed_count % ONNX_TRANSFROM_MEMORY_CLEANUP_INTERVAL == 0 :
98
- cls ._cleanup_memory ()
145
+ if processed_count % ONNX_TRANSFORM_MEMORY_CLEANUP_INTERVAL == 0 :
146
+ cls ._cleanup_memory ()
99
147
100
- cls ._cleanup_memory ()
101
- return model , transformed
148
+ return model , transformed
149
+ finally :
150
+ # Ensure cleanup happens even if an exception occurs
151
+ cls ._cleanup_memory ()
102
152
103
153
104
154
class SplitTensorsTransform (OnnxTransform ):
@@ -123,32 +173,30 @@ def apply(
123
173
:param file_chunk_size: Chunk size to split external files into.
124
174
:param size_threshold: Only tensors greater than this threshold (in bytes) will be saved externally.
125
175
"""
126
- file_num = 0
127
- current_file_size = 0
128
- transformed = False
129
-
130
- # Check if external data is already loaded to avoid redundant loading
131
- external_data_already_loaded = cls ._check_external_data_loaded (model )
132
-
133
- if not external_data_already_loaded :
134
- external_data_helper .load_external_data_for_model (model , onnx_base_dir )
135
- else :
136
- logger .info ("External data already loaded, skipping redundant load operation" )
137
-
138
- processed_count = 0
139
- for tensor in external_data_helper ._get_all_tensors (model ):
140
- if tensor .HasField ("raw_data" ) and ((tsize := len (tensor .raw_data )) > size_threshold ):
141
- transformed = True
142
- current_file_size += tsize
143
- if current_file_size > file_chunk_size :
144
- file_num += 1
145
- current_file_size = tsize
146
- external_data_helper .set_external_data (tensor , f"{ model_name } _{ file_num } .onnx.data" )
147
-
148
- processed_count += 1
149
- if processed_count % ONNX_TRANSFROM_MEMORY_CLEANUP_INTERVAL == 0 :
150
- cls ._cleanup_memory ()
151
-
152
- cls ._cleanup_memory ()
153
-
154
- return model , transformed
176
+ try :
177
+ file_num = 0
178
+ current_file_size = 0
179
+ transformed = False
180
+
181
+ # --- Adjustment: The initial check and load will now use the new bulk loader ---
182
+ # This will either use the cache (if FP16ClipTransform loaded it) or perform the bulk load itself.
183
+ cls ._load_external_data (model , onnx_base_dir )
184
+
185
+ processed_count = 0
186
+ for tensor in external_data_helper ._get_all_tensors (model ):
187
+ if tensor .HasField ("raw_data" ) and ((tsize := len (tensor .raw_data )) > size_threshold ):
188
+ transformed = True
189
+ current_file_size += tsize
190
+ if current_file_size > file_chunk_size :
191
+ file_num += 1
192
+ current_file_size = tsize
193
+ external_data_helper .set_external_data (tensor , f"{ model_name } _{ file_num } .onnx.data" )
194
+
195
+ processed_count += 1
196
+ if processed_count % ONNX_TRANSFORM_MEMORY_CLEANUP_INTERVAL == 0 :
197
+ cls ._cleanup_memory ()
198
+
199
+ return model , transformed
200
+ finally :
201
+ # Ensure cleanup happens even if an exception occurs
202
+ cls ._cleanup_memory ()
0 commit comments