@@ -332,6 +332,7 @@ def __init__(
332
332
encoding : str = "utf-8" ,
333
333
encoding_errors : str = "strict" ,
334
334
decode_responses : bool = False ,
335
+ check_server_ready : bool = False ,
335
336
parser_class = DefaultParser ,
336
337
socket_read_size : int = 65536 ,
337
338
health_check_interval : int = 0 ,
@@ -408,6 +409,7 @@ def __init__(
408
409
self .redis_connect_func = redis_connect_func
409
410
self .encoder = Encoder (encoding , encoding_errors , decode_responses )
410
411
self .handshake_metadata = None
412
+ self .check_server_ready = check_server_ready
411
413
self ._sock = None
412
414
self ._socket_read_size = socket_read_size
413
415
self ._connect_callbacks = []
@@ -571,17 +573,17 @@ def connect_check_health(
571
573
return
572
574
try :
573
575
if retry_socket_connect :
574
- sock = self .retry .call_with_retry (
575
- lambda : self ._connect (), lambda error : self .disconnect (error )
576
+ self .retry .call_with_retry (
577
+ lambda : self ._connect_check_server_ready (),
578
+ lambda error : self .disconnect (error ),
576
579
)
577
580
else :
578
- sock = self ._connect ()
581
+ self ._connect_check_server_ready ()
579
582
except socket .timeout :
580
583
raise TimeoutError ("Timeout connecting to server" )
581
584
except OSError as e :
582
585
raise ConnectionError (self ._error_message (e ))
583
586
584
- self ._sock = sock
585
587
try :
586
588
if self .redis_connect_func is None :
587
589
# Use the default on_connect function
@@ -603,8 +605,27 @@ def connect_check_health(
603
605
if callback :
604
606
callback (self )
605
607
608
+ def _connect_check_server_ready (self ):
609
+ self ._connect ()
610
+
611
+ # Doing handshake since connect and send operations work even when Redis is not ready
612
+ if self .check_server_ready :
613
+ try :
614
+ self .send_command ("PING" , check_health = False )
615
+
616
+ response = str_if_bytes (self ._sock .recv (1024 ))
617
+ if not (response .startswith ("+PONG" ) or response .startswith ("-NOAUTH" )):
618
+ raise ResponseError (f"Invalid PING response: { response } " )
619
+ except (ConnectionResetError , ResponseError ) as err :
620
+ try :
621
+ self ._sock .shutdown (socket .SHUT_RDWR ) # ensure a clean close
622
+ except OSError :
623
+ pass
624
+ self ._sock .close ()
625
+ raise ConnectionError (self ._error_message (err ))
626
+
606
627
@abstractmethod
607
- def _connect (self ):
628
+ def _connect (self ) -> None :
608
629
pass
609
630
610
631
@abstractmethod
@@ -1083,7 +1104,7 @@ def repr_pieces(self):
1083
1104
pieces .append (("client_name" , self .client_name ))
1084
1105
return pieces
1085
1106
1086
- def _connect (self ):
1107
+ def _connect (self ) -> None :
1087
1108
"Create a TCP socket connection"
1088
1109
# we want to mimic what socket.create_connection does to support
1089
1110
# ipv4/ipv6, but we want to set options prior to calling
@@ -1114,7 +1135,8 @@ def _connect(self):
1114
1135
1115
1136
# set the socket_timeout now that we're connected
1116
1137
sock .settimeout (self .socket_timeout )
1117
- return sock
1138
+ self ._sock = sock
1139
+ return
1118
1140
1119
1141
except OSError as _ :
1120
1142
err = _
@@ -1427,15 +1449,15 @@ def __init__(
1427
1449
self .ssl_ciphers = ssl_ciphers
1428
1450
super ().__init__ (** kwargs )
1429
1451
1430
- def _connect (self ):
1452
+ def _connect (self ) -> None :
1431
1453
"""
1432
1454
Wrap the socket with SSL support, handling potential errors.
1433
1455
"""
1434
- sock = super ()._connect ()
1456
+ super ()._connect ()
1435
1457
try :
1436
- return self ._wrap_socket_with_ssl (sock )
1458
+ self . _sock = self ._wrap_socket_with_ssl (self . _sock )
1437
1459
except (OSError , RedisError ):
1438
- sock .close ()
1460
+ self . _sock .close ()
1439
1461
raise
1440
1462
1441
1463
def _wrap_socket_with_ssl (self , sock ):
@@ -1532,7 +1554,7 @@ def repr_pieces(self):
1532
1554
pieces .append (("client_name" , self .client_name ))
1533
1555
return pieces
1534
1556
1535
- def _connect (self ):
1557
+ def _connect (self ) -> None :
1536
1558
"Create a Unix domain socket connection"
1537
1559
sock = socket .socket (socket .AF_UNIX , socket .SOCK_STREAM )
1538
1560
sock .settimeout (self .socket_connect_timeout )
@@ -1547,7 +1569,7 @@ def _connect(self):
1547
1569
sock .close ()
1548
1570
raise
1549
1571
sock .settimeout (self .socket_timeout )
1550
- return sock
1572
+ self . _sock = sock
1551
1573
1552
1574
def _host_error (self ):
1553
1575
return self .path
0 commit comments