From b26a03ef543f022b6dd8b0db8293006a549bff2c Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Wed, 27 Aug 2025 16:42:32 +0300 Subject: [PATCH 1/9] Extract additional interfaces and abstract classes --- redis/multidb/circuit.py | 82 ++++++----- redis/multidb/client.py | 25 ++-- redis/multidb/command_executor.py | 152 ++++++++++---------- redis/multidb/config.py | 8 +- redis/multidb/database.py | 100 +++++++------ redis/multidb/event.py | 13 +- redis/multidb/failover.py | 11 +- redis/multidb/failure_detector.py | 1 - tests/test_multidb/conftest.py | 12 +- tests/test_multidb/test_circuit.py | 4 +- tests/test_multidb/test_client.py | 4 +- tests/test_multidb/test_config.py | 10 +- tests/test_multidb/test_failure_detector.py | 12 +- 13 files changed, 225 insertions(+), 209 deletions(-) diff --git a/redis/multidb/circuit.py b/redis/multidb/circuit.py index 79c8a5f379..221dc556a3 100644 --- a/redis/multidb/circuit.py +++ b/redis/multidb/circuit.py @@ -45,8 +45,49 @@ def database(self, database): """Set database associated with this circuit.""" pass +class BaseCircuitBreaker(CircuitBreaker): + """ + Base implementation of Circuit Breaker interface. + """ + def __init__(self, cb: pybreaker.CircuitBreaker): + self._cb = cb + self._state_pb_mapper = { + State.CLOSED: self._cb.close, + State.OPEN: self._cb.open, + State.HALF_OPEN: self._cb.half_open, + } + self._database = None + + @property + def grace_period(self) -> float: + return self._cb.reset_timeout + + @grace_period.setter + def grace_period(self, grace_period: float): + self._cb.reset_timeout = grace_period + + @property + def state(self) -> State: + return State(value=self._cb.state.name) + + @state.setter + def state(self, state: State): + self._state_pb_mapper[state]() + + @property + def database(self): + return self._database + + @database.setter + def database(self, database): + self._database = database + +class SyncCircuitBreaker(CircuitBreaker): + """ + Synchronous implementation of Circuit Breaker interface. + """ @abstractmethod - def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): + def on_state_changed(self, cb: Callable[["SyncCircuitBreaker", State, State], None]): """Callback called when the state of the circuit changes.""" pass @@ -54,7 +95,7 @@ class PBListener(pybreaker.CircuitBreakerListener): """Wrapper for callback to be compatible with pybreaker implementation.""" def __init__( self, - cb: Callable[[CircuitBreaker, State, State], None], + cb: Callable[[SyncCircuitBreaker, State, State], None], database, ): """ @@ -75,8 +116,7 @@ def state_change(self, cb, old_state, new_state): new_state = State(value=new_state.name) self._cb(cb, old_state, new_state) - -class PBCircuitBreakerAdapter(CircuitBreaker): +class PBCircuitBreakerAdapter(SyncCircuitBreaker, BaseCircuitBreaker): def __init__(self, cb: pybreaker.CircuitBreaker): """ Initialize a PBCircuitBreakerAdapter instance. @@ -87,38 +127,8 @@ def __init__(self, cb: pybreaker.CircuitBreaker): Args: cb: A pybreaker CircuitBreaker instance to be adapted. """ - self._cb = cb - self._state_pb_mapper = { - State.CLOSED: self._cb.close, - State.OPEN: self._cb.open, - State.HALF_OPEN: self._cb.half_open, - } - self._database = None - - @property - def grace_period(self) -> float: - return self._cb.reset_timeout - - @grace_period.setter - def grace_period(self, grace_period: float): - self._cb.reset_timeout = grace_period - - @property - def state(self) -> State: - return State(value=self._cb.state.name) - - @state.setter - def state(self, state: State): - self._state_pb_mapper[state]() - - @property - def database(self): - return self._database - - @database.setter - def database(self, database): - self._database = database + super().__init__(cb) - def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): + def on_state_changed(self, cb: Callable[["SyncCircuitBreaker", State, State], None]): listener = PBListener(cb, self.database) self._cb.add_listener(listener) \ No newline at end of file diff --git a/redis/multidb/client.py b/redis/multidb/client.py index 56342a7a53..8a0e006977 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -1,15 +1,12 @@ import threading -import socket from typing import List, Any, Callable, Optional from redis.background import BackgroundScheduler -from redis.client import PubSubWorkerThread -from redis.exceptions import ConnectionError, TimeoutError from redis.commands import RedisModuleCommands, CoreCommands from redis.multidb.command_executor import DefaultCommandExecutor from redis.multidb.config import MultiDbConfig, DEFAULT_GRACE_PERIOD -from redis.multidb.circuit import State as CBState, CircuitBreaker -from redis.multidb.database import Database, AbstractDatabase, Databases +from redis.multidb.circuit import State as CBState, SyncCircuitBreaker +from redis.multidb.database import Database, Databases, SyncDatabase from redis.multidb.exception import NoValidDatabaseException from redis.multidb.failure_detector import FailureDetector from redis.multidb.healthcheck import HealthCheck @@ -92,7 +89,7 @@ def get_databases(self) -> Databases: """ return self._databases - def set_active_database(self, database: AbstractDatabase) -> None: + def set_active_database(self, database: SyncDatabase) -> None: """ Promote one of the existing databases to become an active. """ @@ -115,7 +112,7 @@ def set_active_database(self, database: AbstractDatabase) -> None: raise NoValidDatabaseException('Cannot set active database, database is unhealthy') - def add_database(self, database: AbstractDatabase): + def add_database(self, database: SyncDatabase): """ Adds a new database to the database list. """ @@ -129,7 +126,7 @@ def add_database(self, database: AbstractDatabase): self._databases.add(database, database.weight) self._change_active_database(database, highest_weighted_db) - def _change_active_database(self, new_database: AbstractDatabase, highest_weight_database: AbstractDatabase): + def _change_active_database(self, new_database: SyncDatabase, highest_weight_database: SyncDatabase): if new_database.weight > highest_weight_database.weight and new_database.circuit.state == CBState.CLOSED: self.command_executor.active_database = new_database @@ -143,7 +140,7 @@ def remove_database(self, database: Database): if highest_weight <= weight and highest_weighted_db.circuit.state == CBState.CLOSED: self.command_executor.active_database = highest_weighted_db - def update_database_weight(self, database: AbstractDatabase, weight: float): + def update_database_weight(self, database: SyncDatabase, weight: float): """ Updates a database from the database list. """ @@ -210,7 +207,7 @@ def pubsub(self, **kwargs): return PubSub(self, **kwargs) - def _check_db_health(self, database: AbstractDatabase, on_error: Callable[[Exception], None] = None) -> None: + def _check_db_health(self, database: SyncDatabase, on_error: Callable[[Exception], None] = None) -> None: """ Runs health checks on the given database until first failure. """ @@ -247,7 +244,7 @@ def _check_databases_health(self, on_error: Callable[[Exception], None] = None): for database, _ in self._databases: self._check_db_health(database, on_error) - def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState): + def _on_circuit_state_change_callback(self, circuit: SyncCircuitBreaker, old_state: CBState, new_state: CBState): if new_state == CBState.HALF_OPEN: self._check_db_health(circuit.database) return @@ -255,7 +252,7 @@ def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state: if old_state == CBState.CLOSED and new_state == CBState.OPEN: self._bg_scheduler.run_once(DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit) -def _half_open_circuit(circuit: CircuitBreaker): +def _half_open_circuit(circuit: SyncCircuitBreaker): circuit.state = CBState.HALF_OPEN @@ -450,8 +447,8 @@ def run_in_thread( exception_handler: Optional[Callable] = None, sharded_pubsub: bool = False, ) -> "PubSubWorkerThread": - return self._client.command_executor.execute_pubsub_run_in_thread( - sleep_time=sleep_time, + return self._client.command_executor.execute_pubsub_run( + sleep_time, daemon=daemon, exception_handler=exception_handler, pubsub=self, diff --git a/redis/multidb/command_executor.py b/redis/multidb/command_executor.py index 094230a31d..364c0a07ea 100644 --- a/redis/multidb/command_executor.py +++ b/redis/multidb/command_executor.py @@ -1,11 +1,11 @@ from abc import ABC, abstractmethod from datetime import datetime, timedelta -from typing import List, Optional, Callable +from typing import List, Optional, Callable, Any from redis.client import Pipeline, PubSub, PubSubWorkerThread from redis.event import EventDispatcherInterface, OnCommandsFailEvent from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL -from redis.multidb.database import Database, AbstractDatabase, Databases +from redis.multidb.database import Database, Databases, SyncDatabase from redis.multidb.circuit import State as CBState from redis.multidb.event import RegisterCommandFailure, ActiveDatabaseChanged, ResubscribeOnActiveDatabaseChanged from redis.multidb.failover import FailoverStrategy @@ -17,15 +17,40 @@ class CommandExecutor(ABC): @property @abstractmethod - def failure_detectors(self) -> List[FailureDetector]: - """Returns a list of failure detectors.""" + def auto_fallback_interval(self) -> float: + """Returns auto-fallback interval.""" pass + @auto_fallback_interval.setter @abstractmethod - def add_failure_detector(self, failure_detector: FailureDetector) -> None: - """Adds new failure detector to the list of failure detectors.""" + def auto_fallback_interval(self, auto_fallback_interval: float) -> None: + """Sets auto-fallback interval.""" pass +class BaseCommandExecutor(CommandExecutor): + def __init__( + self, + auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL, + ): + self._auto_fallback_interval = auto_fallback_interval + self._next_fallback_attempt: datetime + + @property + def auto_fallback_interval(self) -> float: + return self._auto_fallback_interval + + @auto_fallback_interval.setter + def auto_fallback_interval(self, auto_fallback_interval: int) -> None: + self._auto_fallback_interval = auto_fallback_interval + + def _schedule_next_fallback(self) -> None: + if self._auto_fallback_interval == DEFAULT_AUTO_FALLBACK_INTERVAL: + return + + self._next_fallback_attempt = datetime.now() + timedelta(seconds=self._auto_fallback_interval) + +class SyncCommandExecutor(CommandExecutor): + @property @abstractmethod def databases(self) -> Databases: @@ -34,19 +59,25 @@ def databases(self) -> Databases: @property @abstractmethod - def active_database(self) -> Optional[Database]: - """Returns currently active database.""" + def failure_detectors(self) -> List[FailureDetector]: + """Returns a list of failure detectors.""" pass - @active_database.setter @abstractmethod - def active_database(self, database: AbstractDatabase) -> None: - """Sets currently active database.""" + def add_failure_detector(self, failure_detector: FailureDetector) -> None: + """Adds a new failure detector to the list of failure detectors.""" pass + @property @abstractmethod - def pubsub(self, **kwargs): - """Initializes a PubSub object on a currently active database""" + def active_database(self) -> Optional[Database]: + """Returns currently active database.""" + pass + + @active_database.setter + @abstractmethod + def active_database(self, database: SyncDatabase) -> None: + """Sets the currently active database.""" pass @property @@ -69,30 +100,41 @@ def failover_strategy(self) -> FailoverStrategy: @property @abstractmethod - def auto_fallback_interval(self) -> float: - """Returns auto-fallback interval.""" + def command_retry(self) -> Retry: + """Returns command retry object.""" pass - @auto_fallback_interval.setter @abstractmethod - def auto_fallback_interval(self, auto_fallback_interval: float) -> None: - """Sets auto-fallback interval.""" + def pubsub(self, **kwargs): + """Initializes a PubSub object on a currently active database""" pass - @property @abstractmethod - def command_retry(self) -> Retry: - """Returns command retry object.""" + def execute_command(self, *args, **options): + """Executes a command and returns the result.""" pass @abstractmethod - def execute_command(self, *args, **options): - """Executes a command and returns the result.""" + def execute_pipeline(self, command_stack: tuple): + """Executes a stack of commands in pipeline.""" pass + @abstractmethod + def execute_transaction(self, transaction: Callable[[Pipeline], None], *watches, **options): + """Executes a transaction block wrapped in callback.""" + pass -class DefaultCommandExecutor(CommandExecutor): + @abstractmethod + def execute_pubsub_method(self, method_name: str, *args, **kwargs): + """Executes a given method on active pub/sub.""" + pass + @abstractmethod + def execute_pubsub_run(self, sleep_time: float, **kwargs) -> Any: + """Executes pub/sub run in a thread.""" + pass + +class DefaultCommandExecutor(SyncCommandExecutor, BaseCommandExecutor): def __init__( self, failure_detectors: List[FailureDetector], @@ -113,22 +155,26 @@ def __init__( event_dispatcher: Interface for dispatching events auto_fallback_interval: Time interval in seconds between attempts to fall back to a primary database """ + super().__init__(auto_fallback_interval) + for fd in failure_detectors: fd.set_command_executor(command_executor=self) - self._failure_detectors = failure_detectors self._databases = databases + self._failure_detectors = failure_detectors self._command_retry = command_retry self._failover_strategy = failover_strategy self._event_dispatcher = event_dispatcher - self._auto_fallback_interval = auto_fallback_interval - self._next_fallback_attempt: datetime self._active_database: Optional[Database] = None self._active_pubsub: Optional[PubSub] = None self._active_pubsub_kwargs = {} self._setup_event_dispatcher() self._schedule_next_fallback() + @property + def databases(self) -> Databases: + return self._databases + @property def failure_detectors(self) -> List[FailureDetector]: return self._failure_detectors @@ -136,20 +182,16 @@ def failure_detectors(self) -> List[FailureDetector]: def add_failure_detector(self, failure_detector: FailureDetector) -> None: self._failure_detectors.append(failure_detector) - @property - def databases(self) -> Databases: - return self._databases - @property def command_retry(self) -> Retry: return self._command_retry @property - def active_database(self) -> Optional[AbstractDatabase]: + def active_database(self) -> Optional[SyncDatabase]: return self._active_database @active_database.setter - def active_database(self, database: AbstractDatabase) -> None: + def active_database(self, database: SyncDatabase) -> None: old_active = self._active_database self._active_database = database @@ -170,25 +212,13 @@ def active_pubsub(self, pubsub: PubSub) -> None: def failover_strategy(self) -> FailoverStrategy: return self._failover_strategy - @property - def auto_fallback_interval(self) -> float: - return self._auto_fallback_interval - - @auto_fallback_interval.setter - def auto_fallback_interval(self, auto_fallback_interval: int) -> None: - self._auto_fallback_interval = auto_fallback_interval - def execute_command(self, *args, **options): - """Executes a command and returns the result.""" def callback(): return self._active_database.client.execute_command(*args, **options) return self._execute_with_failure_detection(callback, args) def execute_pipeline(self, command_stack: tuple): - """ - Executes a stack of commands in pipeline. - """ def callback(): with self._active_database.client.pipeline() as pipe: for command, options in command_stack: @@ -199,18 +229,12 @@ def callback(): return self._execute_with_failure_detection(callback, command_stack) def execute_transaction(self, transaction: Callable[[Pipeline], None], *watches, **options): - """ - Executes a transaction block wrapped in callback. - """ def callback(): return self._active_database.client.transaction(transaction, *watches, **options) return self._execute_with_failure_detection(callback) def pubsub(self, **kwargs): - """ - Initializes a PubSub object on a currently active database. - """ def callback(): if self._active_pubsub is None: self._active_pubsub = self._active_database.client.pubsub(**kwargs) @@ -220,31 +244,15 @@ def callback(): return self._execute_with_failure_detection(callback) def execute_pubsub_method(self, method_name: str, *args, **kwargs): - """ - Executes given method on active pub/sub. - """ def callback(): method = getattr(self.active_pubsub, method_name) return method(*args, **kwargs) return self._execute_with_failure_detection(callback, *args) - def execute_pubsub_run_in_thread( - self, - pubsub, - sleep_time: float = 0.0, - daemon: bool = False, - exception_handler: Optional[Callable] = None, - sharded_pubsub: bool = False, - ) -> "PubSubWorkerThread": + def execute_pubsub_run(self, sleep_time, **kwargs) -> "PubSubWorkerThread": def callback(): - return self._active_pubsub.run_in_thread( - sleep_time, - daemon=daemon, - exception_handler=exception_handler, - pubsub=pubsub, - sharded_pubsub=sharded_pubsub - ) + return self._active_pubsub.run_in_thread(sleep_time, **kwargs) return self._execute_with_failure_detection(callback) @@ -280,12 +288,6 @@ def _check_active_database(self): self.active_database = self._failover_strategy.database self._schedule_next_fallback() - def _schedule_next_fallback(self) -> None: - if self._auto_fallback_interval == DEFAULT_AUTO_FALLBACK_INTERVAL: - return - - self._next_fallback_attempt = datetime.now() + timedelta(seconds=self._auto_fallback_interval) - def _setup_event_dispatcher(self): """ Registers necessary listeners. diff --git a/redis/multidb/config.py b/redis/multidb/config.py index 5555baec44..a966ec329a 100644 --- a/redis/multidb/config.py +++ b/redis/multidb/config.py @@ -9,7 +9,7 @@ from redis.backoff import ExponentialWithJitterBackoff, AbstractBackoff, NoBackoff from redis.data_structure import WeightedList from redis.event import EventDispatcher, EventDispatcherInterface -from redis.multidb.circuit import CircuitBreaker, PBCircuitBreakerAdapter +from redis.multidb.circuit import PBCircuitBreakerAdapter, SyncCircuitBreaker from redis.multidb.database import Database, Databases from redis.multidb.failure_detector import FailureDetector, CommandFailureDetector from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, \ @@ -44,7 +44,7 @@ class DatabaseConfig: client_kwargs (dict): Additional parameters for the database client connection. from_url (Optional[str]): Redis URL way of connecting to the database. from_pool (Optional[ConnectionPool]): A pre-configured connection pool to use. - circuit (Optional[CircuitBreaker]): Custom circuit breaker implementation. + circuit (Optional[SyncCircuitBreaker]): Custom circuit breaker implementation. grace_period (float): Grace period after which we need to check if the circuit could be closed again. health_check_url (Optional[str]): URL for health checks. Cluster FQDN is typically used on public Redis Enterprise endpoints. @@ -57,11 +57,11 @@ class DatabaseConfig: client_kwargs: dict = field(default_factory=dict) from_url: Optional[str] = None from_pool: Optional[ConnectionPool] = None - circuit: Optional[CircuitBreaker] = None + circuit: Optional[SyncCircuitBreaker] = None grace_period: float = DEFAULT_GRACE_PERIOD health_check_url: Optional[str] = None - def default_circuit_breaker(self) -> CircuitBreaker: + def default_circuit_breaker(self) -> SyncCircuitBreaker: circuit_breaker = pybreaker.CircuitBreaker(reset_timeout=self.grace_period) return PBCircuitBreakerAdapter(circuit_breaker) diff --git a/redis/multidb/database.py b/redis/multidb/database.py index b03e77bd70..75a662d904 100644 --- a/redis/multidb/database.py +++ b/redis/multidb/database.py @@ -5,65 +5,92 @@ from redis import RedisCluster from redis.data_structure import WeightedList -from redis.multidb.circuit import CircuitBreaker +from redis.multidb.circuit import SyncCircuitBreaker from redis.typing import Number class AbstractDatabase(ABC): @property @abstractmethod - def client(self) -> Union[redis.Redis, RedisCluster]: - """The underlying redis client.""" + def weight(self) -> float: + """The weight of this database in compare to others. Used to determine the database failover to.""" pass - @client.setter + @weight.setter @abstractmethod - def client(self, client: Union[redis.Redis, RedisCluster]): - """Set the underlying redis client.""" + def weight(self, weight: float): + """Set the weight of this database in compare to others.""" pass @property @abstractmethod - def weight(self) -> float: - """The weight of this database in compare to others. Used to determine the database failover to.""" + def health_check_url(self) -> Optional[str]: + """Health check URL associated with the current database.""" pass - @weight.setter + @health_check_url.setter @abstractmethod - def weight(self, weight: float): - """Set the weight of this database in compare to others.""" + def health_check_url(self, health_check_url: Optional[str]): + """Set the health check URL associated with the current database.""" pass +class BaseDatabase(AbstractDatabase): + def __init__( + self, + weight: float, + health_check_url: Optional[str] = None, + ): + self._weight = weight + self._health_check_url = health_check_url + + @property + def weight(self) -> float: + return self._weight + + @weight.setter + def weight(self, weight: float): + self._weight = weight + + @property + def health_check_url(self) -> Optional[str]: + return self._health_check_url + + @health_check_url.setter + def health_check_url(self, health_check_url: Optional[str]): + self._health_check_url = health_check_url + +class SyncDatabase(AbstractDatabase): + """Database with an underlying synchronous redis client.""" @property @abstractmethod - def circuit(self) -> CircuitBreaker: - """Circuit breaker for the current database.""" + def client(self) -> Union[redis.Redis, RedisCluster]: + """The underlying redis client.""" pass - @circuit.setter + @client.setter @abstractmethod - def circuit(self, circuit: CircuitBreaker): - """Set the circuit breaker for the current database.""" + def client(self, client: Union[redis.Redis, RedisCluster]): + """Set the underlying redis client.""" pass @property @abstractmethod - def health_check_url(self) -> Optional[str]: - """Health check URL associated with the current database.""" + def circuit(self) -> SyncCircuitBreaker: + """Circuit breaker for the current database.""" pass - @health_check_url.setter + @circuit.setter @abstractmethod - def health_check_url(self, health_check_url: Optional[str]): - """Set the health check URL associated with the current database.""" + def circuit(self, circuit: SyncCircuitBreaker): + """Set the circuit breaker for the current database.""" pass -Databases = WeightedList[tuple[AbstractDatabase, Number]] +Databases = WeightedList[tuple[SyncDatabase, Number]] -class Database(AbstractDatabase): +class Database(BaseDatabase, SyncDatabase): def __init__( self, client: Union[redis.Redis, RedisCluster], - circuit: CircuitBreaker, + circuit: SyncCircuitBreaker, weight: float, health_check_url: Optional[str] = None, ): @@ -79,8 +106,7 @@ def __init__( self._client = client self._cb = circuit self._cb.database = self - self._weight = weight - self._health_check_url = health_check_url + super().__init__(weight, health_check_url) @property def client(self) -> Union[redis.Redis, RedisCluster]: @@ -91,25 +117,9 @@ def client(self, client: Union[redis.Redis, RedisCluster]): self._client = client @property - def weight(self) -> float: - return self._weight - - @weight.setter - def weight(self, weight: float): - self._weight = weight - - @property - def circuit(self) -> CircuitBreaker: + def circuit(self) -> SyncCircuitBreaker: return self._cb @circuit.setter - def circuit(self, circuit: CircuitBreaker): - self._cb = circuit - - @property - def health_check_url(self) -> Optional[str]: - return self._health_check_url - - @health_check_url.setter - def health_check_url(self, health_check_url: Optional[str]): - self._health_check_url = health_check_url + def circuit(self, circuit: SyncCircuitBreaker): + self._cb = circuit \ No newline at end of file diff --git a/redis/multidb/event.py b/redis/multidb/event.py index 2598bc4d06..bca9482347 100644 --- a/redis/multidb/event.py +++ b/redis/multidb/event.py @@ -1,8 +1,7 @@ from typing import List from redis.event import EventListenerInterface, OnCommandsFailEvent -from redis.multidb.config import Databases -from redis.multidb.database import AbstractDatabase +from redis.multidb.database import SyncDatabase from redis.multidb.failure_detector import FailureDetector class ActiveDatabaseChanged: @@ -11,8 +10,8 @@ class ActiveDatabaseChanged: """ def __init__( self, - old_database: AbstractDatabase, - new_database: AbstractDatabase, + old_database: SyncDatabase, + new_database: SyncDatabase, command_executor, **kwargs ): @@ -22,11 +21,11 @@ def __init__( self._kwargs = kwargs @property - def old_database(self) -> AbstractDatabase: + def old_database(self) -> SyncDatabase: return self._old_database @property - def new_database(self) -> AbstractDatabase: + def new_database(self) -> SyncDatabase: return self._new_database @property @@ -39,7 +38,7 @@ def kwargs(self): class ResubscribeOnActiveDatabaseChanged(EventListenerInterface): """ - Re-subscribe currently active pub/sub to a new active database. + Re-subscribe the currently active pub / sub to a new active database. """ def listen(self, event: ActiveDatabaseChanged): old_pubsub = event.command_executor.active_pubsub diff --git a/redis/multidb/failover.py b/redis/multidb/failover.py index d6cf198678..fd08b77ecd 100644 --- a/redis/multidb/failover.py +++ b/redis/multidb/failover.py @@ -1,8 +1,7 @@ from abc import ABC, abstractmethod from redis.data_structure import WeightedList -from redis.multidb.database import Databases -from redis.multidb.database import AbstractDatabase +from redis.multidb.database import Databases, SyncDatabase from redis.multidb.circuit import State as CBState from redis.multidb.exception import NoValidDatabaseException from redis.retry import Retry @@ -13,13 +12,13 @@ class FailoverStrategy(ABC): @property @abstractmethod - def database(self) -> AbstractDatabase: + def database(self) -> SyncDatabase: """Select the database according to the strategy.""" pass @abstractmethod def set_databases(self, databases: Databases) -> None: - """Set the databases strategy operates on.""" + """Set the database strategy operates on.""" pass class WeightBasedFailoverStrategy(FailoverStrategy): @@ -35,7 +34,7 @@ def __init__( self._databases = WeightedList() @property - def database(self) -> AbstractDatabase: + def database(self) -> SyncDatabase: return self._retry.call_with_retry( lambda: self._get_active_database(), lambda _: dummy_fail() @@ -44,7 +43,7 @@ def database(self) -> AbstractDatabase: def set_databases(self, databases: Databases) -> None: self._databases = databases - def _get_active_database(self) -> AbstractDatabase: + def _get_active_database(self) -> SyncDatabase: for database, _ in self._databases: if database.circuit.state == CBState.CLOSED: return database diff --git a/redis/multidb/failure_detector.py b/redis/multidb/failure_detector.py index 3280fa6c32..ef4bd35f69 100644 --- a/redis/multidb/failure_detector.py +++ b/redis/multidb/failure_detector.py @@ -24,7 +24,6 @@ class CommandFailureDetector(FailureDetector): """ Detects a failure based on a threshold of failed commands during a specific period of time. """ - def __init__( self, threshold: int, diff --git a/tests/test_multidb/conftest.py b/tests/test_multidb/conftest.py index a34ef01476..9503d79d9b 100644 --- a/tests/test_multidb/conftest.py +++ b/tests/test_multidb/conftest.py @@ -4,7 +4,7 @@ from redis import Redis from redis.data_structure import WeightedList -from redis.multidb.circuit import CircuitBreaker, State as CBState +from redis.multidb.circuit import State as CBState, SyncCircuitBreaker from redis.multidb.config import MultiDbConfig, DatabaseConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ DEFAULT_AUTO_FALLBACK_INTERVAL from redis.multidb.database import Database, Databases @@ -19,8 +19,8 @@ def mock_client() -> Redis: return Mock(spec=Redis) @pytest.fixture() -def mock_cb() -> CircuitBreaker: - return Mock(spec=CircuitBreaker) +def mock_cb() -> SyncCircuitBreaker: + return Mock(spec=SyncCircuitBreaker) @pytest.fixture() def mock_fd() -> FailureDetector: @@ -41,7 +41,7 @@ def mock_db(request) -> Database: db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) - mock_cb = Mock(spec=CircuitBreaker) + mock_cb = Mock(spec=SyncCircuitBreaker) mock_cb.grace_period = cb.get("grace_period", 1.0) mock_cb.state = cb.get("state", CBState.CLOSED) @@ -55,7 +55,7 @@ def mock_db1(request) -> Database: db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) - mock_cb = Mock(spec=CircuitBreaker) + mock_cb = Mock(spec=SyncCircuitBreaker) mock_cb.grace_period = cb.get("grace_period", 1.0) mock_cb.state = cb.get("state", CBState.CLOSED) @@ -69,7 +69,7 @@ def mock_db2(request) -> Database: db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) - mock_cb = Mock(spec=CircuitBreaker) + mock_cb = Mock(spec=SyncCircuitBreaker) mock_cb.grace_period = cb.get("grace_period", 1.0) mock_cb.state = cb.get("state", CBState.CLOSED) diff --git a/tests/test_multidb/test_circuit.py b/tests/test_multidb/test_circuit.py index 7dc642373b..f5f39c3f6b 100644 --- a/tests/test_multidb/test_circuit.py +++ b/tests/test_multidb/test_circuit.py @@ -1,7 +1,7 @@ import pybreaker import pytest -from redis.multidb.circuit import PBCircuitBreakerAdapter, State as CbState, CircuitBreaker +from redis.multidb.circuit import PBCircuitBreakerAdapter, State as CbState, CircuitBreaker, SyncCircuitBreaker class TestPBCircuitBreaker: @@ -39,7 +39,7 @@ def test_cb_executes_callback_on_state_changed(self): adapter = PBCircuitBreakerAdapter(cb=pb_circuit) called_count = 0 - def callback(cb: CircuitBreaker, old_state: CbState, new_state: CbState): + def callback(cb: SyncCircuitBreaker, old_state: CbState, new_state: CbState): nonlocal called_count assert old_state == CbState.CLOSED assert new_state == CbState.HALF_OPEN diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index 193980d37c..c7c15fe684 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -8,7 +8,7 @@ from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter from redis.multidb.config import DEFAULT_FAILOVER_RETRIES, \ DEFAULT_FAILOVER_BACKOFF -from redis.multidb.database import AbstractDatabase +from redis.multidb.database import SyncDatabase from redis.multidb.client import MultiDBClient from redis.multidb.exception import NoValidDatabaseException from redis.multidb.failover import WeightBasedFailoverStrategy @@ -458,7 +458,7 @@ def test_set_active_database( assert client.set('key', 'value') == 'OK' with pytest.raises(ValueError, match='Given database is not a member of database list'): - client.set_active_database(Mock(spec=AbstractDatabase)) + client.set_active_database(Mock(spec=SyncDatabase)) mock_hc.check_health.return_value = False diff --git a/tests/test_multidb/test_config.py b/tests/test_multidb/test_config.py index 87aae701a9..e428b3ce7a 100644 --- a/tests/test_multidb/test_config.py +++ b/tests/test_multidb/test_config.py @@ -1,6 +1,6 @@ from unittest.mock import Mock from redis.connection import ConnectionPool -from redis.multidb.circuit import CircuitBreaker, PBCircuitBreakerAdapter +from redis.multidb.circuit import PBCircuitBreakerAdapter, SyncCircuitBreaker from redis.multidb.config import MultiDbConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ DEFAULT_AUTO_FALLBACK_INTERVAL, DatabaseConfig, DEFAULT_GRACE_PERIOD from redis.multidb.database import Database @@ -49,11 +49,11 @@ def test_overridden_config(self): mock_connection_pools[0].connection_kwargs = {} mock_connection_pools[1].connection_kwargs = {} mock_connection_pools[2].connection_kwargs = {} - mock_cb1 = Mock(spec=CircuitBreaker) + mock_cb1 = Mock(spec=SyncCircuitBreaker) mock_cb1.grace_period = grace_period - mock_cb2 = Mock(spec=CircuitBreaker) + mock_cb2 = Mock(spec=SyncCircuitBreaker) mock_cb2.grace_period = grace_period - mock_cb3 = Mock(spec=CircuitBreaker) + mock_cb3 = Mock(spec=SyncCircuitBreaker) mock_cb3.grace_period = grace_period mock_failure_detectors = [Mock(spec=FailureDetector), Mock(spec=FailureDetector)] mock_health_checks = [Mock(spec=HealthCheck), Mock(spec=HealthCheck)] @@ -113,7 +113,7 @@ def test_default_config(self): def test_overridden_config(self): mock_connection_pool = Mock(spec=ConnectionPool) - mock_circuit = Mock(spec=CircuitBreaker) + mock_circuit = Mock(spec=SyncCircuitBreaker) config = DatabaseConfig( client_kwargs={'connection_pool': mock_connection_pool}, weight=1.0, circuit=mock_circuit diff --git a/tests/test_multidb/test_failure_detector.py b/tests/test_multidb/test_failure_detector.py index 86d6e1cd82..28687f2a11 100644 --- a/tests/test_multidb/test_failure_detector.py +++ b/tests/test_multidb/test_failure_detector.py @@ -3,7 +3,7 @@ import pytest -from redis.multidb.command_executor import CommandExecutor +from redis.multidb.command_executor import SyncCommandExecutor from redis.multidb.failure_detector import CommandFailureDetector from redis.multidb.circuit import State as CBState from redis.exceptions import ConnectionError @@ -19,7 +19,7 @@ class TestCommandFailureDetector: ) def test_failure_detector_open_circuit_on_threshold_exceed_and_interval_not_exceed(self, mock_db): fd = CommandFailureDetector(5, 1) - mock_ce = Mock(spec=CommandExecutor) + mock_ce = Mock(spec=SyncCommandExecutor) mock_ce.active_database = mock_db fd.set_command_executor(mock_ce) assert mock_db.circuit.state == CBState.CLOSED @@ -41,7 +41,7 @@ def test_failure_detector_open_circuit_on_threshold_exceed_and_interval_not_exce ) def test_failure_detector_do_not_open_circuit_if_threshold_not_exceed_and_interval_not_exceed(self, mock_db): fd = CommandFailureDetector(5, 1) - mock_ce = Mock(spec=CommandExecutor) + mock_ce = Mock(spec=SyncCommandExecutor) mock_ce.active_database = mock_db fd.set_command_executor(mock_ce) assert mock_db.circuit.state == CBState.CLOSED @@ -62,7 +62,7 @@ def test_failure_detector_do_not_open_circuit_if_threshold_not_exceed_and_interv ) def test_failure_detector_do_not_open_circuit_on_threshold_exceed_and_interval_exceed(self, mock_db): fd = CommandFailureDetector(5, 0.3) - mock_ce = Mock(spec=CommandExecutor) + mock_ce = Mock(spec=SyncCommandExecutor) mock_ce.active_database = mock_db fd.set_command_executor(mock_ce) assert mock_db.circuit.state == CBState.CLOSED @@ -96,7 +96,7 @@ def test_failure_detector_do_not_open_circuit_on_threshold_exceed_and_interval_e ) def test_failure_detector_refresh_timer_on_expired_duration(self, mock_db): fd = CommandFailureDetector(5, 0.3) - mock_ce = Mock(spec=CommandExecutor) + mock_ce = Mock(spec=SyncCommandExecutor) mock_ce.active_database = mock_db fd.set_command_executor(mock_ce) assert mock_db.circuit.state == CBState.CLOSED @@ -128,7 +128,7 @@ def test_failure_detector_refresh_timer_on_expired_duration(self, mock_db): ) def test_failure_detector_open_circuit_on_specific_exception_threshold_exceed(self, mock_db): fd = CommandFailureDetector(5, 1, error_types=[ConnectionError]) - mock_ce = Mock(spec=CommandExecutor) + mock_ce = Mock(spec=SyncCommandExecutor) mock_ce.active_database = mock_db fd.set_command_executor(mock_ce) assert mock_db.circuit.state == CBState.CLOSED From bad9bcc32a69265cf1c5709b41b4867362e8007b Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 29 Aug 2025 11:37:36 +0300 Subject: [PATCH 2/9] Added base async components --- redis/asyncio/multidb/__init__.py | 0 redis/asyncio/multidb/circuit.py | 26 +++ redis/asyncio/multidb/command_executor.py | 95 +++++++++++ redis/asyncio/multidb/database.py | 67 ++++++++ redis/asyncio/multidb/event.py | 65 ++++++++ redis/asyncio/multidb/failover.py | 51 ++++++ redis/asyncio/multidb/failure_detector.py | 29 ++++ redis/asyncio/multidb/healthcheck.py | 75 +++++++++ redis/event.py | 3 + redis/multidb/circuit.py | 4 +- redis/utils.py | 6 + tests/test_asyncio/test_multidb/__init__.py | 0 tests/test_asyncio/test_multidb/conftest.py | 59 +++++++ .../test_asyncio/test_multidb/test_circuit.py | 58 +++++++ .../test_multidb/test_failover.py | 121 ++++++++++++++ .../test_multidb/test_failure_detector.py | 153 ++++++++++++++++++ .../test_multidb/test_healthcheck.py | 48 ++++++ 17 files changed, 858 insertions(+), 2 deletions(-) create mode 100644 redis/asyncio/multidb/__init__.py create mode 100644 redis/asyncio/multidb/circuit.py create mode 100644 redis/asyncio/multidb/command_executor.py create mode 100644 redis/asyncio/multidb/database.py create mode 100644 redis/asyncio/multidb/event.py create mode 100644 redis/asyncio/multidb/failover.py create mode 100644 redis/asyncio/multidb/failure_detector.py create mode 100644 redis/asyncio/multidb/healthcheck.py create mode 100644 tests/test_asyncio/test_multidb/__init__.py create mode 100644 tests/test_asyncio/test_multidb/conftest.py create mode 100644 tests/test_asyncio/test_multidb/test_circuit.py create mode 100644 tests/test_asyncio/test_multidb/test_failover.py create mode 100644 tests/test_asyncio/test_multidb/test_failure_detector.py create mode 100644 tests/test_asyncio/test_multidb/test_healthcheck.py diff --git a/redis/asyncio/multidb/__init__.py b/redis/asyncio/multidb/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/redis/asyncio/multidb/circuit.py b/redis/asyncio/multidb/circuit.py new file mode 100644 index 0000000000..97411e6e42 --- /dev/null +++ b/redis/asyncio/multidb/circuit.py @@ -0,0 +1,26 @@ +from abc import abstractmethod +from typing import Callable + +import pybreaker + +from redis.multidb.circuit import CircuitBreaker, State, BaseCircuitBreaker, PBCircuitBreakerAdapter + + +class AsyncCircuitBreaker(CircuitBreaker): + """Async implementation of Circuit Breaker interface.""" + + @abstractmethod + async def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): + """Callback called when the state of the circuit changes.""" + pass + +class AsyncPBCircuitBreakerAdapter(BaseCircuitBreaker, AsyncCircuitBreaker): + """ + Async adapter for pybreaker's CircuitBreaker implementation. + """ + def __init__(self, cb: pybreaker.CircuitBreaker): + super().__init__(cb) + self._sync_cb = PBCircuitBreakerAdapter(cb) + + async def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): + self._sync_cb.on_state_changed(cb) \ No newline at end of file diff --git a/redis/asyncio/multidb/command_executor.py b/redis/asyncio/multidb/command_executor.py new file mode 100644 index 0000000000..18117160ee --- /dev/null +++ b/redis/asyncio/multidb/command_executor.py @@ -0,0 +1,95 @@ +from abc import abstractmethod +from typing import List, Optional, Callable, Any + +from redis.asyncio.client import PubSub, Pipeline +from redis.asyncio.multidb.database import Databases, AsyncDatabase +from redis.asyncio.multidb.failover import AsyncFailoverStrategy +from redis.asyncio.multidb.failure_detector import AsyncFailureDetector +from redis.asyncio.retry import Retry +from redis.multidb.command_executor import CommandExecutor + + +class AsyncCommandExecutor(CommandExecutor): + + @property + @abstractmethod + def databases(self) -> Databases: + """Returns a list of databases.""" + pass + + @property + @abstractmethod + def failure_detectors(self) -> List[AsyncFailureDetector]: + """Returns a list of failure detectors.""" + pass + + @abstractmethod + def add_failure_detector(self, failure_detector: AsyncFailureDetector) -> None: + """Adds a new failure detector to the list of failure detectors.""" + pass + + @property + @abstractmethod + def active_database(self) -> Optional[AsyncDatabase]: + """Returns currently active database.""" + pass + + @active_database.setter + @abstractmethod + def active_database(self, database: AsyncDatabase) -> None: + """Sets the currently active database.""" + pass + + @property + @abstractmethod + def active_pubsub(self) -> Optional[PubSub]: + """Returns currently active pubsub.""" + pass + + @active_pubsub.setter + @abstractmethod + def active_pubsub(self, pubsub: PubSub) -> None: + """Sets currently active pubsub.""" + pass + + @property + @abstractmethod + def failover_strategy(self) -> AsyncFailoverStrategy: + """Returns failover strategy.""" + pass + + @property + @abstractmethod + def command_retry(self) -> Retry: + """Returns command retry object.""" + pass + + @abstractmethod + async def pubsub(self, **kwargs): + """Initializes a PubSub object on a currently active database""" + pass + + @abstractmethod + async def execute_command(self, *args, **options): + """Executes a command and returns the result.""" + pass + + @abstractmethod + async def execute_pipeline(self, command_stack: tuple): + """Executes a stack of commands in pipeline.""" + pass + + @abstractmethod + async def execute_transaction(self, transaction: Callable[[Pipeline], None], *watches, **options): + """Executes a transaction block wrapped in callback.""" + pass + + @abstractmethod + def execute_pubsub_method(self, method_name: str, *args, **kwargs): + """Executes a given method on active pub/sub.""" + pass + + @abstractmethod + def execute_pubsub_run(self, sleep_time: float, **kwargs) -> Any: + """Executes pub/sub run in a thread.""" + pass \ No newline at end of file diff --git a/redis/asyncio/multidb/database.py b/redis/asyncio/multidb/database.py new file mode 100644 index 0000000000..85320f3aaa --- /dev/null +++ b/redis/asyncio/multidb/database.py @@ -0,0 +1,67 @@ +from abc import abstractmethod +from typing import Union, Optional + +from redis.asyncio import Redis, RedisCluster +from redis.asyncio.multidb.circuit import AsyncCircuitBreaker +from redis.data_structure import WeightedList +from redis.multidb.database import AbstractDatabase, BaseDatabase +from redis.typing import Number + + +class AsyncDatabase(AbstractDatabase): + """Database with an underlying asynchronous redis client.""" + @property + @abstractmethod + def client(self) -> Union[Redis, RedisCluster]: + """The underlying redis client.""" + pass + + @client.setter + @abstractmethod + def client(self, client: Union[Redis, RedisCluster]): + """Set the underlying redis client.""" + pass + + @property + @abstractmethod + def circuit(self) -> AsyncCircuitBreaker: + """Circuit breaker for the current database.""" + pass + + @circuit.setter + @abstractmethod + def circuit(self, circuit: AsyncCircuitBreaker): + """Set the circuit breaker for the current database.""" + pass + +Databases = WeightedList[tuple[AsyncDatabase, Number]] + +class Database(BaseDatabase, AsyncDatabase): + def __init__( + self, + client: Union[Redis, RedisCluster], + circuit: AsyncCircuitBreaker, + weight: float, + health_check_url: Optional[str] = None, + ): + self._client = client + self._cb = circuit + self._cb.database = self + super().__init__(weight, health_check_url) + + @property + def client(self) -> Union[Redis, RedisCluster]: + return self._client + + @client.setter + def client(self, client: Union[Redis, RedisCluster]): + self._client = client + + @property + def circuit(self) -> AsyncCircuitBreaker: + return self._cb + + @circuit.setter + def circuit(self, circuit: AsyncCircuitBreaker): + self._cb = circuit + diff --git a/redis/asyncio/multidb/event.py b/redis/asyncio/multidb/event.py new file mode 100644 index 0000000000..ea5534ce86 --- /dev/null +++ b/redis/asyncio/multidb/event.py @@ -0,0 +1,65 @@ +from typing import List + +from redis.asyncio.multidb.database import AsyncDatabase +from redis.asyncio.multidb.failure_detector import AsyncFailureDetector +from redis.event import AsyncEventListenerInterface, AsyncOnCommandsFailEvent + + +class AsyncActiveDatabaseChanged: + """ + Event fired when an async active database has been changed. + """ + def __init__( + self, + old_database: AsyncDatabase, + new_database: AsyncDatabase, + command_executor, + **kwargs + ): + self._old_database = old_database + self._new_database = new_database + self._command_executor = command_executor + self._kwargs = kwargs + + @property + def old_database(self) -> AsyncDatabase: + return self._old_database + + @property + def new_database(self) -> AsyncDatabase: + return self._new_database + + @property + def command_executor(self): + return self._command_executor + + @property + def kwargs(self): + return self._kwargs + +class ResubscribeOnActiveDatabaseChanged(AsyncEventListenerInterface): + """ + Re-subscribe the currently active pub / sub to a new active database. + """ + async def listen(self, event: AsyncActiveDatabaseChanged): + old_pubsub = event.command_executor.active_pubsub + + if old_pubsub is not None: + # Re-assign old channels and patterns so they will be automatically subscribed on connection. + new_pubsub = event.new_database.client.pubsub(**event.kwargs) + new_pubsub.channels = old_pubsub.channels + new_pubsub.patterns = old_pubsub.patterns + await new_pubsub.on_connect(None) + event.command_executor.active_pubsub = new_pubsub + await old_pubsub.close() + +class RegisterCommandFailure(AsyncEventListenerInterface): + """ + Event listener that registers command failures and passing it to the failure detectors. + """ + def __init__(self, failure_detectors: List[AsyncFailureDetector]): + self._failure_detectors = failure_detectors + + async def listen(self, event: AsyncOnCommandsFailEvent) -> None: + for failure_detector in self._failure_detectors: + await failure_detector.register_failure(event.exception, event.commands) \ No newline at end of file diff --git a/redis/asyncio/multidb/failover.py b/redis/asyncio/multidb/failover.py new file mode 100644 index 0000000000..ad7f25ce41 --- /dev/null +++ b/redis/asyncio/multidb/failover.py @@ -0,0 +1,51 @@ +from abc import abstractmethod, ABC + +from redis.asyncio.multidb.database import AsyncDatabase, Databases +from redis.multidb.circuit import State as CBState +from redis.asyncio.retry import Retry +from redis.data_structure import WeightedList +from redis.multidb.exception import NoValidDatabaseException +from redis.utils import dummy_fail_async + + +class AsyncFailoverStrategy(ABC): + + @property + @abstractmethod + async def database(self) -> AsyncDatabase: + """Select the database according to the strategy.""" + pass + + @abstractmethod + def set_databases(self, databases: Databases) -> None: + """Set the database strategy operates on.""" + pass + +class WeightBasedFailoverStrategy(AsyncFailoverStrategy): + """ + Failover strategy based on database weights. + """ + def __init__( + self, + retry: Retry + ): + self._retry = retry + self._retry.update_supported_errors([NoValidDatabaseException]) + self._databases = WeightedList() + + @property + async def database(self) -> AsyncDatabase: + return await self._retry.call_with_retry( + lambda: self._get_active_database(), + lambda _: dummy_fail_async() + ) + + def set_databases(self, databases: Databases) -> None: + self._databases = databases + + async def _get_active_database(self) -> AsyncDatabase: + for database, _ in self._databases: + if database.circuit.state == CBState.CLOSED: + return database + + raise NoValidDatabaseException('No valid database available for communication') \ No newline at end of file diff --git a/redis/asyncio/multidb/failure_detector.py b/redis/asyncio/multidb/failure_detector.py new file mode 100644 index 0000000000..8aa4752924 --- /dev/null +++ b/redis/asyncio/multidb/failure_detector.py @@ -0,0 +1,29 @@ +from abc import ABC, abstractmethod + +from redis.multidb.failure_detector import FailureDetector + + +class AsyncFailureDetector(ABC): + + @abstractmethod + async def register_failure(self, exception: Exception, cmd: tuple) -> None: + """Register a failure that occurred during command execution.""" + pass + + @abstractmethod + def set_command_executor(self, command_executor) -> None: + """Set the command executor for this failure.""" + pass + +class FailureDetectorAsyncWrapper(AsyncFailureDetector): + """ + Async wrapper for the failure detector. + """ + def __init__(self, failure_detector: FailureDetector) -> None: + self._failure_detector = failure_detector + + async def register_failure(self, exception: Exception, cmd: tuple) -> None: + self._failure_detector.register_failure(exception, cmd) + + def set_command_executor(self, command_executor) -> None: + self._failure_detector.set_command_executor(command_executor) \ No newline at end of file diff --git a/redis/asyncio/multidb/healthcheck.py b/redis/asyncio/multidb/healthcheck.py new file mode 100644 index 0000000000..7ae7bf34de --- /dev/null +++ b/redis/asyncio/multidb/healthcheck.py @@ -0,0 +1,75 @@ +import logging +from abc import ABC, abstractmethod + +from redis.asyncio import Redis +from redis.asyncio.retry import Retry +from redis.backoff import ExponentialWithJitterBackoff +from redis.utils import dummy_fail_async + +DEFAULT_HEALTH_CHECK_RETRIES = 3 +DEFAULT_HEALTH_CHECK_BACKOFF = ExponentialWithJitterBackoff(cap=10) + +logger = logging.getLogger(__name__) + +class HealthCheck(ABC): + + @property + @abstractmethod + def retry(self) -> Retry: + """The retry object to use for health checks.""" + pass + + @abstractmethod + async def check_health(self, database) -> bool: + """Function to determine the health status.""" + pass + +class AbstractHealthCheck(HealthCheck): + def __init__( + self, + retry: Retry = Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + ) -> None: + self._retry = retry + self._retry.update_supported_errors([ConnectionRefusedError]) + + @property + def retry(self) -> Retry: + return self._retry + + @abstractmethod + async def check_health(self, database) -> bool: + pass + +class EchoHealthCheck(AbstractHealthCheck): + def __init__( + self, + retry: Retry = Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + ) -> None: + """ + Check database healthiness by sending an echo request. + """ + super().__init__( + retry=retry, + ) + async def check_health(self, database) -> bool: + return await self._retry.call_with_retry( + lambda: self._returns_echoed_message(database), + lambda _: dummy_fail_async() + ) + + async def _returns_echoed_message(self, database) -> bool: + expected_message = ["healthcheck", b"healthcheck"] + + if isinstance(database.client, Redis): + actual_message = await database.client.execute_command("ECHO" ,"healthcheck") + return actual_message in expected_message + else: + # For a cluster checks if all nodes are healthy. + all_nodes = database.client.get_nodes() + for node in all_nodes: + actual_message = await node.redis_connection.execute_command("ECHO" ,"healthcheck") + + if actual_message not in expected_message: + return False + + return True \ No newline at end of file diff --git a/redis/event.py b/redis/event.py index 1fa66f0587..4d167442eb 100644 --- a/redis/event.py +++ b/redis/event.py @@ -271,6 +271,9 @@ def commands(self) -> tuple: def exception(self) -> Exception: return self._exception +class AsyncOnCommandsFailEvent(OnCommandsFailEvent): + pass + class ReAuthConnectionListener(EventListenerInterface): """ Listener that performs re-authentication of given connection. diff --git a/redis/multidb/circuit.py b/redis/multidb/circuit.py index 221dc556a3..576ee27fab 100644 --- a/redis/multidb/circuit.py +++ b/redis/multidb/circuit.py @@ -87,7 +87,7 @@ class SyncCircuitBreaker(CircuitBreaker): Synchronous implementation of Circuit Breaker interface. """ @abstractmethod - def on_state_changed(self, cb: Callable[["SyncCircuitBreaker", State, State], None]): + def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): """Callback called when the state of the circuit changes.""" pass @@ -129,6 +129,6 @@ def __init__(self, cb: pybreaker.CircuitBreaker): """ super().__init__(cb) - def on_state_changed(self, cb: Callable[["SyncCircuitBreaker", State, State], None]): + def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): listener = PBListener(cb, self.database) self._cb.add_listener(listener) \ No newline at end of file diff --git a/redis/utils.py b/redis/utils.py index 94bfab61bb..1800582e46 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -314,3 +314,9 @@ def dummy_fail(): Fake function for a Retry object if you don't need to handle each failure. """ pass + +async def dummy_fail_async(): + """ + Async fake function for a Retry object if you don't need to handle each failure. + """ + pass \ No newline at end of file diff --git a/tests/test_asyncio/test_multidb/__init__.py b/tests/test_asyncio/test_multidb/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_asyncio/test_multidb/conftest.py b/tests/test_asyncio/test_multidb/conftest.py new file mode 100644 index 0000000000..1f67e3c63c --- /dev/null +++ b/tests/test_asyncio/test_multidb/conftest.py @@ -0,0 +1,59 @@ +from unittest.mock import Mock + +import pytest + +from redis.multidb.circuit import State as CBState +from redis.asyncio import Redis +from redis.asyncio.multidb.circuit import AsyncCircuitBreaker +from redis.asyncio.multidb.database import Database + + +@pytest.fixture() +def mock_client() -> Redis: + return Mock(spec=Redis) + +@pytest.fixture() +def mock_cb() -> AsyncCircuitBreaker: + return Mock(spec=AsyncCircuitBreaker) + +@pytest.fixture() +def mock_db(request) -> Database: + db = Mock(spec=Database) + db.weight = request.param.get("weight", 1.0) + db.client = Mock(spec=Redis) + + cb = request.param.get("circuit", {}) + mock_cb = Mock(spec=AsyncCircuitBreaker) + mock_cb.grace_period = cb.get("grace_period", 1.0) + mock_cb.state = cb.get("state", CBState.CLOSED) + + db.circuit = mock_cb + return db + +@pytest.fixture() +def mock_db1(request) -> Database: + db = Mock(spec=Database) + db.weight = request.param.get("weight", 1.0) + db.client = Mock(spec=Redis) + + cb = request.param.get("circuit", {}) + mock_cb = Mock(spec=AsyncCircuitBreaker) + mock_cb.grace_period = cb.get("grace_period", 1.0) + mock_cb.state = cb.get("state", CBState.CLOSED) + + db.circuit = mock_cb + return db + +@pytest.fixture() +def mock_db2(request) -> Database: + db = Mock(spec=Database) + db.weight = request.param.get("weight", 1.0) + db.client = Mock(spec=Redis) + + cb = request.param.get("circuit", {}) + mock_cb = Mock(spec=AsyncCircuitBreaker) + mock_cb.grace_period = cb.get("grace_period", 1.0) + mock_cb.state = cb.get("state", CBState.CLOSED) + + db.circuit = mock_cb + return db \ No newline at end of file diff --git a/tests/test_asyncio/test_multidb/test_circuit.py b/tests/test_asyncio/test_multidb/test_circuit.py new file mode 100644 index 0000000000..b1080cfc7d --- /dev/null +++ b/tests/test_asyncio/test_multidb/test_circuit.py @@ -0,0 +1,58 @@ +import pybreaker +import pytest + +from redis.asyncio.multidb.circuit import ( + AsyncPBCircuitBreakerAdapter, + State as CbState, +) +from redis.multidb.circuit import CircuitBreaker, PBCircuitBreakerAdapter + + +class TestAsyncPBCircuitBreaker: + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CbState.CLOSED}}, + ], + indirect=True, + ) + async def test_cb_correctly_configured(self, mock_db): + pb_circuit = pybreaker.CircuitBreaker(reset_timeout=5) + adapter = AsyncPBCircuitBreakerAdapter(cb=pb_circuit) + assert adapter.state == CbState.CLOSED + + adapter.state = CbState.OPEN + assert adapter.state == CbState.OPEN + + adapter.state = CbState.HALF_OPEN + assert adapter.state == CbState.HALF_OPEN + + adapter.state = CbState.CLOSED + assert adapter.state == CbState.CLOSED + + assert adapter.grace_period == 5 + adapter.grace_period = 10 + + assert adapter.grace_period == 10 + + adapter.database = mock_db + assert adapter.database == mock_db + + @pytest.mark.asyncio + async def test_cb_executes_callback_on_state_changed(self): + pb_circuit = pybreaker.CircuitBreaker(reset_timeout=5) + adapter = AsyncPBCircuitBreakerAdapter(cb=pb_circuit) + called_count = 0 + + def callback(cb: CircuitBreaker, old_state: CbState, new_state: CbState): + nonlocal called_count + assert old_state == CbState.CLOSED + assert new_state == CbState.HALF_OPEN + assert isinstance(cb, PBCircuitBreakerAdapter) + called_count += 1 + + await adapter.on_state_changed(callback) + adapter.state = CbState.HALF_OPEN + + assert called_count == 1 \ No newline at end of file diff --git a/tests/test_asyncio/test_multidb/test_failover.py b/tests/test_asyncio/test_multidb/test_failover.py new file mode 100644 index 0000000000..d7bc4411b6 --- /dev/null +++ b/tests/test_asyncio/test_multidb/test_failover.py @@ -0,0 +1,121 @@ +from unittest.mock import PropertyMock + +import pytest + +from redis.backoff import NoBackoff, ExponentialBackoff +from redis.data_structure import WeightedList +from redis.multidb.circuit import State as CBState +from redis.multidb.exception import NoValidDatabaseException +from redis.asyncio.multidb.failover import WeightBasedFailoverStrategy +from redis.asyncio.retry import Retry + + +class TestAsyncWeightBasedFailoverStrategy: + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + ), + ], + ids=['all closed - highest weight', 'highest weight - open'], + indirect=True, + ) + async def test_get_valid_database(self, mock_db, mock_db1, mock_db2): + retry = Retry(NoBackoff(), 0) + databases = WeightedList() + databases.add(mock_db, mock_db.weight) + databases.add(mock_db1, mock_db1.weight) + databases.add(mock_db2, mock_db2.weight) + + strategy = WeightBasedFailoverStrategy(retry=retry) + strategy.set_databases(databases) + + assert await strategy.database == mock_db1 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + async def test_get_valid_database_with_retries(self, mock_db, mock_db1, mock_db2): + state_mock = PropertyMock( + side_effect=[CBState.OPEN, CBState.OPEN, CBState.OPEN, CBState.CLOSED] + ) + type(mock_db.circuit).state = state_mock + + retry = Retry(ExponentialBackoff(cap=1), 3) + databases = WeightedList() + databases.add(mock_db, mock_db.weight) + databases.add(mock_db1, mock_db1.weight) + databases.add(mock_db2, mock_db2.weight) + failover_strategy = WeightBasedFailoverStrategy(retry=retry) + failover_strategy.set_databases(databases) + + assert await failover_strategy.database == mock_db + assert state_mock.call_count == 4 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + async def test_get_valid_database_throws_exception_with_retries(self, mock_db, mock_db1, mock_db2): + state_mock = PropertyMock( + side_effect=[CBState.OPEN, CBState.OPEN, CBState.OPEN, CBState.OPEN] + ) + type(mock_db.circuit).state = state_mock + + retry = Retry(ExponentialBackoff(cap=1), 3) + databases = WeightedList() + databases.add(mock_db, mock_db.weight) + databases.add(mock_db1, mock_db1.weight) + databases.add(mock_db2, mock_db2.weight) + failover_strategy = WeightBasedFailoverStrategy(retry=retry) + failover_strategy.set_databases(databases) + + with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): + assert await failover_strategy.database + + assert state_mock.call_count == 4 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + async def test_throws_exception_on_empty_databases(self, mock_db, mock_db1, mock_db2): + retry = Retry(NoBackoff(), 0) + failover_strategy = WeightBasedFailoverStrategy(retry=retry) + + with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): + assert await failover_strategy.database \ No newline at end of file diff --git a/tests/test_asyncio/test_multidb/test_failure_detector.py b/tests/test_asyncio/test_multidb/test_failure_detector.py new file mode 100644 index 0000000000..3c1eb4fabd --- /dev/null +++ b/tests/test_asyncio/test_multidb/test_failure_detector.py @@ -0,0 +1,153 @@ +import asyncio +from unittest.mock import Mock + +import pytest + +from redis.asyncio.multidb.command_executor import AsyncCommandExecutor +from redis.asyncio.multidb.failure_detector import FailureDetectorAsyncWrapper +from redis.multidb.circuit import State as CBState +from redis.multidb.failure_detector import CommandFailureDetector + + +class TestFailureDetectorAsyncWrapper: + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + ], + indirect=True, + ) + async def test_failure_detector_open_circuit_on_threshold_exceed_and_interval_not_exceed(self, mock_db): + fd = FailureDetectorAsyncWrapper(CommandFailureDetector(5, 1)) + mock_ce = Mock(spec=AsyncCommandExecutor) + mock_ce.active_database = mock_db + fd.set_command_executor(mock_ce) + assert mock_db.circuit.state == CBState.CLOSED + + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.OPEN + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + ], + indirect=True, + ) + async def test_failure_detector_do_not_open_circuit_if_threshold_not_exceed_and_interval_not_exceed(self, mock_db): + fd = FailureDetectorAsyncWrapper(CommandFailureDetector(5, 1)) + mock_ce = Mock(spec=AsyncCommandExecutor) + mock_ce.active_database = mock_db + fd.set_command_executor(mock_ce) + assert mock_db.circuit.state == CBState.CLOSED + + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.CLOSED + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + ], + indirect=True, + ) + async def test_failure_detector_do_not_open_circuit_on_threshold_exceed_and_interval_exceed(self, mock_db): + fd = FailureDetectorAsyncWrapper(CommandFailureDetector(5, 0.3)) + mock_ce = Mock(spec=AsyncCommandExecutor) + mock_ce.active_database = mock_db + fd.set_command_executor(mock_ce) + assert mock_db.circuit.state == CBState.CLOSED + + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await asyncio.sleep(0.1) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await asyncio.sleep(0.1) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await asyncio.sleep(0.1) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await asyncio.sleep(0.1) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.CLOSED + + # 4 more failures as the last one already refreshed timer + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.OPEN + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + ], + indirect=True, + ) + async def test_failure_detector_refresh_timer_on_expired_duration(self, mock_db): + fd = FailureDetectorAsyncWrapper(CommandFailureDetector(5, 0.3)) + mock_ce = Mock(spec=AsyncCommandExecutor) + mock_ce.active_database = mock_db + fd.set_command_executor(mock_ce) + assert mock_db.circuit.state == CBState.CLOSED + + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await asyncio.sleep(0.4) + + assert mock_db.circuit.state == CBState.CLOSED + + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.CLOSED + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.OPEN + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db', + [ + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + ], + indirect=True, + ) + async def test_failure_detector_open_circuit_on_specific_exception_threshold_exceed(self, mock_db): + fd = FailureDetectorAsyncWrapper(CommandFailureDetector(5, 1, error_types=[ConnectionError])) + mock_ce = Mock(spec=AsyncCommandExecutor) + mock_ce.active_database = mock_db + fd.set_command_executor(mock_ce) + assert mock_db.circuit.state == CBState.CLOSED + + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) + await fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + await fd.register_failure(Exception(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.CLOSED + + await fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) + await fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) + await fd.register_failure(ConnectionError(), ('SET', 'key1', 'value1')) + + assert mock_db.circuit.state == CBState.OPEN \ No newline at end of file diff --git a/tests/test_asyncio/test_multidb/test_healthcheck.py b/tests/test_asyncio/test_multidb/test_healthcheck.py new file mode 100644 index 0000000000..fd5c8ec3f0 --- /dev/null +++ b/tests/test_asyncio/test_multidb/test_healthcheck.py @@ -0,0 +1,48 @@ +import pytest +from mock.mock import AsyncMock + +from redis.asyncio.multidb.database import Database +from redis.asyncio.multidb.healthcheck import EchoHealthCheck +from redis.asyncio.retry import Retry +from redis.backoff import ExponentialBackoff +from redis.multidb.circuit import State as CBState +from redis.exceptions import ConnectionError + + +class TestEchoHealthCheck: + + @pytest.mark.asyncio + async def test_database_is_healthy_on_echo_response(self, mock_client, mock_cb): + """ + Mocking responses to mix error and actual responses to ensure that health check retry + according to given configuration. + """ + mock_client.execute_command = AsyncMock(side_effect=[ConnectionError, ConnectionError, 'healthcheck']) + hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) + db = Database(mock_client, mock_cb, 0.9) + + assert await hc.check_health(db) == True + assert mock_client.execute_command.call_count == 3 + + @pytest.mark.asyncio + async def test_database_is_unhealthy_on_incorrect_echo_response(self, mock_client, mock_cb): + """ + Mocking responses to mix error and actual responses to ensure that health check retry + according to given configuration. + """ + mock_client.execute_command = AsyncMock(side_effect=[ConnectionError, ConnectionError, 'wrong']) + hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) + db = Database(mock_client, mock_cb, 0.9) + + assert await hc.check_health(db) == False + assert mock_client.execute_command.call_count == 3 + + @pytest.mark.asyncio + async def test_database_close_circuit_on_successful_healthcheck(self, mock_client, mock_cb): + mock_client.execute_command = AsyncMock(side_effect=[ConnectionError, ConnectionError, 'healthcheck']) + mock_cb.state = CBState.HALF_OPEN + hc = EchoHealthCheck(Retry(backoff=ExponentialBackoff(cap=1.0), retries=3)) + db = Database(mock_client, mock_cb, 0.9) + + assert await hc.check_health(db) == True + assert mock_client.execute_command.call_count == 3 \ No newline at end of file From ae42bea09a097855e05fbf62ca757a85df53af5a Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 2 Sep 2025 11:45:57 +0300 Subject: [PATCH 3/9] Added command executor --- redis/asyncio/multidb/command_executor.py | 184 +++++++++++++++++- redis/asyncio/multidb/failover.py | 2 - redis/event.py | 7 +- tests/test_asyncio/test_multidb/conftest.py | 29 ++- .../test_multidb/test_command_executor.py | 165 ++++++++++++++++ .../test_multidb/test_failover.py | 8 +- 6 files changed, 378 insertions(+), 17 deletions(-) create mode 100644 tests/test_asyncio/test_multidb/test_command_executor.py diff --git a/redis/asyncio/multidb/command_executor.py b/redis/asyncio/multidb/command_executor.py index 18117160ee..af10a00988 100644 --- a/redis/asyncio/multidb/command_executor.py +++ b/redis/asyncio/multidb/command_executor.py @@ -1,12 +1,18 @@ from abc import abstractmethod +from datetime import datetime from typing import List, Optional, Callable, Any from redis.asyncio.client import PubSub, Pipeline -from redis.asyncio.multidb.database import Databases, AsyncDatabase +from redis.asyncio.multidb.database import Databases, AsyncDatabase, Database +from redis.asyncio.multidb.event import AsyncActiveDatabaseChanged, RegisterCommandFailure, \ + ResubscribeOnActiveDatabaseChanged from redis.asyncio.multidb.failover import AsyncFailoverStrategy from redis.asyncio.multidb.failure_detector import AsyncFailureDetector +from redis.multidb.circuit import State as CBState from redis.asyncio.retry import Retry -from redis.multidb.command_executor import CommandExecutor +from redis.event import EventDispatcherInterface, AsyncOnCommandsFailEvent +from redis.multidb.command_executor import CommandExecutor, BaseCommandExecutor +from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL class AsyncCommandExecutor(CommandExecutor): @@ -34,9 +40,8 @@ def active_database(self) -> Optional[AsyncDatabase]: """Returns currently active database.""" pass - @active_database.setter @abstractmethod - def active_database(self, database: AsyncDatabase) -> None: + async def set_active_database(self, database: AsyncDatabase) -> None: """Sets the currently active database.""" pass @@ -85,11 +90,176 @@ async def execute_transaction(self, transaction: Callable[[Pipeline], None], *wa pass @abstractmethod - def execute_pubsub_method(self, method_name: str, *args, **kwargs): + async def execute_pubsub_method(self, method_name: str, *args, **kwargs): """Executes a given method on active pub/sub.""" pass @abstractmethod - def execute_pubsub_run(self, sleep_time: float, **kwargs) -> Any: + async def execute_pubsub_run(self, sleep_time: float, **kwargs) -> Any: """Executes pub/sub run in a thread.""" - pass \ No newline at end of file + pass + + +class DefaultCommandExecutor(BaseCommandExecutor, AsyncCommandExecutor): + def __init__( + self, + failure_detectors: List[AsyncFailureDetector], + databases: Databases, + command_retry: Retry, + failover_strategy: AsyncFailoverStrategy, + event_dispatcher: EventDispatcherInterface, + auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL, + ): + """ + Initialize the DefaultCommandExecutor instance. + + Args: + failure_detectors: List of failure detector instances to monitor database health + databases: Collection of available databases to execute commands on + command_retry: Retry policy for failed command execution + failover_strategy: Strategy for handling database failover + event_dispatcher: Interface for dispatching events + auto_fallback_interval: Time interval in seconds between attempts to fall back to a primary database + """ + super().__init__(auto_fallback_interval) + + for fd in failure_detectors: + fd.set_command_executor(command_executor=self) + + self._databases = databases + self._failure_detectors = failure_detectors + self._command_retry = command_retry + self._failover_strategy = failover_strategy + self._event_dispatcher = event_dispatcher + self._active_database: Optional[Database] = None + self._active_pubsub: Optional[PubSub] = None + self._active_pubsub_kwargs = {} + self._setup_event_dispatcher() + self._schedule_next_fallback() + + @property + def databases(self) -> Databases: + return self._databases + + @property + def failure_detectors(self) -> List[AsyncFailureDetector]: + return self._failure_detectors + + def add_failure_detector(self, failure_detector: AsyncFailureDetector) -> None: + self._failure_detectors.append(failure_detector) + + @property + def active_database(self) -> Optional[AsyncDatabase]: + return self._active_database + + async def set_active_database(self, database: AsyncDatabase) -> None: + old_active = self._active_database + self._active_database = database + + if old_active is not None and old_active is not database: + await self._event_dispatcher.dispatch_async( + AsyncActiveDatabaseChanged(old_active, self._active_database, self, **self._active_pubsub_kwargs) + ) + + @property + def active_pubsub(self) -> Optional[PubSub]: + return self._active_pubsub + + @active_pubsub.setter + def active_pubsub(self, pubsub: PubSub) -> None: + self._active_pubsub = pubsub + + @property + def failover_strategy(self) -> AsyncFailoverStrategy: + return self._failover_strategy + + @property + def command_retry(self) -> Retry: + return self._command_retry + + async def pubsub(self, **kwargs): + async def callback(): + if self._active_pubsub is None: + self._active_pubsub = self._active_database.client.pubsub(**kwargs) + self._active_pubsub_kwargs = kwargs + return None + + return await self._execute_with_failure_detection(callback) + + async def execute_command(self, *args, **options): + async def callback(): + return await self._active_database.client.execute_command(*args, **options) + + return await self._execute_with_failure_detection(callback, args) + + async def execute_pipeline(self, command_stack: tuple): + async def callback(): + with self._active_database.client.pipeline() as pipe: + for command, options in command_stack: + await pipe.execute_command(*command, **options) + + return await pipe.execute() + + return await self._execute_with_failure_detection(callback, command_stack) + + async def execute_transaction(self, transaction: Callable[[Pipeline], None], *watches, **options): + async def callback(): + return await self._active_database.client.transaction(transaction, *watches, **options) + + return await self._execute_with_failure_detection(callback) + + async def execute_pubsub_method(self, method_name: str, *args, **kwargs): + async def callback(): + method = getattr(self.active_pubsub, method_name) + return await method(*args, **kwargs) + + return await self._execute_with_failure_detection(callback, *args) + + async def execute_pubsub_run(self, sleep_time: float, **kwargs) -> Any: + async def callback(): + return await self._active_pubsub.run(poll_timeout=sleep_time, **kwargs) + + return await self._execute_with_failure_detection(callback) + + async def _execute_with_failure_detection(self, callback: Callable, cmds: tuple = ()): + """ + Execute a commands execution callback with failure detection. + """ + async def wrapper(): + # On each retry we need to check active database as it might change. + await self._check_active_database() + return await callback() + + return await self._command_retry.call_with_retry( + lambda: wrapper(), + lambda error: self._on_command_fail(error, *cmds), + ) + + async def _check_active_database(self): + """ + Checks if active a database needs to be updated. + """ + if ( + self._active_database is None + or self._active_database.circuit.state != CBState.CLOSED + or ( + self._auto_fallback_interval != DEFAULT_AUTO_FALLBACK_INTERVAL + and self._next_fallback_attempt <= datetime.now() + ) + ): + await self.set_active_database(await self._failover_strategy.database()) + self._schedule_next_fallback() + + async def _on_command_fail(self, error, *args): + await self._event_dispatcher.dispatch_async(AsyncOnCommandsFailEvent(args, error)) + + def _setup_event_dispatcher(self): + """ + Registers necessary listeners. + """ + failure_listener = RegisterCommandFailure(self._failure_detectors) + resubscribe_listener = ResubscribeOnActiveDatabaseChanged() + self._event_dispatcher.register_listeners({ + AsyncOnCommandsFailEvent: [failure_listener], + AsyncActiveDatabaseChanged: [resubscribe_listener], + }) \ No newline at end of file diff --git a/redis/asyncio/multidb/failover.py b/redis/asyncio/multidb/failover.py index ad7f25ce41..a2ed427e05 100644 --- a/redis/asyncio/multidb/failover.py +++ b/redis/asyncio/multidb/failover.py @@ -10,7 +10,6 @@ class AsyncFailoverStrategy(ABC): - @property @abstractmethod async def database(self) -> AsyncDatabase: """Select the database according to the strategy.""" @@ -33,7 +32,6 @@ def __init__( self._retry.update_supported_errors([NoValidDatabaseException]) self._databases = WeightedList() - @property async def database(self) -> AsyncDatabase: return await self._retry.call_with_retry( lambda: self._get_active_database(), diff --git a/redis/event.py b/redis/event.py index 4d167442eb..8327ec5f76 100644 --- a/redis/event.py +++ b/redis/event.py @@ -43,7 +43,10 @@ async def dispatch_async(self, event: object): pass @abstractmethod - def register_listeners(self, mappings: Dict[Type[object], List[EventListenerInterface]]): + def register_listeners( + self, + mappings: Dict[Type[object], List[Union[EventListenerInterface, AsyncEventListenerInterface]]] + ): """Register additional listeners.""" pass @@ -99,7 +102,7 @@ def dispatch(self, event: object): listener.listen(event) async def dispatch_async(self, event: object): - with self._async_lock: + async with self._async_lock: listeners = self._event_listeners_mapping.get(type(event), []) for listener in listeners: diff --git a/tests/test_asyncio/test_multidb/conftest.py b/tests/test_asyncio/test_multidb/conftest.py index 1f67e3c63c..0c4e427264 100644 --- a/tests/test_asyncio/test_multidb/conftest.py +++ b/tests/test_asyncio/test_multidb/conftest.py @@ -2,10 +2,14 @@ import pytest +from redis.asyncio.multidb.failover import AsyncFailoverStrategy +from redis.asyncio.multidb.failure_detector import AsyncFailureDetector +from redis.asyncio.multidb.healthcheck import HealthCheck +from redis.data_structure import WeightedList from redis.multidb.circuit import State as CBState from redis.asyncio import Redis from redis.asyncio.multidb.circuit import AsyncCircuitBreaker -from redis.asyncio.multidb.database import Database +from redis.asyncio.multidb.database import Database, Databases @pytest.fixture() @@ -16,6 +20,18 @@ def mock_client() -> Redis: def mock_cb() -> AsyncCircuitBreaker: return Mock(spec=AsyncCircuitBreaker) +@pytest.fixture() +def mock_fd() -> AsyncFailureDetector: + return Mock(spec=AsyncFailureDetector) + +@pytest.fixture() +def mock_fs() -> AsyncFailoverStrategy: + return Mock(spec=AsyncFailoverStrategy) + +@pytest.fixture() +def mock_hc() -> HealthCheck: + return Mock(spec=HealthCheck) + @pytest.fixture() def mock_db(request) -> Database: db = Mock(spec=Database) @@ -56,4 +72,13 @@ def mock_db2(request) -> Database: mock_cb.state = cb.get("state", CBState.CLOSED) db.circuit = mock_cb - return db \ No newline at end of file + return db + + +def create_weighted_list(*databases) -> Databases: + dbs = WeightedList() + + for db in databases: + dbs.add(db, db.weight) + + return dbs \ No newline at end of file diff --git a/tests/test_asyncio/test_multidb/test_command_executor.py b/tests/test_asyncio/test_multidb/test_command_executor.py new file mode 100644 index 0000000000..3f64e6aa0b --- /dev/null +++ b/tests/test_asyncio/test_multidb/test_command_executor.py @@ -0,0 +1,165 @@ +import asyncio +from unittest.mock import AsyncMock + +import pytest + +from redis.asyncio.multidb.failure_detector import FailureDetectorAsyncWrapper +from redis.event import EventDispatcher +from redis.exceptions import ConnectionError +from redis.asyncio.multidb.command_executor import DefaultCommandExecutor +from redis.asyncio.retry import Retry +from redis.backoff import NoBackoff +from redis.multidb.circuit import State as CBState +from redis.multidb.failure_detector import CommandFailureDetector +from tests.test_asyncio.test_multidb.conftest import create_weighted_list + + +class TestDefaultCommandExecutor: + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_command_on_active_database(self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed): + mock_db1.client.execute_command = AsyncMock(return_value='OK1') + mock_db2.client.execute_command = AsyncMock(return_value='OK2') + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + executor = DefaultCommandExecutor( + failure_detectors=[mock_fd], + databases=databases, + failover_strategy=mock_fs, + event_dispatcher=mock_ed, + command_retry=Retry(NoBackoff(), 0) + ) + + await executor.set_active_database(mock_db1) + assert await executor.execute_command('SET', 'key', 'value') == 'OK1' + + await executor.set_active_database(mock_db2) + assert await executor.execute_command('SET', 'key', 'value') == 'OK2' + assert mock_ed.register_listeners.call_count == 1 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_command_automatically_select_active_database( + self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed + ): + mock_db1.client.execute_command = AsyncMock(return_value='OK1') + mock_db2.client.execute_command = AsyncMock(return_value='OK2') + mock_selector = AsyncMock(side_effect=[mock_db1, mock_db2]) + type(mock_fs).database = mock_selector + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + executor = DefaultCommandExecutor( + failure_detectors=[mock_fd], + databases=databases, + failover_strategy=mock_fs, + event_dispatcher=mock_ed, + command_retry=Retry(NoBackoff(), 0) + ) + + assert await executor.execute_command('SET', 'key', 'value') == 'OK1' + mock_db1.circuit.state = CBState.OPEN + + assert await executor.execute_command('SET', 'key', 'value') == 'OK2' + assert mock_ed.register_listeners.call_count == 1 + assert mock_selector.call_count == 2 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_command_fallback_to_another_db_after_fallback_interval( + self, mock_db, mock_db1, mock_db2, mock_fd, mock_fs, mock_ed + ): + mock_db1.client.execute_command = AsyncMock(return_value='OK1') + mock_db2.client.execute_command = AsyncMock(return_value='OK2') + mock_selector = AsyncMock(side_effect=[mock_db1, mock_db2, mock_db1]) + type(mock_fs).database = mock_selector + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + executor = DefaultCommandExecutor( + failure_detectors=[mock_fd], + databases=databases, + failover_strategy=mock_fs, + event_dispatcher=mock_ed, + auto_fallback_interval=0.1, + command_retry=Retry(NoBackoff(), 0) + ) + + assert await executor.execute_command('SET', 'key', 'value') == 'OK1' + mock_db1.weight = 0.1 + await asyncio.sleep(0.15) + + assert await executor.execute_command('SET', 'key', 'value') == 'OK2' + mock_db1.weight = 0.7 + await asyncio.sleep(0.15) + + assert await executor.execute_command('SET', 'key', 'value') == 'OK1' + assert mock_ed.register_listeners.call_count == 1 + assert mock_selector.call_count == 3 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_db,mock_db1,mock_db2', + [ + ( + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_command_fallback_to_another_db_after_failure_detection( + self, mock_db, mock_db1, mock_db2, mock_fs + ): + mock_db1.client.execute_command = AsyncMock(side_effect=['OK1', ConnectionError, ConnectionError, ConnectionError, 'OK1']) + mock_db2.client.execute_command = AsyncMock(side_effect=['OK2', ConnectionError, ConnectionError, ConnectionError]) + mock_selector = AsyncMock(side_effect=[mock_db1, mock_db2, mock_db1]) + type(mock_fs).database = mock_selector + threshold = 3 + fd = FailureDetectorAsyncWrapper(CommandFailureDetector(threshold, 1)) + ed = EventDispatcher() + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + executor = DefaultCommandExecutor( + failure_detectors=[fd], + databases=databases, + failover_strategy=mock_fs, + event_dispatcher=ed, + auto_fallback_interval=0.1, + command_retry=Retry(NoBackoff(), threshold), + ) + fd.set_command_executor(command_executor=executor) + + assert await executor.execute_command('SET', 'key', 'value') == 'OK1' + assert await executor.execute_command('SET', 'key', 'value') == 'OK2' + assert await executor.execute_command('SET', 'key', 'value') == 'OK1' + assert mock_selector.call_count == 3 \ No newline at end of file diff --git a/tests/test_asyncio/test_multidb/test_failover.py b/tests/test_asyncio/test_multidb/test_failover.py index d7bc4411b6..f692c40643 100644 --- a/tests/test_asyncio/test_multidb/test_failover.py +++ b/tests/test_asyncio/test_multidb/test_failover.py @@ -39,7 +39,7 @@ async def test_get_valid_database(self, mock_db, mock_db1, mock_db2): strategy = WeightBasedFailoverStrategy(retry=retry) strategy.set_databases(databases) - assert await strategy.database == mock_db1 + assert await strategy.database() == mock_db1 @pytest.mark.asyncio @pytest.mark.parametrize( @@ -67,7 +67,7 @@ async def test_get_valid_database_with_retries(self, mock_db, mock_db1, mock_db2 failover_strategy = WeightBasedFailoverStrategy(retry=retry) failover_strategy.set_databases(databases) - assert await failover_strategy.database == mock_db + assert await failover_strategy.database() == mock_db assert state_mock.call_count == 4 @pytest.mark.asyncio @@ -97,7 +97,7 @@ async def test_get_valid_database_throws_exception_with_retries(self, mock_db, m failover_strategy.set_databases(databases) with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): - assert await failover_strategy.database + assert await failover_strategy.database() assert state_mock.call_count == 4 @@ -118,4 +118,4 @@ async def test_throws_exception_on_empty_databases(self, mock_db, mock_db1, mock failover_strategy = WeightBasedFailoverStrategy(retry=retry) with pytest.raises(NoValidDatabaseException, match='No valid database available for communication'): - assert await failover_strategy.database \ No newline at end of file + assert await failover_strategy.database() \ No newline at end of file From 8fc74b96c20e2cfcd4160217afbf05dced520375 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Tue, 2 Sep 2025 12:53:16 +0300 Subject: [PATCH 4/9] Added recurring background tasks with event loop only --- redis/asyncio/multidb/config.py | 169 ++++++++++++++++++++++++++++++++ redis/background.py | 52 +++++++++- tests/test_background.py | 33 +++++++ 3 files changed, 251 insertions(+), 3 deletions(-) create mode 100644 redis/asyncio/multidb/config.py diff --git a/redis/asyncio/multidb/config.py b/redis/asyncio/multidb/config.py new file mode 100644 index 0000000000..9b3588aa06 --- /dev/null +++ b/redis/asyncio/multidb/config.py @@ -0,0 +1,169 @@ +from dataclasses import dataclass, field +from typing import Optional, List, Type, Union + +import pybreaker + +from redis.asyncio import ConnectionPool, Redis, RedisCluster +from redis.asyncio.multidb.circuit import AsyncCircuitBreaker, AsyncPBCircuitBreakerAdapter +from redis.asyncio.multidb.database import Databases, Database +from redis.asyncio.multidb.failover import AsyncFailoverStrategy, WeightBasedFailoverStrategy +from redis.asyncio.multidb.failure_detector import AsyncFailureDetector, FailureDetectorAsyncWrapper +from redis.asyncio.multidb.healthcheck import HealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, DEFAULT_HEALTH_CHECK_BACKOFF, \ + EchoHealthCheck +from redis.asyncio.retry import Retry +from redis.backoff import ExponentialWithJitterBackoff, AbstractBackoff, NoBackoff +from redis.data_structure import WeightedList +from redis.event import EventDispatcherInterface, EventDispatcher +from redis.multidb.failure_detector import CommandFailureDetector + +DEFAULT_GRACE_PERIOD = 5.0 +DEFAULT_HEALTH_CHECK_INTERVAL = 5 +DEFAULT_FAILURES_THRESHOLD = 3 +DEFAULT_FAILURES_DURATION = 2 +DEFAULT_FAILOVER_RETRIES = 3 +DEFAULT_FAILOVER_BACKOFF = ExponentialWithJitterBackoff(cap=3) +DEFAULT_AUTO_FALLBACK_INTERVAL = -1 + +def default_event_dispatcher() -> EventDispatcherInterface: + return EventDispatcher() + +@dataclass +class DatabaseConfig: + """ + Dataclass representing the configuration for a database connection. + + This class is used to store configuration settings for a database connection, + including client options, connection sourcing details, circuit breaker settings, + and cluster-specific properties. It provides a structure for defining these + attributes and allows for the creation of customized configurations for various + database setups. + + Attributes: + weight (float): Weight of the database to define the active one. + client_kwargs (dict): Additional parameters for the database client connection. + from_url (Optional[str]): Redis URL way of connecting to the database. + from_pool (Optional[ConnectionPool]): A pre-configured connection pool to use. + circuit (Optional[SyncCircuitBreaker]): Custom circuit breaker implementation. + grace_period (float): Grace period after which we need to check if the circuit could be closed again. + health_check_url (Optional[str]): URL for health checks. Cluster FQDN is typically used + on public Redis Enterprise endpoints. + + Methods: + default_circuit_breaker: + Generates and returns a default CircuitBreaker instance adapted for use. + """ + weight: float = 1.0 + client_kwargs: dict = field(default_factory=dict) + from_url: Optional[str] = None + from_pool: Optional[ConnectionPool] = None + circuit: Optional[AsyncCircuitBreaker] = None + grace_period: float = DEFAULT_GRACE_PERIOD + health_check_url: Optional[str] = None + + def default_circuit_breaker(self) -> AsyncCircuitBreaker: + circuit_breaker = pybreaker.CircuitBreaker(reset_timeout=self.grace_period) + return AsyncPBCircuitBreakerAdapter(circuit_breaker) + +@dataclass +class MultiDbConfig: + """ + Configuration class for managing multiple database connections in a resilient and fail-safe manner. + + Attributes: + databases_config: A list of database configurations. + client_class: The client class used to manage database connections. + command_retry: Retry strategy for executing database commands. + failure_detectors: Optional list of additional failure detectors for monitoring database failures. + failure_threshold: Threshold for determining database failure. + failures_interval: Time interval for tracking database failures. + health_checks: Optional list of additional health checks performed on databases. + health_check_interval: Time interval for executing health checks. + health_check_retries: Number of retry attempts for performing health checks. + health_check_backoff: Backoff strategy for health check retries. + failover_strategy: Optional strategy for handling database failover scenarios. + failover_retries: Number of retries allowed for failover operations. + failover_backoff: Backoff strategy for failover retries. + auto_fallback_interval: Time interval to trigger automatic fallback. + event_dispatcher: Interface for dispatching events related to database operations. + + Methods: + databases: + Retrieves a collection of database clients managed by weighted configurations. + Initializes database clients based on the provided configuration and removes + redundant retry objects for lower-level clients to rely on global retry logic. + + default_failure_detectors: + Returns the default list of failure detectors used to monitor database failures. + + default_health_checks: + Returns the default list of health checks used to monitor database health + with specific retry and backoff strategies. + + default_failover_strategy: + Provides the default failover strategy used for handling failover scenarios + with defined retry and backoff configurations. + """ + databases_config: List[DatabaseConfig] + client_class: Type[Union[Redis, RedisCluster]] = Redis + command_retry: Retry = Retry( + backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3 + ) + failure_detectors: Optional[List[AsyncFailureDetector]] = None + failure_threshold: int = DEFAULT_FAILURES_THRESHOLD + failures_interval: float = DEFAULT_FAILURES_DURATION + health_checks: Optional[List[HealthCheck]] = None + health_check_interval: float = DEFAULT_HEALTH_CHECK_INTERVAL + health_check_retries: int = DEFAULT_HEALTH_CHECK_RETRIES + health_check_backoff: AbstractBackoff = DEFAULT_HEALTH_CHECK_BACKOFF + failover_strategy: Optional[AsyncFailoverStrategy] = None + failover_retries: int = DEFAULT_FAILOVER_RETRIES + failover_backoff: AbstractBackoff = DEFAULT_FAILOVER_BACKOFF + auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL + event_dispatcher: EventDispatcherInterface = field(default_factory=default_event_dispatcher) + + def databases(self) -> Databases: + databases = WeightedList() + + for database_config in self.databases_config: + # The retry object is not used in the lower level clients, so we can safely remove it. + # We rely on command_retry in terms of global retries. + database_config.client_kwargs.update({"retry": Retry(retries=0, backoff=NoBackoff())}) + + if database_config.from_url: + client = self.client_class.from_url(database_config.from_url, **database_config.client_kwargs) + elif database_config.from_pool: + database_config.from_pool.set_retry(Retry(retries=0, backoff=NoBackoff())) + client = self.client_class.from_pool(connection_pool=database_config.from_pool) + else: + client = self.client_class(**database_config.client_kwargs) + + circuit = database_config.default_circuit_breaker() \ + if database_config.circuit is None else database_config.circuit + databases.add( + Database( + client=client, + circuit=circuit, + weight=database_config.weight, + health_check_url=database_config.health_check_url + ), + database_config.weight + ) + + return databases + + def default_failure_detectors(self) -> List[AsyncFailureDetector]: + return [ + FailureDetectorAsyncWrapper( + CommandFailureDetector(threshold=self.failure_threshold, duration=self.failures_interval) + ), + ] + + def default_health_checks(self) -> List[HealthCheck]: + return [ + EchoHealthCheck(retry=Retry(retries=self.health_check_retries, backoff=self.health_check_backoff)), + ] + + def default_failover_strategy(self) -> AsyncFailoverStrategy: + return WeightBasedFailoverStrategy( + retry=Retry(retries=self.failover_retries, backoff=self.failover_backoff), + ) \ No newline at end of file diff --git a/redis/background.py b/redis/background.py index 6466649859..ce43cbfa7a 100644 --- a/redis/background.py +++ b/redis/background.py @@ -1,6 +1,7 @@ import asyncio import threading -from typing import Callable +from typing import Callable, Coroutine, Any + class BackgroundScheduler: """ @@ -45,7 +46,35 @@ def run_recurring( ) thread.start() - def _call_later(self, loop: asyncio.AbstractEventLoop, delay: float, callback: Callable, *args): + async def run_recurring_async( + self, + interval: float, + coro: Callable[..., Coroutine[Any, Any, Any]], + *args + ): + """ + Runs recurring coroutine with given interval in seconds in the current event loop. + To be used only from an async context. No additional threads are created. + """ + loop = asyncio.get_running_loop() + wrapped = _async_to_sync_wrapper(loop, coro, *args) + + def tick(): + # Schedule the coroutine + wrapped() + # Schedule next tick + self._next_timer = loop.call_later(interval, tick) + + # Schedule first tick + self._next_timer = loop.call_later(interval, tick) + + def _call_later( + self, + loop: asyncio.AbstractEventLoop, + delay: float, + callback: Callable, + *args + ): self._next_timer = loop.call_later(delay, callback, *args) def _call_later_recurring( @@ -86,4 +115,21 @@ def _start_event_loop_in_thread(event_loop: asyncio.AbstractEventLoop, call_soon """ asyncio.set_event_loop(event_loop) event_loop.call_soon(call_soon_cb, event_loop, *args) - event_loop.run_forever() \ No newline at end of file + event_loop.run_forever() + +def _async_to_sync_wrapper(loop, coro_func, *args, **kwargs): + """ + Wraps an asynchronous function so it can be used with loop.call_later. + + :param loop: The event loop in which the coroutine will be executed. + :param coro_func: The coroutine function to wrap. + :param args: Positional arguments to pass to the coroutine function. + :param kwargs: Keyword arguments to pass to the coroutine function. + :return: A regular function suitable for loop.call_later. + """ + + def wrapped(): + # Schedule the coroutine in the event loop + asyncio.ensure_future(coro_func(*args, **kwargs), loop=loop) + + return wrapped \ No newline at end of file diff --git a/tests/test_background.py b/tests/test_background.py index 4b3a5377c1..ba62e5bdd9 100644 --- a/tests/test_background.py +++ b/tests/test_background.py @@ -1,3 +1,4 @@ +import asyncio from time import sleep import pytest @@ -57,4 +58,36 @@ def callback(arg1: str, arg2: int): sleep(timeout) + assert execute_counter == call_count + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "interval,timeout,call_count", + [ + (0.012, 0.04, 3), + (0.035, 0.04, 1), + (0.045, 0.04, 0), + ] + ) + async def test_run_recurring_async(self, interval, timeout, call_count): + execute_counter = 0 + one = 'arg1' + two = 9999 + + async def callback(arg1: str, arg2: int): + nonlocal execute_counter + nonlocal one + nonlocal two + + execute_counter += 1 + + assert arg1 == one + assert arg2 == two + + scheduler = BackgroundScheduler() + await scheduler.run_recurring_async(interval, callback, one, two) + assert execute_counter == 0 + + await asyncio.sleep(timeout) + assert execute_counter == call_count \ No newline at end of file From 97c3cde72dad63a46e00f6e6425e690206a2daaa Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Wed, 3 Sep 2025 12:21:00 +0300 Subject: [PATCH 5/9] Added MultiDBClient --- redis/asyncio/multidb/circuit.py | 26 - redis/asyncio/multidb/client.py | 237 +++++++++ redis/asyncio/multidb/command_executor.py | 1 + redis/asyncio/multidb/config.py | 10 +- redis/asyncio/multidb/database.py | 12 +- redis/multidb/circuit.py | 13 +- redis/multidb/client.py | 6 +- redis/multidb/config.py | 8 +- redis/multidb/database.py | 12 +- tests/test_asyncio/test_multidb/conftest.py | 38 +- .../test_asyncio/test_multidb/test_circuit.py | 58 --- .../test_asyncio/test_multidb/test_client.py | 471 ++++++++++++++++++ tests/test_multidb/conftest.py | 12 +- tests/test_multidb/test_circuit.py | 4 +- tests/test_multidb/test_client.py | 4 - tests/test_multidb/test_config.py | 10 +- 16 files changed, 784 insertions(+), 138 deletions(-) delete mode 100644 redis/asyncio/multidb/circuit.py create mode 100644 redis/asyncio/multidb/client.py delete mode 100644 tests/test_asyncio/test_multidb/test_circuit.py create mode 100644 tests/test_asyncio/test_multidb/test_client.py diff --git a/redis/asyncio/multidb/circuit.py b/redis/asyncio/multidb/circuit.py deleted file mode 100644 index 97411e6e42..0000000000 --- a/redis/asyncio/multidb/circuit.py +++ /dev/null @@ -1,26 +0,0 @@ -from abc import abstractmethod -from typing import Callable - -import pybreaker - -from redis.multidb.circuit import CircuitBreaker, State, BaseCircuitBreaker, PBCircuitBreakerAdapter - - -class AsyncCircuitBreaker(CircuitBreaker): - """Async implementation of Circuit Breaker interface.""" - - @abstractmethod - async def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): - """Callback called when the state of the circuit changes.""" - pass - -class AsyncPBCircuitBreakerAdapter(BaseCircuitBreaker, AsyncCircuitBreaker): - """ - Async adapter for pybreaker's CircuitBreaker implementation. - """ - def __init__(self, cb: pybreaker.CircuitBreaker): - super().__init__(cb) - self._sync_cb = PBCircuitBreakerAdapter(cb) - - async def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): - self._sync_cb.on_state_changed(cb) \ No newline at end of file diff --git a/redis/asyncio/multidb/client.py b/redis/asyncio/multidb/client.py new file mode 100644 index 0000000000..dbf03a3ef4 --- /dev/null +++ b/redis/asyncio/multidb/client.py @@ -0,0 +1,237 @@ +import asyncio +from typing import Callable, Optional, Coroutine, Any + +from redis.asyncio.multidb.command_executor import DefaultCommandExecutor +from redis.asyncio.multidb.database import AsyncDatabase, Databases +from redis.asyncio.multidb.failure_detector import AsyncFailureDetector +from redis.asyncio.multidb.healthcheck import HealthCheck +from redis.multidb.circuit import State as CBState, CircuitBreaker +from redis.asyncio.multidb.config import MultiDbConfig, DEFAULT_GRACE_PERIOD +from redis.background import BackgroundScheduler +from redis.commands import AsyncRedisModuleCommands, AsyncCoreCommands +from redis.multidb.exception import NoValidDatabaseException + + +class MultiDBClient(AsyncRedisModuleCommands, AsyncCoreCommands): + """ + Client that operates on multiple logical Redis databases. + Should be used in Active-Active database setups. + """ + def __init__(self, config: MultiDbConfig): + self._databases = config.databases() + self._health_checks = config.default_health_checks() + + if config.health_checks is not None: + self._health_checks.extend(config.health_checks) + + self._health_check_interval = config.health_check_interval + self._failure_detectors = config.default_failure_detectors() + + if config.failure_detectors is not None: + self._failure_detectors.extend(config.failure_detectors) + + self._failover_strategy = config.default_failover_strategy() \ + if config.failover_strategy is None else config.failover_strategy + self._failover_strategy.set_databases(self._databases) + self._auto_fallback_interval = config.auto_fallback_interval + self._event_dispatcher = config.event_dispatcher + self._command_retry = config.command_retry + self._command_retry.update_supported_errors([ConnectionRefusedError]) + self.command_executor = DefaultCommandExecutor( + failure_detectors=self._failure_detectors, + databases=self._databases, + command_retry=self._command_retry, + failover_strategy=self._failover_strategy, + event_dispatcher=self._event_dispatcher, + auto_fallback_interval=self._auto_fallback_interval, + ) + self.initialized = False + self._hc_lock = asyncio.Lock() + self._bg_scheduler = BackgroundScheduler() + self._config = config + + async def initialize(self): + """ + Perform initialization of databases to define their initial state. + """ + async def raise_exception_on_failed_hc(error): + raise error + + # Initial databases check to define initial state + await self._check_databases_health(on_error=raise_exception_on_failed_hc) + + # Starts recurring health checks on the background. + await self._bg_scheduler.run_recurring_async( + self._health_check_interval, + self._check_databases_health, + ) + + is_active_db_found = False + + for database, weight in self._databases: + # Set on state changed callback for each circuit. + database.circuit.on_state_changed(self._on_circuit_state_change_callback) + + # Set states according to a weights and circuit state + if database.circuit.state == CBState.CLOSED and not is_active_db_found: + await self.command_executor.set_active_database(database) + is_active_db_found = True + + if not is_active_db_found: + raise NoValidDatabaseException('Initial connection failed - no active database found') + + self.initialized = True + + def get_databases(self) -> Databases: + """ + Returns a sorted (by weight) list of all databases. + """ + return self._databases + + async def set_active_database(self, database: AsyncDatabase) -> None: + """ + Promote one of the existing databases to become an active. + """ + exists = None + + for existing_db, _ in self._databases: + if existing_db == database: + exists = True + break + + if not exists: + raise ValueError('Given database is not a member of database list') + + await self._check_db_health(database) + + if database.circuit.state == CBState.CLOSED: + highest_weighted_db, _ = self._databases.get_top_n(1)[0] + await self.command_executor.set_active_database(database) + return + + raise NoValidDatabaseException('Cannot set active database, database is unhealthy') + + async def add_database(self, database: AsyncDatabase): + """ + Adds a new database to the database list. + """ + for existing_db, _ in self._databases: + if existing_db == database: + raise ValueError('Given database already exists') + + await self._check_db_health(database) + + highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0] + self._databases.add(database, database.weight) + await self._change_active_database(database, highest_weighted_db) + + async def _change_active_database(self, new_database: AsyncDatabase, highest_weight_database: AsyncDatabase): + if new_database.weight > highest_weight_database.weight and new_database.circuit.state == CBState.CLOSED: + await self.command_executor.set_active_database(new_database) + + async def remove_database(self, database: AsyncDatabase): + """ + Removes a database from the database list. + """ + weight = self._databases.remove(database) + highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0] + + if highest_weight <= weight and highest_weighted_db.circuit.state == CBState.CLOSED: + await self.command_executor.set_active_database(highest_weighted_db) + + async def update_database_weight(self, database: AsyncDatabase, weight: float): + """ + Updates a database from the database list. + """ + exists = None + + for existing_db, _ in self._databases: + if existing_db == database: + exists = True + break + + if not exists: + raise ValueError('Given database is not a member of database list') + + highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0] + self._databases.update_weight(database, weight) + database.weight = weight + await self._change_active_database(database, highest_weighted_db) + + def add_failure_detector(self, failure_detector: AsyncFailureDetector): + """ + Adds a new failure detector to the database. + """ + self._failure_detectors.append(failure_detector) + + async def add_health_check(self, healthcheck: HealthCheck): + """ + Adds a new health check to the database. + """ + async with self._hc_lock: + self._health_checks.append(healthcheck) + + async def execute_command(self, *args, **options): + """ + Executes a single command and return its result. + """ + if not self.initialized: + await self.initialize() + + return await self.command_executor.execute_command(*args, **options) + + async def _check_databases_health( + self, + on_error: Optional[Callable[[Exception], Coroutine[Any, Any, None]]] = None, + ): + """ + Runs health checks as a recurring task. + Runs health checks against all databases. + """ + for database, _ in self._databases: + async with self._hc_lock: + await self._check_db_health(database, on_error) + + async def _check_db_health( + self, + database: AsyncDatabase, + on_error: Optional[Callable[[Exception], Coroutine[Any, Any, None]]] = None, + ) -> None: + """ + Runs health checks on the given database until first failure. + """ + is_healthy = True + + # Health check will setup circuit state + for health_check in self._health_checks: + if not is_healthy: + # If one of the health checks failed, it's considered unhealthy + break + + try: + is_healthy = await health_check.check_health(database) + + if not is_healthy and database.circuit.state != CBState.OPEN: + database.circuit.state = CBState.OPEN + elif is_healthy and database.circuit.state != CBState.CLOSED: + database.circuit.state = CBState.CLOSED + except Exception as e: + if database.circuit.state != CBState.OPEN: + database.circuit.state = CBState.OPEN + is_healthy = False + + if on_error: + await on_error(e) + + def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState): + loop = asyncio.get_running_loop() + + if new_state == CBState.HALF_OPEN: + asyncio.create_task(self._check_db_health(circuit.database)) + return + + if old_state == CBState.CLOSED and new_state == CBState.OPEN: + loop.call_later(DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit) + +def _half_open_circuit(circuit: CircuitBreaker): + circuit.state = CBState.HALF_OPEN \ No newline at end of file diff --git a/redis/asyncio/multidb/command_executor.py b/redis/asyncio/multidb/command_executor.py index af10a00988..22aef83118 100644 --- a/redis/asyncio/multidb/command_executor.py +++ b/redis/asyncio/multidb/command_executor.py @@ -248,6 +248,7 @@ async def _check_active_database(self): ) ): await self.set_active_database(await self._failover_strategy.database()) + print("Active database now with weight {}", format(self._active_database.weight)) self._schedule_next_fallback() async def _on_command_fail(self, error, *args): diff --git a/redis/asyncio/multidb/config.py b/redis/asyncio/multidb/config.py index 9b3588aa06..b5f4a0658d 100644 --- a/redis/asyncio/multidb/config.py +++ b/redis/asyncio/multidb/config.py @@ -4,7 +4,6 @@ import pybreaker from redis.asyncio import ConnectionPool, Redis, RedisCluster -from redis.asyncio.multidb.circuit import AsyncCircuitBreaker, AsyncPBCircuitBreakerAdapter from redis.asyncio.multidb.database import Databases, Database from redis.asyncio.multidb.failover import AsyncFailoverStrategy, WeightBasedFailoverStrategy from redis.asyncio.multidb.failure_detector import AsyncFailureDetector, FailureDetectorAsyncWrapper @@ -14,6 +13,7 @@ from redis.backoff import ExponentialWithJitterBackoff, AbstractBackoff, NoBackoff from redis.data_structure import WeightedList from redis.event import EventDispatcherInterface, EventDispatcher +from redis.multidb.circuit import CircuitBreaker, PBCircuitBreakerAdapter from redis.multidb.failure_detector import CommandFailureDetector DEFAULT_GRACE_PERIOD = 5.0 @@ -43,7 +43,7 @@ class DatabaseConfig: client_kwargs (dict): Additional parameters for the database client connection. from_url (Optional[str]): Redis URL way of connecting to the database. from_pool (Optional[ConnectionPool]): A pre-configured connection pool to use. - circuit (Optional[SyncCircuitBreaker]): Custom circuit breaker implementation. + circuit (Optional[CircuitBreaker]): Custom circuit breaker implementation. grace_period (float): Grace period after which we need to check if the circuit could be closed again. health_check_url (Optional[str]): URL for health checks. Cluster FQDN is typically used on public Redis Enterprise endpoints. @@ -56,13 +56,13 @@ class DatabaseConfig: client_kwargs: dict = field(default_factory=dict) from_url: Optional[str] = None from_pool: Optional[ConnectionPool] = None - circuit: Optional[AsyncCircuitBreaker] = None + circuit: Optional[CircuitBreaker] = None grace_period: float = DEFAULT_GRACE_PERIOD health_check_url: Optional[str] = None - def default_circuit_breaker(self) -> AsyncCircuitBreaker: + def default_circuit_breaker(self) -> CircuitBreaker: circuit_breaker = pybreaker.CircuitBreaker(reset_timeout=self.grace_period) - return AsyncPBCircuitBreakerAdapter(circuit_breaker) + return PBCircuitBreakerAdapter(circuit_breaker) @dataclass class MultiDbConfig: diff --git a/redis/asyncio/multidb/database.py b/redis/asyncio/multidb/database.py index 85320f3aaa..6afbbbf5ea 100644 --- a/redis/asyncio/multidb/database.py +++ b/redis/asyncio/multidb/database.py @@ -2,8 +2,8 @@ from typing import Union, Optional from redis.asyncio import Redis, RedisCluster -from redis.asyncio.multidb.circuit import AsyncCircuitBreaker from redis.data_structure import WeightedList +from redis.multidb.circuit import CircuitBreaker from redis.multidb.database import AbstractDatabase, BaseDatabase from redis.typing import Number @@ -24,13 +24,13 @@ def client(self, client: Union[Redis, RedisCluster]): @property @abstractmethod - def circuit(self) -> AsyncCircuitBreaker: + def circuit(self) -> CircuitBreaker: """Circuit breaker for the current database.""" pass @circuit.setter @abstractmethod - def circuit(self, circuit: AsyncCircuitBreaker): + def circuit(self, circuit: CircuitBreaker): """Set the circuit breaker for the current database.""" pass @@ -40,7 +40,7 @@ class Database(BaseDatabase, AsyncDatabase): def __init__( self, client: Union[Redis, RedisCluster], - circuit: AsyncCircuitBreaker, + circuit: CircuitBreaker, weight: float, health_check_url: Optional[str] = None, ): @@ -58,10 +58,10 @@ def client(self, client: Union[Redis, RedisCluster]): self._client = client @property - def circuit(self) -> AsyncCircuitBreaker: + def circuit(self) -> CircuitBreaker: return self._cb @circuit.setter - def circuit(self, circuit: AsyncCircuitBreaker): + def circuit(self, circuit: CircuitBreaker): self._cb = circuit diff --git a/redis/multidb/circuit.py b/redis/multidb/circuit.py index 576ee27fab..8f904c0e4b 100644 --- a/redis/multidb/circuit.py +++ b/redis/multidb/circuit.py @@ -45,6 +45,11 @@ def database(self, database): """Set database associated with this circuit.""" pass + @abstractmethod + def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): + """Callback called when the state of the circuit changes.""" + pass + class BaseCircuitBreaker(CircuitBreaker): """ Base implementation of Circuit Breaker interface. @@ -82,10 +87,6 @@ def database(self): def database(self, database): self._database = database -class SyncCircuitBreaker(CircuitBreaker): - """ - Synchronous implementation of Circuit Breaker interface. - """ @abstractmethod def on_state_changed(self, cb: Callable[["CircuitBreaker", State, State], None]): """Callback called when the state of the circuit changes.""" @@ -95,7 +96,7 @@ class PBListener(pybreaker.CircuitBreakerListener): """Wrapper for callback to be compatible with pybreaker implementation.""" def __init__( self, - cb: Callable[[SyncCircuitBreaker, State, State], None], + cb: Callable[[CircuitBreaker, State, State], None], database, ): """ @@ -116,7 +117,7 @@ def state_change(self, cb, old_state, new_state): new_state = State(value=new_state.name) self._cb(cb, old_state, new_state) -class PBCircuitBreakerAdapter(SyncCircuitBreaker, BaseCircuitBreaker): +class PBCircuitBreakerAdapter(BaseCircuitBreaker): def __init__(self, cb: pybreaker.CircuitBreaker): """ Initialize a PBCircuitBreakerAdapter instance. diff --git a/redis/multidb/client.py b/redis/multidb/client.py index 8a0e006977..71e079346a 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -5,7 +5,7 @@ from redis.commands import RedisModuleCommands, CoreCommands from redis.multidb.command_executor import DefaultCommandExecutor from redis.multidb.config import MultiDbConfig, DEFAULT_GRACE_PERIOD -from redis.multidb.circuit import State as CBState, SyncCircuitBreaker +from redis.multidb.circuit import State as CBState, CircuitBreaker from redis.multidb.database import Database, Databases, SyncDatabase from redis.multidb.exception import NoValidDatabaseException from redis.multidb.failure_detector import FailureDetector @@ -244,7 +244,7 @@ def _check_databases_health(self, on_error: Callable[[Exception], None] = None): for database, _ in self._databases: self._check_db_health(database, on_error) - def _on_circuit_state_change_callback(self, circuit: SyncCircuitBreaker, old_state: CBState, new_state: CBState): + def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state: CBState, new_state: CBState): if new_state == CBState.HALF_OPEN: self._check_db_health(circuit.database) return @@ -252,7 +252,7 @@ def _on_circuit_state_change_callback(self, circuit: SyncCircuitBreaker, old_sta if old_state == CBState.CLOSED and new_state == CBState.OPEN: self._bg_scheduler.run_once(DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit) -def _half_open_circuit(circuit: SyncCircuitBreaker): +def _half_open_circuit(circuit: CircuitBreaker): circuit.state = CBState.HALF_OPEN diff --git a/redis/multidb/config.py b/redis/multidb/config.py index a966ec329a..fc349ed04b 100644 --- a/redis/multidb/config.py +++ b/redis/multidb/config.py @@ -9,7 +9,7 @@ from redis.backoff import ExponentialWithJitterBackoff, AbstractBackoff, NoBackoff from redis.data_structure import WeightedList from redis.event import EventDispatcher, EventDispatcherInterface -from redis.multidb.circuit import PBCircuitBreakerAdapter, SyncCircuitBreaker +from redis.multidb.circuit import PBCircuitBreakerAdapter, CircuitBreaker from redis.multidb.database import Database, Databases from redis.multidb.failure_detector import FailureDetector, CommandFailureDetector from redis.multidb.healthcheck import HealthCheck, EchoHealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, \ @@ -44,7 +44,7 @@ class DatabaseConfig: client_kwargs (dict): Additional parameters for the database client connection. from_url (Optional[str]): Redis URL way of connecting to the database. from_pool (Optional[ConnectionPool]): A pre-configured connection pool to use. - circuit (Optional[SyncCircuitBreaker]): Custom circuit breaker implementation. + circuit (Optional[CircuitBreaker]): Custom circuit breaker implementation. grace_period (float): Grace period after which we need to check if the circuit could be closed again. health_check_url (Optional[str]): URL for health checks. Cluster FQDN is typically used on public Redis Enterprise endpoints. @@ -57,11 +57,11 @@ class DatabaseConfig: client_kwargs: dict = field(default_factory=dict) from_url: Optional[str] = None from_pool: Optional[ConnectionPool] = None - circuit: Optional[SyncCircuitBreaker] = None + circuit: Optional[CircuitBreaker] = None grace_period: float = DEFAULT_GRACE_PERIOD health_check_url: Optional[str] = None - def default_circuit_breaker(self) -> SyncCircuitBreaker: + def default_circuit_breaker(self) -> CircuitBreaker: circuit_breaker = pybreaker.CircuitBreaker(reset_timeout=self.grace_period) return PBCircuitBreakerAdapter(circuit_breaker) diff --git a/redis/multidb/database.py b/redis/multidb/database.py index 75a662d904..9c2ffe3552 100644 --- a/redis/multidb/database.py +++ b/redis/multidb/database.py @@ -5,7 +5,7 @@ from redis import RedisCluster from redis.data_structure import WeightedList -from redis.multidb.circuit import SyncCircuitBreaker +from redis.multidb.circuit import CircuitBreaker from redis.typing import Number class AbstractDatabase(ABC): @@ -74,13 +74,13 @@ def client(self, client: Union[redis.Redis, RedisCluster]): @property @abstractmethod - def circuit(self) -> SyncCircuitBreaker: + def circuit(self) -> CircuitBreaker: """Circuit breaker for the current database.""" pass @circuit.setter @abstractmethod - def circuit(self, circuit: SyncCircuitBreaker): + def circuit(self, circuit: CircuitBreaker): """Set the circuit breaker for the current database.""" pass @@ -90,7 +90,7 @@ class Database(BaseDatabase, SyncDatabase): def __init__( self, client: Union[redis.Redis, RedisCluster], - circuit: SyncCircuitBreaker, + circuit: CircuitBreaker, weight: float, health_check_url: Optional[str] = None, ): @@ -117,9 +117,9 @@ def client(self, client: Union[redis.Redis, RedisCluster]): self._client = client @property - def circuit(self) -> SyncCircuitBreaker: + def circuit(self) -> CircuitBreaker: return self._cb @circuit.setter - def circuit(self, circuit: SyncCircuitBreaker): + def circuit(self, circuit: CircuitBreaker): self._cb = circuit \ No newline at end of file diff --git a/tests/test_asyncio/test_multidb/conftest.py b/tests/test_asyncio/test_multidb/conftest.py index 0c4e427264..0ac231cf52 100644 --- a/tests/test_asyncio/test_multidb/conftest.py +++ b/tests/test_asyncio/test_multidb/conftest.py @@ -2,13 +2,14 @@ import pytest +from redis.asyncio.multidb.config import MultiDbConfig, DEFAULT_HEALTH_CHECK_INTERVAL, DEFAULT_AUTO_FALLBACK_INTERVAL, \ + DatabaseConfig from redis.asyncio.multidb.failover import AsyncFailoverStrategy from redis.asyncio.multidb.failure_detector import AsyncFailureDetector from redis.asyncio.multidb.healthcheck import HealthCheck from redis.data_structure import WeightedList -from redis.multidb.circuit import State as CBState +from redis.multidb.circuit import State as CBState, CircuitBreaker from redis.asyncio import Redis -from redis.asyncio.multidb.circuit import AsyncCircuitBreaker from redis.asyncio.multidb.database import Database, Databases @@ -17,8 +18,8 @@ def mock_client() -> Redis: return Mock(spec=Redis) @pytest.fixture() -def mock_cb() -> AsyncCircuitBreaker: - return Mock(spec=AsyncCircuitBreaker) +def mock_cb() -> CircuitBreaker: + return Mock(spec=CircuitBreaker) @pytest.fixture() def mock_fd() -> AsyncFailureDetector: @@ -39,7 +40,7 @@ def mock_db(request) -> Database: db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) - mock_cb = Mock(spec=AsyncCircuitBreaker) + mock_cb = Mock(spec=CircuitBreaker) mock_cb.grace_period = cb.get("grace_period", 1.0) mock_cb.state = cb.get("state", CBState.CLOSED) @@ -53,7 +54,7 @@ def mock_db1(request) -> Database: db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) - mock_cb = Mock(spec=AsyncCircuitBreaker) + mock_cb = Mock(spec=CircuitBreaker) mock_cb.grace_period = cb.get("grace_period", 1.0) mock_cb.state = cb.get("state", CBState.CLOSED) @@ -67,13 +68,36 @@ def mock_db2(request) -> Database: db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) - mock_cb = Mock(spec=AsyncCircuitBreaker) + mock_cb = Mock(spec=CircuitBreaker) mock_cb.grace_period = cb.get("grace_period", 1.0) mock_cb.state = cb.get("state", CBState.CLOSED) db.circuit = mock_cb return db +@pytest.fixture() +def mock_multi_db_config( + request, mock_fd, mock_fs, mock_hc, mock_ed +) -> MultiDbConfig: + hc_interval = request.param.get('hc_interval', None) + if hc_interval is None: + hc_interval = DEFAULT_HEALTH_CHECK_INTERVAL + + auto_fallback_interval = request.param.get('auto_fallback_interval', None) + if auto_fallback_interval is None: + auto_fallback_interval = DEFAULT_AUTO_FALLBACK_INTERVAL + + config = MultiDbConfig( + databases_config=[Mock(spec=DatabaseConfig)], + failure_detectors=[mock_fd], + health_check_interval=hc_interval, + failover_strategy=mock_fs, + auto_fallback_interval=auto_fallback_interval, + event_dispatcher=mock_ed + ) + + return config + def create_weighted_list(*databases) -> Databases: dbs = WeightedList() diff --git a/tests/test_asyncio/test_multidb/test_circuit.py b/tests/test_asyncio/test_multidb/test_circuit.py deleted file mode 100644 index b1080cfc7d..0000000000 --- a/tests/test_asyncio/test_multidb/test_circuit.py +++ /dev/null @@ -1,58 +0,0 @@ -import pybreaker -import pytest - -from redis.asyncio.multidb.circuit import ( - AsyncPBCircuitBreakerAdapter, - State as CbState, -) -from redis.multidb.circuit import CircuitBreaker, PBCircuitBreakerAdapter - - -class TestAsyncPBCircuitBreaker: - @pytest.mark.asyncio - @pytest.mark.parametrize( - 'mock_db', - [ - {'weight': 0.7, 'circuit': {'state': CbState.CLOSED}}, - ], - indirect=True, - ) - async def test_cb_correctly_configured(self, mock_db): - pb_circuit = pybreaker.CircuitBreaker(reset_timeout=5) - adapter = AsyncPBCircuitBreakerAdapter(cb=pb_circuit) - assert adapter.state == CbState.CLOSED - - adapter.state = CbState.OPEN - assert adapter.state == CbState.OPEN - - adapter.state = CbState.HALF_OPEN - assert adapter.state == CbState.HALF_OPEN - - adapter.state = CbState.CLOSED - assert adapter.state == CbState.CLOSED - - assert adapter.grace_period == 5 - adapter.grace_period = 10 - - assert adapter.grace_period == 10 - - adapter.database = mock_db - assert adapter.database == mock_db - - @pytest.mark.asyncio - async def test_cb_executes_callback_on_state_changed(self): - pb_circuit = pybreaker.CircuitBreaker(reset_timeout=5) - adapter = AsyncPBCircuitBreakerAdapter(cb=pb_circuit) - called_count = 0 - - def callback(cb: CircuitBreaker, old_state: CbState, new_state: CbState): - nonlocal called_count - assert old_state == CbState.CLOSED - assert new_state == CbState.HALF_OPEN - assert isinstance(cb, PBCircuitBreakerAdapter) - called_count += 1 - - await adapter.on_state_changed(callback) - adapter.state = CbState.HALF_OPEN - - assert called_count == 1 \ No newline at end of file diff --git a/tests/test_asyncio/test_multidb/test_client.py b/tests/test_asyncio/test_multidb/test_client.py new file mode 100644 index 0000000000..c2fe914e9f --- /dev/null +++ b/tests/test_asyncio/test_multidb/test_client.py @@ -0,0 +1,471 @@ +import asyncio +from unittest.mock import patch, AsyncMock, Mock + +import pybreaker +import pytest + +from redis.asyncio.multidb.client import MultiDBClient +from redis.asyncio.multidb.config import DEFAULT_FAILOVER_RETRIES, DEFAULT_FAILOVER_BACKOFF +from redis.asyncio.multidb.database import AsyncDatabase +from redis.asyncio.multidb.failover import WeightBasedFailoverStrategy +from redis.asyncio.multidb.failure_detector import AsyncFailureDetector +from redis.asyncio.multidb.healthcheck import EchoHealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, \ + DEFAULT_HEALTH_CHECK_BACKOFF, HealthCheck +from redis.asyncio.retry import Retry +from redis.event import EventDispatcher, AsyncOnCommandsFailEvent +from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter +from redis.multidb.exception import NoValidDatabaseException +from tests.test_asyncio.test_multidb.conftest import create_weighted_list + + +class TestMultiDbClient: + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_command_against_correct_db_on_successful_initialization( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_db1.client.execute_command = AsyncMock(return_value='OK1') + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert await client.set('key', 'value') == 'OK1' + assert mock_hc.check_health.call_count == 3 + + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.CLOSED + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + async def test_execute_command_against_correct_db_and_closed_circuit( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_db1.client.execute_command = AsyncMock(return_value='OK1') + + mock_hc.check_health.side_effect = [False, True, True] + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert await client.set('key', 'value') == 'OK1' + assert mock_hc.check_health.call_count == 3 + + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.OPEN + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_command_against_correct_db_on_background_health_check_determine_active_db_unhealthy( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + cb = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb.database = mock_db + mock_db.circuit = cb + + cb1 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb1.database = mock_db1 + mock_db1.circuit = cb1 + + cb2 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb2.database = mock_db2 + mock_db2.circuit = cb2 + + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck( + retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + )]): + mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'OK', 'error'] + mock_db1.client.execute_command.side_effect = ['healthcheck', 'OK1', 'error', 'error', 'healthcheck', 'OK1'] + mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'OK2', 'error', 'error'] + mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( + retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) + ) + + client = MultiDBClient(mock_multi_db_config) + assert await client.set('key', 'value') == 'OK1' + await asyncio.sleep(0.15) + assert await client.set('key', 'value') == 'OK2' + await asyncio.sleep(0.1) + assert await client.set('key', 'value') == 'OK' + await asyncio.sleep(0.1) + assert await client.set('key', 'value') == 'OK1' + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_command_auto_fallback_to_highest_weight_db( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck( + retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + )]): + mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'healthcheck', 'healthcheck'] + mock_db1.client.execute_command.side_effect = ['healthcheck', 'OK1', 'error', 'healthcheck', 'healthcheck', 'OK1'] + mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'OK2', 'healthcheck', 'healthcheck', 'healthcheck'] + mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.auto_fallback_interval = 0.2 + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( + retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) + ) + + client = MultiDBClient(mock_multi_db_config) + assert await client.set('key', 'value') == 'OK1' + await asyncio.sleep(0.15) + assert await client.set('key', 'value') == 'OK2' + await asyncio.sleep(0.22) + assert await client.set('key', 'value') == 'OK1' + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + {'weight': 0.5, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + async def test_execute_command_throws_exception_on_failed_initialization( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_hc.check_health.return_value = False + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + with pytest.raises(NoValidDatabaseException, match='Initial connection failed - no active database found'): + await client.set('key', 'value') + assert mock_hc.check_health.call_count == 3 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_add_database_throws_exception_on_same_database( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_hc.check_health.return_value = False + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + with pytest.raises(ValueError, match='Given database already exists'): + await client.add_database(mock_db) + assert mock_hc.check_health.call_count == 3 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_add_database_makes_new_database_active( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db2.client.execute_command.return_value = 'OK2' + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + assert await client.set('key', 'value') == 'OK2' + assert mock_hc.check_health.call_count == 2 + + await client.add_database(mock_db1) + assert mock_hc.check_health.call_count == 3 + + assert await client.set('key', 'value') == 'OK1' + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_remove_highest_weighted_database( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db2.client.execute_command.return_value = 'OK2' + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + assert await client.set('key', 'value') == 'OK1' + assert mock_hc.check_health.call_count == 3 + + await client.remove_database(mock_db1) + assert await client.set('key', 'value') == 'OK2' + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_update_database_weight_to_be_highest( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db2.client.execute_command.return_value = 'OK2' + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + assert await client.set('key', 'value') == 'OK1' + assert mock_hc.check_health.call_count == 3 + + await client.update_database_weight(mock_db2, 0.8) + assert mock_db2.weight == 0.8 + + assert await client.set('key', 'value') == 'OK2' + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_add_new_failure_detector( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_db1.client.execute_command.return_value = 'OK1' + mock_multi_db_config.event_dispatcher = EventDispatcher() + mock_fd = mock_multi_db_config.failure_detectors[0] + + # Event fired if command against mock_db1 would fail + command_fail_event = AsyncOnCommandsFailEvent( + commands=('SET', 'key', 'value'), + exception=Exception(), + ) + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert await client.set('key', 'value') == 'OK1' + assert mock_hc.check_health.call_count == 3 + + # Simulate failing command events that lead to a failure detection + for i in range(5): + await mock_multi_db_config.event_dispatcher.dispatch_async(command_fail_event) + + assert mock_fd.register_failure.call_count == 5 + + another_fd = Mock(spec=AsyncFailureDetector) + client.add_failure_detector(another_fd) + + # Simulate failing command events that lead to a failure detection + for i in range(5): + await mock_multi_db_config.event_dispatcher.dispatch_async(command_fail_event) + + assert mock_fd.register_failure.call_count == 10 + assert another_fd.register_failure.call_count == 5 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_add_new_health_check( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_db1.client.execute_command.return_value = 'OK1' + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert await client.set('key', 'value') == 'OK1' + assert mock_hc.check_health.call_count == 3 + + another_hc = Mock(spec=HealthCheck) + another_hc.check_health.return_value = True + + await client.add_health_check(another_hc) + await client._check_db_health(mock_db1) + + assert mock_hc.check_health.call_count == 4 + assert another_hc.check_health.call_count == 1 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_set_active_database( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_db1.client.execute_command.return_value = 'OK1' + mock_db.client.execute_command.return_value = 'OK' + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + assert await client.set('key', 'value') == 'OK1' + assert mock_hc.check_health.call_count == 3 + + await client.set_active_database(mock_db) + assert await client.set('key', 'value') == 'OK' + + with pytest.raises(ValueError, match='Given database is not a member of database list'): + await client.set_active_database(Mock(spec=AsyncDatabase)) + + mock_hc.check_health.return_value = False + + with pytest.raises(NoValidDatabaseException, match='Cannot set active database, database is unhealthy'): + await client.set_active_database(mock_db1) \ No newline at end of file diff --git a/tests/test_multidb/conftest.py b/tests/test_multidb/conftest.py index 9503d79d9b..0c082f0f17 100644 --- a/tests/test_multidb/conftest.py +++ b/tests/test_multidb/conftest.py @@ -4,7 +4,7 @@ from redis import Redis from redis.data_structure import WeightedList -from redis.multidb.circuit import State as CBState, SyncCircuitBreaker +from redis.multidb.circuit import State as CBState, CircuitBreaker from redis.multidb.config import MultiDbConfig, DatabaseConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ DEFAULT_AUTO_FALLBACK_INTERVAL from redis.multidb.database import Database, Databases @@ -19,8 +19,8 @@ def mock_client() -> Redis: return Mock(spec=Redis) @pytest.fixture() -def mock_cb() -> SyncCircuitBreaker: - return Mock(spec=SyncCircuitBreaker) +def mock_cb() -> CircuitBreaker: + return Mock(spec=CircuitBreaker) @pytest.fixture() def mock_fd() -> FailureDetector: @@ -41,7 +41,7 @@ def mock_db(request) -> Database: db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) - mock_cb = Mock(spec=SyncCircuitBreaker) + mock_cb = Mock(spec=CircuitBreaker) mock_cb.grace_period = cb.get("grace_period", 1.0) mock_cb.state = cb.get("state", CBState.CLOSED) @@ -55,7 +55,7 @@ def mock_db1(request) -> Database: db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) - mock_cb = Mock(spec=SyncCircuitBreaker) + mock_cb = Mock(spec=CircuitBreaker) mock_cb.grace_period = cb.get("grace_period", 1.0) mock_cb.state = cb.get("state", CBState.CLOSED) @@ -69,7 +69,7 @@ def mock_db2(request) -> Database: db.client = Mock(spec=Redis) cb = request.param.get("circuit", {}) - mock_cb = Mock(spec=SyncCircuitBreaker) + mock_cb = Mock(spec=CircuitBreaker) mock_cb.grace_period = cb.get("grace_period", 1.0) mock_cb.state = cb.get("state", CBState.CLOSED) diff --git a/tests/test_multidb/test_circuit.py b/tests/test_multidb/test_circuit.py index f5f39c3f6b..7dc642373b 100644 --- a/tests/test_multidb/test_circuit.py +++ b/tests/test_multidb/test_circuit.py @@ -1,7 +1,7 @@ import pybreaker import pytest -from redis.multidb.circuit import PBCircuitBreakerAdapter, State as CbState, CircuitBreaker, SyncCircuitBreaker +from redis.multidb.circuit import PBCircuitBreakerAdapter, State as CbState, CircuitBreaker class TestPBCircuitBreaker: @@ -39,7 +39,7 @@ def test_cb_executes_callback_on_state_changed(self): adapter = PBCircuitBreakerAdapter(cb=pb_circuit) called_count = 0 - def callback(cb: SyncCircuitBreaker, old_state: CbState, new_state: CbState): + def callback(cb: CircuitBreaker, old_state: CbState, new_state: CbState): nonlocal called_count assert old_state == CbState.CLOSED assert new_state == CbState.HALF_OPEN diff --git a/tests/test_multidb/test_client.py b/tests/test_multidb/test_client.py index c7c15fe684..d352c1da92 100644 --- a/tests/test_multidb/test_client.py +++ b/tests/test_multidb/test_client.py @@ -166,13 +166,9 @@ def test_execute_command_auto_fallback_to_highest_weight_db( client = MultiDBClient(mock_multi_db_config) assert client.set('key', 'value') == 'OK1' - sleep(0.15) - assert client.set('key', 'value') == 'OK2' - sleep(0.22) - assert client.set('key', 'value') == 'OK1' @pytest.mark.parametrize( diff --git a/tests/test_multidb/test_config.py b/tests/test_multidb/test_config.py index e428b3ce7a..1ea63a0e14 100644 --- a/tests/test_multidb/test_config.py +++ b/tests/test_multidb/test_config.py @@ -1,6 +1,6 @@ from unittest.mock import Mock from redis.connection import ConnectionPool -from redis.multidb.circuit import PBCircuitBreakerAdapter, SyncCircuitBreaker +from redis.multidb.circuit import PBCircuitBreakerAdapter, CircuitBreaker from redis.multidb.config import MultiDbConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ DEFAULT_AUTO_FALLBACK_INTERVAL, DatabaseConfig, DEFAULT_GRACE_PERIOD from redis.multidb.database import Database @@ -49,11 +49,11 @@ def test_overridden_config(self): mock_connection_pools[0].connection_kwargs = {} mock_connection_pools[1].connection_kwargs = {} mock_connection_pools[2].connection_kwargs = {} - mock_cb1 = Mock(spec=SyncCircuitBreaker) + mock_cb1 = Mock(spec=CircuitBreaker) mock_cb1.grace_period = grace_period - mock_cb2 = Mock(spec=SyncCircuitBreaker) + mock_cb2 = Mock(spec=CircuitBreaker) mock_cb2.grace_period = grace_period - mock_cb3 = Mock(spec=SyncCircuitBreaker) + mock_cb3 = Mock(spec=CircuitBreaker) mock_cb3.grace_period = grace_period mock_failure_detectors = [Mock(spec=FailureDetector), Mock(spec=FailureDetector)] mock_health_checks = [Mock(spec=HealthCheck), Mock(spec=HealthCheck)] @@ -113,7 +113,7 @@ def test_default_config(self): def test_overridden_config(self): mock_connection_pool = Mock(spec=ConnectionPool) - mock_circuit = Mock(spec=SyncCircuitBreaker) + mock_circuit = Mock(spec=CircuitBreaker) config = DatabaseConfig( client_kwargs={'connection_pool': mock_connection_pool}, weight=1.0, circuit=mock_circuit From e376544c55f97d7b3ca336402b42920c59afc33a Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 4 Sep 2025 10:24:20 +0300 Subject: [PATCH 6/9] Added scenario and config tests --- redis/asyncio/multidb/client.py | 4 +- redis/asyncio/multidb/command_executor.py | 1 - redis/event.py | 5 +- .../test_asyncio/test_multidb/test_config.py | 125 ++++++++++++++++++ tests/test_asyncio/test_scenario/__init__.py | 0 tests/test_asyncio/test_scenario/conftest.py | 88 ++++++++++++ .../test_scenario/test_active_active.py | 59 +++++++++ 7 files changed, 278 insertions(+), 4 deletions(-) create mode 100644 tests/test_asyncio/test_multidb/test_config.py create mode 100644 tests/test_asyncio/test_scenario/__init__.py create mode 100644 tests/test_asyncio/test_scenario/conftest.py create mode 100644 tests/test_asyncio/test_scenario/test_active_active.py diff --git a/redis/asyncio/multidb/client.py b/redis/asyncio/multidb/client.py index dbf03a3ef4..73eafd9026 100644 --- a/redis/asyncio/multidb/client.py +++ b/redis/asyncio/multidb/client.py @@ -61,10 +61,10 @@ async def raise_exception_on_failed_hc(error): await self._check_databases_health(on_error=raise_exception_on_failed_hc) # Starts recurring health checks on the background. - await self._bg_scheduler.run_recurring_async( + asyncio.create_task(self._bg_scheduler.run_recurring_async( self._health_check_interval, self._check_databases_health, - ) + )) is_active_db_found = False diff --git a/redis/asyncio/multidb/command_executor.py b/redis/asyncio/multidb/command_executor.py index 22aef83118..af10a00988 100644 --- a/redis/asyncio/multidb/command_executor.py +++ b/redis/asyncio/multidb/command_executor.py @@ -248,7 +248,6 @@ async def _check_active_database(self): ) ): await self.set_active_database(await self._failover_strategy.database()) - print("Active database now with weight {}", format(self._active_database.weight)) self._schedule_next_fallback() async def _on_command_fail(self, error, *args): diff --git a/redis/event.py b/redis/event.py index 8327ec5f76..de38e1a069 100644 --- a/redis/event.py +++ b/redis/event.py @@ -108,7 +108,10 @@ async def dispatch_async(self, event: object): for listener in listeners: await listener.listen(event) - def register_listeners(self, event_listeners: Dict[Type[object], List[EventListenerInterface]]): + def register_listeners( + self, + event_listeners: Dict[Type[object], List[Union[EventListenerInterface, AsyncEventListenerInterface]]] + ): with self._lock: for event_type in event_listeners: if event_type in self._event_listeners_mapping: diff --git a/tests/test_asyncio/test_multidb/test_config.py b/tests/test_asyncio/test_multidb/test_config.py new file mode 100644 index 0000000000..64760740a1 --- /dev/null +++ b/tests/test_asyncio/test_multidb/test_config.py @@ -0,0 +1,125 @@ +from unittest.mock import Mock + +from redis.asyncio import ConnectionPool +from redis.asyncio.multidb.config import DatabaseConfig, MultiDbConfig, DEFAULT_GRACE_PERIOD, \ + DEFAULT_HEALTH_CHECK_INTERVAL, DEFAULT_AUTO_FALLBACK_INTERVAL +from redis.asyncio.multidb.database import Database +from redis.asyncio.multidb.failover import WeightBasedFailoverStrategy, AsyncFailoverStrategy +from redis.asyncio.multidb.failure_detector import FailureDetectorAsyncWrapper, AsyncFailureDetector +from redis.asyncio.multidb.healthcheck import EchoHealthCheck, HealthCheck +from redis.asyncio.retry import Retry +from redis.multidb.circuit import CircuitBreaker + + +class TestMultiDbConfig: + def test_default_config(self): + db_configs = [ + DatabaseConfig(client_kwargs={'host': 'host1', 'port': 'port1'}, weight=1.0), + DatabaseConfig(client_kwargs={'host': 'host2', 'port': 'port2'}, weight=0.9), + DatabaseConfig(client_kwargs={'host': 'host3', 'port': 'port3'}, weight=0.8), + ] + + config = MultiDbConfig( + databases_config=db_configs + ) + + assert config.databases_config == db_configs + databases = config.databases() + assert len(databases) == 3 + + i = 0 + for db, weight in databases: + assert isinstance(db, Database) + assert weight == db_configs[i].weight + assert db.circuit.grace_period == DEFAULT_GRACE_PERIOD + assert db.client.get_retry() is not config.command_retry + i+=1 + + assert len(config.default_failure_detectors()) == 1 + assert isinstance(config.default_failure_detectors()[0], FailureDetectorAsyncWrapper) + assert len(config.default_health_checks()) == 1 + assert isinstance(config.default_health_checks()[0], EchoHealthCheck) + assert config.health_check_interval == DEFAULT_HEALTH_CHECK_INTERVAL + assert isinstance(config.default_failover_strategy(), WeightBasedFailoverStrategy) + assert config.auto_fallback_interval == DEFAULT_AUTO_FALLBACK_INTERVAL + assert isinstance(config.command_retry, Retry) + + def test_overridden_config(self): + grace_period = 2 + mock_connection_pools = [Mock(spec=ConnectionPool), Mock(spec=ConnectionPool), Mock(spec=ConnectionPool)] + mock_connection_pools[0].connection_kwargs = {} + mock_connection_pools[1].connection_kwargs = {} + mock_connection_pools[2].connection_kwargs = {} + mock_cb1 = Mock(spec=CircuitBreaker) + mock_cb1.grace_period = grace_period + mock_cb2 = Mock(spec=CircuitBreaker) + mock_cb2.grace_period = grace_period + mock_cb3 = Mock(spec=CircuitBreaker) + mock_cb3.grace_period = grace_period + mock_failure_detectors = [Mock(spec=AsyncFailureDetector), Mock(spec=AsyncFailureDetector)] + mock_health_checks = [Mock(spec=HealthCheck), Mock(spec=HealthCheck)] + health_check_interval = 10 + mock_failover_strategy = Mock(spec=AsyncFailoverStrategy) + auto_fallback_interval = 10 + db_configs = [ + DatabaseConfig( + client_kwargs={"connection_pool": mock_connection_pools[0]}, weight=1.0, circuit=mock_cb1 + ), + DatabaseConfig( + client_kwargs={"connection_pool": mock_connection_pools[1]}, weight=0.9, circuit=mock_cb2 + ), + DatabaseConfig( + client_kwargs={"connection_pool": mock_connection_pools[2]}, weight=0.8, circuit=mock_cb3 + ), + ] + + config = MultiDbConfig( + databases_config=db_configs, + failure_detectors=mock_failure_detectors, + health_checks=mock_health_checks, + health_check_interval=health_check_interval, + failover_strategy=mock_failover_strategy, + auto_fallback_interval=auto_fallback_interval, + ) + + assert config.databases_config == db_configs + databases = config.databases() + assert len(databases) == 3 + + i = 0 + for db, weight in databases: + assert isinstance(db, Database) + assert weight == db_configs[i].weight + assert db.client.connection_pool == mock_connection_pools[i] + assert db.circuit.grace_period == grace_period + i+=1 + + assert len(config.failure_detectors) == 2 + assert config.failure_detectors[0] == mock_failure_detectors[0] + assert config.failure_detectors[1] == mock_failure_detectors[1] + assert len(config.health_checks) == 2 + assert config.health_checks[0] == mock_health_checks[0] + assert config.health_checks[1] == mock_health_checks[1] + assert config.health_check_interval == health_check_interval + assert config.failover_strategy == mock_failover_strategy + assert config.auto_fallback_interval == auto_fallback_interval + +class TestDatabaseConfig: + def test_default_config(self): + config = DatabaseConfig(client_kwargs={'host': 'host1', 'port': 'port1'}, weight=1.0) + + assert config.client_kwargs == {'host': 'host1', 'port': 'port1'} + assert config.weight == 1.0 + assert isinstance(config.default_circuit_breaker(), CircuitBreaker) + + def test_overridden_config(self): + mock_connection_pool = Mock(spec=ConnectionPool) + mock_circuit = Mock(spec=CircuitBreaker) + + config = DatabaseConfig( + client_kwargs={'connection_pool': mock_connection_pool}, weight=1.0, circuit=mock_circuit + ) + + assert config.client_kwargs == {'connection_pool': mock_connection_pool} + assert config.weight == 1.0 + assert config.circuit == mock_circuit \ No newline at end of file diff --git a/tests/test_asyncio/test_scenario/__init__.py b/tests/test_asyncio/test_scenario/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_asyncio/test_scenario/conftest.py b/tests/test_asyncio/test_scenario/conftest.py new file mode 100644 index 0000000000..312712ba05 --- /dev/null +++ b/tests/test_asyncio/test_scenario/conftest.py @@ -0,0 +1,88 @@ +import os + +import pytest + +from redis.asyncio import Redis +from redis.asyncio.multidb.client import MultiDBClient +from redis.asyncio.multidb.config import DEFAULT_FAILURES_THRESHOLD, DEFAULT_HEALTH_CHECK_INTERVAL, DatabaseConfig, \ + MultiDbConfig +from redis.asyncio.multidb.event import AsyncActiveDatabaseChanged +from redis.asyncio.retry import Retry +from redis.backoff import ExponentialBackoff +from redis.event import AsyncEventListenerInterface, EventDispatcher +from tests.test_scenario.conftest import get_endpoint_config, extract_cluster_fqdn +from tests.test_scenario.fault_injector_client import FaultInjectorClient + + +class CheckActiveDatabaseChangedListener(AsyncEventListenerInterface): + def __init__(self): + self.is_changed_flag = False + + async def listen(self, event: AsyncActiveDatabaseChanged): + self.is_changed_flag = True + +@pytest.fixture() +def fault_injector_client(): + url = os.getenv("FAULT_INJECTION_API_URL", "http://127.0.0.1:20324") + return FaultInjectorClient(url) + +@pytest.fixture() +def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListener, dict]: + client_class = request.param.get('client_class', Redis) + + if client_class == Redis: + endpoint_config = get_endpoint_config('re-active-active') + else: + endpoint_config = get_endpoint_config('re-active-active-oss-cluster') + + username = endpoint_config.get('username', None) + password = endpoint_config.get('password', None) + failure_threshold = request.param.get('failure_threshold', DEFAULT_FAILURES_THRESHOLD) + command_retry = request.param.get('command_retry', Retry(ExponentialBackoff(cap=2, base=0.05), retries=10)) + + # Retry configuration different for health checks as initial health check require more time in case + # if infrastructure wasn't restored from the previous test. + health_check_interval = request.param.get('health_check_interval', DEFAULT_HEALTH_CHECK_INTERVAL) + event_dispatcher = EventDispatcher() + listener = CheckActiveDatabaseChangedListener() + event_dispatcher.register_listeners({ + AsyncActiveDatabaseChanged: [listener], + }) + db_configs = [] + + db_config = DatabaseConfig( + weight=1.0, + from_url=endpoint_config['endpoints'][0], + client_kwargs={ + 'username': username, + 'password': password, + 'decode_responses': True, + }, + health_check_url=extract_cluster_fqdn(endpoint_config['endpoints'][0]) + ) + db_configs.append(db_config) + + db_config1 = DatabaseConfig( + weight=0.9, + from_url=endpoint_config['endpoints'][1], + client_kwargs={ + 'username': username, + 'password': password, + 'decode_responses': True, + }, + health_check_url=extract_cluster_fqdn(endpoint_config['endpoints'][1]) + ) + db_configs.append(db_config1) + + config = MultiDbConfig( + client_class=client_class, + databases_config=db_configs, + command_retry=command_retry, + failure_threshold=failure_threshold, + health_check_retries=3, + health_check_interval=health_check_interval, + event_dispatcher=event_dispatcher, + health_check_backoff=ExponentialBackoff(cap=5, base=0.5), + ) + + return MultiDBClient(config), listener, endpoint_config \ No newline at end of file diff --git a/tests/test_asyncio/test_scenario/test_active_active.py b/tests/test_asyncio/test_scenario/test_active_active.py new file mode 100644 index 0000000000..833bb0776f --- /dev/null +++ b/tests/test_asyncio/test_scenario/test_active_active.py @@ -0,0 +1,59 @@ +import asyncio +import logging +from time import sleep + +import pytest + +from tests.test_scenario.fault_injector_client import ActionRequest, ActionType + +logger = logging.getLogger(__name__) + +async def trigger_network_failure_action(fault_injector_client, config, event: asyncio.Event = None): + action_request = ActionRequest( + action_type=ActionType.NETWORK_FAILURE, + parameters={"bdb_id": config['bdb_id'], "delay": 2, "cluster_index": 0} + ) + + result = fault_injector_client.trigger_action(action_request) + status_result = fault_injector_client.get_action_status(result['action_id']) + + while status_result['status'] != "success": + await asyncio.sleep(0.1) + status_result = fault_injector_client.get_action_status(result['action_id']) + logger.info(f"Waiting for action to complete. Status: {status_result['status']}") + + if event: + event.set() + + logger.info(f"Action completed. Status: {status_result['status']}") + +class TestActiveActive: + + def teardown_method(self, method): + # Timeout so the cluster could recover from network failure. + sleep(5) + + @pytest.mark.parametrize( + "r_multi_db", + [{"failure_threshold": 2}], + indirect=True + ) + @pytest.mark.timeout(50) + async def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector_client): + r_multi_db, listener, config = r_multi_db + + event = asyncio.Event() + asyncio.create_task(trigger_network_failure_action(fault_injector_client,config,event)) + + # Client initialized on the first command. + await r_multi_db.set('key', 'value') + + # Execute commands before network failure + while not event.is_set(): + assert await r_multi_db.get('key') == 'value' + await asyncio.sleep(0.5) + + # Execute commands until database failover + while not listener.is_changed_flag: + assert await r_multi_db.get('key') == 'value' + await asyncio.sleep(0.5) \ No newline at end of file From 57f6d8bb82bbb50213e8f3f238264f30436cc6a0 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 4 Sep 2025 12:11:39 +0300 Subject: [PATCH 7/9] Added pipeline and transaction support for MultiDBClient --- redis/asyncio/multidb/client.py | 114 ++++++- redis/asyncio/multidb/command_executor.py | 24 +- .../test_multidb/test_pipeline.py | 321 ++++++++++++++++++ tests/test_asyncio/test_scenario/conftest.py | 8 +- .../test_scenario/test_active_active.py | 130 +++++++ 5 files changed, 585 insertions(+), 12 deletions(-) create mode 100644 tests/test_asyncio/test_multidb/test_pipeline.py diff --git a/redis/asyncio/multidb/client.py b/redis/asyncio/multidb/client.py index 73eafd9026..1025c4b37b 100644 --- a/redis/asyncio/multidb/client.py +++ b/redis/asyncio/multidb/client.py @@ -1,5 +1,5 @@ import asyncio -from typing import Callable, Optional, Coroutine, Any +from typing import Callable, Optional, Coroutine, Any, List, Union, Awaitable from redis.asyncio.multidb.command_executor import DefaultCommandExecutor from redis.asyncio.multidb.database import AsyncDatabase, Databases @@ -10,6 +10,7 @@ from redis.background import BackgroundScheduler from redis.commands import AsyncRedisModuleCommands, AsyncCoreCommands from redis.multidb.exception import NoValidDatabaseException +from redis.typing import KeyT class MultiDBClient(AsyncRedisModuleCommands, AsyncCoreCommands): @@ -49,6 +50,19 @@ def __init__(self, config: MultiDbConfig): self._hc_lock = asyncio.Lock() self._bg_scheduler = BackgroundScheduler() self._config = config + self._hc_task = None + self._half_open_state_task = None + + async def __aenter__(self: "MultiDBClient") -> "MultiDBClient": + if not self.initialized: + await self.initialize() + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + if self._hc_task: + self._hc_task.cancel() + if self._half_open_state_task: + self._half_open_state_task.cancel() async def initialize(self): """ @@ -61,7 +75,7 @@ async def raise_exception_on_failed_hc(error): await self._check_databases_health(on_error=raise_exception_on_failed_hc) # Starts recurring health checks on the background. - asyncio.create_task(self._bg_scheduler.run_recurring_async( + self._hc_task = asyncio.create_task(self._bg_scheduler.run_recurring_async( self._health_check_interval, self._check_databases_health, )) @@ -180,6 +194,34 @@ async def execute_command(self, *args, **options): return await self.command_executor.execute_command(*args, **options) + def pipeline(self): + """ + Enters into pipeline mode of the client. + """ + return Pipeline(self) + + async def transaction( + self, + func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]], + *watches: KeyT, + shard_hint: Optional[str] = None, + value_from_callable: bool = False, + watch_delay: Optional[float] = None, + ): + """ + Executes callable as transaction. + """ + if not self.initialized: + await self.initialize() + + return await self.command_executor.execute_transaction( + func, + *watches, + shard_hint=shard_hint, + value_from_callable=value_from_callable, + watch_delay=watch_delay, + ) + async def _check_databases_health( self, on_error: Optional[Callable[[Exception], Coroutine[Any, Any, None]]] = None, @@ -227,11 +269,75 @@ def _on_circuit_state_change_callback(self, circuit: CircuitBreaker, old_state: loop = asyncio.get_running_loop() if new_state == CBState.HALF_OPEN: - asyncio.create_task(self._check_db_health(circuit.database)) + self._half_open_state_task = asyncio.create_task(self._check_db_health(circuit.database)) return if old_state == CBState.CLOSED and new_state == CBState.OPEN: loop.call_later(DEFAULT_GRACE_PERIOD, _half_open_circuit, circuit) def _half_open_circuit(circuit: CircuitBreaker): - circuit.state = CBState.HALF_OPEN \ No newline at end of file + circuit.state = CBState.HALF_OPEN + +class Pipeline(AsyncRedisModuleCommands, AsyncCoreCommands): + """ + Pipeline implementation for multiple logical Redis databases. + """ + def __init__(self, client: MultiDBClient): + self._command_stack = [] + self._client = client + + async def __aenter__(self: "Pipeline") -> "Pipeline": + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.reset() + await self._client.__aexit__(exc_type, exc_value, traceback) + + def __await__(self): + return self._async_self().__await__() + + async def _async_self(self): + return self + + def __len__(self) -> int: + return len(self._command_stack) + + def __bool__(self) -> bool: + """Pipeline instances should always evaluate to True""" + return True + + async def reset(self) -> None: + self._command_stack = [] + + async def aclose(self) -> None: + """Close the pipeline""" + await self.reset() + + def pipeline_execute_command(self, *args, **options) -> "Pipeline": + """ + Stage a command to be executed when execute() is next called + + Returns the current Pipeline object back so commands can be + chained together, such as: + + pipe = pipe.set('foo', 'bar').incr('baz').decr('bang') + + At some other point, you can then run: pipe.execute(), + which will execute all commands queued in the pipe. + """ + self._command_stack.append((args, options)) + return self + + def execute_command(self, *args, **kwargs): + """Adds a command to the stack""" + return self.pipeline_execute_command(*args, **kwargs) + + async def execute(self) -> List[Any]: + """Execute all the commands in the current pipeline""" + if not self._client.initialized: + await self._client.initialize() + + try: + return await self._client.command_executor.execute_pipeline(tuple(self._command_stack)) + finally: + await self.reset() \ No newline at end of file diff --git a/redis/asyncio/multidb/command_executor.py b/redis/asyncio/multidb/command_executor.py index af10a00988..4133dba394 100644 --- a/redis/asyncio/multidb/command_executor.py +++ b/redis/asyncio/multidb/command_executor.py @@ -1,6 +1,6 @@ from abc import abstractmethod from datetime import datetime -from typing import List, Optional, Callable, Any +from typing import List, Optional, Callable, Any, Union, Awaitable from redis.asyncio.client import PubSub, Pipeline from redis.asyncio.multidb.database import Databases, AsyncDatabase, Database @@ -13,6 +13,7 @@ from redis.event import EventDispatcherInterface, AsyncOnCommandsFailEvent from redis.multidb.command_executor import CommandExecutor, BaseCommandExecutor from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL +from redis.typing import KeyT class AsyncCommandExecutor(CommandExecutor): @@ -194,17 +195,30 @@ async def callback(): async def execute_pipeline(self, command_stack: tuple): async def callback(): - with self._active_database.client.pipeline() as pipe: + async with self._active_database.client.pipeline() as pipe: for command, options in command_stack: - await pipe.execute_command(*command, **options) + pipe.execute_command(*command, **options) return await pipe.execute() return await self._execute_with_failure_detection(callback, command_stack) - async def execute_transaction(self, transaction: Callable[[Pipeline], None], *watches, **options): + async def execute_transaction( + self, + func: Callable[["Pipeline"], Union[Any, Awaitable[Any]]], + *watches: KeyT, + shard_hint: Optional[str] = None, + value_from_callable: bool = False, + watch_delay: Optional[float] = None, + ): async def callback(): - return await self._active_database.client.transaction(transaction, *watches, **options) + return await self._active_database.client.transaction( + func, + *watches, + shard_hint=shard_hint, + value_from_callable=value_from_callable, + watch_delay=watch_delay + ) return await self._execute_with_failure_detection(callback) diff --git a/tests/test_asyncio/test_multidb/test_pipeline.py b/tests/test_asyncio/test_multidb/test_pipeline.py new file mode 100644 index 0000000000..5af2e3e864 --- /dev/null +++ b/tests/test_asyncio/test_multidb/test_pipeline.py @@ -0,0 +1,321 @@ +import asyncio +from unittest.mock import Mock, AsyncMock, patch + +import pybreaker +import pytest + +from redis.asyncio.client import Pipeline +from redis.asyncio.multidb.client import MultiDBClient +from redis.asyncio.multidb.config import DEFAULT_FAILOVER_RETRIES +from redis.asyncio.multidb.failover import WeightBasedFailoverStrategy +from redis.asyncio.multidb.healthcheck import EchoHealthCheck, DEFAULT_HEALTH_CHECK_RETRIES, \ + DEFAULT_HEALTH_CHECK_BACKOFF +from redis.asyncio.retry import Retry +from redis.multidb.circuit import State as CBState, PBCircuitBreakerAdapter +from redis.multidb.config import DEFAULT_FAILOVER_BACKOFF +from tests.test_asyncio.test_multidb.conftest import create_weighted_list + + +def mock_pipe() -> Pipeline: + mock_pipe = Mock(spec=Pipeline) + mock_pipe.__aenter__ = AsyncMock(return_value=mock_pipe) + mock_pipe.__aexit__ = AsyncMock(return_value=None) + return mock_pipe + +class TestPipeline: + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_executes_pipeline_against_correct_db( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + pipe = mock_pipe() + pipe.execute.return_value = ['OK1', 'value1'] + mock_db1.client.pipeline.return_value = pipe + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + pipe = client.pipeline() + pipe.set('key1', 'value1') + pipe.get('key1') + + assert await pipe.execute() == ['OK1', 'value1'] + assert mock_hc.check_health.call_count == 3 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + async def test_execute_pipeline_against_correct_db_and_closed_circuit( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + pipe = mock_pipe() + pipe.execute.return_value = ['OK1', 'value1'] + mock_db1.client.pipeline.return_value = pipe + + mock_hc.check_health.side_effect = [False, True, True] + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + async with client.pipeline() as pipe: + pipe.set('key1', 'value1') + pipe.get('key1') + + assert await pipe.execute() == ['OK1', 'value1'] + assert mock_hc.check_health.call_count == 3 + + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.OPEN + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_pipeline_against_correct_db_on_background_health_check_determine_active_db_unhealthy( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + cb = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb.database = mock_db + mock_db.circuit = cb + + cb1 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb1.database = mock_db1 + mock_db1.circuit = cb1 + + cb2 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb2.database = mock_db2 + mock_db2.circuit = cb2 + + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck( + retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + )]): + mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'error'] + mock_db1.client.execute_command.side_effect = ['healthcheck', 'error', 'error', 'healthcheck'] + mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'error', 'error'] + + pipe = mock_pipe() + pipe.execute.return_value = ['OK', 'value'] + mock_db.client.pipeline.return_value = pipe + + pipe1 = mock_pipe() + pipe1.execute.return_value = ['OK1', 'value'] + mock_db1.client.pipeline.return_value = pipe1 + + pipe2 = mock_pipe() + pipe2.execute.return_value = ['OK2', 'value'] + mock_db2.client.pipeline.return_value = pipe2 + + mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( + retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) + ) + + client = MultiDBClient(mock_multi_db_config) + + async with client.pipeline() as pipe: + pipe.set('key1', 'value') + pipe.get('key1') + + assert await pipe.execute() == ['OK1', 'value'] + + await asyncio.sleep(0.15) + + async with client.pipeline() as pipe: + pipe.set('key1', 'value') + pipe.get('key1') + + assert await pipe.execute() == ['OK2', 'value'] + + await asyncio.sleep(0.1) + + async with client.pipeline() as pipe: + pipe.set('key1', 'value') + pipe.get('key1') + + assert await pipe.execute() == ['OK', 'value'] + + await asyncio.sleep(0.1) + + async with client.pipeline() as pipe: + pipe.set('key1', 'value') + pipe.get('key1') + + assert await pipe.execute() == ['OK1', 'value'] + +class TestTransaction: + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_executes_transaction_against_correct_db( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_db1.client.transaction.return_value = ['OK1', 'value1'] + + mock_hc.check_health.return_value = True + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + async def callback(pipe: Pipeline): + pipe.set('key1', 'value1') + pipe.get('key1') + + assert await client.transaction(callback) == ['OK1', 'value1'] + assert mock_hc.check_health.call_count == 3 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.OPEN}}, + ), + ], + indirect=True, + ) + async def test_execute_transaction_against_correct_db_and_closed_circuit( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2, mock_hc + ): + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]): + mock_db1.client.transaction.return_value = ['OK1', 'value1'] + + mock_hc.check_health.side_effect = [False, True, True] + + client = MultiDBClient(mock_multi_db_config) + assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1 + + async def callback(pipe: Pipeline): + pipe.set('key1', 'value1') + pipe.get('key1') + + assert await client.transaction(callback) == ['OK1', 'value1'] + assert mock_hc.check_health.call_count == 3 + + assert mock_db.circuit.state == CBState.CLOSED + assert mock_db1.circuit.state == CBState.CLOSED + assert mock_db2.circuit.state == CBState.OPEN + + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'mock_multi_db_config,mock_db, mock_db1, mock_db2', + [ + ( + {}, + {'weight': 0.2, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.7, 'circuit': {'state': CBState.CLOSED}}, + {'weight': 0.5, 'circuit': {'state': CBState.CLOSED}}, + ), + ], + indirect=True, + ) + async def test_execute_transaction_against_correct_db_on_background_health_check_determine_active_db_unhealthy( + self, mock_multi_db_config, mock_db, mock_db1, mock_db2 + ): + cb = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb.database = mock_db + mock_db.circuit = cb + + cb1 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb1.database = mock_db1 + mock_db1.circuit = cb1 + + cb2 = PBCircuitBreakerAdapter(pybreaker.CircuitBreaker(reset_timeout=5)) + cb2.database = mock_db2 + mock_db2.circuit = cb2 + + databases = create_weighted_list(mock_db, mock_db1, mock_db2) + + with patch.object(mock_multi_db_config,'databases',return_value=databases), \ + patch.object(mock_multi_db_config,'default_health_checks', return_value=[EchoHealthCheck( + retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF) + )]): + mock_db.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'healthcheck', 'error'] + mock_db1.client.execute_command.side_effect = ['healthcheck', 'error', 'error', 'healthcheck'] + mock_db2.client.execute_command.side_effect = ['healthcheck', 'healthcheck', 'error', 'error'] + + mock_db.client.transaction.return_value = ['OK', 'value'] + mock_db1.client.transaction.return_value = ['OK1', 'value'] + mock_db2.client.transaction.return_value = ['OK2', 'value'] + + mock_multi_db_config.health_check_interval = 0.1 + mock_multi_db_config.failover_strategy = WeightBasedFailoverStrategy( + retry=Retry(retries=DEFAULT_FAILOVER_RETRIES, backoff=DEFAULT_FAILOVER_BACKOFF) + ) + + client = MultiDBClient(mock_multi_db_config) + + async def callback(pipe: Pipeline): + pipe.set('key1', 'value1') + pipe.get('key1') + + assert await client.transaction(callback) == ['OK1', 'value'] + await asyncio.sleep(0.15) + assert await client.transaction(callback) == ['OK2', 'value'] + await asyncio.sleep(0.1) + assert await client.transaction(callback) == ['OK', 'value'] + await asyncio.sleep(0.1) + assert await client.transaction(callback) == ['OK1', 'value'] \ No newline at end of file diff --git a/tests/test_asyncio/test_scenario/conftest.py b/tests/test_asyncio/test_scenario/conftest.py index 312712ba05..18bc8f1417 100644 --- a/tests/test_asyncio/test_scenario/conftest.py +++ b/tests/test_asyncio/test_scenario/conftest.py @@ -1,6 +1,7 @@ import os import pytest +import pytest_asyncio from redis.asyncio import Redis from redis.asyncio.multidb.client import MultiDBClient @@ -26,8 +27,8 @@ def fault_injector_client(): url = os.getenv("FAULT_INJECTION_API_URL", "http://127.0.0.1:20324") return FaultInjectorClient(url) -@pytest.fixture() -def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListener, dict]: +@pytest_asyncio.fixture() +async def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListener, dict]: client_class = request.param.get('client_class', Redis) if client_class == Redis: @@ -85,4 +86,5 @@ def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListen health_check_backoff=ExponentialBackoff(cap=5, base=0.5), ) - return MultiDBClient(config), listener, endpoint_config \ No newline at end of file + async with MultiDBClient(config) as client: + return client, listener, endpoint_config \ No newline at end of file diff --git a/tests/test_asyncio/test_scenario/test_active_active.py b/tests/test_asyncio/test_scenario/test_active_active.py index 833bb0776f..76db322253 100644 --- a/tests/test_asyncio/test_scenario/test_active_active.py +++ b/tests/test_asyncio/test_scenario/test_active_active.py @@ -4,6 +4,7 @@ import pytest +from redis.asyncio.client import Pipeline from tests.test_scenario.fault_injector_client import ActionRequest, ActionType logger = logging.getLogger(__name__) @@ -33,6 +34,7 @@ def teardown_method(self, method): # Timeout so the cluster could recover from network failure. sleep(5) + @pytest.mark.asyncio @pytest.mark.parametrize( "r_multi_db", [{"failure_threshold": 2}], @@ -56,4 +58,132 @@ async def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_in # Execute commands until database failover while not listener.is_changed_flag: assert await r_multi_db.get('key') == 'value' + await asyncio.sleep(0.5) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "r_multi_db", + [{"failure_threshold": 2}], + indirect=True + ) + @pytest.mark.timeout(50) + async def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): + r_multi_db, listener, config = r_multi_db + + event = asyncio.Event() + asyncio.create_task(trigger_network_failure_action(fault_injector_client,config,event)) + + # Client initialized on first pipe execution. + async with r_multi_db.pipeline() as pipe: + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + + # Execute pipeline before network failure + while not event.is_set(): + async with r_multi_db.pipeline() as pipe: + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + await asyncio.sleep(0.5) + + # Execute pipeline until database failover + for _ in range(5): + async with r_multi_db.pipeline() as pipe: + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + await asyncio.sleep(0.5) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "r_multi_db", + [{"failure_threshold": 2}], + indirect=True + ) + @pytest.mark.timeout(50) + async def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_injector_client): + r_multi_db, listener, config = r_multi_db + + event = asyncio.Event() + asyncio.create_task(trigger_network_failure_action(fault_injector_client,config,event)) + + # Client initialized on first pipe execution. + pipe = r_multi_db.pipeline() + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + + # Execute pipeline before network failure + while not event.is_set(): + pipe = r_multi_db.pipeline() + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + await asyncio.sleep(0.5) + + # Execute pipeline until database failover + for _ in range(5): + pipe = r_multi_db.pipeline() + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + assert await pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] + await asyncio.sleep(0.5) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "r_multi_db", + [{"failure_threshold": 2}], + indirect=True + ) + @pytest.mark.timeout(50) + async def test_transaction_failover_to_another_db(self, r_multi_db, fault_injector_client): + r_multi_db, listener, config = r_multi_db + + event = asyncio.Event() + asyncio.create_task(trigger_network_failure_action(fault_injector_client,config,event)) + + async def callback(pipe: Pipeline): + pipe.set('{hash}key1', 'value1') + pipe.set('{hash}key2', 'value2') + pipe.set('{hash}key3', 'value3') + pipe.get('{hash}key1') + pipe.get('{hash}key2') + pipe.get('{hash}key3') + + # Client initialized on first transaction execution. + await r_multi_db.transaction(callback) == [True, True, True, 'value1', 'value2', 'value3'] + + # Execute transaction before network failure + while not event.is_set(): + await r_multi_db.transaction(callback) + await asyncio.sleep(0.5) + + # Execute transaction until database failover + while not listener.is_changed_flag: + await r_multi_db.transaction(callback) == [True, True, True, 'value1', 'value2', 'value3'] await asyncio.sleep(0.5) \ No newline at end of file From 25eebb96824fdbbb56edc0957cd910408315b11c Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 4 Sep 2025 16:01:58 +0300 Subject: [PATCH 8/9] Added pub/sub support for MultiDBClient --- redis/asyncio/client.py | 12 +- redis/asyncio/multidb/client.py | 135 +++++++++++++++++- redis/asyncio/multidb/command_executor.py | 12 +- redis/multidb/client.py | 8 +- .../test_scenario/test_active_active.py | 42 +++++- 5 files changed, 191 insertions(+), 18 deletions(-) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index aac409073f..4c000bd2e7 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -1191,6 +1191,7 @@ async def run( *, exception_handler: Optional["PSWorkerThreadExcHandlerT"] = None, poll_timeout: float = 1.0, + pubsub = None ) -> None: """Process pub/sub messages using registered callbacks. @@ -1215,9 +1216,14 @@ async def run( await self.connect() while True: try: - await self.get_message( - ignore_subscribe_messages=True, timeout=poll_timeout - ) + if pubsub is None: + await self.get_message( + ignore_subscribe_messages=True, timeout=poll_timeout + ) + else: + await pubsub.get_message( + ignore_subscribe_messages=True, timeout=poll_timeout + ) except asyncio.CancelledError: raise except BaseException as e: diff --git a/redis/asyncio/multidb/client.py b/redis/asyncio/multidb/client.py index 1025c4b37b..7c0bef4f6e 100644 --- a/redis/asyncio/multidb/client.py +++ b/redis/asyncio/multidb/client.py @@ -1,6 +1,7 @@ import asyncio from typing import Callable, Optional, Coroutine, Any, List, Union, Awaitable +from redis.asyncio.client import PubSubHandler from redis.asyncio.multidb.command_executor import DefaultCommandExecutor from redis.asyncio.multidb.database import AsyncDatabase, Databases from redis.asyncio.multidb.failure_detector import AsyncFailureDetector @@ -10,7 +11,7 @@ from redis.background import BackgroundScheduler from redis.commands import AsyncRedisModuleCommands, AsyncCoreCommands from redis.multidb.exception import NoValidDatabaseException -from redis.typing import KeyT +from redis.typing import KeyT, EncodableT, ChannelT class MultiDBClient(AsyncRedisModuleCommands, AsyncCoreCommands): @@ -222,6 +223,17 @@ async def transaction( watch_delay=watch_delay, ) + async def pubsub(self, **kwargs): + """ + Return a Publish/Subscribe object. With this object, you can + subscribe to channels and listen for messages that get published to + them. + """ + if not self.initialized: + await self.initialize() + + return PubSub(self, **kwargs) + async def _check_databases_health( self, on_error: Optional[Callable[[Exception], Coroutine[Any, Any, None]]] = None, @@ -340,4 +352,123 @@ async def execute(self) -> List[Any]: try: return await self._client.command_executor.execute_pipeline(tuple(self._command_stack)) finally: - await self.reset() \ No newline at end of file + await self.reset() + +class PubSub: + """ + PubSub object for multi database client. + """ + def __init__(self, client: MultiDBClient, **kwargs): + """Initialize the PubSub object for a multi-database client. + + Args: + client: MultiDBClient instance to use for pub/sub operations + **kwargs: Additional keyword arguments to pass to the underlying pubsub implementation + """ + + self._client = client + self._client.command_executor.pubsub(**kwargs) + + async def __aenter__(self) -> "PubSub": + return self + + async def __aexit__(self, exc_type, exc_value, traceback) -> None: + await self.aclose() + + async def aclose(self): + return await self._client.command_executor.execute_pubsub_method('aclose') + + @property + def subscribed(self) -> bool: + return self._client.command_executor.active_pubsub.subscribed + + async def execute_command(self, *args: EncodableT): + return await self._client.command_executor.execute_pubsub_method('execute_command', *args) + + async def psubscribe(self, *args: ChannelT, **kwargs: PubSubHandler): + """ + Subscribe to channel patterns. Patterns supplied as keyword arguments + expect a pattern name as the key and a callable as the value. A + pattern's callable will be invoked automatically when a message is + received on that pattern rather than producing a message via + ``listen()``. + """ + return await self._client.command_executor.execute_pubsub_method( + 'psubscribe', + *args, + **kwargs + ) + + async def punsubscribe(self, *args: ChannelT): + """ + Unsubscribe from the supplied patterns. If empty, unsubscribe from + all patterns. + """ + return await self._client.command_executor.execute_pubsub_method( + 'punsubscribe', + *args + ) + + async def subscribe(self, *args: ChannelT, **kwargs: Callable): + """ + Subscribe to channels. Channels supplied as keyword arguments expect + a channel name as the key and a callable as the value. A channel's + callable will be invoked automatically when a message is received on + that channel rather than producing a message via ``listen()`` or + ``get_message()``. + """ + return await self._client.command_executor.execute_pubsub_method( + 'subscribe', + *args, + **kwargs + ) + + async def unsubscribe(self, *args): + """ + Unsubscribe from the supplied channels. If empty, unsubscribe from + all channels + """ + return await self._client.command_executor.execute_pubsub_method( + 'unsubscribe', + *args + ) + + async def get_message( + self, ignore_subscribe_messages: bool = False, timeout: Optional[float] = 0.0 + ): + """ + Get the next message if one is available, otherwise None. + + If timeout is specified, the system will wait for `timeout` seconds + before returning. Timeout should be specified as a floating point + number or None to wait indefinitely. + """ + return await self._client.command_executor.execute_pubsub_method( + 'get_message', + ignore_subscribe_messages=ignore_subscribe_messages, timeout=timeout + ) + + async def run( + self, + *, + exception_handler: Optional["PSWorkerThreadExcHandlerT"] = None, + poll_timeout: float = 1.0, + ) -> None: + """Process pub/sub messages using registered callbacks. + + This is the equivalent of :py:meth:`redis.PubSub.run_in_thread` in + redis-py, but it is a coroutine. To launch it as a separate task, use + ``asyncio.create_task``: + + >>> task = asyncio.create_task(pubsub.run()) + + To shut it down, use asyncio cancellation: + + >>> task.cancel() + >>> await task + """ + return await self._client.command_executor.execute_pubsub_run( + exception_handler=exception_handler, + sleep_time=poll_timeout, + pubsub=self + ) \ No newline at end of file diff --git a/redis/asyncio/multidb/command_executor.py b/redis/asyncio/multidb/command_executor.py index 4133dba394..f7ae0e717b 100644 --- a/redis/asyncio/multidb/command_executor.py +++ b/redis/asyncio/multidb/command_executor.py @@ -178,14 +178,10 @@ def failover_strategy(self) -> AsyncFailoverStrategy: def command_retry(self) -> Retry: return self._command_retry - async def pubsub(self, **kwargs): - async def callback(): - if self._active_pubsub is None: - self._active_pubsub = self._active_database.client.pubsub(**kwargs) - self._active_pubsub_kwargs = kwargs - return None - - return await self._execute_with_failure_detection(callback) + def pubsub(self, **kwargs): + if self._active_pubsub is None: + self._active_pubsub = self._active_database.client.pubsub(**kwargs) + self._active_pubsub_kwargs = kwargs async def execute_command(self, *args, **options): async def callback(): diff --git a/redis/multidb/client.py b/redis/multidb/client.py index 71e079346a..e6b815c76f 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -337,9 +337,6 @@ def __init__(self, client: MultiDBClient, **kwargs): def __enter__(self) -> "PubSub": return self - def __exit__(self, exc_type, exc_value, traceback) -> None: - self.reset() - def __del__(self) -> None: try: # if this object went out of scope prior to shutting down @@ -350,7 +347,7 @@ def __del__(self) -> None: pass def reset(self) -> None: - pass + return self._client.command_executor.execute_pubsub_method('reset') def close(self) -> None: self.reset() @@ -359,6 +356,9 @@ def close(self) -> None: def subscribed(self) -> bool: return self._client.command_executor.active_pubsub.subscribed + def execute_command(self, *args): + return self._client.command_executor.execute_pubsub_method('execute_command', *args) + def psubscribe(self, *args, **kwargs): """ Subscribe to channel patterns. Patterns supplied as keyword arguments diff --git a/tests/test_asyncio/test_scenario/test_active_active.py b/tests/test_asyncio/test_scenario/test_active_active.py index 76db322253..93068f6756 100644 --- a/tests/test_asyncio/test_scenario/test_active_active.py +++ b/tests/test_asyncio/test_scenario/test_active_active.py @@ -1,4 +1,5 @@ import asyncio +import json import logging from time import sleep @@ -186,4 +187,43 @@ async def callback(pipe: Pipeline): # Execute transaction until database failover while not listener.is_changed_flag: await r_multi_db.transaction(callback) == [True, True, True, 'value1', 'value2', 'value3'] - await asyncio.sleep(0.5) \ No newline at end of file + await asyncio.sleep(0.5) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "r_multi_db", + [{"failure_threshold": 2}], + indirect=True + ) + @pytest.mark.timeout(50) + async def test_pubsub_failover_to_another_db(self, r_multi_db, fault_injector_client): + r_multi_db, listener, config = r_multi_db + + event = asyncio.Event() + asyncio.create_task(trigger_network_failure_action(fault_injector_client,config,event)) + + data = json.dumps({'message': 'test'}) + messages_count = 0 + + async def handler(message): + nonlocal messages_count + messages_count += 1 + + pubsub = await r_multi_db.pubsub() + + # Assign a handler and run in a separate thread. + await pubsub.subscribe(**{'test-channel': handler}) + task = asyncio.create_task(pubsub.run(poll_timeout=0.1)) + + # Execute publish before network failure + while not event.is_set(): + await r_multi_db.publish('test-channel', data) + await asyncio.sleep(0.5) + + # Execute publish until database failover + while not listener.is_changed_flag: + await r_multi_db.publish('test-channel', data) + await asyncio.sleep(0.5) + + task.cancel() + assert messages_count > 1 \ No newline at end of file From a82d8e7d0bc6552099c4eb691d0ac00011c3fc4c Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Fri, 5 Sep 2025 10:43:16 +0300 Subject: [PATCH 9/9] Added check for couroutines methods for pub/sub --- redis/asyncio/multidb/command_executor.py | 6 +++++- tests/test_asyncio/test_scenario/test_active_active.py | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/redis/asyncio/multidb/command_executor.py b/redis/asyncio/multidb/command_executor.py index f7ae0e717b..7133955740 100644 --- a/redis/asyncio/multidb/command_executor.py +++ b/redis/asyncio/multidb/command_executor.py @@ -1,4 +1,5 @@ from abc import abstractmethod +from asyncio import iscoroutinefunction from datetime import datetime from typing import List, Optional, Callable, Any, Union, Awaitable @@ -221,7 +222,10 @@ async def callback(): async def execute_pubsub_method(self, method_name: str, *args, **kwargs): async def callback(): method = getattr(self.active_pubsub, method_name) - return await method(*args, **kwargs) + if iscoroutinefunction(method): + return await method(*args, **kwargs) + else: + return method(*args, **kwargs) return await self._execute_with_failure_detection(callback, *args) diff --git a/tests/test_asyncio/test_scenario/test_active_active.py b/tests/test_asyncio/test_scenario/test_active_active.py index 93068f6756..4d61434d8a 100644 --- a/tests/test_asyncio/test_scenario/test_active_active.py +++ b/tests/test_asyncio/test_scenario/test_active_active.py @@ -226,4 +226,5 @@ async def handler(message): await asyncio.sleep(0.5) task.cancel() + await pubsub.unsubscribe('test-channel') is True assert messages_count > 1 \ No newline at end of file