11import abc
22import asyncio
3+ import contextlib
34from typing import Any , Callable , Coroutine , Generic , Type , TypeVar
45from itertools import count
56
@@ -108,17 +109,36 @@ class Reader(Stoppable):
108109 on_msg_coro : OnMsgCoro = attrs .field (validator = Validators .not_none ())
109110 on_close_coro : OnCloseCoro = attrs .field (validator = Validators .not_none ())
110111 _buffer : bytearray = attrs .field (init = False , factory = bytearray )
112+ _drain_buffer : bytearray | None = attrs .field (init = False , default = None )
111113 _task : asyncio .Task = attrs .field (init = False , default = None )
112114 _stopped : bool = attrs .field (init = False , default = False )
115+ _drain_mode : asyncio .Event | None = attrs .field (init = False , default = None )
113116
114117 def __attrs_post_init__ (self ):
115118 self ._task = asyncio .create_task (self ._process (), name = f'reader:{ self .session_id } ' )
116119
120+ @contextlib .asynccontextmanager
121+ async def buffer_until_drained (self , discard_buffer : bool = False ):
122+ """Async context manager that waits until the buffer is drained."""
123+ if self ._drain_mode :
124+ raise StateError ('Already draining, cannot nest buffer_until_drained' )
125+ self ._drain_mode = asyncio .Event ()
126+ self ._drain_buffer = bytearray ()
127+ try :
128+ await self ._drain_mode .wait ()
129+ yield
130+ finally :
131+ self ._drain_mode = None
132+ if not discard_buffer :
133+ self ._buffer .extend (self ._drain_buffer )
134+ self ._drain_buffer = None
135+
117136 def on_data (self , data : bytes ):
118137 self .log .debug ('%s> on_data: existing = %s, received = %s' , self .session_id , self ._buffer , data )
119138 if len (data ) == 0 :
120139 return
121- self ._buffer .extend (data )
140+ buffer = self ._drain_buffer if self ._drain_buffer is not None else self ._buffer
141+ buffer .extend (data )
122142
123143 async def stop (self ):
124144 if self ._stopped :
@@ -133,8 +153,13 @@ def is_stopped(self):
133153
134154 async def _process (self ):
135155 while not self ._stopped :
136- if len (self ._buffer ) > 0 :
156+ len_before = len_after = len (self ._buffer )
157+ if len_before > 0 :
137158 await self ._process_1 ()
159+ len_after = len (self ._buffer )
160+ if self ._drain_mode and not self ._drain_mode .is_set () and len_after == len_before :
161+ self ._drain_mode .set ()
162+
138163 await asyncio .sleep (0.0001 )
139164
140165 async def _process_1 (self ):
@@ -200,6 +225,7 @@ class AsyncSession(asyncio.Protocol, abc.ABC, Generic[T]):
200225 on_msg_coro : OnMsgCoro = attrs .field (kw_only = True , default = None )
201226 on_close_coro : OnCloseCoro = attrs .field (kw_only = True , default = None )
202227 dispatch_on_connect : bool = attrs .field (kw_only = True , default = True )
228+ graceful_shutdown : bool = attrs .field (kw_only = True , default = True )
203229 _reader : Reader = attrs .field (init = False , default = None )
204230 _transport : asyncio .Transport = attrs .field (init = False , default = None )
205231 _closed : bool = attrs .field (init = False , default = False )
@@ -242,7 +268,7 @@ def is_closed(self) -> bool:
242268 """
243269 return self ._closed
244270
245- def initiate_close (self ) -> None :
271+ def initiate_close (self , drain : bool = False ) -> None :
246272 """
247273 Initiate close of the session.
248274 An asynchronous task is created which will close the session and all its
@@ -253,22 +279,38 @@ def initiate_close(self) -> None:
253279 """
254280 if self ._closed or self ._closing_task :
255281 return
256- self ._closing_task = asyncio .create_task (self .close (), name = f'asyncsession-close:{ self .session_id } ' )
257282
258- async def close (self ):
283+ if drain :
284+ name = f'asyncsession-close-drain:{ self .session_id } '
285+ else :
286+ name = f'asyncsession-close:{ self .session_id } '
287+
288+ self ._closing_task = asyncio .create_task (self .close (drain ), name = name )
289+
290+ @contextlib .asynccontextmanager
291+ async def buffer_until_drained (self , discard_buffer : bool = False ):
292+ """Async context manager that waits until both the reader and message queue are drained."""
293+ async with self ._reader .buffer_until_drained (discard_buffer = discard_buffer ):
294+ async with self ._msg_queue .buffer_until_drained (discard_buffer = discard_buffer ):
295+ yield
296+
297+ async def close (self , drain : bool = False ):
259298 """
260299 Close the session, the session cannot be used after this call.
261300 """
262301 if not self ._closed :
263302 self ._closed = True
264- await stop_task ([
265- self ._msg_queue ,
266- self ._local_hb_monitor ,
267- self ._remote_hb_monitor ,
268- self ._reader
269- ])
270303 if self ._transport :
271304 self ._transport .close ()
305+
306+ await stop_task ([self ._local_hb_monitor , self ._remote_hb_monitor ])
307+
308+ if drain :
309+ async with self .buffer_until_drained (discard_buffer = True ):
310+ self .log .debug ('%s> close: drained session.' , self .session_id )
311+
312+ await stop_task ([self ._msg_queue ,self ._reader ])
313+
272314 if self .on_close_coro :
273315 await self .on_close_coro ()
274316
@@ -336,7 +378,7 @@ def connection_lost(self, exc):
336378 :meta private:
337379 """
338380 self .log .debug ('%s> connection lost' , self .session_id )
339- self .initiate_close ()
381+ self .initiate_close (self . graceful_shutdown )
340382
341383 async def on_message (self , msg ):
342384 await self ._msg_queue .put (msg )
0 commit comments