Skip to content

Commit b384b19

Browse files
committed
Added locking mechanism and updated the tests
Signed-off-by: Rishin Raj <[email protected]>
1 parent 3f695d1 commit b384b19

12 files changed

+165
-90
lines changed

QEfficient/utils/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
QEFF_DIR = os.path.dirname(UTILS_DIR)
1313
ROOT_DIR = os.path.dirname(QEFF_DIR)
1414
QEFF_CACHE_DIR_NAME = "qeff_cache"
15+
LOCK_DIR = "/tmp/device_locks"
1516

1617
ONNX_EXPORT_EXAMPLE_BATCH_SIZE = 1
1718
ONNX_EXPORT_EXAMPLE_SEQ_LEN = 32

QEfficient/utils/device_utils.py

Lines changed: 105 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,71 +5,152 @@
55
#
66
# -----------------------------------------------------------------------------
77

8+
import fcntl
89
import math
10+
import os
911
import re
1012
import subprocess
1113
import time
14+
from typing import Optional
1215

13-
from QEfficient.utils.constants import Constants
16+
from QEfficient.utils.constants import LOCK_DIR, Constants
1417
from QEfficient.utils.logging_utils import logger
1518

1619

17-
def is_device_available(stdout: str) -> bool:
20+
def is_device_loaded(stdout: str) -> bool:
1821
try:
1922
match = re.search(r"Networks Loaded:(\d+)", stdout)
2023
return int(match.group(1)) > 0 if match else False
24+
2125
except (ValueError, AttributeError):
2226
return False
2327

2428

29+
def release_device_lock(lock_file):
30+
try:
31+
fcntl.flock(lock_file, fcntl.LOCK_UN)
32+
lock_file.close()
33+
34+
except Exception as e:
35+
logger.error(f"Error releasing lock: {e}")
36+
37+
2538
def get_device_count():
2639
command = ["/opt/qti-aic/tools/qaic-util", "-q"]
40+
2741
try:
2842
result = subprocess.run(command, capture_output=True, text=True)
2943
qids = re.findall(r"QID (\d+)", result.stdout)
3044
return max(map(int, qids)) + 1 if qids else 0
45+
3146
except OSError:
3247
logger.warning("ERROR while fetching the device", command)
3348
return 0
3449

3550

36-
def get_available_device_id(max_retry_count: int = 50, wait_time: int = 5) -> list[int] | None:
51+
def ensure_lock_dir(lock_dir: str):
52+
if not os.path.exists(lock_dir):
53+
os.makedirs(lock_dir)
54+
55+
56+
def acquire_device_lock(retry_interval: int = 10, retry_duration: int = 300) -> Optional[object]:
3757
"""
38-
Find an available Cloud AI 100 device ID.
58+
Attempt to acquire a non-blocking exclusive lock on a device lock file.
59+
Retries every 10 seconds for up to 5 minutes.
3960
4061
Args:
41-
max_retry_count (int): Maximum number of retries.
42-
wait_time (int): Seconds to wait between retries.
62+
device_id (int): The device ID to lock.
4363
4464
Returns:
45-
list[int] | None: List containing available device ID, or None if not found.
65+
file object if lock is acquired, else None.
4666
"""
67+
ensure_lock_dir()
68+
lock_file_path = os.path.join(LOCK_DIR, "device_check.lock")
69+
start_time = time.time()
4770

71+
while (time.time() - start_time) < retry_duration:
72+
lock_file = open(lock_file_path, "w")
73+
74+
try:
75+
fcntl.flock(lock_file, fcntl.LOCK_EX | fcntl.LOCK_NB)
76+
logger.debug("Lock acquired for device check")
77+
return lock_file
78+
79+
except BlockingIOError:
80+
lock_file.close()
81+
logger.debug(f"Device check is locked. Retrying in {retry_interval} seconds...")
82+
time.sleep(retry_interval)
83+
84+
except Exception as e:
85+
logger.error(f"Unexpected error acquiring lock for device check: {e}")
86+
return None
87+
88+
logger.warning("Failed to acquire lock for device check after 5 minutes.")
89+
return None
90+
91+
92+
def __fetch_device_id(device_count):
93+
for device_id in range(device_count):
94+
try:
95+
device_query_cmd = ["/opt/qti-aic/tools/qaic-util", "-q", "-d", str(device_id)]
96+
result = subprocess.run(device_query_cmd, capture_output=True, text=True)
97+
98+
if "Failed to find requested device ID" in result.stdout:
99+
logger.warning(f"Device ID {device_id} not found.")
100+
continue
101+
102+
if "Status:Error" in result.stdout or not is_device_loaded(result.stdout):
103+
logger.debug(f"Device {device_id} is not available.")
104+
continue
105+
106+
logger.info(f"Device ID {device_id} is available and locked.")
107+
return [device_id]
108+
109+
except subprocess.TimeoutExpired:
110+
logger.error(f"Timeout while querying device {device_id}.")
111+
except OSError as e:
112+
logger.error(f"OSError while querying device {device_id}: {e}")
113+
return None
114+
except Exception as e:
115+
logger.exception(f"Unexpected error while checking device {device_id}: {e}")
116+
return None
117+
118+
119+
def get_available_device_id(retry_duration: int = 300, wait_time: int = 5) -> Optional[list[int]]:
120+
"""
121+
Find an available Cloud AI 100 device ID using file-based locking.
122+
123+
Args:
124+
max_retry_count (int): Maximum number of retries.
125+
wait_time (int): Seconds to wait between retries.
126+
127+
Returns:
128+
list[int] | None: List containing available device ID, or None if not found.
129+
"""
48130
device_count = get_device_count()
131+
49132
if device_count == 0:
50-
logger.warning("No Cloud AI 100 devices found or platform sdk not installed.")
133+
logger.warning("No Cloud AI 100 devices found or platform SDK not installed.")
51134
return None
52135

53-
for retry_count in range(max_retry_count):
54-
for device_id in range(device_count):
55-
command = ["/opt/qti-aic/tools/qaic-util", "-q", "-d", str(device_id)]
56-
try:
57-
result = subprocess.run(command, capture_output=True, text=True)
58-
except OSError:
59-
logger.warning("Failed while querying the AIC card", command)
60-
return None
136+
lock_file = acquire_device_lock()
61137

62-
if "Status:Error" in result.stdout or not is_device_available(result.stdout):
63-
continue
138+
if lock_file:
139+
start_time = time.time()
140+
141+
while (time.time() - start_time) < retry_duration:
142+
device_id = __fetch_device_id(device_count)
143+
144+
if device_id:
145+
release_device_lock(lock_file)
146+
return device_id
64147

65-
elif "Status:Ready" in result.stdout:
66-
logger.info(f"Device ID : {device_id} is available.")
67-
return [device_id]
148+
time.sleep(wait_time)
68149

69-
elif "Failed to find requested device ID" in result.stdout:
70-
logger.warning("Device ID %d not found.", device_id)
150+
if lock_file:
151+
release_device_lock(lock_file)
71152

72-
time.sleep(wait_time)
153+
logger.warning("No available device found after all retries.")
73154
return None
74155

75156

tests/peft/lora/test_lora_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from QEfficient import QEffAutoPeftModelForCausalLM
1818
from QEfficient.peft.lora import QEffAutoLoraModelForCausalLM
1919
from QEfficient.utils import load_hf_tokenizer
20+
from QEfficient.utils.device_utils import get_available_device_id
2021

2122
configs = [
2223
pytest.param(
@@ -235,6 +236,7 @@ def test_auto_lora_model_for_causal_lm_noncb_export_compile_generate(
235236
tokenizer=load_hf_tokenizer(pretrained_model_name_or_path=base_model_name),
236237
prompts=prompts,
237238
prompt_to_adapter_mapping=["adapter_0", "adapter_1", "adapter_0", "base"],
239+
device_ids=get_available_device_id(),
238240
)
239241

240242

@@ -260,4 +262,5 @@ def test_auto_lora_model_for_causal_lm_cb_compile_generate(base_model_name, adap
260262
tokenizer=load_hf_tokenizer(pretrained_model_name_or_path=base_model_name),
261263
prompts=prompts,
262264
prompt_to_adapter_mapping=["adapter_0", "adapter_1", "adapter_0", "base"],
265+
device_ids=get_available_device_id(),
263266
)

tests/peft/test_peft_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from transformers import AutoConfig, AutoModelForCausalLM
1717

1818
from QEfficient import QEffAutoPeftModelForCausalLM
19+
from QEfficient.utils.device_utils import get_available_device_id
1920

2021
configs = [
2122
pytest.param(
@@ -181,6 +182,7 @@ def test_auto_peft_model_for_causal_lm_compile_generate(base_config, adapter_con
181182
axis=1,
182183
),
183184
max_new_tokens=10,
185+
device_ids=get_available_device_id(),
184186
)
185187

186188
start = perf_counter()

tests/text_generation/test_text_generation.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,6 @@ def test_generate_text_stream(
7272
qeff_model = QEFFAutoModelForCausalLM(model_hf)
7373

7474
qeff_model.export()
75-
device_id = get_available_device_id()
76-
77-
if not device_id:
78-
pytest.skip("No available devices to run model on Cloud AI 100")
7975

8076
qpc_path = qeff_model.compile(
8177
prefill_seq_len=prompt_len,
@@ -86,7 +82,9 @@ def test_generate_text_stream(
8682
full_batch_size=full_batch_size,
8783
)
8884

89-
exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR, generation_len=max_gen_len)
85+
exec_info = qeff_model.generate(
86+
tokenizer, prompts=Constants.INPUT_STR, generation_len=max_gen_len, device_ids=get_available_device_id()
87+
)
9088
cloud_ai_100_tokens = exec_info.generated_ids[0] # Because we always run for single input and single batch size
9189
cloud_ai_100_output = [tokenizer.decode(token, skip_special_tokens=True) for token in cloud_ai_100_tokens[0]]
9290

@@ -100,7 +98,7 @@ def test_generate_text_stream(
10098
for decoded_tokens in text_generator.generate_stream_tokens(Constants.INPUT_STR, generation_len=max_gen_len):
10199
stream_tokens.extend(decoded_tokens)
102100

103-
assert cloud_ai_100_output == stream_tokens, (
104-
f"Deviation in output observed while comparing regular execution and streamed output: {cloud_ai_100_output} != {stream_tokens}"
105-
)
101+
assert (
102+
cloud_ai_100_output == stream_tokens
103+
), f"Deviation in output observed while comparing regular execution and streamed output: {cloud_ai_100_output} != {stream_tokens}"
106104
assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json"))

tests/transformers/models/test_causal_lm_models.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -127,19 +127,16 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
127127

128128
pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model)
129129

130-
assert (pytorch_hf_tokens == pytorch_kv_tokens).all(), (
131-
"Tokens don't match for HF PyTorch model output and KV PyTorch model output"
132-
)
130+
assert (
131+
pytorch_hf_tokens == pytorch_kv_tokens
132+
).all(), "Tokens don't match for HF PyTorch model output and KV PyTorch model output"
133133

134134
onnx_model_path = qeff_model.export()
135135
ort_tokens = api_runner.run_kv_model_on_ort(onnx_model_path, is_tlm=is_tlm)
136136
gen_len = ort_tokens.shape[-1]
137137

138138
assert (pytorch_kv_tokens == ort_tokens).all(), "Tokens don't match for ONNXRT output and PyTorch output."
139139

140-
if not get_available_device_id():
141-
pytest.skip("No available devices to run model on Cloud AI 100")
142-
143140
qpc_path = qeff_model.compile(
144141
prefill_seq_len=prompt_len,
145142
ctx_len=ctx_len,
@@ -151,18 +148,18 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
151148
enable_qnn=enable_qnn,
152149
qnn_config=qnn_config,
153150
)
154-
exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR)
151+
exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR, device_ids=get_available_device_id())
155152
cloud_ai_100_tokens = exec_info.generated_ids[0][
156153
:, :gen_len
157154
] # Because we always run for single input and single batch size
158155
if prefill_only:
159-
assert (ort_tokens[0][0] == cloud_ai_100_tokens[0][0]).all(), (
160-
"prefill run output tokens don't match for ONNXRT output and Cloud AI 100 output."
161-
)
156+
assert (
157+
ort_tokens[0][0] == cloud_ai_100_tokens[0][0]
158+
).all(), "prefill run output tokens don't match for ONNXRT output and Cloud AI 100 output."
162159
else:
163-
assert (ort_tokens == cloud_ai_100_tokens).all(), (
164-
"Tokens don't match for ONNXRT output and Cloud AI 100 output."
165-
)
160+
assert (
161+
ort_tokens == cloud_ai_100_tokens
162+
).all(), "Tokens don't match for ONNXRT output and Cloud AI 100 output."
166163
assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json"))
167164
if prefill_only is not None:
168165
return
@@ -188,9 +185,6 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
188185
)
189186
onnx_model_path = qeff_model.export()
190187

191-
if not get_available_device_id():
192-
pytest.skip("No available devices to run model on Cloud AI 100")
193-
194188
# TODO: add prefill_only tests
195189
qpc_path = qeff_model.compile(
196190
prefill_seq_len=prompt_len,
@@ -203,7 +197,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
203197
enable_qnn=enable_qnn,
204198
qnn_config=qnn_config,
205199
)
206-
exec_info_fbs = qeff_model.generate(tokenizer, prompts=fbs_prompts)
200+
exec_info_fbs = qeff_model.generate(tokenizer, prompts=fbs_prompts, device_ids=get_available_device_id())
207201

208202
assert all(
209203
[
@@ -239,9 +233,9 @@ def test_causal_lm_export_with_deprecated_api(model_name):
239233
new_api_ort_tokens = api_runner.run_kv_model_on_ort(new_api_onnx_model_path)
240234
old_api_ort_tokens = api_runner.run_kv_model_on_ort(old_api_onnx_model_path)
241235

242-
assert (new_api_ort_tokens == old_api_ort_tokens).all(), (
243-
"New API output does not match old API output for ONNX export function"
244-
)
236+
assert (
237+
new_api_ort_tokens == old_api_ort_tokens
238+
).all(), "New API output does not match old API output for ONNX export function"
245239

246240

247241
@pytest.mark.on_qaic

tests/transformers/models/test_embedding_models.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from QEfficient.transformers.models.modeling_auto import QEFFAutoModel
1717
from QEfficient.utils._utils import create_json
1818
from QEfficient.utils.constants import Constants, QnnConstants
19+
from QEfficient.utils.device_utils import get_available_device_id
1920

2021
embed_test_models = [
2122
# model_name, architecture
@@ -48,7 +49,7 @@ def check_embed_pytorch_vs_ort_vs_ai100(
4849
pt_embeddings = pt_outputs[0][0].detach().numpy()
4950
# Pytorch transformed model
5051
qeff_model = QEFFAutoModel(pt_model, pretrained_model_name_or_path=model_name)
51-
qeff_pt_outputs = qeff_model.generate(inputs=inputs, runtime_ai100=False)
52+
qeff_pt_outputs = qeff_model.generate(inputs=inputs, runtime_ai100=False, device_ids=get_available_device_id())
5253
qeff_pt_embeddings = qeff_pt_outputs[0][0].detach().numpy()
5354
mad = np.mean(np.abs(pt_embeddings - qeff_pt_embeddings))
5455
print("Mad for PyTorch and PyTorch transformed qeff_model is ", mad)
@@ -78,7 +79,7 @@ def check_embed_pytorch_vs_ort_vs_ai100(
7879
enable_qnn=enable_qnn,
7980
qnn_config=qnn_config,
8081
)
81-
ai100_output = qeff_model.generate(inputs=inputs)
82+
ai100_output = qeff_model.generate(inputs=inputs, device_ids=get_available_device_id())
8283

8384
# Compare ONNX and AI 100 outputs
8485
mad = np.mean(np.abs(ai100_output - onnx_outputs[0]))

0 commit comments

Comments
 (0)