Skip to content

Commit 119ec97

Browse files
Add Redis readiness verification (#3555)
1 parent 03f4125 commit 119ec97

File tree

11 files changed

+476
-127
lines changed

11 files changed

+476
-127
lines changed

redis/asyncio/client.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ def __init__(
229229
encoding: str = "utf-8",
230230
encoding_errors: str = "strict",
231231
decode_responses: bool = False,
232+
check_server_ready: bool = False,
232233
retry_on_timeout: bool = False,
233234
retry: Retry = Retry(
234235
backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3
@@ -276,6 +277,10 @@ def __init__(
276277
277278
When 'connection_pool' is provided - the retry configuration of the
278279
provided pool will be used.
280+
281+
Args:
282+
check_server_ready: if `True`, an extra handshake is performed by sending a PING command, since
283+
connect and send operations work even when Redis server is not ready.
279284
"""
280285
kwargs: Dict[str, Any]
281286
if event_dispatcher is None:
@@ -310,6 +315,7 @@ def __init__(
310315
"encoding": encoding,
311316
"encoding_errors": encoding_errors,
312317
"decode_responses": decode_responses,
318+
"check_server_ready": check_server_ready,
313319
"retry_on_error": retry_on_error,
314320
"retry": copy.deepcopy(retry),
315321
"max_connections": max_connections,

redis/asyncio/cluster.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ def __init__(
289289
encoding_errors: str = "strict",
290290
decode_responses: bool = False,
291291
# Connection related kwargs
292+
check_server_ready: bool = False,
292293
health_check_interval: float = 0,
293294
socket_connect_timeout: Optional[float] = None,
294295
socket_keepalive: bool = False,
@@ -342,6 +343,7 @@ def __init__(
342343
"encoding_errors": encoding_errors,
343344
"decode_responses": decode_responses,
344345
# Connection related kwargs
346+
"check_server_ready": check_server_ready,
345347
"health_check_interval": health_check_interval,
346348
"socket_connect_timeout": socket_connect_timeout,
347349
"socket_keepalive": socket_keepalive,

redis/asyncio/connection.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def __init__(
148148
encoding_errors: str = "strict",
149149
decode_responses: bool = False,
150150
parser_class: Type[BaseParser] = DefaultParser,
151+
check_server_ready: bool = False,
151152
socket_read_size: int = 65536,
152153
health_check_interval: float = 0,
153154
client_name: Optional[str] = None,
@@ -204,6 +205,7 @@ def __init__(
204205
self.health_check_interval = health_check_interval
205206
self.next_health_check: float = -1
206207
self.encoder = encoder_class(encoding, encoding_errors, decode_responses)
208+
self.check_server_ready = check_server_ready
207209
self.redis_connect_func = redis_connect_func
208210
self._reader: Optional[asyncio.StreamReader] = None
209211
self._writer: Optional[asyncio.StreamWriter] = None
@@ -303,11 +305,13 @@ async def connect_check_health(
303305
try:
304306
if retry_socket_connect:
305307
await self.retry.call_with_retry(
306-
lambda: self._connect(), lambda error: self.disconnect()
308+
lambda: self._connect_check_server_ready(),
309+
lambda error: self.disconnect(),
307310
)
308311
else:
309-
await self._connect()
312+
await self._connect_check_server_ready()
310313
except asyncio.CancelledError:
314+
self._close()
311315
raise # in 3.7 and earlier, this is an Exception, not BaseException
312316
except (socket.timeout, asyncio.TimeoutError):
313317
raise TimeoutError("Timeout connecting to server")
@@ -342,6 +346,33 @@ async def connect_check_health(
342346
if task and inspect.isawaitable(task):
343347
await task
344348

349+
async def _connect_check_server_ready(self):
350+
await self._connect()
351+
352+
# Doing handshake since connect and send operations work even when Redis is not ready
353+
if self.check_server_ready:
354+
try:
355+
await self.send_command("PING", check_health=False)
356+
357+
if self.socket_timeout is not None:
358+
async with async_timeout(self.socket_timeout):
359+
response = str_if_bytes(await self._reader.read(1024))
360+
else:
361+
response = str_if_bytes(await self._reader.read(1024))
362+
363+
if not (response.startswith("+PONG") or response.startswith("-NOAUTH")):
364+
raise ResponseError(f"Invalid PING response: {response}")
365+
except (
366+
socket.timeout,
367+
asyncio.TimeoutError,
368+
ResponseError,
369+
ConnectionResetError,
370+
) as e:
371+
# `socket_keepalive_options` might contain invalid options
372+
# causing an error. Do not leave the connection open.
373+
self._close()
374+
raise ConnectionError(self._error_message(e))
375+
345376
@abstractmethod
346377
async def _connect(self):
347378
pass
@@ -531,8 +562,7 @@ async def send_packed_command(
531562
self._send_packed_command(command), self.socket_timeout
532563
)
533564
else:
534-
self._writer.writelines(command)
535-
await self._writer.drain()
565+
await self._send_packed_command(command)
536566
except asyncio.TimeoutError:
537567
await self.disconnect(nowait=True)
538568
raise TimeoutError("Timeout writing to socket") from None
@@ -775,7 +805,7 @@ async def _connect(self):
775805
except (OSError, TypeError):
776806
# `socket_keepalive_options` might contain invalid options
777807
# causing an error. Do not leave the connection open.
778-
writer.close()
808+
self._close()
779809
raise
780810

781811
def _host_error(self) -> str:
@@ -936,7 +966,6 @@ async def _connect(self):
936966
reader, writer = await asyncio.open_unix_connection(path=self.path)
937967
self._reader = reader
938968
self._writer = writer
939-
await self.on_connect()
940969

941970
def _host_error(self) -> str:
942971
return self.path

redis/client.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def __init__(
215215
encoding: str = "utf-8",
216216
encoding_errors: str = "strict",
217217
decode_responses: bool = False,
218+
check_server_ready: bool = False,
218219
retry_on_timeout: bool = False,
219220
retry: Retry = Retry(
220221
backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3
@@ -272,10 +273,11 @@ def __init__(
272273
provided pool will be used.
273274
274275
Args:
275-
276-
single_connection_client:
277-
if `True`, connection pool is not used. In that case `Redis`
278-
instance use is not thread safe.
276+
check_server_ready: if `True`, an extra handshake is performed by sending a PING command, since
277+
connect and send operations work even when Redis server is not ready.
278+
single_connection_client:
279+
if `True`, connection pool is not used. In that case `Redis`
280+
instance use is not thread safe.
279281
"""
280282
if event_dispatcher is None:
281283
self._event_dispatcher = EventDispatcher()
@@ -292,6 +294,7 @@ def __init__(
292294
"encoding": encoding,
293295
"encoding_errors": encoding_errors,
294296
"decode_responses": decode_responses,
297+
"check_server_ready": check_server_ready,
295298
"retry_on_error": retry_on_error,
296299
"retry": copy.deepcopy(retry),
297300
"max_connections": max_connections,

redis/connection.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ def __init__(
332332
encoding: str = "utf-8",
333333
encoding_errors: str = "strict",
334334
decode_responses: bool = False,
335+
check_server_ready: bool = False,
335336
parser_class=DefaultParser,
336337
socket_read_size: int = 65536,
337338
health_check_interval: int = 0,
@@ -408,6 +409,7 @@ def __init__(
408409
self.redis_connect_func = redis_connect_func
409410
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
410411
self.handshake_metadata = None
412+
self.check_server_ready = check_server_ready
411413
self._sock = None
412414
self._socket_read_size = socket_read_size
413415
self._connect_callbacks = []
@@ -571,17 +573,17 @@ def connect_check_health(
571573
return
572574
try:
573575
if retry_socket_connect:
574-
sock = self.retry.call_with_retry(
575-
lambda: self._connect(), lambda error: self.disconnect(error)
576+
self.retry.call_with_retry(
577+
lambda: self._connect_check_server_ready(),
578+
lambda error: self.disconnect(error),
576579
)
577580
else:
578-
sock = self._connect()
581+
self._connect_check_server_ready()
579582
except socket.timeout:
580583
raise TimeoutError("Timeout connecting to server")
581584
except OSError as e:
582585
raise ConnectionError(self._error_message(e))
583586

584-
self._sock = sock
585587
try:
586588
if self.redis_connect_func is None:
587589
# Use the default on_connect function
@@ -603,8 +605,27 @@ def connect_check_health(
603605
if callback:
604606
callback(self)
605607

608+
def _connect_check_server_ready(self):
609+
self._connect()
610+
611+
# Doing handshake since connect and send operations work even when Redis is not ready
612+
if self.check_server_ready:
613+
try:
614+
self.send_command("PING", check_health=False)
615+
616+
response = str_if_bytes(self._sock.recv(1024))
617+
if not (response.startswith("+PONG") or response.startswith("-NOAUTH")):
618+
raise ResponseError(f"Invalid PING response: {response}")
619+
except (ConnectionResetError, ResponseError) as err:
620+
try:
621+
self._sock.shutdown(socket.SHUT_RDWR) # ensure a clean close
622+
except OSError:
623+
pass
624+
self._sock.close()
625+
raise ConnectionError(self._error_message(err))
626+
606627
@abstractmethod
607-
def _connect(self):
628+
def _connect(self) -> None:
608629
pass
609630

610631
@abstractmethod
@@ -1083,7 +1104,7 @@ def repr_pieces(self):
10831104
pieces.append(("client_name", self.client_name))
10841105
return pieces
10851106

1086-
def _connect(self):
1107+
def _connect(self) -> None:
10871108
"Create a TCP socket connection"
10881109
# we want to mimic what socket.create_connection does to support
10891110
# ipv4/ipv6, but we want to set options prior to calling
@@ -1114,7 +1135,8 @@ def _connect(self):
11141135

11151136
# set the socket_timeout now that we're connected
11161137
sock.settimeout(self.socket_timeout)
1117-
return sock
1138+
self._sock = sock
1139+
return
11181140

11191141
except OSError as _:
11201142
err = _
@@ -1427,15 +1449,15 @@ def __init__(
14271449
self.ssl_ciphers = ssl_ciphers
14281450
super().__init__(**kwargs)
14291451

1430-
def _connect(self):
1452+
def _connect(self) -> None:
14311453
"""
14321454
Wrap the socket with SSL support, handling potential errors.
14331455
"""
1434-
sock = super()._connect()
1456+
super()._connect()
14351457
try:
1436-
return self._wrap_socket_with_ssl(sock)
1458+
self._sock = self._wrap_socket_with_ssl(self._sock)
14371459
except (OSError, RedisError):
1438-
sock.close()
1460+
self._sock.close()
14391461
raise
14401462

14411463
def _wrap_socket_with_ssl(self, sock):
@@ -1532,7 +1554,7 @@ def repr_pieces(self):
15321554
pieces.append(("client_name", self.client_name))
15331555
return pieces
15341556

1535-
def _connect(self):
1557+
def _connect(self) -> None:
15361558
"Create a Unix domain socket connection"
15371559
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
15381560
sock.settimeout(self.socket_connect_timeout)
@@ -1547,7 +1569,7 @@ def _connect(self):
15471569
sock.close()
15481570
raise
15491571
sock.settimeout(self.socket_timeout)
1550-
return sock
1572+
self._sock = sock
15511573

15521574
def _host_error(self):
15531575
return self.path

tests/test_asyncio/test_cluster.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,7 @@ async def test_reading_with_load_balancing_strategies(
729729
Connection,
730730
send_command=mock.DEFAULT,
731731
read_response=mock.DEFAULT,
732-
_connect=mock.DEFAULT,
732+
_connect_check_server_ready=mock.DEFAULT,
733733
can_read_destructive=mock.DEFAULT,
734734
on_connect=mock.DEFAULT,
735735
) as mocks:
@@ -761,7 +761,7 @@ def execute_command_mock_third(self, *args, **options):
761761
execute_command.side_effect = execute_command_mock_first
762762
mocks["send_command"].return_value = True
763763
mocks["read_response"].return_value = "OK"
764-
mocks["_connect"].return_value = True
764+
mocks["_connect_check_server_ready"].return_value = True
765765
mocks["can_read_destructive"].return_value = False
766766
mocks["on_connect"].return_value = True
767767

@@ -3117,13 +3117,19 @@ async def execute_command(self, *args, **kwargs):
31173117

31183118
return _create_client
31193119

3120+
@pytest.mark.parametrize("check_server_ready", [True, False])
31203121
async def test_ssl_connection_without_ssl(
3121-
self, create_client: Callable[..., Awaitable[RedisCluster]]
3122+
self, create_client: Callable[..., Awaitable[RedisCluster]], check_server_ready
31223123
) -> None:
31233124
with pytest.raises(RedisClusterException) as e:
3124-
await create_client(mocked=False, ssl=False)
3125+
await create_client(
3126+
mocked=False, ssl=False, check_server_ready=check_server_ready
3127+
)
31253128
e = e.value.__cause__
3126-
assert "Connection closed by server" in str(e)
3129+
if check_server_ready:
3130+
assert "Invalid PING response" in str(e)
3131+
else:
3132+
assert "Connection closed by server" in str(e)
31273133

31283134
async def test_ssl_with_invalid_cert(
31293135
self, create_client: Callable[..., Awaitable[RedisCluster]]

0 commit comments

Comments
 (0)