Skip to content

Commit a7a6f49

Browse files
author
Sam Daniel Thangarajan
committed
add graceful shutdown for sessions
1 parent 807e48c commit a7a6f49

File tree

2 files changed

+55
-3
lines changed

2 files changed

+55
-3
lines changed

src/nasdaq_protocols/common/session.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,17 @@ class Reader(Stoppable):
107107
session_id: Any = attrs.field(validator=Validators.not_none())
108108
on_msg_coro: OnMsgCoro = attrs.field(validator=Validators.not_none())
109109
on_close_coro: OnCloseCoro = attrs.field(validator=Validators.not_none())
110+
_close_when_drained: bool = attrs.field(init=False, default=False)
110111
_buffer: bytearray = attrs.field(init=False, factory=bytearray)
111112
_task: asyncio.Task = attrs.field(init=False, default=None)
112113
_stopped: bool = attrs.field(init=False, default=False)
113114

114115
def __attrs_post_init__(self):
115116
self._task = asyncio.create_task(self._process(), name=f'reader:{self.session_id}')
116117

118+
def close_when_drained(self):
119+
self._close_when_drained = True
120+
117121
def on_data(self, data: bytes):
118122
self.log.debug('%s> on_data: existing = %s, received = %s', self.session_id, self._buffer, data)
119123
if len(data) == 0:
@@ -133,8 +137,14 @@ def is_stopped(self):
133137

134138
async def _process(self):
135139
while not self._stopped:
136-
if len(self._buffer) > 0:
140+
len_before = len(self._buffer)
141+
if len_before > 0:
137142
await self._process_1()
143+
if self._close_when_drained and len(self._buffer) == len_before:
144+
self.log.debug('%s> buffer drained, closing reader.', self.session_id)
145+
await self.stop()
146+
return
147+
138148
await asyncio.sleep(0.0001)
139149

140150
async def _process_1(self):
@@ -200,6 +210,7 @@ class AsyncSession(asyncio.Protocol, abc.ABC, Generic[T]):
200210
on_msg_coro: OnMsgCoro = attrs.field(kw_only=True, default=None)
201211
on_close_coro: OnCloseCoro = attrs.field(kw_only=True, default=None)
202212
dispatch_on_connect: bool = attrs.field(kw_only=True, default=True)
213+
graceful_shutdown: bool = attrs.field(kw_only=True, default=True)
203214
_reader: Reader = attrs.field(init=False, default=None)
204215
_transport: asyncio.Transport = attrs.field(init=False, default=None)
205216
_closed: bool = attrs.field(init=False, default=False)
@@ -255,6 +266,13 @@ def initiate_close(self) -> None:
255266
return
256267
self._closing_task = asyncio.create_task(self.close(), name=f'asyncsession-close:{self.session_id}')
257268

269+
async def _graceful_close(self):
270+
await stop_task([
271+
self._local_hb_monitor,
272+
self._remote_hb_monitor,
273+
])
274+
self._reader.close_when_drained()
275+
258276
async def close(self):
259277
"""
260278
Close the session, the session cannot be used after this call.
@@ -336,7 +354,10 @@ def connection_lost(self, exc):
336354
:meta private:
337355
"""
338356
self.log.debug('%s> connection lost', self.session_id)
339-
self.initiate_close()
357+
if self.graceful_shutdown:
358+
if self._closed or self._closing_task:
359+
return
360+
self._closing_task = asyncio.create_task(self._graceful_close(), name=f'asyncsession-graceful-close:{self.session_id}')
340361

341362
async def on_message(self, msg):
342363
await self._msg_queue.put(msg)

tests/test_common_asyncsession.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
@attrs.define(auto_attribs=True)
1414
class SampleTestReader(common.Reader):
1515
def deserialize(self) -> Any:
16-
return self._buffer.decode('ascii'), False, False
16+
result = self._buffer.decode('ascii')
17+
self._buffer = bytearray()
18+
return result, False, False
1719

1820
@staticmethod
1921
def create(session_id, on_msg, on_close):
@@ -39,6 +41,7 @@ async def send_heartbeat(self):
3941
pass
4042

4143
async def on_close(self):
44+
LOG.warning("Session %s closed", self.session_id)
4245
self.closed.set()
4346

4447

@@ -103,3 +106,31 @@ async def test__asyncsession__missed_remote_heartbeats__session_is_closed(mock_s
103106

104107
# test client session is closed
105108
assert client_session.is_closed()
109+
110+
111+
async def test__asyncsession__stream_after_connect_and_close__client_session_able_to_read_all_messages(mock_server_session):
112+
event_loop = asyncio.get_running_loop()
113+
port, server_session = mock_server_session
114+
115+
# server sends 100 messages after connection and then closes the session
116+
def stream_messages_and_close(session, _):
117+
for i in range(100000):
118+
session.send(f'msg{i}|'.encode('ascii'))
119+
session.send(f'end-of-stream'.encode('ascii'))
120+
session.close()
121+
server_session.when_connect().do(stream_messages_and_close)
122+
123+
session_ = SampleTestClientSession(session_id=common.SessionId())
124+
_, session_ = await event_loop.create_connection(lambda: session_, '127.0.0.1', port=port)
125+
126+
found_end_of_stream = False
127+
while not session_.closed.is_set() and not found_end_of_stream:
128+
msg = await asyncio.wait_for(session_.received.get(), 1)
129+
if 'end-of-stream' in msg:
130+
found_end_of_stream = True
131+
132+
await session_.closed.wait()
133+
134+
assert found_end_of_stream
135+
assert session_.closed.is_set()
136+
assert session_.is_closed()

0 commit comments

Comments
 (0)