Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.

Commit d72180e

Browse files
committed
Handle Redis pub/sub subscribe errors
1 parent a422d8a commit d72180e

File tree

3 files changed

+44
-6
lines changed

3 files changed

+44
-6
lines changed

broadcaster/_base.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class Broadcast:
2929
def __init__(self, url: str | None = None, *, backend: BroadcastBackend | None = None) -> None:
3030
assert url or backend, "Either `url` or `backend` must be provided."
3131
self._backend = backend or self._create_backend(cast(str, url))
32-
self._subscribers: dict[str, set[asyncio.Queue[Event | None]]] = {}
32+
self._subscribers: dict[str, set[asyncio.Queue[Event | BaseException | None]]] = {}
3333

3434
def _create_backend(self, url: str) -> BroadcastBackend:
3535
parsed_url = urlparse(url)
@@ -69,10 +69,19 @@ async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
6969
async def connect(self) -> None:
7070
await self._backend.connect()
7171
self._listener_task = asyncio.create_task(self._listener())
72+
self._listener_task.add_done_callback(self.drop)
73+
74+
def drop(self, task: asyncio.Task[None]) -> None:
75+
exc = task.exception()
76+
for queues in self._subscribers.values():
77+
for queue in queues:
78+
queue.put_nowait(exc)
7279

7380
async def disconnect(self) -> None:
7481
if self._listener_task.done():
75-
self._listener_task.result()
82+
exc = self._listener_task.exception()
83+
if exc is None:
84+
self._listener_task.result()
7685
else:
7786
self._listener_task.cancel()
7887
await self._backend.disconnect()
@@ -88,7 +97,7 @@ async def publish(self, channel: str, message: Any) -> None:
8897

8998
@asynccontextmanager
9099
async def subscribe(self, channel: str) -> AsyncIterator[Subscriber]:
91-
queue: asyncio.Queue[Event | None] = asyncio.Queue()
100+
queue: asyncio.Queue[Event | BaseException | None] = asyncio.Queue()
92101

93102
try:
94103
if not self._subscribers.get(channel):
@@ -107,7 +116,7 @@ async def subscribe(self, channel: str) -> AsyncIterator[Subscriber]:
107116

108117

109118
class Subscriber:
110-
def __init__(self, queue: asyncio.Queue[Event | None]) -> None:
119+
def __init__(self, queue: asyncio.Queue[Event | BaseException | None]) -> None:
111120
self._queue = queue
112121

113122
async def __aiter__(self) -> AsyncGenerator[Event | None, None]:
@@ -119,6 +128,8 @@ async def __aiter__(self) -> AsyncGenerator[Event | None, None]:
119128

120129
async def get(self) -> Event:
121130
item = await self._queue.get()
131+
if isinstance(item, BaseException):
132+
raise item
122133
if item is None:
123134
raise Unsubscribed()
124135
return item

broadcaster/backends/redis.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@ def __init__(self, url: str):
1414
self._conn = redis.Redis.from_url(url)
1515
self._pubsub = self._conn.pubsub()
1616
self._ready = asyncio.Event()
17-
self._queue: asyncio.Queue[Event] = asyncio.Queue()
17+
self._queue: asyncio.Queue[Event | BaseException | None] = asyncio.Queue()
1818
self._listener: asyncio.Task[None] | None = None
1919

2020
async def connect(self) -> None:
2121
self._listener = asyncio.create_task(self._pubsub_listener())
22+
self._listener.add_done_callback(self.drop)
2223
await self._pubsub.connect()
2324

2425
async def disconnect(self) -> None:
@@ -27,6 +28,10 @@ async def disconnect(self) -> None:
2728
if self._listener is not None:
2829
self._listener.cancel()
2930

31+
def drop(self, task: asyncio.Task[None]) -> None:
32+
exc = task.exception()
33+
self._queue.put_nowait(exc)
34+
3035
async def subscribe(self, channel: str) -> None:
3136
self._ready.set()
3237
await self._pubsub.subscribe(channel)
@@ -38,7 +43,12 @@ async def publish(self, channel: str, message: typing.Any) -> None:
3843
await self._conn.publish(channel, message)
3944

4045
async def next_published(self) -> Event:
41-
return await self._queue.get()
46+
result = await self._queue.get()
47+
if result is None:
48+
raise RuntimeError
49+
if isinstance(result, BaseException):
50+
raise result
51+
return result
4252

4353
async def _pubsub_listener(self) -> None:
4454
# redis-py does not listen to the pubsub connection if there are no channels subscribed

tests/test_broadcast.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import typing
55

66
import pytest
7+
import redis
78

89
from broadcaster import Broadcast, BroadcastBackend, Event
910
from broadcaster.backends.kafka import KafkaBackend
@@ -56,6 +57,22 @@ async def test_redis():
5657
assert event.message == "hello"
5758

5859

60+
@pytest.mark.asyncio
61+
async def test_redis_disconnect():
62+
with pytest.raises(redis.ConnectionError) as exc:
63+
async with Broadcast("redis://localhost:6379") as broadcast:
64+
async with broadcast.subscribe("chatroom") as subscriber:
65+
await broadcast.publish("chatroom", "hello")
66+
await broadcast._backend._conn.connection_pool.aclose() # type: ignore[attr-defined]
67+
event = await subscriber.get()
68+
assert event.channel == "chatroom"
69+
assert event.message == "hello"
70+
await subscriber.get()
71+
assert False
72+
73+
assert exc.value.args == ("Connection closed by server.",)
74+
75+
5976
@pytest.mark.asyncio
6077
async def test_redis_stream():
6178
async with Broadcast("redis-stream://localhost:6379") as broadcast:

0 commit comments

Comments
 (0)