diff --git a/aws_advanced_python_wrapper/host_monitoring_v2_plugin.py b/aws_advanced_python_wrapper/host_monitoring_v2_plugin.py new file mode 100644 index 00000000..5d346ac6 --- /dev/null +++ b/aws_advanced_python_wrapper/host_monitoring_v2_plugin.py @@ -0,0 +1,526 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import copy +import weakref +from queue import Queue +from threading import Thread +from time import perf_counter_ns, sleep +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Set + +from aws_advanced_python_wrapper.errors import AwsWrapperError +from aws_advanced_python_wrapper.host_availability import HostAvailability +from aws_advanced_python_wrapper.plugin import (CanReleaseResources, Plugin, + PluginFactory) +from aws_advanced_python_wrapper.utils.atomic import (AtomicBoolean, + AtomicReference) +from aws_advanced_python_wrapper.utils.concurrent import ConcurrentDict +from aws_advanced_python_wrapper.utils.log import Logger +from aws_advanced_python_wrapper.utils.messages import Messages +from aws_advanced_python_wrapper.utils.notifications import ( + ConnectionEvent, OldConnectionSuggestedAction) +from aws_advanced_python_wrapper.utils.properties import (Properties, + WrapperProperties) +from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils +from aws_advanced_python_wrapper.utils.sliding_expiration_cache import \ + SlidingExpirationCacheWithCleanupThread +from aws_advanced_python_wrapper.utils.telemetry.telemetry import ( + TelemetryCounter, TelemetryFactory, TelemetryTraceLevel) + +if TYPE_CHECKING: + from aws_advanced_python_wrapper.driver_dialect import DriverDialect + from aws_advanced_python_wrapper.hostinfo import HostInfo + from aws_advanced_python_wrapper.pep249 import Connection + from aws_advanced_python_wrapper.plugin_service import PluginService + +logger = Logger(__name__) + + +class HostMonitoringV2Plugin(Plugin, CanReleaseResources): + _SUBSCRIBED_METHODS: Set[str] = {"*"} + + def __init__(self, plugin_service, props): + dialect: DriverDialect = plugin_service.driver_dialect + if not dialect.supports_abort_connection(): + raise AwsWrapperError(Messages.get_formatted( + "HostMonitoringV2Plugin.ConfigurationNotSupported", type(dialect).__name__)) + + self._properties: Properties = props + self._plugin_service: PluginService = plugin_service + self._monitoring_host_info: Optional[HostInfo] = None + self._rds_utils: RdsUtils = RdsUtils() + self._monitor_service: MonitorServiceV2 = MonitorServiceV2(plugin_service) + self._failure_detection_time_ms = WrapperProperties.FAILURE_DETECTION_TIME_MS.get_int(self._properties) + self._failure_detection_interval_ms = WrapperProperties.FAILURE_DETECTION_INTERVAL_MS.get_int(self._properties) + self._failure_detection_count = WrapperProperties.FAILURE_DETECTION_COUNT.get_int(self._properties) + self._failure_detection_enabled = WrapperProperties.FAILURE_DETECTION_ENABLED.get_bool(self._properties) + + @property + def subscribed_methods(self) -> Set[str]: + return HostMonitoringV2Plugin._SUBSCRIBED_METHODS + + def connect( + self, + target_driver_func: Callable, + driver_dialect: DriverDialect, + host_info: HostInfo, + props: Properties, + is_initial_connection: bool, + connect_func: Callable) -> Connection: + connection = connect_func() + if connection: + rds_type = self._rds_utils.identify_rds_type(host_info.host) + if rds_type.is_rds_cluster: + host_info.reset_aliases() + self._plugin_service.fill_aliases(connection, host_info) + return connection + + def execute(self, target: object, method_name: str, execute_func: Callable, *args: Any, **kwargs: Any) -> Any: + if self._plugin_service.current_connection is None: + raise AwsWrapperError(Messages.get_formatted("HostMonitoringV2Plugin.ConnectionNone", method_name)) + if not self._failure_detection_enabled or not self._plugin_service.is_network_bound_method(method_name): + return execute_func() + monitor_context = None + result = None + + try: + logger.debug("HostMonitoringV2Plugin.ActivatedMonitoring", method_name) + monitor_context = self._monitor_service.start_monitoring( + self._plugin_service.current_connection, + self._get_monitoring_host_info(), + self._properties, + self._failure_detection_time_ms, + self._failure_detection_interval_ms, + self._failure_detection_count + ) + result = execute_func() + finally: + if monitor_context is not None: + self._monitor_service.stop_monitoring(monitor_context, self._plugin_service.current_connection) + logger.debug("HostMonitoringV2Plugin.MonitoringDeactivated", method_name) + + return result + + def notify_connection_changed(self, changes: Set[ConnectionEvent]) -> OldConnectionSuggestedAction: + if ConnectionEvent.CONNECTION_OBJECT_CHANGED in changes: + self._monitoring_host_info = None + + return OldConnectionSuggestedAction.NO_OPINION + + def _get_monitoring_host_info(self) -> HostInfo: + if self._monitoring_host_info is None: + current_host_info = self._plugin_service.current_host_info + if current_host_info is None: + raise AwsWrapperError(Messages.get("HostMonitoringV2Plugin.HostInfoNone")) + self._monitoring_host_info = current_host_info + rds_url_type = self._rds_utils.identify_rds_type(self._monitoring_host_info.url) + + try: + if not rds_url_type.is_rds_cluster: + return self._monitoring_host_info + logger.debug("HostMonitoringV2Plugin.ClusterEndpointHostInfo") + current_connection = self._plugin_service.current_connection + self._monitoring_host_info = self._plugin_service.identify_connection(current_connection) + if self._monitoring_host_info is None: + raise AwsWrapperError( + Messages.get_formatted( + "HostMonitoringV2Plugin.UnableToIdentifyConnection", + current_host_info.host, + self._plugin_service.host_list_provider)) + self._plugin_service.fill_aliases(current_connection, self._monitoring_host_info) + except Exception as e: + message = "HostMonitoringV2Plugin.ErrorIdentifyingConnection" + logger.debug(message, e) + raise AwsWrapperError(Messages.get_formatted(message, e)) from e + return self._monitoring_host_info + + def release_resources(self): + if self._monitor_service is not None: + self._monitor_service.release_resources() + + self._monitor_service = None + + +class HostMonitoringV2PluginFactory(PluginFactory): + + def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: + return HostMonitoringV2Plugin(plugin_service, props) + + +class MonitoringContext: + """ + Monitoring context for each connection. + This contains each connection's criteria for whether a server should be considered unhealthy. + The context is shared between the main thread and the monitor thread. + """ + + def __init__(self, connection: Connection): + self._connection_to_abort: AtomicReference = AtomicReference(weakref.ref(connection)) + self._host_unhealthy: AtomicBoolean = AtomicBoolean(False) + + def set_host_unhealthy(self) -> None: + self._host_unhealthy.set(True) + + def should_abort(self): + connection_weak_ref = self._connection_to_abort.get() + return self._host_unhealthy.get() and connection_weak_ref is not None and connection_weak_ref() is not None + + def set_inactive(self) -> None: + self._connection_to_abort.set(None) + + def get_connection(self) -> Optional[Connection]: + connection_weak_ref = self._connection_to_abort.get() + if connection_weak_ref is not None: + return connection_weak_ref() + else: + return None + + def is_active(self) -> bool: + connection_weak_ref = self._connection_to_abort.get() + if connection_weak_ref is not None: + connection = connection_weak_ref() + return connection is not None + return connection_weak_ref is not None and connection_weak_ref() is not None + + +class HostMonitorV2: + """ + This class uses a background thread to monitor a particular server with one or more active :py:class:Connection + objects. It performs periodic health checks and aborts connections when the server becomes unhealthy. + """ + _THREAD_SLEEP_NANO = 100_000_000 + _MONITORING_PROPERTY_PREFIX = "monitoring-" + _QUERY = "SELECT 1" + + def __init__( + self, + plugin_service: PluginService, + host_info: HostInfo, + props: Properties, + failure_detection_time_ms: int, + failure_detection_interval_ms: int, + failure_detection_count: int, + aborted_connection_counter: TelemetryCounter): + self._plugin_service: PluginService = plugin_service + self._host_info: HostInfo = host_info + self._props: Properties = props + self._telemetry_factory: TelemetryFactory = self._plugin_service.get_telemetry_factory() + self._failure_detection_time_ns: int = failure_detection_time_ms * 10**6 + self._failure_detection_interval_ns: int = failure_detection_interval_ms * 10**6 + self._failure_detection_count: int = failure_detection_count + self._aborted_connection_counter: TelemetryCounter = aborted_connection_counter + + self._active_contexts: Queue = Queue() + self._new_contexts: ConcurrentDict[float, Queue] = ConcurrentDict() + self._is_stopped: AtomicBoolean = AtomicBoolean(False) + self._is_unhealthy: bool = False + self._failure_count: int = 0 + self._invalid_host_start_time_ns: int = 0 + self._monitoring_connection: Optional[Connection] = None + self._driver_dialect: DriverDialect = self._plugin_service.driver_dialect + + self._monitor_run_thread: Thread = Thread(daemon=True, name="HostMonitoringThreadRun", target=self.run) + self._monitor_run_thread.start() + self._monitor_new_context_thread: Thread = Thread(daemon=True, name="HostMonitoringThreadNewContextRun", + target=self._new_context_run) + self._monitor_new_context_thread.start() + + def can_dispose(self) -> bool: + return self._active_contexts.empty() and len(self._new_contexts.items()) == 0 + + @property + def is_stopped(self): + return self._is_stopped.get() + + def stop(self): + self._is_stopped.set(True) + + def start_monitoring(self, context: MonitoringContext): + if self.is_stopped: + logger.warning("HostMonitorV2.MonitorIsStopped", self._host_info.host) + + current_time_ns = self.get_current_time_ns() + start_monitoring_time_ns = self._round_ns_to_seconds(current_time_ns + self._failure_detection_time_ns) + weak_ref = weakref.ref(context) + queue = self._new_contexts.compute_if_absent(start_monitoring_time_ns, lambda _: Queue()) + if queue is not None: + queue.put(weak_ref) + + def _round_ns_to_seconds(self, nano_seconds): + return (nano_seconds // 1_000_000_000) * 1_000_000_000 + + def get_current_time_ns(self) -> float: + return float(perf_counter_ns()) + + def _new_context_run(self) -> None: + logger.debug("HostMonitorV2.StartMonitoringThreadNewContext", self._host_info.host) + + try: + while not self.is_stopped: + current_time_ns = self.get_current_time_ns() + + processed_keys = [] + keys = list(self._new_contexts.keys()) + for key in keys: + queue: Queue | None = self._new_contexts.get(key) + if queue is None: + continue + if key > current_time_ns: + continue + processed_keys.append(key) + while not queue.empty(): + context_weak_ref = queue.get() + if context_weak_ref is not None and context_weak_ref() is not None: + active = context_weak_ref().is_active() + if active: + self._active_contexts.put(context_weak_ref) + + for key in processed_keys: + self._new_contexts.remove(key) + + sleep(1) + except InterruptedError: + pass + except Exception as ex: + logger.debug("HostMonitorV2.ExceptionDuringMonitoringStop", self._host_info.host, ex) + + self._monitor_new_context_thread.join(5) + logger.debug("HostMonitorV2.StopMonitoringThreadNewContext", self._host_info.host) + + def run(self) -> None: + logger.debug("HostMonitorV2.StartMonitoringThread", self._host_info.host) + + try: + while not self.is_stopped: + if self._active_contexts.empty() and not self._is_unhealthy: + sleep(HostMonitorV2._THREAD_SLEEP_NANO / 1_000_000_000) + continue + + status_check_start_time_ns: float = self.get_current_time_ns() + is_valid: bool = self.check_connection_status() + status_check_end_time_ns: float = self.get_current_time_ns() + + self._update_host_health_status(is_valid, status_check_start_time_ns, status_check_end_time_ns) + + if self._is_unhealthy: + self._plugin_service.set_availability(self._host_info.as_aliases(), HostAvailability.UNAVAILABLE) + + temp_active_contexts = [] + while not self._active_contexts.empty(): + monitor_context_weak_ref = self._active_contexts.get() + if self.is_stopped: + break + + if monitor_context_weak_ref is None: + continue + + monitor_context = monitor_context_weak_ref() + + if monitor_context is None: + continue + if self._is_unhealthy: + # Kill Connection + monitor_context.set_host_unhealthy() + connection_to_abort = monitor_context.get_connection() + if connection_to_abort is not None: + self.abort_connection(connection_to_abort) + if self._aborted_connection_counter is not None: + self._aborted_connection_counter.inc() + monitor_context.set_inactive() + elif monitor_context.is_active(): + temp_active_contexts.append(monitor_context) + + for active_context in temp_active_contexts: + self._active_contexts.put(weakref.ref(active_context)) + + delay_ns = self._failure_detection_interval_ns - (status_check_end_time_ns - status_check_start_time_ns) + if delay_ns < self._THREAD_SLEEP_NANO: + delay_ns = self._THREAD_SLEEP_NANO + + sleep(delay_ns / 1_000_000_000) + + except InterruptedError: + pass + except Exception as ex: + logger.debug("HostMonitorV2.ExceptionDuringMonitoringStop", self._host_info.host, ex) + finally: + self.stop() + if self._monitoring_connection is not None: + try: + self._driver_dialect.abort_connection(self._monitoring_connection) + except AwsWrapperError as ex: + logger.debug(ex) + pass + + logger.debug("HostMonitorV2.StopMonitoringThread", self._host_info.host) + self._monitor_run_thread.join(5) + + def check_connection_status(self) -> bool: + connect_telemetry_context = self._telemetry_factory.open_telemetry_context("connection status check", + TelemetryTraceLevel.FORCE_TOP_LEVEL) + + if connect_telemetry_context is not None: + connect_telemetry_context.set_attribute("url", self._host_info.url) + + try: + if self._monitoring_connection is None or self._driver_dialect.is_closed(self._monitoring_connection): + monitoring_properties = copy.deepcopy(self._props) + for property_key in monitoring_properties.keys(): + if property_key.startswith(self._MONITORING_PROPERTY_PREFIX): + monitoring_properties[property_key[len(self._MONITORING_PROPERTY_PREFIX):]] = \ + monitoring_properties[property_key] + monitoring_properties.pop(property_key, None) + + logger.debug("HostMonitorV2.OpeningMonitoringConnection", self._host_info.url) + self._monitoring_connection = self._plugin_service.force_connect(self._host_info, monitoring_properties) + logger.debug("HostMonitorV2.OpenedMonitoringConnection", self._host_info.url) + return True + valid_timeout = ((self._failure_detection_interval_ns - self._THREAD_SLEEP_NANO) / 2) / 1_000_000_000 + return self._is_host_available(self._monitoring_connection, valid_timeout) + except AwsWrapperError: + return False + except Exception: + return False + finally: + if connect_telemetry_context is not None: + connect_telemetry_context.close_context() + + def _is_host_available(self, conn: Connection, timeout_sec: float) -> bool: + try: + self._execute_conn_check(conn, timeout_sec) + return True + except TimeoutError: + return False + + def _execute_conn_check(self, conn: Connection, timeout_sec: float): + driver_dialect = self._plugin_service.driver_dialect + with conn.cursor() as cursor: + query = HostMonitorV2._QUERY + driver_dialect.execute("Cursor.execute", lambda: cursor.execute(query), query, exec_timeout=timeout_sec) + cursor.fetchone() + + def _update_host_health_status( + self, + connection_valid: bool, + status_check_start_ns: float, + status_check_end_ns: float) -> None: + if not connection_valid: + self._failure_count += 1 + + if self._invalid_host_start_time_ns == 0: + self._invalid_host_start_time_ns = int(status_check_start_ns) + + invalid_host_duration_ns = status_check_end_ns - self._invalid_host_start_time_ns + max_invalid_host_duration_ns = ( + self._failure_detection_interval_ns * max(0, self._failure_detection_count - 1)) + + if invalid_host_duration_ns >= max_invalid_host_duration_ns: + logger.debug("HostMonitorV2.HostDead", self._host_info.host) + self._is_unhealthy = True + return + + logger.debug("HostMonitorV2.HostNotResponding", self._host_info.host, self._failure_count) + return + + if self._failure_count > 0: + # Host is back alive + logger.debug("HostMonitorV2.HostAlive", self._host_info.host) + + self._failure_count = 0 + self._invalid_host_start_time_ns = 0 + self._is_unhealthy = False + + def abort_connection(self, connection: Connection) -> None: + try: + if self._driver_dialect.is_closed(connection): + return + self._driver_dialect.abort_connection(connection) + except Exception as ex: + logger.debug("HostMonitorV2.ExceptionAbortingConnection", ex) + + def close(self) -> None: + self.stop() + + +class MonitorServiceV2: + # 1 Minute to Nanoseconds + _CACHE_CLEANUP_NANO = 1 * 60 * 1_000_000_000 + + _monitors: ClassVar[SlidingExpirationCacheWithCleanupThread[str, HostMonitorV2]] = \ + SlidingExpirationCacheWithCleanupThread(_CACHE_CLEANUP_NANO, + should_dispose_func=lambda monitor: monitor.can_dispose(), + item_disposal_func=lambda monitor: monitor.close()) + + def __init__(self, plugin_service: PluginService): + self._plugin_service: PluginService = plugin_service + + telemetry_factory = self._plugin_service.get_telemetry_factory() + self._aborted_connections_counter = telemetry_factory.create_counter("efm2.connections.aborted") + + def start_monitoring( + self, + conn: Connection, + host_info: HostInfo, + props: Properties, + failure_detection_time_ms: int, + failure_detection_interval_ms: int, + failure_detection_count: int) -> MonitoringContext: + monitor = self.get_monitor(conn, host_info, props, failure_detection_time_ms, failure_detection_interval_ms, + failure_detection_count) + context = MonitoringContext(conn) + if monitor is not None: + monitor.start_monitoring(context) + return context + + def stop_monitoring(self, context: MonitoringContext, connection_to_abort: Connection): + if context.should_abort(): + context.set_inactive() + try: + self._plugin_service.driver_dialect.abort_connection(connection_to_abort) + if self._aborted_connections_counter is not None: + self._aborted_connections_counter.inc() + except AwsWrapperError as ex: + logger.debug("MonitorServiceV2.ExceptionAbortingConnection", ex) + else: + context.set_inactive() + + def release_resources(self): + pass + + def get_monitor(self, + conn: Connection, + host_info: HostInfo, + props: Properties, + failure_detection_time_ms: int, + failure_detection_interval_ms: int, + failure_detection_count: int) -> Optional[HostMonitorV2]: + monitor_key = "{}:{}:{}:{}".format( + failure_detection_time_ms, + failure_detection_interval_ms, + failure_detection_count, + host_info.host + ) + + cache_expiration_ns = int(WrapperProperties.MONITOR_DISPOSAL_TIME_MS.get_float(props) * 10**6) + return self._monitors.compute_if_absent(monitor_key, + lambda k: HostMonitorV2(self._plugin_service, + host_info, + props, + failure_detection_time_ms, + failure_detection_interval_ms, + failure_detection_count, + self._aborted_connections_counter), + cache_expiration_ns) diff --git a/aws_advanced_python_wrapper/plugin_service.py b/aws_advanced_python_wrapper/plugin_service.py index b58f8ea3..f3dc5fbc 100644 --- a/aws_advanced_python_wrapper/plugin_service.py +++ b/aws_advanced_python_wrapper/plugin_service.py @@ -75,6 +75,8 @@ HostListProviderService, StaticHostListProvider) from aws_advanced_python_wrapper.host_monitoring_plugin import \ HostMonitoringPluginFactory +from aws_advanced_python_wrapper.host_monitoring_v2_plugin import \ + HostMonitoringV2PluginFactory from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole from aws_advanced_python_wrapper.iam_plugin import IamAuthPluginFactory from aws_advanced_python_wrapper.plugin import CanReleaseResources @@ -755,6 +757,7 @@ class PluginManager(CanReleaseResources): "aws_secrets_manager": AwsSecretsManagerPluginFactory, "aurora_connection_tracker": AuroraConnectionTrackerPluginFactory, "host_monitoring": HostMonitoringPluginFactory, + "host_monitoring_v2": HostMonitoringV2PluginFactory, "failover": FailoverPluginFactory, "read_write_splitting": ReadWriteSplittingPluginFactory, "fastest_response_strategy": FastestResponseStrategyPluginFactory, @@ -783,6 +786,7 @@ class PluginManager(CanReleaseResources): ReadWriteSplittingPluginFactory: 300, FailoverPluginFactory: 400, HostMonitoringPluginFactory: 500, + HostMonitoringV2PluginFactory: 510, BlueGreenPluginFactory: 550, FastestResponseStrategyPluginFactory: 600, IamAuthPluginFactory: 700, diff --git a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties index 9d690871..c7be1412 100644 --- a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties +++ b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties @@ -176,6 +176,28 @@ HostMonitoringPlugin.ConfigurationNotSupported=[HostMonitoringPlugin] Aborting c HostMonitoringPlugin.UnableToIdentifyConnection=[HostMonitoringPlugin] Unable to identify the connected database instance: '{}', please ensure the correct host list provider is specified. The host list provider in use is: '{}'. HostMonitoringPlugin.UnavailableHost=[HostMonitoringPlugin] Host '{}' is unavailable. +HostMonitoringV2Plugin.ActivatedMonitoring=[HostMonitoringV2Plugin] Executing method '{}', monitoring is activated. +HostMonitoringV2Plugin.ClusterEndpointHostInfo=[HostMonitoringV2Plugin] The HostInfo to monitor is associated with a cluster endpoint. The plugin will attempt to identify the connected database instance. +HostMonitoringV2Plugin.ErrorIdentifyingConnection=[HostMonitoringV2Plugin] An error occurred while identifying the connection database instance: '{}'. +HostMonitoringV2Plugin.MonitoringDeactivated=[HostMonitoringV2Plugin] Monitoring deactivated for method '{}'. +HostMonitoringV2Plugin.ConnectionNone=[HostMonitoringV2Plugin] Attempted to execute method '{}' but the current connection is None. +HostMonitoringV2Plugin.HostInfoNone=[HostMonitoringV2Plugin] Could not find HostInfo to monitor for the current connection. +HostMonitoringV2Plugin.ConfigurationNotSupported=[HostMonitoringV2Plugin] Aborting connections from a separate thread is not supported for the detected driver dialect: '{}'. The EFM V2 plugin requires this feature to be supported. +HostMonitoringV2Plugin.UnableToIdentifyConnection=[HostMonitoringV2Plugin] Unable to identify the connected database instance: '{}', please ensure the correct host list provider is specified. The host list provider in use is: '{}'. +HostMonitorV2.ExceptionDuringMonitoringStop=[HostMonitorV2] Stopping monitoring after unhandled exception was thrown in monitoring thread for host '{}'. Exception: '{}' +HostMonitorV2.MonitorIsStopped=[HostMonitorV2] Monitoring was already stopped for host '{}'. +HostMonitorV2.StartMonitoringThreadNewContext=[HostMonitorV2] Start monitoring thread for checking new contexts for '{}'. +HostMonitorV2.StopMonitoringThreadNewContext=[HostMonitorV2] Stop monitoring thread for checking new contexts for '{}'. +HostMonitorV2.StartMonitoringThread=[HostMonitorV2] Start monitoring thread for '{}'. +HostMonitorV2.OpeningMonitoringConnection=[HostMonitorV2] Opening a monitoring connection to '{}' +HostMonitorV2.OpenedMonitoringConnection=[HostMonitorV2] Opened monitoring connection: '{}' +HostMonitorV2.ExceptionAbortingConnection=[HostMonitorV2] Exception while aborting connection: '{}' +HostMonitorV2.HostDead=[HostMonitorV2] Host '{}' is *dead*. +HostMonitorV2.HostNotResponding=[HostMonitorV2] Host '{}' is not *responding* '{}'. +HostMonitorV2.HostAlive=[HostMonitorV2] Host '{}' is *alive*. +HostMonitorV2.StopMonitoringThread=[HostMonitorV2] Stop monitoring thread for '{}'. +MonitorServiceV2.ExceptionAbortingConnection=[MonitorServiceV2] Exception during aborting connection: '{}' + HostSelector.NoEligibleHost=[HostSelector] No Eligible Hosts Found. HostSelector.NoHostsMatchingRole=[HostSelector] No hosts were found matching the requested role: '{}'. @@ -193,15 +215,15 @@ LimitlessPlugin.UnsupportedDialectOrDatabase=[LimitlessPlugin] Unsupported diale LimitlessQueryHelper.UnsupportedDialectOrDatabase=[LimitlessQueryHelper] Unsupported dialect '{}' encountered. Please ensure JDBC connection parameters are correct, and refer to the documentation to ensure that the connecting database is compatible with the Limitless Connection Plugin. -LimitlessRouterMonitor.errorDuringMonitoringStop=[LimitlessRouterMonitor] Stopping monitoring after unhandled error was thrown in Limitless Router Monitoring thread for node {}. Error: {} -LimitlessRouterMonitor.InterruptedErrorDuringMonitoring=[LimitlessRouterMonitor] Limitless Router Monitoring thread for node {} was interrupted. +LimitlessRouterMonitor.errorDuringMonitoringStop=[LimitlessRouterMonitor] Stopping monitoring after unhandled error was thrown in Limitless Router Monitoring thread for host {}. Error: {} +LimitlessRouterMonitor.InterruptedErrorDuringMonitoring=[LimitlessRouterMonitor] Limitless Router Monitoring thread for host {} was interrupted. LimitlessRouterMonitor.InvalidQuery=[LimitlessRouterMonitor] Limitless Connection Plugin has encountered an error obtaining Limitless Router endpoints. Please ensure that you are connecting to an Aurora Limitless Database Shard Group Endpoint URL. LimitlessRouterMonitor.InvalidRouterLoad=[LimitlessRouterMonitor] Invalid load metric value of '{}' from the transaction router query aurora_limitless_router_endpoints() for transaction router '{}'. The load metric value must be a decimal value between 0 and 1. Host weight be assigned a default weight of 1. LimitlessRouterMonitor.GetNetworkTimeoutError=[LimitlessRouterMonitor] An error occurred while getting the connection network timeout: {} LimitlessRouterMonitor.OpeningConnection=[LimitlessRouterMonitor] Opening Limitless Router Monitor connection to '{}'. LimitlessRouterMonitor.OpenedConnection=[LimitlessRouterMonitor] Opened Limitless Router Monitor connection: {}. -LimitlessRouterMonitor.Running=[LimitlessRouterMonitor] Limitless Router Monitor thread running on node {}. -LimitlessRouterMonitor.Stopped=[LimitlessRouterMonitor] Limitless Router Monitor thread stopped on node {}. +LimitlessRouterMonitor.Running=[LimitlessRouterMonitor] Limitless Router Monitor thread running on host {}. +LimitlessRouterMonitor.Stopped=[LimitlessRouterMonitor] Limitless Router Monitor thread stopped on host {}. LimitlessRouterService.ConnectWithHost=[LimitlessRouterService] Connecting to host {}. LimitlessRouterService.ErrorClosingMonitor=[LimitlessRouterService] An error occurred while closing Limitless Router Monitor: {} diff --git a/aws_advanced_python_wrapper/utils/atomic.py b/aws_advanced_python_wrapper/utils/atomic.py index 42d7961e..01e61c4e 100644 --- a/aws_advanced_python_wrapper/utils/atomic.py +++ b/aws_advanced_python_wrapper/utils/atomic.py @@ -13,6 +13,9 @@ # limitations under the License. from threading import Lock +from typing import Generic, TypeVar + +T = TypeVar('T') class AtomicInt: @@ -59,3 +62,31 @@ def compare_and_set(self, expected_value: int, new_value: int) -> bool: self._value = new_value return True return False + + +class AtomicBoolean: + def __init__(self, initial_value: bool): + self._value: bool = initial_value + self._lock: Lock = Lock() + + def get(self) -> bool: + with self._lock: + return self._value + + def set(self, value: bool) -> None: + with self._lock: + self._value = value + + +class AtomicReference(Generic[T]): + def __init__(self, initial_value: T): + self._value: T = initial_value + self._lock: Lock = Lock() + + def get(self) -> T: + with self._lock: + return self._value + + def set(self, new_value: T) -> None: + with self._lock: + self._value = new_value diff --git a/aws_advanced_python_wrapper/utils/properties.py b/aws_advanced_python_wrapper/utils/properties.py index 45aded90..0b75b6ff 100644 --- a/aws_advanced_python_wrapper/utils/properties.py +++ b/aws_advanced_python_wrapper/utils/properties.py @@ -73,7 +73,7 @@ def set(self, props: Properties, value: Any): class WrapperProperties: - DEFAULT_PLUGINS = "aurora_connection_tracker,failover,host_monitoring" + DEFAULT_PLUGINS = "aurora_connection_tracker,failover,host_monitoring_v2" _DEFAULT_TOKEN_EXPIRATION_SEC = 15 * 60 PROFILE_NAME = WrapperProperty("profile_name", "Driver configuration profile name", None) diff --git a/docs/using-the-python-driver/UsingThePythonDriver.md b/docs/using-the-python-driver/UsingThePythonDriver.md index 7a566809..c33e4e95 100644 --- a/docs/using-the-python-driver/UsingThePythonDriver.md +++ b/docs/using-the-python-driver/UsingThePythonDriver.md @@ -64,7 +64,7 @@ The AWS Advanced Python Driver has several built-in plugins that are available t | Plugin name | Plugin Code | Database Compatibility | Description | Additional Required Dependencies | |--------------------------------------------------------------------------------------------------------|-----------------------------|--------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------| | [Failover Connection Plugin](./using-plugins/UsingTheFailoverPlugin.md) | `failover` | Aurora | Enables the failover functionality supported by Amazon Aurora clusters. Prevents opening a wrong connection to an old writer host dues to stale DNS after failover event. This plugin is enabled by default. | None | -| [Host Monitoring Connection Plugin](./using-plugins/UsingTheHostMonitoringPlugin.md) | `host_monitoring` | Aurora | Enables enhanced host connection failure monitoring, allowing faster failure detection rates. This plugin is enabled by default. | None | +| [Host Monitoring Connection Plugin](./using-plugins/UsingTheHostMonitoringPlugin.md) | `host_monitoring_v2` or `host_monitoring` | Aurora | Enables enhanced host connection failure monitoring, allowing faster failure detection rates. This plugin is enabled by default. | None | | [IAM Authentication Connection Plugin](./using-plugins/UsingTheIamAuthenticationPlugin.md) | `iam` | Any database | Enables users to connect to their Amazon Aurora clusters using AWS Identity and Access Management (IAM). | [Boto3 - AWS SDK for Python](https://aws.amazon.com/sdk-for-python/) | | [AWS Secrets Manager Connection Plugin](./using-plugins/UsingTheAwsSecretsManagerPlugin.md) | `aws_secrets_manager` | Any database | Enables fetching database credentials from the AWS Secrets Manager service. | [Boto3 - AWS SDK for Python](https://aws.amazon.com/sdk-for-python/) | | [Federated Authentication Connection Plugin](./using-plugins/UsingTheFederatedAuthenticationPlugin.md) | `federated_auth` | Any database | Enables users to authenticate via Federated Identity and then database access via IAM. | [Boto3 - AWS SDK for Python](https://aws.amazon.com/sdk-for-python/) | diff --git a/docs/using-the-python-driver/using-plugins/UsingTheFastestResponseStrategyPlugin.md b/docs/using-the-python-driver/using-plugins/UsingTheFastestResponseStrategyPlugin.md index 5aeb43b9..55283968 100644 --- a/docs/using-the-python-driver/using-plugins/UsingTheFastestResponseStrategyPlugin.md +++ b/docs/using-the-python-driver/using-plugins/UsingTheFastestResponseStrategyPlugin.md @@ -13,7 +13,7 @@ The plugin can be loaded by adding the plugin code `fastest_response_strategy` t ```python params = { - "plugins": "read_write_splitting,fastest_response_strategy,failover,host_monitoring", + "plugins": "read_write_splitting,fastest_response_strategy,failover,host_monitoring_v2", "reader_response_strategy": "fastest_response" # Add other connection properties below... } diff --git a/docs/using-the-python-driver/using-plugins/UsingTheHostMonitoringPlugin.md b/docs/using-the-python-driver/using-plugins/UsingTheHostMonitoringPlugin.md index e9c996c3..a1a19ad4 100644 --- a/docs/using-the-python-driver/using-plugins/UsingTheHostMonitoringPlugin.md +++ b/docs/using-the-python-driver/using-plugins/UsingTheHostMonitoringPlugin.md @@ -79,3 +79,19 @@ conn = AwsWrapperConnection.connect( > We recommend you either disable the Host Monitoring Connection Plugin or avoid using RDS Proxy endpoints when the Host Monitoring Connection Plugin is active. > > Although using RDS Proxy endpoints with the AWS Advanced Python Driver with Enhanced Failure Monitoring doesn't cause any critical issues, we don't recommend this approach. The main reason is that RDS Proxy transparently re-routes requests to a single database instance. RDS Proxy decides which database instance is used based on many criteria (on a per-request basis). Switching between different instances makes the Host Monitoring Connection Plugin useless in terms of instance health monitoring because the plugin will be unable to identify which instance it's connected to, and which one it's monitoring. This could result in false positive failure detections. At the same time, the plugin will still proactively monitor network connectivity to RDS Proxy endpoints and report outages back to a user application if they occur. + +# Host Monitoring Plugin v2 + +Host Monitoring Plugin v2, also known as `host_monitoring_v2`, is an alternative implementation of enhanced failure monitoring and it is functionally equivalent to the Host Monitoring Plugin described above. Both plugins share the same set of [configuration parameters](#enhanced-failure-monitoring-parameters). The `host_monitoring_v2` plugin is designed to be a drop-in replacement for the `host_monitoring` plugin. +The `host_monitoring_v2` plugin can be used in any scenario where the `host_monitoring` plugin is mentioned. This plugin is enabled by default. The original EFM plugin can still be used by specifying `host_monitoring` in the `plugins` parameter. + +> [!NOTE]\ +> Since these two plugins are separate plugins, users may decide to use them together with a single connection. While this should not have any negative side effects, it is not recommended. It is recommended to use either the `host_monitoring_v2` plugin, or the `host_monitoring` plugin where it's needed. + + +The `host_monitoring_v2` plugin is designed to address [some of the issues](https://github.com/aws/aws-advanced-jdbc-wrapper/issues/675) that have been reported by multiple users. The following changes have been made: +- Used weak pointers to ease garbage collection +- Split monitoring logic into two separate threads to increase overall monitoring stability +- Reviewed locks for monitoring context +- Reviewed and redesigned stopping of idle monitoring threads +- Reviewed and simplified monitoring logic diff --git a/docs/using-the-python-driver/using-plugins/UsingTheReadWriteSplittingPlugin.md b/docs/using-the-python-driver/using-plugins/UsingTheReadWriteSplittingPlugin.md index 678dea00..9d1bdc08 100644 --- a/docs/using-the-python-driver/using-plugins/UsingTheReadWriteSplittingPlugin.md +++ b/docs/using-the-python-driver/using-plugins/UsingTheReadWriteSplittingPlugin.md @@ -8,7 +8,7 @@ The Read/Write Splitting Plugin is not loaded by default. To load the plugin, in ```python params = { - "plugins": "read_write_splitting,failover,host_monitoring", + "plugins": "read_write_splitting,failover,host_monitoring_v2", # Add other connection properties below... } diff --git a/tests/integration/container/test_aurora_failover.py b/tests/integration/container/test_aurora_failover.py index 71524bee..db8dfbe9 100644 --- a/tests/integration/container/test_aurora_failover.py +++ b/tests/integration/container/test_aurora_failover.py @@ -116,6 +116,7 @@ def test_fail_from_writer_to_new_writer_fail_on_connection_bound_object_invocati assert aurora_utility.is_db_instance_writer(current_connection_id) is True assert current_connection_id != initial_writer_id + @pytest.mark.parametrize("plugins", ["failover,host_monitoring", "failover,host_monitoring_v2"]) @enable_on_features([TestEnvironmentFeatures.NETWORK_OUTAGES_ENABLED, TestEnvironmentFeatures.ABORT_CONNECTION_SUPPORTED]) def test_fail_from_reader_to_writer( @@ -124,12 +125,13 @@ def test_fail_from_reader_to_writer( test_driver: TestDriver, conn_utils, proxied_props, - aurora_utility): + aurora_utility, + plugins): target_driver_connect = DriverHelper.get_connect_func(test_driver) reader: TestInstanceInfo = test_environment.get_proxy_instances()[1] writer_id: str = test_environment.get_proxy_writer().get_instance_id() - proxied_props["plugins"] = "failover,host_monitoring" + proxied_props["plugins"] = plugins with AwsWrapperConnection.connect( target_driver_connect, **conn_utils.get_proxy_connect_params(reader.get_host()), diff --git a/tests/integration/container/test_basic_connectivity.py b/tests/integration/container/test_basic_connectivity.py index 6745eac6..aa70e108 100644 --- a/tests/integration/container/test_basic_connectivity.py +++ b/tests/integration/container/test_basic_connectivity.py @@ -126,15 +126,16 @@ def test_proxied_wrapper_connection_failed( # That is expected exception. Test pass. assert True + @pytest.mark.parametrize("plugins", ["host_monitoring", "host_monitoring_v2"]) @enable_on_num_instances(min_instances=2) @enable_on_deployments([DatabaseEngineDeployment.AURORA, DatabaseEngineDeployment.RDS_MULTI_AZ_CLUSTER]) @enable_on_features([TestEnvironmentFeatures.ABORT_CONNECTION_SUPPORTED]) - def test_wrapper_connection_reader_cluster_with_efm_enabled(self, test_driver: TestDriver, conn_utils): + def test_wrapper_connection_reader_cluster_with_efm_enabled(self, test_driver: TestDriver, conn_utils, plugins): target_driver_connect = DriverHelper.get_connect_func(test_driver) conn = AwsWrapperConnection.connect( target_driver_connect, **conn_utils.get_connect_params(conn_utils.reader_cluster_host), - plugins="host_monitoring", connect_timeout=10) + plugins=plugins, connect_timeout=10) cursor = conn.cursor() cursor.execute("SELECT 1") result = cursor.fetchone() diff --git a/tests/integration/container/test_failover_performance.py b/tests/integration/container/test_failover_performance.py index 1c1a54ca..a2cde066 100644 --- a/tests/integration/container/test_failover_performance.py +++ b/tests/integration/container/test_failover_performance.py @@ -18,7 +18,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from logging import getLogger from time import perf_counter_ns, sleep -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional import pytest @@ -131,8 +131,9 @@ def props(self): return props + @pytest.mark.parametrize("plugins", ["host_monitoring", "host_monitoring_v2"]) def test_failure_detection_time_efm(self, test_environment: TestEnvironment, test_driver: TestDriver, conn_utils, - props: Properties): + props: Properties, plugins): enhanced_failure_monitoring_perf_data_list: List[PerfStatBase] = [] target_driver_connect_func = DriverHelper.get_connect_func(test_driver) try: @@ -147,7 +148,7 @@ def test_failure_detection_time_efm(self, test_environment: TestEnvironment, tes WrapperProperties.FAILURE_DETECTION_TIME_MS.set(props, str(detection_time)) WrapperProperties.FAILURE_DETECTION_INTERVAL_MS.set(props, str(detection_interval)) WrapperProperties.FAILURE_DETECTION_COUNT.set(props, str(detection_count)) - WrapperProperties.PLUGINS.set(props, "host_monitoring") + WrapperProperties.PLUGINS.set(props, plugins) data: PerfStatMonitoring = PerfStatMonitoring() self._measure_performance(test_environment, target_driver_connect_func, conn_utils, sleep_delay_sec, props, data) @@ -159,11 +160,13 @@ def test_failure_detection_time_efm(self, test_environment: TestEnvironment, tes PerformanceUtil.write_perf_data_to_file( f"/app/tests/integration/container/reports/" f"DbEngine_{test_environment.get_engine()}_" + f"Plugins_{plugins}_" f"FailureDetectionPerformanceResults_EnhancedMonitoringEnabled.csv", TestPerformance.PERF_STAT_MONITORING_HEADER, enhanced_failure_monitoring_perf_data_list) + @pytest.mark.parametrize("plugins", ["failover,host_monitoring", "failover,host_monitoring_v2"]) def test_failure_detection_time_failover_and_efm(self, test_environment: TestEnvironment, test_driver: TestDriver, conn_utils, - props: Properties): + props: Properties, plugins): enhanced_failure_monitoring_perf_data_list: List[PerfStatBase] = [] try: for i in range(len(TestPerformance.failure_detection_time_params)): @@ -177,7 +180,7 @@ def test_failure_detection_time_failover_and_efm(self, test_environment: TestEnv WrapperProperties.FAILURE_DETECTION_TIME_MS.set(props, str(detection_time)) WrapperProperties.FAILURE_DETECTION_INTERVAL_MS.set(props, str(detection_interval)) WrapperProperties.FAILURE_DETECTION_COUNT.set(props, str(detection_count)) - WrapperProperties.PLUGINS.set(props, "failover,host_monitoring") + WrapperProperties.PLUGINS.set(props, plugins) WrapperProperties.FAILOVER_TIMEOUT_SEC.set(props, TestPerformance.PERF_FAILOVER_TIMEOUT_SEC) WrapperProperties.FAILOVER_MODE.set(props, "strict_reader") @@ -191,6 +194,7 @@ def test_failure_detection_time_failover_and_efm(self, test_environment: TestEnv PerformanceUtil.write_perf_data_to_file( f"/app/tests/integration/container/reports/" f"DbEngine_{test_environment.get_engine()}_" + f"Plugins_{plugins}_" f"FailureDetectionPerformanceResults_FailoverAndEnhancedMonitoringEnabled.csv", TestPerformance.PERF_STAT_MONITORING_HEADER, enhanced_failure_monitoring_perf_data_list) @@ -205,12 +209,14 @@ def _measure_performance( query: str = "SELECT pg_sleep(600)" downtime: AtomicInt = AtomicInt() elapsed_times: List[int] = [] - connection_str = conn_utils.get_proxy_conn_string(test_environment.get_proxy_writer().get_host()) for _ in range(TestPerformance.REPEAT_TIMES): downtime.set(0) - with self._open_connect_with_retry(connect_func, connection_str, props) as aws_conn, ThreadPoolExecutor() as executor: + with self._open_connect_with_retry(connect_func, + conn_utils.get_proxy_connect_params( + test_environment.get_proxy_writer().get_host()), + props) as aws_conn, ThreadPoolExecutor() as executor: try: futures = [ executor.submit(self._stop_network_thread, test_environment, sleep_delay_sec, downtime), @@ -236,21 +242,21 @@ def _measure_performance( data.max_failure_detection_time_millis = PerformanceUtil.to_millis(max_val) data.avg_failure_detection_time_millis = PerformanceUtil.to_millis(avg_val) - def _open_connect_with_retry(self, connect_func, conn_str: str, props: Properties): + def _open_connect_with_retry(self, connect_func, connect_params: Dict[str, Any], props: Properties): connection_attempts: int = 0 conn: Optional[Connection] = None while conn is None and connection_attempts < 10: try: conn = AwsWrapperConnection.connect( connect_func, - conn_str, + **connect_params, **props) except Exception as e: TestPerformance.logger.debug("OpenConnectionFailed", str(e)) connection_attempts += 1 if conn is None: - pytest.fail(f"Unable to connect to {conn_str}") + pytest.fail(f"Unable to connect to {connect_params}") return conn def _stop_network_thread(self, test_environment: TestEnvironment, sleep_delay_seconds: int, downtime: AtomicInt): diff --git a/tests/integration/container/test_read_write_splitting.py b/tests/integration/container/test_read_write_splitting.py index de75badb..f0bf34aa 100644 --- a/tests/integration/container/test_read_write_splitting.py +++ b/tests/integration/container/test_read_write_splitting.py @@ -346,14 +346,15 @@ def test_failover_to_new_writer__switch_read_only( current_id = rds_utils.query_instance_id(conn) assert new_writer_id == current_id + @pytest.mark.parametrize("plugins", ["read_write_splitting,failover,host_monitoring_v2"]) @enable_on_features([TestEnvironmentFeatures.NETWORK_OUTAGES_ENABLED, TestEnvironmentFeatures.ABORT_CONNECTION_SUPPORTED]) @enable_on_num_instances(min_instances=3) @disable_on_engines([DatabaseEngine.MYSQL]) def test_failover_to_new_reader__switch_read_only( self, test_environment: TestEnvironment, test_driver: TestDriver, - proxied_failover_props, conn_utils, rds_utils): - WrapperProperties.PLUGINS.set(proxied_failover_props, "read_write_splitting,failover,host_monitoring") + proxied_failover_props, conn_utils, rds_utils, plugins): + WrapperProperties.PLUGINS.set(proxied_failover_props, plugins) WrapperProperties.FAILOVER_MODE.set(proxied_failover_props, "reader-or-writer") target_driver_connect = DriverHelper.get_connect_func(test_driver) @@ -394,14 +395,16 @@ def test_failover_to_new_reader__switch_read_only( current_id = rds_utils.query_instance_id(conn) assert other_reader_id == current_id + @pytest.mark.parametrize("plugins", ["read_write_splitting,failover,host_monitoring", + "read_write_splitting,failover,host_monitoring_v2"]) @enable_on_features([TestEnvironmentFeatures.NETWORK_OUTAGES_ENABLED, TestEnvironmentFeatures.ABORT_CONNECTION_SUPPORTED]) @enable_on_num_instances(min_instances=3) @disable_on_engines([DatabaseEngine.MYSQL]) def test_failover_reader_to_writer__switch_read_only( self, test_environment: TestEnvironment, test_driver: TestDriver, - proxied_failover_props, conn_utils, rds_utils): - WrapperProperties.PLUGINS.set(proxied_failover_props, "read_write_splitting,failover,host_monitoring") + proxied_failover_props, conn_utils, rds_utils, plugins): + WrapperProperties.PLUGINS.set(proxied_failover_props, plugins) target_driver_connect = DriverHelper.get_connect_func(test_driver) conn = AwsWrapperConnection.connect( target_driver_connect, **conn_utils.get_proxy_connect_params(), **proxied_failover_props) @@ -513,17 +516,19 @@ def test_pooled_connection__cluster_url_failover( new_driver_conn = conn.target_connection assert initial_driver_conn is not new_driver_conn + @pytest.mark.parametrize("plugins", ["read_write_splitting,failover,host_monitoring", + "read_write_splitting,failover,host_monitoring_v2"]) @enable_on_features([TestEnvironmentFeatures.FAILOVER_SUPPORTED, TestEnvironmentFeatures.NETWORK_OUTAGES_ENABLED, TestEnvironmentFeatures.ABORT_CONNECTION_SUPPORTED]) @disable_on_engines([DatabaseEngine.MYSQL]) def test_pooled_connection__failover_failed( self, test_environment: TestEnvironment, test_driver: TestDriver, - rds_utils, conn_utils, proxied_failover_props): + rds_utils, conn_utils, proxied_failover_props, plugins): writer_host = test_environment.get_writer().get_host() provider = SqlAlchemyPooledConnectionProvider(lambda _, __: {"pool_size": 1}, None, lambda host_info, props: writer_host in host_info.host) ConnectionProviderManager.set_connection_provider(provider) - WrapperProperties.PLUGINS.set(proxied_failover_props, "read_write_splitting,failover,host_monitoring") + WrapperProperties.PLUGINS.set(proxied_failover_props, plugins) WrapperProperties.FAILOVER_TIMEOUT_SEC.set(proxied_failover_props, "1") WrapperProperties.FAILURE_DETECTION_TIME_MS.set(proxied_failover_props, "1000") WrapperProperties.FAILURE_DETECTION_COUNT.set(proxied_failover_props, "1") diff --git a/tests/integration/container/test_read_write_splitting_performance.py b/tests/integration/container/test_read_write_splitting_performance.py index 9765f6bd..7c566e70 100644 --- a/tests/integration/container/test_read_write_splitting_performance.py +++ b/tests/integration/container/test_read_write_splitting_performance.py @@ -196,13 +196,14 @@ def _measure_performance( connect_func, conn_utils, read_write_plugin_props: Properties) -> Result: - switch_to_reader_elapsed_times: List[int] = [] switch_to_writer_elapsed_times: List[int] = [] - conn_str = conn_utils.get_conn_string(test_environment.get_writer().get_host()) for _ in range(TestReadWriteSplittingPerformance.REPEAT_TIMES): - with AwsWrapperConnection.connect(connect_func, conn_str, **read_write_plugin_props) as aws_conn: + with AwsWrapperConnection.connect(connect_func, + **conn_utils.get_connect_params( + test_environment.get_writer().get_host()), + **read_write_plugin_props) as aws_conn: ConnectTimePlugin.reset_connect_time() ExecuteTimePlugin.reset_execute_time() diff --git a/tests/unit/test_host_monitor_v2_plugin.py b/tests/unit/test_host_monitor_v2_plugin.py new file mode 100644 index 00000000..194faadd --- /dev/null +++ b/tests/unit/test_host_monitor_v2_plugin.py @@ -0,0 +1,342 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from queue import Queue +from unittest.mock import MagicMock, patch + +import psycopg +import pytest + +from aws_advanced_python_wrapper.errors import AwsWrapperError +from aws_advanced_python_wrapper.host_monitoring_v2_plugin import ( + HostMonitoringV2Plugin, HostMonitorV2, MonitoringContext) +from aws_advanced_python_wrapper.hostinfo import HostInfo +from aws_advanced_python_wrapper.pep249 import Error +from aws_advanced_python_wrapper.utils.properties import (Properties, + WrapperProperties) +from aws_advanced_python_wrapper.utils.telemetry.telemetry import \ + TelemetryCounter + +FAILURE_DETECTION_TIME_MS = 1000 +FAILURE_DETECTION_INTERVAL_MS = 5000 +FAILURE_DETECTION_COUNT = 3 + + +@pytest.fixture +def mock_driver_dialect(mocker): + driver_dialect_mock = mocker.MagicMock() + driver_dialect_mock.supports_socket_timeout.return_value = True + driver_dialect_mock.is_closed.return_value = True + return driver_dialect_mock + + +@pytest.fixture +def mock_plugin_service(mocker, mock_driver_dialect, mock_conn, host_info): + service_mock = mocker.MagicMock() + service_mock.current_connection = mock_conn + service_mock.current_host_info = host_info + type(service_mock).driver_dialect = mocker.PropertyMock(return_value=mock_driver_dialect) + return service_mock + + +@pytest.fixture +def mock_conn(mocker): + return mocker.MagicMock(spec=psycopg.Connection) + + +@pytest.fixture +def mock_execute_func(mocker): + return mocker.MagicMock() + + +@pytest.fixture +def monitoring_context(mock_conn): + return MonitoringContext(mock_conn) + + +@pytest.fixture +def mock_monitor_service(mocker, monitoring_context): + monitor_service_mock = mocker.MagicMock() + monitor_service_mock.start_monitoring.return_value = monitoring_context + return monitor_service_mock + + +@pytest.fixture +def mock_telemetry_counter(): + return MagicMock(spec=TelemetryCounter) + + +@pytest.fixture +def host_info(): + return HostInfo("my_database.cluster-xyz.us-east-2.rds.amazonaws.com") + + +@pytest.fixture +def props(): + return Properties() + + +@pytest.fixture +def host_monitor( + mock_plugin_service, + host_info, + props, + mock_telemetry_counter): + return HostMonitorV2( + plugin_service=mock_plugin_service, + host_info=host_info, + props=props, + failure_detection_time_ms=FAILURE_DETECTION_TIME_MS, + failure_detection_interval_ms=FAILURE_DETECTION_INTERVAL_MS, + failure_detection_count=FAILURE_DETECTION_COUNT, + aborted_connection_counter=mock_telemetry_counter + ) + + +@pytest.fixture +def plugin(mock_plugin_service, props, mock_monitor_service): + return init_plugin(mock_plugin_service, props, mock_monitor_service) + + +def init_plugin(plugin_service, props, mock_monitor_service): + plugin = HostMonitoringV2Plugin(plugin_service, props) + plugin._monitor_service = mock_monitor_service + return plugin + + +def test_init_no_query_timeout(mock_plugin_service, mock_driver_dialect, props): + mock_driver_dialect.supports_abort_connection.return_value = False + + with pytest.raises(AwsWrapperError): + HostMonitoringV2Plugin(mock_plugin_service, props) + + +def test_execute_null_connection_info( + mocker, plugin, mock_plugin_service, mock_monitor_service, props, mock_execute_func, mock_conn): + mock_plugin_service.current_connection = None + with pytest.raises(AwsWrapperError): + plugin.execute(mocker.MagicMock(), "Cursor.execute", mock_execute_func, "SELECT 1") + mock_execute_func.assert_not_called() + + mock_plugin_service.current_connection = mock_conn + mock_plugin_service.current_host_info = None + with pytest.raises(AwsWrapperError): + plugin.execute(mocker.MagicMock(), "Cursor.execute", mock_execute_func, "SELECT 1") + mock_execute_func.assert_not_called() + + +def test_execute_monitoring_disabled(mocker, mock_plugin_service, props, mock_monitor_service, mock_execute_func): + WrapperProperties.FAILURE_DETECTION_ENABLED.set(props, "False") + plugin = init_plugin(mock_plugin_service, props, mock_monitor_service) + plugin.execute(mocker.MagicMock(), "Cursor.execute", mock_execute_func, "SELECT 1") + + mock_monitor_service.start_monitoring.assert_not_called() + mock_monitor_service.stop_monitoring.assert_not_called() + mock_execute_func.assert_called_once() + + +def test_execute_non_network_method(mocker, plugin, mock_execute_func): + mock_monitor_service = mocker.MagicMock() + mock_monitor_service.network_bound_methods = {"foo"} + plugin.execute(mocker.MagicMock(), "Connection.cancel", mock_execute_func) + + mock_monitor_service.start_monitoring.assert_not_called() + mock_monitor_service.stop_monitoring.assert_not_called() + mock_execute_func.assert_called_once() + + +def test_execute_monitoring_enabled(mocker, plugin, mock_monitor_service, mock_execute_func): + plugin.execute(mocker.MagicMock(), "Cursor.execute", mock_execute_func, "SELECT 1") + + mock_monitor_service.start_monitoring.assert_called_once() + mock_monitor_service.stop_monitoring.assert_called_once() + mock_execute_func.assert_called_once() + + +def test_connect(mocker, plugin, host_info, props, mock_conn, mock_plugin_service): + mock_connect_func = mocker.MagicMock() + mock_connect_func.return_value = mock_conn + + connection = plugin.connect( + mocker.MagicMock(), mocker.MagicMock(), host_info, props, True, mock_connect_func) + mock_plugin_service.fill_aliases.assert_called_once() + assert mock_conn == connection + + +def test_get_monitoring_host_info_errors(mocker, plugin, mock_plugin_service): + mock_plugin_service.identify_connection.return_value = None + + with pytest.raises(AwsWrapperError): + plugin._get_monitoring_host_info() + + expected_exception = Error() + mock_plugin_service.identify_connection.return_value = mocker.MagicMock() + mock_plugin_service.identify_connection.side_effect = expected_exception + with pytest.raises(AwsWrapperError) as exc_info: + plugin._get_monitoring_host_info() + + assert expected_exception == exc_info.value.__cause__ + + +def test_release_resources(plugin, mock_monitor_service): + plugin.release_resources() + mock_monitor_service.release_resources.assert_called_once() + assert plugin._monitor_service is None + + +def test_set_node_unhealthy(monitoring_context): + assert monitoring_context._host_unhealthy.get() is False + monitoring_context.set_host_unhealthy() + assert monitoring_context._host_unhealthy.get() is True + + +def test_should_abort_when_healthy(monitoring_context): + assert monitoring_context.should_abort() is False + + +def test_should_abort_when_unhealthy(monitoring_context): + monitoring_context.set_host_unhealthy() + assert monitoring_context.should_abort() is True + + +def test_should_abort_when_inactive(monitoring_context): + monitoring_context.set_host_unhealthy() + monitoring_context.set_inactive() + assert monitoring_context.should_abort() is False + + +def test_set_inactive(monitoring_context): + monitoring_context.set_inactive() + assert monitoring_context.get_connection() is None + assert monitoring_context.is_active() is False + + +def test_get_connection_when_active(monitoring_context, mock_conn): + assert monitoring_context.get_connection() is mock_conn + + +def test_get_connection_when_inactive(monitoring_context): + monitoring_context.set_inactive() + assert monitoring_context.get_connection() is None + + +def test_is_active_when_active(monitoring_context): + assert monitoring_context.is_active() is True + + +def test_is_active_when_inactive(monitoring_context): + monitoring_context.set_inactive() + assert monitoring_context.is_active() is False + + +def test_can_dispose_none_empty_active_context(host_monitor): + assert host_monitor.can_dispose() is True + + host_monitor._active_contexts.put(MagicMock()) + assert host_monitor.can_dispose() is False + + +def test_can_dispose_none_new_contexts_context(host_monitor): + assert host_monitor.can_dispose() is True + + host_monitor._new_contexts.compute_if_absent(1, lambda key: Queue()) + assert host_monitor.can_dispose() is False + + +def test_is_stopped(host_monitor): + assert host_monitor.is_stopped is False + host_monitor._is_stopped.set(True) + assert host_monitor.is_stopped is True + + +def test_stop(host_monitor): + host_monitor.stop() + assert host_monitor.is_stopped is True + + +@patch.object(HostMonitorV2, '_is_host_available') +def test_check_connection_status_new_connection( + mock_is_host_available, + host_monitor, + mock_plugin_service, + mock_conn): + mock_plugin_service.force_connect.return_value = mock_conn + mock_is_host_available.return_value = True + + result = host_monitor.check_connection_status() + + assert result is True + mock_plugin_service.force_connect.assert_called_once() + assert host_monitor._monitoring_connection == mock_conn + mock_is_host_available.assert_not_called() + + +@patch.object(HostMonitorV2, '_execute_conn_check') +def test_is_host_available_success(mock_execute, host_monitor, mock_conn): + mock_execute.return_value = True + assert host_monitor._is_host_available(mock_conn, 1.0) is True + + +@patch.object(HostMonitorV2, '_execute_conn_check') +def test_is_host_available_timeout(mock_execute, host_monitor, mock_conn): + mock_execute.side_effect = TimeoutError() + assert host_monitor._is_host_available(mock_conn, 1.0) is False + + +@patch.object(HostMonitorV2, '_execute_conn_check') +def test_is_host_available_operational_error(mock_execute, host_monitor, mock_conn): + mock_execute.side_effect = psycopg.OperationalError() + assert host_monitor._is_host_available(mock_conn, 1.0) is False + + +def test_execute_conn_check(mocker, host_monitor, mock_conn, mock_plugin_service): + mock_cursor = mocker.MagicMock() + mock_conn.cursor.return_value.__enter__.return_value = mock_cursor + + host_monitor._execute_conn_check(mock_conn, 1.0) + + mock_plugin_service.driver_dialect.execute.assert_called_once() + mock_cursor.fetchone.assert_called_once() + + +def test_update_node_health_status_multiple_failures(host_monitor): + host_monitor._failure_count = 2 + host_monitor._invalid_host_start_time_ns = 10 ** 9 + + host_monitor._update_host_health_status(False, 10 ** 9, 11 * 10 ** 9) + + assert host_monitor._is_unhealthy is True + + +def test_update_node_health_status_recovery(host_monitor): + host_monitor._failure_count = 1 + host_monitor._invalid_host_start_time_ns = 10 ** 9 + + host_monitor._update_host_health_status(True, 2 * 10 ** 9, 2.1 * 10 ** 9) + + assert host_monitor._failure_count == 0 + assert host_monitor._invalid_host_start_time_ns == 0 + assert host_monitor._is_unhealthy is False + + +def test_abort_connection_success(host_monitor, mock_conn): + host_monitor.abort_connection(mock_conn) + mock_conn.close.assert_called_once() + + +@patch('aws_advanced_python_wrapper.host_monitoring_v2_plugin.logger') +def test_abort_connection_failure(mock_logger, host_monitor, mock_conn): + mock_conn.close.side_effect = AwsWrapperError("Error") + host_monitor.abort_connection(mock_conn) + mock_logger.debug.assert_called()