Skip to content

Commit 193219d

Browse files
Hitless upgrade: Adding handshake command to enable the notifications after connection is established (#3735)
1 parent 83f84cd commit 193219d

File tree

5 files changed

+509
-82
lines changed

5 files changed

+509
-82
lines changed

redis/_parsers/base.py

Lines changed: 74 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
import sys
23
from abc import ABC
34
from asyncio import IncompleteReadError, StreamReader, TimeoutError
@@ -56,6 +57,8 @@
5657
"Client sent AUTH, but no password is set": AuthenticationError,
5758
}
5859

60+
logger = logging.getLogger(__name__)
61+
5962

6063
class BaseParser(ABC):
6164
EXCEPTION_CLASSES = {
@@ -199,31 +202,42 @@ def handle_push_response(self, response, **kwargs):
199202
*_MOVING_MESSAGE,
200203
):
201204
return self.pubsub_push_handler_func(response)
202-
if msg_type in _INVALIDATION_MESSAGE and self.invalidation_push_handler_func:
203-
return self.invalidation_push_handler_func(response)
204-
if msg_type in _MOVING_MESSAGE and self.node_moving_push_handler_func:
205-
# TODO: PARSE latest format when available
206-
host, port = response[2].decode().split(":")
207-
ttl = response[1]
208-
id = 1 # Hardcoded value until the notification starts including the id
209-
notification = NodeMovingEvent(id, host, port, ttl)
210-
return self.node_moving_push_handler_func(notification)
211-
if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func:
212-
if msg_type in _MIGRATING_MESSAGE:
213-
# TODO: PARSE latest format when available
214-
ttl = response[1]
215-
id = 2 # Hardcoded value until the notification starts including the id
216-
notification = NodeMigratingEvent(id, ttl)
217-
elif msg_type in _MIGRATED_MESSAGE:
218-
# TODO: PARSE latest format when available
219-
id = 3 # Hardcoded value until the notification starts including the id
220-
notification = NodeMigratedEvent(id)
221-
else:
205+
206+
try:
207+
if (
208+
msg_type in _INVALIDATION_MESSAGE
209+
and self.invalidation_push_handler_func
210+
):
211+
return self.invalidation_push_handler_func(response)
212+
213+
if msg_type in _MOVING_MESSAGE and self.node_moving_push_handler_func:
214+
# Expected message format is: MOVING <seq_number> <time> <endpoint>
215+
id = response[1]
216+
ttl = response[2]
217+
host, port = response[3].decode().split(":")
218+
notification = NodeMovingEvent(id, host, port, ttl)
219+
return self.node_moving_push_handler_func(notification)
220+
221+
if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func:
222222
notification = None
223-
if notification is not None:
224-
return self.maintenance_push_handler_func(notification)
225-
else:
226-
return None
223+
224+
if msg_type in _MIGRATING_MESSAGE:
225+
# Expected message format is: MIGRATING <seq_number> <time> <shard_id-s>
226+
id = response[1]
227+
ttl = response[2]
228+
notification = NodeMigratingEvent(id, ttl)
229+
elif msg_type in _MIGRATED_MESSAGE:
230+
id = response[1]
231+
notification = NodeMigratedEvent(id)
232+
233+
if notification is not None:
234+
return self.maintenance_push_handler_func(notification)
235+
except Exception as e:
236+
logger.error(
237+
"Error handling {} message ({}): {}".format(msg_type, response, e)
238+
)
239+
240+
return None
227241

228242
def set_pubsub_push_handler(self, pubsub_push_handler_func):
229243
self.pubsub_push_handler_func = pubsub_push_handler_func
@@ -252,34 +266,49 @@ async def handle_pubsub_push_response(self, response):
252266

253267
async def handle_push_response(self, response, **kwargs):
254268
"""Handle push responses asynchronously"""
269+
255270
msg_type = response[0]
256271
if msg_type not in (
257272
*_INVALIDATION_MESSAGE,
258273
*_MAINTENANCE_MESSAGES,
259274
*_MOVING_MESSAGE,
260275
):
261276
return await self.pubsub_push_handler_func(response)
262-
if msg_type in _INVALIDATION_MESSAGE and self.invalidation_push_handler_func:
263-
return await self.invalidation_push_handler_func(response)
264-
if msg_type in _MOVING_MESSAGE and self.node_moving_push_handler_func:
265-
# push notification from enterprise cluster for node moving
266-
# TODO: PARSE latest format when available
267-
host, port = response[2].split(":")
268-
ttl = response[1]
269-
id = 1 # Hardcoded value for async parser
270-
notification = NodeMovingEvent(id, host, port, ttl)
271-
return await self.node_moving_push_handler_func(notification)
272-
if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func:
273-
if msg_type in _MIGRATING_MESSAGE:
274-
# TODO: PARSE latest format when available
275-
ttl = response[1]
276-
id = 2 # Hardcoded value for async parser
277-
notification = NodeMigratingEvent(id, ttl)
278-
elif msg_type in _MIGRATED_MESSAGE:
279-
# TODO: PARSE latest format when available
280-
id = 3 # Hardcoded value for async parser
281-
notification = NodeMigratedEvent(id)
282-
return await self.maintenance_push_handler_func(notification)
277+
278+
try:
279+
if (
280+
msg_type in _INVALIDATION_MESSAGE
281+
and self.invalidation_push_handler_func
282+
):
283+
return await self.invalidation_push_handler_func(response)
284+
285+
if msg_type in _MOVING_MESSAGE and self.node_moving_push_handler_func:
286+
# push notification from enterprise cluster for node moving
287+
id = response[1]
288+
ttl = response[2]
289+
host, port = response[3].split(":")
290+
notification = NodeMovingEvent(id, host, port, ttl)
291+
return await self.node_moving_push_handler_func(notification)
292+
293+
if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func:
294+
notification = None
295+
296+
if msg_type in _MIGRATING_MESSAGE:
297+
id = response[1]
298+
ttl = response[2]
299+
notification = NodeMigratingEvent(id, ttl)
300+
elif msg_type in _MIGRATED_MESSAGE:
301+
id = response[1]
302+
notification = NodeMigratedEvent(id)
303+
304+
if notification is not None:
305+
return await self.maintenance_push_handler_func(notification)
306+
except Exception as e:
307+
logger.error(
308+
"Error handling {} message ({}): {}".format(msg_type, response, e)
309+
)
310+
311+
return None
283312

284313
def set_pubsub_push_handler(self, pubsub_push_handler_func):
285314
"""Set the pubsub push handler function"""

redis/connection.py

Lines changed: 139 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,13 @@ def should_reconnect(self):
282282
"""
283283
pass
284284

285+
@abstractmethod
286+
def get_resolved_ip(self):
287+
"""
288+
Get resolved ip address for the connection.
289+
"""
290+
pass
291+
285292
@abstractmethod
286293
def update_current_socket_timeout(self, relax_timeout: Optional[float] = None):
287294
"""
@@ -421,32 +428,16 @@ def __init__(
421428
parser_class = _RESP3Parser
422429
self.set_parser(parser_class)
423430

424-
if maintenance_events_config and maintenance_events_config.enabled:
425-
if maintenance_events_pool_handler:
426-
maintenance_events_pool_handler.set_connection(self)
427-
self._parser.set_node_moving_push_handler(
428-
maintenance_events_pool_handler.handle_event
429-
)
430-
self._maintenance_event_connection_handler = (
431-
MaintenanceEventConnectionHandler(self, maintenance_events_config)
432-
)
433-
self._parser.set_maintenance_push_handler(
434-
self._maintenance_event_connection_handler.handle_event
435-
)
431+
self.maintenance_events_config = maintenance_events_config
432+
433+
# Set up maintenance events if enabled
434+
self._configure_maintenance_events(
435+
maintenance_events_pool_handler,
436+
orig_host_address,
437+
orig_socket_timeout,
438+
orig_socket_connect_timeout,
439+
)
436440

437-
self.orig_host_address = (
438-
orig_host_address if orig_host_address else self.host
439-
)
440-
self.orig_socket_timeout = (
441-
orig_socket_timeout if orig_socket_timeout else self.socket_timeout
442-
)
443-
self.orig_socket_connect_timeout = (
444-
orig_socket_connect_timeout
445-
if orig_socket_connect_timeout
446-
else self.socket_connect_timeout
447-
)
448-
else:
449-
self._maintenance_event_connection_handler = None
450441
self._should_reconnect = False
451442
self.maintenance_state = maintenance_state
452443

@@ -505,6 +496,46 @@ def set_parser(self, parser_class):
505496
"""
506497
self._parser = parser_class(socket_read_size=self._socket_read_size)
507498

499+
def _configure_maintenance_events(
500+
self,
501+
maintenance_events_pool_handler=None,
502+
orig_host_address=None,
503+
orig_socket_timeout=None,
504+
orig_socket_connect_timeout=None,
505+
):
506+
"""Enable maintenance events by setting up handlers and storing original connection parameters."""
507+
if (
508+
not self.maintenance_events_config
509+
or not self.maintenance_events_config.enabled
510+
):
511+
self._maintenance_event_connection_handler = None
512+
return
513+
514+
# Set up pool handler if available
515+
if maintenance_events_pool_handler:
516+
self._parser.set_node_moving_push_handler(
517+
maintenance_events_pool_handler.handle_event
518+
)
519+
520+
# Set up connection handler
521+
self._maintenance_event_connection_handler = MaintenanceEventConnectionHandler(
522+
self, self.maintenance_events_config
523+
)
524+
self._parser.set_maintenance_push_handler(
525+
self._maintenance_event_connection_handler.handle_event
526+
)
527+
528+
# Store original connection parameters
529+
self.orig_host_address = orig_host_address if orig_host_address else self.host
530+
self.orig_socket_timeout = (
531+
orig_socket_timeout if orig_socket_timeout else self.socket_timeout
532+
)
533+
self.orig_socket_connect_timeout = (
534+
orig_socket_connect_timeout
535+
if orig_socket_connect_timeout
536+
else self.socket_connect_timeout
537+
)
538+
508539
def set_maintenance_event_pool_handler(
509540
self, maintenance_event_pool_handler: MaintenanceEventPoolHandler
510541
):
@@ -652,6 +683,39 @@ def on_connect_check_health(self, check_health: bool = True):
652683
):
653684
raise ConnectionError("Invalid RESP version")
654685

686+
# Send maintenance notifications handshake if RESP3 is active and maintenance events are enabled
687+
# and we have a host to determine the endpoint type from
688+
if (
689+
self.protocol not in [2, "2"]
690+
and self.maintenance_events_config
691+
and self.maintenance_events_config.enabled
692+
and self._maintenance_event_connection_handler
693+
and hasattr(self, "host")
694+
):
695+
try:
696+
endpoint_type = self.maintenance_events_config.get_endpoint_type(
697+
self.host, self
698+
)
699+
self.send_command(
700+
"CLIENT",
701+
"MAINT_NOTIFICATIONS",
702+
"ON",
703+
"moving-endpoint-type",
704+
endpoint_type.value,
705+
check_health=check_health,
706+
)
707+
response = self.read_response()
708+
if str_if_bytes(response) != "OK":
709+
raise ConnectionError(
710+
"The server doesn't support maintenance notifications"
711+
)
712+
except Exception as e:
713+
# Log warning but don't fail the connection
714+
import logging
715+
716+
logger = logging.getLogger(__name__)
717+
logger.warning(f"Failed to enable maintenance notifications: {e}")
718+
655719
# if a client_name is given, set it
656720
if self.client_name:
657721
self.send_command(
@@ -888,6 +952,56 @@ def re_auth(self):
888952
self.read_response()
889953
self._re_auth_token = None
890954

955+
def get_resolved_ip(self) -> Optional[str]:
956+
"""
957+
Extract the resolved IP address from an
958+
established connection or resolve it from the host.
959+
960+
First tries to get the actual IP from the socket (most accurate),
961+
then falls back to DNS resolution if needed.
962+
963+
Args:
964+
connection: The connection object to extract the IP from
965+
966+
Returns:
967+
str: The resolved IP address, or None if it cannot be determined
968+
"""
969+
970+
# Method 1: Try to get the actual IP from the established socket connection
971+
# This is most accurate as it shows the exact IP being used
972+
try:
973+
if self._sock is not None:
974+
peer_addr = self._sock.getpeername()
975+
if peer_addr and len(peer_addr) >= 1:
976+
# For TCP sockets, peer_addr is typically (host, port) tuple
977+
# Return just the host part
978+
return peer_addr[0]
979+
except (AttributeError, OSError):
980+
# Socket might not be connected or getpeername() might fail
981+
pass
982+
983+
# Method 2: Fallback to DNS resolution of the host
984+
# This is less accurate but works when socket is not available
985+
try:
986+
host = getattr(self, "host", "localhost")
987+
port = getattr(self, "port", 6379)
988+
if host:
989+
# Use getaddrinfo to resolve the hostname to IP
990+
# This mimics what the connection would do during _connect()
991+
addr_info = socket.getaddrinfo(
992+
host, port, socket.AF_UNSPEC, socket.SOCK_STREAM
993+
)
994+
if addr_info:
995+
# Return the IP from the first result
996+
# addr_info[0] is (family, socktype, proto, canonname, sockaddr)
997+
# sockaddr[0] is the IP address
998+
return addr_info[0][4][0]
999+
except (AttributeError, OSError, socket.gaierror):
1000+
# DNS resolution might fail
1001+
pass
1002+
1003+
return None
1004+
8911005
@property
8921006
def maintenance_state(self) -> MaintenanceState:
8931007
return self._maintenance_state

0 commit comments

Comments
 (0)