From a8ba5ce992a292e6ac75ac60b23003ac79ed2867 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Thu, 24 Jul 2025 18:05:59 +0300 Subject: [PATCH] Adding handling of FAILING_OVER and FAILED_OVER events/push notifications --- redis/maintenance_events.py | 116 ++++++++++++++- tests/test_connection_pool.py | 1 - tests/test_maintenance_events.py | 171 +++++++++++++++++++--- tests/test_maintenance_events_handling.py | 124 +++++++++++++++- 4 files changed, 379 insertions(+), 33 deletions(-) diff --git a/redis/maintenance_events.py b/redis/maintenance_events.py index d4b4e06231..f99ad37397 100644 --- a/redis/maintenance_events.py +++ b/redis/maintenance_events.py @@ -12,6 +12,7 @@ class MaintenanceState(enum.Enum): NONE = "none" MOVING = "moving" MIGRATING = "migrating" + FAILING_OVER = "failing_over" if TYPE_CHECKING: @@ -261,6 +262,105 @@ def __hash__(self) -> int: return hash((self.__class__, self.id)) +class NodeFailingOverEvent(MaintenanceEvent): + """ + Event for when a Redis cluster node is in the process of failing over. + + This event is received when a node starts a failover process during + cluster maintenance operations or when handling node failures. + + Args: + id (int): Unique identifier for this event + ttl (int): Time-to-live in seconds for this notification + """ + + def __init__(self, id: int, ttl: int): + super().__init__(id, ttl) + + def __repr__(self) -> str: + expiry_time = self.creation_time + self.ttl + remaining = max(0, expiry_time - time.monotonic()) + return ( + f"{self.__class__.__name__}(" + f"id={self.id}, " + f"ttl={self.ttl}, " + f"creation_time={self.creation_time}, " + f"expires_at={expiry_time}, " + f"remaining={remaining:.1f}s, " + f"expired={self.is_expired()}" + f")" + ) + + def __eq__(self, other) -> bool: + """ + Two NodeFailingOverEvent events are considered equal if they have the same + id and are of the same type. + """ + if not isinstance(other, NodeFailingOverEvent): + return False + return self.id == other.id and type(self) is type(other) + + def __hash__(self) -> int: + """ + Return a hash value for the event to allow + instances to be used in sets and as dictionary keys. + + Returns: + int: Hash value based on event type and id + """ + return hash((self.__class__, self.id)) + + +class NodeFailedOverEvent(MaintenanceEvent): + """ + Event for when a Redis cluster node has completed a failover. + + This event is received when a node has finished the failover process + during cluster maintenance operations or after handling node failures. + + Args: + id (int): Unique identifier for this event + """ + + DEFAULT_TTL = 5 + + def __init__(self, id: int): + super().__init__(id, NodeFailedOverEvent.DEFAULT_TTL) + + def __repr__(self) -> str: + expiry_time = self.creation_time + self.ttl + remaining = max(0, expiry_time - time.monotonic()) + return ( + f"{self.__class__.__name__}(" + f"id={self.id}, " + f"ttl={self.ttl}, " + f"creation_time={self.creation_time}, " + f"expires_at={expiry_time}, " + f"remaining={remaining:.1f}s, " + f"expired={self.is_expired()}" + f")" + ) + + def __eq__(self, other) -> bool: + """ + Two NodeFailedOverEvent events are considered equal if they have the same + id and are of the same type. + """ + if not isinstance(other, NodeFailedOverEvent): + return False + return self.id == other.id and type(self) is type(other) + + def __hash__(self) -> int: + """ + Return a hash value for the event to allow + instances to be used in sets and as dictionary keys. + + Returns: + int: Hash value based on event type and id + """ + return hash((self.__class__, self.id)) + + class MaintenanceEventsConfig: """ Configuration class for maintenance events handling behaviour. Events are received through @@ -446,24 +546,28 @@ def __init__( def handle_event(self, event: MaintenanceEvent): if isinstance(event, NodeMigratingEvent): - return self.handle_migrating_event(event) + return self.handle_maintenance_start_event(MaintenanceState.MIGRATING) elif isinstance(event, NodeMigratedEvent): - return self.handle_migration_completed_event(event) + return self.handle_maintenance_completed_event() + elif isinstance(event, NodeFailingOverEvent): + return self.handle_maintenance_start_event(MaintenanceState.FAILING_OVER) + elif isinstance(event, NodeFailedOverEvent): + return self.handle_maintenance_completed_event() else: logging.error(f"Unhandled event type: {event}") - def handle_migrating_event(self, notification: NodeMigratingEvent): + def handle_maintenance_start_event(self, maintenance_state: MaintenanceState): if ( self.connection.maintenance_state == MaintenanceState.MOVING or not self.config.is_relax_timeouts_enabled() ): return - self.connection.maintenance_state = MaintenanceState.MIGRATING + self.connection.maintenance_state = maintenance_state self.connection.set_tmp_settings(tmp_relax_timeout=self.config.relax_timeout) # extend the timeout for all created connections self.connection.update_current_socket_timeout(self.config.relax_timeout) - def handle_migration_completed_event(self, notification: "NodeMigratedEvent"): + def handle_maintenance_completed_event(self): # Only reset timeouts if state is not MOVING and relax timeouts are enabled if ( self.connection.maintenance_state == MaintenanceState.MOVING @@ -471,7 +575,7 @@ def handle_migration_completed_event(self, notification: "NodeMigratedEvent"): ): return self.connection.reset_tmp_settings(reset_relax_timeout=True) - # Node migration completed - reset the connection + # Maintenance completed - reset the connection # timeouts by providing -1 as the relax timeout self.connection.update_current_socket_timeout(-1) self.connection.maintenance_state = MaintenanceState.NONE diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 282aec567d..1eb68d3775 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -9,7 +9,6 @@ import redis from redis.cache import CacheConfig from redis.connection import CacheProxyConnection, Connection, to_bool -from redis.maintenance_events import MaintenanceState from redis.utils import SSL_AVAILABLE from .conftest import ( diff --git a/tests/test_maintenance_events.py b/tests/test_maintenance_events.py index 3eb648f079..a59b834a4e 100644 --- a/tests/test_maintenance_events.py +++ b/tests/test_maintenance_events.py @@ -7,9 +7,12 @@ NodeMovingEvent, NodeMigratingEvent, NodeMigratedEvent, + NodeFailingOverEvent, + NodeFailedOverEvent, MaintenanceEventsConfig, MaintenanceEventPoolHandler, MaintenanceEventConnectionHandler, + MaintenanceState, ) @@ -281,6 +284,84 @@ def test_equality_and_hash(self): assert hash(event1) != hash(event3) +class TestNodeFailingOverEvent: + """Test the NodeFailingOverEvent class.""" + + def test_init(self): + """Test NodeFailingOverEvent initialization.""" + with patch("time.monotonic", return_value=1000): + event = NodeFailingOverEvent(id=1, ttl=5) + assert event.id == 1 + assert event.ttl == 5 + assert event.creation_time == 1000 + + def test_repr(self): + """Test NodeFailingOverEvent string representation.""" + with patch("time.monotonic", return_value=1000): + event = NodeFailingOverEvent(id=1, ttl=5) + + with patch("time.monotonic", return_value=1002): # 2 seconds later + repr_str = repr(event) + assert "NodeFailingOverEvent" in repr_str + assert "id=1" in repr_str + assert "ttl=5" in repr_str + assert "remaining=3.0s" in repr_str + assert "expired=False" in repr_str + + def test_equality_and_hash(self): + """Test equality and hash for NodeFailingOverEvent.""" + event1 = NodeFailingOverEvent(id=1, ttl=5) + event2 = NodeFailingOverEvent(id=1, ttl=10) # Same id, different ttl + event3 = NodeFailingOverEvent(id=2, ttl=5) # Different id + + assert event1 == event2 + assert event1 != event3 + assert hash(event1) == hash(event2) + assert hash(event1) != hash(event3) + + +class TestNodeFailedOverEvent: + """Test the NodeFailedOverEvent class.""" + + def test_init(self): + """Test NodeFailedOverEvent initialization.""" + with patch("time.monotonic", return_value=1000): + event = NodeFailedOverEvent(id=1) + assert event.id == 1 + assert event.ttl == NodeFailedOverEvent.DEFAULT_TTL + assert event.creation_time == 1000 + + def test_default_ttl(self): + """Test that DEFAULT_TTL is used correctly.""" + assert NodeFailedOverEvent.DEFAULT_TTL == 5 + event = NodeFailedOverEvent(id=1) + assert event.ttl == 5 + + def test_repr(self): + """Test NodeFailedOverEvent string representation.""" + with patch("time.monotonic", return_value=1000): + event = NodeFailedOverEvent(id=1) + + with patch("time.monotonic", return_value=1001): # 1 second later + repr_str = repr(event) + assert "NodeFailedOverEvent" in repr_str + assert "id=1" in repr_str + assert "ttl=5" in repr_str + assert "remaining=4.0s" in repr_str + assert "expired=False" in repr_str + + def test_equality_and_hash(self): + """Test equality and hash for NodeFailedOverEvent.""" + event1 = NodeFailedOverEvent(id=1) + event2 = NodeFailedOverEvent(id=1) # Same id + event3 = NodeFailedOverEvent(id=2) # Different id + + assert event1 == event2 + assert event1 != event3 + assert hash(event1) == hash(event2) + assert hash(event1) != hash(event3) + + class TestMaintenanceEventsConfig: """Test the MaintenanceEventsConfig class.""" @@ -477,19 +558,41 @@ def test_handle_event_migrating(self): """Test handling of NodeMigratingEvent.""" event = NodeMigratingEvent(id=1, ttl=5) - with patch.object(self.handler, "handle_migrating_event") as mock_handle: + with patch.object( + self.handler, "handle_maintenance_start_event" + ) as mock_handle: self.handler.handle_event(event) - mock_handle.assert_called_once_with(event) + mock_handle.assert_called_once_with(MaintenanceState.MIGRATING) def test_handle_event_migrated(self): """Test handling of NodeMigratedEvent.""" event = NodeMigratedEvent(id=1) with patch.object( - self.handler, "handle_migration_completed_event" + self.handler, "handle_maintenance_completed_event" ) as mock_handle: self.handler.handle_event(event) - mock_handle.assert_called_once_with(event) + mock_handle.assert_called_once_with() + + def test_handle_event_failing_over(self): + """Test handling of NodeFailingOverEvent.""" + event = NodeFailingOverEvent(id=1, ttl=5) + + with patch.object( + self.handler, "handle_maintenance_start_event" + ) as mock_handle: + self.handler.handle_event(event) + mock_handle.assert_called_once_with(MaintenanceState.FAILING_OVER) + + def test_handle_event_failed_over(self): + """Test handling of NodeFailedOverEvent.""" + event = NodeFailedOverEvent(id=1) + + with patch.object( + self.handler, "handle_maintenance_completed_event" + ) as mock_handle: + self.handler.handle_event(event) + mock_handle.assert_called_once_with() def test_handle_event_unknown_type(self): """Test handling of unknown event type.""" @@ -500,43 +603,71 @@ def test_handle_event_unknown_type(self): result = self.handler.handle_event(event) assert result is None - def test_handle_migrating_event_disabled(self): - """Test migrating event handling when relax timeouts are disabled.""" + def test_handle_maintenance_start_event_disabled(self): + """Test maintenance start event handling when relax timeouts are disabled.""" config = MaintenanceEventsConfig(relax_timeout=-1) handler = MaintenanceEventConnectionHandler(self.mock_connection, config) - event = NodeMigratingEvent(id=1, ttl=5) - result = handler.handle_migrating_event(event) + result = handler.handle_maintenance_start_event(MaintenanceState.MIGRATING) assert result is None self.mock_connection.update_current_socket_timeout.assert_not_called() - def test_handle_migrating_event_success(self): - """Test successful migrating event handling.""" - event = NodeMigratingEvent(id=1, ttl=5) + def test_handle_maintenance_start_event_moving_state(self): + """Test maintenance start event handling when connection is in MOVING state.""" + self.mock_connection.maintenance_state = MaintenanceState.MOVING - self.handler.handle_migrating_event(event) + result = self.handler.handle_maintenance_start_event(MaintenanceState.MIGRATING) + assert result is None + self.mock_connection.update_current_socket_timeout.assert_not_called() + def test_handle_maintenance_start_event_migrating_success(self): + """Test successful maintenance start event handling for migrating.""" + self.mock_connection.maintenance_state = MaintenanceState.NONE + + self.handler.handle_maintenance_start_event(MaintenanceState.MIGRATING) + + assert self.mock_connection.maintenance_state == MaintenanceState.MIGRATING self.mock_connection.update_current_socket_timeout.assert_called_once_with(20) self.mock_connection.set_tmp_settings.assert_called_once_with( tmp_relax_timeout=20 ) - def test_handle_migration_completed_event_disabled(self): - """Test migration completed event handling when relax timeouts are disabled.""" + def test_handle_maintenance_start_event_failing_over_success(self): + """Test successful maintenance start event handling for failing over.""" + self.mock_connection.maintenance_state = MaintenanceState.NONE + + self.handler.handle_maintenance_start_event(MaintenanceState.FAILING_OVER) + + assert self.mock_connection.maintenance_state == MaintenanceState.FAILING_OVER + self.mock_connection.update_current_socket_timeout.assert_called_once_with(20) + self.mock_connection.set_tmp_settings.assert_called_once_with( + tmp_relax_timeout=20 + ) + + def test_handle_maintenance_completed_event_disabled(self): + """Test maintenance completed event handling when relax timeouts are disabled.""" config = MaintenanceEventsConfig(relax_timeout=-1) handler = MaintenanceEventConnectionHandler(self.mock_connection, config) - event = NodeMigratedEvent(id=1) - result = handler.handle_migration_completed_event(event) + result = handler.handle_maintenance_completed_event() assert result is None self.mock_connection.update_current_socket_timeout.assert_not_called() - def test_handle_migration_completed_event_success(self): - """Test successful migration completed event handling.""" - event = NodeMigratedEvent(id=1) + def test_handle_maintenance_completed_event_moving_state(self): + """Test maintenance completed event handling when connection is in MOVING state.""" + self.mock_connection.maintenance_state = MaintenanceState.MOVING + + result = self.handler.handle_maintenance_completed_event() + assert result is None + self.mock_connection.update_current_socket_timeout.assert_not_called() + + def test_handle_maintenance_completed_event_success(self): + """Test successful maintenance completed event handling.""" + self.mock_connection.maintenance_state = MaintenanceState.MIGRATING - self.handler.handle_migration_completed_event(event) + self.handler.handle_maintenance_completed_event() + assert self.mock_connection.maintenance_state == MaintenanceState.NONE self.mock_connection.update_current_socket_timeout.assert_called_once_with(-1) self.mock_connection.reset_tmp_settings.assert_called_once_with( reset_relax_timeout=True diff --git a/tests/test_maintenance_events_handling.py b/tests/test_maintenance_events_handling.py index dc4e850a50..b6fc2116c7 100644 --- a/tests/test_maintenance_events_handling.py +++ b/tests/test_maintenance_events_handling.py @@ -15,9 +15,11 @@ from redis.maintenance_events import ( MaintenanceEventsConfig, NodeMigratingEvent, + NodeMigratedEvent, + NodeFailingOverEvent, + NodeFailedOverEvent, MaintenanceEventPoolHandler, NodeMovingEvent, - NodeMigratedEvent, ) @@ -69,6 +71,22 @@ def send(self, data): # Format: >1\r\n$8\r\nMIGRATED\r\n (1 element: MIGRATED) migrated_push = ">1\r\n$8\r\nMIGRATED\r\n" response = migrated_push.encode() + response + elif ( + b"key_receive_failing_over_" in data + or b"key_receive_failing_over" in data + ): + # FAILING_OVER push message before SET key_receive_failing_over_X response + # Format: >2\r\n$12\r\nFAILING_OVER\r\n:10\r\n (2 elements: FAILING_OVER, ttl) + failing_over_push = ">2\r\n$12\r\nFAILING_OVER\r\n:10\r\n" + response = failing_over_push.encode() + response + elif ( + b"key_receive_failed_over_" in data + or b"key_receive_failed_over" in data + ): + # FAILED_OVER push message before SET key_receive_failed_over_X response + # Format: >1\r\n$11\r\nFAILED_OVER\r\n (1 element: FAILED_OVER) + failed_over_push = ">1\r\n$11\r\nFAILED_OVER\r\n" + response = failed_over_push.encode() + response elif b"key_receive_moving_" in data: # MOVING push message before SET key_receive_moving_X response # Format: >3\r\n$6\r\nMOVING\r\n:15\r\n+localhost:6379\r\n (3 elements: MOVING, ttl, host:port) @@ -91,6 +109,10 @@ def send(self, data): self.pending_responses.append(b"$6\r\nvalue2\r\n") elif b"key_receive_migrated" in data: self.pending_responses.append(b"$6\r\nvalue3\r\n") + elif b"key_receive_failing_over" in data: + self.pending_responses.append(b"$6\r\nvalue4\r\n") + elif b"key_receive_failed_over" in data: + self.pending_responses.append(b"$6\r\nvalue5\r\n") elif b"key1" in data: self.pending_responses.append(b"$6\r\nvalue1\r\n") else: @@ -719,13 +741,14 @@ def test_migration_related_events_handling_integration(self, pool_class): @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) def test_migrating_event_with_disabled_relax_timeout(self, pool_class): """ - Test migrating event handling when relax timeout is disabled. + Test maintenance events handling when relax timeout is disabled. This test validates that when relax_timeout is disabled (-1): - 1. MIGRATING events are received and processed + 1. MIGRATING, MIGRATED, FAILING_OVER, and FAILED_OVER events are received and processed 2. No timeout updates are applied to connections - 3. Socket timeouts remain unchanged during migration events + 3. Socket timeouts remain unchanged during all maintenance events 4. Tests both ConnectionPool and BlockingConnectionPool implementations + 5. Tests the complete lifecycle: MIGRATING -> MIGRATED -> FAILING_OVER -> FAILED_OVER """ # Create config with disabled relax timeout disabled_config = MaintenanceEventsConfig( @@ -768,6 +791,57 @@ def test_migrating_event_with_disabled_relax_timeout(self, pool_class): f"Command 3 (GET key1) failed. Expected: {expected_value3}, Got: {result3}" ) + # Command 4: This SET command will receive MIGRATED push message before response + key_migrated = "key_receive_migrated" + value_migrated = "value3" + result4 = test_redis_client.set(key_migrated, value_migrated) + + # Validate Command 4 result + assert result4 is True, "Command 4 (SET key_receive_migrated) failed" + + # Validate timeout is still NOT updated after MIGRATED (relax is disabled) + self._validate_current_timeout(None) + + # Command 5: This SET command will receive FAILING_OVER push message before response + key_failing_over = "key_receive_failing_over" + value_failing_over = "value4" + result5 = test_redis_client.set(key_failing_over, value_failing_over) + + # Validate Command 5 result + assert result5 is True, "Command 5 (SET key_receive_failing_over) failed" + + # Validate timeout is still NOT updated after FAILING_OVER (relax is disabled) + self._validate_current_timeout(None) + + # Command 6: Another command to verify timeout remains unchanged during failover + result6 = test_redis_client.get(key_failing_over) + + # Validate Command 6 result + expected_value6 = value_failing_over.encode() + assert result6 == expected_value6, ( + f"Command 6 (GET key_receive_failing_over) failed. Expected: {expected_value6}, Got: {result6}" + ) + + # Command 7: This SET command will receive FAILED_OVER push message before response + key_failed_over = "key_receive_failed_over" + value_failed_over = "value5" + result7 = test_redis_client.set(key_failed_over, value_failed_over) + + # Validate Command 7 result + assert result7 is True, "Command 7 (SET key_receive_failed_over) failed" + + # Validate timeout is still NOT updated after FAILED_OVER (relax is disabled) + self._validate_current_timeout(None) + + # Command 8: Final command to verify timeout remains unchanged after all events + result8 = test_redis_client.get(key_failed_over) + + # Validate Command 8 result + expected_value8 = value_failed_over.encode() + assert result8 == expected_value8, ( + f"Command 8 (GET key_receive_failed_over) failed. Expected: {expected_value8}, Got: {result8}" + ) + # Verify maintenance events were processed correctly # The key is that we have at least 1 socket and all operations succeeded assert len(self.mock_sockets) >= 1, ( @@ -1327,7 +1401,7 @@ def worker(idx): def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): """ Test moving configs are not lost if the per connection events get picked up after moving is handled. - MOVING → MIGRATING → MIGRATED → MOVED + MOVING → MIGRATING → MIGRATED → FAILING_OVER → FAILED_OVER → MOVED Checks the state after each event for all connections and for new connections created during each state. """ # Setup @@ -1418,7 +1492,45 @@ def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], ) - # 4. MOVED event (simulate timer expiry) + # 4. FAILING_OVER event (simulate direct connection handler call) + for conn in in_use_connections: + conn._maintenance_event_connection_handler.handle_event( + NodeFailingOverEvent(id=3, ttl=1) + ) + # State should not change for connections that are in MOVING state + self._validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.MOVING, + expected_host_address=tmp_address, + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], + ) + + # 5. FAILED_OVER event (simulate direct connection handler call) + for conn in in_use_connections: + conn._maintenance_event_connection_handler.handle_event( + NodeFailedOverEvent(id=3) + ) + # State should not change for connections that are in MOVING state + self._validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.MOVING, + expected_host_address=tmp_address, + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], + ) + + # 6. MOVED event (simulate timer expiry) pool_handler.handle_node_moved_event(moving_event) self._validate_in_use_connections_state( in_use_connections,