diff --git a/redis/cluster.py b/redis/cluster.py index 2fd4625e6b..dc91209ed2 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -691,6 +691,7 @@ def __init__( self._event_dispatcher = EventDispatcher() else: self._event_dispatcher = event_dispatcher + self.startup_nodes = startup_nodes self.nodes_manager = NodesManager( startup_nodes=startup_nodes, from_url=from_url, diff --git a/redis/http/__init__.py b/redis/http/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/redis/http/http_client.py b/redis/http/http_client.py new file mode 100644 index 0000000000..0a2de2e44c --- /dev/null +++ b/redis/http/http_client.py @@ -0,0 +1,412 @@ +from __future__ import annotations + +import base64 +import json +import ssl +import gzip +import zlib +from dataclasses import dataclass +from typing import Any, Dict, Mapping, Optional, Tuple, Union +from urllib.parse import urlencode, urljoin +from urllib.request import Request, urlopen +from urllib.error import URLError, HTTPError + + +__all__ = [ + "HttpClient", + "HttpResponse", + "HttpError", + "DEFAULT_TIMEOUT" +] + +from redis.backoff import ExponentialWithJitterBackoff +from redis.retry import Retry +from redis.utils import dummy_fail + +DEFAULT_USER_AGENT = "HttpClient/1.0 (+https://example.invalid)" +DEFAULT_TIMEOUT = 30.0 +RETRY_STATUS_CODES = {429, 500, 502, 503, 504} + + +@dataclass +class HttpResponse: + status: int + headers: Dict[str, str] + url: str + content: bytes + + def text(self, encoding: Optional[str] = None) -> str: + enc = encoding or self._get_encoding() + return self.content.decode(enc, errors="replace") + + def json(self) -> Any: + return json.loads(self.text(encoding=self._get_encoding())) + + def _get_encoding(self) -> str: + # Try to infer encoding from headers; default to utf-8 + ctype = self.headers.get("content-type", "") + # Example: application/json; charset=utf-8 + for part in ctype.split(";"): + p = part.strip() + if p.lower().startswith("charset="): + return p.split("=", 1)[1].strip() or "utf-8" + return "utf-8" + + +class HttpError(Exception): + def __init__(self, status: int, url: str, message: Optional[str] = None): + self.status = status + self.url = url + self.message = message or f"HTTP {status} for {url}" + super().__init__(self.message) + + +class HttpClient: + """ + A lightweight HTTP client for REST API calls. + """ + def __init__( + self, + base_url: str = "", + *, + headers: Optional[Mapping[str, str]] = None, + timeout: float = DEFAULT_TIMEOUT, + retry: Retry = Retry( + backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3 + ), + verify_tls: bool = True, + # TLS verification (server) options + ca_file: Optional[str] = None, + ca_path: Optional[str] = None, + ca_data: Optional[Union[str, bytes]] = None, + # Mutual TLS (client cert) options + client_cert_file: Optional[str] = None, + client_key_file: Optional[str] = None, + client_key_password: Optional[str] = None, + auth_basic: Optional[Tuple[str, str]] = None, # (username, password) + user_agent: str = DEFAULT_USER_AGENT, + ) -> None: + """ + Initialize a new HTTP client instance. + + Args: + base_url: Base URL for all requests. Will be prefixed to all paths. + headers: Default headers to include in all requests. + timeout: Default timeout in seconds for requests. + retry: Retry configuration for failed requests. + verify_tls: Whether to verify TLS certificates. + ca_file: Path to CA certificate file for TLS verification. + ca_path: Path to a directory containing CA certificates. + ca_data: CA certificate data as string or bytes. + client_cert_file: Path to client certificate for mutual TLS. + client_key_file: Path to a client private key for mutual TLS. + client_key_password: Password for an encrypted client private key. + auth_basic: Tuple of (username, password) for HTTP basic auth. + user_agent: User-Agent header value for requests. + + The client supports both regular HTTPS with server verification and mutual TLS + authentication. For server verification, provide CA certificate information via + ca_file, ca_path or ca_data. For mutual TLS, additionally provide a client + certificate and key via client_cert_file and client_key_file. + """ + self.base_url = base_url.rstrip() + "/" if base_url and not base_url.endswith("/") else base_url + self._default_headers = {k.lower(): v for k, v in (headers or {}).items()} + self.timeout = timeout + self.retry = retry + self.retry.update_supported_errors((HTTPError, URLError, ssl.SSLError)) + self.verify_tls = verify_tls + + # TLS settings + self.ca_file = ca_file + self.ca_path = ca_path + self.ca_data = ca_data + self.client_cert_file = client_cert_file + self.client_key_file = client_key_file + self.client_key_password = client_key_password + + self.auth_basic = auth_basic + self.user_agent = user_agent + + # Public JSON-centric helpers + def get( + self, + path: str, + *, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True + ) -> Union[HttpResponse, Any]: + return self._json_call( + "GET", + path, + params=params, + headers=headers, + timeout=timeout, + body=None, + expect_json=expect_json + ) + + def delete( + self, + path: str, + *, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True + ) -> Union[HttpResponse, Any]: + return self._json_call( + "DELETE", + path, + params=params, + headers=headers, + timeout=timeout, + body=None, + expect_json=expect_json + ) + + def post( + self, + path: str, + *, + json_body: Optional[Any] = None, + data: Optional[Union[bytes, str]] = None, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True + ) -> Union[HttpResponse, Any]: + return self._json_call( + "POST", + path, + params=params, + headers=headers, + timeout=timeout, + body=self._prepare_body(json_body=json_body, data=data), + expect_json=expect_json + ) + + def put( + self, + path: str, + *, + json_body: Optional[Any] = None, + data: Optional[Union[bytes, str]] = None, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True + ) -> Union[HttpResponse, Any]: + return self._json_call( + "PUT", + path, + params=params, + headers=headers, + timeout=timeout, + body=self._prepare_body(json_body=json_body, data=data), + expect_json=expect_json + ) + + def patch( + self, + path: str, + *, + json_body: Optional[Any] = None, + data: Optional[Union[bytes, str]] = None, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True + ) -> Union[HttpResponse, Any]: + return self._json_call( + "PATCH", + path, + params=params, + headers=headers, + timeout=timeout, + body=self._prepare_body(json_body=json_body, data=data), + expect_json=expect_json + ) + + # Low-level request + def request( + self, + method: str, + path: str, + *, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + body: Optional[Union[bytes, str]] = None, + timeout: Optional[float] = None, + ) -> HttpResponse: + url = self._build_url(path, params) + all_headers = self._prepare_headers(headers, body) + data = body.encode("utf-8") if isinstance(body, str) else body + + req = Request(url=url, method=method.upper(), data=data, headers=all_headers) + + context: Optional[ssl.SSLContext] = None + if url.lower().startswith("https"): + if self.verify_tls: + # Use provided CA material if any; fall back to system defaults + context = ssl.create_default_context( + cafile=self.ca_file, + capath=self.ca_path, + cadata=self.ca_data, + ) + # Load client certificate for mTLS if configured + if self.client_cert_file: + context.load_cert_chain( + certfile=self.client_cert_file, + keyfile=self.client_key_file, + password=self.client_key_password, + ) + else: + # Verification disabled + context = ssl.create_default_context() + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + + try: + return self.retry.call_with_retry( + lambda: self._make_request(req, context=context, timeout=timeout), + lambda _: dummy_fail(), + lambda error: self._is_retryable_http_error(error), + ) + except HTTPError as e: + # Read error body, build response, and decide on retry + err_body = b"" + try: + err_body = e.read() + except Exception: + pass + headers_map = {k.lower(): v for k, v in (e.headers or {}).items()} + err_body = self._maybe_decompress(err_body, headers_map) + status = getattr(e, "code", 0) or 0 + response = HttpResponse( + status=status, + headers=headers_map, + url=url, + content=err_body, + ) + return response + + def _make_request( + self, + request: Request, + context: Optional[ssl.SSLContext] = None, + timeout: Optional[float] = None, + ): + with urlopen(request, timeout=timeout or self.timeout, context=context) as resp: + raw = resp.read() + headers_map = {k.lower(): v for k, v in resp.headers.items()} + raw = self._maybe_decompress(raw, headers_map) + return HttpResponse( + status=resp.status, + headers=headers_map, + url=resp.geturl(), + content=raw, + ) + + def _is_retryable_http_error(self, error: Exception) -> bool: + if isinstance(error, HTTPError): + return self._should_retry_status(error.code) + return False + + # Internal utilities + def _json_call( + self, + method: str, + path: str, + *, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + body: Optional[Union[bytes, str]] = None, + expect_json: bool = True, + ) -> Union[HttpResponse, Any]: + resp = self.request( + method=method, + path=path, + params=params, + headers=headers, + body=body, + timeout=timeout, + ) + if not (200 <= resp.status < 400): + raise HttpError(resp.status, resp.url, resp.text()) + if expect_json: + return resp.json() + return resp + + def _prepare_body(self, *, json_body: Optional[Any] = None, data: Optional[Union[bytes, str]] = None) -> Optional[Union[bytes, str]]: + if json_body is not None and data is not None: + raise ValueError("Provide either json_body or data, not both.") + if json_body is not None: + return json.dumps(json_body, ensure_ascii=False, separators=(",", ":")) + return data + + def _build_url( + self, + path: str, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, + ) -> str: + url = urljoin(self.base_url or "", path) + if params: + # urlencode with doseq=True supports list/tuple values + query = urlencode({k: v for k, v in params.items() if v is not None}, doseq=True) + separator = "&" if ("?" in url) else "?" + url = f"{url}{separator}{query}" if query else url + return url + + def _prepare_headers(self, headers: Optional[Mapping[str, str]], body: Optional[Union[bytes, str]]) -> Dict[str, str]: + # Start with defaults + prepared: Dict[str, str] = {} + prepared.update(self._default_headers) + + # Standard defaults for JSON REST usage + prepared.setdefault("accept", "application/json") + prepared.setdefault("user-agent", self.user_agent) + # We will send gzip accept-encoding; handle decompression manually + prepared.setdefault("accept-encoding", "gzip, deflate") + + # If we have a string body and content-type not specified, assume JSON + if body is not None and isinstance(body, str): + prepared.setdefault("content-type", "application/json; charset=utf-8") + + # Basic authentication if provided and not overridden + if self.auth_basic and "authorization" not in prepared: + user, pwd = self.auth_basic + token = base64.b64encode(f"{user}:{pwd}".encode("utf-8")).decode("ascii") + prepared["authorization"] = f"Basic {token}" + + # Merge per-call headers (case-insensitive) + if headers: + for k, v in headers.items(): + prepared[k.lower()] = v + + # urllib expects header keys in canonical capitalization sometimes; but it’s tolerant. + # We'll return as provided; urllib will handle it. + return prepared + + def _should_retry_status(self, status: int) -> bool: + return status in RETRY_STATUS_CODES + + def _maybe_decompress(self, content: bytes, headers: Mapping[str, str]) -> bytes: + if not content: + return content + encoding = (headers.get("content-encoding") or "").lower() + try: + if "gzip" in encoding: + return gzip.decompress(content) + if "deflate" in encoding: + # Try raw deflate, then zlib-wrapped + try: + return zlib.decompress(content, -zlib.MAX_WBITS) + except zlib.error: + return zlib.decompress(content) + except Exception: + # If decompression fails, return original bytes + return content + return content \ No newline at end of file diff --git a/redis/multidb/client.py b/redis/multidb/client.py index 2f87024f20..56342a7a53 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -230,7 +230,7 @@ def _check_db_health(self, database: AbstractDatabase, on_error: Callable[[Excep database.circuit.state = CBState.OPEN elif is_healthy and database.circuit.state != CBState.CLOSED: database.circuit.state = CBState.CLOSED - except (ConnectionError, TimeoutError, socket.timeout, ConnectionRefusedError, ValueError) as e: + except Exception as e: if database.circuit.state != CBState.OPEN: database.circuit.state = CBState.OPEN is_healthy = False diff --git a/redis/multidb/config.py b/redis/multidb/config.py index 4bacc2c680..5555baec44 100644 --- a/redis/multidb/config.py +++ b/redis/multidb/config.py @@ -12,14 +12,13 @@ from redis.multidb.circuit import CircuitBreaker, PBCircuitBreakerAdapter from redis.multidb.database import Database, Databases from redis.multidb.failure_detector import FailureDetector, CommandFailureDetector -from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck +from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, \ + DEFAULT_HEALTH_CHECK_BACKOFF from redis.multidb.failover import FailoverStrategy, WeightBasedFailoverStrategy from redis.retry import Retry DEFAULT_GRACE_PERIOD = 5.0 DEFAULT_HEALTH_CHECK_INTERVAL = 5 -DEFAULT_HEALTH_CHECK_RETRIES = 3 -DEFAULT_HEALTH_CHECK_BACKOFF = ExponentialWithJitterBackoff(cap=10) DEFAULT_FAILURES_THRESHOLD = 3 DEFAULT_FAILURES_DURATION = 2 DEFAULT_FAILOVER_RETRIES = 3 @@ -31,12 +30,36 @@ def default_event_dispatcher() -> EventDispatcherInterface: @dataclass class DatabaseConfig: + """ + Dataclass representing the configuration for a database connection. + + This class is used to store configuration settings for a database connection, + including client options, connection sourcing details, circuit breaker settings, + and cluster-specific properties. It provides a structure for defining these + attributes and allows for the creation of customized configurations for various + database setups. + + Attributes: + weight (float): Weight of the database to define the active one. + client_kwargs (dict): Additional parameters for the database client connection. + from_url (Optional[str]): Redis URL way of connecting to the database. + from_pool (Optional[ConnectionPool]): A pre-configured connection pool to use. + circuit (Optional[CircuitBreaker]): Custom circuit breaker implementation. + grace_period (float): Grace period after which we need to check if the circuit could be closed again. + health_check_url (Optional[str]): URL for health checks. Cluster FQDN is typically used + on public Redis Enterprise endpoints. + + Methods: + default_circuit_breaker: + Generates and returns a default CircuitBreaker instance adapted for use. + """ weight: float = 1.0 client_kwargs: dict = field(default_factory=dict) from_url: Optional[str] = None from_pool: Optional[ConnectionPool] = None circuit: Optional[CircuitBreaker] = None grace_period: float = DEFAULT_GRACE_PERIOD + health_check_url: Optional[str] = None def default_circuit_breaker(self) -> CircuitBreaker: circuit_breaker = pybreaker.CircuitBreaker(reset_timeout=self.grace_period) @@ -118,7 +141,12 @@ def databases(self) -> Databases: circuit = database_config.default_circuit_breaker() \ if database_config.circuit is None else database_config.circuit databases.add( - Database(client=client, circuit=circuit, weight=database_config.weight), + Database( + client=client, + circuit=circuit, + weight=database_config.weight, + health_check_url=database_config.health_check_url + ), database_config.weight ) diff --git a/redis/multidb/database.py b/redis/multidb/database.py index 3253ffa093..b03e77bd70 100644 --- a/redis/multidb/database.py +++ b/redis/multidb/database.py @@ -1,7 +1,7 @@ import redis from abc import ABC, abstractmethod from enum import Enum -from typing import Union +from typing import Union, Optional from redis import RedisCluster from redis.data_structure import WeightedList @@ -45,6 +45,18 @@ def circuit(self, circuit: CircuitBreaker): """Set the circuit breaker for the current database.""" pass + @property + @abstractmethod + def health_check_url(self) -> Optional[str]: + """Health check URL associated with the current database.""" + pass + + @health_check_url.setter + @abstractmethod + def health_check_url(self, health_check_url: Optional[str]): + """Set the health check URL associated with the current database.""" + pass + Databases = WeightedList[tuple[AbstractDatabase, Number]] class Database(AbstractDatabase): @@ -52,7 +64,8 @@ def __init__( self, client: Union[redis.Redis, RedisCluster], circuit: CircuitBreaker, - weight: float + weight: float, + health_check_url: Optional[str] = None, ): """ Initialize a new Database instance. @@ -61,12 +74,13 @@ def __init__( client: Underlying Redis client instance for database operations circuit: Circuit breaker for handling database failures weight: Weight value used for database failover prioritization - state: Initial database state, defaults to DISCONNECTED + health_check_url: Health check URL associated with the current database """ self._client = client self._cb = circuit self._cb.database = self self._weight = weight + self._health_check_url = health_check_url @property def client(self) -> Union[redis.Redis, RedisCluster]: @@ -91,3 +105,11 @@ def circuit(self) -> CircuitBreaker: @circuit.setter def circuit(self, circuit: CircuitBreaker): self._cb = circuit + + @property + def health_check_url(self) -> Optional[str]: + return self._health_check_url + + @health_check_url.setter + def health_check_url(self, health_check_url: Optional[str]): + self._health_check_url = health_check_url diff --git a/redis/multidb/failover.py b/redis/multidb/failover.py index 541f3413dc..d6cf198678 100644 --- a/redis/multidb/failover.py +++ b/redis/multidb/failover.py @@ -6,6 +6,7 @@ from redis.multidb.circuit import State as CBState from redis.multidb.exception import NoValidDatabaseException from redis.retry import Retry +from redis.utils import dummy_fail class FailoverStrategy(ABC): @@ -37,7 +38,7 @@ def __init__( def database(self) -> AbstractDatabase: return self._retry.call_with_retry( lambda: self._get_active_database(), - lambda _: self._dummy_fail() + lambda _: dummy_fail() ) def set_databases(self, databases: Databases) -> None: @@ -49,6 +50,3 @@ def _get_active_database(self) -> AbstractDatabase: return database raise NoValidDatabaseException('No valid database available for communication') - - def _dummy_fail(self): - pass diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py index cca220dc3f..63ba415334 100644 --- a/redis/multidb/healthcheck.py +++ b/redis/multidb/healthcheck.py @@ -1,9 +1,17 @@ +import logging from abc import abstractmethod, ABC +from typing import Optional, Tuple, Union -import redis from redis import Redis +from redis.backoff import ExponentialWithJitterBackoff +from redis.http.http_client import DEFAULT_TIMEOUT, HttpClient from redis.retry import Retry +from redis.utils import dummy_fail +DEFAULT_HEALTH_CHECK_RETRIES = 3 +DEFAULT_HEALTH_CHECK_BACKOFF = ExponentialWithJitterBackoff(cap=10) + +logger = logging.getLogger(__name__) class HealthCheck(ABC): @@ -21,7 +29,7 @@ def check_health(self, database) -> bool: class AbstractHealthCheck(HealthCheck): def __init__( self, - retry: Retry, + retry: Retry = Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) ) -> None: self._retry = retry self._retry.update_supported_errors([ConnectionRefusedError]) @@ -37,8 +45,8 @@ def check_health(self, database) -> bool: class EchoHealthCheck(AbstractHealthCheck): def __init__( - self, - retry: Retry, + self, + retry: Retry = Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) ) -> None: """ Check database healthiness by sending an echo request. @@ -49,7 +57,7 @@ def __init__( def check_health(self, database) -> bool: return self._retry.call_with_retry( lambda: self._returns_echoed_message(database), - lambda _: self._dummy_fail() + lambda _: dummy_fail() ) def _returns_echoed_message(self, database) -> bool: @@ -69,5 +77,94 @@ def _returns_echoed_message(self, database) -> bool: return True - def _dummy_fail(self): - pass \ No newline at end of file +class LagAwareHealthCheck(AbstractHealthCheck): + """ + Health check available for Redis Enterprise deployments. + Verify via REST API that the database is healthy based on different lags. + """ + def __init__( + self, + retry: Retry = Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF), + rest_api_port: int = 9443, + timeout: float = DEFAULT_TIMEOUT, + auth_basic: Optional[Tuple[str, str]] = None, + verify_tls: bool = True, + # TLS verification (server) options + ca_file: Optional[str] = None, + ca_path: Optional[str] = None, + ca_data: Optional[Union[str, bytes]] = None, + # Mutual TLS (client cert) options + client_cert_file: Optional[str] = None, + client_key_file: Optional[str] = None, + client_key_password: Optional[str] = None, + ): + """ + Initialize LagAwareHealthCheck with the specified parameters. + + Args: + retry: Retry configuration for health checks + rest_api_port: Port number for Redis Enterprise REST API (default: 9443) + timeout: Request timeout in seconds (default: DEFAULT_TIMEOUT) + auth_basic: Tuple of (username, password) for basic authentication + verify_tls: Whether to verify TLS certificates (default: True) + ca_file: Path to CA certificate file for TLS verification + ca_path: Path to CA certificates directory for TLS verification + ca_data: CA certificate data as string or bytes + client_cert_file: Path to client certificate file for mutual TLS + client_key_file: Path to client private key file for mutual TLS + client_key_password: Password for encrypted client private key + """ + super().__init__( + retry=retry, + ) + self._http_client = HttpClient( + timeout=timeout, + auth_basic=auth_basic, + retry=self.retry, + verify_tls=verify_tls, + ca_file=ca_file, + ca_path=ca_path, + ca_data=ca_data, + client_cert_file=client_cert_file, + client_key_file=client_key_file, + client_key_password=client_key_password + ) + self._rest_api_port = rest_api_port + + def check_health(self, database) -> bool: + if database.health_check_url is None: + raise ValueError( + "Database health check url is not set. Please check DatabaseConfig for the current database." + ) + + if isinstance(database.client, Redis): + db_host = database.client.get_connection_kwargs()["host"] + else: + db_host = database.client.startup_nodes[0].host + + base_url = f"{database.health_check_url}:{self._rest_api_port}" + self._http_client.base_url = base_url + + # Find bdb matching to the current database host + matching_bdb = None + for bdb in self._http_client.get("/v1/bdbs"): + for endpoint in bdb["endpoints"]: + if endpoint['dns_name'] == db_host: + matching_bdb = bdb + break + + # In case if the host was set as public IP + for addr in endpoint['addr']: + if addr == db_host: + matching_bdb = bdb + break + + if matching_bdb is None: + logger.warning("LagAwareHealthCheck failed: Couldn't find a matching bdb") + raise ValueError("Could not find a matching bdb") + + url = f"/v1/local/bdbs/{matching_bdb['uid']}/endpoint/availability" + self._http_client.get(url, expect_json=False) + + # Status checked in an http client, otherwise HttpError will be raised + return True diff --git a/redis/retry.py b/redis/retry.py index c93f34e65f..7989b41742 100644 --- a/redis/retry.py +++ b/redis/retry.py @@ -1,6 +1,6 @@ import socket from time import sleep -from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Type, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Type, TypeVar, Optional from redis.exceptions import ConnectionError, TimeoutError @@ -73,6 +73,7 @@ def call_with_retry( self, do: Callable[[], T], fail: Callable[[Exception], Any], + is_retryable: Optional[Callable[[Exception], bool]] = None ) -> T: """ Execute an operation that might fail and returns its result, or @@ -86,6 +87,8 @@ def call_with_retry( try: return do() except self._supported_errors as error: + if is_retryable and not is_retryable(error): + raise failures += 1 fail(error) if self._retries >= 0 and failures > self._retries: diff --git a/redis/utils.py b/redis/utils.py index 715913e914..94bfab61bb 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -308,3 +308,9 @@ def truncate_text(txt, max_length=100): return textwrap.shorten( text=txt, width=max_length, placeholder="...", break_long_words=True ) + +def dummy_fail(): + """ + Fake function for a Retry object if you don't need to handle each failure. + """ + pass diff --git a/tests/test_http/__init__.py b/tests/test_http/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_http/test_http_client.py b/tests/test_http/test_http_client.py new file mode 100644 index 0000000000..9a6d28ecd4 --- /dev/null +++ b/tests/test_http/test_http_client.py @@ -0,0 +1,324 @@ +import json +import gzip +from io import BytesIO +from typing import Any, Dict +from urllib.error import HTTPError +from urllib.parse import urlparse, parse_qs + +import pytest + +from redis.backoff import ExponentialWithJitterBackoff +from redis.http.http_client import HttpClient, HttpError +from redis.retry import Retry + + +class FakeResponse: + def __init__(self, *, status: int, headers: Dict[str, str], url: str, content: bytes): + self.status = status + self.headers = headers + self._url = url + self._content = content + + def read(self) -> bytes: + return self._content + + def geturl(self) -> str: + return self._url + + # Support context manager used by urlopen + def __enter__(self) -> "FakeResponse": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + +class TestHttpClient: + def test_get_returns_parsed_json_and_uses_timeout(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + base_url = "https://api.example.com/" + path = "v1/items" + params = {"limit": 5, "q": "hello world"} + expected_url = f"{base_url}{path}?limit=5&q=hello+world" + payload: Dict[str, Any] = {"items": [1, 2, 3], "ok": True} + content = json.dumps(payload).encode("utf-8") + + captured_kwargs = {} + + def fake_urlopen(request, *, timeout=None, context=None): + # Capture call details for assertions + captured_kwargs["timeout"] = timeout + captured_kwargs["context"] = context + # Assert the request was constructed correctly + assert getattr(request, "method", "").upper() == "GET" + assert request.full_url == expected_url + # Return a successful response + return FakeResponse( + status=200, + headers={"Content-Type": "application/json; charset=utf-8"}, + url=expected_url, + content=content, + ) + + # Patch the urlopen used inside HttpClient + monkeypatch.setattr("redis.http.http_client.urlopen", fake_urlopen) + + client = HttpClient(base_url=base_url) + + # Act + result = client.get(path, params=params, timeout=12.34) # default expect_json=True + + # Assert + assert result == payload + assert pytest.approx(captured_kwargs["timeout"], rel=1e-6) == 12.34 + # HTTPS -> a context should be provided (created by ssl.create_default_context) + assert captured_kwargs["context"] is not None + + def test_get_handles_gzip_response(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + base_url = "https://api.example.com/" + path = "gzip-endpoint" + expected_url = f"{base_url}{path}" + payload = {"message": "compressed ok"} + raw = json.dumps(payload).encode("utf-8") + gzipped = gzip.compress(raw) + + def fake_urlopen(request, *, timeout=None, context=None): + # Return gzipped content with appropriate header + return FakeResponse( + status=200, + headers={ + "Content-Type": "application/json; charset=utf-8", + "Content-Encoding": "gzip", + }, + url=expected_url, + content=gzipped, + ) + + monkeypatch.setattr("redis.http.http_client.urlopen", fake_urlopen) + + client = HttpClient(base_url=base_url) + + # Act + result = client.get(path) # expect_json=True by default + + # Assert + assert result == payload + + def test_get_retries_on_retryable_http_errors_and_succeeds(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange: configure limited retries so we can assert attempts + retry_policy = Retry(backoff=ExponentialWithJitterBackoff(base=0, cap=0), + retries=2) # 2 retries -> up to 3 attempts + base_url = "https://api.example.com/" + path = "sometimes-busy" + expected_url = f"{base_url}{path}" + payload = {"ok": True} + success_content = json.dumps(payload).encode("utf-8") + + call_count = {"n": 0} + + def make_http_error(url: str, code: int, body: bytes = b"busy"): + # Provide a file-like object for .read() when HttpClient tries to read error content + fp = BytesIO(body) + return HTTPError(url=url, code=code, msg="Service Unavailable", hdrs={"Content-Type": "text/plain"}, fp=fp) + + def flaky_urlopen(request, *, timeout=None, context=None): + call_count["n"] += 1 + # Fail with a retryable status (503) for the first two calls, then succeed + if call_count["n"] <= 2: + raise make_http_error(expected_url, 503) + return FakeResponse( + status=200, + headers={"Content-Type": "application/json; charset=utf-8"}, + url=expected_url, + content=success_content, + ) + + monkeypatch.setattr("redis.http.http_client.urlopen", flaky_urlopen) + + client = HttpClient(base_url=base_url, retry=retry_policy) + + # Act + result = client.get(path) + + # Assert: should have retried twice (total 3 attempts) and finally returned parsed JSON + assert result == payload + assert call_count["n"] == retry_policy.get_retries() + 1 + + def test_post_sends_json_body_and_parses_response(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + base_url = "https://api.example.com/" + path = "v1/create" + expected_url = f"{base_url}{path}" + send_payload = {"a": 1, "b": "x"} + recv_payload = {"id": 10, "ok": True} + recv_content = json.dumps(recv_payload, separators=(",", ":")).encode("utf-8") + + def fake_urlopen(request, *, timeout=None, context=None): + # Verify method, URL and headers + assert getattr(request, "method", "").upper() == "POST" + assert request.full_url == expected_url + # Content-Type should be auto-set for string JSON body + assert request.headers.get("Content-type") == "application/json; charset=utf-8" + # Body should be already UTF-8 encoded JSON with no spaces + assert request.data == json.dumps(send_payload, ensure_ascii=False, separators=(",", ":")).encode("utf-8") + return FakeResponse( + status=200, + headers={"Content-Type": "application/json; charset=utf-8"}, + url=expected_url, + content=recv_content, + ) + + monkeypatch.setattr("redis.http.http_client.urlopen", fake_urlopen) + + client = HttpClient(base_url=base_url) + + # Act + result = client.post(path, json_body=send_payload) + + # Assert + assert result == recv_payload + + def test_post_with_raw_data_and_custom_headers(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + base_url = "https://api.example.com/" + path = "upload" + expected_url = f"{base_url}{path}" + raw_data = b"\x00\x01BINARY" + custom_headers = {"Content-type": "application/octet-stream", "X-extra": "1"} + recv_payload = {"status": "ok"} + + def fake_urlopen(request, *, timeout=None, context=None): + assert getattr(request, "method", "").upper() == "POST" + assert request.full_url == expected_url + # Ensure our provided headers are present + assert request.headers.get("Content-type") == "application/octet-stream" + assert request.headers.get("X-extra") == "1" + assert request.data == raw_data + return FakeResponse( + status=200, + headers={"Content-Type": "application/json"}, + url=expected_url, + content=json.dumps(recv_payload).encode("utf-8"), + ) + + monkeypatch.setattr("redis.http.http_client.urlopen", fake_urlopen) + + client = HttpClient(base_url=base_url) + # Act + result = client.post(path, data=raw_data, headers=custom_headers) + + # Assert + assert result == recv_payload + + def test_delete_returns_http_response_when_expect_json_false(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + base_url = "https://api.example.com/" + path = "v1/resource/42" + expected_url = f"{base_url}{path}" + body = b"deleted" + + def fake_urlopen(request, *, timeout=None, context=None): + assert getattr(request, "method", "").upper() == "DELETE" + assert request.full_url == expected_url + return FakeResponse( + status=204, + headers={"Content-Type": "text/plain"}, + url=expected_url, + content=body, + ) + + monkeypatch.setattr("redis.http.http_client.urlopen", fake_urlopen) + client = HttpClient(base_url=base_url) + + # Act + resp = client.delete(path, expect_json=False) + + # Assert + assert resp.status == 204 + assert resp.url == expected_url + assert resp.content == body + + def test_put_raises_http_error_on_non_success(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + base_url = "https://api.example.com/" + path = "v1/update/1" + expected_url = f"{base_url}{path}" + + def make_http_error(url: str, code: int, body: bytes = b"not found"): + fp = BytesIO(body) + return HTTPError(url=url, code=code, msg="Not Found", hdrs={"Content-Type": "text/plain"}, fp=fp) + + def fake_urlopen(request, *, timeout=None, context=None): + raise make_http_error(expected_url, 404) + + monkeypatch.setattr("redis.http.http_client.urlopen", fake_urlopen) + client = HttpClient(base_url=base_url) + + # Act / Assert + with pytest.raises(HttpError) as exc: + client.put(path, json_body={"x": 1}) + assert exc.value.status == 404 + assert exc.value.url == expected_url + + def test_patch_with_params_encodes_query(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + base_url = "https://api.example.com/" + path = "v1/edit" + params = {"tag": ["a", "b"], "q": "hello world"} + + captured_url = {"u": None} + + def fake_urlopen(request, *, timeout=None, context=None): + captured_url["u"] = request.full_url + return FakeResponse( + status=200, + headers={"Content-Type": "application/json"}, + url=request.full_url, + content=json.dumps({"ok": True}).encode("utf-8"), + ) + + monkeypatch.setattr("redis.http.http_client.urlopen", fake_urlopen) + + client = HttpClient(base_url=base_url) + client.patch(path, params=params) # We don't care about response here + + # Assert query parameters regardless of ordering + parsed = urlparse(captured_url["u"]) + qs = parse_qs(parsed.query) + assert qs["q"] == ["hello world"] + assert qs["tag"] == ["a", "b"] + + def test_request_low_level_headers_auth_and_timeout_default(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange: use plain HTTP to verify no TLS context, and check default timeout used + base_url = "http://example.com/" + path = "ping" + captured = {"timeout": None, "context": "unset", "headers": None, "method": None} + + def fake_urlopen(request, *, timeout=None, context=None): + captured["timeout"] = timeout + captured["context"] = context + captured["headers"] = dict(request.headers) + captured["method"] = getattr(request, "method", "").upper() + return FakeResponse( + status=200, + headers={"Content-Type": "application/json"}, + url=request.full_url, + content=json.dumps({"pong": True}).encode("utf-8"), + ) + + monkeypatch.setattr("redis.http.http_client.urlopen", fake_urlopen) + + client = HttpClient(base_url=base_url, auth_basic=("user", "pass")) + resp = client.request("GET", path) + + # Assert + assert resp.status == 200 + assert captured["method"] == "GET" + assert captured["context"] is None # no TLS for http + assert pytest.approx(captured["timeout"], rel=1e-6) == client.timeout # default used + # Check some default headers and Authorization presence + headers = {k.lower(): v for k, v in captured["headers"].items()} + assert "authorization" in headers and headers["authorization"].startswith("Basic ") + assert headers.get("accept") == "application/json" + assert "gzip" in headers.get("accept-encoding", "").lower() + assert "user-agent" in headers \ No newline at end of file diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index 37ee9b3fd3..193980d37c 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -6,14 +6,15 @@ from redis.event import EventDispatcher, OnCommandsFailEvent from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter -from redis.multidb.config import DEFAULT_HEALTH_CHECK_RETRIES, DEFAULT_HEALTH_CHECK_BACKOFF, DEFAULT_FAILOVER_RETRIES, \ +from redis.multidb.config import DEFAULT_FAILOVER_RETRIES, \ DEFAULT_FAILOVER_BACKOFF from redis.multidb.database import AbstractDatabase from redis.multidb.client import MultiDBClient from redis.multidb.exception import NoValidDatabaseException from redis.multidb.failover import WeightBasedFailoverStrategy from redis.multidb.failure_detector import FailureDetector -from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck +from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, \ + DEFAULT_HEALTH_CHECK_BACKOFF from redis.retry import Retry from tests.test_multidb.conftest import create_weighted_list diff --git a/tests/test_multidb/test_healthcheck.py b/tests/test_multidb/test_healthcheck.py index 08bd8ab0c4..bc71fdb57d 100644 --- a/tests/test_multidb/test_healthcheck.py +++ b/tests/test_multidb/test_healthcheck.py @@ -1,6 +1,12 @@ +from unittest.mock import MagicMock + +import pytest + from redis.backoff import ExponentialBackoff from redis.multidb.database import Database from redis.multidb.healthcheck import EchoHealthCheck +from redis.http.http_client import HttpError +from redis.multidb.healthcheck import EchoHealthCheck, LagAwareHealthCheck from redis.multidb.circuit import State as CBState from redis.exceptions import ConnectionError from redis.retry import Retry @@ -38,4 +44,137 @@ def test_database_close_circuit_on_successful_healthcheck(self, mock_client, moc db = Database(mock_client, mock_cb, 0.9) assert hc.check_health(db) == True - assert mock_client.execute_command.call_count == 3 \ No newline at end of file + assert mock_client.execute_command.call_count == 3 + + +class TestLagAwareHealthCheck: + def test_database_is_healthy_when_bdb_matches_by_dns_name(self, mock_client, mock_cb): + """ + Ensures health check succeeds when /v1/bdbs contains an endpoint whose dns_name + matches database host, and availability endpoint returns success. + """ + host = "db1.example.com" + mock_client.get_connection_kwargs.return_value = {"host": host} + + # Mock HttpClient used inside LagAwareHealthCheck + mock_http = MagicMock() + mock_http.get.side_effect = [ + # First call: list of bdbs + [ + { + "uid": "bdb-1", + "endpoints": [ + {"dns_name": host, "addr": ["10.0.0.1", "10.0.0.2"]}, + ], + } + ], + # Second call: availability check (no JSON expected) + None, + ] + + hc = LagAwareHealthCheck( + retry=Retry(backoff=ExponentialBackoff(cap=1.0), retries=3), + rest_api_port=1234, + ) + # Inject our mocked http client + hc._http_client = mock_http + + db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") + + assert hc.check_health(db) is True + # Base URL must be set correctly + assert hc._http_client.base_url == f"https://healthcheck.example.com:1234" + # Calls: first to list bdbs, then to availability + assert mock_http.get.call_count == 2 + first_call = mock_http.get.call_args_list[0] + second_call = mock_http.get.call_args_list[1] + assert first_call.args[0] == "/v1/bdbs" + assert second_call.args[0] == "/v1/local/bdbs/bdb-1/endpoint/availability" + assert second_call.kwargs.get("expect_json") is False + + def test_database_is_healthy_when_bdb_matches_by_addr(self, mock_client, mock_cb): + """ + Ensures health check succeeds when endpoint addr list contains the database host. + """ + host_ip = "203.0.113.5" + mock_client.get_connection_kwargs.return_value = {"host": host_ip} + + mock_http = MagicMock() + mock_http.get.side_effect = [ + [ + { + "uid": "bdb-42", + "endpoints": [ + {"dns_name": "not-matching.example.com", "addr": [host_ip]}, + ], + } + ], + None, + ] + + hc = LagAwareHealthCheck( + retry=Retry(backoff=ExponentialBackoff(cap=1.0), retries=3), + ) + hc._http_client = mock_http + + db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") + + assert hc.check_health(db) is True + assert mock_http.get.call_count == 2 + assert mock_http.get.call_args_list[1].args[0] == "/v1/local/bdbs/bdb-42/endpoint/availability" + + def test_raises_value_error_when_no_matching_bdb(self, mock_client, mock_cb): + """ + Ensures health check raises ValueError when there's no bdb matching the database host. + """ + host = "db2.example.com" + mock_client.get_connection_kwargs.return_value = {"host": host} + + mock_http = MagicMock() + # Return bdbs that do not match host by dns_name nor addr + mock_http.get.return_value = [ + {"uid": "a", "endpoints": [{"dns_name": "other.example.com", "addr": ["10.0.0.9"]}]}, + {"uid": "b", "endpoints": [{"dns_name": "another.example.com", "addr": ["10.0.0.10"]}]}, + ] + + hc = LagAwareHealthCheck( + retry=Retry(backoff=ExponentialBackoff(cap=1.0), retries=3), + ) + hc._http_client = mock_http + + db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") + + with pytest.raises(ValueError, match="Could not find a matching bdb"): + hc.check_health(db) + + # Only the listing call should have happened + mock_http.get.assert_called_once_with("/v1/bdbs") + + def test_propagates_http_error_from_availability(self, mock_client, mock_cb): + """ + Ensures that any HTTP error raised by the availability endpoint is propagated. + """ + host = "db3.example.com" + mock_client.get_connection_kwargs.return_value = {"host": host} + + mock_http = MagicMock() + # First: list bdbs -> match by dns_name + mock_http.get.side_effect = [ + [{"uid": "bdb-err", "endpoints": [{"dns_name": host, "addr": []}]}], + # Second: availability -> raise HttpError + HttpError(url=f"https://{host}:9443/v1/bdbs/bdb-err/availability", status=503, message="busy"), + ] + + hc = LagAwareHealthCheck( + retry=Retry(backoff=ExponentialBackoff(cap=1.0), retries=3), + ) + hc._http_client = mock_http + + db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") + + with pytest.raises(HttpError, match="busy") as e: + hc.check_health(db) + assert e.status == 503 + + # Ensure both calls were attempted + assert mock_http.get.call_count == 2 \ No newline at end of file diff --git a/tests/test_multidb/test_pipeline.py b/tests/test_multidb/test_pipeline.py index f0d2a0dbe3..6e7c344d85 100644 --- a/tests/test_multidb/test_pipeline.py +++ b/tests/test_multidb/test_pipeline.py @@ -4,15 +4,13 @@ import pybreaker import pytest -from redis.event import EventDispatcher -from redis.exceptions import ConnectionError from redis.client import Pipeline from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter from redis.multidb.client import MultiDBClient -from redis.multidb.config import DEFAULT_HEALTH_CHECK_RETRIES, DEFAULT_HEALTH_CHECK_BACKOFF, DEFAULT_FAILOVER_RETRIES, \ - DEFAULT_FAILOVER_BACKOFF, DEFAULT_FAILURES_THRESHOLD +from redis.multidb.config import DEFAULT_FAILOVER_RETRIES, \ + DEFAULT_FAILOVER_BACKOFF from redis.multidb.failover import WeightBasedFailoverStrategy -from redis.multidb.healthcheck import EchoHealthCheck +from redis.multidb.healthcheck import EchoHealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, DEFAULT_HEALTH_CHECK_BACKOFF from redis.retry import Retry from tests.test_multidb.conftest import create_weighted_list diff --git a/tests/test_scenario/conftest.py b/tests/test_scenario/conftest.py index 4182962fb1..a0f19e1a87 100644 --- a/tests/test_scenario/conftest.py +++ b/tests/test_scenario/conftest.py @@ -1,5 +1,7 @@ import json import os +import re +from urllib.parse import urlparse import pytest @@ -73,7 +75,8 @@ def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListen 'username': username, 'password': password, 'decode_responses': True, - } + }, + health_check_url=extract_cluster_fqdn(endpoint_config['endpoints'][0]) ) db_configs.append(db_config) @@ -84,7 +87,8 @@ def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListen 'username': username, 'password': password, 'decode_responses': True, - } + }, + health_check_url=extract_cluster_fqdn(endpoint_config['endpoints'][1]) ) db_configs.append(db_config1) @@ -93,10 +97,29 @@ def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListen databases_config=db_configs, command_retry=command_retry, failure_threshold=failure_threshold, + health_check_retries=3, health_check_interval=health_check_interval, event_dispatcher=event_dispatcher, health_check_backoff=ExponentialBackoff(cap=5, base=0.5), - health_check_retries=3, ) - return MultiDBClient(config), listener, endpoint_config \ No newline at end of file + return MultiDBClient(config), listener, endpoint_config + + +def extract_cluster_fqdn(url): + """ + Extract Cluster FQDN from Redis URL + """ + # Parse the URL + parsed = urlparse(url) + + # Extract hostname and port + hostname = parsed.hostname + port = parsed.port + + # Remove the 'redis-XXXX.' prefix using regex + # This pattern matches 'redis-' followed by digits and a dot + cleaned_hostname = re.sub(r'^redis-\d+\.', '', hostname) + + # Reconstruct the URL + return f"https://{cleaned_hostname}" \ No newline at end of file diff --git a/tests/test_scenario/test_active_active.py b/tests/test_scenario/test_active_active.py index 967fa43cdb..44c57e6b99 100644 --- a/tests/test_scenario/test_active_active.py +++ b/tests/test_scenario/test_active_active.py @@ -1,5 +1,6 @@ import json import logging +import os import threading from time import sleep @@ -7,6 +8,7 @@ from redis import Redis, RedisCluster from redis.client import Pipeline +from redis.multidb.healthcheck import LagAwareHealthCheck from tests.test_scenario.fault_injector_client import ActionRequest, ActionType logger = logging.getLogger(__name__) @@ -70,6 +72,48 @@ def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector assert r_multi_db.get('key') == 'value' sleep(0.5) + @pytest.mark.parametrize( + "r_multi_db", + [ + {"client_class": Redis, "failure_threshold": 2}, + {"client_class": RedisCluster, "failure_threshold": 2}, + ], + ids=["standalone", "cluster"], + indirect=True + ) + @pytest.mark.timeout(50) + def test_multi_db_client_uses_lag_aware_health_check(self, r_multi_db, fault_injector_client): + r_multi_db, listener, config = r_multi_db + + event = threading.Event() + thread = threading.Thread( + target=trigger_network_failure_action, + daemon=True, + args=(fault_injector_client,config,event) + ) + + env0_username = os.getenv('ENV0_USERNAME') + env0_password = os.getenv('ENV0_PASSWORD') + + # Adding additional health check to the client. + r_multi_db.add_health_check( + LagAwareHealthCheck(verify_tls=False, auth_basic=(env0_username,env0_password)) + ) + + # Client initialized on the first command. + r_multi_db.set('key', 'value') + thread.start() + + # Execute commands before network failure + while not event.is_set(): + assert r_multi_db.get('key') == 'value' + sleep(0.5) + + # Execute commands after network failure + while not listener.is_changed_flag: + assert r_multi_db.get('key') == 'value' + sleep(0.5) + @pytest.mark.parametrize( "r_multi_db", [ @@ -268,7 +312,7 @@ def handler(message): sleep(0.5) pubsub_thread.stop() - assert messages_count > 5 + assert messages_count > 2 @pytest.mark.parametrize( "r_multi_db", @@ -318,4 +362,4 @@ def handler(message): sleep(0.5) pubsub_thread.stop() - assert messages_count > 5 \ No newline at end of file + assert messages_count > 2 \ No newline at end of file