Skip to content

Commit 16dbbe7

Browse files
committed
fixed isort changes
Signed-off-by: Sharvari Medhe <[email protected]>
1 parent d020b88 commit 16dbbe7

File tree

2 files changed

+113
-98
lines changed

2 files changed

+113
-98
lines changed

QEfficient/__init__.py

Lines changed: 35 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,50 @@
88
import os
99
import warnings
1010

11-
from QEfficient.utils import custom_format_warning
12-
1311
# For faster downloads via hf_transfer
1412
# This code is put above import statements as this needs to be executed before
15-
# hf_transfer is imported (will happen on line 15 via leading imports)
13+
# hf_transfer is imported (will happen on line 14 via leading imports)
1614
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
15+
1716
# Placeholder for all non-transformer models registered in QEfficient
1817
import QEfficient.utils.model_registery # noqa: F401
18+
from QEfficient.base import (
19+
QEFFAutoModel,
20+
QEFFAutoModelForCausalLM,
21+
QEFFAutoModelForImageTextToText,
22+
QEFFAutoModelForSpeechSeq2Seq,
23+
QEFFCommonLoader,
24+
)
25+
from QEfficient.compile.compile_helper import compile
26+
from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter
27+
from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv
28+
from QEfficient.peft import QEffAutoPeftModelForCausalLM
29+
from QEfficient.transformers.transform import transform
30+
from QEfficient.utils import custom_format_warning
1931
from QEfficient.utils.logging_utils import logger
2032

2133
# custom warning for the better logging experience
2234
warnings.formatwarning = custom_format_warning
2335

36+
# Conditionally import QAIC-related modules if the SDK is installed
37+
__version__ = "0.0.1.dev0"
38+
39+
# Users can use QEfficient.export for exporting models to ONNX
40+
export = qualcomm_efficient_converter
41+
42+
__all__ = [
43+
"transform",
44+
"export",
45+
"compile",
46+
"cloud_ai_100_exec_kv",
47+
"QEFFAutoModel",
48+
"QEFFAutoModelForCausalLM",
49+
"QEffAutoPeftModelForCausalLM",
50+
"QEFFAutoModelForImageTextToText",
51+
"QEFFAutoModelForSpeechSeq2Seq",
52+
"QEFFCommonLoader",
53+
]
54+
2455

2556
def check_qaic_sdk():
2657
"""Check if QAIC SDK is installed"""
@@ -36,38 +67,5 @@ def check_qaic_sdk():
3667
return False
3768

3869

39-
# Conditionally import QAIC-related modules if the SDK is installed
40-
__version__ = "0.0.1.dev0"
41-
42-
if check_qaic_sdk():
43-
from QEfficient.base import (
44-
QEFFAutoModel,
45-
QEFFAutoModelForCausalLM,
46-
QEFFAutoModelForImageTextToText,
47-
QEFFAutoModelForSpeechSeq2Seq,
48-
QEFFCommonLoader,
49-
)
50-
from QEfficient.compile.compile_helper import compile
51-
from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter
52-
from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv
53-
from QEfficient.peft import QEffAutoPeftModelForCausalLM
54-
from QEfficient.transformers.transform import transform
55-
56-
# Users can use QEfficient.export for exporting models to ONNX
57-
export = qualcomm_efficient_converter
58-
59-
__all__ = [
60-
"transform",
61-
"export",
62-
"compile",
63-
"cloud_ai_100_exec_kv",
64-
"QEFFAutoModel",
65-
"QEFFAutoModelForCausalLM",
66-
"QEffAutoPeftModelForCausalLM",
67-
"QEFFAutoModelForImageTextToText",
68-
"QEFFAutoModelForSpeechSeq2Seq",
69-
"QEFFCommonLoader",
70-
]
71-
72-
else:
70+
if not check_qaic_sdk():
7371
logger.warning("QAIC SDK is not installed, eager mode features won't be available!")

QEfficient/generation/cloud_infer.py

Lines changed: 78 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -5,43 +5,40 @@
55
#
66
# -----------------------------------------------------------------------------
77

8+
import importlib
9+
import platform
10+
import sys
811
from pathlib import Path
912
from typing import Dict, List, Optional, Union
1013
from warnings import warn
1114

1215
import numpy as np
1316

14-
try:
15-
import qaicrt
16-
except ImportError:
17-
import platform
18-
import sys
1917

20-
sys.path.append(f"/opt/qti-aic/dev/lib/{platform.machine()}")
21-
import qaicrt
22-
23-
try:
24-
import QAicApi_pb2 as aicapi
25-
except ImportError:
26-
import sys
27-
28-
sys.path.append("/opt/qti-aic/dev/python")
29-
import QAicApi_pb2 as aicapi
18+
class QAICInferenceSession:
19+
_qaicrt = None
20+
_aicapi = None
3021

31-
aic_to_np_dtype_mapping = {
32-
aicapi.FLOAT_TYPE: np.dtype(np.float32),
33-
aicapi.FLOAT_16_TYPE: np.dtype(np.float16),
34-
aicapi.INT8_Q_TYPE: np.dtype(np.int8),
35-
aicapi.UINT8_Q_TYPE: np.dtype(np.uint8),
36-
aicapi.INT16_Q_TYPE: np.dtype(np.int16),
37-
aicapi.INT32_Q_TYPE: np.dtype(np.int32),
38-
aicapi.INT32_I_TYPE: np.dtype(np.int32),
39-
aicapi.INT64_I_TYPE: np.dtype(np.int64),
40-
aicapi.INT8_TYPE: np.dtype(np.int8),
41-
}
22+
@property
23+
def qaicrt(self):
24+
if QAICInferenceSession._qaicrt is None:
25+
try:
26+
QAICInferenceSession._qaicrt = importlib.import_module("qaicrt")
27+
except ImportError:
28+
sys.path.append(f"/opt/qti-aic/dev/lib/{platform.machine()}")
29+
QAICInferenceSession._qaicrt = importlib.import_module("qaicrt")
30+
return QAICInferenceSession._qaicrt
4231

32+
@property
33+
def aicapi(self):
34+
if QAICInferenceSession._aicapi is None:
35+
try:
36+
QAICInferenceSession._aicapi = importlib.import_module("QAicApi_pb2")
37+
except ImportError:
38+
sys.path.append("/opt/qti-aic/dev/python")
39+
QAICInferenceSession._aicapi = importlib.import_module("QAicApi_pb2")
40+
return QAICInferenceSession._aicapi
4341

44-
class QAICInferenceSession:
4542
def __init__(
4643
self,
4744
qpc_path: Union[Path, str],
@@ -58,59 +55,81 @@ def __init__(
5855
:activate: bool. If false, activation will be disabled. Default=True.
5956
:enable_debug_logs: bool. If True, It will enable debug logs. Default=False.
6057
"""
58+
59+
# Build the dtype map one time, not on every property access
60+
self.aic_to_np_dtype_mapping = {
61+
self.aicapi.FLOAT_TYPE: np.dtype(np.float32),
62+
self.aicapi.FLOAT_16_TYPE: np.dtype(np.float16),
63+
self.aicapi.INT8_Q_TYPE: np.dtype(np.int8),
64+
self.aicapi.UINT8_Q_TYPE: np.dtype(np.uint8),
65+
self.aicapi.INT16_Q_TYPE: np.dtype(np.int16),
66+
self.aicapi.INT32_Q_TYPE: np.dtype(np.int32),
67+
self.aicapi.INT32_I_TYPE: np.dtype(np.int32),
68+
self.aicapi.INT64_I_TYPE: np.dtype(np.int64),
69+
self.aicapi.INT8_TYPE: np.dtype(np.int8),
70+
}
71+
6172
# Load QPC
6273
if device_ids is not None:
63-
devices = qaicrt.QIDList(device_ids)
64-
self.context = qaicrt.Context(devices)
65-
self.queue = qaicrt.Queue(self.context, device_ids[0])
74+
devices = self.qaicrt.QIDList(device_ids)
75+
self.context = self.qaicrt.Context(devices)
76+
self.queue = self.qaicrt.Queue(self.context, device_ids[0])
6677
else:
67-
self.context = qaicrt.Context()
68-
self.queue = qaicrt.Queue(self.context, 0) # Async API
78+
self.context = self.qaicrt.Context()
79+
self.queue = self.qaicrt.Queue(self.context, 0) # Async API
80+
6981
if enable_debug_logs:
70-
if self.context.setLogLevel(qaicrt.QLogLevel.QL_DEBUG) != qaicrt.QStatus.QS_SUCCESS:
82+
if self.context.setLogLevel(self.qaicrt.QLogLevel.QL_DEBUG) != self.qaicrt.QStatus.QS_SUCCESS:
7183
raise RuntimeError("Failed to setLogLevel")
72-
qpc = qaicrt.Qpc(str(qpc_path))
84+
85+
qpc = self.qaicrt.Qpc(str(qpc_path))
86+
7387
# Load IO Descriptor
74-
iodesc = aicapi.IoDesc()
88+
iodesc = self.aicapi.IoDesc()
7589
status, iodesc_data = qpc.getIoDescriptor()
76-
if status != qaicrt.QStatus.QS_SUCCESS:
90+
if status != self.qaicrt.QStatus.QS_SUCCESS:
7791
raise RuntimeError("Failed to getIoDescriptor")
7892
iodesc.ParseFromString(bytes(iodesc_data))
93+
7994
self.allowed_shapes = [
80-
[(aic_to_np_dtype_mapping[x.type].itemsize, list(x.dims)) for x in allowed_shape.shapes]
95+
[(self.aic_to_np_dtype_mapping[x.type].itemsize, list(x.dims)) for x in allowed_shape.shapes]
8196
for allowed_shape in iodesc.allowed_shapes
8297
]
8398
self.bindings = iodesc.selected_set.bindings
8499
self.binding_index_map = {binding.name: binding.index for binding in self.bindings}
100+
85101
# Create and load Program
86-
prog_properties = qaicrt.QAicProgramProperties()
102+
prog_properties = self.qaicrt.QAicProgramProperties()
87103
prog_properties.SubmitRetryTimeoutMs = 60_000
88104
if device_ids and len(device_ids) > 1:
89105
prog_properties.devMapping = ":".join(map(str, device_ids))
90-
self.program = qaicrt.Program(self.context, None, qpc, prog_properties)
91-
if self.program.load() != qaicrt.QStatus.QS_SUCCESS:
106+
107+
self.program = self.qaicrt.Program(self.context, None, qpc, prog_properties)
108+
if self.program.load() != self.qaicrt.QStatus.QS_SUCCESS:
92109
raise RuntimeError("Failed to load program")
110+
93111
if activate:
94112
self.activate()
113+
95114
# Create input qbuffers and buf_dims
96-
self.qbuffers = [qaicrt.QBuffer(bytes(binding.size)) for binding in self.bindings]
97-
self.buf_dims = qaicrt.BufferDimensionsVecRef(
98-
[(aic_to_np_dtype_mapping[binding.type].itemsize, list(binding.dims)) for binding in self.bindings]
115+
self.qbuffers = [self.qaicrt.QBuffer(bytes(binding.size)) for binding in self.bindings]
116+
self.buf_dims = self.qaicrt.BufferDimensionsVecRef(
117+
[(self.aic_to_np_dtype_mapping[binding.type].itemsize, list(binding.dims)) for binding in self.bindings]
99118
)
100119

101120
@property
102121
def input_names(self) -> List[str]:
103-
return [binding.name for binding in self.bindings if binding.dir == aicapi.BUFFER_IO_TYPE_INPUT]
122+
return [binding.name for binding in self.bindings if binding.dir == self.aicapi.BUFFER_IO_TYPE_INPUT]
104123

105124
@property
106125
def output_names(self) -> List[str]:
107-
return [binding.name for binding in self.bindings if binding.dir == aicapi.BUFFER_IO_TYPE_OUTPUT]
126+
return [binding.name for binding in self.bindings if binding.dir == self.aicapi.BUFFER_IO_TYPE_OUTPUT]
108127

109128
def activate(self):
110129
"""Activate qpc"""
111130

112131
self.program.activate()
113-
self.execObj = qaicrt.ExecObj(self.context, self.program)
132+
self.execObj = self.qaicrt.ExecObj(self.context, self.program)
114133

115134
def deactivate(self):
116135
"""Deactivate qpc"""
@@ -131,7 +150,7 @@ def set_buffers(self, buffers: Dict[str, np.ndarray]):
131150
warn(f'Buffer: "{buffer_name}" not found')
132151
continue
133152
buffer_index = self.binding_index_map[buffer_name]
134-
self.qbuffers[buffer_index] = qaicrt.QBuffer(buffer.tobytes())
153+
self.qbuffers[buffer_index] = self.qaicrt.QBuffer(buffer.tobytes())
135154
self.buf_dims[buffer_index] = (
136155
buffer.itemsize,
137156
buffer.shape if len(buffer.shape) > 0 else (1,),
@@ -157,21 +176,19 @@ def run(self, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
157176
Return:
158177
:Dict[str, np.ndarray]:
159178
"""
160-
# Set inputs
179+
161180
self.set_buffers(inputs)
162-
if self.execObj.setData(self.qbuffers, self.buf_dims) != qaicrt.QStatus.QS_SUCCESS:
181+
if self.execObj.setData(self.qbuffers, self.buf_dims) != self.qaicrt.QStatus.QS_SUCCESS:
163182
raise MemoryError("Failed to setData")
164-
# # Run with sync API
165-
# if self.execObj.run(self.qbuffers) != qaicrt.QStatus.QS_SUCCESS:
166-
# Run with async API
167-
if self.queue.enqueue(self.execObj) != qaicrt.QStatus.QS_SUCCESS:
183+
184+
if self.queue.enqueue(self.execObj) != self.qaicrt.QStatus.QS_SUCCESS:
168185
raise MemoryError("Failed to enqueue")
169-
if self.execObj.waitForCompletion() != qaicrt.QStatus.QS_SUCCESS:
186+
187+
if self.execObj.waitForCompletion() != self.qaicrt.QStatus.QS_SUCCESS:
170188
error_message = "Failed to run"
171-
# Print additional error messages for unmatched dimension error
189+
172190
if self.allowed_shapes:
173-
error_message += "\n\n"
174-
error_message += '(Only if "No matching dimension found" error is present above)'
191+
error_message += "\n\n(Only if 'No matching dimension found' error is present above)"
175192
error_message += "\nAllowed shapes:"
176193
for i, allowed_shape in enumerate(self.allowed_shapes):
177194
error_message += f"\n{i}\n"
@@ -189,18 +206,18 @@ def run(self, inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
189206
continue
190207
error_message += f"{binding.name}:\t{elemsize}\t{shape}\n"
191208
raise ValueError(error_message)
192-
# Get output buffers
209+
193210
status, output_qbuffers = self.execObj.getData()
194-
if status != qaicrt.QStatus.QS_SUCCESS:
211+
if status != self.qaicrt.QStatus.QS_SUCCESS:
195212
raise MemoryError("Failed to getData")
196-
# Build output
213+
197214
outputs = {}
198215
for output_name in self.output_names:
199216
buffer_index = self.binding_index_map[output_name]
200217
if self.qbuffers[buffer_index].size == 0:
201218
continue
202219
outputs[output_name] = np.frombuffer(
203220
bytes(output_qbuffers[buffer_index]),
204-
aic_to_np_dtype_mapping[self.bindings[buffer_index].type],
221+
self.aic_to_np_dtype_mapping[self.bindings[buffer_index].type],
205222
).reshape(self.buf_dims[buffer_index][1])
206223
return outputs

0 commit comments

Comments
 (0)