@@ -237,6 +237,7 @@ def __init__(
237
237
encoding : str = "utf-8" ,
238
238
encoding_errors : str = "strict" ,
239
239
decode_responses : bool = False ,
240
+ check_server_ready : bool = False ,
240
241
parser_class = DefaultParser ,
241
242
socket_read_size : int = 65536 ,
242
243
health_check_interval : int = 0 ,
@@ -303,6 +304,7 @@ def __init__(
303
304
self .redis_connect_func = redis_connect_func
304
305
self .encoder = Encoder (encoding , encoding_errors , decode_responses )
305
306
self .handshake_metadata = None
307
+ self .check_server_ready = check_server_ready
306
308
self ._sock = None
307
309
self ._socket_read_size = socket_read_size
308
310
self .set_parser (parser_class )
@@ -386,17 +388,17 @@ def connect_check_health(
386
388
return
387
389
try :
388
390
if retry_socket_connect :
389
- sock = self .retry .call_with_retry (
390
- lambda : self ._connect (), lambda error : self .disconnect (error )
391
+ self .retry .call_with_retry (
392
+ lambda : self ._connect_check_server_ready (),
393
+ lambda error : self .disconnect (error ),
391
394
)
392
395
else :
393
- sock = self ._connect ()
396
+ self ._connect_check_server_ready ()
394
397
except socket .timeout :
395
398
raise TimeoutError ("Timeout connecting to server" )
396
399
except OSError as e :
397
400
raise ConnectionError (self ._error_message (e ))
398
401
399
- self ._sock = sock
400
402
try :
401
403
if self .redis_connect_func is None :
402
404
# Use the default on_connect function
@@ -418,8 +420,27 @@ def connect_check_health(
418
420
if callback :
419
421
callback (self )
420
422
423
+ def _connect_check_server_ready (self ):
424
+ self ._connect ()
425
+
426
+ # Doing handshake since connect and send operations work even when Redis is not ready
427
+ if self .check_server_ready :
428
+ try :
429
+ self .send_command ("PING" , check_health = False )
430
+
431
+ response = str_if_bytes (self ._sock .recv (1024 ))
432
+ if not (response .startswith ("+PONG" ) or response .startswith ("-NOAUTH" )):
433
+ raise ResponseError (f"Invalid PING response: { response } " )
434
+ except (ConnectionResetError , ResponseError ) as err :
435
+ try :
436
+ self ._sock .shutdown (socket .SHUT_RDWR ) # ensure a clean close
437
+ except OSError :
438
+ pass
439
+ self ._sock .close ()
440
+ raise ConnectionError (self ._error_message (err ))
441
+
421
442
@abstractmethod
422
- def _connect (self ):
443
+ def _connect (self ) -> None :
423
444
pass
424
445
425
446
@abstractmethod
@@ -758,7 +779,7 @@ def repr_pieces(self):
758
779
pieces .append (("client_name" , self .client_name ))
759
780
return pieces
760
781
761
- def _connect (self ):
782
+ def _connect (self ) -> None :
762
783
"Create a TCP socket connection"
763
784
# we want to mimic what socket.create_connection does to support
764
785
# ipv4/ipv6, but we want to set options prior to calling
@@ -788,7 +809,8 @@ def _connect(self):
788
809
789
810
# set the socket_timeout now that we're connected
790
811
sock .settimeout (self .socket_timeout )
791
- return sock
812
+ self ._sock = sock
813
+ return
792
814
793
815
except OSError as _ :
794
816
err = _
@@ -1101,15 +1123,15 @@ def __init__(
1101
1123
self .ssl_ciphers = ssl_ciphers
1102
1124
super ().__init__ (** kwargs )
1103
1125
1104
- def _connect (self ):
1126
+ def _connect (self ) -> None :
1105
1127
"""
1106
1128
Wrap the socket with SSL support, handling potential errors.
1107
1129
"""
1108
- sock = super ()._connect ()
1130
+ super ()._connect ()
1109
1131
try :
1110
- return self ._wrap_socket_with_ssl (sock )
1132
+ self . _sock = self ._wrap_socket_with_ssl (self . _sock )
1111
1133
except (OSError , RedisError ):
1112
- sock .close ()
1134
+ self . _sock .close ()
1113
1135
raise
1114
1136
1115
1137
def _wrap_socket_with_ssl (self , sock ):
@@ -1206,7 +1228,7 @@ def repr_pieces(self):
1206
1228
pieces .append (("client_name" , self .client_name ))
1207
1229
return pieces
1208
1230
1209
- def _connect (self ):
1231
+ def _connect (self ) -> None :
1210
1232
"Create a Unix domain socket connection"
1211
1233
sock = socket .socket (socket .AF_UNIX , socket .SOCK_STREAM )
1212
1234
sock .settimeout (self .socket_connect_timeout )
@@ -1221,7 +1243,7 @@ def _connect(self):
1221
1243
sock .close ()
1222
1244
raise
1223
1245
sock .settimeout (self .socket_timeout )
1224
- return sock
1246
+ self . _sock = sock
1225
1247
1226
1248
def _host_error (self ):
1227
1249
return self .path
0 commit comments