@@ -107,13 +107,17 @@ class Reader(Stoppable):
107107 session_id : Any = attrs .field (validator = Validators .not_none ())
108108 on_msg_coro : OnMsgCoro = attrs .field (validator = Validators .not_none ())
109109 on_close_coro : OnCloseCoro = attrs .field (validator = Validators .not_none ())
110+ _close_when_drained : bool = attrs .field (init = False , default = False )
110111 _buffer : bytearray = attrs .field (init = False , factory = bytearray )
111112 _task : asyncio .Task = attrs .field (init = False , default = None )
112113 _stopped : bool = attrs .field (init = False , default = False )
113114
114115 def __attrs_post_init__ (self ):
115116 self ._task = asyncio .create_task (self ._process (), name = f'reader:{ self .session_id } ' )
116117
118+ def close_when_drained (self ):
119+ self ._close_when_drained = True
120+
117121 def on_data (self , data : bytes ):
118122 self .log .debug ('%s> on_data: existing = %s, received = %s' , self .session_id , self ._buffer , data )
119123 if len (data ) == 0 :
@@ -133,8 +137,14 @@ def is_stopped(self):
133137
134138 async def _process (self ):
135139 while not self ._stopped :
136- if len (self ._buffer ) > 0 :
140+ len_before = len (self ._buffer )
141+ if len_before > 0 :
137142 await self ._process_1 ()
143+ if self ._close_when_drained and len (self ._buffer ) == len_before :
144+ self .log .debug ('%s> buffer drained, closing reader.' , self .session_id )
145+ await self .stop ()
146+ return
147+
138148 await asyncio .sleep (0.0001 )
139149
140150 async def _process_1 (self ):
@@ -200,6 +210,7 @@ class AsyncSession(asyncio.Protocol, abc.ABC, Generic[T]):
200210 on_msg_coro : OnMsgCoro = attrs .field (kw_only = True , default = None )
201211 on_close_coro : OnCloseCoro = attrs .field (kw_only = True , default = None )
202212 dispatch_on_connect : bool = attrs .field (kw_only = True , default = True )
213+ graceful_shutdown : bool = attrs .field (kw_only = True , default = True )
203214 _reader : Reader = attrs .field (init = False , default = None )
204215 _transport : asyncio .Transport = attrs .field (init = False , default = None )
205216 _closed : bool = attrs .field (init = False , default = False )
@@ -255,6 +266,13 @@ def initiate_close(self) -> None:
255266 return
256267 self ._closing_task = asyncio .create_task (self .close (), name = f'asyncsession-close:{ self .session_id } ' )
257268
269+ async def _graceful_close (self ):
270+ await stop_task ([
271+ self ._local_hb_monitor ,
272+ self ._remote_hb_monitor ,
273+ ])
274+ self ._reader .close_when_drained ()
275+
258276 async def close (self ):
259277 """
260278 Close the session, the session cannot be used after this call.
@@ -336,7 +354,10 @@ def connection_lost(self, exc):
336354 :meta private:
337355 """
338356 self .log .debug ('%s> connection lost' , self .session_id )
339- self .initiate_close ()
357+ if self .graceful_shutdown :
358+ if self ._closed or self ._closing_task :
359+ return
360+ self ._closing_task = asyncio .create_task (self ._graceful_close (), name = f'asyncsession-graceful-close:{ self .session_id } ' )
340361
341362 async def on_message (self , msg ):
342363 await self ._msg_queue .put (msg )
0 commit comments