Skip to content
Merged
12 changes: 9 additions & 3 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,6 +1191,7 @@ async def run(
*,
exception_handler: Optional["PSWorkerThreadExcHandlerT"] = None,
poll_timeout: float = 1.0,
pubsub = None
) -> None:
"""Process pub/sub messages using registered callbacks.

Expand All @@ -1215,9 +1216,14 @@ async def run(
await self.connect()
while True:
try:
await self.get_message(
ignore_subscribe_messages=True, timeout=poll_timeout
)
if pubsub is None:
await self.get_message(
ignore_subscribe_messages=True, timeout=poll_timeout
)
else:
await pubsub.get_message(
ignore_subscribe_messages=True, timeout=poll_timeout
)
except asyncio.CancelledError:
raise
except BaseException as e:
Expand Down
135 changes: 133 additions & 2 deletions redis/asyncio/multidb/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from typing import Callable, Optional, Coroutine, Any, List, Union, Awaitable

from redis.asyncio.client import PubSubHandler
from redis.asyncio.multidb.command_executor import DefaultCommandExecutor
from redis.asyncio.multidb.database import AsyncDatabase, Databases
from redis.asyncio.multidb.failure_detector import AsyncFailureDetector
Expand All @@ -10,7 +11,7 @@
from redis.background import BackgroundScheduler
from redis.commands import AsyncRedisModuleCommands, AsyncCoreCommands
from redis.multidb.exception import NoValidDatabaseException
from redis.typing import KeyT
from redis.typing import KeyT, EncodableT, ChannelT


class MultiDBClient(AsyncRedisModuleCommands, AsyncCoreCommands):
Expand Down Expand Up @@ -222,6 +223,17 @@ async def transaction(
watch_delay=watch_delay,
)

async def pubsub(self, **kwargs):
"""
Return a Publish/Subscribe object. With this object, you can
subscribe to channels and listen for messages that get published to
them.
"""
if not self.initialized:
await self.initialize()

return PubSub(self, **kwargs)

async def _check_databases_health(
self,
on_error: Optional[Callable[[Exception], Coroutine[Any, Any, None]]] = None,
Expand Down Expand Up @@ -340,4 +352,123 @@ async def execute(self) -> List[Any]:
try:
return await self._client.command_executor.execute_pipeline(tuple(self._command_stack))
finally:
await self.reset()
await self.reset()

class PubSub:
"""
PubSub object for multi database client.
"""
def __init__(self, client: MultiDBClient, **kwargs):
"""Initialize the PubSub object for a multi-database client.

Args:
client: MultiDBClient instance to use for pub/sub operations
**kwargs: Additional keyword arguments to pass to the underlying pubsub implementation
"""

self._client = client
self._client.command_executor.pubsub(**kwargs)

async def __aenter__(self) -> "PubSub":
return self

async def __aexit__(self, exc_type, exc_value, traceback) -> None:
await self.aclose()

async def aclose(self):
return await self._client.command_executor.execute_pubsub_method('aclose')

@property
def subscribed(self) -> bool:
return self._client.command_executor.active_pubsub.subscribed

async def execute_command(self, *args: EncodableT):
return await self._client.command_executor.execute_pubsub_method('execute_command', *args)

async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler):
"""
Subscribe to channel patterns. Patterns supplied as keyword arguments
expect a pattern name as the key and a callable as the value. A
pattern's callable will be invoked automatically when a message is
received on that pattern rather than producing a message via
``listen()``.
"""
return await self._client.command_executor.execute_pubsub_method(
'psubscribe',
*args,
**kwargs
)

async def punsubscribe(self, *args: ChannelT):
"""
Unsubscribe from the supplied patterns. If empty, unsubscribe from
all patterns.
"""
return await self._client.command_executor.execute_pubsub_method(
'punsubscribe',
*args
)

async def subscribe(self, *args: ChannelT, **kwargs: Callable):
"""
Subscribe to channels. Channels supplied as keyword arguments expect
a channel name as the key and a callable as the value. A channel's
callable will be invoked automatically when a message is received on
that channel rather than producing a message via ``listen()`` or
``get_message()``.
"""
return await self._client.command_executor.execute_pubsub_method(
'subscribe',
*args,
**kwargs
)

async def unsubscribe(self, *args):
"""
Unsubscribe from the supplied channels. If empty, unsubscribe from
all channels
"""
return await self._client.command_executor.execute_pubsub_method(
'unsubscribe',
*args
)

async def get_message(
self, ignore_subscribe_messages: bool = False, timeout: Optional[float] = 0.0
):
"""
Get the next message if one is available, otherwise None.

If timeout is specified, the system will wait for `timeout` seconds
before returning. Timeout should be specified as a floating point
number or None to wait indefinitely.
"""
return await self._client.command_executor.execute_pubsub_method(
'get_message',
ignore_subscribe_messages=ignore_subscribe_messages, timeout=timeout
)

async def run(
self,
*,
exception_handler: Optional["PSWorkerThreadExcHandlerT"] = None,
poll_timeout: float = 1.0,
) -> None:
"""Process pub/sub messages using registered callbacks.

This is the equivalent of :py:meth:`redis.PubSub.run_in_thread` in
redis-py, but it is a coroutine. To launch it as a separate task, use
``asyncio.create_task``:

>>> task = asyncio.create_task(pubsub.run())

To shut it down, use asyncio cancellation:

>>> task.cancel()
>>> await task
"""
return await self._client.command_executor.execute_pubsub_run(
exception_handler=exception_handler,
sleep_time=poll_timeout,
pubsub=self
)
18 changes: 9 additions & 9 deletions redis/asyncio/multidb/command_executor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import abstractmethod
from asyncio import iscoroutinefunction
from datetime import datetime
from typing import List, Optional, Callable, Any, Union, Awaitable

Expand Down Expand Up @@ -178,14 +179,10 @@ def failover_strategy(self) -> AsyncFailoverStrategy:
def command_retry(self) -> Retry:
return self._command_retry

async def pubsub(self, **kwargs):
async def callback():
if self._active_pubsub is None:
self._active_pubsub = self._active_database.client.pubsub(**kwargs)
self._active_pubsub_kwargs = kwargs
return None

return await self._execute_with_failure_detection(callback)
def pubsub(self, **kwargs):
if self._active_pubsub is None:
self._active_pubsub = self._active_database.client.pubsub(**kwargs)
self._active_pubsub_kwargs = kwargs

async def execute_command(self, *args, **options):
async def callback():
Expand Down Expand Up @@ -225,7 +222,10 @@ async def callback():
async def execute_pubsub_method(self, method_name: str, *args, **kwargs):
async def callback():
method = getattr(self.active_pubsub, method_name)
return await method(*args, **kwargs)
if iscoroutinefunction(method):
return await method(*args, **kwargs)
else:
return method(*args, **kwargs)

return await self._execute_with_failure_detection(callback, *args)

Expand Down
8 changes: 4 additions & 4 deletions redis/multidb/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,9 +337,6 @@ def __init__(self, client: MultiDBClient, **kwargs):
def __enter__(self) -> "PubSub":
return self

def __exit__(self, exc_type, exc_value, traceback) -> None:
self.reset()

def __del__(self) -> None:
try:
# if this object went out of scope prior to shutting down
Expand All @@ -350,7 +347,7 @@ def __del__(self) -> None:
pass

def reset(self) -> None:
pass
return self._client.command_executor.execute_pubsub_method('reset')

def close(self) -> None:
self.reset()
Expand All @@ -359,6 +356,9 @@ def close(self) -> None:
def subscribed(self) -> bool:
return self._client.command_executor.active_pubsub.subscribed

def execute_command(self, *args):
return self._client.command_executor.execute_pubsub_method('execute_command', *args)

def psubscribe(self, *args, **kwargs):
"""
Subscribe to channel patterns. Patterns supplied as keyword arguments
Expand Down
43 changes: 42 additions & 1 deletion tests/test_asyncio/test_scenario/test_active_active.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import json
import logging
from time import sleep

Expand Down Expand Up @@ -186,4 +187,44 @@ async def callback(pipe: Pipeline):
# Execute transaction until database failover
while not listener.is_changed_flag:
await r_multi_db.transaction(callback) == [True, True, True, 'value1', 'value2', 'value3']
await asyncio.sleep(0.5)
await asyncio.sleep(0.5)

@pytest.mark.asyncio
@pytest.mark.parametrize(
"r_multi_db",
[{"failure_threshold": 2}],
indirect=True
)
@pytest.mark.timeout(50)
async def test_pubsub_failover_to_another_db(self, r_multi_db, fault_injector_client):
r_multi_db, listener, config = r_multi_db

event = asyncio.Event()
asyncio.create_task(trigger_network_failure_action(fault_injector_client,config,event))

data = json.dumps({'message': 'test'})
messages_count = 0

async def handler(message):
nonlocal messages_count
messages_count += 1

pubsub = await r_multi_db.pubsub()

# Assign a handler and run in a separate thread.
await pubsub.subscribe(**{'test-channel': handler})
task = asyncio.create_task(pubsub.run(poll_timeout=0.1))

# Execute publish before network failure
while not event.is_set():
await r_multi_db.publish('test-channel', data)
await asyncio.sleep(0.5)

# Execute publish until database failover
while not listener.is_changed_flag:
await r_multi_db.publish('test-channel', data)
await asyncio.sleep(0.5)

task.cancel()
await pubsub.unsubscribe('test-channel') is True
assert messages_count > 1