Skip to content

Commit fc26d10

Browse files
author
Sam Daniel Thangarajan
committed
more work in progress
1 parent 9d6b068 commit fc26d10

File tree

5 files changed

+283
-52
lines changed

5 files changed

+283
-52
lines changed

src/nasdaq_protocols/common/message_queue.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import contextlib
23
from contextlib import asynccontextmanager
34
from itertools import count
45
from typing import Any, Callable, Awaitable
@@ -26,19 +27,25 @@ class DispatchableMessageQueue(Stoppable):
2627
on_msg_coro: DispatcherCoro = None
2728
_closed: bool = attrs.field(init=False, default=False)
2829
_msg_queue: asyncio.Queue = attrs.field(init=False, default=None)
30+
_buffer_msg_queue: asyncio.Queue | None = attrs.field(init=False, default=None)
2931
_recv_task: asyncio.Task = attrs.field(init=False, default=None)
3032
_dispatcher_task: asyncio.Task = attrs.field(init=False, default=None)
3133

3234
def __attrs_post_init__(self):
3335
self._msg_queue = asyncio.Queue()
3436
self.start_dispatching(self.on_msg_coro)
3537

38+
def __len__(self) -> int:
39+
"""Return the number of entries in the queue."""
40+
return self._msg_queue.qsize()
41+
3642
async def put(self, msg: Any) -> None:
3743
"""
3844
put an entry into the queue.
3945
:param msg: Any
4046
"""
41-
await self._msg_queue.put(msg)
47+
queue = self._buffer_msg_queue if self._buffer_msg_queue else self._msg_queue
48+
await queue.put(msg)
4249

4350
async def get(self):
4451
"""get an entry from the queue.
@@ -59,7 +66,8 @@ def put_nowait(self, msg: Any):
5966
put an entry into the queue.
6067
:param msg: Any
6168
"""
62-
self._msg_queue.put_nowait(msg)
69+
queue = self._buffer_msg_queue if self._buffer_msg_queue else self._msg_queue
70+
queue.put_nowait(msg)
6371

6472
def get_nowait(self) -> Any | None:
6573
"""
@@ -118,6 +126,27 @@ def start_dispatching(self, on_msg_coro: DispatcherCoro) -> None:
118126
self._dispatcher_task = asyncio.create_task(self._start_dispatching(), name=f'{self.session_id}-dispatcher')
119127
self.log.debug('%s> queue dispatcher started.', self.session_id)
120128

129+
@contextlib.asynccontextmanager
130+
async def buffer_until_drained(self, discard_buffer: bool = False):
131+
"""Async context manager that waits until the buffer is drained."""
132+
if self._buffer_msg_queue:
133+
raise StateError('Already blocking new messages, cannot nest block_until_empty')
134+
135+
self._buffer_msg_queue = asyncio.Queue()
136+
137+
try:
138+
while not self._msg_queue.empty():
139+
await asyncio.sleep(0.0001)
140+
yield
141+
finally:
142+
if not discard_buffer:
143+
if self.is_dispatching():
144+
async with self.pause_dispatching():
145+
self._msg_queue = self._buffer_msg_queue
146+
else:
147+
self._msg_queue = self._buffer_msg_queue
148+
self._buffer_msg_queue = None
149+
121150
async def stop(self) -> None:
122151
"""
123152
Stop the queue.
@@ -137,8 +166,7 @@ async def _start_dispatching(self):
137166
counter = count(1)
138167
while True:
139168
try:
140-
msg = await self._msg_queue.get()
141-
await self.on_msg_coro(msg)
169+
await self.on_msg_coro(await self._msg_queue.get())
142170
self.log.debug('%s> dispatched message %s', self.session_id, next(counter))
143171
except asyncio.CancelledError:
144172
break

src/nasdaq_protocols/common/session.py

Lines changed: 49 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import abc
22
import asyncio
3+
import contextlib
34
from typing import Any, Callable, Coroutine, Generic, Type, TypeVar
45
from itertools import count
56

@@ -107,22 +108,37 @@ class Reader(Stoppable):
107108
session_id: Any = attrs.field(validator=Validators.not_none())
108109
on_msg_coro: OnMsgCoro = attrs.field(validator=Validators.not_none())
109110
on_close_coro: OnCloseCoro = attrs.field(validator=Validators.not_none())
110-
_close_when_drained: bool = attrs.field(init=False, default=False)
111111
_buffer: bytearray = attrs.field(init=False, factory=bytearray)
112+
_drain_buffer: bytearray | None = attrs.field(init=False, default=None)
112113
_task: asyncio.Task = attrs.field(init=False, default=None)
113114
_stopped: bool = attrs.field(init=False, default=False)
115+
_drain_mode: asyncio.Event | None = attrs.field(init=False, default=None)
114116

115117
def __attrs_post_init__(self):
116118
self._task = asyncio.create_task(self._process(), name=f'reader:{self.session_id}')
117119

118-
def close_when_drained(self):
119-
self._close_when_drained = True
120+
@contextlib.asynccontextmanager
121+
async def buffer_until_drained(self, discard_buffer: bool = False):
122+
"""Async context manager that waits until the buffer is drained."""
123+
if self._drain_mode:
124+
raise StateError('Already draining, cannot nest buffer_until_drained')
125+
self._drain_mode = asyncio.Event()
126+
self._drain_buffer = bytearray()
127+
try:
128+
await self._drain_mode.wait()
129+
yield
130+
finally:
131+
self._drain_mode = None
132+
if not discard_buffer:
133+
self._buffer.extend(self._drain_buffer)
134+
self._drain_buffer = None
120135

121136
def on_data(self, data: bytes):
122137
self.log.debug('%s> on_data: existing = %s, received = %s', self.session_id, self._buffer, data)
123138
if len(data) == 0:
124139
return
125-
self._buffer.extend(data)
140+
buffer = self._drain_buffer if self._drain_buffer is not None else self._buffer
141+
buffer.extend(data)
126142

127143
async def stop(self):
128144
if self._stopped:
@@ -137,13 +153,12 @@ def is_stopped(self):
137153

138154
async def _process(self):
139155
while not self._stopped:
140-
len_before = len(self._buffer)
156+
len_before = len_after = len(self._buffer)
141157
if len_before > 0:
142158
await self._process_1()
143-
if self._close_when_drained and len(self._buffer) == len_before:
144-
self.log.debug('%s> buffer drained, closing reader.', self.session_id)
145-
await self.stop()
146-
return
159+
len_after = len(self._buffer)
160+
if self._drain_mode and not self._drain_mode.is_set() and len_after == len_before:
161+
self._drain_mode.set()
147162

148163
await asyncio.sleep(0.0001)
149164

@@ -253,7 +268,7 @@ def is_closed(self) -> bool:
253268
"""
254269
return self._closed
255270

256-
def initiate_close(self) -> None:
271+
def initiate_close(self, drain: bool = False) -> None:
257272
"""
258273
Initiate close of the session.
259274
An asynchronous task is created which will close the session and all its
@@ -264,29 +279,38 @@ def initiate_close(self) -> None:
264279
"""
265280
if self._closed or self._closing_task:
266281
return
267-
self._closing_task = asyncio.create_task(self.close(), name=f'asyncsession-close:{self.session_id}')
268282

269-
async def _graceful_close(self):
270-
await stop_task([
271-
self._local_hb_monitor,
272-
self._remote_hb_monitor,
273-
])
274-
self._reader.close_when_drained()
283+
if drain:
284+
name = f'asyncsession-close-drain:{self.session_id}'
285+
else:
286+
name = f'asyncsession-close:{self.session_id}'
287+
288+
self._closing_task = asyncio.create_task(self.close(drain), name=name)
289+
290+
@contextlib.asynccontextmanager
291+
async def buffer_until_drained(self, discard_buffer: bool = False):
292+
"""Async context manager that waits until both the reader and message queue are drained."""
293+
async with self._reader.buffer_until_drained(discard_buffer=discard_buffer):
294+
async with self._msg_queue.buffer_until_drained(discard_buffer=discard_buffer):
295+
yield
275296

276-
async def close(self):
297+
async def close(self, drain: bool = False):
277298
"""
278299
Close the session, the session cannot be used after this call.
279300
"""
280301
if not self._closed:
281302
self._closed = True
282-
await stop_task([
283-
self._msg_queue,
284-
self._local_hb_monitor,
285-
self._remote_hb_monitor,
286-
self._reader
287-
])
288303
if self._transport:
289304
self._transport.close()
305+
306+
await stop_task([self._local_hb_monitor, self._remote_hb_monitor])
307+
308+
if drain:
309+
async with self.buffer_until_drained(discard_buffer=True):
310+
self.log.debug('%s> close: drained session.', self.session_id)
311+
312+
await stop_task([self._msg_queue,self._reader])
313+
290314
if self.on_close_coro:
291315
await self.on_close_coro()
292316

@@ -354,12 +378,7 @@ def connection_lost(self, exc):
354378
:meta private:
355379
"""
356380
self.log.debug('%s> connection lost', self.session_id)
357-
if self.graceful_shutdown:
358-
if self._closed or self._closing_task:
359-
return
360-
self._closing_task = asyncio.create_task(self._graceful_close(), name=f'asyncsession-graceful-close:{self.session_id}')
361-
else:
362-
self.initiate_close()
381+
self.initiate_close(self.graceful_shutdown)
363382

364383
async def on_message(self, msg):
365384
await self._msg_queue.put(msg)

tests/reader_app_tests.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,37 @@ async def reader__empty_data__no_effect(**kwargs):
148148
await reader.stop()
149149

150150

151+
@reader_test
152+
async def reader__buffer_until_drained__buffers_not_discarded(**kwargs):
153+
handler, reader, input_factory, output_factory = all_test_params(**kwargs)
154+
155+
discard_buffers_modes = [False, True]
156+
for discard_buffer in discard_buffers_modes:
157+
reader.on_data(input_factory(1))
158+
async with reader.buffer_until_drained(discard_buffer=discard_buffer):
159+
first = await handler.received_messages.get()
160+
assert first == output_factory(1)
161+
162+
# Send another message while buffering
163+
reader.on_data(input_factory(2))
164+
await asyncio.sleep(0.1)
165+
166+
# Ensure no new messages are received while buffering
167+
try:
168+
second = await asyncio.wait_for(handler.received_messages.get(), timeout=0.5)
169+
assert False, f'Unexpected message received: {second}'
170+
except asyncio.TimeoutError:
171+
pass
172+
173+
if not discard_buffer:
174+
# After exiting the context, the buffered message should be processed
175+
second = await asyncio.wait_for(handler.received_messages.get(), timeout=0.5)
176+
assert second == output_factory(2)
177+
else:
178+
with pytest.raises(asyncio.TimeoutError):
179+
await asyncio.wait_for(handler.received_messages.get(), timeout=0.5)
180+
181+
151182
@pytest.fixture(scope='function', params=READER_TESTS)
152183
async def reader_clientapp_common_tests(request, handler):
153184
async def _test(reader_factory, input_factory, output_factory):

tests/test_common_asyncsession.py

Lines changed: 61 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import logging
3+
import socket
34
from typing import Any
45

56
import attrs
@@ -12,31 +13,50 @@
1213

1314
@attrs.define(auto_attribs=True)
1415
class SampleTestReader(common.Reader):
16+
separator: bytes = attrs.field(kw_only=True, default=None)
17+
1518
def deserialize(self) -> Any:
16-
result = self._buffer.decode('ascii')
17-
self._buffer = bytearray()
19+
idx, sep_len = len(self._buffer), len(self.separator) if self.separator else 0
20+
if self.separator:
21+
idx = self._buffer.find(self.separator)
22+
if idx == -1:
23+
return None, False, False
24+
25+
result = self._buffer[:idx].decode('ascii')
26+
self._buffer = self._buffer[idx + sep_len:]
1827
return result, False, False
1928

2029
@staticmethod
2130
def create(session_id, on_msg, on_close):
2231
return SampleTestReader(session_id, on_msg, on_close)
2332

33+
@staticmethod
34+
def creator(separator: bytes):
35+
def _creator(session_id, on_msg, on_close):
36+
return SampleTestReader(session_id, on_msg, on_close, separator=separator)
37+
return _creator
38+
2439

2540
@attrs.define(auto_attribs=True)
2641
class SampleTestClientSession(common.AsyncSession):
2742
session_id: Any = attrs.field(validator=common.Validators.not_none())
2843
reader_factory: common.ReaderFactory = SampleTestReader.create
2944
received: asyncio.Queue = attrs.field(init=False, factory=asyncio.Queue)
3045
closed: asyncio.Event = attrs.field(init=False, factory=asyncio.Event)
46+
slow_client: bool = attrs.field(kw_only=True, default=False)
3147

3248
def __attrs_post_init__(self):
33-
self.on_msg_coro = self.received.put
49+
self.on_msg_coro = self.slow_on_msg if self.slow_client else self.received.put
3450
self.on_close_coro = self.on_close
3551
super().__attrs_post_init__()
3652

3753
def send_msg(self, data: str):
3854
self._transport.write(data.encode('ascii'))
3955

56+
async def slow_on_msg(self, msg):
57+
await self.received.put(msg)
58+
await asyncio.sleep(0.01)
59+
4060
async def send_heartbeat(self):
4161
pass
4262

@@ -135,32 +155,56 @@ def stream_messages_and_close(session, _):
135155
assert session_.closed.is_set()
136156
assert session_.is_closed()
137157

138-
async def test__asyncsession__stream_after_connect_and_close__no_graceful_close_in_client_session(mock_server_session):
158+
159+
@pytest.mark.parametrize('graceful_shutdown', [False, True])
160+
@pytest.mark.parametrize('num_messages', [1, 10, 100])
161+
@pytest.mark.parametrize('slow_client', [False, True])
162+
async def test__asyncsession__graceful_shutdown__enables_client_to_read_all_messages(mock_server_session, graceful_shutdown, num_messages, slow_client):
139163
event_loop = asyncio.get_running_loop()
140164
port, server_session = mock_server_session
141165

142-
# server sends 100000 messages after connection and then closes the session
143-
def stream_messages_and_close(session, _):
144-
for i in range(100000):
145-
session.send(f'msg{i}|'.encode('ascii'))
146-
session.send(f'end-of-stream'.encode('ascii'))
147-
session.close()
148-
server_session.when_connect().do(stream_messages_and_close)
166+
# server sends n messages after connection and then closes the session
167+
server_session.when_connect().do(get_streamer_close_function(num_messages))
149168

150-
session_ = SampleTestClientSession(session_id=common.SessionId(), graceful_shutdown=False)
151-
_, session_ = await event_loop.create_connection(lambda: session_, '127.0.0.1', port=port)
169+
session_ = SampleTestClientSession(
170+
session_id=common.SessionId(),
171+
reader_factory=SampleTestReader.creator(separator=b'|'),
172+
graceful_shutdown=graceful_shutdown,
173+
slow_client=slow_client
174+
)
175+
176+
_1, _2 = await event_loop.create_connection(
177+
lambda: session_, '127.0.0.1', port=port
178+
)
152179

153180
found_end_of_stream = False
154-
while not session_.closed.is_set() and not found_end_of_stream:
181+
while not found_end_of_stream:
155182
try:
156183
msg = await asyncio.wait_for(session_.received.get(), 0.1)
184+
if 'end-of-stream' == msg:
185+
found_end_of_stream = True
186+
except asyncio.TimeoutError:
187+
if session_.closed.is_set():
188+
break
157189
except Exception:
158190
break
159-
if 'end-of-stream' in msg:
160-
found_end_of_stream = True
161191

192+
task_id = asyncio.current_task().get_name()
193+
LOG.info('Waiting for session to close, from task %s', task_id)
162194
await session_.closed.wait()
195+
LOG.info('Session closed')
163196

164197
assert session_.closed.is_set()
165198
assert session_.is_closed()
166-
assert found_end_of_stream == False
199+
200+
assert found_end_of_stream == graceful_shutdown, \
201+
"Graceful shutdown should allow reading all messages"
202+
203+
204+
def get_streamer_close_function(message_count: int):
205+
def stream_messages_and_close(session, _):
206+
for i in range(message_count):
207+
session.send(f'msg{i}|'.encode('ascii'))
208+
session.send(f'end-of-stream|'.encode('ascii'))
209+
session.close()
210+
return stream_messages_and_close

0 commit comments

Comments
 (0)