@@ -236,6 +236,7 @@ def __init__(
236
236
encoding : str = "utf-8" ,
237
237
encoding_errors : str = "strict" ,
238
238
decode_responses : bool = False ,
239
+ check_server_ready : bool = False ,
239
240
parser_class = DefaultParser ,
240
241
socket_read_size : int = 65536 ,
241
242
health_check_interval : int = 0 ,
@@ -302,6 +303,7 @@ def __init__(
302
303
self .redis_connect_func = redis_connect_func
303
304
self .encoder = Encoder (encoding , encoding_errors , decode_responses )
304
305
self .handshake_metadata = None
306
+ self .check_server_ready = check_server_ready
305
307
self ._sock = None
306
308
self ._socket_read_size = socket_read_size
307
309
self .set_parser (parser_class )
@@ -382,15 +384,15 @@ def connect_check_health(self, check_health: bool = True):
382
384
if self ._sock :
383
385
return
384
386
try :
385
- sock = self .retry .call_with_retry (
386
- lambda : self ._connect (), lambda error : self .disconnect (error )
387
+ self .retry .call_with_retry (
388
+ lambda : self ._connect_check_server_ready (),
389
+ lambda error : self .disconnect (error ),
387
390
)
388
391
except socket .timeout :
389
392
raise TimeoutError ("Timeout connecting to server" )
390
393
except OSError as e :
391
394
raise ConnectionError (self ._error_message (e ))
392
395
393
- self ._sock = sock
394
396
try :
395
397
if self .redis_connect_func is None :
396
398
# Use the default on_connect function
@@ -412,8 +414,27 @@ def connect_check_health(self, check_health: bool = True):
412
414
if callback :
413
415
callback (self )
414
416
417
+ def _connect_check_server_ready (self ):
418
+ self ._connect ()
419
+
420
+ # Doing handshake since connect and send operations work even when Redis is not ready
421
+ if self .check_server_ready :
422
+ try :
423
+ self .send_command ("PING" , check_health = False )
424
+
425
+ response = str_if_bytes (self ._sock .recv (1024 ))
426
+ if not (response .startswith ("+PONG" ) or response .startswith ("-NOAUTH" )):
427
+ raise ResponseError (f"Invalid PING response: { response } " )
428
+ except (ConnectionResetError , ResponseError ) as err :
429
+ try :
430
+ self ._sock .shutdown (socket .SHUT_RDWR ) # ensure a clean close
431
+ except OSError :
432
+ pass
433
+ self ._sock .close ()
434
+ raise ConnectionError (self ._error_message (err ))
435
+
415
436
@abstractmethod
416
- def _connect (self ):
437
+ def _connect (self ) -> None :
417
438
pass
418
439
419
440
@abstractmethod
@@ -752,7 +773,7 @@ def repr_pieces(self):
752
773
pieces .append (("client_name" , self .client_name ))
753
774
return pieces
754
775
755
- def _connect (self ):
776
+ def _connect (self ) -> None :
756
777
"Create a TCP socket connection"
757
778
# we want to mimic what socket.create_connection does to support
758
779
# ipv4/ipv6, but we want to set options prior to calling
@@ -782,7 +803,8 @@ def _connect(self):
782
803
783
804
# set the socket_timeout now that we're connected
784
805
sock .settimeout (self .socket_timeout )
785
- return sock
806
+ self ._sock = sock
807
+ return
786
808
787
809
except OSError as _ :
788
810
err = _
@@ -1095,15 +1117,15 @@ def __init__(
1095
1117
self .ssl_ciphers = ssl_ciphers
1096
1118
super ().__init__ (** kwargs )
1097
1119
1098
- def _connect (self ):
1120
+ def _connect (self ) -> None :
1099
1121
"""
1100
1122
Wrap the socket with SSL support, handling potential errors.
1101
1123
"""
1102
- sock = super ()._connect ()
1124
+ super ()._connect ()
1103
1125
try :
1104
- return self ._wrap_socket_with_ssl (sock )
1126
+ self . _sock = self ._wrap_socket_with_ssl (self . _sock )
1105
1127
except (OSError , RedisError ):
1106
- sock .close ()
1128
+ self . _sock .close ()
1107
1129
raise
1108
1130
1109
1131
def _wrap_socket_with_ssl (self , sock ):
@@ -1200,7 +1222,7 @@ def repr_pieces(self):
1200
1222
pieces .append (("client_name" , self .client_name ))
1201
1223
return pieces
1202
1224
1203
- def _connect (self ):
1225
+ def _connect (self ) -> None :
1204
1226
"Create a Unix domain socket connection"
1205
1227
sock = socket .socket (socket .AF_UNIX , socket .SOCK_STREAM )
1206
1228
sock .settimeout (self .socket_connect_timeout )
@@ -1215,7 +1237,7 @@ def _connect(self):
1215
1237
sock .close ()
1216
1238
raise
1217
1239
sock .settimeout (self .socket_timeout )
1218
- return sock
1240
+ self . _sock = sock
1219
1241
1220
1242
def _host_error (self ):
1221
1243
return self .path
0 commit comments