@@ -29,7 +29,7 @@ class Broadcast:
29
29
def __init__ (self , url : str | None = None , * , backend : BroadcastBackend | None = None ) -> None :
30
30
assert url or backend , "Either `url` or `backend` must be provided."
31
31
self ._backend = backend or self ._create_backend (cast (str , url ))
32
- self ._subscribers : dict [str , set [asyncio .Queue [Event | None ]]] = {}
32
+ self ._subscribers : dict [str , set [asyncio .Queue [Event | BaseException | None ]]] = {}
33
33
34
34
def _create_backend (self , url : str ) -> BroadcastBackend :
35
35
parsed_url = urlparse (url )
@@ -69,10 +69,19 @@ async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
69
69
async def connect (self ) -> None :
70
70
await self ._backend .connect ()
71
71
self ._listener_task = asyncio .create_task (self ._listener ())
72
+ self ._listener_task .add_done_callback (self .drop )
73
+
74
+ def drop (self , task : asyncio .Task [None ]) -> None :
75
+ exc = task .exception ()
76
+ for queues in self ._subscribers .values ():
77
+ for queue in queues :
78
+ queue .put_nowait (exc )
72
79
73
80
async def disconnect (self ) -> None :
74
81
if self ._listener_task .done ():
75
- self ._listener_task .result ()
82
+ exc = self ._listener_task .exception ()
83
+ if exc is None :
84
+ self ._listener_task .result ()
76
85
else :
77
86
self ._listener_task .cancel ()
78
87
await self ._backend .disconnect ()
@@ -88,7 +97,7 @@ async def publish(self, channel: str, message: Any) -> None:
88
97
89
98
@asynccontextmanager
90
99
async def subscribe (self , channel : str ) -> AsyncIterator [Subscriber ]:
91
- queue : asyncio .Queue [Event | None ] = asyncio .Queue ()
100
+ queue : asyncio .Queue [Event | BaseException | None ] = asyncio .Queue ()
92
101
93
102
try :
94
103
if not self ._subscribers .get (channel ):
@@ -107,7 +116,7 @@ async def subscribe(self, channel: str) -> AsyncIterator[Subscriber]:
107
116
108
117
109
118
class Subscriber :
110
- def __init__ (self , queue : asyncio .Queue [Event | None ]) -> None :
119
+ def __init__ (self , queue : asyncio .Queue [Event | BaseException | None ]) -> None :
111
120
self ._queue = queue
112
121
113
122
async def __aiter__ (self ) -> AsyncGenerator [Event | None , None ]:
@@ -119,6 +128,8 @@ async def __aiter__(self) -> AsyncGenerator[Event | None, None]:
119
128
120
129
async def get (self ) -> Event :
121
130
item = await self ._queue .get ()
131
+ if isinstance (item , BaseException ):
132
+ raise item
122
133
if item is None :
123
134
raise Unsubscribed ()
124
135
return item
0 commit comments