@@ -282,6 +282,13 @@ def should_reconnect(self):
282
282
"""
283
283
pass
284
284
285
+ @abstractmethod
286
+ def get_resolved_ip (self ):
287
+ """
288
+ Get resolved ip address for the connection.
289
+ """
290
+ pass
291
+
285
292
@abstractmethod
286
293
def update_current_socket_timeout (self , relax_timeout : Optional [float ] = None ):
287
294
"""
@@ -421,32 +428,16 @@ def __init__(
421
428
parser_class = _RESP3Parser
422
429
self .set_parser (parser_class )
423
430
424
- if maintenance_events_config and maintenance_events_config .enabled :
425
- if maintenance_events_pool_handler :
426
- maintenance_events_pool_handler .set_connection (self )
427
- self ._parser .set_node_moving_push_handler (
428
- maintenance_events_pool_handler .handle_event
429
- )
430
- self ._maintenance_event_connection_handler = (
431
- MaintenanceEventConnectionHandler (self , maintenance_events_config )
432
- )
433
- self ._parser .set_maintenance_push_handler (
434
- self ._maintenance_event_connection_handler .handle_event
435
- )
431
+ self .maintenance_events_config = maintenance_events_config
432
+
433
+ # Set up maintenance events if enabled
434
+ self ._configure_maintenance_events (
435
+ maintenance_events_pool_handler ,
436
+ orig_host_address ,
437
+ orig_socket_timeout ,
438
+ orig_socket_connect_timeout ,
439
+ )
436
440
437
- self .orig_host_address = (
438
- orig_host_address if orig_host_address else self .host
439
- )
440
- self .orig_socket_timeout = (
441
- orig_socket_timeout if orig_socket_timeout else self .socket_timeout
442
- )
443
- self .orig_socket_connect_timeout = (
444
- orig_socket_connect_timeout
445
- if orig_socket_connect_timeout
446
- else self .socket_connect_timeout
447
- )
448
- else :
449
- self ._maintenance_event_connection_handler = None
450
441
self ._should_reconnect = False
451
442
self .maintenance_state = maintenance_state
452
443
@@ -505,6 +496,46 @@ def set_parser(self, parser_class):
505
496
"""
506
497
self ._parser = parser_class (socket_read_size = self ._socket_read_size )
507
498
499
+ def _configure_maintenance_events (
500
+ self ,
501
+ maintenance_events_pool_handler = None ,
502
+ orig_host_address = None ,
503
+ orig_socket_timeout = None ,
504
+ orig_socket_connect_timeout = None ,
505
+ ):
506
+ """Enable maintenance events by setting up handlers and storing original connection parameters."""
507
+ if (
508
+ not self .maintenance_events_config
509
+ or not self .maintenance_events_config .enabled
510
+ ):
511
+ self ._maintenance_event_connection_handler = None
512
+ return
513
+
514
+ # Set up pool handler if available
515
+ if maintenance_events_pool_handler :
516
+ self ._parser .set_node_moving_push_handler (
517
+ maintenance_events_pool_handler .handle_event
518
+ )
519
+
520
+ # Set up connection handler
521
+ self ._maintenance_event_connection_handler = MaintenanceEventConnectionHandler (
522
+ self , self .maintenance_events_config
523
+ )
524
+ self ._parser .set_maintenance_push_handler (
525
+ self ._maintenance_event_connection_handler .handle_event
526
+ )
527
+
528
+ # Store original connection parameters
529
+ self .orig_host_address = orig_host_address if orig_host_address else self .host
530
+ self .orig_socket_timeout = (
531
+ orig_socket_timeout if orig_socket_timeout else self .socket_timeout
532
+ )
533
+ self .orig_socket_connect_timeout = (
534
+ orig_socket_connect_timeout
535
+ if orig_socket_connect_timeout
536
+ else self .socket_connect_timeout
537
+ )
538
+
508
539
def set_maintenance_event_pool_handler (
509
540
self , maintenance_event_pool_handler : MaintenanceEventPoolHandler
510
541
):
@@ -652,6 +683,39 @@ def on_connect_check_health(self, check_health: bool = True):
652
683
):
653
684
raise ConnectionError ("Invalid RESP version" )
654
685
686
+ # Send maintenance notifications handshake if RESP3 is active and maintenance events are enabled
687
+ # and we have a host to determine the endpoint type from
688
+ if (
689
+ self .protocol not in [2 , "2" ]
690
+ and self .maintenance_events_config
691
+ and self .maintenance_events_config .enabled
692
+ and self ._maintenance_event_connection_handler
693
+ and hasattr (self , "host" )
694
+ ):
695
+ try :
696
+ endpoint_type = self .maintenance_events_config .get_endpoint_type (
697
+ self .host , self
698
+ )
699
+ self .send_command (
700
+ "CLIENT" ,
701
+ "MAINT_NOTIFICATIONS" ,
702
+ "ON" ,
703
+ "moving-endpoint-type" ,
704
+ endpoint_type .value ,
705
+ check_health = check_health ,
706
+ )
707
+ response = self .read_response ()
708
+ if str_if_bytes (response ) != "OK" :
709
+ raise ConnectionError (
710
+ "The server doesn't support maintenance notifications"
711
+ )
712
+ except Exception as e :
713
+ # Log warning but don't fail the connection
714
+ import logging
715
+
716
+ logger = logging .getLogger (__name__ )
717
+ logger .warning (f"Failed to enable maintenance notifications: { e } " )
718
+
655
719
# if a client_name is given, set it
656
720
if self .client_name :
657
721
self .send_command (
@@ -888,6 +952,56 @@ def re_auth(self):
888
952
self .read_response ()
889
953
self ._re_auth_token = None
890
954
955
+ def get_resolved_ip (self ) -> Optional [str ]:
956
+ """
957
+ Extract the resolved IP address from an
958
+ established connection or resolve it from the host.
959
+
960
+ First tries to get the actual IP from the socket (most accurate),
961
+ then falls back to DNS resolution if needed.
962
+
963
+ Args:
964
+ connection: The connection object to extract the IP from
965
+
966
+ Returns:
967
+ str: The resolved IP address, or None if it cannot be determined
968
+ """
969
+
970
+ # Method 1: Try to get the actual IP from the established socket connection
971
+ # This is most accurate as it shows the exact IP being used
972
+ try :
973
+ if self ._sock is not None :
974
+ peer_addr = self ._sock .getpeername ()
975
+ if peer_addr and len (peer_addr ) >= 1 :
976
+ # For TCP sockets, peer_addr is typically (host, port) tuple
977
+ # Return just the host part
978
+ return peer_addr [0 ]
979
+ except (AttributeError , OSError ):
980
+ # Socket might not be connected or getpeername() might fail
981
+ pass
982
+
983
+ # Method 2: Fallback to DNS resolution of the host
984
+ # This is less accurate but works when socket is not available
985
+ try :
986
+ host = getattr (self , "host" , "localhost" )
987
+ port = getattr (self , "port" , 6379 )
988
+ if host :
989
+ # Use getaddrinfo to resolve the hostname to IP
990
+ # This mimics what the connection would do during _connect()
991
+ addr_info = socket .getaddrinfo (
992
+ host , port , socket .AF_UNSPEC , socket .SOCK_STREAM
993
+ )
994
+ if addr_info :
995
+ # Return the IP from the first result
996
+ # addr_info[0] is (family, socktype, proto, canonname, sockaddr)
997
+ # sockaddr[0] is the IP address
998
+ return addr_info [0 ][4 ][0 ]
999
+ except (AttributeError , OSError , socket .gaierror ):
1000
+ # DNS resolution might fail
1001
+ pass
1002
+
1003
+ return None
1004
+
891
1005
@property
892
1006
def maintenance_state (self ) -> MaintenanceState :
893
1007
return self ._maintenance_state
0 commit comments