diff --git a/.agent/cluster_pubsub_connection_extraction_analysis.md b/.agent/cluster_pubsub_connection_extraction_analysis.md new file mode 100644 index 0000000000..0fa11b552b --- /dev/null +++ b/.agent/cluster_pubsub_connection_extraction_analysis.md @@ -0,0 +1,338 @@ +# Cluster PubSub: Connection Extraction Analysis & Refactoring Plan + +Scope: `redis.cluster.ClusterPubSub` (sync) and `redis.asyncio.cluster.ClusterPubSub` +(async). This document analyses how each implementation currently obtains the +connection(s) used to send `SUBSCRIBE` / `PSUBSCRIBE` / `SSUBSCRIBE` commands and +proposes a plan to align both with the resources already exposed by +`NodesManager` / `ClusterNode`, following the pattern used in +`redis.keyspace_notifications` / `redis.asyncio.keyspace_notifications`. + +--- + +## 1. Current state — Sync (`redis/cluster.py`) + +### 1.1 Underlying cluster resources available to `ClusterPubSub` + +| Resource | Type | Provided by | +| --- | --- | --- | +| `ClusterNode.redis_connection` | `redis.Redis` | `NodesManager.create_redis_connections()` (lazy: via `RedisCluster.get_redis_connection(node)`) | +| `ClusterNode.redis_connection.connection_pool` | `ConnectionPool` | Created in `NodesManager.create_redis_node()` | +| `RedisCluster.get_redis_connection(node)` | `redis.Redis` | Ensures node has a `redis_connection`, returns it | +| `RedisCluster.get_primaries()` / `get_random_node()` | `list[ClusterNode]` / `ClusterNode` | Topology accessors on `NodesManager` | + +Each sync `ClusterNode` owns a full `Redis` instance whose `ConnectionPool` +is the authoritative pool for that node. All normal commands go through it. + +### 1.2 Primary (non-sharded) pubsub path + +File: `redis/cluster.py`, class `ClusterPubSub`. + +- `__init__` (L2598–L2641): if a node was supplied, eagerly pulls the node's + pool: + ````python + connection_pool = ( + None if self.node is None + else redis_cluster.get_redis_connection(self.node).connection_pool + ) + ```` + That pool is forwarded to `PubSub.__init__`. +- `execute_command` (L2692–L2732): if `self.connection is None`: + - Picks a node via `cluster.nodes_manager.get_node_from_slot(...)` (keyslot + of the first channel) or `cluster.get_random_node()`. + - Sets `self.connection_pool = cluster.get_redis_connection(node).connection_pool`. + - Calls `self.connection = self.connection_pool.get_connection()` — the + connection is then held for the entire subscribed lifetime (standard + `PubSub` contract). + +The sync implementation therefore **already reuses the node's backing +`ConnectionPool`** and its connection lifecycle (creation, retry config, +events, maintenance notifications) is controlled by `NodesManager`. + +### 1.3 Sharded pubsub path (`ssubscribe`/`sunsubscribe`) + +`_get_node_pubsub(node)` (L2734–L2742): +````python +pubsub = node.redis_connection.pubsub(push_handler_func=...) +```` +It delegates to the node's `Redis` client `.pubsub()`, which internally +constructs a `PubSub` bound to that same `node.redis_connection.connection_pool`. +So sharded pubsubs also transitively reuse the node-level pool. + +### 1.4 Disconnect + +`disconnect()` (L2834–L2841) closes `self.connection` and every shard +pubsub's `connection`. No pool-level release call — the connections are +effectively owned by the pubsub for its lifetime (consistent with `PubSub`). + +--- + +## 2. Current state — Async (`redis/asyncio/cluster.py`) + +### 2.1 Underlying cluster resources available to `ClusterPubSub` + +The async `ClusterNode` (L1398–L1615) has a **different model** than sync: + +| Attribute / method | Role | +| --- | --- | +| `connection_class`, `connection_kwargs`, `max_connections` | Config for connections | +| `_connections: list[Connection]`, `_free: deque[Connection]` | Node-owned connection pool | +| `acquire_connection()` | Pop from `_free` or create a new `Connection` up to `max_connections` | +| `release(connection)` | Return connection to `_free` | +| `get_encoder()` | Build an `Encoder` from `connection_kwargs` | +| `disconnect_if_needed(conn)` | Lazy reconnect support for maintenance | + +There is **no** `redis_connection` attribute, and **no** `ConnectionPool` +object — the node itself *is* the pool. All regular commands +(`execute_command`, `execute_pipeline`) go through +`acquire_connection()` / `release()`. + +### 2.2 Primary (non-sharded) pubsub path + +File: `redis/asyncio/cluster.py`, class `ClusterPubSub`. + +- `__init__` (L3064–L3115): if a node was supplied, builds a **brand new** + `ConnectionPool` from the node's kwargs: + ````python + connection_pool = ConnectionPool( + connection_class=self.node.connection_class, + **self.node.connection_kwargs, + ) + ```` + This pool is passed to `PubSub.__init__`. It is completely disjoint from + the node's own `_connections` / `_free` pool. +- `execute_command` (L3333–L3376): if `self.connection_pool is None`, picks a + node (keyslot / random) and again constructs a fresh `ConnectionPool` + with `node.connection_kwargs`, then delegates to + `super().execute_command(...)` which lazily acquires a connection from + that newly created pool. + +### 2.3 Sharded pubsub path + +`_get_node_pubsub(node)` (L3159–L3176) repeats the pattern: +````python +connection_pool = ConnectionPool( + connection_class=node.connection_class, **node.connection_kwargs +) +pubsub = PubSub(connection_pool=connection_pool, ...) +```` +Every sharded node gets yet another detached `ConnectionPool`. + +### 2.4 Consequences + +1. **Resource duplication**: each `ClusterPubSub` (and each shard) creates a + parallel `ConnectionPool` that bypasses the node's `_free` queue and + `max_connections` budget. +2. **Event / maintenance-notification divergence**: connections opened via + the detached pool do not participate in the node's + `update_active_connections_for_reconnect`, `disconnect_if_needed`, or + event dispatcher wiring that `ClusterNode` owns. +3. **Credentials / retry / on_connect drift risk**: because the detached + pool is reconstructed from `connection_kwargs` only, any adjustments + made on the node after construction (e.g. retry rebinding) are not + reflected. +4. **Inconsistency with sync**: the sync path centralises pubsub through + the node's authoritative pool; async does not. + +--- + +## 3. Reference implementation — keyspace notifications + +`ClusterKeyspaceNotifications` / `AsyncClusterKeyspaceNotifications` already +solve the exact same problem for keyspace events. They serve as the +template for the refactor. + +### 3.1 Sync (`redis/keyspace_notifications.py`) + +`ClusterKeyspaceNotifications._ensure_node_pubsub(node)` (L1342–L1351): +````python +redis_conn = self.cluster.get_redis_connection(node) +pubsub = redis_conn.pubsub(ignore_subscribe_messages=False) +self._node_pubsubs[node.name] = pubsub +```` +It asks `RedisCluster.get_redis_connection(node)` (which lazily materialises +`node.redis_connection` through `NodesManager.create_redis_connections`) and +calls `.pubsub()` on it, inheriting the node's `ConnectionPool`. Nodes are +enumerated via `self.cluster.get_primaries()`. + +### 3.2 Async (`redis/asyncio/keyspace_notifications.py`) + +`_ClusterNodePoolAdapter` (L68–L102) is a minimal object that implements the +tiny `ConnectionPool` surface `PubSub` needs, backed by the node itself: + +````python +class _ClusterNodePoolAdapter: + def __init__(self, node: ClusterNode) -> None: + self._node = node + self.connection_kwargs = node.connection_kwargs + def get_encoder(self) -> Encoder: + return self._node.get_encoder() + async def get_connection(self, ...): + connection = self._node.acquire_connection() + await connection.connect() + return connection + async def release(self, connection) -> None: + self._node.release(connection) +```` + +`AsyncClusterKeyspaceNotifications._ensure_node_pubsub(node)` (L634–L649) +wraps the node with this adapter and feeds it as +`connection_pool=` to `PubSub(...)`. No duplicate `ConnectionPool` is +created; `PubSub.aclose()` already disconnects the connection before +calling `release`, so the socket never re-enters the node's free queue in a +subscribed state. + +This pattern: +- reuses the node's `_free` / `max_connections` budget, +- respects the node's event / maintenance machinery, +- keeps `ClusterPubSub` free of connection-construction logic. + +--- + +## 4. Plan — migrate `ClusterPubSub` to `NodesManager`-owned resources + +Goal: make both `ClusterPubSub` implementations obtain their connections +*exclusively* through the `NodesManager` / `ClusterNode` surface used by +normal cluster commands and by keyspace notifications. No new +public APIs, no behaviour changes visible to callers. + +### 4.1 Async `ClusterPubSub` — primary work + +This is the implementation that actually needs to change. + +1. **Reuse the existing adapter.** Import + `redis.asyncio.keyspace_notifications._ClusterNodePoolAdapter` (or, if + cross-module import is undesirable, relocate it to a neutral module — + e.g. `redis/asyncio/_cluster_pool_adapter.py` — and import it from both + `keyspace_notifications` and `cluster`). Preferred option: move it to a + neutral module to avoid a circular dependency between `cluster` and + `keyspace_notifications`. + +2. **Replace `ConnectionPool(...)` creation in `__init__`** (L3093–L3100): + ````python + if self.node is not None: + connection_pool = _ClusterNodePoolAdapter(self.node) + else: + connection_pool = None + ```` + The rest of `super().__init__(connection_pool=..., ...)` is unchanged; + `PubSub` only calls `get_connection` / `release` / `get_encoder` on the + pool, all of which the adapter satisfies. + +3. **Replace `ConnectionPool(...)` creation in `execute_command`** + (L3354–L3373): + ````python + if self.connection is None: + if self.connection_pool is None: + # ...node selection unchanged... + self.node = node + self.connection_pool = _ClusterNodePoolAdapter(node) + return await super().execute_command(*args, **kwargs) + ```` + +4. **Replace `ConnectionPool(...)` creation in `_get_node_pubsub`** + (L3159–L3176): + ````python + def _get_node_pubsub(self, node: "ClusterNode") -> PubSub: + try: + return self.node_pubsub_mapping[node.name] + except KeyError: + pubsub = PubSub( + connection_pool=_ClusterNodePoolAdapter(node), + encoder=self.cluster.encoder, + push_handler_func=self.push_handler_func, + event_dispatcher=self._event_dispatcher, + ) + self.node_pubsub_mapping[node.name] = pubsub + return pubsub + ```` + +5. **`aclose` / `get_redis_connection`**: unchanged. `PubSub.aclose()` + disconnects `self.connection` before calling + `connection_pool.release(connection)`, which the adapter forwards to + `ClusterNode.release(...)`; the socket re-enters the node's free queue + in a disconnected state, matching the guarantee already documented on + `_ClusterNodePoolAdapter`. + +6. **Encoder parity**: today the async `ClusterPubSub` passes + `encoder=redis_cluster.encoder`; after the change the adapter would + expose `node.get_encoder()` instead for any `PubSub` code path that + uses `connection_pool.get_encoder()`. The explicit `encoder=` argument + should be kept to preserve current behaviour (cluster-wide encoder). + +### 4.2 Sync `ClusterPubSub` — smaller adjustments + +The sync implementation already uses the node's own `ConnectionPool` via +`redis_cluster.get_redis_connection(node).connection_pool`, so it is +essentially compliant. The remaining improvements are about *symmetry* +and *delegation*, not correctness: + +1. **Centralise node-pubsub creation through `get_redis_connection`.** + In `_get_node_pubsub` (L2734–L2742) the current code reaches into + `node.redis_connection` directly; if a node has not yet been + materialised (e.g. just discovered on topology refresh), `redis_connection` + can be `None`. Replace with: + ````python + redis_conn = self.cluster.get_redis_connection(node) + pubsub = redis_conn.pubsub(push_handler_func=self.push_handler_func) + ```` + matching `ClusterKeyspaceNotifications._ensure_node_pubsub` and + guaranteeing the lazy `NodesManager.create_redis_connections([node])` + path is taken. + +2. **No change to the primary path** (`__init__` / `execute_command`): + both already use `cluster.get_redis_connection(node).connection_pool`. + +3. **`disconnect()` review**: iterating + `self.node_pubsub_mapping.values()` and calling + `pubsub.connection.disconnect()` is fine, but should tolerate + `pubsub.connection is None` (a shard pubsub that has not yet sent a + command). This is a pre-existing latent bug, not introduced by the + refactor; fix it as part of this change for parity with async + `aclose()` which already tolerates it. + +### 4.3 Cross-cutting tasks + +1. **Relocate `_ClusterNodePoolAdapter`** to a shared neutral module + (e.g. `redis/asyncio/cluster_pool_adapter.py`) and re-export from + `redis.asyncio.keyspace_notifications` for backwards compatibility. + Rationale: avoid `redis.asyncio.cluster` importing from + `redis.asyncio.keyspace_notifications` (the dependency direction today + is the opposite). + +2. **Import hygiene**: remove `ConnectionPool` import from + `redis/asyncio/cluster.py` if it becomes unused after the refactor + (check `ClusterPipeline` and `RedisCluster` first — likely still used). + +3. **Tests**: + - Add unit tests mirroring + `tests/test_asyncio/test_keyspace_notifications.py::test_receives_notification_from_any_node` + for `ClusterPubSub` to assert that the pubsub's `connection_pool` + after lazy node selection is a `_ClusterNodePoolAdapter` wrapping the + target `ClusterNode` (and not a detached `ConnectionPool`). + - Add a test that `_get_node_pubsub(node).connection_pool._node is node`. + - Add a test verifying that after `aclose()` the connection is + returned to `node._free` in a disconnected state (so subsequent + regular commands do not reuse a subscribed socket). + - For the sync side, add a test that `_get_node_pubsub` materialises + `node.redis_connection` when it is `None` + (via `cluster.get_redis_connection`). + +4. **Sync/async parity self-check** (per `.agent/instructions.md`): + - Public API unchanged: `ClusterPubSub.__init__`, `execute_command`, + `ssubscribe`, `sunsubscribe`, `get_redis_connection`, `disconnect` / + `aclose` keep their signatures and semantics. + - Both implementations now obtain connections exclusively via + `NodesManager`-owned resources (`node.redis_connection.connection_pool` + for sync, `_ClusterNodePoolAdapter(node)` for async). + - No new dependencies; only an internal module move. + +### 4.4 Out-of-scope (explicitly not changed) + +- The `PubSub` base class contract (single long-lived connection per + subscriber). +- Topology-refresh / auto-resubscribe behaviour. `ClusterPubSub` currently + has **no** auto-re-subscribe on topology change (unlike + `ClusterKeyspaceNotifications`); adding that is a separate feature and + is intentionally excluded from this refactor. +- Routing rules for `SSUBSCRIBE` / `SUNSUBSCRIBE` / `SPUBLISH` + (keyslot-based), which remain exactly as they are today. diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 25604526a4..6a69185c0d 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -1038,7 +1038,11 @@ async def aclose(self): return async with self._lock: if self.connection: - await self.connection.disconnect() + # Use nowait=True to avoid awaiting StreamWriter.wait_closed(), + # which can deadlock when a concurrent reader task (e.g. one + # running pubsub.run() or get_message(block=True)) still holds + # the transport. See https://github.com/redis/redis-py/issues/3941 + await self.connection.disconnect(nowait=True) self.connection.deregister_connect_callback(self.on_connect) await self.connection_pool.release(self.connection) self.connection = None diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index e9644efca4..40c728b273 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -43,7 +43,6 @@ from redis.asyncio.connection import ( AbstractConnection, Connection, - ConnectionPool, SSLConnection, parse_url, ) @@ -3052,6 +3051,43 @@ async def unlink(self, *names): return self.execute_command("UNLINK", *names) +class _ClusterNodePoolAdapter: + """Thin adapter exposing the :class:`ConnectionPool` interface that + :class:`PubSub` requires, backed by a :class:`ClusterNode`'s own + connection pool. + + Connections are acquired from the node via + :meth:`ClusterNode.acquire_connection` and returned via + :meth:`ClusterNode.release`. :meth:`PubSub.aclose` already + disconnects the connection *before* calling :meth:`release`, so the + connection is returned to the node's free-queue in a disconnected + state — guaranteeing that a subscribed socket is never silently + reused for regular commands. + """ + + def __init__(self, node: "ClusterNode") -> None: + self._node = node + self.connection_kwargs = node.connection_kwargs + + # -- methods used by PubSub ------------------------------------------------ + + def get_encoder(self) -> Encoder: + return self._node.get_encoder() + + async def get_connection( + self, command_name: Optional[str] = None, *keys: Any, **options: Any + ) -> Any: + connection = self._node.acquire_connection() + await connection.connect() + return connection + + async def release(self, connection: Any) -> None: + # PubSub.aclose() disconnects the connection before calling + # release(), so it is safe to put it back in the node's free + # queue – it will reconnect lazily on next use. + self._node.release(connection) + + class ClusterPubSub(PubSub): """ Async cluster implementation for pub/sub. @@ -3090,12 +3126,10 @@ def __init__( self.node = None self.set_pubsub_node(redis_cluster, node, host, port) - # Create connection pool if node is specified + # Borrow the node's own connection pool via an adapter rather than + # creating a second, detached ConnectionPool for pubsub. if self.node is not None: - connection_pool = ConnectionPool( - connection_class=self.node.connection_class, - **self.node.connection_kwargs, - ) + connection_pool = _ClusterNodePoolAdapter(self.node) else: connection_pool = None @@ -3161,13 +3195,8 @@ def _get_node_pubsub(self, node: "ClusterNode") -> PubSub: try: return self.node_pubsub_mapping[node.name] except KeyError: - # Create a minimal connection pool for this node - connection_pool = ConnectionPool( - connection_class=node.connection_class, **node.connection_kwargs - ) - pubsub = PubSub( - connection_pool=connection_pool, + connection_pool=_ClusterNodePoolAdapter(node), encoder=self.cluster.encoder, push_handler_func=self.push_handler_func, event_dispatcher=self._event_dispatcher, @@ -3367,10 +3396,7 @@ async def execute_command(self, *args: Any, **kwargs: Any) -> Any: # Get a random node node = self.cluster.get_random_node() self.node = node - self.connection_pool = ConnectionPool( - connection_class=node.connection_class, - **node.connection_kwargs, - ) + self.connection_pool = _ClusterNodePoolAdapter(node) # Now we have a connection_pool, use parent's execute_command return await super().execute_command(*args, **kwargs) diff --git a/redis/asyncio/keyspace_notifications.py b/redis/asyncio/keyspace_notifications.py index 8fce0f79b9..7bad8197cc 100644 --- a/redis/asyncio/keyspace_notifications.py +++ b/redis/asyncio/keyspace_notifications.py @@ -45,9 +45,8 @@ from collections.abc import AsyncIterator, Awaitable, Callable from typing import Any -from redis._parsers.encoders import Encoder from redis.asyncio.client import PubSub, Redis -from redis.asyncio.cluster import ClusterNode, RedisCluster +from redis.asyncio.cluster import ClusterNode, RedisCluster, _ClusterNodePoolAdapter from redis.exceptions import ( ConnectionError, RedisError, @@ -65,43 +64,6 @@ logger = logging.getLogger(__name__) -class _ClusterNodePoolAdapter: - """Thin adapter exposing the :class:`ConnectionPool` interface that - :class:`PubSub` requires, backed by a :class:`ClusterNode`'s own - connection pool. - - Connections are acquired from the node via - :meth:`ClusterNode.acquire_connection` and returned via - :meth:`ClusterNode.release`. :meth:`PubSub.aclose` already - disconnects the connection *before* calling :meth:`release`, so the - connection is returned to the node's free-queue in a disconnected - state — guaranteeing that a subscribed socket is never silently - reused for regular commands. - """ - - def __init__(self, node: ClusterNode) -> None: - self._node = node - self.connection_kwargs = node.connection_kwargs - - # -- methods used by PubSub ------------------------------------------------ - - def get_encoder(self) -> Encoder: - return self._node.get_encoder() - - async def get_connection( - self, command_name: str | None = None, *keys: Any, **options: Any - ) -> Any: - connection = self._node.acquire_connection() - await connection.connect() - return connection - - async def release(self, connection: Any) -> None: - # PubSub.aclose() disconnects the connection before calling - # release(), so it is safe to put it back in the node's free - # queue – it will reconnect lazily on next use. - self._node.release(connection) - - # Type alias for handlers that can be sync or async AsyncHandlerT = Callable[[KeyNotification], None | Awaitable[None]] diff --git a/redis/cluster.py b/redis/cluster.py index 171e586efb..fa45375c4e 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -2735,9 +2735,8 @@ def _get_node_pubsub(self, node): try: return self.node_pubsub_mapping[node.name] except KeyError: - pubsub = node.redis_connection.pubsub( - push_handler_func=self.push_handler_func - ) + redis_connection = self.cluster.get_redis_connection(node) + pubsub = redis_connection.pubsub(push_handler_func=self.push_handler_func) self.node_pubsub_mapping[node.name] = pubsub return pubsub @@ -2838,7 +2837,8 @@ def disconnect(self): if self.connection: self.connection.disconnect() for pubsub in self.node_pubsub_mapping.values(): - pubsub.connection.disconnect() + if pubsub.connection: + pubsub.connection.disconnect() class ClusterPipeline(RedisCluster): diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 0bf663800b..a57c4d90aa 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -56,9 +56,6 @@ from ..ssl_utils import get_tls_certificates from .compat import aclosing -pytestmark = pytest.mark.onlycluster - - default_host = "127.0.0.1" default_port = 7000 default_cluster_slots = [ @@ -306,6 +303,7 @@ def ok_response(self, *args, **options): assert fetched_node.server_type == PRIMARY +@pytest.mark.onlycluster class TestRedisClusterObj: """ Tests for the RedisCluster class @@ -1093,6 +1091,7 @@ def address_remap(address): assert n_used > 1 +@pytest.mark.onlycluster class TestClusterRedisCommands: """ Tests for RedisCluster unique commands @@ -2430,6 +2429,7 @@ async def test_hotkeys_cluster(self, r: RedisCluster) -> None: await r.hotkeys_stop() +@pytest.mark.onlycluster class TestNodesManager: """ Tests for the NodesManager class @@ -2843,6 +2843,7 @@ async def test_move_node_to_end_of_cached_nodes_single_node(self) -> None: assert nodes_cache_names == [node1.name] +@pytest.mark.fixed_client class TestClusterNodeConnectionHandling: """Tests for ClusterNode connection handling methods.""" @@ -2986,6 +2987,7 @@ async def test_disconnect_if_needed_skips_when_no_reconnect_needed(self) -> None conn.disconnect.assert_not_called() +@pytest.mark.fixed_client class TestClusterConnectionErrorHandling: """Tests for cluster connection error handling behavior.""" @@ -3103,6 +3105,7 @@ def cmd_init_mock( disconnect_free.assert_called() +@pytest.mark.onlycluster class TestClusterPipeline: """Tests for the ClusterPipeline class.""" @@ -3458,6 +3461,7 @@ async def test_pipeline_with_default_node_error_command(self, create_redis): assert result[0] == cmd_count +@pytest.mark.onlycluster @pytest.mark.ssl class TestSSL: """ @@ -4533,3 +4537,74 @@ def message_handler(message): assert received_messages[0]["data"] == b"test message" finally: await pubsub.aclose() + + +@pytest.mark.fixed_client +class TestClusterPubSubWithMocks: + """ + Unit tests for async ClusterPubSub that do not require a running cluster. + """ + + def _make_pubsub(self, cluster_mock, node=None): + """Create a ClusterPubSub with the provided cluster mock.""" + from redis._parsers import Encoder + from redis.asyncio.cluster import ClusterPubSub + + cluster_mock.encoder = Encoder("utf-8", "strict", False) + return ClusterPubSub(cluster_mock, node=node) + + async def test_init_with_node_uses_adapter(self) -> None: + """ + __init__ with a node must wrap it in _ClusterNodePoolAdapter instead + of creating a detached ConnectionPool. + """ + from redis.asyncio.cluster import _ClusterNodePoolAdapter + + node = ClusterNode("127.0.0.1", 7000) + cluster = Mock() + cluster.get_node.return_value = node + + pubsub = self._make_pubsub(cluster, node=node) + + assert isinstance(pubsub.connection_pool, _ClusterNodePoolAdapter) + assert pubsub.connection_pool._node is node + assert pubsub.connection_pool.connection_kwargs is node.connection_kwargs + + async def test_init_without_node_has_no_connection_pool(self) -> None: + """ + __init__ without a node must defer connection_pool creation until + the first command selects a node. + """ + cluster = Mock() + pubsub = self._make_pubsub(cluster) + + assert pubsub.node is None + assert pubsub.connection_pool is None + + async def test_get_node_pubsub_uses_adapter(self) -> None: + """ + _get_node_pubsub must build a PubSub backed by a + _ClusterNodePoolAdapter wrapping the ClusterNode. + """ + from redis.asyncio.cluster import _ClusterNodePoolAdapter + + cluster = Mock() + pubsub = self._make_pubsub(cluster) + node = ClusterNode("127.0.0.1", 7000) + + shard_pubsub = pubsub._get_node_pubsub(node) + + assert isinstance(shard_pubsub.connection_pool, _ClusterNodePoolAdapter) + assert shard_pubsub.connection_pool._node is node + assert pubsub.node_pubsub_mapping[node.name] is shard_pubsub + + async def test_get_node_pubsub_caches_by_node_name(self) -> None: + """Repeated calls must not re-materialise the shard PubSub.""" + cluster = Mock() + pubsub = self._make_pubsub(cluster) + node = ClusterNode("127.0.0.1", 7000) + + first = pubsub._get_node_pubsub(node) + second = pubsub._get_node_pubsub(node) + + assert first is second diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py index 90d69f75ce..2538e64bb9 100644 --- a/tests/test_asyncio/test_pubsub.py +++ b/tests/test_asyncio/test_pubsub.py @@ -1470,3 +1470,67 @@ async def test_handle_message_does_not_record_metric_for_pong_type( ) mock_record.assert_not_called() + + +@pytest.mark.asyncio +@pytest.mark.fixed_client +class TestPubSubAcloseWithMocks: + """ + Regression tests for https://github.com/redis/redis-py/issues/3941 — + PubSub.aclose() must not hang inside StreamWriter.wait_closed() when a + concurrent reader task still owns the pubsub connection's transport. + """ + + def _make_pubsub(self): + pool = MagicMock() + pool.get_encoder = MagicMock( + return_value=MagicMock( + decode_responses=False, + encode=lambda v: v.encode() if isinstance(v, str) else v, + ) + ) + pool.release = AsyncMock() + pubsub = PubSub(connection_pool=pool) + connection = MagicMock() + connection.disconnect = AsyncMock() + connection.deregister_connect_callback = MagicMock() + pubsub.connection = connection + return pubsub, pool, connection + + async def test_aclose_disconnects_with_nowait(self): + """aclose() must call connection.disconnect(nowait=True) to avoid + awaiting StreamWriter.wait_closed(), which can deadlock when another + task is blocked in parse_response() on the same socket.""" + pubsub, pool, connection = self._make_pubsub() + + await pubsub.aclose() + + connection.disconnect.assert_awaited_once_with(nowait=True) + connection.deregister_connect_callback.assert_called_once_with( + pubsub.on_connect + ) + pool.release.assert_awaited_once_with(connection) + assert pubsub.connection is None + + async def test_aclose_does_not_hang_when_wait_closed_would_block(self): + """ + End-to-end regression: even if the underlying StreamWriter.wait_closed() + would hang forever (as happens when a concurrent reader still holds the + transport), aclose() returns promptly because it passes nowait=True to + disconnect(). + """ + pubsub, _pool, connection = self._make_pubsub() + + async def fake_disconnect(nowait: bool = False, **_): + # Simulates AbstractConnection.disconnect(): if nowait=False we + # would hang awaiting wait_closed(); with nowait=True we return + # immediately. + if not nowait: + await asyncio.Event().wait() + + connection.disconnect = AsyncMock(side_effect=fake_disconnect) + + async with async_timeout(2): + await pubsub.aclose() + + connection.disconnect.assert_awaited_once_with(nowait=True) diff --git a/tests/test_cluster.py b/tests/test_cluster.py index b8537aad8a..f908ede80e 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -3216,6 +3216,81 @@ def test_move_node_to_end_of_cached_nodes_single_node(self): assert nodes_cache_names == [node1.name] +@pytest.mark.fixed_client +class TestClusterPubSubWithMocks: + """ + Unit tests for ClusterPubSub that do not require a running cluster. + """ + + def _make_pubsub(self, cluster_mock): + """Create a ClusterPubSub with no node set, using the provided cluster mock.""" + from redis._parsers import Encoder + from redis.cluster import ClusterPubSub + + cluster_mock.encoder = Encoder("utf-8", "strict", False) + return ClusterPubSub(cluster_mock) + + def test_get_node_pubsub_uses_cluster_get_redis_connection(self): + """ + _get_node_pubsub must route through cluster.get_redis_connection(node) + so newly-discovered nodes are materialised via NodesManager. + """ + cluster = Mock() + redis_conn = Mock() + shard_pubsub = Mock() + redis_conn.pubsub.return_value = shard_pubsub + cluster.get_redis_connection.return_value = redis_conn + + pubsub = self._make_pubsub(cluster) + node = ClusterNode("127.0.0.1", 7000) + + result = pubsub._get_node_pubsub(node) + + cluster.get_redis_connection.assert_called_once_with(node) + redis_conn.pubsub.assert_called_once_with(push_handler_func=None) + assert result is shard_pubsub + assert pubsub.node_pubsub_mapping[node.name] is shard_pubsub + + def test_get_node_pubsub_caches_by_node_name(self): + """Repeated calls must not re-materialise the shard PubSub.""" + cluster = Mock() + redis_conn = Mock() + redis_conn.pubsub.return_value = Mock() + cluster.get_redis_connection.return_value = redis_conn + + pubsub = self._make_pubsub(cluster) + node = ClusterNode("127.0.0.1", 7000) + + first = pubsub._get_node_pubsub(node) + second = pubsub._get_node_pubsub(node) + + assert first is second + cluster.get_redis_connection.assert_called_once_with(node) + redis_conn.pubsub.assert_called_once() + + def test_disconnect_tolerates_shard_pubsub_with_no_connection(self): + """ + disconnect() must skip shard PubSub entries whose connection has not + been materialised yet (pubsub.connection is None). + """ + cluster = Mock() + pubsub = self._make_pubsub(cluster) + + pending_pubsub = Mock() + pending_pubsub.connection = None + active_pubsub = Mock() + active_pubsub.connection = Mock() + + pubsub.node_pubsub_mapping = { + "127.0.0.1:7000": pending_pubsub, + "127.0.0.1:7001": active_pubsub, + } + + pubsub.disconnect() + + active_pubsub.connection.disconnect.assert_called_once() + + @pytest.mark.onlycluster class TestClusterPubSubObject: """