5
5
#
6
6
# -----------------------------------------------------------------------------
7
7
8
+ import importlib
9
+ import platform
10
+ import sys
8
11
from pathlib import Path
9
12
from typing import Dict , List , Optional , Union
10
13
from warnings import warn
11
14
12
15
import numpy as np
13
16
14
- try :
15
- import qaicrt
16
- except ImportError :
17
- import platform
18
- import sys
19
17
20
- sys .path .append (f"/opt/qti-aic/dev/lib/{ platform .machine ()} " )
21
- import qaicrt
22
-
23
- try :
24
- import QAicApi_pb2 as aicapi
25
- except ImportError :
26
- import sys
27
-
28
- sys .path .append ("/opt/qti-aic/dev/python" )
29
- import QAicApi_pb2 as aicapi
18
+ class QAICInferenceSession :
19
+ _qaicrt = None
20
+ _aicapi = None
30
21
31
- aic_to_np_dtype_mapping = {
32
- aicapi .FLOAT_TYPE : np .dtype (np .float32 ),
33
- aicapi .FLOAT_16_TYPE : np .dtype (np .float16 ),
34
- aicapi .INT8_Q_TYPE : np .dtype (np .int8 ),
35
- aicapi .UINT8_Q_TYPE : np .dtype (np .uint8 ),
36
- aicapi .INT16_Q_TYPE : np .dtype (np .int16 ),
37
- aicapi .INT32_Q_TYPE : np .dtype (np .int32 ),
38
- aicapi .INT32_I_TYPE : np .dtype (np .int32 ),
39
- aicapi .INT64_I_TYPE : np .dtype (np .int64 ),
40
- aicapi .INT8_TYPE : np .dtype (np .int8 ),
41
- }
22
+ @property
23
+ def qaicrt (self ):
24
+ if QAICInferenceSession ._qaicrt is None :
25
+ try :
26
+ QAICInferenceSession ._qaicrt = importlib .import_module ("qaicrt" )
27
+ except ImportError :
28
+ sys .path .append (f"/opt/qti-aic/dev/lib/{ platform .machine ()} " )
29
+ QAICInferenceSession ._qaicrt = importlib .import_module ("qaicrt" )
30
+ return QAICInferenceSession ._qaicrt
42
31
32
+ @property
33
+ def aicapi (self ):
34
+ if QAICInferenceSession ._aicapi is None :
35
+ try :
36
+ QAICInferenceSession ._aicapi = importlib .import_module ("QAicApi_pb2" )
37
+ except ImportError :
38
+ sys .path .append ("/opt/qti-aic/dev/python" )
39
+ QAICInferenceSession ._aicapi = importlib .import_module ("QAicApi_pb2" )
40
+ return QAICInferenceSession ._aicapi
43
41
44
- class QAICInferenceSession :
45
42
def __init__ (
46
43
self ,
47
44
qpc_path : Union [Path , str ],
@@ -58,59 +55,81 @@ def __init__(
58
55
:activate: bool. If false, activation will be disabled. Default=True.
59
56
:enable_debug_logs: bool. If True, It will enable debug logs. Default=False.
60
57
"""
58
+
59
+ # Build the dtype map one time, not on every property access
60
+ self .aic_to_np_dtype_mapping = {
61
+ self .aicapi .FLOAT_TYPE : np .dtype (np .float32 ),
62
+ self .aicapi .FLOAT_16_TYPE : np .dtype (np .float16 ),
63
+ self .aicapi .INT8_Q_TYPE : np .dtype (np .int8 ),
64
+ self .aicapi .UINT8_Q_TYPE : np .dtype (np .uint8 ),
65
+ self .aicapi .INT16_Q_TYPE : np .dtype (np .int16 ),
66
+ self .aicapi .INT32_Q_TYPE : np .dtype (np .int32 ),
67
+ self .aicapi .INT32_I_TYPE : np .dtype (np .int32 ),
68
+ self .aicapi .INT64_I_TYPE : np .dtype (np .int64 ),
69
+ self .aicapi .INT8_TYPE : np .dtype (np .int8 ),
70
+ }
71
+
61
72
# Load QPC
62
73
if device_ids is not None :
63
- devices = qaicrt .QIDList (device_ids )
64
- self .context = qaicrt .Context (devices )
65
- self .queue = qaicrt .Queue (self .context , device_ids [0 ])
74
+ devices = self . qaicrt .QIDList (device_ids )
75
+ self .context = self . qaicrt .Context (devices )
76
+ self .queue = self . qaicrt .Queue (self .context , device_ids [0 ])
66
77
else :
67
- self .context = qaicrt .Context ()
68
- self .queue = qaicrt .Queue (self .context , 0 ) # Async API
78
+ self .context = self .qaicrt .Context ()
79
+ self .queue = self .qaicrt .Queue (self .context , 0 ) # Async API
80
+
69
81
if enable_debug_logs :
70
- if self .context .setLogLevel (qaicrt .QLogLevel .QL_DEBUG ) != qaicrt .QStatus .QS_SUCCESS :
82
+ if self .context .setLogLevel (self . qaicrt .QLogLevel .QL_DEBUG ) != self . qaicrt .QStatus .QS_SUCCESS :
71
83
raise RuntimeError ("Failed to setLogLevel" )
72
- qpc = qaicrt .Qpc (str (qpc_path ))
84
+
85
+ qpc = self .qaicrt .Qpc (str (qpc_path ))
86
+
73
87
# Load IO Descriptor
74
- iodesc = aicapi .IoDesc ()
88
+ iodesc = self . aicapi .IoDesc ()
75
89
status , iodesc_data = qpc .getIoDescriptor ()
76
- if status != qaicrt .QStatus .QS_SUCCESS :
90
+ if status != self . qaicrt .QStatus .QS_SUCCESS :
77
91
raise RuntimeError ("Failed to getIoDescriptor" )
78
92
iodesc .ParseFromString (bytes (iodesc_data ))
93
+
79
94
self .allowed_shapes = [
80
- [(aic_to_np_dtype_mapping [x .type ].itemsize , list (x .dims )) for x in allowed_shape .shapes ]
95
+ [(self . aic_to_np_dtype_mapping [x .type ].itemsize , list (x .dims )) for x in allowed_shape .shapes ]
81
96
for allowed_shape in iodesc .allowed_shapes
82
97
]
83
98
self .bindings = iodesc .selected_set .bindings
84
99
self .binding_index_map = {binding .name : binding .index for binding in self .bindings }
100
+
85
101
# Create and load Program
86
- prog_properties = qaicrt .QAicProgramProperties ()
102
+ prog_properties = self . qaicrt .QAicProgramProperties ()
87
103
prog_properties .SubmitRetryTimeoutMs = 60_000
88
104
if device_ids and len (device_ids ) > 1 :
89
105
prog_properties .devMapping = ":" .join (map (str , device_ids ))
90
- self .program = qaicrt .Program (self .context , None , qpc , prog_properties )
91
- if self .program .load () != qaicrt .QStatus .QS_SUCCESS :
106
+
107
+ self .program = self .qaicrt .Program (self .context , None , qpc , prog_properties )
108
+ if self .program .load () != self .qaicrt .QStatus .QS_SUCCESS :
92
109
raise RuntimeError ("Failed to load program" )
110
+
93
111
if activate :
94
112
self .activate ()
113
+
95
114
# Create input qbuffers and buf_dims
96
- self .qbuffers = [qaicrt .QBuffer (bytes (binding .size )) for binding in self .bindings ]
97
- self .buf_dims = qaicrt .BufferDimensionsVecRef (
98
- [(aic_to_np_dtype_mapping [binding .type ].itemsize , list (binding .dims )) for binding in self .bindings ]
115
+ self .qbuffers = [self . qaicrt .QBuffer (bytes (binding .size )) for binding in self .bindings ]
116
+ self .buf_dims = self . qaicrt .BufferDimensionsVecRef (
117
+ [(self . aic_to_np_dtype_mapping [binding .type ].itemsize , list (binding .dims )) for binding in self .bindings ]
99
118
)
100
119
101
120
@property
102
121
def input_names (self ) -> List [str ]:
103
- return [binding .name for binding in self .bindings if binding .dir == aicapi .BUFFER_IO_TYPE_INPUT ]
122
+ return [binding .name for binding in self .bindings if binding .dir == self . aicapi .BUFFER_IO_TYPE_INPUT ]
104
123
105
124
@property
106
125
def output_names (self ) -> List [str ]:
107
- return [binding .name for binding in self .bindings if binding .dir == aicapi .BUFFER_IO_TYPE_OUTPUT ]
126
+ return [binding .name for binding in self .bindings if binding .dir == self . aicapi .BUFFER_IO_TYPE_OUTPUT ]
108
127
109
128
def activate (self ):
110
129
"""Activate qpc"""
111
130
112
131
self .program .activate ()
113
- self .execObj = qaicrt .ExecObj (self .context , self .program )
132
+ self .execObj = self . qaicrt .ExecObj (self .context , self .program )
114
133
115
134
def deactivate (self ):
116
135
"""Deactivate qpc"""
@@ -131,7 +150,7 @@ def set_buffers(self, buffers: Dict[str, np.ndarray]):
131
150
warn (f'Buffer: "{ buffer_name } " not found' )
132
151
continue
133
152
buffer_index = self .binding_index_map [buffer_name ]
134
- self .qbuffers [buffer_index ] = qaicrt .QBuffer (buffer .tobytes ())
153
+ self .qbuffers [buffer_index ] = self . qaicrt .QBuffer (buffer .tobytes ())
135
154
self .buf_dims [buffer_index ] = (
136
155
buffer .itemsize ,
137
156
buffer .shape if len (buffer .shape ) > 0 else (1 ,),
@@ -157,21 +176,19 @@ def run(self, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
157
176
Return:
158
177
:Dict[str, np.ndarray]:
159
178
"""
160
- # Set inputs
179
+
161
180
self .set_buffers (inputs )
162
- if self .execObj .setData (self .qbuffers , self .buf_dims ) != qaicrt .QStatus .QS_SUCCESS :
181
+ if self .execObj .setData (self .qbuffers , self .buf_dims ) != self . qaicrt .QStatus .QS_SUCCESS :
163
182
raise MemoryError ("Failed to setData" )
164
- # # Run with sync API
165
- # if self.execObj.run(self.qbuffers) != qaicrt.QStatus.QS_SUCCESS:
166
- # Run with async API
167
- if self .queue .enqueue (self .execObj ) != qaicrt .QStatus .QS_SUCCESS :
183
+
184
+ if self .queue .enqueue (self .execObj ) != self .qaicrt .QStatus .QS_SUCCESS :
168
185
raise MemoryError ("Failed to enqueue" )
169
- if self .execObj .waitForCompletion () != qaicrt .QStatus .QS_SUCCESS :
186
+
187
+ if self .execObj .waitForCompletion () != self .qaicrt .QStatus .QS_SUCCESS :
170
188
error_message = "Failed to run"
171
- # Print additional error messages for unmatched dimension error
189
+
172
190
if self .allowed_shapes :
173
- error_message += "\n \n "
174
- error_message += '(Only if "No matching dimension found" error is present above)'
191
+ error_message += "\n \n (Only if 'No matching dimension found' error is present above)"
175
192
error_message += "\n Allowed shapes:"
176
193
for i , allowed_shape in enumerate (self .allowed_shapes ):
177
194
error_message += f"\n { i } \n "
@@ -189,18 +206,18 @@ def run(self, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
189
206
continue
190
207
error_message += f"{ binding .name } :\t { elemsize } \t { shape } \n "
191
208
raise ValueError (error_message )
192
- # Get output buffers
209
+
193
210
status , output_qbuffers = self .execObj .getData ()
194
- if status != qaicrt .QStatus .QS_SUCCESS :
211
+ if status != self . qaicrt .QStatus .QS_SUCCESS :
195
212
raise MemoryError ("Failed to getData" )
196
- # Build output
213
+
197
214
outputs = {}
198
215
for output_name in self .output_names :
199
216
buffer_index = self .binding_index_map [output_name ]
200
217
if self .qbuffers [buffer_index ].size == 0 :
201
218
continue
202
219
outputs [output_name ] = np .frombuffer (
203
220
bytes (output_qbuffers [buffer_index ]),
204
- aic_to_np_dtype_mapping [self .bindings [buffer_index ].type ],
221
+ self . aic_to_np_dtype_mapping [self .bindings [buffer_index ].type ],
205
222
).reshape (self .buf_dims [buffer_index ][1 ])
206
223
return outputs
0 commit comments