Skip to content

Commit 97db940

Browse files
committed
Applying review comments
1 parent 1388cb9 commit 97db940

File tree

4 files changed

+28
-63
lines changed

4 files changed

+28
-63
lines changed

redis/_parsers/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ def handle_push_response(self, response, **kwargs):
209209
and self.invalidation_push_handler_func
210210
):
211211
return self.invalidation_push_handler_func(response)
212+
212213
if msg_type in _MOVING_MESSAGE and self.node_moving_push_handler_func:
213214
# Expected message format is: MOVING <seq_number> <time> <endpoint>
214215
id = response[1]
@@ -280,6 +281,7 @@ async def handle_push_response(self, response, **kwargs):
280281
and self.invalidation_push_handler_func
281282
):
282283
return await self.invalidation_push_handler_func(response)
284+
283285
if msg_type in _MOVING_MESSAGE and self.node_moving_push_handler_func:
284286
# push notification from enterprise cluster for node moving
285287
id = response[1]

redis/connection.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -431,15 +431,13 @@ def __init__(
431431
self.maintenance_events_config = maintenance_events_config
432432

433433
# Set up maintenance events if enabled
434-
if maintenance_events_config and maintenance_events_config.enabled:
435-
self._enable_maintenance_events(
436-
maintenance_events_pool_handler,
437-
orig_host_address,
438-
orig_socket_timeout,
439-
orig_socket_connect_timeout,
440-
)
441-
else:
442-
self._maintenance_event_connection_handler = None
434+
self._configure_maintenance_events(
435+
maintenance_events_pool_handler,
436+
orig_host_address,
437+
orig_socket_timeout,
438+
orig_socket_connect_timeout,
439+
)
440+
443441
self._should_reconnect = False
444442
self.maintenance_state = maintenance_state
445443

@@ -498,7 +496,7 @@ def set_parser(self, parser_class):
498496
"""
499497
self._parser = parser_class(socket_read_size=self._socket_read_size)
500498

501-
def _enable_maintenance_events(
499+
def _configure_maintenance_events(
502500
self,
503501
maintenance_events_pool_handler=None,
504502
orig_host_address=None,
@@ -510,6 +508,7 @@ def _enable_maintenance_events(
510508
not self.maintenance_events_config
511509
or not self.maintenance_events_config.enabled
512510
):
511+
self._maintenance_event_connection_handler = None
513512
return
514513

515514
# Set up pool handler if available
@@ -702,7 +701,7 @@ def on_connect_check_health(self, check_health: bool = True):
702701
"MAINT_NOTIFICATIONS",
703702
"ON",
704703
"moving-endpoint-type",
705-
endpoint_type,
704+
endpoint_type.value,
706705
check_health=check_health,
707706
)
708707
response = self.read_response()

redis/maintenance_events.py

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,18 @@ class MaintenanceState(enum.Enum):
1616
MAINTENANCE = "maintenance"
1717

1818

19-
class EndpointType:
20-
"""Constants for valid endpoint types used in CLIENT MAINT_NOTIFICATIONS command."""
19+
class EndpointType(enum.Enum):
20+
"""Valid endpoint types used in CLIENT MAINT_NOTIFICATIONS command."""
2121

2222
INTERNAL_IP = "internal-ip"
2323
INTERNAL_FQDN = "internal-fqdn"
2424
EXTERNAL_IP = "external-ip"
2525
EXTERNAL_FQDN = "external-fqdn"
2626
NONE = "none"
2727

28-
@classmethod
29-
def get_valid_types(cls):
30-
"""Return a set of all valid endpoint types."""
31-
return {
32-
cls.INTERNAL_IP,
33-
cls.INTERNAL_FQDN,
34-
cls.EXTERNAL_IP,
35-
cls.EXTERNAL_FQDN,
36-
cls.NONE,
37-
}
28+
def __str__(self):
29+
"""Return the string value of the enum."""
30+
return self.value
3831

3932

4033
if TYPE_CHECKING:
@@ -438,7 +431,7 @@ def __init__(
438431
enabled: bool = True,
439432
proactive_reconnect: bool = True,
440433
relax_timeout: Optional[Number] = 20,
441-
endpoint_type: Optional[str] = None,
434+
endpoint_type: Optional[EndpointType] = None,
442435
):
443436
"""
444437
Initialize a new MaintenanceEventsConfig.
@@ -450,8 +443,7 @@ def __init__(
450443
Defaults to True.
451444
relax_timeout (Number): The relax timeout to use for the connection during maintenance.
452445
If -1 is provided - the relax timeout is disabled. Defaults to 20.
453-
endpoint_type (Optional[str]): Override for the endpoint type to use in CLIENT MAINT_NOTIFICATIONS.
454-
Must be one of: 'internal-ip', 'internal-fqdn', 'external-ip', 'external-fqdn', 'none'.
446+
endpoint_type (Optional[EndpointType]): Override for the endpoint type to use in CLIENT MAINT_NOTIFICATIONS.
455447
If None, the endpoint type will be automatically determined based on the host and TLS configuration.
456448
Defaults to None.
457449
@@ -461,19 +453,6 @@ def __init__(
461453
self.enabled = enabled
462454
self.relax_timeout = relax_timeout
463455
self.proactive_reconnect = proactive_reconnect
464-
465-
# Validate endpoint_type if provided
466-
if (
467-
endpoint_type is not None
468-
and endpoint_type not in EndpointType.get_valid_types()
469-
):
470-
valid_types = ", ".join(
471-
f"'{t}'" for t in sorted(EndpointType.get_valid_types())
472-
)
473-
raise ValueError(
474-
f"Invalid endpoint_type '{endpoint_type}'. Must be one of: {valid_types}"
475-
)
476-
477456
self.endpoint_type = endpoint_type
478457

479458
def __repr__(self) -> str:
@@ -497,7 +476,9 @@ def is_relax_timeouts_enabled(self) -> bool:
497476
"""
498477
return self.relax_timeout != -1
499478

500-
def get_endpoint_type(self, host: str, connection: "ConnectionInterface") -> str:
479+
def get_endpoint_type(
480+
self, host: str, connection: "ConnectionInterface"
481+
) -> EndpointType:
501482
"""
502483
Determine the appropriate endpoint type for CLIENT MAINT_NOTIFICATIONS command.
503484

tests/test_maintenance_events.py

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -674,23 +674,11 @@ class TestEndpointType:
674674

675675
def test_endpoint_type_constants(self):
676676
"""Test that the EndpointType constants are correct."""
677-
assert EndpointType.INTERNAL_IP == "internal-ip"
678-
assert EndpointType.INTERNAL_FQDN == "internal-fqdn"
679-
assert EndpointType.EXTERNAL_IP == "external-ip"
680-
assert EndpointType.EXTERNAL_FQDN == "external-fqdn"
681-
assert EndpointType.NONE == "none"
682-
683-
def test_get_valid_types(self):
684-
"""Test that get_valid_types returns the expected set."""
685-
valid_types = EndpointType.get_valid_types()
686-
expected_types = {
687-
"internal-ip",
688-
"internal-fqdn",
689-
"external-ip",
690-
"external-fqdn",
691-
"none",
692-
}
693-
assert valid_types == expected_types
677+
assert EndpointType.INTERNAL_IP.value == "internal-ip"
678+
assert EndpointType.INTERNAL_FQDN.value == "internal-fqdn"
679+
assert EndpointType.EXTERNAL_IP.value == "external-ip"
680+
assert EndpointType.EXTERNAL_FQDN.value == "external-fqdn"
681+
assert EndpointType.NONE.value == "none"
694682

695683

696684
class TestMaintenanceEventsConfigEndpointType:
@@ -724,15 +712,10 @@ def get_resolved_ip(self):
724712

725713
def test_config_validation_valid_endpoint_types(self):
726714
"""Test that MaintenanceEventsConfig accepts valid endpoint types."""
727-
for endpoint_type in EndpointType.get_valid_types():
715+
for endpoint_type in EndpointType:
728716
config = MaintenanceEventsConfig(endpoint_type=endpoint_type)
729717
assert config.endpoint_type == endpoint_type
730718

731-
def test_config_validation_invalid_endpoint_type(self):
732-
"""Test that MaintenanceEventsConfig raises ValueError for invalid endpoint type."""
733-
with pytest.raises(ValueError, match="Invalid endpoint_type"):
734-
MaintenanceEventsConfig(endpoint_type="invalid-type")
735-
736719
def test_config_validation_none_endpoint_type(self):
737720
"""Test that MaintenanceEventsConfig accepts None as endpoint type."""
738721
config = MaintenanceEventsConfig(endpoint_type=None)

0 commit comments

Comments
 (0)