Skip to content

Commit 6dada58

Browse files
Sam Daniel ThangarajanSamDanielThangarajan
authored andcommitted
add graceful shutdown for sessions
1 parent 807e48c commit 6dada58

File tree

5 files changed

+373
-20
lines changed

5 files changed

+373
-20
lines changed

src/nasdaq_protocols/common/message_queue.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import contextlib
23
from contextlib import asynccontextmanager
34
from itertools import count
45
from typing import Any, Callable, Awaitable
@@ -26,19 +27,25 @@ class DispatchableMessageQueue(Stoppable):
2627
on_msg_coro: DispatcherCoro = None
2728
_closed: bool = attrs.field(init=False, default=False)
2829
_msg_queue: asyncio.Queue = attrs.field(init=False, default=None)
30+
_buffer_msg_queue: asyncio.Queue | None = attrs.field(init=False, default=None)
2931
_recv_task: asyncio.Task = attrs.field(init=False, default=None)
3032
_dispatcher_task: asyncio.Task = attrs.field(init=False, default=None)
3133

3234
def __attrs_post_init__(self):
3335
self._msg_queue = asyncio.Queue()
3436
self.start_dispatching(self.on_msg_coro)
3537

38+
def __len__(self) -> int:
39+
"""Return the number of entries in the queue."""
40+
return self._msg_queue.qsize()
41+
3642
async def put(self, msg: Any) -> None:
3743
"""
3844
put an entry into the queue.
3945
:param msg: Any
4046
"""
41-
await self._msg_queue.put(msg)
47+
queue = self._buffer_msg_queue if self._buffer_msg_queue else self._msg_queue
48+
await queue.put(msg)
4249

4350
async def get(self):
4451
"""get an entry from the queue.
@@ -59,7 +66,8 @@ def put_nowait(self, msg: Any):
5966
put an entry into the queue.
6067
:param msg: Any
6168
"""
62-
self._msg_queue.put_nowait(msg)
69+
queue = self._buffer_msg_queue if self._buffer_msg_queue else self._msg_queue
70+
queue.put_nowait(msg)
6371

6472
def get_nowait(self) -> Any | None:
6573
"""
@@ -118,6 +126,27 @@ def start_dispatching(self, on_msg_coro: DispatcherCoro) -> None:
118126
self._dispatcher_task = asyncio.create_task(self._start_dispatching(), name=f'{self.session_id}-dispatcher')
119127
self.log.debug('%s> queue dispatcher started.', self.session_id)
120128

129+
@contextlib.asynccontextmanager
130+
async def buffer_until_drained(self, discard_buffer: bool = False):
131+
"""Async context manager that waits until the buffer is drained."""
132+
if self._buffer_msg_queue:
133+
raise StateError('Already blocking new messages, cannot nest block_until_empty')
134+
135+
self._buffer_msg_queue = asyncio.Queue()
136+
137+
try:
138+
while not self._msg_queue.empty():
139+
await asyncio.sleep(0.0001)
140+
yield
141+
finally:
142+
if not discard_buffer:
143+
if self.is_dispatching():
144+
async with self.pause_dispatching():
145+
self._msg_queue = self._buffer_msg_queue
146+
else:
147+
self._msg_queue = self._buffer_msg_queue
148+
self._buffer_msg_queue = None
149+
121150
async def stop(self) -> None:
122151
"""
123152
Stop the queue.
@@ -137,8 +166,7 @@ async def _start_dispatching(self):
137166
counter = count(1)
138167
while True:
139168
try:
140-
msg = await self._msg_queue.get()
141-
await self.on_msg_coro(msg)
169+
await self.on_msg_coro(await self._msg_queue.get())
142170
self.log.debug('%s> dispatched message %s', self.session_id, next(counter))
143171
except asyncio.CancelledError:
144172
break

src/nasdaq_protocols/common/session.py

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import abc
22
import asyncio
3+
import contextlib
34
from typing import Any, Callable, Coroutine, Generic, Type, TypeVar
45
from itertools import count
56

@@ -108,17 +109,36 @@ class Reader(Stoppable):
108109
on_msg_coro: OnMsgCoro = attrs.field(validator=Validators.not_none())
109110
on_close_coro: OnCloseCoro = attrs.field(validator=Validators.not_none())
110111
_buffer: bytearray = attrs.field(init=False, factory=bytearray)
112+
_drain_buffer: bytearray | None = attrs.field(init=False, default=None)
111113
_task: asyncio.Task = attrs.field(init=False, default=None)
112114
_stopped: bool = attrs.field(init=False, default=False)
115+
_drain_mode: asyncio.Event | None = attrs.field(init=False, default=None)
113116

114117
def __attrs_post_init__(self):
115118
self._task = asyncio.create_task(self._process(), name=f'reader:{self.session_id}')
116119

120+
@contextlib.asynccontextmanager
121+
async def buffer_until_drained(self, discard_buffer: bool = False):
122+
"""Async context manager that waits until the buffer is drained."""
123+
if self._drain_mode:
124+
raise StateError('Already draining, cannot nest buffer_until_drained')
125+
self._drain_mode = asyncio.Event()
126+
self._drain_buffer = bytearray()
127+
try:
128+
await self._drain_mode.wait()
129+
yield
130+
finally:
131+
self._drain_mode = None
132+
if not discard_buffer:
133+
self._buffer.extend(self._drain_buffer)
134+
self._drain_buffer = None
135+
117136
def on_data(self, data: bytes):
118137
self.log.debug('%s> on_data: existing = %s, received = %s', self.session_id, self._buffer, data)
119138
if len(data) == 0:
120139
return
121-
self._buffer.extend(data)
140+
buffer = self._drain_buffer if self._drain_buffer is not None else self._buffer
141+
buffer.extend(data)
122142

123143
async def stop(self):
124144
if self._stopped:
@@ -133,8 +153,13 @@ def is_stopped(self):
133153

134154
async def _process(self):
135155
while not self._stopped:
136-
if len(self._buffer) > 0:
156+
len_before = len_after = len(self._buffer)
157+
if len_before > 0:
137158
await self._process_1()
159+
len_after = len(self._buffer)
160+
if self._drain_mode and not self._drain_mode.is_set() and len_after == len_before:
161+
self._drain_mode.set()
162+
138163
await asyncio.sleep(0.0001)
139164

140165
async def _process_1(self):
@@ -200,6 +225,7 @@ class AsyncSession(asyncio.Protocol, abc.ABC, Generic[T]):
200225
on_msg_coro: OnMsgCoro = attrs.field(kw_only=True, default=None)
201226
on_close_coro: OnCloseCoro = attrs.field(kw_only=True, default=None)
202227
dispatch_on_connect: bool = attrs.field(kw_only=True, default=True)
228+
graceful_shutdown: bool = attrs.field(kw_only=True, default=True)
203229
_reader: Reader = attrs.field(init=False, default=None)
204230
_transport: asyncio.Transport = attrs.field(init=False, default=None)
205231
_closed: bool = attrs.field(init=False, default=False)
@@ -242,7 +268,7 @@ def is_closed(self) -> bool:
242268
"""
243269
return self._closed
244270

245-
def initiate_close(self) -> None:
271+
def initiate_close(self, drain: bool = False) -> None:
246272
"""
247273
Initiate close of the session.
248274
An asynchronous task is created which will close the session and all its
@@ -253,22 +279,38 @@ def initiate_close(self) -> None:
253279
"""
254280
if self._closed or self._closing_task:
255281
return
256-
self._closing_task = asyncio.create_task(self.close(), name=f'asyncsession-close:{self.session_id}')
257282

258-
async def close(self):
283+
if drain:
284+
name = f'asyncsession-close-drain:{self.session_id}'
285+
else:
286+
name = f'asyncsession-close:{self.session_id}'
287+
288+
self._closing_task = asyncio.create_task(self.close(drain), name=name)
289+
290+
@contextlib.asynccontextmanager
291+
async def buffer_until_drained(self, discard_buffer: bool = False):
292+
"""Async context manager that waits until both the reader and message queue are drained."""
293+
async with self._reader.buffer_until_drained(discard_buffer=discard_buffer):
294+
async with self._msg_queue.buffer_until_drained(discard_buffer=discard_buffer):
295+
yield
296+
297+
async def close(self, drain: bool = False):
259298
"""
260299
Close the session, the session cannot be used after this call.
261300
"""
262301
if not self._closed:
263302
self._closed = True
264-
await stop_task([
265-
self._msg_queue,
266-
self._local_hb_monitor,
267-
self._remote_hb_monitor,
268-
self._reader
269-
])
270303
if self._transport:
271304
self._transport.close()
305+
306+
await stop_task([self._local_hb_monitor, self._remote_hb_monitor])
307+
308+
if drain:
309+
async with self.buffer_until_drained(discard_buffer=True):
310+
self.log.debug('%s> close: drained session.', self.session_id)
311+
312+
await stop_task([self._msg_queue,self._reader])
313+
272314
if self.on_close_coro:
273315
await self.on_close_coro()
274316

@@ -336,7 +378,7 @@ def connection_lost(self, exc):
336378
:meta private:
337379
"""
338380
self.log.debug('%s> connection lost', self.session_id)
339-
self.initiate_close()
381+
self.initiate_close(self.graceful_shutdown)
340382

341383
async def on_message(self, msg):
342384
await self._msg_queue.put(msg)

tests/reader_app_tests.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ async def reader__stop__reader_is_stopped(**kwargs):
5858

5959
assert handler.closed.is_set()
6060

61+
# Double stop should be no-op
62+
await common.stop_task(reader)
63+
6164

6265
@reader_test
6366
async def reader__one_msg_per_packet__msg_is_read(**kwargs):
@@ -148,6 +151,51 @@ async def reader__empty_data__no_effect(**kwargs):
148151
await reader.stop()
149152

150153

154+
@reader_test
155+
async def reader__buffer_until_drained__buffers_not_discarded(**kwargs):
156+
handler, reader, input_factory, output_factory = all_test_params(**kwargs)
157+
158+
discard_buffers_modes = [False, True]
159+
for discard_buffer in discard_buffers_modes:
160+
reader.on_data(input_factory(1))
161+
async with reader.buffer_until_drained(discard_buffer=discard_buffer):
162+
first = await handler.received_messages.get()
163+
assert first == output_factory(1)
164+
165+
# Send another message while buffering
166+
reader.on_data(input_factory(2))
167+
await asyncio.sleep(0.1)
168+
169+
# Ensure no new messages are received while buffering
170+
try:
171+
second = await asyncio.wait_for(handler.received_messages.get(), timeout=0.5)
172+
assert False, f'Unexpected message received: {second}'
173+
except asyncio.TimeoutError:
174+
pass
175+
176+
if not discard_buffer:
177+
# After exiting the context, the buffered message should be processed
178+
second = await asyncio.wait_for(handler.received_messages.get(), timeout=0.5)
179+
assert second == output_factory(2)
180+
else:
181+
with pytest.raises(asyncio.TimeoutError):
182+
await asyncio.wait_for(handler.received_messages.get(), timeout=0.5)
183+
184+
await reader.stop()
185+
186+
187+
@reader_test
188+
async def reader__buffer_until_drained__nested_call_raises_error(**kwargs):
189+
handler, reader, input_factory, output_factory = all_test_params(**kwargs)
190+
191+
async with reader.buffer_until_drained():
192+
with pytest.raises(RuntimeError):
193+
async with reader.buffer_until_drained():
194+
pass
195+
196+
await reader.stop()
197+
198+
151199
@pytest.fixture(scope='function', params=READER_TESTS)
152200
async def reader_clientapp_common_tests(request, handler):
153201
async def _test(reader_factory, input_factory, output_factory):

0 commit comments

Comments
 (0)