Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,7 @@ def __init__(
self._event_dispatcher = EventDispatcher()
else:
self._event_dispatcher = event_dispatcher
self.startup_nodes = startup_nodes
self.nodes_manager = NodesManager(
startup_nodes=startup_nodes,
from_url=from_url,
Expand Down
33 changes: 28 additions & 5 deletions redis/http/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,30 @@ def __init__(
auth_basic: Optional[Tuple[str, str]] = None, # (username, password)
user_agent: str = DEFAULT_USER_AGENT,
) -> None:
self.base_url = base_url.rstrip("/") + "/" if base_url and not base_url.endswith("/") else base_url
"""
Initialize a new HTTP client instance.

Args:
base_url: Base URL for all requests. Will be prefixed to all paths.
headers: Default headers to include in all requests.
timeout: Default timeout in seconds for requests.
retry: Retry configuration for failed requests.
verify_tls: Whether to verify TLS certificates.
ca_file: Path to CA certificate file for TLS verification.
ca_path: Path to a directory containing CA certificates.
ca_data: CA certificate data as string or bytes.
client_cert_file: Path to client certificate for mutual TLS.
client_key_file: Path to a client private key for mutual TLS.
client_key_password: Password for an encrypted client private key.
auth_basic: Tuple of (username, password) for HTTP basic auth.
user_agent: User-Agent header value for requests.

The client supports both regular HTTPS with server verification and mutual TLS
authentication. For server verification, provide CA certificate information via
ca_file, ca_path or ca_data. For mutual TLS, additionally provide a client
certificate and key via client_cert_file and client_key_file.
"""
self.base_url = base_url.rstrip() + "/" if base_url and not base_url.endswith("/") else base_url
self._default_headers = {k.lower(): v for k, v in (headers or {}).items()}
self.timeout = timeout
self.retry = retry
Expand Down Expand Up @@ -147,7 +170,7 @@ def post(
self,
path: str,
*,
json_body: Any = None,
json_body: Optional[Any] = None,
data: Optional[Union[bytes, str]] = None,
params: Optional[Mapping[str, Union[str, int, float, bool, None, list, tuple]]] = None,
headers: Optional[Mapping[str, str]] = None,
Expand All @@ -168,7 +191,7 @@ def put(
self,
path: str,
*,
json_body: Any = None,
json_body: Optional[Any] = None,
data: Optional[Union[bytes, str]] = None,
params: Optional[Mapping[str, Union[str, int, float, bool, None, list, tuple]]] = None,
headers: Optional[Mapping[str, str]] = None,
Expand All @@ -189,7 +212,7 @@ def patch(
self,
path: str,
*,
json_body: Any = None,
json_body: Optional[Any] = None,
data: Optional[Union[bytes, str]] = None,
params: Optional[Mapping[str, Union[str, int, float, bool, None, list, tuple]]] = None,
headers: Optional[Mapping[str, str]] = None,
Expand Down Expand Up @@ -317,7 +340,7 @@ def _json_call(
return resp.json()
return resp

def _prepare_body(self, *, json_body: Any = None, data: Optional[Union[bytes, str]] = None) -> Optional[Union[bytes, str]]:
def _prepare_body(self, *, json_body: Optional[Any] = None, data: Optional[Union[bytes, str]] = None) -> Optional[Union[bytes, str]]:
if json_body is not None and data is not None:
raise ValueError("Provide either json_body or data, not both.")
if json_body is not None:
Expand Down
4 changes: 2 additions & 2 deletions redis/multidb/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def __init__(self, config: MultiDbConfig):
self._databases = config.databases()
self._health_checks = config.default_health_checks()

if config.additional_health_checks is not None:
self._health_checks.extend(config.additional_health_checks)
if config.health_checks is not None:
self._health_checks.extend(config.health_checks)

self._health_check_interval = config.health_check_interval
self._failure_detectors = config.default_failure_detectors()
Expand Down
4 changes: 2 additions & 2 deletions redis/multidb/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class MultiDbConfig:
failure_detectors: Optional list of additional failure detectors for monitoring database failures.
failure_threshold: Threshold for determining database failure.
failures_interval: Time interval for tracking database failures.
additional_health_checks: Optional list of health checks performed on databases.
health_checks: Optional list of additional health checks performed on databases.
health_check_interval: Time interval for executing health checks.
health_check_retries: Number of retry attempts for performing health checks.
health_check_backoff: Backoff strategy for health check retries.
Expand Down Expand Up @@ -88,7 +88,7 @@ class MultiDbConfig:
failure_detectors: Optional[List[FailureDetector]] = None
failure_threshold: int = DEFAULT_FAILURES_THRESHOLD
failures_interval: float = DEFAULT_FAILURES_DURATION
additional_health_checks: Optional[List[HealthCheck]] = None
health_checks: Optional[List[HealthCheck]] = None
health_check_interval: float = DEFAULT_HEALTH_CHECK_INTERVAL
health_check_retries: int = DEFAULT_HEALTH_CHECK_RETRIES
health_check_backoff: AbstractBackoff = DEFAULT_HEALTH_CHECK_BACKOFF
Expand Down
21 changes: 12 additions & 9 deletions redis/multidb/healthcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from abc import abstractmethod, ABC
from typing import Optional, Tuple, Union

from redis import Redis
from redis.backoff import ExponentialWithJitterBackoff
from redis.http.http_client import DEFAULT_TIMEOUT, HttpClient
from redis.retry import Retry
Expand All @@ -28,7 +29,7 @@ def check_health(self, database) -> bool:
class AbstractHealthCheck(HealthCheck):
def __init__(
self,
retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF)
retry: Retry = Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF)
) -> None:
self._retry = retry
self._retry.update_supported_errors([ConnectionRefusedError])
Expand All @@ -45,7 +46,7 @@ def check_health(self, database) -> bool:
class EchoHealthCheck(AbstractHealthCheck):
def __init__(
self,
retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF),
retry: Retry = Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF)
) -> None:
"""
Check database healthiness by sending an echo request.
Expand Down Expand Up @@ -83,9 +84,8 @@ class LagAwareHealthCheck(AbstractHealthCheck):
"""
def __init__(
self,
retry=Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF),
retry: Retry = Retry(retries=DEFAULT_HEALTH_CHECK_RETRIES, backoff=DEFAULT_HEALTH_CHECK_BACKOFF),
rest_api_port: int = 9443,
availability_lag_tolerance: int = 100,
timeout: float = DEFAULT_TIMEOUT,
auth_basic: Optional[Tuple[str, str]] = None,
verify_tls: bool = True,
Expand All @@ -104,7 +104,6 @@ def __init__(
Args:
retry: Retry configuration for health checks
rest_api_port: Port number for Redis Enterprise REST API (default: 9443)
availability_lag_tolerance: Maximum acceptable lag in milliseconds (default: 100)
timeout: Request timeout in seconds (default: DEFAULT_TIMEOUT)
auth_basic: Tuple of (username, password) for basic authentication
verify_tls: Whether to verify TLS certificates (default: True)
Expand All @@ -115,7 +114,6 @@ def __init__(
client_key_file: Path to client private key file for mutual TLS
client_key_password: Password for encrypted client private key
"""

super().__init__(
retry=retry,
)
Expand All @@ -132,12 +130,17 @@ def __init__(
client_key_password=client_key_password
)
self._rest_api_port = rest_api_port
self._availability_lag_tolerance = availability_lag_tolerance

def check_health(self, database) -> bool:
client = database.client
db_host = client.get_connection_kwargs()['host']

if isinstance(client, Redis):
db_host = client.get_connection_kwargs()['host']
else:
db_host = client.startup_nodes[0].host

base_url = f"https://{db_host}:{self._rest_api_port}"
print(base_url)
self._http_client.base_url = base_url

# Find bdb matching to the current database host
Expand All @@ -158,7 +161,7 @@ def check_health(self, database) -> bool:
logger.warning("LagAwareHealthCheck failed: Couldn't find a matching bdb")
raise ValueError("Could not find a matching bdb")

url = f"/v1/bdbs/{matching_bdb['uid']}/availability?availability_lag_tolerance_ms={self._availability_lag_tolerance}"
url = f"/v1/local/bdbs/{matching_bdb['uid']}/endpoint/availability"
self._http_client.get(url, expect_json=False)

# Status checked in an http client, otherwise HttpError will be raised
Expand Down
29 changes: 9 additions & 20 deletions tests/test_multidb/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def test_execute_command_against_correct_db_on_successful_initialization(
with patch.object(mock_multi_db_config,'databases',return_value=databases), \
patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]):
mock_db1.client.execute_command.return_value = 'OK1'

mock_hc.check_health.return_value = True

client = MultiDBClient(mock_multi_db_config)
Expand Down Expand Up @@ -71,6 +72,7 @@ def test_execute_command_against_correct_db_and_closed_circuit(
with patch.object(mock_multi_db_config,'databases',return_value=databases), \
patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]):
mock_db1.client.execute_command.return_value = 'OK1'

mock_hc.check_health.side_effect = [False, True, True]

client = MultiDBClient(mock_multi_db_config)
Expand Down Expand Up @@ -199,11 +201,8 @@ def test_execute_command_throws_exception_on_failed_initialization(

with pytest.raises(NoValidDatabaseException, match='Initial connection failed - no active database found'):
client.set('key', 'value')
assert mock_hc.check_health.call_count == 3

assert mock_db.state == DBState.DISCONNECTED
assert mock_db1.state == DBState.DISCONNECTED
assert mock_db2.state == DBState.DISCONNECTED
assert mock_hc.check_health.call_count == 3

@pytest.mark.parametrize(
'mock_multi_db_config,mock_db, mock_db1, mock_db2',
Expand Down Expand Up @@ -254,6 +253,7 @@ def test_add_database_makes_new_database_active(
patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]):
mock_db1.client.execute_command.return_value = 'OK1'
mock_db2.client.execute_command.return_value = 'OK2'

mock_hc.check_health.return_value = True

client = MultiDBClient(mock_multi_db_config)
Expand All @@ -262,11 +262,7 @@ def test_add_database_makes_new_database_active(
assert client.set('key', 'value') == 'OK2'
assert mock_hc.check_health.call_count == 2

assert mock_db.state == DBState.PASSIVE
assert mock_db2.state == DBState.ACTIVE

client.add_database(mock_db1)

assert mock_hc.check_health.call_count == 3

assert client.set('key', 'value') == 'OK1'
Expand All @@ -292,6 +288,7 @@ def test_remove_highest_weighted_database(
patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]):
mock_db1.client.execute_command.return_value = 'OK1'
mock_db2.client.execute_command.return_value = 'OK2'

mock_hc.check_health.return_value = True

client = MultiDBClient(mock_multi_db_config)
Expand All @@ -300,10 +297,6 @@ def test_remove_highest_weighted_database(
assert client.set('key', 'value') == 'OK1'
assert mock_hc.check_health.call_count == 3

assert mock_db.state == DBState.PASSIVE
assert mock_db1.state == DBState.ACTIVE
assert mock_db2.state == DBState.PASSIVE

client.remove_database(mock_db1)

assert client.set('key', 'value') == 'OK2'
Expand All @@ -329,6 +322,7 @@ def test_update_database_weight_to_be_highest(
patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]):
mock_db1.client.execute_command.return_value = 'OK1'
mock_db2.client.execute_command.return_value = 'OK2'

mock_hc.check_health.return_value = True

client = MultiDBClient(mock_multi_db_config)
Expand All @@ -337,10 +331,6 @@ def test_update_database_weight_to_be_highest(
assert client.set('key', 'value') == 'OK1'
assert mock_hc.check_health.call_count == 3

assert mock_db.state == DBState.PASSIVE
assert mock_db1.state == DBState.ACTIVE
assert mock_db2.state == DBState.PASSIVE

client.update_database_weight(mock_db2, 0.8)
assert mock_db2.weight == 0.8

Expand Down Expand Up @@ -374,6 +364,7 @@ def test_add_new_failure_detector(
commands=('SET', 'key', 'value'),
exception=Exception(),
)

mock_hc.check_health.return_value = True

client = MultiDBClient(mock_multi_db_config)
Expand Down Expand Up @@ -417,6 +408,7 @@ def test_add_new_health_check(
with patch.object(mock_multi_db_config,'databases',return_value=databases), \
patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]):
mock_db1.client.execute_command.return_value = 'OK1'

mock_hc.check_health.return_value = True

client = MultiDBClient(mock_multi_db_config)
Expand Down Expand Up @@ -454,17 +446,14 @@ def test_set_active_database(
patch.object(mock_multi_db_config,'default_health_checks', return_value=[mock_hc]):
mock_db1.client.execute_command.return_value = 'OK1'
mock_db.client.execute_command.return_value = 'OK'

mock_hc.check_health.return_value = True

client = MultiDBClient(mock_multi_db_config)
assert mock_multi_db_config.failover_strategy.set_databases.call_count == 1
assert client.set('key', 'value') == 'OK1'
assert mock_hc.check_health.call_count == 3

assert mock_db.state == DBState.PASSIVE
assert mock_db1.state == DBState.ACTIVE
assert mock_db2.state == DBState.PASSIVE

client.set_active_database(mock_db)
assert client.set('key', 'value') == 'OK'

Expand Down
8 changes: 4 additions & 4 deletions tests/test_multidb/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_overridden_config(self):
config = MultiDbConfig(
databases_config=db_configs,
failure_detectors=mock_failure_detectors,
additional_health_checks=mock_health_checks,
health_checks=mock_health_checks,
health_check_interval=health_check_interval,
failover_strategy=mock_failover_strategy,
auto_fallback_interval=auto_fallback_interval,
Expand All @@ -96,9 +96,9 @@ def test_overridden_config(self):
assert len(config.failure_detectors) == 2
assert config.failure_detectors[0] == mock_failure_detectors[0]
assert config.failure_detectors[1] == mock_failure_detectors[1]
assert len(config.additional_health_checks) == 2
assert config.additional_health_checks[0] == mock_health_checks[0]
assert config.additional_health_checks[1] == mock_health_checks[1]
assert len(config.health_checks) == 2
assert config.health_checks[0] == mock_health_checks[0]
assert config.health_checks[1] == mock_health_checks[1]
assert config.health_check_interval == health_check_interval
assert config.failover_strategy == mock_failover_strategy
assert config.auto_fallback_interval == auto_fallback_interval
Expand Down
15 changes: 8 additions & 7 deletions tests/test_multidb/test_healthcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import pytest

from redis.backoff import ExponentialBackoff
from redis.multidb.database import Database
from redis.multidb.healthcheck import EchoHealthCheck
from redis.http.http_client import HttpError
from redis.multidb.database import Database, State
from redis.multidb.healthcheck import EchoHealthCheck, LagAwareHealthCheck
from redis.multidb.circuit import State as CBState
from redis.exceptions import ConnectionError
Expand Down Expand Up @@ -78,7 +79,7 @@ def test_database_is_healthy_when_bdb_matches_by_dns_name(self, mock_client, moc
# Inject our mocked http client
hc._http_client = mock_http

db = Database(mock_client, mock_cb, 1.0, State.ACTIVE)
db = Database(mock_client, mock_cb, 1.0)

assert hc.check_health(db) is True
# Base URL must be set correctly
Expand All @@ -88,7 +89,7 @@ def test_database_is_healthy_when_bdb_matches_by_dns_name(self, mock_client, moc
first_call = mock_http.get.call_args_list[0]
second_call = mock_http.get.call_args_list[1]
assert first_call.args[0] == "/v1/bdbs"
assert second_call.args[0] == "/v1/bdbs/bdb-1/availability"
assert second_call.args[0] == "/v1/local/bdbs/bdb-1/endpoint/availability"
assert second_call.kwargs.get("expect_json") is False

def test_database_is_healthy_when_bdb_matches_by_addr(self, mock_client, mock_cb):
Expand Down Expand Up @@ -116,11 +117,11 @@ def test_database_is_healthy_when_bdb_matches_by_addr(self, mock_client, mock_cb
)
hc._http_client = mock_http

db = Database(mock_client, mock_cb, 1.0, State.ACTIVE)
db = Database(mock_client, mock_cb, 1.0)

assert hc.check_health(db) is True
assert mock_http.get.call_count == 2
assert mock_http.get.call_args_list[1].args[0] == "/v1/bdbs/bdb-42/availability"
assert mock_http.get.call_args_list[1].args[0] == "/v1/local/bdbs/bdb-42/endpoint/availability"

def test_raises_value_error_when_no_matching_bdb(self, mock_client, mock_cb):
"""
Expand All @@ -141,7 +142,7 @@ def test_raises_value_error_when_no_matching_bdb(self, mock_client, mock_cb):
)
hc._http_client = mock_http

db = Database(mock_client, mock_cb, 1.0, State.ACTIVE)
db = Database(mock_client, mock_cb, 1.0)

with pytest.raises(ValueError, match="Could not find a matching bdb"):
hc.check_health(db)
Expand Down Expand Up @@ -169,7 +170,7 @@ def test_propagates_http_error_from_availability(self, mock_client, mock_cb):
)
hc._http_client = mock_http

db = Database(mock_client, mock_cb, 1.0, State.ACTIVE)
db = Database(mock_client, mock_cb, 1.0)

with pytest.raises(HttpError, match="busy") as e:
hc.check_health(db)
Expand Down
Loading