diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index aac409073f..92c352344f 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -810,7 +810,7 @@ class PubSub: """ PUBLISH_MESSAGE_TYPES = ("message", "pmessage") - UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe") + UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe", "sunsubscribe") HEALTH_CHECK_MESSAGE = "redis-py-health-check" def __init__( @@ -852,6 +852,8 @@ def __init__( self.pending_unsubscribe_channels = set() self.patterns = {} self.pending_unsubscribe_patterns = set() + self.shard_channels = {} + self.pending_unsubscribe_shard_channels = set() self._lock = asyncio.Lock() async def __aenter__(self): @@ -880,6 +882,8 @@ async def aclose(self): self.pending_unsubscribe_channels = set() self.patterns = {} self.pending_unsubscribe_patterns = set() + self.shard_channels = {} + self.pending_unsubscribe_shard_channels = set() @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close") async def close(self) -> None: @@ -898,6 +902,7 @@ async def on_connect(self, connection: Connection): # before passing them to [p]subscribe. self.pending_unsubscribe_channels.clear() self.pending_unsubscribe_patterns.clear() + self.pending_unsubscribe_shard_channels.clear() if self.channels: channels = {} for k, v in self.channels.items(): @@ -908,11 +913,17 @@ async def on_connect(self, connection: Connection): for k, v in self.patterns.items(): patterns[self.encoder.decode(k, force=True)] = v await self.psubscribe(**patterns) + if self.shard_channels: + shard_channels = { + self.encoder.decode(k, force=True): v + for k, v in self.shard_channels.items() + } + await self.ssubscribe(**shard_channels) @property def subscribed(self): """Indicates if there are subscriptions to any channels or patterns""" - return bool(self.channels or self.patterns) + return bool(self.channels or self.patterns or self.shard_channels) async def execute_command(self, *args: EncodableT): """Execute a publish/subscribe command""" @@ -1091,6 +1102,40 @@ def unsubscribe(self, *args) -> Awaitable: self.pending_unsubscribe_channels.update(channels) return self.execute_command("UNSUBSCRIBE", *parsed_args) + def ssubscribe(self, *args, target_node=None, **kwargs) -> Awaitable: + """ + Subscribes the client to the specified shard channels. + Channels supplied as keyword arguments expect a channel name as the key + and a callable as the value. A channel's callable will be invoked automatically + when a message is received on that channel rather than producing a message via + ``listen()`` or ``get_sharded_message()``. + """ + if args: + args = list_or_args(args[0], args[1:]) + new_s_channels = dict.fromkeys(args) + new_s_channels.update(kwargs) + ret_val = self.execute_command("SSUBSCRIBE", *new_s_channels.keys()) + # update the s_channels dict AFTER we send the command. we don't want to + # subscribe twice to these channels, once for the command and again + # for the reconnection. + new_s_channels = self._normalize_keys(new_s_channels) + self.shard_channels.update(new_s_channels) + self.pending_unsubscribe_shard_channels.difference_update(new_s_channels) + return ret_val + + def sunsubscribe(self, *args, target_node=None) -> Awaitable: + """ + Unsubscribe from the supplied shard_channels. If empty, unsubscribe from + all shard_channels + """ + if args: + args = list_or_args(args[0], args[1:]) + s_channels = self._normalize_keys(dict.fromkeys(args)) + else: + s_channels = self.shard_channels + self.pending_unsubscribe_shard_channels.update(s_channels) + return self.execute_command("SUNSUBSCRIBE", *args) + async def listen(self) -> AsyncIterator: """Listen for messages on channels this client has been subscribed to""" while self.subscribed: @@ -1160,6 +1205,11 @@ async def handle_message(self, response, ignore_subscribe_messages=False): if pattern in self.pending_unsubscribe_patterns: self.pending_unsubscribe_patterns.remove(pattern) self.patterns.pop(pattern, None) + elif message_type == "sunsubscribe": + s_channel = response[1] + if s_channel in self.pending_unsubscribe_shard_channels: + self.pending_unsubscribe_shard_channels.remove(s_channel) + self.shard_channels.pop(s_channel, None) else: channel = response[1] if channel in self.pending_unsubscribe_channels: @@ -1172,6 +1222,8 @@ async def handle_message(self, response, ignore_subscribe_messages=False): handler = self.patterns.get(message["pattern"], None) else: handler = self.channels.get(message["channel"], None) + if handler is None: + handler = self.shard_channels.get(message["channel"], None) if handler: if inspect.iscoroutinefunction(handler): await handler(message) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index e8434d04a5..32944e1856 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -31,7 +31,7 @@ _RedisCallbacksRESP2, _RedisCallbacksRESP3, ) -from redis.asyncio.client import ResponseCallbackT +from redis.asyncio.client import PubSub, ResponseCallbackT from redis.asyncio.connection import Connection, SSLConnection, parse_url from redis.asyncio.lock import Lock from redis.asyncio.retry import Retry @@ -51,6 +51,7 @@ parse_cluster_slots, ) from redis.commands import READ_COMMANDS, AsyncRedisClusterCommands +from redis.commands.helpers import list_or_args from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot from redis.credentials import CredentialProvider from redis.event import AfterAsyncClusterInstantiationEvent, EventDispatcher @@ -889,6 +890,27 @@ def pipeline( return ClusterPipeline(self, transaction) + def pubsub( + self, + node: Optional["ClusterNode"] = None, + host: Optional[str] = None, + port: Optional[int] = None, + **kwargs: Any, + ) -> "ClusterPubSub": + """ + Create and return a ClusterPubSub instance. + + Allows passing a ClusterNode, or host&port, to get a pubsub instance + connected to the specified node + + :param node: ClusterNode to connect to + :param host: Host of the node to connect to + :param port: Port of the node to connect to + :param kwargs: Additional keyword arguments + :return: ClusterPubSub instance + """ + return ClusterPubSub(self, node=node, host=host, port=port, **kwargs) + def lock( self, name: KeyT, @@ -2394,3 +2416,294 @@ async def discard(self): async def unlink(self, *names): return self.execute_command("UNLINK", *names) + + +class ClusterPubSub(PubSub): + """ + Async cluster implementation for pub/sub. + + IMPORTANT: before using ClusterPubSub, read about the known limitations + with pubsub in Cluster mode and learn how to workaround them: + https://redis-py-cluster.readthedocs.io/en/stable/pubsub.html + """ + + def __init__( + self, + redis_cluster: "RedisCluster", + node: Optional["ClusterNode"] = None, + host: Optional[str] = None, + port: Optional[int] = None, + push_handler_func: Optional[Callable] = None, + event_dispatcher: Optional[EventDispatcher] = None, + **kwargs: Any, + ) -> None: + """ + When a pubsub instance is created without specifying a node, a single + node will be transparently chosen for the pubsub connection on the + first command execution. The node will be determined by: + 1. Hashing the channel name in the request to find its keyslot + 2. Selecting a node that handles the keyslot: If read_from_replicas is + set to true or load_balancing_strategy is set, a replica can be selected. + + :param redis_cluster: RedisCluster instance + :param node: ClusterNode to connect to + :param host: Host of the node to connect to + :param port: Port of the node to connect to + :param push_handler_func: Optional push handler function + :param event_dispatcher: Optional event dispatcher + :param kwargs: Additional keyword arguments + """ + self.node = None + self.set_pubsub_node(redis_cluster, node, host, port) + + # Create connection pool if node is specified + if self.node is not None: + from redis.asyncio.connection import ConnectionPool + connection_pool = ConnectionPool( + connection_class=self.node.connection_class, + **self.node.connection_kwargs + ) + else: + connection_pool = None + + self.cluster = redis_cluster + self.node_pubsub_mapping: Dict[str, PubSub] = {} + self._pubsubs_generator = self._pubsubs_generator() + if event_dispatcher is None: + self._event_dispatcher = EventDispatcher() + else: + self._event_dispatcher = event_dispatcher + super().__init__( + connection_pool=connection_pool, + encoder=redis_cluster.encoder, + push_handler_func=push_handler_func, + event_dispatcher=self._event_dispatcher, + **kwargs, + ) + + def set_pubsub_node( + self, + cluster: "RedisCluster", + node: Optional["ClusterNode"] = None, + host: Optional[str] = None, + port: Optional[int] = None, + ) -> None: + """ + The pubsub node will be set according to the passed node, host and port + When none of the node, host, or port are specified - the node is set + to None and will be determined by the keyslot of the channel in the + first command to be executed. + RedisClusterException will be thrown if the passed node does not exist + in the cluster. + If host is passed without port, or vice versa, a DataError will be + thrown. + """ + if node is not None: + # node is passed by the user + self._raise_on_invalid_node(cluster, node, node.host, node.port) + pubsub_node = node + elif host is not None and port is not None: + # host and port passed by the user + node = cluster.get_node(host=host, port=port) + self._raise_on_invalid_node(cluster, node, host, port) + pubsub_node = node + elif host is not None or port is not None: + # only one of host and port is specified + raise DataError("Specify both host and port") + else: + # nothing specified by the user + pubsub_node = None + self.node = pubsub_node + + def _get_node_pubsub(self, node: "ClusterNode") -> PubSub: + """Get or create a PubSub instance for the given node.""" + try: + return self.node_pubsub_mapping[node.name] + except KeyError: + # Create a simple connection pool-like interface for the node + from redis.asyncio.connection import ConnectionPool + + # Create a minimal connection pool for this node + connection_pool = ConnectionPool( + connection_class=node.connection_class, + **node.connection_kwargs + ) + + pubsub = PubSub( + connection_pool=connection_pool, + encoder=self.cluster.encoder, + push_handler_func=self.push_handler_func, + event_dispatcher=self._event_dispatcher, + ) + self.node_pubsub_mapping[node.name] = pubsub + return pubsub + + def _sharded_message_generator(self) -> Optional[Dict[str, Any]]: + """Generate messages from shard channels across all nodes.""" + for _ in range(len(self.node_pubsub_mapping)): + pubsub = next(self._pubsubs_generator) + # Check if pubsub has async get_message method + if hasattr(pubsub, 'get_message') and callable(pubsub.get_message): + # This would need to be adapted for async usage + # For now, we'll return None to avoid blocking + pass + return None + + def _pubsubs_generator(self) -> Generator[PubSub, None, None]: + """Generator that yields PubSub instances in round-robin fashion.""" + while True: + yield from self.node_pubsub_mapping.values() + + async def get_sharded_message( + self, + ignore_subscribe_messages: bool = False, + timeout: float = 0.0, + target_node: Optional["ClusterNode"] = None, + ) -> Optional[Dict[str, Any]]: + """ + Get a message from shard channels. + + :param ignore_subscribe_messages: Whether to ignore subscribe messages + :param timeout: Timeout for message retrieval + :param target_node: Specific node to get message from + :return: Message dictionary or None + """ + if target_node: + pubsub = self.node_pubsub_mapping.get(target_node.name) + if pubsub: + # Note: This would need proper async implementation + # For now returning None as placeholder + message = None + else: + message = None + else: + message = self._sharded_message_generator() + + if message is None: + return None + elif str_if_bytes(message["type"]) == "sunsubscribe": + if message["channel"] in self.pending_unsubscribe_shard_channels: + self.pending_unsubscribe_shard_channels.remove(message["channel"]) + self.shard_channels.pop(message["channel"], None) + node = self.cluster.get_node_from_key(message["channel"]) + if node and node.name in self.node_pubsub_mapping: + pubsub = self.node_pubsub_mapping[node.name] + if not pubsub.subscribed: + self.node_pubsub_mapping.pop(node.name) + + if not self.channels and not self.patterns and not self.shard_channels: + # There are no subscriptions anymore + pass + + if self.ignore_subscribe_messages or ignore_subscribe_messages: + return None + return message + + async def ssubscribe(self, *args: Any, **kwargs: Any) -> None: + """ + Subscribe to shard channels. + + :param args: Channel names + :param kwargs: Channel names with handlers + """ + if args: + args = list_or_args(args[0], args[1:]) + s_channels = dict.fromkeys(args) + s_channels.update(kwargs) + + for s_channel, handler in s_channels.items(): + node = self.cluster.get_node_from_key(s_channel) + if node: + pubsub = self._get_node_pubsub(node) + if handler: + await pubsub.ssubscribe(**{s_channel: handler}) + else: + await pubsub.ssubscribe(s_channel) + self.shard_channels.update(pubsub.shard_channels) + self.pending_unsubscribe_shard_channels.difference_update( + self._normalize_keys({s_channel: None}) + ) + + async def sunsubscribe(self, *args: Any) -> None: + """ + Unsubscribe from shard channels. + + :param args: Channel names to unsubscribe from. If empty, unsubscribe from all. + """ + if args: + args = list_or_args(args[0], args[1:]) + else: + args = list(self.shard_channels.keys()) + + for s_channel in args: + node = self.cluster.get_node_from_key(s_channel) + if node and node.name in self.node_pubsub_mapping: + pubsub = self.node_pubsub_mapping[node.name] + await pubsub.sunsubscribe(s_channel) + self.pending_unsubscribe_shard_channels.update( + pubsub.pending_unsubscribe_shard_channels + ) + + def get_redis_connection(self) -> Optional[Connection]: + """ + Get the Redis connection of the pubsub connected node. + """ + if self.node is not None: + return self.node.acquire_connection() + return None + + async def aclose(self) -> None: + """ + Disconnect the pubsub connection. + """ + if self.connection: + await self.connection.disconnect() + for pubsub in self.node_pubsub_mapping.values(): + await pubsub.aclose() + await super().aclose() + + def _raise_on_invalid_node( + self, + redis_cluster: "RedisCluster", + node: Optional["ClusterNode"], + host: Optional[str], + port: Optional[int], + ) -> None: + """ + Raise an exception if the node is invalid. + """ + if node is None: + raise RedisClusterException( + f"Node {host}:{port} does not exist in cluster" + ) + + async def execute_command(self, *args: Any, **kwargs: Any) -> Any: + """ + Execute a command on the appropriate cluster node. + """ + # For shard commands, route to appropriate node + command = args[0].upper() if args else "" + if command in ("SSUBSCRIBE", "SUNSUBSCRIBE", "SPUBLISH"): + if len(args) > 1: + channel = args[1] + node = self.cluster.get_node_from_key(channel) + if node: + pubsub = self._get_node_pubsub(node) + return await pubsub.execute_command(*args, **kwargs) + + # For other commands, use the set node or default behavior + if self.node: + pubsub = self._get_node_pubsub(self.node) + return await pubsub.execute_command(*args, **kwargs) + else: + # Use parent's execute_command if no specific node is set + return await super().execute_command(*args, **kwargs) + + def _normalize_keys(self, data: Dict[Any, Any]) -> Dict[bytes, Any]: + """ + Normalize keys to bytes for internal storage. + """ + return { + self.encoder.encode(key): value + for key, value in data.items() + } diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 25f487fe4c..f6d155579b 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -3234,3 +3234,145 @@ async def test_validating_self_signed_string_certificate( ssl_keyfile=self.client_key, ) as rc: assert await rc.ping() + + +@pytest.mark.onlycluster +class TestClusterPubSub: + """ + Test ClusterPubSub with shard channels functionality + """ + + async def wait_for_message(self, pubsub, timeout=0.2, ignore_subscribe_messages=False): + """Helper method to wait for a message with timeout""" + import asyncio + now = asyncio.get_running_loop().time() + timeout = now + timeout + while now < timeout: + message = await pubsub.get_message( + ignore_subscribe_messages=ignore_subscribe_messages + ) + if message is not None: + return message + await asyncio.sleep(0.01) + now = asyncio.get_running_loop().time() + return None + + def make_message(self, type, channel, data, pattern=None): + """Helper method to create expected message format""" + return { + "type": type, + "pattern": pattern and pattern.encode("utf-8") or None, + "channel": channel and channel.encode("utf-8") or None, + "data": data.encode("utf-8") if isinstance(data, str) else data, + } + + @skip_if_server_version_lt("7.0.0") + async def test_cluster_pubsub_creation(self, r): + """Test basic ClusterPubSub creation""" + pubsub = r.pubsub() + assert pubsub is not None + assert hasattr(pubsub, 'ssubscribe') + assert hasattr(pubsub, 'sunsubscribe') + assert hasattr(pubsub, 'get_sharded_message') + await pubsub.aclose() + + @skip_if_server_version_lt("7.0.0") + async def test_cluster_pubsub_with_node(self, r): + """Test ClusterPubSub creation with specific node""" + nodes = r.get_nodes() + if nodes: + node = nodes[0] + pubsub = r.pubsub(node=node) + assert pubsub.node == node + await pubsub.aclose() + + @skip_if_server_version_lt("7.0.0") + async def test_cluster_pubsub_with_host_port(self, r): + """Test ClusterPubSub creation with host and port""" + nodes = r.get_nodes() + if nodes: + node = nodes[0] + pubsub = r.pubsub(host=node.host, port=node.port) + assert pubsub.node.host == node.host + assert pubsub.node.port == node.port + await pubsub.aclose() + + @skip_if_server_version_lt("7.0.0") + async def test_shard_channel_subscribe_unsubscribe(self, r): + """Test shard channel subscribe and unsubscribe""" + pubsub = r.pubsub() + + try: + # Test channels that map to different nodes + channels = ["shard_test_1", "shard_test_2", "shard_test_3"] + + # Subscribe to shard channels + for channel in channels: + await pubsub.ssubscribe(channel) + + # Verify subscription messages (implementation dependent) + for channel in channels: + # This is a basic test - in practice, subscription messages + # may vary depending on the cluster implementation + pass + + # Unsubscribe from shard channels + for channel in channels: + await pubsub.sunsubscribe(channel) + + finally: + await pubsub.aclose() + + @skip_if_server_version_lt("7.0.0") + async def test_shard_channel_attributes(self, r): + """Test shard channel attributes""" + pubsub = r.pubsub() + + try: + # Initially no shard channels + assert not pubsub.shard_channels + assert not pubsub.pending_unsubscribe_shard_channels + + # Subscribe to a shard channel + await pubsub.ssubscribe("test_shard_attr") + + # Should have shard channel information + # Note: The exact behavior may depend on implementation details + + finally: + await pubsub.aclose() + + @skip_if_server_version_lt("7.0.0") + async def test_shard_channel_with_handler(self, r): + """Test shard channel subscription with message handler""" + pubsub = r.pubsub() + + try: + received_messages = [] + + def message_handler(message): + received_messages.append(message) + + # Subscribe with handler + await pubsub.ssubscribe(test_handler_channel=message_handler) + + # This test verifies that the handler mechanism is properly set up + # Actual message delivery testing would require a live Redis cluster + + finally: + await pubsub.aclose() + + @skip_if_server_version_lt("7.0.0") + async def test_invalid_node_raises_exception(self, r): + """Test that invalid node raises appropriate exception""" + with pytest.raises(RedisClusterException): + r.pubsub(host="invalid_host", port=9999) + + @skip_if_server_version_lt("7.0.0") + async def test_partial_host_port_raises_exception(self, r): + """Test that providing only host or port raises DataError""" + with pytest.raises(DataError): + r.pubsub(host="localhost") # Missing port + + with pytest.raises(DataError): + r.pubsub(port=7000) # Missing host diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py index b281cb1281..071b106ede 100644 --- a/tests/test_asyncio/test_pubsub.py +++ b/tests/test_asyncio/test_pubsub.py @@ -71,6 +71,15 @@ def make_subscribe_test_data(pubsub, type): "unsub_func": pubsub.unsubscribe, "keys": ["foo", "bar", "uni" + chr(4456) + "code"], } + elif type == "shard_channel": + return { + "p": pubsub, + "sub_type": "ssubscribe", + "unsub_type": "sunsubscribe", + "sub_func": pubsub.ssubscribe, + "unsub_func": pubsub.sunsubscribe, + "keys": ["foo", "bar", "uni" + chr(4456) + "code"], + } elif type == "pattern": return { "p": pubsub, @@ -118,6 +127,12 @@ async def test_pattern_subscribe_unsubscribe(self, pubsub): kwargs = make_subscribe_test_data(pubsub, "pattern") await self._test_subscribe_unsubscribe(**kwargs) + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.0.0") + async def test_shard_channel_subscribe_unsubscribe(self, pubsub): + kwargs = make_subscribe_test_data(pubsub, "shard_channel") + await self._test_subscribe_unsubscribe(**kwargs) + @pytest.mark.onlynoncluster async def _test_resubscribe_on_reconnection( self, p, sub_type, unsub_type, sub_func, unsub_func, keys @@ -215,6 +230,12 @@ async def test_subscribe_property_with_patterns(self, pubsub): kwargs = make_subscribe_test_data(pubsub, "pattern") await self._test_subscribed_property(**kwargs) + @pytest.mark.onlynoncluster + @skip_if_server_version_lt("7.0.0") + async def test_subscribe_property_with_shard_channels(self, pubsub): + kwargs = make_subscribe_test_data(pubsub, "shard_channel") + await self._test_subscribed_property(**kwargs) + async def test_aclosing(self, r: redis.Redis): p = r.pubsub() async with aclosing(p):