Skip to content

Commit 51d24ba

Browse files
committed
Applying review comments
1 parent e4a8646 commit 51d24ba

File tree

3 files changed

+30
-31
lines changed

3 files changed

+30
-31
lines changed

redis/maintenance_events.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
class MaintenanceState(enum.Enum):
1212
NONE = "none"
1313
MOVING = "moving"
14-
MIGRATING = "migrating"
14+
MAINTENANCE = "maintenance"
1515
FAILING_OVER = "failing_over"
1616

1717

@@ -557,23 +557,32 @@ def handle_node_moved_event(self, event: NodeMovingEvent):
557557

558558

559559
class MaintenanceEventConnectionHandler:
560+
# 1 = "starting maintenance" events, 0 = "completed maintenance" events
561+
_EVENT_TYPES: dict[type["MaintenanceEvent"], int] = {
562+
NodeMigratingEvent: 1,
563+
NodeFailingOverEvent: 1,
564+
NodeMigratedEvent: 0,
565+
NodeFailedOverEvent: 0,
566+
}
567+
560568
def __init__(
561569
self, connection: "ConnectionInterface", config: MaintenanceEventsConfig
562570
) -> None:
563571
self.connection = connection
564572
self.config = config
565573

566574
def handle_event(self, event: MaintenanceEvent):
567-
if isinstance(event, NodeMigratingEvent):
568-
return self.handle_maintenance_start_event(MaintenanceState.MIGRATING)
569-
elif isinstance(event, NodeMigratedEvent):
570-
return self.handle_maintenance_completed_event()
571-
elif isinstance(event, NodeFailingOverEvent):
572-
return self.handle_maintenance_start_event(MaintenanceState.FAILING_OVER)
573-
elif isinstance(event, NodeFailedOverEvent):
574-
return self.handle_maintenance_completed_event()
575-
else:
575+
# get the event type by checking its class in the _EVENT_TYPES dict
576+
event_type = self._EVENT_TYPES.get(event.__class__)
577+
578+
if event_type is None:
576579
logging.error(f"Unhandled event type: {event}")
580+
return
581+
582+
if event_type:
583+
self.handle_maintenance_start_event(MaintenanceState.MAINTENANCE)
584+
else:
585+
self.handle_maintenance_completed_event()
577586

578587
def handle_maintenance_start_event(self, maintenance_state: MaintenanceState):
579588
if (

tests/test_maintenance_events.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@ def test_handle_event_migrating(self):
562562
self.handler, "handle_maintenance_start_event"
563563
) as mock_handle:
564564
self.handler.handle_event(event)
565-
mock_handle.assert_called_once_with(MaintenanceState.MIGRATING)
565+
mock_handle.assert_called_once_with(MaintenanceState.MAINTENANCE)
566566

567567
def test_handle_event_migrated(self):
568568
"""Test handling of NodeMigratedEvent."""
@@ -582,7 +582,7 @@ def test_handle_event_failing_over(self):
582582
self.handler, "handle_maintenance_start_event"
583583
) as mock_handle:
584584
self.handler.handle_event(event)
585-
mock_handle.assert_called_once_with(MaintenanceState.FAILING_OVER)
585+
mock_handle.assert_called_once_with(MaintenanceState.MAINTENANCE)
586586

587587
def test_handle_event_failed_over(self):
588588
"""Test handling of NodeFailedOverEvent."""
@@ -608,37 +608,27 @@ def test_handle_maintenance_start_event_disabled(self):
608608
config = MaintenanceEventsConfig(relax_timeout=-1)
609609
handler = MaintenanceEventConnectionHandler(self.mock_connection, config)
610610

611-
result = handler.handle_maintenance_start_event(MaintenanceState.MIGRATING)
611+
result = handler.handle_maintenance_start_event(MaintenanceState.MAINTENANCE)
612612
assert result is None
613613
self.mock_connection.update_current_socket_timeout.assert_not_called()
614614

615615
def test_handle_maintenance_start_event_moving_state(self):
616616
"""Test maintenance start event handling when connection is in MOVING state."""
617617
self.mock_connection.maintenance_state = MaintenanceState.MOVING
618618

619-
result = self.handler.handle_maintenance_start_event(MaintenanceState.MIGRATING)
619+
result = self.handler.handle_maintenance_start_event(
620+
MaintenanceState.MAINTENANCE
621+
)
620622
assert result is None
621623
self.mock_connection.update_current_socket_timeout.assert_not_called()
622624

623-
def test_handle_maintenance_start_event_migrating_success(self):
625+
def test_handle_maintenance_start_event_success(self):
624626
"""Test successful maintenance start event handling for migrating."""
625627
self.mock_connection.maintenance_state = MaintenanceState.NONE
626628

627-
self.handler.handle_maintenance_start_event(MaintenanceState.MIGRATING)
628-
629-
assert self.mock_connection.maintenance_state == MaintenanceState.MIGRATING
630-
self.mock_connection.update_current_socket_timeout.assert_called_once_with(20)
631-
self.mock_connection.set_tmp_settings.assert_called_once_with(
632-
tmp_relax_timeout=20
633-
)
634-
635-
def test_handle_maintenance_start_event_failing_over_success(self):
636-
"""Test successful maintenance start event handling for failing over."""
637-
self.mock_connection.maintenance_state = MaintenanceState.NONE
638-
639-
self.handler.handle_maintenance_start_event(MaintenanceState.FAILING_OVER)
629+
self.handler.handle_maintenance_start_event(MaintenanceState.MAINTENANCE)
640630

641-
assert self.mock_connection.maintenance_state == MaintenanceState.FAILING_OVER
631+
assert self.mock_connection.maintenance_state == MaintenanceState.MAINTENANCE
642632
self.mock_connection.update_current_socket_timeout.assert_called_once_with(20)
643633
self.mock_connection.set_tmp_settings.assert_called_once_with(
644634
tmp_relax_timeout=20
@@ -663,7 +653,7 @@ def test_handle_maintenance_completed_event_moving_state(self):
663653

664654
def test_handle_maintenance_completed_event_success(self):
665655
"""Test successful maintenance completed event handling."""
666-
self.mock_connection.maintenance_state = MaintenanceState.MIGRATING
656+
self.mock_connection.maintenance_state = MaintenanceState.MAINTENANCE
667657

668658
self.handler.handle_maintenance_completed_event()
669659

tests/test_maintenance_events_handling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1808,7 +1808,7 @@ def test_migrating_after_moving_multiple_proxies(self, pool_class):
18081808
conn_event_handler = conn._maintenance_event_connection_handler
18091809
conn_event_handler.handle_event(NodeMigratingEvent(id=3, ttl=1))
18101810
# validate connection is in MIGRATING state
1811-
assert conn.maintenance_state == MaintenanceState.MIGRATING
1811+
assert conn.maintenance_state == MaintenanceState.MAINTENANCE
18121812
assert conn.socket_timeout == self.config.relax_timeout
18131813

18141814
# Send MIGRATED event to con with ip = key3

0 commit comments

Comments
 (0)