diff --git a/valkey/_parsers/libvalkey.py b/valkey/_parsers/libvalkey.py index bf91c82c..698bef0d 100644 --- a/valkey/_parsers/libvalkey.py +++ b/valkey/_parsers/libvalkey.py @@ -15,7 +15,6 @@ from .socket import ( NONBLOCKING_EXCEPTION_ERROR_NUMBERS, NONBLOCKING_EXCEPTIONS, - SENTINEL, SERVER_CLOSED_CONNECTION_ERROR, ) @@ -80,9 +79,11 @@ def can_read(self, timeout): return self.read_from_socket(timeout=timeout, raise_on_timeout=False) return True - def read_from_socket(self, timeout=SENTINEL, raise_on_timeout=True): + def read_from_socket( + self, timeout: Optional[float] = None, raise_on_timeout: bool = True + ): sock = self._sock - custom_timeout = timeout is not SENTINEL + custom_timeout = timeout is not None try: if custom_timeout: sock.settimeout(timeout) diff --git a/valkey/_parsers/socket.py b/valkey/_parsers/socket.py index 8147243b..d9fbacf2 100644 --- a/valkey/_parsers/socket.py +++ b/valkey/_parsers/socket.py @@ -2,12 +2,14 @@ import io import socket from io import SEEK_END -from typing import Optional, Union +from typing import Optional from ..exceptions import ConnectionError, TimeoutError from ..utils import SSL_AVAILABLE -NONBLOCKING_EXCEPTION_ERROR_NUMBERS = {BlockingIOError: errno.EWOULDBLOCK} +NONBLOCKING_EXCEPTION_ERROR_NUMBERS: dict[type[OSError], int] = { + BlockingIOError: errno.EWOULDBLOCK +} if SSL_AVAILABLE: import ssl @@ -21,19 +23,19 @@ NONBLOCKING_EXCEPTIONS = tuple(NONBLOCKING_EXCEPTION_ERROR_NUMBERS.keys()) SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server." -SENTINEL = object() SYM_CRLF = b"\r\n" class SocketBuffer: def __init__( - self, socket: socket.socket, socket_read_size: int, socket_timeout: float + self, sock: socket.socket, socket_read_size: int, socket_timeout: float ): - self._sock = socket + self._sock = sock self.socket_read_size = socket_read_size self.socket_timeout = socket_timeout self._buffer = io.BytesIO() + self._closed = False def unread_bytes(self) -> int: """ @@ -47,13 +49,13 @@ def unread_bytes(self) -> int: def _read_from_socket( self, length: Optional[int] = None, - timeout: Union[float, object] = SENTINEL, + timeout: Optional[float] = None, raise_on_timeout: Optional[bool] = True, ) -> bool: sock = self._sock socket_read_size = self.socket_read_size marker = 0 - custom_timeout = timeout is not SENTINEL + custom_timeout = timeout is not None buf = self._buffer current_pos = buf.tell() @@ -92,11 +94,14 @@ def _read_from_socket( sock.settimeout(self.socket_timeout) def can_read(self, timeout: float) -> bool: - return bool(self.unread_bytes()) or self._read_from_socket( - timeout=timeout, raise_on_timeout=False + return not self._closed and ( + bool(self.unread_bytes()) + or self._read_from_socket(timeout=timeout, raise_on_timeout=False) ) def read(self, length: int) -> bytes: + if self._closed: + raise ConnectionError("Socket is closed") length = length + 2 # make sure to read the \r\n terminator # BufferIO will return less than requested if buffer is short data = self._buffer.read(length) @@ -108,6 +113,8 @@ def read(self, length: int) -> bytes: return data[:-2] def readline(self) -> bytes: + if self._closed: + raise ConnectionError("Socket is closed") buf = self._buffer data = buf.readline() while not data.endswith(SYM_CRLF): @@ -151,6 +158,7 @@ def purge(self) -> None: def close(self) -> None: try: self._buffer.close() + self._sock.close() except Exception: # issue #633 suggests the purge/close somehow raised a # BadFileDescriptor error. Perhaps the client ran out of @@ -158,5 +166,4 @@ def close(self) -> None: # any error being raised from purge/close since we're # removing the reference to the instance below. pass - self._buffer = None - self._sock = None + self._closed = True