From e09d399637a58aabf7ec7561fbffd5433c5c2bf8 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Wed, 13 Aug 2025 18:33:34 +0300 Subject: [PATCH 01/13] Added LagAwareHealthcheck --- redis/http/__init__.py | 0 redis/http/http_client.py | 368 ++++++++++++++++++++++++++++ redis/multidb/config.py | 5 +- redis/multidb/failover.py | 6 +- redis/multidb/healthcheck.py | 65 ++++- redis/retry.py | 5 +- redis/utils.py | 6 + tests/test_multidb/test_client.py | 5 +- tests/test_multidb/test_pipeline.py | 8 +- 9 files changed, 447 insertions(+), 21 deletions(-) create mode 100644 redis/http/__init__.py create mode 100644 redis/http/http_client.py diff --git a/redis/http/__init__.py b/redis/http/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/redis/http/http_client.py b/redis/http/http_client.py new file mode 100644 index 0000000000..30efee661a --- /dev/null +++ b/redis/http/http_client.py @@ -0,0 +1,368 @@ +from __future__ import annotations + +import base64 +import json +import ssl +import gzip +import zlib +from dataclasses import dataclass +from typing import Any, Dict, Mapping, Optional, Tuple, Union +from urllib.parse import urlencode, urljoin +from urllib.request import Request, urlopen +from urllib.error import URLError, HTTPError + + +__all__ = [ + "HttpClient", + "HttpResponse", + "HttpError", + "DEFAULT_TIMEOUT" +] + +from redis.backoff import ExponentialWithJitterBackoff +from redis.retry import Retry +from redis.utils import dummy_fail + +DEFAULT_USER_AGENT = "HttpClient/1.0 (+https://example.invalid)" +DEFAULT_TIMEOUT = 30.0 +RETRY_STATUS_CODES = {429, 500, 502, 503, 504} + + +@dataclass +class HttpResponse: + status: int + headers: Dict[str, str] + url: str + content: bytes + + def text(self, encoding: Optional[str] = None) -> str: + enc = encoding or self._get_encoding() + return self.content.decode(enc, errors="replace") + + def json(self) -> Any: + return json.loads(self.text(encoding=self._get_encoding())) + + def _get_encoding(self) -> str: + # Try to infer encoding from headers; default to utf-8 + ctype = self.headers.get("content-type", "") + # Example: application/json; charset=utf-8 + for part in ctype.split(";"): + p = part.strip() + if p.lower().startswith("charset="): + return p.split("=", 1)[1].strip() or "utf-8" + return "utf-8" + + +class HttpError(Exception): + def __init__(self, status: int, url: str, message: Optional[str] = None): + self.status = status + self.url = url + self.message = message or f"HTTP {status} for {url}" + super().__init__(self.message) + + +class HttpClient: + """ + A lightweight HTTP client for REST API calls. + """ + def __init__( + self, + base_url: str = "", + *, + headers: Optional[Mapping[str, str]] = None, + timeout: float = DEFAULT_TIMEOUT, + retry: Retry = Retry( + backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3 + ), + verify_tls: bool = False, + auth_basic: Optional[Tuple[str, str]] = None, # (username, password) + user_agent: str = DEFAULT_USER_AGENT, + ) -> None: + self.base_url = base_url.rstrip("/") + "/" if base_url and not base_url.endswith("/") else base_url + self._default_headers = {k.lower(): v for k, v in (headers or {}).items()} + self.timeout = timeout + self.retry = retry + self.retry.update_supported_errors((HTTPError, URLError, ssl.SSLError)) + self.verify_tls = verify_tls + self.auth_basic = auth_basic + self.user_agent = user_agent + + # Public JSON-centric helpers + def get( + self, + path: str, + *, + params: Optional[Mapping[str, Union[str, int, float, bool, None, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True + ) -> Union[HttpResponse, Any]: + return self._json_call( + "GET", + path, + params=params, + headers=headers, + timeout=timeout, + body=None, + expect_json=expect_json + ) + + def delete( + self, + path: str, + *, + params: Optional[Mapping[str, Union[str, int, float, bool, None, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True + ) -> Union[HttpResponse, Any]: + return self._json_call( + "DELETE", + path, + params=params, + headers=headers, + timeout=timeout, + body=None, + expect_json=expect_json + ) + + def post( + self, + path: str, + *, + json_body: Any = None, + data: Optional[Union[bytes, str]] = None, + params: Optional[Mapping[str, Union[str, int, float, bool, None, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True + ) -> Union[HttpResponse, Any]: + return self._json_call( + "POST", + path, + params=params, + headers=headers, + timeout=timeout, + body=self._prepare_body(json_body=json_body, data=data), + expect_json=expect_json + ) + + def put( + self, + path: str, + *, + json_body: Any = None, + data: Optional[Union[bytes, str]] = None, + params: Optional[Mapping[str, Union[str, int, float, bool, None, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True + ) -> Union[HttpResponse, Any]: + return self._json_call( + "PUT", + path, + params=params, + headers=headers, + timeout=timeout, + body=self._prepare_body(json_body=json_body, data=data), + expect_json=expect_json + ) + + def patch( + self, + path: str, + *, + json_body: Any = None, + data: Optional[Union[bytes, str]] = None, + params: Optional[Mapping[str, Union[str, int, float, bool, None, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + expect_json: bool = True + ) -> Union[HttpResponse, Any]: + return self._json_call( + "PATCH", + path, + params=params, + headers=headers, + timeout=timeout, + body=self._prepare_body(json_body=json_body, data=data), + expect_json=expect_json + ) + + # Low-level request + def request( + self, + method: str, + path: str, + *, + params: Optional[Mapping[str, Union[str, int, float, bool, None, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + body: Optional[Union[bytes, str]] = None, + timeout: Optional[float] = None, + ) -> HttpResponse: + url = self._build_url(path, params) + all_headers = self._prepare_headers(headers, body) + data = body.encode("utf-8") if isinstance(body, str) else body + + req = Request(url=url, method=method.upper(), data=data, headers=all_headers) + + context = None + if url.lower().startswith("https"): + context = ssl.create_default_context() + if not self.verify_tls: + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + + try: + return self.retry.call_with_retry( + lambda: self._make_request(req, context=context, timeout=timeout), + lambda _: dummy_fail(), + lambda error: self._is_retryable_http_error(error), + ) + except HTTPError as e: + # Read error body, build response, and decide on retry + err_body = b"" + try: + err_body = e.read() + except Exception: + pass + headers_map = {k.lower(): v for k, v in (e.headers or {}).items()} + err_body = self._maybe_decompress(err_body, headers_map) + status = getattr(e, "code", 0) or 0 + response = HttpResponse( + status=status, + headers=headers_map, + url=url, + content=err_body, + ) + return response + + def _make_request( + self, + request: Request, + context: Optional[ssl.SSLContext] = None, + timeout: Optional[float] = None, + ): + with urlopen(request, timeout=timeout or self.timeout, context=context) as resp: + raw = resp.read() + headers_map = {k.lower(): v for k, v in resp.headers.items()} + raw = self._maybe_decompress(raw, headers_map) + return HttpResponse( + status=resp.status, + headers=headers_map, + url=resp.geturl(), + content=raw, + ) + + def _is_retryable_http_error(self, error: Exception) -> bool: + if isinstance(error, HTTPError): + return self._should_retry_status(error.code) + return False + + # Internal utilities + def _json_call( + self, + method: str, + path: str, + *, + params: Optional[Mapping[str, Union[str, int, float, bool, None, list, tuple]]] = None, + headers: Optional[Mapping[str, str]] = None, + timeout: Optional[float] = None, + body: Optional[Union[bytes, str]] = None, + expect_json: bool = True, + ) -> Union[HttpResponse, Any]: + resp = self.request( + method=method, + path=path, + params=params, + headers=headers, + body=body, + timeout=timeout, + ) + if not (200 <= resp.status < 400): + raise HttpError(resp.status, resp.url, resp.text()) + if expect_json: + return resp.json() + return resp + + def _prepare_body(self, *, json_body: Any = None, data: Optional[Union[bytes, str]] = None) -> Optional[Union[bytes, str]]: + if json_body is not None and data is not None: + raise ValueError("Provide either json_body or data, not both.") + if json_body is not None: + return json.dumps(json_body, ensure_ascii=False, separators=(",", ":")) + return data + + def _build_url( + self, + path: str, + params: Optional[Mapping[str, Union[str, int, float, bool, None, list, tuple]]] = None, + ) -> str: + url = urljoin(self.base_url or "", path) + if params: + # urlencode with doseq=True supports list/tuple values + query = urlencode({k: v for k, v in params.items() if v is not None}, doseq=True) + separator = "&" if ("?" in url) else "?" + url = f"{url}{separator}{query}" if query else url + return url + + def _prepare_headers(self, headers: Optional[Mapping[str, str]], body: Optional[Union[bytes, str]]) -> Dict[str, str]: + # Start with defaults + prepared: Dict[str, str] = {} + prepared.update(self._default_headers) + + # Standard defaults for JSON REST usage + prepared.setdefault("accept", "application/json") + prepared.setdefault("user-agent", self.user_agent) + # We will send gzip accept-encoding; handle decompression manually + prepared.setdefault("accept-encoding", "gzip, deflate") + + # If we have a string body and content-type not specified, assume JSON + if body is not None and isinstance(body, str): + prepared.setdefault("content-type", "application/json; charset=utf-8") + + # Basic authentication if provided and not overridden + if self.auth_basic and "authorization" not in prepared: + user, pwd = self.auth_basic + token = base64.b64encode(f"{user}:{pwd}".encode("utf-8")).decode("ascii") + prepared["authorization"] = f"Basic {token}" + + # Merge per-call headers (case-insensitive) + if headers: + for k, v in headers.items(): + prepared[k.lower()] = v + + # urllib expects header keys in canonical capitalization sometimes; but it’s tolerant. + # We'll return as provided; urllib will handle it. + return prepared + + def _should_retry_status(self, status: int) -> bool: + return status in RETRY_STATUS_CODES + + def _maybe_decompress(self, content: bytes, headers: Mapping[str, str]) -> bytes: + if not content: + return content + encoding = (headers.get("content-encoding") or "").lower() + try: + if "gzip" in encoding: + return gzip.decompress(content) + if "deflate" in encoding: + # Try raw deflate, then zlib-wrapped + try: + return zlib.decompress(content, -zlib.MAX_WBITS) + except zlib.error: + return zlib.decompress(content) + except Exception: + # If decompression fails, return original bytes + return content + return content + + +# Example usage: +# client = HttpClient( +# base_url="https://api.example.com/", +# auth_basic=("username", "password"), +# ) +# data = client.get("v1/items", params={"limit": 10}) +# created = client.post("v1/items", json_body={"name": "sample"}) +# resp = client.get("v1/raw", expect_json=False) # returns HttpResponse +# print(resp.text()) \ No newline at end of file diff --git a/redis/multidb/config.py b/redis/multidb/config.py index 64ad7c9052..dd5fcd622b 100644 --- a/redis/multidb/config.py +++ b/redis/multidb/config.py @@ -12,14 +12,13 @@ from redis.multidb.circuit import CircuitBreaker, PBCircuitBreakerAdapter from redis.multidb.database import Database, Databases from redis.multidb.failure_detector import FailureDetector, CommandFailureDetector -from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck +from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, \ + DEFAULT_HEALTH_CHECK_BACKOFF from redis.multidb.failover import FailoverStrategy, WeightBasedFailoverStrategy from redis.retry import Retry DEFAULT_GRACE_PERIOD = 5.0 DEFAULT_HEALTH_CHECK_INTERVAL = 5 -DEFAULT_HEALTH_CHECK_RETRIES = 3 -DEFAULT_HEALTH_CHECK_BACKOFF = ExponentialWithJitterBackoff(cap=10) DEFAULT_FAILURES_THRESHOLD = 3 DEFAULT_FAILURES_DURATION = 2 DEFAULT_FAILOVER_RETRIES = 3 diff --git a/redis/multidb/failover.py b/redis/multidb/failover.py index a4c825aac1..b172a0f8bb 100644 --- a/redis/multidb/failover.py +++ b/redis/multidb/failover.py @@ -6,6 +6,7 @@ from redis.multidb.circuit import State as CBState from redis.multidb.exception import NoValidDatabaseException from redis.retry import Retry +from redis.utils import dummy_fail class FailoverStrategy(ABC): @@ -37,7 +38,7 @@ def __init__( def database(self) -> AbstractDatabase: return self._retry.call_with_retry( lambda: self._get_active_database(), - lambda _: self._dummy_fail() + lambda _: dummy_fail() ) def set_databases(self, databases: Databases) -> None: @@ -49,6 +50,3 @@ def _get_active_database(self) -> AbstractDatabase: return database raise NoValidDatabaseException('No valid database available for communication') - - def _dummy_fail(self): - pass diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py index 1396a1e997..6a067e5166 100644 --- a/redis/multidb/healthcheck.py +++ b/redis/multidb/healthcheck.py @@ -1,6 +1,16 @@ +import logging from abc import abstractmethod, ABC +from typing import Optional, Tuple + +from redis.backoff import ExponentialWithJitterBackoff +from redis.http.http_client import DEFAULT_TIMEOUT, HttpClient from redis.retry import Retry +from redis.utils import dummy_fail + +DEFAULT_HEALTH_CHECK_RETRIES = 3 +DEFAULT_HEALTH_CHECK_BACKOFF = ExponentialWithJitterBackoff(cap=10) +logger = logging.getLogger(__name__) class HealthCheck(ABC): @@ -18,7 +28,7 @@ def check_health(self, database) -> bool: class AbstractHealthCheck(HealthCheck): def __init__( self, - retry: Retry, + retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) ) -> None: self._retry = retry self._retry.update_supported_errors([ConnectionRefusedError]) @@ -34,8 +44,8 @@ def check_health(self, database) -> bool: class EchoHealthCheck(AbstractHealthCheck): def __init__( - self, - retry: Retry, + self, + retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF), ) -> None: """ Check database healthiness by sending an echo request. @@ -46,7 +56,7 @@ def __init__( def check_health(self, database) -> bool: return self._retry.call_with_retry( lambda: self._returns_echoed_message(database), - lambda _: self._dummy_fail() + lambda _: dummy_fail() ) def _returns_echoed_message(self, database) -> bool: @@ -54,5 +64,48 @@ def _returns_echoed_message(self, database) -> bool: actual_message = database.client.execute_command('ECHO', "healthcheck") return actual_message in expected_message - def _dummy_fail(self): - pass \ No newline at end of file +class LagAwareHealthCheck(AbstractHealthCheck): + """ + Health check available for Redis Enterprise deployments. + Verify via REST API that the database is healthy based on different lags. + """ + def __init__( + self, + retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF), + rest_api_port: int = 9443, + timeout: float = DEFAULT_TIMEOUT, + auth_basic: Optional[Tuple[str, str]] = None, + ): + super().__init__( + retry=retry, + ) + self._http_client = HttpClient( + timeout=timeout, + auth_basic=auth_basic, + retry=self.retry, + ) + self._rest_api_port = rest_api_port + + def check_health(self, database) -> bool: + client = database.client + db_host = client.get_connection_kwargs()['host'] + base_url = f"https://{db_host}:{self._rest_api_port}" + self._http_client.base_url = base_url + + # Find bdb matching to current database host. + matching_bdb = None + for bdb in self._http_client.get("/v1/bdbs"): + for endpoint in bdb["endpoints"]: + if endpoint['dns_name'] == db_host: + matching_bdb = bdb + break + + if matching_bdb is None: + logger.warning("LagAwareHealthCheck failed: Couldn't find a matching bdb") + raise ValueError("Could not find a matching bdb") + + url = f"/v1/bdbs/{matching_bdb['uid']}/availability" + self._http_client.get(url, expect_json=False) + + # Status checked in http client, otherwise HttpError will be raised + return True diff --git a/redis/retry.py b/redis/retry.py index c93f34e65f..7989b41742 100644 --- a/redis/retry.py +++ b/redis/retry.py @@ -1,6 +1,6 @@ import socket from time import sleep -from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Type, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Type, TypeVar, Optional from redis.exceptions import ConnectionError, TimeoutError @@ -73,6 +73,7 @@ def call_with_retry( self, do: Callable[[], T], fail: Callable[[Exception], Any], + is_retryable: Optional[Callable[[Exception], bool]] = None ) -> T: """ Execute an operation that might fail and returns its result, or @@ -86,6 +87,8 @@ def call_with_retry( try: return do() except self._supported_errors as error: + if is_retryable and not is_retryable(error): + raise failures += 1 fail(error) if self._retries >= 0 and failures > self._retries: diff --git a/redis/utils.py b/redis/utils.py index 715913e914..94bfab61bb 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -308,3 +308,9 @@ def truncate_text(txt, max_length=100): return textwrap.shorten( text=txt, width=max_length, placeholder="...", break_long_words=True ) + +def dummy_fail(): + """ + Fake function for a Retry object if you don't need to handle each failure. + """ + pass diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index cf3877957f..2713e2a564 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -6,14 +6,15 @@ from redis.event import EventDispatcher, OnCommandsFailEvent from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter -from redis.multidb.config import DEFAULT_HEALTH_CHECK_RETRIES, DEFAULT_HEALTH_CHECK_BACKOFF, DEFAULT_FAILOVER_RETRIES, \ +from redis.multidb.config import DEFAULT_FAILOVER_RETRIES, \ DEFAULT_FAILOVER_BACKOFF from redis.multidb.database import State as DBState, AbstractDatabase from redis.multidb.client import MultiDBClient from redis.multidb.exception import NoValidDatabaseException from redis.multidb.failover import WeightBasedFailoverStrategy from redis.multidb.failure_detector import FailureDetector -from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck +from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, \ + DEFAULT_HEALTH_CHECK_BACKOFF from redis.retry import Retry from tests.test_multidb.conftest import create_weighted_list diff --git a/tests/test_multidb/test_pipeline.py b/tests/test_multidb/test_pipeline.py index 9caad235df..ab911562fc 100644 --- a/tests/test_multidb/test_pipeline.py +++ b/tests/test_multidb/test_pipeline.py @@ -4,15 +4,13 @@ import pybreaker import pytest -from redis.event import EventDispatcher -from redis.exceptions import ConnectionError from redis.client import Pipeline from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter from redis.multidb.client import MultiDBClient -from redis.multidb.config import DEFAULT_HEALTH_CHECK_RETRIES, DEFAULT_HEALTH_CHECK_BACKOFF, DEFAULT_FAILOVER_RETRIES, \ - DEFAULT_FAILOVER_BACKOFF, DEFAULT_FAILURES_THRESHOLD +from redis.multidb.config import DEFAULT_FAILOVER_RETRIES, \ + DEFAULT_FAILOVER_BACKOFF from redis.multidb.failover import WeightBasedFailoverStrategy -from redis.multidb.healthcheck import EchoHealthCheck +from redis.multidb.healthcheck import EchoHealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, DEFAULT_HEALTH_CHECK_BACKOFF from redis.retry import Retry from tests.test_multidb.conftest import create_weighted_list From 36cdf40e572a13f266b418c763ef4dfa816d1101 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 14 Aug 2025 15:45:19 +0300 Subject: [PATCH 02/13] Added testing for LagAwareHealthCheck --- redis/http/http_client.py | 53 ++-- redis/multidb/client.py | 8 +- redis/multidb/config.py | 4 +- redis/multidb/healthcheck.py | 28 +- tests/test_http/__init__.py | 0 tests/test_http/test_http_client.py | 324 ++++++++++++++++++++++ tests/test_multidb/conftest.py | 1 - tests/test_multidb/test_client.py | 202 ++++---------- tests/test_multidb/test_config.py | 8 +- tests/test_multidb/test_healthcheck.py | 142 +++++++++- tests/test_multidb/test_pipeline.py | 98 ++----- tests/test_scenario/conftest.py | 2 +- tests/test_scenario/test_active_active.py | 43 +++ 13 files changed, 671 insertions(+), 242 deletions(-) create mode 100644 tests/test_http/__init__.py create mode 100644 tests/test_http/test_http_client.py diff --git a/redis/http/http_client.py b/redis/http/http_client.py index 30efee661a..fae68c712b 100644 --- a/redis/http/http_client.py +++ b/redis/http/http_client.py @@ -74,7 +74,15 @@ def __init__( retry: Retry = Retry( backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3 ), - verify_tls: bool = False, + verify_tls: bool = True, + # TLS verification (server) options + ca_file: Optional[str] = None, + ca_path: Optional[str] = None, + ca_data: Optional[Union[str, bytes]] = None, + # Mutual TLS (client cert) options + client_cert_file: Optional[str] = None, + client_key_file: Optional[str] = None, + client_key_password: Optional[str] = None, auth_basic: Optional[Tuple[str, str]] = None, # (username, password) user_agent: str = DEFAULT_USER_AGENT, ) -> None: @@ -84,6 +92,15 @@ def __init__( self.retry = retry self.retry.update_supported_errors((HTTPError, URLError, ssl.SSLError)) self.verify_tls = verify_tls + + # TLS settings + self.ca_file = ca_file + self.ca_path = ca_path + self.ca_data = ca_data + self.client_cert_file = client_cert_file + self.client_key_file = client_key_file + self.client_key_password = client_key_password + self.auth_basic = auth_basic self.user_agent = user_agent @@ -206,10 +223,25 @@ def request( req = Request(url=url, method=method.upper(), data=data, headers=all_headers) - context = None + context: Optional[ssl.SSLContext] = None if url.lower().startswith("https"): - context = ssl.create_default_context() - if not self.verify_tls: + if self.verify_tls: + # Use provided CA material if any; fall back to system defaults + context = ssl.create_default_context( + cafile=self.ca_file, + capath=self.ca_path, + cadata=self.ca_data, + ) + # Load client certificate for mTLS if configured + if self.client_cert_file: + context.load_cert_chain( + certfile=self.client_cert_file, + keyfile=self.client_key_file, + password=self.client_key_password, + ) + else: + # Verification disabled + context = ssl.create_default_context() context.check_hostname = False context.verify_mode = ssl.CERT_NONE @@ -354,15 +386,4 @@ def _maybe_decompress(self, content: bytes, headers: Mapping[str, str]) -> bytes except Exception: # If decompression fails, return original bytes return content - return content - - -# Example usage: -# client = HttpClient( -# base_url="https://api.example.com/", -# auth_basic=("username", "password"), -# ) -# data = client.get("v1/items", params={"limit": 10}) -# created = client.post("v1/items", json_body={"name": "sample"}) -# resp = client.get("v1/raw", expect_json=False) # returns HttpResponse -# print(resp.text()) \ No newline at end of file + return content \ No newline at end of file diff --git a/redis/multidb/client.py b/redis/multidb/client.py index 172017f036..b080e4a1a5 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -22,7 +22,11 @@ class MultiDBClient(RedisModuleCommands, CoreCommands): """ def __init__(self, config: MultiDbConfig): self._databases = config.databases() - self._health_checks = config.default_health_checks() if config.health_checks is None else config.health_checks + self._health_checks = config.default_health_checks() + + if config.additional_health_checks is not None: + self._health_checks.extend(config.additional_health_checks) + self._health_check_interval = config.health_check_interval self._failure_detectors = config.default_failure_detectors() \ if config.failure_detectors is None else config.failure_detectors @@ -233,7 +237,7 @@ def _check_db_health(self, database: AbstractDatabase, on_error: Callable[[Excep database.circuit.state = CBState.OPEN elif is_healthy and database.circuit.state != CBState.CLOSED: database.circuit.state = CBState.CLOSED - except (ConnectionError, TimeoutError, socket.timeout, ConnectionRefusedError) as e: + except Exception as e: if database.circuit.state != CBState.OPEN: database.circuit.state = CBState.OPEN is_healthy = False diff --git a/redis/multidb/config.py b/redis/multidb/config.py index dd5fcd622b..8092a4288a 100644 --- a/redis/multidb/config.py +++ b/redis/multidb/config.py @@ -53,7 +53,7 @@ class MultiDbConfig: failure_detectors: Optional list of failure detectors for monitoring database failures. failure_threshold: Threshold for determining database failure. failures_interval: Time interval for tracking database failures. - health_checks: Optional list of health checks performed on databases. + additional_health_checks: Optional list of health checks performed on databases. health_check_interval: Time interval for executing health checks. health_check_retries: Number of retry attempts for performing health checks. health_check_backoff: Backoff strategy for health check retries. @@ -88,7 +88,7 @@ class MultiDbConfig: failure_detectors: Optional[List[FailureDetector]] = None failure_threshold: int = DEFAULT_FAILURES_THRESHOLD failures_interval: float = DEFAULT_FAILURES_DURATION - health_checks: Optional[List[HealthCheck]] = None + additional_health_checks: Optional[List[HealthCheck]] = None health_check_interval: float = DEFAULT_HEALTH_CHECK_INTERVAL health_check_retries: int = DEFAULT_HEALTH_CHECK_RETRIES health_check_backoff: AbstractBackoff = DEFAULT_HEALTH_CHECK_BACKOFF diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py index 6a067e5166..5a47e4b332 100644 --- a/redis/multidb/healthcheck.py +++ b/redis/multidb/healthcheck.py @@ -1,6 +1,6 @@ import logging from abc import abstractmethod, ABC -from typing import Optional, Tuple +from typing import Optional, Tuple, Union from redis.backoff import ExponentialWithJitterBackoff from redis.http.http_client import DEFAULT_TIMEOUT, HttpClient @@ -75,6 +75,15 @@ def __init__( rest_api_port: int = 9443, timeout: float = DEFAULT_TIMEOUT, auth_basic: Optional[Tuple[str, str]] = None, + verify_tls: bool = True, + # TLS verification (server) options + ca_file: Optional[str] = None, + ca_path: Optional[str] = None, + ca_data: Optional[Union[str, bytes]] = None, + # Mutual TLS (client cert) options + client_cert_file: Optional[str] = None, + client_key_file: Optional[str] = None, + client_key_password: Optional[str] = None, ): super().__init__( retry=retry, @@ -83,6 +92,13 @@ def __init__( timeout=timeout, auth_basic=auth_basic, retry=self.retry, + verify_tls=verify_tls, + ca_file=ca_file, + ca_path=ca_path, + ca_data=ca_data, + client_cert_file=client_cert_file, + client_key_file=client_key_file, + client_key_password=client_key_password ) self._rest_api_port = rest_api_port @@ -92,7 +108,7 @@ def check_health(self, database) -> bool: base_url = f"https://{db_host}:{self._rest_api_port}" self._http_client.base_url = base_url - # Find bdb matching to current database host. + # Find bdb matching to the current database host matching_bdb = None for bdb in self._http_client.get("/v1/bdbs"): for endpoint in bdb["endpoints"]: @@ -100,6 +116,12 @@ def check_health(self, database) -> bool: matching_bdb = bdb break + # In case if the host was set as public IP + for addr in endpoint['addr']: + if addr == db_host: + matching_bdb = bdb + break + if matching_bdb is None: logger.warning("LagAwareHealthCheck failed: Couldn't find a matching bdb") raise ValueError("Could not find a matching bdb") @@ -107,5 +129,5 @@ def check_health(self, database) -> bool: url = f"/v1/bdbs/{matching_bdb['uid']}/availability" self._http_client.get(url, expect_json=False) - # Status checked in http client, otherwise HttpError will be raised + # Status checked in an http client, otherwise HttpError will be raised return True diff --git a/tests/test_http/__init__.py b/tests/test_http/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_http/test_http_client.py b/tests/test_http/test_http_client.py new file mode 100644 index 0000000000..9a6d28ecd4 --- /dev/null +++ b/tests/test_http/test_http_client.py @@ -0,0 +1,324 @@ +import json +import gzip +from io import BytesIO +from typing import Any, Dict +from urllib.error import HTTPError +from urllib.parse import urlparse, parse_qs + +import pytest + +from redis.backoff import ExponentialWithJitterBackoff +from redis.http.http_client import HttpClient, HttpError +from redis.retry import Retry + + +class FakeResponse: + def __init__(self, *, status: int, headers: Dict[str, str], url: str, content: bytes): + self.status = status + self.headers = headers + self._url = url + self._content = content + + def read(self) -> bytes: + return self._content + + def geturl(self) -> str: + return self._url + + # Support context manager used by urlopen + def __enter__(self) -> "FakeResponse": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + +class TestHttpClient: + def test_get_returns_parsed_json_and_uses_timeout(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + base_url = "https://api.example.com/" + path = "v1/items" + params = {"limit": 5, "q": "hello world"} + expected_url = f"{base_url}{path}?limit=5&q=hello+world" + payload: Dict[str, Any] = {"items": [1, 2, 3], "ok": True} + content = json.dumps(payload).encode("utf-8") + + captured_kwargs = {} + + def fake_urlopen(request, *, timeout=None, context=None): + # Capture call details for assertions + captured_kwargs["timeout"] = timeout + captured_kwargs["context"] = context + # Assert the request was constructed correctly + assert getattr(request, "method", "").upper() == "GET" + assert request.full_url == expected_url + # Return a successful response + return FakeResponse( + status=200, + headers={"Content-Type": "application/json; charset=utf-8"}, + url=expected_url, + content=content, + ) + + # Patch the urlopen used inside HttpClient + monkeypatch.setattr("redis.http.http_client.urlopen", fake_urlopen) + + client = HttpClient(base_url=base_url) + + # Act + result = client.get(path, params=params, timeout=12.34) # default expect_json=True + + # Assert + assert result == payload + assert pytest.approx(captured_kwargs["timeout"], rel=1e-6) == 12.34 + # HTTPS -> a context should be provided (created by ssl.create_default_context) + assert captured_kwargs["context"] is not None + + def test_get_handles_gzip_response(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + base_url = "https://api.example.com/" + path = "gzip-endpoint" + expected_url = f"{base_url}{path}" + payload = {"message": "compressed ok"} + raw = json.dumps(payload).encode("utf-8") + gzipped = gzip.compress(raw) + + def fake_urlopen(request, *, timeout=None, context=None): + # Return gzipped content with appropriate header + return FakeResponse( + status=200, + headers={ + "Content-Type": "application/json; charset=utf-8", + "Content-Encoding": "gzip", + }, + url=expected_url, + content=gzipped, + ) + + monkeypatch.setattr("redis.http.http_client.urlopen", fake_urlopen) + + client = HttpClient(base_url=base_url) + + # Act + result = client.get(path) # expect_json=True by default + + # Assert + assert result == payload + + def test_get_retries_on_retryable_http_errors_and_succeeds(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange: configure limited retries so we can assert attempts + retry_policy = Retry(backoff=ExponentialWithJitterBackoff(base=0, cap=0), + retries=2) # 2 retries -> up to 3 attempts + base_url = "https://api.example.com/" + path = "sometimes-busy" + expected_url = f"{base_url}{path}" + payload = {"ok": True} + success_content = json.dumps(payload).encode("utf-8") + + call_count = {"n": 0} + + def make_http_error(url: str, code: int, body: bytes = b"busy"): + # Provide a file-like object for .read() when HttpClient tries to read error content + fp = BytesIO(body) + return HTTPError(url=url, code=code, msg="Service Unavailable", hdrs={"Content-Type": "text/plain"}, fp=fp) + + def flaky_urlopen(request, *, timeout=None, context=None): + call_count["n"] += 1 + # Fail with a retryable status (503) for the first two calls, then succeed + if call_count["n"] <= 2: + raise make_http_error(expected_url, 503) + return FakeResponse( + status=200, + headers={"Content-Type": "application/json; charset=utf-8"}, + url=expected_url, + content=success_content, + ) + + monkeypatch.setattr("redis.http.http_client.urlopen", flaky_urlopen) + + client = HttpClient(base_url=base_url, retry=retry_policy) + + # Act + result = client.get(path) + + # Assert: should have retried twice (total 3 attempts) and finally returned parsed JSON + assert result == payload + assert call_count["n"] == retry_policy.get_retries() + 1 + + def test_post_sends_json_body_and_parses_response(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + base_url = "https://api.example.com/" + path = "v1/create" + expected_url = f"{base_url}{path}" + send_payload = {"a": 1, "b": "x"} + recv_payload = {"id": 10, "ok": True} + recv_content = json.dumps(recv_payload, separators=(",", ":")).encode("utf-8") + + def fake_urlopen(request, *, timeout=None, context=None): + # Verify method, URL and headers + assert getattr(request, "method", "").upper() == "POST" + assert request.full_url == expected_url + # Content-Type should be auto-set for string JSON body + assert request.headers.get("Content-type") == "application/json; charset=utf-8" + # Body should be already UTF-8 encoded JSON with no spaces + assert request.data == json.dumps(send_payload, ensure_ascii=False, separators=(",", ":")).encode("utf-8") + return FakeResponse( + status=200, + headers={"Content-Type": "application/json; charset=utf-8"}, + url=expected_url, + content=recv_content, + ) + + monkeypatch.setattr("redis.http.http_client.urlopen", fake_urlopen) + + client = HttpClient(base_url=base_url) + + # Act + result = client.post(path, json_body=send_payload) + + # Assert + assert result == recv_payload + + def test_post_with_raw_data_and_custom_headers(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + base_url = "https://api.example.com/" + path = "upload" + expected_url = f"{base_url}{path}" + raw_data = b"\x00\x01BINARY" + custom_headers = {"Content-type": "application/octet-stream", "X-extra": "1"} + recv_payload = {"status": "ok"} + + def fake_urlopen(request, *, timeout=None, context=None): + assert getattr(request, "method", "").upper() == "POST" + assert request.full_url == expected_url + # Ensure our provided headers are present + assert request.headers.get("Content-type") == "application/octet-stream" + assert request.headers.get("X-extra") == "1" + assert request.data == raw_data + return FakeResponse( + status=200, + headers={"Content-Type": "application/json"}, + url=expected_url, + content=json.dumps(recv_payload).encode("utf-8"), + ) + + monkeypatch.setattr("redis.http.http_client.urlopen", fake_urlopen) + + client = HttpClient(base_url=base_url) + # Act + result = client.post(path, data=raw_data, headers=custom_headers) + + # Assert + assert result == recv_payload + + def test_delete_returns_http_response_when_expect_json_false(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + base_url = "https://api.example.com/" + path = "v1/resource/42" + expected_url = f"{base_url}{path}" + body = b"deleted" + + def fake_urlopen(request, *, timeout=None, context=None): + assert getattr(request, "method", "").upper() == "DELETE" + assert request.full_url == expected_url + return FakeResponse( + status=204, + headers={"Content-Type": "text/plain"}, + url=expected_url, + content=body, + ) + + monkeypatch.setattr("redis.http.http_client.urlopen", fake_urlopen) + client = HttpClient(base_url=base_url) + + # Act + resp = client.delete(path, expect_json=False) + + # Assert + assert resp.status == 204 + assert resp.url == expected_url + assert resp.content == body + + def test_put_raises_http_error_on_non_success(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + base_url = "https://api.example.com/" + path = "v1/update/1" + expected_url = f"{base_url}{path}" + + def make_http_error(url: str, code: int, body: bytes = b"not found"): + fp = BytesIO(body) + return HTTPError(url=url, code=code, msg="Not Found", hdrs={"Content-Type": "text/plain"}, fp=fp) + + def fake_urlopen(request, *, timeout=None, context=None): + raise make_http_error(expected_url, 404) + + monkeypatch.setattr("redis.http.http_client.urlopen", fake_urlopen) + client = HttpClient(base_url=base_url) + + # Act / Assert + with pytest.raises(HttpError) as exc: + client.put(path, json_body={"x": 1}) + assert exc.value.status == 404 + assert exc.value.url == expected_url + + def test_patch_with_params_encodes_query(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + base_url = "https://api.example.com/" + path = "v1/edit" + params = {"tag": ["a", "b"], "q": "hello world"} + + captured_url = {"u": None} + + def fake_urlopen(request, *, timeout=None, context=None): + captured_url["u"] = request.full_url + return FakeResponse( + status=200, + headers={"Content-Type": "application/json"}, + url=request.full_url, + content=json.dumps({"ok": True}).encode("utf-8"), + ) + + monkeypatch.setattr("redis.http.http_client.urlopen", fake_urlopen) + + client = HttpClient(base_url=base_url) + client.patch(path, params=params) # We don't care about response here + + # Assert query parameters regardless of ordering + parsed = urlparse(captured_url["u"]) + qs = parse_qs(parsed.query) + assert qs["q"] == ["hello world"] + assert qs["tag"] == ["a", "b"] + + def test_request_low_level_headers_auth_and_timeout_default(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange: use plain HTTP to verify no TLS context, and check default timeout used + base_url = "http://example.com/" + path = "ping" + captured = {"timeout": None, "context": "unset", "headers": None, "method": None} + + def fake_urlopen(request, *, timeout=None, context=None): + captured["timeout"] = timeout + captured["context"] = context + captured["headers"] = dict(request.headers) + captured["method"] = getattr(request, "method", "").upper() + return FakeResponse( + status=200, + headers={"Content-Type": "application/json"}, + url=request.full_url, + content=json.dumps({"pong": True}).encode("utf-8"), + ) + + monkeypatch.setattr("redis.http.http_client.urlopen", fake_urlopen) + + client = HttpClient(base_url=base_url, auth_basic=("user", "pass")) + resp = client.request("GET", path) + + # Assert + assert resp.status == 200 + assert captured["method"] == "GET" + assert captured["context"] is None # no TLS for http + assert pytest.approx(captured["timeout"], rel=1e-6) == client.timeout # default used + # Check some default headers and Authorization presence + headers = {k.lower(): v for k, v in captured["headers"].items()} + assert "authorization" in headers and headers["authorization"].startswith("Basic ") + assert headers.get("accept") == "application/json" + assert "gzip" in headers.get("accept-encoding", "").lower() + assert "user-agent" in headers \ No newline at end of file diff --git a/tests/test_multidb/conftest.py b/tests/test_multidb/conftest.py index ad2057a118..f85e0a6fd7 100644 --- a/tests/test_multidb/conftest.py +++ b/tests/test_multidb/conftest.py @@ -94,7 +94,6 @@ def mock_multi_db_config( config = MultiDbConfig( databases_config=[Mock(spec=DatabaseConfig)], failure_detectors=[mock_fd], - health_checks=[mock_hc], health_check_interval=hc_interval, failover_strategy=mock_fs, auto_fallback_interval=auto_fallback_interval, diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index 2713e2a564..2bdd9134c3 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -33,26 +33,19 @@ class TestMultiDbClient: indirect=True, ) def test_execute_command_against_correct_db_on_successful_initialization( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object( - mock_multi_db_config, - 'databases', - return_value=databases - ): + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): mock_db1.client.execute_command.return_value = 'OK1' - - for hc in mock_multi_db_config.health_checks: - hc.check_health.return_value = True + mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert client.set('key', 'value') == 'OK1' - - for hc in mock_multi_db_config.health_checks: - assert hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 3 assert mock_db.circuit.state == CBState.CLOSED assert mock_db1.circuit.state == CBState.CLOSED @@ -71,26 +64,19 @@ def test_execute_command_against_correct_db_on_successful_initialization( indirect=True, ) def test_execute_command_against_correct_db_and_closed_circuit( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object( - mock_multi_db_config, - 'databases', - return_value=databases - ): + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): mock_db1.client.execute_command.return_value = 'OK1' - - for hc in mock_multi_db_config.health_checks: - hc.check_health.side_effect = [False, True, True] + mock_hc.check_health.side_effect = [False, True, True] client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert client.set('key', 'value') == 'OK1' - - for hc in mock_multi_db_config.health_checks: - assert hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 3 assert mock_db.circuit.state == CBState.CLOSED assert mock_db1.circuit.state == CBState.CLOSED @@ -125,20 +111,14 @@ def test_execute_command_against_correct_db_on_background_health_check_determine databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object( - mock_multi_db_config, - 'databases', - return_value=databases - ): + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck( + retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + )]): mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'OK', 'error'] mock_db1.client.execute_command.side_effect = ['healthcheck', 'OK1', 'error', 'error', 'healthcheck', 'OK1'] mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'OK2', 'error', 'error'] mock_multi_db_config.health_check_interval = 0.1 - mock_multi_db_config.health_checks = [ - EchoHealthCheck( - retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) - ) - ] mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) ) @@ -169,21 +149,15 @@ def test_execute_command_auto_fallback_to_highest_weight_db( ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object( - mock_multi_db_config, - 'databases', - return_value=databases - ): + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck( + retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + )]): mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'healthcheck', 'healthcheck'] mock_db1.client.execute_command.side_effect = ['healthcheck', 'OK1', 'error', 'healthcheck', 'healthcheck', 'OK1'] mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'OK2', 'healthcheck', 'healthcheck', 'healthcheck'] mock_multi_db_config.health_check_interval = 0.1 mock_multi_db_config.auto_fallback_interval = 0.2 - mock_multi_db_config.health_checks = [ - EchoHealthCheck( - retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) - ) - ] mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) ) @@ -224,26 +198,20 @@ def test_execute_command_auto_fallback_to_highest_weight_db( indirect=True, ) def test_execute_command_throws_exception_on_failed_initialization( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object( - mock_multi_db_config, - 'databases', - return_value=databases - ): - for hc in mock_multi_db_config.health_checks: - hc.check_health.return_value = False + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_hc.check_health.return_value = False client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 with pytest.raises(NoValidDatabaseException, match='Initial connection failed - no active database found'): client.set('key', 'value') - - for hc in mock_multi_db_config.health_checks: - assert hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 3 assert mock_db.state == DBState.DISCONNECTED assert mock_db1.state == DBState.DISCONNECTED @@ -262,26 +230,20 @@ def test_execute_command_throws_exception_on_failed_initialization( indirect=True, ) def test_add_database_throws_exception_on_same_database( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object( - mock_multi_db_config, - 'databases', - return_value=databases - ): - for hc in mock_multi_db_config.health_checks: - hc.check_health.return_value = False + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_hc.check_health.return_value = False client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 with pytest.raises(ValueError, match='Given database already exists'): client.add_database(mock_db) - - for hc in mock_multi_db_config.health_checks: - assert hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 3 @pytest.mark.parametrize( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', @@ -296,36 +258,28 @@ def test_add_database_throws_exception_on_same_database( indirect=True, ) def test_add_database_makes_new_database_active( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db2) - with patch.object( - mock_multi_db_config, - 'databases', - return_value=databases - ): + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): mock_db1.client.execute_command.return_value = 'OK1' mock_db2.client.execute_command.return_value = 'OK2' - - for hc in mock_multi_db_config.health_checks: - hc.check_health.return_value = True + mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert client.set('key', 'value') == 'OK2' - - for hc in mock_multi_db_config.health_checks: - assert hc.check_health.call_count == 2 + assert mock_hc.check_health.call_count == 2 assert mock_db.state == DBState.PASSIVE assert mock_db2.state == DBState.ACTIVE client.add_database(mock_db1) - for hc in mock_multi_db_config.health_checks: - assert hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 3 assert client.set('key', 'value') == 'OK1' @@ -346,28 +300,21 @@ def test_add_database_makes_new_database_active( indirect=True, ) def test_remove_highest_weighted_database( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object( - mock_multi_db_config, - 'databases', - return_value=databases - ): + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): mock_db1.client.execute_command.return_value = 'OK1' mock_db2.client.execute_command.return_value = 'OK2' - - for hc in mock_multi_db_config.health_checks: - hc.check_health.return_value = True + mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert client.set('key', 'value') == 'OK1' - - for hc in mock_multi_db_config.health_checks: - assert hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 3 assert mock_db.state == DBState.PASSIVE assert mock_db1.state == DBState.ACTIVE @@ -393,28 +340,21 @@ def test_remove_highest_weighted_database( indirect=True, ) def test_update_database_weight_to_be_highest( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object( - mock_multi_db_config, - 'databases', - return_value=databases - ): + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): mock_db1.client.execute_command.return_value = 'OK1' mock_db2.client.execute_command.return_value = 'OK2' - - for hc in mock_multi_db_config.health_checks: - hc.check_health.return_value = True + mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert client.set('key', 'value') == 'OK1' - - for hc in mock_multi_db_config.health_checks: - assert hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 3 assert mock_db.state == DBState.PASSIVE assert mock_db1.state == DBState.ACTIVE @@ -442,15 +382,12 @@ def test_update_database_weight_to_be_highest( indirect=True, ) def test_add_new_failure_detector( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object( - mock_multi_db_config, - 'databases', - return_value=databases - ): + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): mock_db1.client.execute_command.return_value = 'OK1' mock_multi_db_config.event_dispatcher = EventDispatcher() mock_fd = mock_multi_db_config.failure_detectors[0] @@ -460,16 +397,12 @@ def test_add_new_failure_detector( commands=('SET', 'key', 'value'), exception=Exception(), ) - - for hc in mock_multi_db_config.health_checks: - hc.check_health.return_value = True + mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert client.set('key', 'value') == 'OK1' - - for hc in mock_multi_db_config.health_checks: - assert hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 3 # Simulate failing command events that lead to a failure detection for i in range(5): @@ -500,26 +433,19 @@ def test_add_new_failure_detector( indirect=True, ) def test_add_new_health_check( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object( - mock_multi_db_config, - 'databases', - return_value=databases - ): + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): mock_db1.client.execute_command.return_value = 'OK1' - - for hc in mock_multi_db_config.health_checks: - hc.check_health.return_value = True + mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert client.set('key', 'value') == 'OK1' - - for hc in mock_multi_db_config.health_checks: - assert hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 3 another_hc = Mock(spec=HealthCheck) another_hc.check_health.return_value = True @@ -542,27 +468,20 @@ def test_add_new_health_check( indirect=True, ) def test_set_active_database( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object( - mock_multi_db_config, - 'databases', - return_value=databases - ): + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): mock_db1.client.execute_command.return_value = 'OK1' mock_db.client.execute_command.return_value = 'OK' - - for hc in mock_multi_db_config.health_checks: - hc.check_health.return_value = True + mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 assert client.set('key', 'value') == 'OK1' - - for hc in mock_multi_db_config.health_checks: - assert hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 3 assert mock_db.state == DBState.PASSIVE assert mock_db1.state == DBState.ACTIVE @@ -578,8 +497,7 @@ def test_set_active_database( with pytest.raises(ValueError, match='Given database is not a member of database list'): client.set_active_database(Mock(spec=AbstractDatabase)) - for hc in mock_multi_db_config.health_checks: - hc.check_health.return_value = False + mock_hc.check_health.return_value = False with pytest.raises(NoValidDatabaseException, match='Cannot set active database, database is unhealthy'): client.set_active_database(mock_db1) \ No newline at end of file diff --git a/tests/test_multidb/test_config.py b/tests/test_multidb/test_config.py index 87aae701a9..7b72f65bbf 100644 --- a/tests/test_multidb/test_config.py +++ b/tests/test_multidb/test_config.py @@ -75,7 +75,7 @@ def test_overridden_config(self): config = MultiDbConfig( databases_config=db_configs, failure_detectors=mock_failure_detectors, - health_checks=mock_health_checks, + additional_health_checks=mock_health_checks, health_check_interval=health_check_interval, failover_strategy=mock_failover_strategy, auto_fallback_interval=auto_fallback_interval, @@ -96,9 +96,9 @@ def test_overridden_config(self): assert len(config.failure_detectors) == 2 assert config.failure_detectors[0] == mock_failure_detectors[0] assert config.failure_detectors[1] == mock_failure_detectors[1] - assert len(config.health_checks) == 2 - assert config.health_checks[0] == mock_health_checks[0] - assert config.health_checks[1] == mock_health_checks[1] + assert len(config.additional_health_checks) == 2 + assert config.additional_health_checks[0] == mock_health_checks[0] + assert config.additional_health_checks[1] == mock_health_checks[1] assert config.health_check_interval == health_check_interval assert config.failover_strategy == mock_failover_strategy assert config.auto_fallback_interval == auto_fallback_interval diff --git a/tests/test_multidb/test_healthcheck.py b/tests/test_multidb/test_healthcheck.py index 9601638913..c9655bb121 100644 --- a/tests/test_multidb/test_healthcheck.py +++ b/tests/test_multidb/test_healthcheck.py @@ -1,6 +1,11 @@ +from unittest.mock import MagicMock + +import pytest + from redis.backoff import ExponentialBackoff +from redis.http.http_client import HttpError from redis.multidb.database import Database, State -from redis.multidb.healthcheck import EchoHealthCheck +from redis.multidb.healthcheck import EchoHealthCheck, LagAwareHealthCheck from redis.multidb.circuit import State as CBState from redis.exceptions import ConnectionError from redis.retry import Retry @@ -38,4 +43,137 @@ def test_database_close_circuit_on_successful_healthcheck(self, mock_client, moc db = Database(mock_client, mock_cb, 0.9, State.ACTIVE) assert hc.check_health(db) == True - assert mock_client.execute_command.call_count == 3 \ No newline at end of file + assert mock_client.execute_command.call_count == 3 + + +class TestLagAwareHealthCheck: + def test_database_is_healthy_when_bdb_matches_by_dns_name(self, mock_client, mock_cb): + """ + Ensures health check succeeds when /v1/bdbs contains an endpoint whose dns_name + matches database host, and availability endpoint returns success. + """ + host = "db1.example.com" + mock_client.get_connection_kwargs.return_value = {"host": host} + + # Mock HttpClient used inside LagAwareHealthCheck + mock_http = MagicMock() + mock_http.get.side_effect = [ + # First call: list of bdbs + [ + { + "uid": "bdb-1", + "endpoints": [ + {"dns_name": host, "addr": ["10.0.0.1", "10.0.0.2"]}, + ], + } + ], + # Second call: availability check (no JSON expected) + None, + ] + + hc = LagAwareHealthCheck( + retry=Retry(backoff=ExponentialBackoff(cap=1.0), retries=3), + rest_api_port=1234, + ) + # Inject our mocked http client + hc._http_client = mock_http + + db = Database(mock_client, mock_cb, 1.0, State.ACTIVE) + + assert hc.check_health(db) is True + # Base URL must be set correctly + assert hc._http_client.base_url == f"https://{host}:1234" + # Calls: first to list bdbs, then to availability + assert mock_http.get.call_count == 2 + first_call = mock_http.get.call_args_list[0] + second_call = mock_http.get.call_args_list[1] + assert first_call.args[0] == "/v1/bdbs" + assert second_call.args[0] == "/v1/bdbs/bdb-1/availability" + assert second_call.kwargs.get("expect_json") is False + + def test_database_is_healthy_when_bdb_matches_by_addr(self, mock_client, mock_cb): + """ + Ensures health check succeeds when endpoint addr list contains the database host. + """ + host_ip = "203.0.113.5" + mock_client.get_connection_kwargs.return_value = {"host": host_ip} + + mock_http = MagicMock() + mock_http.get.side_effect = [ + [ + { + "uid": "bdb-42", + "endpoints": [ + {"dns_name": "not-matching.example.com", "addr": [host_ip]}, + ], + } + ], + None, + ] + + hc = LagAwareHealthCheck( + retry=Retry(backoff=ExponentialBackoff(cap=1.0), retries=3), + ) + hc._http_client = mock_http + + db = Database(mock_client, mock_cb, 1.0, State.ACTIVE) + + assert hc.check_health(db) is True + assert mock_http.get.call_count == 2 + assert mock_http.get.call_args_list[1].args[0] == "/v1/bdbs/bdb-42/availability" + + def test_raises_value_error_when_no_matching_bdb(self, mock_client, mock_cb): + """ + Ensures health check raises ValueError when there's no bdb matching the database host. + """ + host = "db2.example.com" + mock_client.get_connection_kwargs.return_value = {"host": host} + + mock_http = MagicMock() + # Return bdbs that do not match host by dns_name nor addr + mock_http.get.return_value = [ + {"uid": "a", "endpoints": [{"dns_name": "other.example.com", "addr": ["10.0.0.9"]}]}, + {"uid": "b", "endpoints": [{"dns_name": "another.example.com", "addr": ["10.0.0.10"]}]}, + ] + + hc = LagAwareHealthCheck( + retry=Retry(backoff=ExponentialBackoff(cap=1.0), retries=3), + ) + hc._http_client = mock_http + + db = Database(mock_client, mock_cb, 1.0, State.ACTIVE) + + with pytest.raises(ValueError, match="Could not find a matching bdb"): + hc.check_health(db) + + # Only the listing call should have happened + mock_http.get.assert_called_once_with("/v1/bdbs") + + def test_propagates_http_error_from_availability(self, mock_client, mock_cb): + """ + Ensures that any HTTP error raised by the availability endpoint is propagated. + """ + host = "db3.example.com" + mock_client.get_connection_kwargs.return_value = {"host": host} + + mock_http = MagicMock() + # First: list bdbs -> match by dns_name + mock_http.get.side_effect = [ + [{"uid": "bdb-err", "endpoints": [{"dns_name": host, "addr": []}]}], + # Second: availability -> raise HttpError + HttpError(url=f"https://{host}:9443/v1/bdbs/bdb-err/availability", status=503, message="busy"), + ] + + hc = LagAwareHealthCheck( + retry=Retry(backoff=ExponentialBackoff(cap=1.0), retries=3), + ) + hc._http_client = mock_http + + db = Database(mock_client, mock_cb, 1.0, State.ACTIVE) + + with pytest.raises(HttpError, match="busy") as e: + hc.check_health(db) + assert e.status == 503 + + # Ensure both calls were attempted + assert mock_http.get.call_count == 2 \ No newline at end of file diff --git a/tests/test_multidb/test_pipeline.py b/tests/test_multidb/test_pipeline.py index ab911562fc..26cad795fb 100644 --- a/tests/test_multidb/test_pipeline.py +++ b/tests/test_multidb/test_pipeline.py @@ -34,21 +34,16 @@ class TestPipeline: indirect=True, ) def test_executes_pipeline_against_correct_db( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object( - mock_multi_db_config, - 'databases', - return_value=databases - ): + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): pipe = mock_pipe() pipe.execute.return_value = ['OK1', 'value1'] mock_db1.client.pipeline.return_value = pipe - - for hc in mock_multi_db_config.health_checks: - hc.check_health.return_value = True + mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 @@ -58,9 +53,7 @@ def test_executes_pipeline_against_correct_db( pipe.get('key1') assert pipe.execute() == ['OK1', 'value1'] - - for hc in mock_multi_db_config.health_checks: - assert hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 3 @pytest.mark.parametrize( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', @@ -75,21 +68,16 @@ def test_executes_pipeline_against_correct_db( indirect=True, ) def test_execute_pipeline_against_correct_db_and_closed_circuit( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object( - mock_multi_db_config, - 'databases', - return_value=databases - ): + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): pipe = mock_pipe() pipe.execute.return_value = ['OK1', 'value1'] mock_db1.client.pipeline.return_value = pipe - - for hc in mock_multi_db_config.health_checks: - hc.check_health.side_effect = [False, True, True] + mock_hc.check_health.side_effect = [False, True, True] client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 @@ -99,9 +87,7 @@ def test_execute_pipeline_against_correct_db_and_closed_circuit( pipe.get('key1') assert pipe.execute() == ['OK1', 'value1'] - - for hc in mock_multi_db_config.health_checks: - assert hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 3 assert mock_db.circuit.state == CBState.CLOSED assert mock_db1.circuit.state == CBState.CLOSED @@ -120,7 +106,7 @@ def test_execute_pipeline_against_correct_db_and_closed_circuit( indirect=True, ) def test_execute_pipeline_against_correct_db_on_background_health_check_determine_active_db_unhealthy( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): cb = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) cb.database = mock_db @@ -136,11 +122,10 @@ def test_execute_pipeline_against_correct_db_on_background_health_check_determin databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object( - mock_multi_db_config, - 'databases', - return_value=databases - ): + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck( + retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + )]): mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'error'] mock_db1.client.execute_command.side_effect = ['healthcheck', 'error', 'error', 'healthcheck'] mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'error', 'error'] @@ -158,11 +143,6 @@ def test_execute_pipeline_against_correct_db_on_background_health_check_determin mock_db2.client.pipeline.return_value = pipe2 mock_multi_db_config.health_check_interval = 0.1 - mock_multi_db_config.health_checks = [ - EchoHealthCheck( - retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) - ) - ] mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) ) @@ -214,19 +194,14 @@ class TestTransaction: indirect=True, ) def test_executes_transaction_against_correct_db( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object( - mock_multi_db_config, - 'databases', - return_value=databases - ): + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): mock_db1.client.transaction.return_value = ['OK1', 'value1'] - - for hc in mock_multi_db_config.health_checks: - hc.check_health.return_value = True + mock_hc.check_health.return_value = True client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 @@ -236,9 +211,7 @@ def callback(pipe: Pipeline): pipe.get('key1') assert client.transaction(callback) == ['OK1', 'value1'] - - for hc in mock_multi_db_config.health_checks: - assert hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 3 @pytest.mark.parametrize( 'mock_multi_db_config,mock_db, mock_db1, mock_db2', @@ -253,19 +226,14 @@ def callback(pipe: Pipeline): indirect=True, ) def test_execute_transaction_against_correct_db_and_closed_circuit( - self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc ): databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object( - mock_multi_db_config, - 'databases', - return_value=databases - ): + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): mock_db1.client.transaction.return_value = ['OK1', 'value1'] - - for hc in mock_multi_db_config.health_checks: - hc.check_health.side_effect = [False, True, True] + mock_hc.check_health.side_effect = [False, True, True] client = MultiDBClient(mock_multi_db_config) assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 @@ -275,9 +243,7 @@ def callback(pipe: Pipeline): pipe.get('key1') assert client.transaction(callback) == ['OK1', 'value1'] - - for hc in mock_multi_db_config.health_checks: - assert hc.check_health.call_count == 3 + assert mock_hc.check_health.call_count == 3 assert mock_db.circuit.state == CBState.CLOSED assert mock_db1.circuit.state == CBState.CLOSED @@ -312,11 +278,10 @@ def test_execute_transaction_against_correct_db_on_background_health_check_deter databases = create_weighted_list(mock_db, mock_db1, mock_db2) - with patch.object( - mock_multi_db_config, - 'databases', - return_value=databases - ): + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck( + retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + )]): mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'error'] mock_db1.client.execute_command.side_effect = ['healthcheck', 'error', 'error', 'healthcheck'] mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'error', 'error'] @@ -326,11 +291,6 @@ def test_execute_transaction_against_correct_db_on_background_health_check_deter mock_db2.client.transaction.return_value = ['OK2', 'value'] mock_multi_db_config.health_check_interval = 0.1 - mock_multi_db_config.health_checks = [ - EchoHealthCheck( - retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) - ) - ] mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) ) diff --git a/tests/test_scenario/conftest.py b/tests/test_scenario/conftest.py index 486dc948f1..45ebddb357 100644 --- a/tests/test_scenario/conftest.py +++ b/tests/test_scenario/conftest.py @@ -84,7 +84,7 @@ def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListen config = MultiDbConfig( databases_config=db_configs, - health_checks=health_checks, + additional_health_checks=health_checks, command_retry=command_retry, failure_threshold=failure_threshold, health_check_interval=health_check_interval, diff --git a/tests/test_scenario/test_active_active.py b/tests/test_scenario/test_active_active.py index 071babb6c0..f50511365f 100644 --- a/tests/test_scenario/test_active_active.py +++ b/tests/test_scenario/test_active_active.py @@ -1,11 +1,13 @@ import json import logging +import os import threading from time import sleep import pytest from redis.client import Pipeline +from redis.multidb.healthcheck import LagAwareHealthCheck from tests.test_scenario.conftest import get_endpoint_config from tests.test_scenario.fault_injector_client import ActionRequest, ActionType @@ -70,6 +72,47 @@ def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector assert listener.is_changed_flag == True + @pytest.mark.parametrize( + "r_multi_db", + [ + {"failure_threshold": 2} + ], + indirect=True + ) + def test_multi_db_client_uses_lag_aware_health_check(self, r_multi_db, fault_injector_client): + event = threading.Event() + thread = threading.Thread( + target=trigger_network_failure_action, + daemon=True, + args=(fault_injector_client,event) + ) + + r_multi_db, listener = r_multi_db + + env0_username = os.getenv('ENV0_USERNAME') + env0_password = os.getenv('ENV0_PASSWORD') + + # Adding additional health check to the client. + r_multi_db.add_health_check( + LagAwareHealthCheck(verify_tls=False, auth_basic=(env0_username,env0_password)) + ) + + # Client initialized on the first command. + r_multi_db.set('key', 'value') + thread.start() + + # Execute commands before network failure + while not event.is_set(): + assert r_multi_db.get('key') == 'value' + sleep(0.1) + + # Execute commands after network failure + for _ in range(3): + assert r_multi_db.get('key') == 'value' + sleep(0.1) + + assert listener.is_changed_flag == True + @pytest.mark.parametrize( "r_multi_db", [ From 8551965ef7c37517063e2a4e33f9d9aa799cc33a Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 14 Aug 2025 16:38:21 +0300 Subject: [PATCH 03/13] Fixed timeouts --- tests/test_scenario/conftest.py | 4 ++-- tests/test_scenario/test_active_active.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_scenario/conftest.py b/tests/test_scenario/conftest.py index 45ebddb357..73dbaa7d69 100644 --- a/tests/test_scenario/conftest.py +++ b/tests/test_scenario/conftest.py @@ -51,7 +51,6 @@ def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListen # Retry configuration different for health checks as initial health check require more time in case # if infrastructure wasn't restored from the previous test. - health_checks = [EchoHealthCheck(Retry(ExponentialBackoff(cap=5, base=0.5), retries=3))] health_check_interval = request.param.get('health_check_interval', DEFAULT_HEALTH_CHECK_INTERVAL) event_dispatcher = EventDispatcher() listener = CheckActiveDatabaseChangedListener() @@ -84,9 +83,10 @@ def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListen config = MultiDbConfig( databases_config=db_configs, - additional_health_checks=health_checks, command_retry=command_retry, failure_threshold=failure_threshold, + health_check_backoff=ExponentialBackoff(cap=0.5, base=0.05), + health_check_retries=3, health_check_interval=health_check_interval, event_dispatcher=event_dispatcher, ) diff --git a/tests/test_scenario/test_active_active.py b/tests/test_scenario/test_active_active.py index f50511365f..d5858f7d97 100644 --- a/tests/test_scenario/test_active_active.py +++ b/tests/test_scenario/test_active_active.py @@ -37,7 +37,7 @@ class TestActiveActiveStandalone: def teardown_method(self, method): # Timeout so the cluster could recover from network failure. - sleep(3) + sleep(4) @pytest.mark.parametrize( "r_multi_db", From 3299cd6f18f42a6b65e9b9de7d5f98542b614de5 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Mon, 18 Aug 2025 11:41:14 +0300 Subject: [PATCH 04/13] Added lag tollerance parameter --- redis/multidb/healthcheck.py | 22 +++++++++++++++- tests/test_scenario/test_active_active.py | 32 +++++++++++------------ 2 files changed, 37 insertions(+), 17 deletions(-) diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py index 5a47e4b332..3151bc58ae 100644 --- a/redis/multidb/healthcheck.py +++ b/redis/multidb/healthcheck.py @@ -73,6 +73,7 @@ def __init__( self, retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF), rest_api_port: int = 9443, + availability_lag_tolerance: int = 100, timeout: float = DEFAULT_TIMEOUT, auth_basic: Optional[Tuple[str, str]] = None, verify_tls: bool = True, @@ -85,6 +86,24 @@ def __init__( client_key_file: Optional[str] = None, client_key_password: Optional[str] = None, ): + """ + Initialize LagAwareHealthCheck with the specified parameters. + + Args: + retry: Retry configuration for health checks + rest_api_port: Port number for Redis Enterprise REST API (default: 9443) + availability_lag_tolerance: Maximum acceptable lag in milliseconds (default: 100) + timeout: Request timeout in seconds (default: DEFAULT_TIMEOUT) + auth_basic: Tuple of (username, password) for basic authentication + verify_tls: Whether to verify TLS certificates (default: True) + ca_file: Path to CA certificate file for TLS verification + ca_path: Path to CA certificates directory for TLS verification + ca_data: CA certificate data as string or bytes + client_cert_file: Path to client certificate file for mutual TLS + client_key_file: Path to client private key file for mutual TLS + client_key_password: Password for encrypted client private key + """ + super().__init__( retry=retry, ) @@ -101,6 +120,7 @@ def __init__( client_key_password=client_key_password ) self._rest_api_port = rest_api_port + self._availability_lag_tolerance = availability_lag_tolerance def check_health(self, database) -> bool: client = database.client @@ -126,7 +146,7 @@ def check_health(self, database) -> bool: logger.warning("LagAwareHealthCheck failed: Couldn't find a matching bdb") raise ValueError("Could not find a matching bdb") - url = f"/v1/bdbs/{matching_bdb['uid']}/availability" + url = f"/v1/bdbs/{matching_bdb['uid']}/availability?availability_lag_tolerance_ms={self._availability_lag_tolerance}" self._http_client.get(url, expect_json=False) # Status checked in an http client, otherwise HttpError will be raised diff --git a/tests/test_scenario/test_active_active.py b/tests/test_scenario/test_active_active.py index d5858f7d97..cebd14b577 100644 --- a/tests/test_scenario/test_active_active.py +++ b/tests/test_scenario/test_active_active.py @@ -17,7 +17,7 @@ def trigger_network_failure_action(fault_injector_client, event: threading.Event endpoint_config = get_endpoint_config('re-active-active') action_request = ActionRequest( action_type=ActionType.NETWORK_FAILURE, - parameters={"bdb_id": endpoint_config['bdb_id'], "delay": 2, "cluster_index": 0} + parameters={"bdb_id": endpoint_config['bdb_id'], "delay": 3, "cluster_index": 0} ) result = fault_injector_client.trigger_action(action_request) @@ -37,7 +37,7 @@ class TestActiveActiveStandalone: def teardown_method(self, method): # Timeout so the cluster could recover from network failure. - sleep(4) + sleep(5) @pytest.mark.parametrize( "r_multi_db", @@ -63,12 +63,12 @@ def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector # Execute commands before network failure while not event.is_set(): assert r_multi_db.get('key') == 'value' - sleep(0.1) + sleep(0.5) # Execute commands after network failure for _ in range(3): assert r_multi_db.get('key') == 'value' - sleep(0.1) + sleep(0.5) assert listener.is_changed_flag == True @@ -104,12 +104,12 @@ def test_multi_db_client_uses_lag_aware_health_check(self, r_multi_db, fault_inj # Execute commands before network failure while not event.is_set(): assert r_multi_db.get('key') == 'value' - sleep(0.1) + sleep(0.5) # Execute commands after network failure for _ in range(3): assert r_multi_db.get('key') == 'value' - sleep(0.1) + sleep(0.5) assert listener.is_changed_flag == True @@ -152,7 +152,7 @@ def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault pipe.get('{hash}key2') pipe.get('{hash}key3') assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] - sleep(0.1) + sleep(0.5) # Execute pipeline after network failure for _ in range(3): @@ -164,7 +164,7 @@ def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault pipe.get('{hash}key2') pipe.get('{hash}key3') assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] - sleep(0.1) + sleep(0.5) assert listener.is_changed_flag == True @@ -206,7 +206,7 @@ def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_inject pipe.get('{hash}key2') pipe.get('{hash}key3') assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] - sleep(0.1) + sleep(0.5) # Execute pipeline after network failure for _ in range(3): @@ -217,7 +217,7 @@ def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_inject pipe.get('{hash}key2') pipe.get('{hash}key3') assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] - sleep(0.1) + sleep(0.5) assert listener.is_changed_flag == True @@ -253,12 +253,12 @@ def callback(pipe: Pipeline): # Execute pipeline before network failure while not event.is_set(): r_multi_db.transaction(callback) - sleep(0.1) + sleep(0.5) # Execute pipeline after network failure for _ in range(3): r_multi_db.transaction(callback) - sleep(0.1) + sleep(0.5) assert listener.is_changed_flag == True @@ -295,12 +295,12 @@ def handler(message): # Execute pipeline before network failure while not event.is_set(): r_multi_db.publish('test-channel', data) - sleep(0.1) + sleep(0.5) # Execute pipeline after network failure for _ in range(3): r_multi_db.publish('test-channel', data) - sleep(0.1) + sleep(0.5) pubsub_thread.stop() @@ -340,12 +340,12 @@ def handler(message): # Execute pipeline before network failure while not event.is_set(): r_multi_db.spublish('test-channel', data) - sleep(0.1) + sleep(0.5) # Execute pipeline after network failure for _ in range(3): r_multi_db.spublish('test-channel', data) - sleep(0.1) + sleep(0.5) pubsub_thread.stop() From 0d88c780e7af8436dc271a625ae41643a6d08b68 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Mon, 18 Aug 2025 14:53:48 +0300 Subject: [PATCH 05/13] Decreased messages_count due to increased timeouts --- tests/test_scenario/test_active_active.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_scenario/test_active_active.py b/tests/test_scenario/test_active_active.py index cebd14b577..c3432bc327 100644 --- a/tests/test_scenario/test_active_active.py +++ b/tests/test_scenario/test_active_active.py @@ -305,7 +305,7 @@ def handler(message): pubsub_thread.stop() assert listener.is_changed_flag == True - assert messages_count > 5 + assert messages_count > 2 @pytest.mark.parametrize( "r_multi_db", @@ -350,4 +350,4 @@ def handler(message): pubsub_thread.stop() assert listener.is_changed_flag == True - assert messages_count > 5 \ No newline at end of file + assert messages_count > 2 \ No newline at end of file From 229fb2b2e2d3a671744f28baaec765a45d2d6e53 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Mon, 18 Aug 2025 15:19:14 +0300 Subject: [PATCH 06/13] Added docblocks --- redis/http/http_client.py | 23 +++++++++++++++++++++++ redis/multidb/healthcheck.py | 1 - 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/redis/http/http_client.py b/redis/http/http_client.py index fae68c712b..67fa375e83 100644 --- a/redis/http/http_client.py +++ b/redis/http/http_client.py @@ -86,6 +86,29 @@ def __init__( auth_basic: Optional[Tuple[str, str]] = None, # (username, password) user_agent: str = DEFAULT_USER_AGENT, ) -> None: + """ + Initialize a new HTTP client instance. + + Args: + base_url: Base URL for all requests. Will be prefixed to all paths. + headers: Default headers to include in all requests. + timeout: Default timeout in seconds for requests. + retry: Retry configuration for failed requests. + verify_tls: Whether to verify TLS certificates. + ca_file: Path to CA certificate file for TLS verification. + ca_path: Path to a directory containing CA certificates. + ca_data: CA certificate data as string or bytes. + client_cert_file: Path to client certificate for mutual TLS. + client_key_file: Path to a client private key for mutual TLS. + client_key_password: Password for an encrypted client private key. + auth_basic: Tuple of (username, password) for HTTP basic auth. + user_agent: User-Agent header value for requests. + + The client supports both regular HTTPS with server verification and mutual TLS + authentication. For server verification, provide CA certificate information via + ca_file, ca_path or ca_data. For mutual TLS, additionally provide a client + certificate and key via client_cert_file and client_key_file. + """ self.base_url = base_url.rstrip("/") + "/" if base_url and not base_url.endswith("/") else base_url self._default_headers = {k.lower(): v for k, v in (headers or {}).items()} self.timeout = timeout diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py index 3151bc58ae..ef3ee7b855 100644 --- a/redis/multidb/healthcheck.py +++ b/redis/multidb/healthcheck.py @@ -103,7 +103,6 @@ def __init__( client_key_file: Path to client private key file for mutual TLS client_key_password: Password for encrypted client private key """ - super().__init__( retry=retry, ) From 3563bead110c41fd788068cf2b632d0aa44d82e7 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Mon, 18 Aug 2025 15:46:41 +0300 Subject: [PATCH 07/13] Added missing type hints --- redis/http/http_client.py | 10 +++++----- redis/multidb/healthcheck.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/redis/http/http_client.py b/redis/http/http_client.py index 67fa375e83..bd153791b0 100644 --- a/redis/http/http_client.py +++ b/redis/http/http_client.py @@ -109,7 +109,7 @@ def __init__( ca_file, ca_path or ca_data. For mutual TLS, additionally provide a client certificate and key via client_cert_file and client_key_file. """ - self.base_url = base_url.rstrip("/") + "/" if base_url and not base_url.endswith("/") else base_url + self.base_url = base_url.rstrip() + "/" if base_url and not base_url.endswith("/") else base_url self._default_headers = {k.lower(): v for k, v in (headers or {}).items()} self.timeout = timeout self.retry = retry @@ -170,7 +170,7 @@ def post( self, path: str, *, - json_body: Any = None, + json_body: Optional[Any] = None, data: Optional[Union[bytes, str]] = None, params: Optional[Mapping[str, Union[str, int, float, bool, None, list, tuple]]] = None, headers: Optional[Mapping[str, str]] = None, @@ -191,7 +191,7 @@ def put( self, path: str, *, - json_body: Any = None, + json_body: Optional[Any] = None, data: Optional[Union[bytes, str]] = None, params: Optional[Mapping[str, Union[str, int, float, bool, None, list, tuple]]] = None, headers: Optional[Mapping[str, str]] = None, @@ -212,7 +212,7 @@ def patch( self, path: str, *, - json_body: Any = None, + json_body: Optional[Any] = None, data: Optional[Union[bytes, str]] = None, params: Optional[Mapping[str, Union[str, int, float, bool, None, list, tuple]]] = None, headers: Optional[Mapping[str, str]] = None, @@ -340,7 +340,7 @@ def _json_call( return resp.json() return resp - def _prepare_body(self, *, json_body: Any = None, data: Optional[Union[bytes, str]] = None) -> Optional[Union[bytes, str]]: + def _prepare_body(self, *, json_body: Optional[Any] = None, data: Optional[Union[bytes, str]] = None) -> Optional[Union[bytes, str]]: if json_body is not None and data is not None: raise ValueError("Provide either json_body or data, not both.") if json_body is not None: diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py index ef3ee7b855..f69e804c24 100644 --- a/redis/multidb/healthcheck.py +++ b/redis/multidb/healthcheck.py @@ -28,7 +28,7 @@ def check_health(self, database) -> bool: class AbstractHealthCheck(HealthCheck): def __init__( self, - retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + retry: Retry = Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) ) -> None: self._retry = retry self._retry.update_supported_errors([ConnectionRefusedError]) @@ -45,7 +45,7 @@ def check_health(self, database) -> bool: class EchoHealthCheck(AbstractHealthCheck): def __init__( self, - retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF), + retry: Retry = Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) ) -> None: """ Check database healthiness by sending an echo request. @@ -71,7 +71,7 @@ class LagAwareHealthCheck(AbstractHealthCheck): """ def __init__( self, - retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF), + retry: Retry = Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF), rest_api_port: int = 9443, availability_lag_tolerance: int = 100, timeout: float = DEFAULT_TIMEOUT, From d2c5756c674c57333012ecedd83feb9be7d7d163 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Wed, 20 Aug 2025 09:38:49 +0300 Subject: [PATCH 08/13] Fixed url --- redis/multidb/healthcheck.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py index f69e804c24..02394470bb 100644 --- a/redis/multidb/healthcheck.py +++ b/redis/multidb/healthcheck.py @@ -145,7 +145,7 @@ def check_health(self, database) -> bool: logger.warning("LagAwareHealthCheck failed: Couldn't find a matching bdb") raise ValueError("Could not find a matching bdb") - url = f"/v1/bdbs/{matching_bdb['uid']}/availability?availability_lag_tolerance_ms={self._availability_lag_tolerance}" + url = f"/v1/local/bdbs/{matching_bdb['uid']}/endpoint/availability?extend_check=lag&availability_lag_tolerance_ms={self._availability_lag_tolerance}" self._http_client.get(url, expect_json=False) # Status checked in an http client, otherwise HttpError will be raised From a57333be72a2df8dfdeeaaa328348a9e00a164a5 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 21 Aug 2025 12:53:08 +0300 Subject: [PATCH 09/13] Refactored tests, URL and cluster support --- redis/cluster.py | 1 + redis/multidb/healthcheck.py | 13 ++++++++----- tests/test_multidb/test_config.py | 8 ++++---- tests/test_multidb/test_healthcheck.py | 4 ++-- tests/test_scenario/conftest.py | 2 -- tests/test_scenario/test_active_active.py | 16 ++++++++-------- 6 files changed, 23 insertions(+), 21 deletions(-) diff --git a/redis/cluster.py b/redis/cluster.py index 2fd4625e6b..dc91209ed2 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -691,6 +691,7 @@ def __init__( self._event_dispatcher = EventDispatcher() else: self._event_dispatcher = event_dispatcher + self.startup_nodes = startup_nodes self.nodes_manager = NodesManager( startup_nodes=startup_nodes, from_url=from_url, diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py index ff49dd74ed..8311a23982 100644 --- a/redis/multidb/healthcheck.py +++ b/redis/multidb/healthcheck.py @@ -86,7 +86,6 @@ def __init__( self, retry: Retry = Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF), rest_api_port: int = 9443, - availability_lag_tolerance: int = 100, timeout: float = DEFAULT_TIMEOUT, auth_basic: Optional[Tuple[str, str]] = None, verify_tls: bool = True, @@ -105,7 +104,6 @@ def __init__( Args: retry: Retry configuration for health checks rest_api_port: Port number for Redis Enterprise REST API (default: 9443) - availability_lag_tolerance: Maximum acceptable lag in milliseconds (default: 100) timeout: Request timeout in seconds (default: DEFAULT_TIMEOUT) auth_basic: Tuple of (username, password) for basic authentication verify_tls: Whether to verify TLS certificates (default: True) @@ -132,12 +130,17 @@ def __init__( client_key_password=client_key_password ) self._rest_api_port = rest_api_port - self._availability_lag_tolerance = availability_lag_tolerance def check_health(self, database) -> bool: client = database.client - db_host = client.get_connection_kwargs()['host'] + + if isinstance(client, Redis): + db_host = client.get_connection_kwargs()['host'] + else: + db_host = client.startup_nodes[0].host + base_url = f"https://{db_host}:{self._rest_api_port}" + print(base_url) self._http_client.base_url = base_url # Find bdb matching to the current database host @@ -158,7 +161,7 @@ def check_health(self, database) -> bool: logger.warning("LagAwareHealthCheck failed: Couldn't find a matching bdb") raise ValueError("Could not find a matching bdb") - url = f"/v1/local/bdbs/{matching_bdb['uid']}/endpoint/availability?extend_check=lag&availability_lag_tolerance_ms={self._availability_lag_tolerance}" + url = f"/v1/local/bdbs/{matching_bdb['uid']}/endpoint/availability" self._http_client.get(url, expect_json=False) # Status checked in an http client, otherwise HttpError will be raised diff --git a/tests/test_multidb/test_config.py b/tests/test_multidb/test_config.py index 7b72f65bbf..87aae701a9 100644 --- a/tests/test_multidb/test_config.py +++ b/tests/test_multidb/test_config.py @@ -75,7 +75,7 @@ def test_overridden_config(self): config = MultiDbConfig( databases_config=db_configs, failure_detectors=mock_failure_detectors, - additional_health_checks=mock_health_checks, + health_checks=mock_health_checks, health_check_interval=health_check_interval, failover_strategy=mock_failover_strategy, auto_fallback_interval=auto_fallback_interval, @@ -96,9 +96,9 @@ def test_overridden_config(self): assert len(config.failure_detectors) == 2 assert config.failure_detectors[0] == mock_failure_detectors[0] assert config.failure_detectors[1] == mock_failure_detectors[1] - assert len(config.additional_health_checks) == 2 - assert config.additional_health_checks[0] == mock_health_checks[0] - assert config.additional_health_checks[1] == mock_health_checks[1] + assert len(config.health_checks) == 2 + assert config.health_checks[0] == mock_health_checks[0] + assert config.health_checks[1] == mock_health_checks[1] assert config.health_check_interval == health_check_interval assert config.failover_strategy == mock_failover_strategy assert config.auto_fallback_interval == auto_fallback_interval diff --git a/tests/test_multidb/test_healthcheck.py b/tests/test_multidb/test_healthcheck.py index ed892d7515..a253ae21d9 100644 --- a/tests/test_multidb/test_healthcheck.py +++ b/tests/test_multidb/test_healthcheck.py @@ -89,7 +89,7 @@ def test_database_is_healthy_when_bdb_matches_by_dns_name(self, mock_client, moc first_call = mock_http.get.call_args_list[0] second_call = mock_http.get.call_args_list[1] assert first_call.args[0] == "/v1/bdbs" - assert second_call.args[0] == "/v1/bdbs/bdb-1/availability" + assert second_call.args[0] == "/v1/local/bdbs/bdb-1/endpoint/availability" assert second_call.kwargs.get("expect_json") is False def test_database_is_healthy_when_bdb_matches_by_addr(self, mock_client, mock_cb): @@ -121,7 +121,7 @@ def test_database_is_healthy_when_bdb_matches_by_addr(self, mock_client, mock_cb assert hc.check_health(db) is True assert mock_http.get.call_count == 2 - assert mock_http.get.call_args_list[1].args[0] == "/v1/bdbs/bdb-42/availability" + assert mock_http.get.call_args_list[1].args[0] == "/v1/local/bdbs/bdb-42/endpoint/availability" def test_raises_value_error_when_no_matching_bdb(self, mock_client, mock_cb): """ diff --git a/tests/test_scenario/conftest.py b/tests/test_scenario/conftest.py index 3f93d5a949..e4c5b1f96f 100644 --- a/tests/test_scenario/conftest.py +++ b/tests/test_scenario/conftest.py @@ -93,12 +93,10 @@ def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListen databases_config=db_configs, command_retry=command_retry, failure_threshold=failure_threshold, - health_check_backoff=ExponentialBackoff(cap=0.5, base=0.05), health_check_retries=3, health_check_interval=health_check_interval, event_dispatcher=event_dispatcher, health_check_backoff=ExponentialBackoff(cap=5, base=0.5), - health_check_retries=3, ) return MultiDBClient(config), listener, endpoint_config \ No newline at end of file diff --git a/tests/test_scenario/test_active_active.py b/tests/test_scenario/test_active_active.py index 770479ed13..44c57e6b99 100644 --- a/tests/test_scenario/test_active_active.py +++ b/tests/test_scenario/test_active_active.py @@ -9,7 +9,6 @@ from redis import Redis, RedisCluster from redis.client import Pipeline from redis.multidb.healthcheck import LagAwareHealthCheck -from tests.test_scenario.conftest import get_endpoint_config from tests.test_scenario.fault_injector_client import ActionRequest, ActionType logger = logging.getLogger(__name__) @@ -76,20 +75,23 @@ def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector @pytest.mark.parametrize( "r_multi_db", [ - {"failure_threshold": 2} + {"client_class": Redis, "failure_threshold": 2}, + {"client_class": RedisCluster, "failure_threshold": 2}, ], + ids=["standalone", "cluster"], indirect=True ) + @pytest.mark.timeout(50) def test_multi_db_client_uses_lag_aware_health_check(self, r_multi_db, fault_injector_client): + r_multi_db, listener, config = r_multi_db + event = threading.Event() thread = threading.Thread( target=trigger_network_failure_action, daemon=True, - args=(fault_injector_client,event) + args=(fault_injector_client,config,event) ) - r_multi_db, listener = r_multi_db - env0_username = os.getenv('ENV0_USERNAME') env0_password = os.getenv('ENV0_PASSWORD') @@ -108,12 +110,10 @@ def test_multi_db_client_uses_lag_aware_health_check(self, r_multi_db, fault_inj sleep(0.5) # Execute commands after network failure - for _ in range(3): + while not listener.is_changed_flag: assert r_multi_db.get('key') == 'value' sleep(0.5) - assert listener.is_changed_flag == True - @pytest.mark.parametrize( "r_multi_db", [ From 74bcb7d9f623431197a811bd5c3c954bcd008e3c Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 21 Aug 2025 13:13:43 +0300 Subject: [PATCH 10/13] Use primary node to send an API request --- redis/cluster.py | 1 - redis/multidb/healthcheck.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/redis/cluster.py b/redis/cluster.py index dc91209ed2..2fd4625e6b 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -691,7 +691,6 @@ def __init__( self._event_dispatcher = EventDispatcher() else: self._event_dispatcher = event_dispatcher - self.startup_nodes = startup_nodes self.nodes_manager = NodesManager( startup_nodes=startup_nodes, from_url=from_url, diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py index 8311a23982..84585a9212 100644 --- a/redis/multidb/healthcheck.py +++ b/redis/multidb/healthcheck.py @@ -137,10 +137,9 @@ def check_health(self, database) -> bool: if isinstance(client, Redis): db_host = client.get_connection_kwargs()['host'] else: - db_host = client.startup_nodes[0].host + db_host = client.get_primaries()[0].host base_url = f"https://{db_host}:{self._rest_api_port}" - print(base_url) self._http_client.base_url = base_url # Find bdb matching to the current database host From 164c31ea6d14ba6b87ae96d48d3249143ae0f1fb Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 21 Aug 2025 13:17:02 +0300 Subject: [PATCH 11/13] Added comment about RE bug --- redis/multidb/healthcheck.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py index 84585a9212..1f4f0eac19 100644 --- a/redis/multidb/healthcheck.py +++ b/redis/multidb/healthcheck.py @@ -137,6 +137,11 @@ def check_health(self, database) -> bool: if isinstance(client, Redis): db_host = client.get_connection_kwargs()['host'] else: + # We need to use the primary node public IP here and not DNS name. + # + # The bug exists in Redis Enterprise, if you reach REST API by DNS name + # the proxy will choose a random node, and if it's not a primary node it will redirect + # it to the primary node, the redirect will fail due to private IP. db_host = client.get_primaries()[0].host base_url = f"https://{db_host}:{self._rest_api_port}" From a9086047c39d950db0a0f460d185f726c39b1ca5 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 21 Aug 2025 13:20:23 +0300 Subject: [PATCH 12/13] Moved None type to the beginning --- redis/http/http_client.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/redis/http/http_client.py b/redis/http/http_client.py index bd153791b0..0a2de2e44c 100644 --- a/redis/http/http_client.py +++ b/redis/http/http_client.py @@ -132,7 +132,7 @@ def get( self, path: str, *, - params: Optional[Mapping[str, Union[str, int, float, bool, None, list, tuple]]] = None, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, headers: Optional[Mapping[str, str]] = None, timeout: Optional[float] = None, expect_json: bool = True @@ -151,7 +151,7 @@ def delete( self, path: str, *, - params: Optional[Mapping[str, Union[str, int, float, bool, None, list, tuple]]] = None, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, headers: Optional[Mapping[str, str]] = None, timeout: Optional[float] = None, expect_json: bool = True @@ -172,7 +172,7 @@ def post( *, json_body: Optional[Any] = None, data: Optional[Union[bytes, str]] = None, - params: Optional[Mapping[str, Union[str, int, float, bool, None, list, tuple]]] = None, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, headers: Optional[Mapping[str, str]] = None, timeout: Optional[float] = None, expect_json: bool = True @@ -193,7 +193,7 @@ def put( *, json_body: Optional[Any] = None, data: Optional[Union[bytes, str]] = None, - params: Optional[Mapping[str, Union[str, int, float, bool, None, list, tuple]]] = None, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, headers: Optional[Mapping[str, str]] = None, timeout: Optional[float] = None, expect_json: bool = True @@ -214,7 +214,7 @@ def patch( *, json_body: Optional[Any] = None, data: Optional[Union[bytes, str]] = None, - params: Optional[Mapping[str, Union[str, int, float, bool, None, list, tuple]]] = None, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, headers: Optional[Mapping[str, str]] = None, timeout: Optional[float] = None, expect_json: bool = True @@ -235,7 +235,7 @@ def request( method: str, path: str, *, - params: Optional[Mapping[str, Union[str, int, float, bool, None, list, tuple]]] = None, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, headers: Optional[Mapping[str, str]] = None, body: Optional[Union[bytes, str]] = None, timeout: Optional[float] = None, @@ -320,7 +320,7 @@ def _json_call( method: str, path: str, *, - params: Optional[Mapping[str, Union[str, int, float, bool, None, list, tuple]]] = None, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, headers: Optional[Mapping[str, str]] = None, timeout: Optional[float] = None, body: Optional[Union[bytes, str]] = None, @@ -350,7 +350,7 @@ def _prepare_body(self, *, json_body: Optional[Any] = None, data: Optional[Union def _build_url( self, path: str, - params: Optional[Mapping[str, Union[str, int, float, bool, None, list, tuple]]] = None, + params: Optional[Mapping[str, Union[None, str, int, float, bool, list, tuple]]] = None, ) -> str: url = urljoin(self.base_url or "", path) if params: From 0f07a81f2cf714eebd74a01dbcd58b85006c8ddd Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 22 Aug 2025 13:55:14 +0300 Subject: [PATCH 13/13] Added health_check_url property to Database class --- redis/cluster.py | 1 + redis/multidb/config.py | 31 +++++++++++++++++++++++++- redis/multidb/database.py | 28 ++++++++++++++++++++--- redis/multidb/healthcheck.py | 20 ++++++++--------- tests/test_multidb/test_healthcheck.py | 10 ++++----- tests/test_scenario/conftest.py | 29 +++++++++++++++++++++--- 6 files changed, 96 insertions(+), 23 deletions(-) diff --git a/redis/cluster.py b/redis/cluster.py index 2fd4625e6b..dc91209ed2 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -691,6 +691,7 @@ def __init__( self._event_dispatcher = EventDispatcher() else: self._event_dispatcher = event_dispatcher + self.startup_nodes = startup_nodes self.nodes_manager = NodesManager( startup_nodes=startup_nodes, from_url=from_url, diff --git a/redis/multidb/config.py b/redis/multidb/config.py index 39749a17ae..5555baec44 100644 --- a/redis/multidb/config.py +++ b/redis/multidb/config.py @@ -30,12 +30,36 @@ def default_event_dispatcher() -> EventDispatcherInterface: @dataclass class DatabaseConfig: + """ + Dataclass representing the configuration for a database connection. + + This class is used to store configuration settings for a database connection, + including client options, connection sourcing details, circuit breaker settings, + and cluster-specific properties. It provides a structure for defining these + attributes and allows for the creation of customized configurations for various + database setups. + + Attributes: + weight (float): Weight of the database to define the active one. + client_kwargs (dict): Additional parameters for the database client connection. + from_url (Optional[str]): Redis URL way of connecting to the database. + from_pool (Optional[ConnectionPool]): A pre-configured connection pool to use. + circuit (Optional[CircuitBreaker]): Custom circuit breaker implementation. + grace_period (float): Grace period after which we need to check if the circuit could be closed again. + health_check_url (Optional[str]): URL for health checks. Cluster FQDN is typically used + on public Redis Enterprise endpoints. + + Methods: + default_circuit_breaker: + Generates and returns a default CircuitBreaker instance adapted for use. + """ weight: float = 1.0 client_kwargs: dict = field(default_factory=dict) from_url: Optional[str] = None from_pool: Optional[ConnectionPool] = None circuit: Optional[CircuitBreaker] = None grace_period: float = DEFAULT_GRACE_PERIOD + health_check_url: Optional[str] = None def default_circuit_breaker(self) -> CircuitBreaker: circuit_breaker = pybreaker.CircuitBreaker(reset_timeout=self.grace_period) @@ -117,7 +141,12 @@ def databases(self) -> Databases: circuit = database_config.default_circuit_breaker() \ if database_config.circuit is None else database_config.circuit databases.add( - Database(client=client, circuit=circuit, weight=database_config.weight), + Database( + client=client, + circuit=circuit, + weight=database_config.weight, + health_check_url=database_config.health_check_url + ), database_config.weight ) diff --git a/redis/multidb/database.py b/redis/multidb/database.py index 3253ffa093..b03e77bd70 100644 --- a/redis/multidb/database.py +++ b/redis/multidb/database.py @@ -1,7 +1,7 @@ import redis from abc import ABC, abstractmethod from enum import Enum -from typing import Union +from typing import Union, Optional from redis import RedisCluster from redis.data_structure import WeightedList @@ -45,6 +45,18 @@ def circuit(self, circuit: CircuitBreaker): """Set the circuit breaker for the current database.""" pass + @property + @abstractmethod + def health_check_url(self) -> Optional[str]: + """Health check URL associated with the current database.""" + pass + + @health_check_url.setter + @abstractmethod + def health_check_url(self, health_check_url: Optional[str]): + """Set the health check URL associated with the current database.""" + pass + Databases = WeightedList[tuple[AbstractDatabase, Number]] class Database(AbstractDatabase): @@ -52,7 +64,8 @@ def __init__( self, client: Union[redis.Redis, RedisCluster], circuit: CircuitBreaker, - weight: float + weight: float, + health_check_url: Optional[str] = None, ): """ Initialize a new Database instance. @@ -61,12 +74,13 @@ def __init__( client: Underlying Redis client instance for database operations circuit: Circuit breaker for handling database failures weight: Weight value used for database failover prioritization - state: Initial database state, defaults to DISCONNECTED + health_check_url: Health check URL associated with the current database """ self._client = client self._cb = circuit self._cb.database = self self._weight = weight + self._health_check_url = health_check_url @property def client(self) -> Union[redis.Redis, RedisCluster]: @@ -91,3 +105,11 @@ def circuit(self) -> CircuitBreaker: @circuit.setter def circuit(self, circuit: CircuitBreaker): self._cb = circuit + + @property + def health_check_url(self) -> Optional[str]: + return self._health_check_url + + @health_check_url.setter + def health_check_url(self, health_check_url: Optional[str]): + self._health_check_url = health_check_url diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py index 1f4f0eac19..63ba415334 100644 --- a/redis/multidb/healthcheck.py +++ b/redis/multidb/healthcheck.py @@ -132,19 +132,17 @@ def __init__( self._rest_api_port = rest_api_port def check_health(self, database) -> bool: - client = database.client + if database.health_check_url is None: + raise ValueError( + "Database health check url is not set. Please check DatabaseConfig for the current database." + ) - if isinstance(client, Redis): - db_host = client.get_connection_kwargs()['host'] + if isinstance(database.client, Redis): + db_host = database.client.get_connection_kwargs()["host"] else: - # We need to use the primary node public IP here and not DNS name. - # - # The bug exists in Redis Enterprise, if you reach REST API by DNS name - # the proxy will choose a random node, and if it's not a primary node it will redirect - # it to the primary node, the redirect will fail due to private IP. - db_host = client.get_primaries()[0].host - - base_url = f"https://{db_host}:{self._rest_api_port}" + db_host = database.client.startup_nodes[0].host + + base_url = f"{database.health_check_url}:{self._rest_api_port}" self._http_client.base_url = base_url # Find bdb matching to the current database host diff --git a/tests/test_multidb/test_healthcheck.py b/tests/test_multidb/test_healthcheck.py index a253ae21d9..bc71fdb57d 100644 --- a/tests/test_multidb/test_healthcheck.py +++ b/tests/test_multidb/test_healthcheck.py @@ -79,11 +79,11 @@ def test_database_is_healthy_when_bdb_matches_by_dns_name(self, mock_client, moc # Inject our mocked http client hc._http_client = mock_http - db = Database(mock_client, mock_cb, 1.0) + db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") assert hc.check_health(db) is True # Base URL must be set correctly - assert hc._http_client.base_url == f"https://{host}:1234" + assert hc._http_client.base_url == f"https://healthcheck.example.com:1234" # Calls: first to list bdbs, then to availability assert mock_http.get.call_count == 2 first_call = mock_http.get.call_args_list[0] @@ -117,7 +117,7 @@ def test_database_is_healthy_when_bdb_matches_by_addr(self, mock_client, mock_cb ) hc._http_client = mock_http - db = Database(mock_client, mock_cb, 1.0) + db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") assert hc.check_health(db) is True assert mock_http.get.call_count == 2 @@ -142,7 +142,7 @@ def test_raises_value_error_when_no_matching_bdb(self, mock_client, mock_cb): ) hc._http_client = mock_http - db = Database(mock_client, mock_cb, 1.0) + db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") with pytest.raises(ValueError, match="Could not find a matching bdb"): hc.check_health(db) @@ -170,7 +170,7 @@ def test_propagates_http_error_from_availability(self, mock_client, mock_cb): ) hc._http_client = mock_http - db = Database(mock_client, mock_cb, 1.0) + db = Database(mock_client, mock_cb, 1.0, "https://healthcheck.example.com") with pytest.raises(HttpError, match="busy") as e: hc.check_health(db) diff --git a/tests/test_scenario/conftest.py b/tests/test_scenario/conftest.py index e4c5b1f96f..a0f19e1a87 100644 --- a/tests/test_scenario/conftest.py +++ b/tests/test_scenario/conftest.py @@ -1,5 +1,7 @@ import json import os +import re +from urllib.parse import urlparse import pytest @@ -73,7 +75,8 @@ def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListen 'username': username, 'password': password, 'decode_responses': True, - } + }, + health_check_url=extract_cluster_fqdn(endpoint_config['endpoints'][0]) ) db_configs.append(db_config) @@ -84,7 +87,8 @@ def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListen 'username': username, 'password': password, 'decode_responses': True, - } + }, + health_check_url=extract_cluster_fqdn(endpoint_config['endpoints'][1]) ) db_configs.append(db_config1) @@ -99,4 +103,23 @@ def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListen health_check_backoff=ExponentialBackoff(cap=5, base=0.5), ) - return MultiDBClient(config), listener, endpoint_config \ No newline at end of file + return MultiDBClient(config), listener, endpoint_config + + +def extract_cluster_fqdn(url): + """ + Extract Cluster FQDN from Redis URL + """ + # Parse the URL + parsed = urlparse(url) + + # Extract hostname and port + hostname = parsed.hostname + port = parsed.port + + # Remove the 'redis-XXXX.' prefix using regex + # This pattern matches 'redis-' followed by digits and a dot + cleaned_hostname = re.sub(r'^redis-\d+\.', '', hostname) + + # Reconstruct the URL + return f"https://{cleaned_hostname}" \ No newline at end of file