6
6
from dataclasses import dataclass
7
7
from queue import Empty as QueueEmpty
8
8
from queue import Queue
9
+ from threading import Event
9
10
from typing import (
10
11
Any ,
11
12
Generic ,
@@ -109,9 +110,11 @@ async def resolve(
109
110
...
110
111
111
112
async def get_request_time (
112
- self , times_queue : Queue [WorkerProcessRequestTime ]
113
+ self ,
114
+ times_queue : Queue [WorkerProcessRequestTime ],
115
+ timeout : Optional [int ] = None ,
113
116
) -> WorkerProcessRequestTime :
114
- return await asyncio .to_thread (times_queue .get ) # type: ignore[attr-defined]
117
+ return await asyncio .to_thread (times_queue .get , timeout = timeout ) # type: ignore[attr-defined]
115
118
116
119
async def send_result (
117
120
self ,
@@ -181,6 +184,7 @@ async def resolve_scheduler_request(
181
184
def process_loop_asynchronous (
182
185
self ,
183
186
queues : MPQueues [RequestT , ResponseT ],
187
+ stop_event : Event ,
184
188
prioritize_sessions : bool ,
185
189
max_concurrency : int ,
186
190
process_id : int ,
@@ -189,7 +193,7 @@ async def _process_runner():
189
193
lock = asyncio .Semaphore (max_concurrency )
190
194
pending_sessions : list [RequestSession [RequestT , ResponseT ]] = []
191
195
192
- while True : # TODO: Exit condition
196
+ while True :
193
197
await asyncio .sleep (0 ) # Yield control to the event loop
194
198
await lock .acquire ()
195
199
@@ -201,13 +205,16 @@ async def _process_runner():
201
205
else queues .requests .get_nowait ()
202
206
)
203
207
dequeued_time = time .time ()
204
- request_times = await self .get_request_time (queues .times )
208
+ request_times = await self .get_request_time (queues .times , 5 )
205
209
except (QueueEmpty , IndexError ):
206
210
# Requeue the session if we don't have a next time yet
207
211
if request_session is not None :
208
212
pending_sessions .append (request_session )
209
213
lock .release ()
210
- continue
214
+ if stop_event .is_set ():
215
+ return # Exit if stop event is set
216
+ else :
217
+ continue
211
218
212
219
async def wait_then_requeue (
213
220
session : RequestSession [RequestT , ResponseT ],
@@ -309,13 +316,15 @@ async def prepare_multiprocessing(self):
309
316
def process_loop_asynchronous (
310
317
self ,
311
318
queues : MPQueues [GenerationRequest , ResponseSummary ],
319
+ stop_event : Event ,
312
320
prioritize_sessions : bool ,
313
321
max_concurrency : int ,
314
322
process_id : int ,
315
323
):
316
324
asyncio .run (self .backend .validate ())
317
325
super ().process_loop_asynchronous (
318
326
queues = queues ,
327
+ stop_event = stop_event ,
319
328
prioritize_sessions = prioritize_sessions ,
320
329
max_concurrency = max_concurrency ,
321
330
process_id = process_id ,
0 commit comments