Skip to content

Commit 67aee8c

Browse files
authored
Merge branch 'ps_add_fail_over_events_handling' into hitless_handshake
2 parents 8a6402f + 66c1fe0 commit 67aee8c

File tree

3 files changed

+42
-69
lines changed

3 files changed

+42
-69
lines changed

redis/maintenance_events.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,9 @@
1313
class MaintenanceState(enum.Enum):
1414
NONE = "none"
1515
MOVING = "moving"
16-
MIGRATING = "migrating"
17-
FAILING_OVER = "failing_over"
18-
16+
MAINTENANCE = "maintenance"
1917

18+
2019
class EndpointType:
2120
"""Constants for valid endpoint types used in CLIENT MAINT_NOTIFICATIONS command."""
2221

@@ -26,6 +25,7 @@ class EndpointType:
2625
EXTERNAL_FQDN = "external-fqdn"
2726
NONE = "none"
2827

28+
2929
@classmethod
3030
def get_valid_types(cls):
3131
"""Return a set of all valid endpoint types."""
@@ -37,7 +37,7 @@ def get_valid_types(cls):
3737
cls.NONE,
3838
}
3939

40-
40+
4141
if TYPE_CHECKING:
4242
from redis.connection import (
4343
BlockingConnectionPool,
@@ -531,6 +531,7 @@ def __init__(
531531

532532
self.endpoint_type = endpoint_type
533533

534+
534535
def __repr__(self) -> str:
535536
return (
536537
f"{self.__class__.__name__}("
@@ -748,30 +749,40 @@ def handle_node_moved_event(self, event: NodeMovingEvent):
748749

749750

750751
class MaintenanceEventConnectionHandler:
752+
# 1 = "starting maintenance" events, 0 = "completed maintenance" events
753+
_EVENT_TYPES: dict[type["MaintenanceEvent"], int] = {
754+
NodeMigratingEvent: 1,
755+
NodeFailingOverEvent: 1,
756+
NodeMigratedEvent: 0,
757+
NodeFailedOverEvent: 0,
758+
}
759+
751760
def __init__(
752761
self, connection: "ConnectionInterface", config: MaintenanceEventsConfig
753762
) -> None:
754763
self.connection = connection
755764
self.config = config
756765

757766
def handle_event(self, event: MaintenanceEvent):
758-
if isinstance(event, NodeMigratingEvent):
759-
return self.handle_maintenance_start_event(MaintenanceState.MIGRATING)
760-
elif isinstance(event, NodeMigratedEvent):
761-
return self.handle_maintenance_completed_event()
762-
elif isinstance(event, NodeFailingOverEvent):
763-
return self.handle_maintenance_start_event(MaintenanceState.FAILING_OVER)
764-
elif isinstance(event, NodeFailedOverEvent):
765-
return self.handle_maintenance_completed_event()
766-
else:
767+
# get the event type by checking its class in the _EVENT_TYPES dict
768+
event_type = self._EVENT_TYPES.get(event.__class__, None)
769+
770+
if event_type is None:
767771
logging.error(f"Unhandled event type: {event}")
772+
return
773+
774+
if event_type:
775+
self.handle_maintenance_start_event(MaintenanceState.MAINTENANCE)
776+
else:
777+
self.handle_maintenance_completed_event()
768778

769779
def handle_maintenance_start_event(self, maintenance_state: MaintenanceState):
770780
if (
771781
self.connection.maintenance_state == MaintenanceState.MOVING
772782
or not self.config.is_relax_timeouts_enabled()
773783
):
774784
return
785+
775786
self.connection.maintenance_state = maintenance_state
776787
self.connection.set_tmp_settings(tmp_relax_timeout=self.config.relax_timeout)
777788
# extend the timeout for all created connections

tests/test_maintenance_events.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytest
44

55
from redis.connection import ConnectionInterface
6+
67
from redis.maintenance_events import (
78
MaintenanceEvent,
89
NodeMovingEvent,
@@ -564,7 +565,7 @@ def test_handle_event_migrating(self):
564565
self.handler, "handle_maintenance_start_event"
565566
) as mock_handle:
566567
self.handler.handle_event(event)
567-
mock_handle.assert_called_once_with(MaintenanceState.MIGRATING)
568+
mock_handle.assert_called_once_with(MaintenanceState.MAINTENANCE)
568569

569570
def test_handle_event_migrated(self):
570571
"""Test handling of NodeMigratedEvent."""
@@ -584,7 +585,8 @@ def test_handle_event_failing_over(self):
584585
self.handler, "handle_maintenance_start_event"
585586
) as mock_handle:
586587
self.handler.handle_event(event)
587-
mock_handle.assert_called_once_with(MaintenanceState.FAILING_OVER)
588+
mock_handle.assert_called_once_with(MaintenanceState.MAINTENANCE)
589+
588590

589591
def test_handle_event_failed_over(self):
590592
"""Test handling of NodeFailedOverEvent."""
@@ -610,37 +612,28 @@ def test_handle_maintenance_start_event_disabled(self):
610612
config = MaintenanceEventsConfig(relax_timeout=-1)
611613
handler = MaintenanceEventConnectionHandler(self.mock_connection, config)
612614

613-
result = handler.handle_maintenance_start_event(MaintenanceState.MIGRATING)
615+
result = handler.handle_maintenance_start_event(MaintenanceState.MAINTENANCE)
616+
614617
assert result is None
615618
self.mock_connection.update_current_socket_timeout.assert_not_called()
616619

617620
def test_handle_maintenance_start_event_moving_state(self):
618621
"""Test maintenance start event handling when connection is in MOVING state."""
619622
self.mock_connection.maintenance_state = MaintenanceState.MOVING
620623

621-
result = self.handler.handle_maintenance_start_event(MaintenanceState.MIGRATING)
624+
result = self.handler.handle_maintenance_start_event(
625+
MaintenanceState.MAINTENANCE
626+
)
622627
assert result is None
623628
self.mock_connection.update_current_socket_timeout.assert_not_called()
624629

625-
def test_handle_maintenance_start_event_migrating_success(self):
630+
def test_handle_maintenance_start_event_success(self):
626631
"""Test successful maintenance start event handling for migrating."""
627632
self.mock_connection.maintenance_state = MaintenanceState.NONE
628633

629-
self.handler.handle_maintenance_start_event(MaintenanceState.MIGRATING)
634+
self.handler.handle_maintenance_start_event(MaintenanceState.MAINTENANCE)
630635

631-
assert self.mock_connection.maintenance_state == MaintenanceState.MIGRATING
632-
self.mock_connection.update_current_socket_timeout.assert_called_once_with(20)
633-
self.mock_connection.set_tmp_settings.assert_called_once_with(
634-
tmp_relax_timeout=20
635-
)
636-
637-
def test_handle_maintenance_start_event_failing_over_success(self):
638-
"""Test successful maintenance start event handling for failing over."""
639-
self.mock_connection.maintenance_state = MaintenanceState.NONE
640-
641-
self.handler.handle_maintenance_start_event(MaintenanceState.FAILING_OVER)
642-
643-
assert self.mock_connection.maintenance_state == MaintenanceState.FAILING_OVER
636+
assert self.mock_connection.maintenance_state == MaintenanceState.MAINTENANCE
644637
self.mock_connection.update_current_socket_timeout.assert_called_once_with(20)
645638
self.mock_connection.set_tmp_settings.assert_called_once_with(
646639
tmp_relax_timeout=20
@@ -665,11 +658,12 @@ def test_handle_maintenance_completed_event_moving_state(self):
665658

666659
def test_handle_maintenance_completed_event_success(self):
667660
"""Test successful maintenance completed event handling."""
668-
self.mock_connection.maintenance_state = MaintenanceState.MIGRATING
661+
self.mock_connection.maintenance_state = MaintenanceState.MAINTENANCE
669662

670663
self.handler.handle_maintenance_completed_event()
671664

672665
assert self.mock_connection.maintenance_state == MaintenanceState.NONE
666+
673667
self.mock_connection.update_current_socket_timeout.assert_called_once_with(-1)
674668
self.mock_connection.reset_tmp_settings.assert_called_once_with(
675669
reset_relax_timeout=True

tests/test_maintenance_events_handling.py

Lines changed: 4 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1431,7 +1431,8 @@ def worker(idx):
14311431
def test_moving_migrating_migrated_moved_state_transitions(self, pool_class):
14321432
"""
14331433
Test moving configs are not lost if the per connection events get picked up after moving is handled.
1434-
MOVING → MIGRATING → MIGRATED → FAILING_OVER → FAILED_OVER → MOVED
1434+
Sequence of events: MOVING, MIGRATING, MIGRATED, FAILING_OVER, FAILED_OVER, MOVED.
1435+
Note: FAILING_OVER and FAILED_OVER events do not change the connection state when already in MOVING state.
14351436
Checks the state after each event for all connections and for new connections created during each state.
14361437
"""
14371438
# Setup
@@ -1541,25 +1542,6 @@ def test_moving_migrating_migrated_moved_state_transitions(self, pool_class):
15411542
expected_current_peername=DEFAULT_ADDRESS.split(":")[0],
15421543
)
15431544

1544-
# 4. FAILING_OVER event (simulate direct connection handler call)
1545-
for conn in in_use_connections:
1546-
conn._maintenance_event_connection_handler.handle_event(
1547-
NodeFailingOverEvent(id=3, ttl=1)
1548-
)
1549-
# State should not change for connections that are in MOVING state
1550-
Helpers.validate_in_use_connections_state(
1551-
in_use_connections,
1552-
expected_state=MaintenanceState.MOVING,
1553-
expected_host_address=tmp_address,
1554-
expected_socket_timeout=self.config.relax_timeout,
1555-
expected_socket_connect_timeout=self.config.relax_timeout,
1556-
expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0],
1557-
expected_orig_socket_timeout=None,
1558-
expected_orig_socket_connect_timeout=None,
1559-
expected_current_socket_timeout=self.config.relax_timeout,
1560-
expected_current_peername=DEFAULT_ADDRESS.split(":")[0],
1561-
)
1562-
15631545
# 5. FAILED_OVER event (simulate direct connection handler call)
15641546
for conn in in_use_connections:
15651547
conn._maintenance_event_connection_handler.handle_event(
@@ -1579,21 +1561,6 @@ def test_moving_migrating_migrated_moved_state_transitions(self, pool_class):
15791561
expected_current_peername=DEFAULT_ADDRESS.split(":")[0],
15801562
)
15811563

1582-
# 6. MOVED event (simulate timer expiry)
1583-
pool_handler.handle_node_moved_event(moving_event)
1584-
Helpers.validate_in_use_connections_state(
1585-
in_use_connections,
1586-
expected_state=MaintenanceState.NONE,
1587-
expected_host_address=DEFAULT_ADDRESS.split(":")[0],
1588-
expected_socket_timeout=None,
1589-
expected_socket_connect_timeout=None,
1590-
expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0],
1591-
expected_orig_socket_timeout=None,
1592-
expected_orig_socket_connect_timeout=None,
1593-
expected_current_socket_timeout=None,
1594-
expected_current_peername=DEFAULT_ADDRESS.split(":")[0],
1595-
)
1596-
15971564
# 6. MOVED event (simulate timer expiry)
15981565
pool_handler.handle_node_moved_event(moving_event)
15991566
Helpers.validate_in_use_connections_state(
@@ -1841,7 +1808,8 @@ def test_migrating_after_moving_multiple_proxies(self, pool_class):
18411808
conn_event_handler = conn._maintenance_event_connection_handler
18421809
conn_event_handler.handle_event(NodeMigratingEvent(id=3, ttl=1))
18431810
# validate connection is in MIGRATING state
1844-
assert conn.maintenance_state == MaintenanceState.MIGRATING
1811+
assert conn.maintenance_state == MaintenanceState.MAINTENANCE
1812+
18451813
assert conn.socket_timeout == self.config.relax_timeout
18461814

18471815
# Send MIGRATED event to con with ip = key3

0 commit comments

Comments
 (0)