diff --git a/.gitignore b/.gitignore index 013870b..2823f34 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,5 @@ test.db venv/ build/ dist/ +.idea/ +.vscode/ diff --git a/README.md b/README.md index 959b34d..be36329 100644 --- a/README.md +++ b/README.md @@ -84,6 +84,7 @@ Python 3.8+ * `Broadcast('memory://')` * `Broadcast("redis://localhost:6379")` * `Broadcast("redis-stream://localhost:6379")` +* `Broadcast("redis-stream-cached://localhost:6379")` * `Broadcast("postgres://localhost:5432/broadcaster")` * `Broadcast("kafka://localhost:9092")` diff --git a/broadcaster/__init__.py b/broadcaster/__init__.py index 0bcd9d2..f5db5d4 100644 --- a/broadcaster/__init__.py +++ b/broadcaster/__init__.py @@ -1,4 +1,5 @@ -from ._base import Broadcast, Event +from ._base import Broadcast +from ._event import Event from .backends.base import BroadcastBackend __version__ = "0.3.1" diff --git a/broadcaster/_base.py b/broadcaster/_base.py index a63b22b..2bdd3cd 100644 --- a/broadcaster/_base.py +++ b/broadcaster/_base.py @@ -5,20 +5,12 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, cast from urllib.parse import urlparse -if TYPE_CHECKING: # pragma: no cover - from broadcaster.backends.base import BroadcastBackend - - -class Event: - def __init__(self, channel: str, message: str) -> None: - self.channel = channel - self.message = message +from broadcaster.backends.base import BroadcastCacheBackend - def __eq__(self, other: object) -> bool: - return isinstance(other, Event) and self.channel == other.channel and self.message == other.message +from ._event import Event - def __repr__(self) -> str: - return f"Event(channel={self.channel!r}, message={self.message!r})" +if TYPE_CHECKING: # pragma: no cover + from broadcaster.backends.base import BroadcastBackend class Unsubscribed(Exception): @@ -43,6 +35,11 @@ def _create_backend(self, url: str) -> BroadcastBackend: return RedisStreamBackend(url) + elif parsed_url.scheme == "redis-stream-cached": + from broadcaster.backends.redis import RedisStreamCachedBackend + + return RedisStreamCachedBackend(url) + elif parsed_url.scheme in ("postgres", "postgresql"): from broadcaster.backends.postgres import PostgresBackend @@ -87,7 +84,7 @@ async def publish(self, channel: str, message: Any) -> None: await self._backend.publish(channel, message) @asynccontextmanager - async def subscribe(self, channel: str) -> AsyncIterator[Subscriber]: + async def subscribe(self, channel: str, history: int | None = None) -> AsyncIterator[Subscriber]: queue: asyncio.Queue[Event | None] = asyncio.Queue() try: @@ -95,7 +92,19 @@ async def subscribe(self, channel: str) -> AsyncIterator[Subscriber]: await self._backend.subscribe(channel) self._subscribers[channel] = {queue} else: - self._subscribers[channel].add(queue) + if isinstance(self._backend, BroadcastCacheBackend): + try: + current_id = await self._backend.get_current_channel_id(channel) + self._backend._ready.clear() + for message in await self._backend.get_history_messages(channel, current_id, history): + queue.put_nowait(message) + self._subscribers[channel].add(queue) + finally: + # wake up the listener after inqueue history messages + # for sorted messages by publish time + self._backend._ready.set() + else: + self._subscribers[channel].add(queue) yield Subscriber(queue) finally: diff --git a/broadcaster/_event.py b/broadcaster/_event.py new file mode 100644 index 0000000..65436cb --- /dev/null +++ b/broadcaster/_event.py @@ -0,0 +1,10 @@ +class Event: + def __init__(self, channel: str, message: str) -> None: + self.channel = channel + self.message = message + + def __eq__(self, other: object) -> bool: + return isinstance(other, Event) and self.channel == other.channel and self.message == other.message + + def __repr__(self) -> str: + return f"Event(channel={self.channel!r}, message={self.message!r})" diff --git a/broadcaster/backends/base.py b/broadcaster/backends/base.py index 7017df3..1a27ef8 100644 --- a/broadcaster/backends/base.py +++ b/broadcaster/backends/base.py @@ -1,6 +1,9 @@ +from __future__ import annotations + +import asyncio from typing import Any -from .._base import Event +from .._event import Event class BroadcastBackend: @@ -24,3 +27,18 @@ async def publish(self, channel: str, message: Any) -> None: async def next_published(self) -> Event: raise NotImplementedError() + + +class BroadcastCacheBackend(BroadcastBackend): + _ready: asyncio.Event + + async def get_current_channel_id(self, channel: str) -> str | bytes | memoryview | int: + raise NotImplementedError() + + async def get_history_messages( + self, + channel: str, + msg_id: int | bytes | str | memoryview, + count: int | None = None, + ) -> list[Event]: + raise NotImplementedError() diff --git a/broadcaster/backends/kafka.py b/broadcaster/backends/kafka.py index f09dca1..065a347 100644 --- a/broadcaster/backends/kafka.py +++ b/broadcaster/backends/kafka.py @@ -6,7 +6,7 @@ from aiokafka import AIOKafkaConsumer, AIOKafkaProducer -from .._base import Event +from .._event import Event from .base import BroadcastBackend diff --git a/broadcaster/backends/memory.py b/broadcaster/backends/memory.py index bfd0c44..2704124 100644 --- a/broadcaster/backends/memory.py +++ b/broadcaster/backends/memory.py @@ -3,7 +3,7 @@ import asyncio import typing -from .._base import Event +from .._event import Event from .base import BroadcastBackend diff --git a/broadcaster/backends/postgres.py b/broadcaster/backends/postgres.py index 7769962..d0bd42c 100644 --- a/broadcaster/backends/postgres.py +++ b/broadcaster/backends/postgres.py @@ -3,7 +3,7 @@ import asyncpg -from .._base import Event +from .._event import Event from .base import BroadcastBackend diff --git a/broadcaster/backends/redis.py b/broadcaster/backends/redis.py index 1be4195..a4d0347 100644 --- a/broadcaster/backends/redis.py +++ b/broadcaster/backends/redis.py @@ -5,8 +5,8 @@ from redis import asyncio as redis -from .._base import Event -from .base import BroadcastBackend +from .._event import Event +from .base import BroadcastBackend, BroadcastCacheBackend class RedisBackend(BroadcastBackend): @@ -88,14 +88,20 @@ async def subscribe(self, channel: str) -> None: async def unsubscribe(self, channel: str) -> None: self.streams.pop(channel, None) + if not self.streams: + self._ready.clear() async def publish(self, channel: str, message: typing.Any) -> None: await self._producer.xadd(channel, {"message": message}) async def wait_for_messages(self) -> list[StreamMessageType]: - await self._ready.wait() messages = None while not messages: + if not self.streams: + # 1. save cpu usage + # 2. redis raise expection when self.streams is empty + self._ready.clear() + await self._ready.wait() messages = await self._consumer.xread(self.streams, count=1, block=100) return messages @@ -108,3 +114,77 @@ async def next_published(self) -> Event: channel=stream.decode("utf-8"), message=message.get(b"message", b"").decode("utf-8"), ) + + +class RedisStreamCachedBackend(BroadcastCacheBackend): + def __init__(self, url: str): + url = url.replace("redis-stream-cached", "redis", 1) + self.streams: dict[bytes | str | memoryview, int | bytes | str | memoryview] = {} + self._ready = asyncio.Event() + self._producer = redis.Redis.from_url(url) + self._consumer = redis.Redis.from_url(url) + + async def connect(self) -> None: + pass + + async def disconnect(self) -> None: + await self._producer.aclose() + await self._consumer.aclose() + + async def subscribe(self, channel: str) -> None: + # read from beginning + last_id = "0" + self.streams[channel] = last_id + self._ready.set() + + async def unsubscribe(self, channel: str) -> None: + self.streams.pop(channel, None) + if not self.streams: + self._ready.clear() + + async def publish(self, channel: str, message: typing.Any) -> None: + await self._producer.xadd(channel, {"message": message}) + + async def wait_for_messages(self) -> list[StreamMessageType]: + messages = None + while not messages: + if not self.streams: + # 1. save cpu usage + # 2. redis raise expection when self.streams is empty + self._ready.clear() + await self._ready.wait() + messages = await self._consumer.xread(self.streams, count=1, block=100) + return messages + + async def next_published(self) -> Event: + messages = await self.wait_for_messages() + stream, events = messages[0] + _msg_id, message = events[0] + self.streams[stream.decode("utf-8")] = _msg_id.decode("utf-8") + return Event( + channel=stream.decode("utf-8"), + message=message.get(b"message", b"").decode("utf-8"), + ) + + async def get_current_channel_id(self, channel: str) -> int | bytes | str | memoryview: + try: + info = await self._consumer.xinfo_stream(channel) + last_id: int | bytes | str | memoryview = info["last-generated-id"] + except redis.ResponseError: + last_id = "0" + return last_id + + async def get_history_messages( + self, + channel: str, + msg_id: int | bytes | str | memoryview, + count: int | None = None, + ) -> list[Event]: + messages = await self._consumer.xrevrange(channel, max=msg_id, count=count) + return [ + Event( + channel=channel, + message=message.get(b"message", b"").decode("utf-8"), + ) + for _, message in reversed(messages or []) + ] diff --git a/example/app.py b/example/app.py index a201221..e012eef 100644 --- a/example/app.py +++ b/example/app.py @@ -51,5 +51,7 @@ async def chatroom_ws_sender(websocket): app = Starlette( - routes=routes, on_startup=[broadcast.connect], on_shutdown=[broadcast.disconnect], + routes=routes, + on_startup=[broadcast.connect], + on_shutdown=[broadcast.disconnect], ) diff --git a/tests/test_broadcast.py b/tests/test_broadcast.py index a8bd3eb..19e9881 100644 --- a/tests/test_broadcast.py +++ b/tests/test_broadcast.py @@ -71,6 +71,31 @@ async def test_redis_stream(): assert event.message == "hello" +@pytest.mark.asyncio +async def test_redis_stream_cache(): + messages = ["hello", "I'm cached"] + async with Broadcast("redis-stream-cached://localhost:6379") as broadcast: + await broadcast.publish("chatroom_cached", messages[0]) + await broadcast.publish("chatroom_cached", messages[1]) + await broadcast.publish("chatroom_cached", "quit") + sub1_messages = [] + async with broadcast.subscribe("chatroom_cached") as subscriber: + async for event in subscriber: + if event: + if event.message == "quit": + break + sub1_messages.append(event.message) + sub2_messages = [] + async with broadcast.subscribe("chatroom_cached") as subscriber: + async for event in subscriber: + if event: + if event.message == "quit": + break + sub2_messages.append(event.message) + + assert sub1_messages == sub2_messages == messages + + @pytest.mark.asyncio async def test_postgres(): async with Broadcast("postgres://postgres:postgres@localhost:5432/broadcaster") as broadcast: