From 037ed3c0ea28fb6ba2bf37c1d3bca475eaa31af4 Mon Sep 17 00:00:00 2001 From: leszekamz Date: Thu, 17 Jul 2025 14:11:14 -0700 Subject: [PATCH 01/12] Add support for DSQL iam authentication --- .../dsql_iam_auth_plugin_factory.py | 29 +++++++ aws_advanced_python_wrapper/iam_plugin.py | 11 ++- aws_advanced_python_wrapper/plugin_service.py | 3 + .../utils/dsql_token_utils.py | 60 ++++++++++++++ .../utils/iam_utils.py | 47 ----------- .../utils/rds_token_utils.py | 78 +++++++++++++++++++ .../utils/rds_url_type.py | 1 + aws_advanced_python_wrapper/utils/rdsutils.py | 14 ++++ .../utils/token_utils.py | 35 +++++++++ docs/examples/DSQLIamAuthentication.py | 39 ++++++++++ 10 files changed, 266 insertions(+), 51 deletions(-) create mode 100644 aws_advanced_python_wrapper/dsql_iam_auth_plugin_factory.py create mode 100644 aws_advanced_python_wrapper/utils/dsql_token_utils.py create mode 100644 aws_advanced_python_wrapper/utils/rds_token_utils.py create mode 100644 aws_advanced_python_wrapper/utils/token_utils.py create mode 100644 docs/examples/DSQLIamAuthentication.py diff --git a/aws_advanced_python_wrapper/dsql_iam_auth_plugin_factory.py b/aws_advanced_python_wrapper/dsql_iam_auth_plugin_factory.py new file mode 100644 index 00000000..05ca33f8 --- /dev/null +++ b/aws_advanced_python_wrapper/dsql_iam_auth_plugin_factory.py @@ -0,0 +1,29 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from aws_advanced_python_wrapper.utils.dsql_token_utils import DSQLTokenUtils +from aws_advanced_python_wrapper.iam_plugin import IamAuthPlugin +from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory +from aws_advanced_python_wrapper.utils.properties import (Properties) + +if TYPE_CHECKING: + from aws_advanced_python_wrapper.plugin_service import PluginService + +class DsqlIamAuthPluginFactory(PluginFactory): + def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: + return IamAuthPlugin(plugin_service, DSQLTokenUtils()) diff --git a/aws_advanced_python_wrapper/iam_plugin.py b/aws_advanced_python_wrapper/iam_plugin.py index 1a26c58a..7128c572 100644 --- a/aws_advanced_python_wrapper/iam_plugin.py +++ b/aws_advanced_python_wrapper/iam_plugin.py @@ -18,6 +18,8 @@ from aws_advanced_python_wrapper.utils.iam_utils import IamAuthUtils, TokenInfo from aws_advanced_python_wrapper.utils.region_utils import RegionUtils +from aws_advanced_python_wrapper.utils.token_utils import TokenUtils +from aws_advanced_python_wrapper.utils.rds_token_utils import RDSTokenUtils if TYPE_CHECKING: from boto3 import Session @@ -48,11 +50,12 @@ class IamAuthPlugin(Plugin): _rds_utils: RdsUtils = RdsUtils() _token_cache: Dict[str, TokenInfo] = {} - def __init__(self, plugin_service: PluginService, session: Optional[Session] = None): + def __init__(self, plugin_service: PluginService, token_utils: TokenUtils, session: Optional[Session] = None): self._plugin_service = plugin_service self._session = session self._region_utils = RegionUtils() + self._token_utils = token_utils telemetry_factory = self._plugin_service.get_telemetry_factory() self._fetch_token_counter = telemetry_factory.create_counter("iam.fetch_token.count") self._cache_size_gauge = telemetry_factory.create_gauge( @@ -102,7 +105,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl else: token_expiry = datetime.now() + timedelta(seconds=token_expiration_sec) self._fetch_token_counter.inc() - token: str = IamAuthUtils.generate_authentication_token(self._plugin_service, user, host, port, region, client_session=self._session) + token: str = self._token_utils.generate_authentication_token(self._plugin_service, user, host, port, region, client_session=self._session) self._plugin_service.driver_dialect.set_password(props, token) IamAuthPlugin._token_cache[cache_key] = TokenInfo(token, token_expiry) @@ -120,7 +123,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl # Try to generate a new token and try to connect again token_expiry = datetime.now() + timedelta(seconds=token_expiration_sec) self._fetch_token_counter.inc() - token = IamAuthUtils.generate_authentication_token(self._plugin_service, user, host, port, region, client_session=self._session) + token = self._token_utils.generate_authentication_token(self._plugin_service, user, host, port, region, client_session=self._session) self._plugin_service.driver_dialect.set_password(props, token) IamAuthPlugin._token_cache[cache_key] = TokenInfo(token, token_expiry) @@ -142,4 +145,4 @@ def force_connect( class IamAuthPluginFactory(PluginFactory): def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: - return IamAuthPlugin(plugin_service) + return IamAuthPlugin(plugin_service, RDSTokenUtils()) diff --git a/aws_advanced_python_wrapper/plugin_service.py b/aws_advanced_python_wrapper/plugin_service.py index 4be4fea3..b38a26b2 100644 --- a/aws_advanced_python_wrapper/plugin_service.py +++ b/aws_advanced_python_wrapper/plugin_service.py @@ -73,6 +73,7 @@ HostMonitoringPluginFactory from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole from aws_advanced_python_wrapper.iam_plugin import IamAuthPluginFactory +from aws_advanced_python_wrapper.dsql_iam_auth_plugin_factory import DsqlIamAuthPluginFactory from aws_advanced_python_wrapper.plugin import CanReleaseResources from aws_advanced_python_wrapper.read_write_splitting_plugin import \ ReadWriteSplittingPluginFactory @@ -716,6 +717,7 @@ class PluginManager(CanReleaseResources): PLUGIN_FACTORIES: Dict[str, Type[PluginFactory]] = { "iam": IamAuthPluginFactory, + "iam_dsql": DsqlIamAuthPluginFactory, "aws_secrets_manager": AwsSecretsManagerPluginFactory, "aurora_connection_tracker": AuroraConnectionTrackerPluginFactory, "host_monitoring": HostMonitoringPluginFactory, @@ -748,6 +750,7 @@ class PluginManager(CanReleaseResources): HostMonitoringPluginFactory: 500, FastestResponseStrategyPluginFactory: 600, IamAuthPluginFactory: 700, + DsqlIamAuthPluginFactory: 710, AwsSecretsManagerPluginFactory: 800, FederatedAuthPluginFactory: 900, LimitlessPluginFactory: 950, diff --git a/aws_advanced_python_wrapper/utils/dsql_token_utils.py b/aws_advanced_python_wrapper/utils/dsql_token_utils.py new file mode 100644 index 00000000..d337f1e1 --- /dev/null +++ b/aws_advanced_python_wrapper/utils/dsql_token_utils.py @@ -0,0 +1,60 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Dict, Optional +from aws_advanced_python_wrapper.utils.log import Logger +from aws_advanced_python_wrapper.utils.telemetry.telemetry import \ + TelemetryTraceLevel +from aws_advanced_python_wrapper.utils.token_utils import TokenUtils + +if TYPE_CHECKING: + from aws_advanced_python_wrapper.plugin_service import PluginService +from boto3 import Session + +import boto3 + +logger = Logger(__name__) + +class DSQLTokenUtils(TokenUtils): + def generate_authentication_token( + self, + plugin_service: PluginService, + user: Optional[str], + host_name: Optional[str], + port: Optional[int], + region: Optional[str], + credentials: Optional[Dict[str, str]] = None, + client_session: Optional[Session] = None) -> str: + telemetry_factory = plugin_service.get_telemetry_factory() + context = telemetry_factory.open_telemetry_context("fetch authentication token", TelemetryTraceLevel.NESTED) + + try: + client = boto3.client("dsql", region_name=region) + + if user == "admin": + token = client.generate_db_connect_admin_auth_token(host_name, region) + else: + token = client.generate_db_connect_auth_token(host_name, region) + + logger.debug("IamAuthUtils.GeneratedNewAuthToken", token) + return token + except Exception as ex: + context.set_success(False) + context.set_exception(ex) + raise ex + finally: + context.close_context() \ No newline at end of file diff --git a/aws_advanced_python_wrapper/utils/iam_utils.py b/aws_advanced_python_wrapper/utils/iam_utils.py index ecb5868f..7610bce7 100644 --- a/aws_advanced_python_wrapper/utils/iam_utils.py +++ b/aws_advanced_python_wrapper/utils/iam_utils.py @@ -70,53 +70,6 @@ def get_port(props: Properties, host_info: HostInfo, dialect_default_port: int) def get_cache_key(user: Optional[str], hostname: Optional[str], port: int, region: Optional[str]) -> str: return f"{region}:{hostname}:{port}:{user}" - @staticmethod - def generate_authentication_token( - plugin_service: PluginService, - user: Optional[str], - host_name: Optional[str], - port: Optional[int], - region: Optional[str], - credentials: Optional[Dict[str, str]] = None, - client_session: Optional[Session] = None) -> str: - telemetry_factory = plugin_service.get_telemetry_factory() - context = telemetry_factory.open_telemetry_context("fetch authentication token", TelemetryTraceLevel.NESTED) - - try: - session = client_session if client_session else boto3.Session() - - if credentials is not None: - client = session.client( - 'rds', - region_name=region, - aws_access_key_id=credentials.get('AccessKeyId'), - aws_secret_access_key=credentials.get('SecretAccessKey'), - aws_session_token=credentials.get('SessionToken') - ) - else: - client = session.client( - 'rds', - region_name=region - ) - - token = client.generate_db_auth_token( - DBHostname=host_name, - Port=port, - DBUsername=user - ) - - client.close() - - logger.debug("IamAuthUtils.GeneratedNewAuthToken", token) - return token - except Exception as ex: - context.set_success(False) - context.set_exception(ex) - raise ex - finally: - context.close_context() - - class TokenInfo: @property def token(self): diff --git a/aws_advanced_python_wrapper/utils/rds_token_utils.py b/aws_advanced_python_wrapper/utils/rds_token_utils.py new file mode 100644 index 00000000..a312a71d --- /dev/null +++ b/aws_advanced_python_wrapper/utils/rds_token_utils.py @@ -0,0 +1,78 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Dict, Optional +from aws_advanced_python_wrapper.utils.log import Logger +from aws_advanced_python_wrapper.utils.telemetry.telemetry import \ + TelemetryTraceLevel +from aws_advanced_python_wrapper.utils.token_utils import TokenUtils + +if TYPE_CHECKING: + from aws_advanced_python_wrapper.plugin_service import PluginService +from boto3 import Session + +import boto3 + +logger = Logger(__name__) + +class RDSTokenUtils(TokenUtils): + def generate_authentication_token( + self, + plugin_service: PluginService, + user: Optional[str], + host_name: Optional[str], + port: Optional[int], + region: Optional[str], + credentials: Optional[Dict[str, str]] = None, + client_session: Optional[Session] = None) -> str: + + telemetry_factory = plugin_service.get_telemetry_factory() + context = telemetry_factory.open_telemetry_context("fetch authentication token", TelemetryTraceLevel.NESTED) + + try: + session = client_session if client_session else boto3.Session() + + if credentials is not None: + client = session.client( + 'rds', + region_name=region, + aws_access_key_id=credentials.get('AccessKeyId'), + aws_secret_access_key=credentials.get('SecretAccessKey'), + aws_session_token=credentials.get('SessionToken') + ) + else: + client = session.client( + 'rds', + region_name=region + ) + + token = client.generate_db_auth_token( + DBHostname=host_name, + Port=port, + DBUsername=user + ) + + client.close() + + logger.debug("IamAuthUtils.GeneratedNewAuthToken", token) + return token + except Exception as ex: + context.set_success(False) + context.set_exception(ex) + raise ex + finally: + context.close_context() diff --git a/aws_advanced_python_wrapper/utils/rds_url_type.py b/aws_advanced_python_wrapper/utils/rds_url_type.py index 7226c33c..911c25c6 100644 --- a/aws_advanced_python_wrapper/utils/rds_url_type.py +++ b/aws_advanced_python_wrapper/utils/rds_url_type.py @@ -34,4 +34,5 @@ def __init__(self, is_rds: bool, is_rds_cluster: bool): RDS_PROXY = True, False, RDS_INSTANCE = True, False, RDS_AURORA_LIMITLESS_DB_SHARD_GROUP = True, False, + DSQL_CLUSTER = False, False, OTHER = False, False diff --git a/aws_advanced_python_wrapper/utils/rdsutils.py b/aws_advanced_python_wrapper/utils/rdsutils.py index 7e289d2d..6dade817 100644 --- a/aws_advanced_python_wrapper/utils/rdsutils.py +++ b/aws_advanced_python_wrapper/utils/rdsutils.py @@ -108,6 +108,10 @@ class RdsUtils: r"(?Pcluster-|cluster-ro-)+" \ r"(?P[a-zA-Z0-9]+\.rds\.(?P[a-zA-Z0-9\-]+)" \ r"\.(amazonaws\.com|c2s\.ic\.gov|sc2s\.sgov\.gov))$" + AURORA_DSQL_CLUSTER_PATTERN = r"^(?P[^.]+)\." \ + r"(?Pdsql(?:-[^.]+)?)\." \ + r"(?P(?P[a-zA-Z0-9\-]+)" \ + r"\.on\.aws\.?)$" ELB_PATTERN = r"^(?.+)\.elb\.((?[a-zA-Z0-9\-]+)\.amazonaws\.com)$" IP_V4 = r"^(([1-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){1}" \ @@ -148,6 +152,14 @@ def is_rds_dns(self, host: str) -> bool: def is_rds_instance(self, host: str) -> bool: return self._get_dns_group(host) is None and self.is_rds_dns(host) + + def is_dsql_cluster(self, host: str) -> bool: + if not host or not host.strip(): + return False + + pattern = self._find(host, [RdsUtils.AURORA_DSQL_CLUSTER_PATTERN]) + + return pattern is not None def is_rds_proxy_dns(self, host: str) -> bool: dns_group = self._get_dns_group(host) @@ -257,6 +269,8 @@ def identify_rds_type(self, host: Optional[str]) -> RdsUrlType: return RdsUrlType.RDS_PROXY elif self.is_rds_instance(host): return RdsUrlType.RDS_INSTANCE + elif self.is_dsql_cluster(host): + return RdsUrlType.DSQL_CLUSTER return RdsUrlType.OTHER diff --git a/aws_advanced_python_wrapper/utils/token_utils.py b/aws_advanced_python_wrapper/utils/token_utils.py new file mode 100644 index 00000000..b31332f5 --- /dev/null +++ b/aws_advanced_python_wrapper/utils/token_utils.py @@ -0,0 +1,35 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Dict, Optional + +if TYPE_CHECKING: + from aws_advanced_python_wrapper.plugin_service import PluginService +from boto3 import Session + +class TokenUtils(ABC): + @abstractmethod + def generate_authentication_token( + self, + plugin_service: PluginService, + user: Optional[str], + host_name: Optional[str], + port: Optional[int], + region: Optional[str], + credentials: Optional[Dict[str, str]] = None, + client_session: Optional[Session] = None) -> str: + pass diff --git a/docs/examples/DSQLIamAuthentication.py b/docs/examples/DSQLIamAuthentication.py new file mode 100644 index 00000000..1a440601 --- /dev/null +++ b/docs/examples/DSQLIamAuthentication.py @@ -0,0 +1,39 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import psycopg + +from aws_advanced_python_wrapper import AwsWrapperConnection + +if __name__ == "__main__": + with AwsWrapperConnection.connect( + psycopg.Connection.connect, + host="abcd.dsql.us-east-1.on.aws", + dbname="postgres", + user="admin", + plugins="iam_dsql", + iam_region="us-east-1", + wrapper_dialect="aurora-pg", + autocommit=True + ) as awsconn, awsconn.cursor() as awscursor: + awscursor.execute("CREATE TABLE IF NOT EXISTS bank_test (id int primary key, name varchar(40), account_balance int)") + awscursor.execute("INSERT INTO bank_test VALUES (%s, %s, %s)", (0, "Jane Doe", 200)) + awscursor.execute("INSERT INTO bank_test VALUES (%s, %s, %s)", (1, "John Smith", 200)) + awscursor.execute("SELECT * FROM bank_test") + + res = awscursor.fetchall() + for record in res: + print(record) + awscursor.execute("DROP TABLE bank_test") + From 821750a5c3a227cbef5b948be0d0b6432670c227 Mon Sep 17 00:00:00 2001 From: leszekamz Date: Thu, 17 Jul 2025 16:16:50 -0700 Subject: [PATCH 02/12] Fixed errors with updated interfaces for federated and octa plugins. --- aws_advanced_python_wrapper/federated_plugin.py | 9 ++++++--- aws_advanced_python_wrapper/okta_plugin.py | 9 ++++++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/aws_advanced_python_wrapper/federated_plugin.py b/aws_advanced_python_wrapper/federated_plugin.py index 0eb8258f..08d63d2c 100644 --- a/aws_advanced_python_wrapper/federated_plugin.py +++ b/aws_advanced_python_wrapper/federated_plugin.py @@ -44,6 +44,8 @@ from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils +from aws_advanced_python_wrapper.utils.token_utils import TokenUtils +from aws_advanced_python_wrapper.utils.rds_token_utils import RDSTokenUtils logger = Logger(__name__) @@ -55,12 +57,13 @@ class FederatedAuthPlugin(Plugin): _rds_utils: RdsUtils = RdsUtils() _token_cache: Dict[str, TokenInfo] = {} - def __init__(self, plugin_service: PluginService, credentials_provider_factory: CredentialsProviderFactory, session: Optional[Session] = None): + def __init__(self, plugin_service: PluginService, credentials_provider_factory: CredentialsProviderFactory, token_utils: TokenUtils, session: Optional[Session] = None): self._plugin_service = plugin_service self._credentials_provider_factory = credentials_provider_factory self._session = session self._region_utils = RegionUtils() + self._token_utils = token_utils telemetry_factory = self._plugin_service.get_telemetry_factory() self._fetch_token_counter = telemetry_factory.create_counter("federated.fetch_token.count") self._cache_size_gauge = telemetry_factory.create_gauge("federated.token_cache.size", lambda: len(FederatedAuthPlugin._token_cache)) @@ -145,7 +148,7 @@ def _update_authentication_token(self, credentials: Optional[Dict[str, str]] = self._credentials_provider_factory.get_aws_credentials(region, props) self._fetch_token_counter.inc() - token: str = IamAuthUtils.generate_authentication_token( + token: str = self._token_utils.generate_authentication_token( self._plugin_service, user, host_info.host, @@ -159,7 +162,7 @@ def _update_authentication_token(self, class FederatedAuthPluginFactory(PluginFactory): def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: - return FederatedAuthPlugin(plugin_service, self.get_credentials_provider_factory(plugin_service, props)) + return FederatedAuthPlugin(plugin_service, self.get_credentials_provider_factory(plugin_service, props), RDSTokenUtils()) def get_credentials_provider_factory(self, plugin_service: PluginService, props: Properties) -> AdfsCredentialsProviderFactory: idp_name = WrapperProperties.IDP_NAME.get(props) diff --git a/aws_advanced_python_wrapper/okta_plugin.py b/aws_advanced_python_wrapper/okta_plugin.py index 55bd9980..6909c079 100644 --- a/aws_advanced_python_wrapper/okta_plugin.py +++ b/aws_advanced_python_wrapper/okta_plugin.py @@ -41,6 +41,8 @@ from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils +from aws_advanced_python_wrapper.utils.token_utils import TokenUtils +from aws_advanced_python_wrapper.utils.rds_token_utils import RDSTokenUtils logger = Logger(__name__) @@ -51,12 +53,13 @@ class OktaAuthPlugin(Plugin): _rds_utils: RdsUtils = RdsUtils() _token_cache: Dict[str, TokenInfo] = {} - def __init__(self, plugin_service: PluginService, credentials_provider_factory: CredentialsProviderFactory, session: Optional[Session] = None): + def __init__(self, plugin_service: PluginService, credentials_provider_factory: CredentialsProviderFactory, token_utils: TokenUtils, session: Optional[Session] = None): self._plugin_service = plugin_service self._credentials_provider_factory = credentials_provider_factory self._session = session self._region_utils = RegionUtils() + self._token_utils = token_utils telemetry_factory = self._plugin_service.get_telemetry_factory() self._fetch_token_counter = telemetry_factory.create_counter("okta.fetch_token.count") self._cache_size_gauge = telemetry_factory.create_gauge("okta.token_cache.size", lambda: len(OktaAuthPlugin._token_cache)) @@ -140,7 +143,7 @@ def _update_authentication_token(self, port: int = IamAuthUtils.get_port(props, host_info, self._plugin_service.database_dialect.default_port) credentials: Optional[Dict[str, str]] = self._credentials_provider_factory.get_aws_credentials(region, props) - token: str = IamAuthUtils.generate_authentication_token( + token: str = self._token_utils.generate_authentication_token( self._plugin_service, user, host_info.host, @@ -228,7 +231,7 @@ def get_saml_assertion(self, props: Properties): class OktaAuthPluginFactory(PluginFactory): def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: - return OktaAuthPlugin(plugin_service, self.get_credentials_provider_factory(plugin_service, props)) + return OktaAuthPlugin(plugin_service, self.get_credentials_provider_factory(plugin_service, props), RDSTokenUtils()) def get_credentials_provider_factory(self, plugin_service: PluginService, props: Properties) -> OktaCredentialsProviderFactory: return OktaCredentialsProviderFactory(plugin_service, props) From 69a1a5cdd4edd383b68f7cea4976d1463a673fe0 Mon Sep 17 00:00:00 2001 From: leszekamz Date: Thu, 17 Jul 2025 17:38:27 -0700 Subject: [PATCH 03/12] flake and isort updates --- .../dsql_iam_auth_plugin_factory.py | 5 +++-- aws_advanced_python_wrapper/federated_plugin.py | 10 +++++++--- aws_advanced_python_wrapper/iam_plugin.py | 7 ++++--- aws_advanced_python_wrapper/okta_plugin.py | 10 +++++++--- aws_advanced_python_wrapper/plugin_service.py | 3 ++- aws_advanced_python_wrapper/utils/dsql_token_utils.py | 9 +++++---- aws_advanced_python_wrapper/utils/iam_utils.py | 9 ++------- aws_advanced_python_wrapper/utils/rds_token_utils.py | 9 +++++---- aws_advanced_python_wrapper/utils/rdsutils.py | 4 ++-- aws_advanced_python_wrapper/utils/token_utils.py | 5 +++-- docs/examples/DSQLIamAuthentication.py | 1 - 11 files changed, 40 insertions(+), 32 deletions(-) diff --git a/aws_advanced_python_wrapper/dsql_iam_auth_plugin_factory.py b/aws_advanced_python_wrapper/dsql_iam_auth_plugin_factory.py index 05ca33f8..19a4ed05 100644 --- a/aws_advanced_python_wrapper/dsql_iam_auth_plugin_factory.py +++ b/aws_advanced_python_wrapper/dsql_iam_auth_plugin_factory.py @@ -16,13 +16,14 @@ from typing import TYPE_CHECKING -from aws_advanced_python_wrapper.utils.dsql_token_utils import DSQLTokenUtils from aws_advanced_python_wrapper.iam_plugin import IamAuthPlugin from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory -from aws_advanced_python_wrapper.utils.properties import (Properties) +from aws_advanced_python_wrapper.utils.dsql_token_utils import DSQLTokenUtils if TYPE_CHECKING: from aws_advanced_python_wrapper.plugin_service import PluginService + from aws_advanced_python_wrapper.utils.properties import (Properties) + class DsqlIamAuthPluginFactory(PluginFactory): def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: diff --git a/aws_advanced_python_wrapper/federated_plugin.py b/aws_advanced_python_wrapper/federated_plugin.py index 08d63d2c..abfdfa71 100644 --- a/aws_advanced_python_wrapper/federated_plugin.py +++ b/aws_advanced_python_wrapper/federated_plugin.py @@ -31,6 +31,7 @@ from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.pep249 import Connection from aws_advanced_python_wrapper.plugin_service import PluginService + from aws_advanced_python_wrapper.utils.token_utils import TokenUtils from datetime import datetime, timedelta from typing import Callable, Dict, Optional, Set @@ -43,9 +44,9 @@ from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) -from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils -from aws_advanced_python_wrapper.utils.token_utils import TokenUtils from aws_advanced_python_wrapper.utils.rds_token_utils import RDSTokenUtils +from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils + logger = Logger(__name__) @@ -57,7 +58,10 @@ class FederatedAuthPlugin(Plugin): _rds_utils: RdsUtils = RdsUtils() _token_cache: Dict[str, TokenInfo] = {} - def __init__(self, plugin_service: PluginService, credentials_provider_factory: CredentialsProviderFactory, token_utils: TokenUtils, session: Optional[Session] = None): + def __init__(self, plugin_service: PluginService, + credentials_provider_factory: CredentialsProviderFactory, + token_utils: TokenUtils, + session: Optional[Session] = None): self._plugin_service = plugin_service self._credentials_provider_factory = credentials_provider_factory self._session = session diff --git a/aws_advanced_python_wrapper/iam_plugin.py b/aws_advanced_python_wrapper/iam_plugin.py index 7128c572..fecfa6ac 100644 --- a/aws_advanced_python_wrapper/iam_plugin.py +++ b/aws_advanced_python_wrapper/iam_plugin.py @@ -17,9 +17,9 @@ from typing import TYPE_CHECKING from aws_advanced_python_wrapper.utils.iam_utils import IamAuthUtils, TokenInfo -from aws_advanced_python_wrapper.utils.region_utils import RegionUtils -from aws_advanced_python_wrapper.utils.token_utils import TokenUtils +from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils from aws_advanced_python_wrapper.utils.rds_token_utils import RDSTokenUtils +from aws_advanced_python_wrapper.utils.region_utils import RegionUtils if TYPE_CHECKING: from boto3 import Session @@ -27,6 +27,7 @@ from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.pep249 import Connection from aws_advanced_python_wrapper.plugin_service import PluginService + from aws_advanced_python_wrapper.utils.token_utils import TokenUtils from datetime import datetime, timedelta from typing import Callable, Dict, Optional, Set @@ -37,7 +38,7 @@ from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) -from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils + logger = Logger(__name__) diff --git a/aws_advanced_python_wrapper/okta_plugin.py b/aws_advanced_python_wrapper/okta_plugin.py index 6909c079..3e5c1104 100644 --- a/aws_advanced_python_wrapper/okta_plugin.py +++ b/aws_advanced_python_wrapper/okta_plugin.py @@ -31,6 +31,7 @@ from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.pep249 import Connection from aws_advanced_python_wrapper.plugin_service import PluginService + from aws_advanced_python_wrapper.utils.token_utils import TokenUtils import requests @@ -40,9 +41,9 @@ from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) -from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils -from aws_advanced_python_wrapper.utils.token_utils import TokenUtils from aws_advanced_python_wrapper.utils.rds_token_utils import RDSTokenUtils +from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils + logger = Logger(__name__) @@ -53,7 +54,10 @@ class OktaAuthPlugin(Plugin): _rds_utils: RdsUtils = RdsUtils() _token_cache: Dict[str, TokenInfo] = {} - def __init__(self, plugin_service: PluginService, credentials_provider_factory: CredentialsProviderFactory, token_utils: TokenUtils, session: Optional[Session] = None): + def __init__(self, plugin_service: PluginService, + credentials_provider_factory: CredentialsProviderFactory, + token_utils: TokenUtils, + session: Optional[Session] = None): self._plugin_service = plugin_service self._credentials_provider_factory = credentials_provider_factory self._session = session diff --git a/aws_advanced_python_wrapper/plugin_service.py b/aws_advanced_python_wrapper/plugin_service.py index b38a26b2..8d2d4fea 100644 --- a/aws_advanced_python_wrapper/plugin_service.py +++ b/aws_advanced_python_wrapper/plugin_service.py @@ -57,6 +57,8 @@ from aws_advanced_python_wrapper.developer_plugin import DeveloperPluginFactory from aws_advanced_python_wrapper.driver_configuration_profiles import \ DriverConfigurationProfiles +from aws_advanced_python_wrapper.dsql_iam_auth_plugin_factory import \ + DsqlIamAuthPluginFactory from aws_advanced_python_wrapper.errors import (AwsWrapperError, QueryTimeoutError, UnsupportedOperationError) @@ -73,7 +75,6 @@ HostMonitoringPluginFactory from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole from aws_advanced_python_wrapper.iam_plugin import IamAuthPluginFactory -from aws_advanced_python_wrapper.dsql_iam_auth_plugin_factory import DsqlIamAuthPluginFactory from aws_advanced_python_wrapper.plugin import CanReleaseResources from aws_advanced_python_wrapper.read_write_splitting_plugin import \ ReadWriteSplittingPluginFactory diff --git a/aws_advanced_python_wrapper/utils/dsql_token_utils.py b/aws_advanced_python_wrapper/utils/dsql_token_utils.py index d337f1e1..e5ba47f7 100644 --- a/aws_advanced_python_wrapper/utils/dsql_token_utils.py +++ b/aws_advanced_python_wrapper/utils/dsql_token_utils.py @@ -10,12 +10,12 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. +# limitations under the License. from __future__ import annotations -from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Dict, Optional + from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.telemetry.telemetry import \ TelemetryTraceLevel @@ -23,12 +23,13 @@ if TYPE_CHECKING: from aws_advanced_python_wrapper.plugin_service import PluginService -from boto3 import Session + from boto3 import Session import boto3 logger = Logger(__name__) + class DSQLTokenUtils(TokenUtils): def generate_authentication_token( self, @@ -57,4 +58,4 @@ def generate_authentication_token( context.set_exception(ex) raise ex finally: - context.close_context() \ No newline at end of file + context.close_context() diff --git a/aws_advanced_python_wrapper/utils/iam_utils.py b/aws_advanced_python_wrapper/utils/iam_utils.py index 7610bce7..412499d0 100644 --- a/aws_advanced_python_wrapper/utils/iam_utils.py +++ b/aws_advanced_python_wrapper/utils/iam_utils.py @@ -15,22 +15,16 @@ from __future__ import annotations from datetime import datetime -from typing import TYPE_CHECKING, Dict, Optional - -import boto3 +from typing import TYPE_CHECKING, Optional from aws_advanced_python_wrapper.errors import AwsWrapperError from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils -from aws_advanced_python_wrapper.utils.telemetry.telemetry import \ - TelemetryTraceLevel if TYPE_CHECKING: from aws_advanced_python_wrapper.hostinfo import HostInfo - from aws_advanced_python_wrapper.plugin_service import PluginService - from boto3 import Session from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) @@ -70,6 +64,7 @@ def get_port(props: Properties, host_info: HostInfo, dialect_default_port: int) def get_cache_key(user: Optional[str], hostname: Optional[str], port: int, region: Optional[str]) -> str: return f"{region}:{hostname}:{port}:{user}" + class TokenInfo: @property def token(self): diff --git a/aws_advanced_python_wrapper/utils/rds_token_utils.py b/aws_advanced_python_wrapper/utils/rds_token_utils.py index a312a71d..3d6c0dea 100644 --- a/aws_advanced_python_wrapper/utils/rds_token_utils.py +++ b/aws_advanced_python_wrapper/utils/rds_token_utils.py @@ -10,12 +10,12 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. +# limitations under the License. from __future__ import annotations -from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Dict, Optional + from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.telemetry.telemetry import \ TelemetryTraceLevel @@ -23,12 +23,13 @@ if TYPE_CHECKING: from aws_advanced_python_wrapper.plugin_service import PluginService -from boto3 import Session + from boto3 import Session import boto3 logger = Logger(__name__) + class RDSTokenUtils(TokenUtils): def generate_authentication_token( self, @@ -39,7 +40,7 @@ def generate_authentication_token( region: Optional[str], credentials: Optional[Dict[str, str]] = None, client_session: Optional[Session] = None) -> str: - + telemetry_factory = plugin_service.get_telemetry_factory() context = telemetry_factory.open_telemetry_context("fetch authentication token", TelemetryTraceLevel.NESTED) diff --git a/aws_advanced_python_wrapper/utils/rdsutils.py b/aws_advanced_python_wrapper/utils/rdsutils.py index 6dade817..41726035 100644 --- a/aws_advanced_python_wrapper/utils/rdsutils.py +++ b/aws_advanced_python_wrapper/utils/rdsutils.py @@ -152,13 +152,13 @@ def is_rds_dns(self, host: str) -> bool: def is_rds_instance(self, host: str) -> bool: return self._get_dns_group(host) is None and self.is_rds_dns(host) - + def is_dsql_cluster(self, host: str) -> bool: if not host or not host.strip(): return False pattern = self._find(host, [RdsUtils.AURORA_DSQL_CLUSTER_PATTERN]) - + return pattern is not None def is_rds_proxy_dns(self, host: str) -> bool: diff --git a/aws_advanced_python_wrapper/utils/token_utils.py b/aws_advanced_python_wrapper/utils/token_utils.py index b31332f5..5a13da4d 100644 --- a/aws_advanced_python_wrapper/utils/token_utils.py +++ b/aws_advanced_python_wrapper/utils/token_utils.py @@ -10,7 +10,7 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. +# limitations under the License. from __future__ import annotations @@ -19,7 +19,8 @@ if TYPE_CHECKING: from aws_advanced_python_wrapper.plugin_service import PluginService -from boto3 import Session + from boto3 import Session + class TokenUtils(ABC): @abstractmethod diff --git a/docs/examples/DSQLIamAuthentication.py b/docs/examples/DSQLIamAuthentication.py index 1a440601..49dbd115 100644 --- a/docs/examples/DSQLIamAuthentication.py +++ b/docs/examples/DSQLIamAuthentication.py @@ -36,4 +36,3 @@ for record in res: print(record) awscursor.execute("DROP TABLE bank_test") - From ed40d7189ef67a88e921af1565aef93d5e16053e Mon Sep 17 00:00:00 2001 From: leszekamz Date: Thu, 17 Jul 2025 18:26:53 -0700 Subject: [PATCH 04/12] isort updates --- aws_advanced_python_wrapper/dsql_iam_auth_plugin_factory.py | 2 +- aws_advanced_python_wrapper/federated_plugin.py | 1 - aws_advanced_python_wrapper/iam_plugin.py | 3 +-- aws_advanced_python_wrapper/okta_plugin.py | 1 - aws_advanced_python_wrapper/utils/token_utils.py | 3 ++- 5 files changed, 4 insertions(+), 6 deletions(-) diff --git a/aws_advanced_python_wrapper/dsql_iam_auth_plugin_factory.py b/aws_advanced_python_wrapper/dsql_iam_auth_plugin_factory.py index 19a4ed05..62068870 100644 --- a/aws_advanced_python_wrapper/dsql_iam_auth_plugin_factory.py +++ b/aws_advanced_python_wrapper/dsql_iam_auth_plugin_factory.py @@ -22,7 +22,7 @@ if TYPE_CHECKING: from aws_advanced_python_wrapper.plugin_service import PluginService - from aws_advanced_python_wrapper.utils.properties import (Properties) + from aws_advanced_python_wrapper.utils.properties import Properties class DsqlIamAuthPluginFactory(PluginFactory): diff --git a/aws_advanced_python_wrapper/federated_plugin.py b/aws_advanced_python_wrapper/federated_plugin.py index abfdfa71..e6f1038e 100644 --- a/aws_advanced_python_wrapper/federated_plugin.py +++ b/aws_advanced_python_wrapper/federated_plugin.py @@ -47,7 +47,6 @@ from aws_advanced_python_wrapper.utils.rds_token_utils import RDSTokenUtils from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils - logger = Logger(__name__) diff --git a/aws_advanced_python_wrapper/iam_plugin.py b/aws_advanced_python_wrapper/iam_plugin.py index fecfa6ac..06fb3b07 100644 --- a/aws_advanced_python_wrapper/iam_plugin.py +++ b/aws_advanced_python_wrapper/iam_plugin.py @@ -17,8 +17,8 @@ from typing import TYPE_CHECKING from aws_advanced_python_wrapper.utils.iam_utils import IamAuthUtils, TokenInfo -from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils from aws_advanced_python_wrapper.utils.rds_token_utils import RDSTokenUtils +from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils from aws_advanced_python_wrapper.utils.region_utils import RegionUtils if TYPE_CHECKING: @@ -39,7 +39,6 @@ from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) - logger = Logger(__name__) diff --git a/aws_advanced_python_wrapper/okta_plugin.py b/aws_advanced_python_wrapper/okta_plugin.py index 3e5c1104..bd46e2e6 100644 --- a/aws_advanced_python_wrapper/okta_plugin.py +++ b/aws_advanced_python_wrapper/okta_plugin.py @@ -44,7 +44,6 @@ from aws_advanced_python_wrapper.utils.rds_token_utils import RDSTokenUtils from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils - logger = Logger(__name__) diff --git a/aws_advanced_python_wrapper/utils/token_utils.py b/aws_advanced_python_wrapper/utils/token_utils.py index 5a13da4d..f23c61e7 100644 --- a/aws_advanced_python_wrapper/utils/token_utils.py +++ b/aws_advanced_python_wrapper/utils/token_utils.py @@ -18,9 +18,10 @@ from typing import TYPE_CHECKING, Dict, Optional if TYPE_CHECKING: - from aws_advanced_python_wrapper.plugin_service import PluginService from boto3 import Session + from aws_advanced_python_wrapper.plugin_service import PluginService + class TokenUtils(ABC): @abstractmethod From c8383798965be5a851f06cb40ff035e0da1b121c Mon Sep 17 00:00:00 2001 From: leszekamz Date: Fri, 18 Jul 2025 11:34:28 -0700 Subject: [PATCH 05/12] fixed unit tests --- tests/unit/test_federated_auth_plugin.py | 21 +++++++++++++---- tests/unit/test_iam_plugin.py | 30 ++++++++++++++++++------ tests/unit/test_okta_plugin.py | 23 ++++++++++++++---- 3 files changed, 58 insertions(+), 16 deletions(-) diff --git a/tests/unit/test_federated_auth_plugin.py b/tests/unit/test_federated_auth_plugin.py index 1c3a77e3..3f982a70 100644 --- a/tests/unit/test_federated_auth_plugin.py +++ b/tests/unit/test_federated_auth_plugin.py @@ -25,6 +25,7 @@ from aws_advanced_python_wrapper.iam_plugin import TokenInfo from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) +from aws_advanced_python_wrapper.utils.rds_token_utils import RDSTokenUtils _GENERATED_TOKEN = "generated_token" _TEST_TOKEN = "test_token" @@ -101,6 +102,7 @@ def test_pg_connect_valid_token_in_cache(mocker, mock_plugin_service, mock_sessi _token_cache[_PG_CACHE_KEY] = initial_token target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, + RDSTokenUtils(), mock_session) key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + str(_DEFAULT_PG_PORT) + ":postgesqlUser" _token_cache[key] = initial_token @@ -129,7 +131,10 @@ def test_expired_cached_token(mocker, mock_plugin_service, mock_session, mock_fu initial_token = TokenInfo(_TEST_TOKEN, datetime.now() - timedelta(minutes=5)) _token_cache[_PG_CACHE_KEY] = initial_token - target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, mock_session) + target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, + mock_credentials_provider_factory, + RDSTokenUtils(), + mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), @@ -154,7 +159,10 @@ def test_no_cached_token(mocker, mock_plugin_service, mock_session, mock_func, m test_props: Properties = Properties({"plugins": "federated_auth", "user": "postgresqlUser", "idp_username": "user", "idp_password": "password"}) WrapperProperties.DB_USER.set(test_props, _DB_USER) - target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, mock_session) + target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, + mock_credentials_provider_factory, + RDSTokenUtils(), + mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), @@ -183,7 +191,9 @@ def test_no_cached_token_raises_exception(mocker, mock_plugin_service, mock_sess exception_message = "generic exception" mock_func.side_effect = Exception(exception_message) - target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, + target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, + mock_credentials_provider_factory, + RDSTokenUtils(), mock_session) with pytest.raises(Exception) as e_info: target_plugin.connect( @@ -229,7 +239,10 @@ def test_connect_with_specified_iam_host_port_region(mocker, mock_client.generate_db_auth_token.return_value = f"{_TEST_TOKEN}:{expected_region}" - target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, mock_session) + target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, + mock_credentials_provider_factory, + RDSTokenUtils(), + mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), driver_dialect=mock_dialect, diff --git a/tests/unit/test_iam_plugin.py b/tests/unit/test_iam_plugin.py index 04273698..10a3a3cd 100644 --- a/tests/unit/test_iam_plugin.py +++ b/tests/unit/test_iam_plugin.py @@ -26,6 +26,7 @@ from aws_advanced_python_wrapper.iam_plugin import IamAuthPlugin, TokenInfo from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) +from aws_advanced_python_wrapper.utils.rds_token_utils import RDSTokenUtils _GENERATED_TOKEN = "generated_token" _TEST_TOKEN = "test_token" @@ -99,6 +100,7 @@ def test_pg_connect_valid_token_in_cache(mocker, mock_plugin_service, mock_sessi _token_cache[_PG_CACHE_KEY] = initial_token target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + RDSTokenUtils(), mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), @@ -127,6 +129,7 @@ def test_pg_connect_with_invalid_port_fall_backs_to_host_port( assert test_props.get("password") is None target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + RDSTokenUtils(), mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), @@ -163,6 +166,7 @@ def test_pg_connect_with_invalid_port_and_no_host_port_fall_backs_to_host_port( assert test_props.get("password") is None target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + RDSTokenUtils(), mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), @@ -195,7 +199,9 @@ def test_connect_expired_token_in_cache(mocker, mock_plugin_service, mock_sessio _token_cache[_PG_CACHE_KEY] = initial_token mock_func.side_effect = Exception("generic exception") - target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, mock_session) + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + RDSTokenUtils(), + mock_session) with pytest.raises(Exception): target_plugin.connect( target_driver_func=mocker.MagicMock(), @@ -220,7 +226,9 @@ def test_connect_expired_token_in_cache(mocker, mock_plugin_service, mock_sessio @patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) def test_connect_empty_cache(mocker, mock_plugin_service, mock_connection, mock_session, mock_func, mock_client, mock_dialect): test_props: Properties = Properties({"user": "postgresqlUser"}) - target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, mock_session) + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + RDSTokenUtils(), + mock_session) actual_connection = target_plugin.connect( target_driver_func=mocker.MagicMock(), driver_dialect=mock_dialect, @@ -251,7 +259,9 @@ def test_connect_with_specified_port(mocker, mock_plugin_service, mock_session, # Assert no password has been set assert test_props.get("password") is None - target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, mock_session) + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + RDSTokenUtils(), + mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), driver_dialect=mock_dialect, @@ -285,7 +295,9 @@ def test_connect_with_specified_iam_default_port(mocker, mock_plugin_service, mo # Assert no password has been set assert test_props.get("password") is None - target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, mock_session) + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + RDSTokenUtils(), + mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), driver_dialect=mock_dialect, @@ -323,7 +335,9 @@ def test_connect_with_specified_region(mocker, mock_plugin_service, mock_session assert test_props.get("password") is None mock_client.generate_db_auth_token.return_value = f"{_TEST_TOKEN}:{iam_region}" - target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, mock_session) + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + RDSTokenUtils(), + mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), driver_dialect=mock_dialect, @@ -369,7 +383,9 @@ def test_connect_with_specified_host(iam_host: str, mocker, mock_plugin_service, assert test_props.get("password") is None mock_client.generate_db_auth_token.return_value = f"{_TEST_TOKEN}:{iam_host}" - target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, mock_session) + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + RDSTokenUtils(), + mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), driver_dialect=mock_dialect, @@ -411,7 +427,7 @@ def test_aws_supported_regions_url_exists(): def test_invalid_iam_host(host, mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): test_props: Properties = Properties({"user": "postgresqlUser"}) with pytest.raises(AwsWrapperError): - target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, mock_session) + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, RDSTokenUtils(), mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), driver_dialect=mock_dialect, diff --git a/tests/unit/test_okta_plugin.py b/tests/unit/test_okta_plugin.py index 72f9727a..e2823568 100644 --- a/tests/unit/test_okta_plugin.py +++ b/tests/unit/test_okta_plugin.py @@ -25,6 +25,7 @@ from aws_advanced_python_wrapper.okta_plugin import OktaAuthPlugin from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) +from aws_advanced_python_wrapper.utils.rds_token_utils import RDSTokenUtils _GENERATED_TOKEN = "generated_token" _TEST_TOKEN = "test_token" @@ -100,7 +101,7 @@ def test_pg_connect_valid_token_in_cache(mocker, mock_plugin_service, mock_sessi initial_token = TokenInfo(_TEST_TOKEN, datetime.now() + timedelta(minutes=5)) _token_cache[_PG_CACHE_KEY] = initial_token - target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, mock_session) + target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, RDSTokenUtils(), mock_session) key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + str(_DEFAULT_PG_PORT) + ":postgesqlUser" _token_cache[key] = initial_token @@ -127,7 +128,10 @@ def test_expired_cached_token(mocker, mock_plugin_service, mock_session, mock_fu initial_token = TokenInfo(_TEST_TOKEN, datetime.now() - timedelta(minutes=5)) _token_cache[_PG_CACHE_KEY] = initial_token - target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, mock_session) + target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, + mock_credentials_provider_factory, + RDSTokenUtils(), + mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), @@ -151,7 +155,10 @@ def test_no_cached_token(mocker, mock_plugin_service, mock_session, mock_func, m test_props: Properties = Properties({"plugins": "okta", "user": "postgresqlUser", "idp_username": "user", "idp_password": "password"}) WrapperProperties.DB_USER.set(test_props, _DB_USER) - target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, mock_session) + target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, + mock_credentials_provider_factory, + RDSTokenUtils(), + mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), @@ -179,7 +186,10 @@ def test_no_cached_token_raises_exception(mocker, mock_plugin_service, mock_sess exception_message = "generic exception" mock_func.side_effect = Exception(exception_message) - target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, mock_session) + target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, + mock_credentials_provider_factory, + RDSTokenUtils(), + mock_session) with pytest.raises(Exception) as e_info: target_plugin.connect( @@ -225,7 +235,10 @@ def test_connect_with_specified_iam_host_port_region(mocker, mock_client.generate_db_auth_token.return_value = f"{_TEST_TOKEN}:{expected_region}" - target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, mock_session) + target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, + mock_credentials_provider_factory, + RDSTokenUtils(), + mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), driver_dialect=mock_dialect, From 737adcfe918aa293644a9d1d0367ed8273caaa6f Mon Sep 17 00:00:00 2001 From: leszekamz Date: Mon, 21 Jul 2025 21:00:01 -0700 Subject: [PATCH 06/12] added DSQL unit tests --- .../utils/dsql_token_utils.py | 15 +- docs/examples/DSQLIamAuthentication.py | 2 +- tests/unit/test_iam_dsql_plugin.py | 470 ++++++++++++++++++ 3 files changed, 485 insertions(+), 2 deletions(-) create mode 100644 tests/unit/test_iam_dsql_plugin.py diff --git a/aws_advanced_python_wrapper/utils/dsql_token_utils.py b/aws_advanced_python_wrapper/utils/dsql_token_utils.py index e5ba47f7..47f0f7a7 100644 --- a/aws_advanced_python_wrapper/utils/dsql_token_utils.py +++ b/aws_advanced_python_wrapper/utils/dsql_token_utils.py @@ -44,7 +44,20 @@ def generate_authentication_token( context = telemetry_factory.open_telemetry_context("fetch authentication token", TelemetryTraceLevel.NESTED) try: - client = boto3.client("dsql", region_name=region) + session = client_session if client_session else boto3.Session() + if credentials is not None: + client = session.client( + 'dsql', + region_name=region, + aws_access_key_id=credentials.get('AccessKeyId'), + aws_secret_access_key=credentials.get('SecretAccessKey'), + aws_session_token=credentials.get('SessionToken') + ) + else: + client = session.client( + 'dsql', + region_name=region + ) if user == "admin": token = client.generate_db_connect_admin_auth_token(host_name, region) diff --git a/docs/examples/DSQLIamAuthentication.py b/docs/examples/DSQLIamAuthentication.py index 49dbd115..4e9b2870 100644 --- a/docs/examples/DSQLIamAuthentication.py +++ b/docs/examples/DSQLIamAuthentication.py @@ -24,7 +24,7 @@ user="admin", plugins="iam_dsql", iam_region="us-east-1", - wrapper_dialect="aurora-pg", + wrapper_dialect="pg", autocommit=True ) as awsconn, awsconn.cursor() as awscursor: awscursor.execute("CREATE TABLE IF NOT EXISTS bank_test (id int primary key, name varchar(40), account_balance int)") diff --git a/tests/unit/test_iam_dsql_plugin.py b/tests/unit/test_iam_dsql_plugin.py new file mode 100644 index 00000000..3ebff4c8 --- /dev/null +++ b/tests/unit/test_iam_dsql_plugin.py @@ -0,0 +1,470 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import urllib.request +from datetime import datetime, timedelta +from typing import Dict +from unittest.mock import patch + +import pytest + +from aws_advanced_python_wrapper.errors import AwsWrapperError +from aws_advanced_python_wrapper.hostinfo import HostInfo +from aws_advanced_python_wrapper.iam_plugin import IamAuthPlugin, TokenInfo +from aws_advanced_python_wrapper.utils.properties import (Properties, + WrapperProperties) +from aws_advanced_python_wrapper.utils.iam_utils import IamAuthUtils +from aws_advanced_python_wrapper.utils.dsql_token_utils import DSQLTokenUtils + +_GENERATED_TOKEN = "generated_token admin" +_GENERATED_TOKEN_NON_ADMIN = "generated_token non-admin" +_TEST_TOKEN = "test_token" +_DEFAULT_PG_PORT = 5432 + +_PG_HOST_INFO = HostInfo("dsqltestclusternamefoobar1.dsql-gamma.us-east-2.on.aws") +_PG_HOST_INFO_WITH_PORT = HostInfo(_PG_HOST_INFO.url, port=1234) +_PG_REGION = "us-east-2" + +_PG_CACHE_KEY = f"{_PG_REGION}:{_PG_HOST_INFO.url}:{_DEFAULT_PG_PORT}:admin" + + +_token_cache: Dict[str, TokenInfo] = {} + + +@pytest.fixture(autouse=True) +def clear_caches(): + _token_cache.clear() + + +@pytest.fixture +def mock_session(mocker): + return mocker.MagicMock() + + +@pytest.fixture +def mock_client(mocker): + return mocker.MagicMock() + + +@pytest.fixture +def mock_connection(mocker): + return mocker.MagicMock() + + +@pytest.fixture +def mock_func(mocker): + return mocker.MagicMock() + + +@pytest.fixture +def mock_plugin_service(mocker): + return mocker.MagicMock() + + +@pytest.fixture +def mock_dialect(mocker): + return mocker.MagicMock() + + +@pytest.fixture(autouse=True) +def mock_default_behavior(mock_session, mock_client, mock_func, mock_connection, mock_plugin_service, mock_dialect): + mock_session.client.return_value = mock_client + mock_client.generate_db_connect_admin_auth_token.return_value = _GENERATED_TOKEN + mock_client.generate_db_connect_auth_token.return_value = _GENERATED_TOKEN_NON_ADMIN + mock_session.get_available_regions.return_value = ['us-east-1', 'us-east-2', 'us-west-1', 'us-west-2'] + mock_func.return_value = mock_connection + mock_plugin_service.driver_dialect = mock_dialect + mock_plugin_service.database_dialect = mock_dialect + mock_dialect.default_port = _DEFAULT_PG_PORT + + +@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) +def set_token_cache(user, host, port, region, expired=False): + if not expired: + initial_token = TokenInfo(_TEST_TOKEN, datetime.now() + timedelta(minutes=5)) + else: + initial_token = TokenInfo(_TEST_TOKEN, datetime.now() - timedelta(minutes=5)) + cache_key: str = IamAuthUtils.get_cache_key( + user, + host, + port, + region + ) + _token_cache[cache_key] = initial_token + + return cache_key, initial_token + + +@pytest.mark.parametrize("user", [ + pytest.param("admin"), + pytest.param("non-admin"), +]) +@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) +def test_pg_connect_valid_token_in_cache(user, mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): + test_props: Properties = Properties({"user": user}) + cache_key, _ = set_token_cache(user, _PG_HOST_INFO.url, _DEFAULT_PG_PORT, _PG_REGION) + + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + DSQLTokenUtils(), + mock_session) + target_plugin.connect( + target_driver_func=mocker.MagicMock(), + driver_dialect=mock_dialect, + host_info=_PG_HOST_INFO, + props=test_props, + is_initial_connection=False, + connect_func=mock_func) + + actual_token = _token_cache.get(cache_key) + if user == "admin": + mock_client.generate_db_connect_admin_auth_token.assert_not_called() + assert _GENERATED_TOKEN != actual_token.token + else: + mock_client.generate_db_connect_auth_token.assert_not_called() + assert _GENERATED_TOKEN_NON_ADMIN != actual_token.token + + assert _TEST_TOKEN == actual_token.token + assert actual_token.is_expired() is False + + +@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) +def test_pg_connect_with_invalid_port_fall_backs_to_host_port( + mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): + test_props: Properties = Properties({"user": "admin"}) + invalid_port = "0" + test_props[WrapperProperties.IAM_DEFAULT_PORT.name] = invalid_port + + # Assert no password has been set + assert test_props.get("password") is None + + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + DSQLTokenUtils(), + mock_session) + target_plugin.connect( + target_driver_func=mocker.MagicMock(), + driver_dialect=mock_dialect, + host_info=_PG_HOST_INFO_WITH_PORT, + props=test_props, + is_initial_connection=False, + connect_func=mock_func) + + mock_client.generate_db_connect_admin_auth_token.assert_called_with( + _PG_HOST_INFO.url, _PG_REGION + ) + + actual_token = _token_cache.get(f"{_PG_REGION}:{_PG_HOST_INFO.url}:1234:admin") + assert _GENERATED_TOKEN == actual_token.token + assert actual_token.is_expired() is False + + # Assert password has been updated to the value in token cache + expected_props = {"user": "admin", "iam_default_port": "0"} + mock_dialect.set_password.assert_called_with(expected_props, _GENERATED_TOKEN) + + +@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) +def test_pg_connect_with_invalid_port_and_no_host_port_fall_backs_to_host_port( + mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): + test_props: Properties = Properties({"user": "admin"}) + expected_default_pg_port = 5432 + invalid_port = "0" + test_props[WrapperProperties.IAM_DEFAULT_PORT.name] = invalid_port + + # Assert no password has been set + assert test_props.get("password") is None + + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + DSQLTokenUtils(), + mock_session) + target_plugin.connect( + target_driver_func=mocker.MagicMock(), + driver_dialect=mock_dialect, + host_info=_PG_HOST_INFO, + props=test_props, + is_initial_connection=False, + connect_func=mock_func) + + mock_client.generate_db_connect_admin_auth_token.assert_called_with( + _PG_HOST_INFO.url, _PG_REGION + ) + + actual_token = _token_cache.get( + f"{_PG_REGION}:{_PG_HOST_INFO.url}:{expected_default_pg_port}:admin") + assert _GENERATED_TOKEN == actual_token.token + assert actual_token.is_expired() is False + + # Assert password has been updated to the value in token cache + expected_props = {"user": "admin", "iam_default_port": "0"} + mock_dialect.set_password.assert_called_with(expected_props, _GENERATED_TOKEN) + + +@pytest.mark.parametrize("user", [ + pytest.param("admin"), + pytest.param("non-admin"), +]) +@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) +def test_connect_expired_token_in_cache(user, mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): + test_props: Properties = Properties({"user": user}) + cache_key, initial_token = set_token_cache(user, _PG_HOST_INFO.url, _DEFAULT_PG_PORT, _PG_REGION) + + mock_func.side_effect = Exception("generic exception") + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + DSQLTokenUtils(), + mock_session) + with pytest.raises(Exception): + target_plugin.connect( + target_driver_func=mocker.MagicMock(), + driver_dialect=mock_dialect, + host_info=_PG_HOST_INFO, + props=test_props, + is_initial_connection=False, + connect_func=mock_func) + + actual_token = _token_cache.get(cache_key) + assert initial_token != actual_token + assert actual_token.is_expired() is False + + if user == "admin": + mock_client.generate_db_connect_admin_auth_token.assert_called_with( + _PG_HOST_INFO.url, _PG_REGION) + assert _GENERATED_TOKEN == actual_token.token + else: + mock_client.generate_db_connect_auth_token.assert_called_with( + _PG_HOST_INFO.url, _PG_REGION) + assert _GENERATED_TOKEN_NON_ADMIN == actual_token.token + + +@pytest.mark.parametrize("user", [ + pytest.param("admin"), + pytest.param("non-admin"), +]) +@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) +def test_connect_empty_cache(user, mocker, mock_plugin_service, mock_connection, mock_session, mock_func, mock_client, mock_dialect): + test_props: Properties = Properties({"user": user}) + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + DSQLTokenUtils(), + mock_session) + actual_connection = target_plugin.connect( + target_driver_func=mocker.MagicMock(), + driver_dialect=mock_dialect, + host_info=_PG_HOST_INFO, + props=test_props, + is_initial_connection=False, + connect_func=mock_func) + + cache_key: str = IamAuthUtils.get_cache_key( + user, _PG_HOST_INFO.url, _DEFAULT_PG_PORT, _PG_REGION + ) + actual_token = _token_cache.get(cache_key) + + if user == "admin": + mock_client.generate_db_connect_admin_auth_token.assert_called_with( + _PG_HOST_INFO.url, _PG_REGION + ) + assert _GENERATED_TOKEN == actual_token.token + else: + mock_client.generate_db_connect_auth_token.assert_called_with( + _PG_HOST_INFO.url, _PG_REGION) + assert _GENERATED_TOKEN_NON_ADMIN == actual_token.token + + assert mock_connection == actual_connection + assert actual_token.is_expired() is False + + +@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) +def test_connect_with_specified_port(mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): + test_props: Properties = Properties({"user": "admin"}) + cache_key_with_new_port: str = f"{_PG_REGION}:{_PG_HOST_INFO.url}:1234:admin" + initial_token = TokenInfo(f"{_TEST_TOKEN}:1234", datetime.now() + timedelta(minutes=5)) + _token_cache[cache_key_with_new_port] = initial_token + + # Assert no password has been set + assert test_props.get("password") is None + + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + DSQLTokenUtils(), + mock_session) + target_plugin.connect( + target_driver_func=mocker.MagicMock(), + driver_dialect=mock_dialect, + host_info=_PG_HOST_INFO_WITH_PORT, + props=test_props, + is_initial_connection=False, + connect_func=mock_func) + + mock_client.generate_db_connect_admin_auth_token.assert_not_called() + + actual_token = _token_cache.get(cache_key_with_new_port) + assert _token_cache.get(_PG_CACHE_KEY) is None + assert _GENERATED_TOKEN != actual_token.token + assert f"{_TEST_TOKEN}:1234" == actual_token.token + assert actual_token.is_expired() is False + + # Assert password has been updated to the value in token cache + expected_props = {"user": "admin"} + mock_dialect.set_password.assert_called_with(expected_props, f"{_TEST_TOKEN}:1234") + + +@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) +def test_connect_with_specified_iam_default_port(mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): + test_props: Properties = Properties({"user": "admin"}) + iam_default_port: str = "9999" + test_props[WrapperProperties.IAM_DEFAULT_PORT.name] = iam_default_port + cache_key_with_new_port = f"{_PG_REGION}:{_PG_HOST_INFO.url}:{iam_default_port}:admin" + initial_token = TokenInfo(f"{_TEST_TOKEN}:{iam_default_port}", datetime.now() + timedelta(minutes=5)) + _token_cache[cache_key_with_new_port] = initial_token + + # Assert no password has been set + assert test_props.get("password") is None + + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + DSQLTokenUtils(), + mock_session) + target_plugin.connect( + target_driver_func=mocker.MagicMock(), + driver_dialect=mock_dialect, + host_info=_PG_HOST_INFO_WITH_PORT, + props=test_props, + is_initial_connection=False, + connect_func=mock_func) + + mock_client.generate_db_connect_admin_auth_token.assert_not_called() + + actual_token = _token_cache.get(cache_key_with_new_port) + assert _token_cache.get(_PG_CACHE_KEY) is None + assert _GENERATED_TOKEN != actual_token.token + assert f"{_TEST_TOKEN}:{iam_default_port}" == actual_token.token + assert actual_token.is_expired() is False + + # Assert password has been updated to the value in token cache + expected_props = {"user": "admin", "iam_default_port": "9999"} + mock_dialect.set_password.assert_called_with(expected_props, f"{_TEST_TOKEN}:{iam_default_port}") + + +@pytest.mark.parametrize("user", [ + pytest.param("admin"), + pytest.param("non-admin"), +]) +@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) +def test_connect_with_specified_region(user, mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): + test_props: Properties = Properties({"user": user}) + iam_region: str = "us-east-1" + + # Cache a token with a different region + set_token_cache(user, _PG_HOST_INFO.url, _DEFAULT_PG_PORT, _PG_REGION) + test_props[WrapperProperties.IAM_REGION.name] = iam_region + + # Assert no password has been set + assert test_props.get("password") is None + + mock_client.generate_db_connect_admin_auth_token.return_value = f"{_TEST_TOKEN}:{iam_region}" + mock_client.generate_db_connect_auth_token.return_value = f"{_GENERATED_TOKEN_NON_ADMIN}:{iam_region}" + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + DSQLTokenUtils(), + mock_session) + target_plugin.connect( + target_driver_func=mocker.MagicMock(), + driver_dialect=mock_dialect, + host_info=HostInfo(_PG_HOST_INFO.url), + props=test_props, + is_initial_connection=False, + connect_func=mock_func) + + mock_session.client.assert_called_with( + "dsql", + region_name=iam_region + ) + + expected_props = {"iam_region": "us-east-1", "user": user} + actual_token = _token_cache.get(IamAuthUtils.get_cache_key(user, _PG_HOST_INFO.url, _DEFAULT_PG_PORT, iam_region)) + assert actual_token.is_expired() is False + + if user == "admin": + mock_client.generate_db_connect_admin_auth_token.assert_called_with( + _PG_HOST_INFO.url, iam_region + ) + assert f"{_TEST_TOKEN}:{iam_region}" == actual_token.token + mock_dialect.set_password.assert_called_with(expected_props, f"{_TEST_TOKEN}:{iam_region}") + else: + mock_client.generate_db_connect_auth_token.assert_called_with( + _PG_HOST_INFO.url, iam_region) + assert f"{_GENERATED_TOKEN_NON_ADMIN}:{iam_region}" == actual_token.token + mock_dialect.set_password.assert_called_with(expected_props, f"{_GENERATED_TOKEN_NON_ADMIN}:{iam_region}") + + +@pytest.mark.parametrize("iam_host", [ + pytest.param("dsqltestclusternamefoobar1.dsql-gamma.us-east-2.on.aws"), + pytest.param("dsqltestclusternamefoobar2.dsql-gamma.us-east-2.on.aws"), +]) +@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) +def test_connect_with_specified_host(iam_host: str, mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): + test_props: Properties = Properties({"user": "admin"}) + + test_props[WrapperProperties.IAM_HOST.name] = iam_host + + # Assert no password has been set + assert test_props.get("password") is None + + mock_client.generate_db_connect_admin_auth_token.return_value = f"{_TEST_TOKEN}:{iam_host}" + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + DSQLTokenUtils(), + mock_session) + target_plugin.connect( + target_driver_func=mocker.MagicMock(), + driver_dialect=mock_dialect, + host_info=HostInfo("bar.foo.com"), + props=test_props, + is_initial_connection=False, + connect_func=mock_func) + + mock_client.generate_db_connect_admin_auth_token.assert_called_with( + iam_host, _PG_REGION + ) + + actual_token = _token_cache.get(f"{_PG_REGION}:{iam_host}:5432:admin") + assert actual_token is not None + assert _GENERATED_TOKEN != actual_token.token + assert f"{_TEST_TOKEN}:{iam_host}" == actual_token.token + assert actual_token.is_expired() is False + + +def test_aws_supported_regions_url_exists(): + url = "https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/Concepts.RegionsAndAvailabilityZones.html" + assert 200 == urllib.request.urlopen(url).getcode() + + +@pytest.mark.parametrize("host", [ + pytest.param("<>"), + pytest.param("#"), + pytest.param("'"), + pytest.param("\""), + pytest.param("%"), + pytest.param("^"), + pytest.param("https://foo.com/abc.html"), + pytest.param("foo.boo//"), + pytest.param("8.8.8.8"), + pytest.param("a.b"), +]) +def test_invalid_iam_host(host, mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): + test_props: Properties = Properties({"user": "admin"}) + with pytest.raises(AwsWrapperError): + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, DSQLTokenUtils(), mock_session) + target_plugin.connect( + target_driver_func=mocker.MagicMock(), + driver_dialect=mock_dialect, + host_info=HostInfo(host), + props=test_props, + is_initial_connection=False, + connect_func=mock_func) From c52315b678ddc435cf54bc96b92b12b3b33cea6d Mon Sep 17 00:00:00 2001 From: leszekamz Date: Mon, 21 Jul 2025 21:02:49 -0700 Subject: [PATCH 07/12] sorting imports - isort --- tests/unit/test_iam_dsql_plugin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_iam_dsql_plugin.py b/tests/unit/test_iam_dsql_plugin.py index 3ebff4c8..c20b7379 100644 --- a/tests/unit/test_iam_dsql_plugin.py +++ b/tests/unit/test_iam_dsql_plugin.py @@ -24,10 +24,10 @@ from aws_advanced_python_wrapper.errors import AwsWrapperError from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.iam_plugin import IamAuthPlugin, TokenInfo +from aws_advanced_python_wrapper.utils.dsql_token_utils import DSQLTokenUtils +from aws_advanced_python_wrapper.utils.iam_utils import IamAuthUtils from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) -from aws_advanced_python_wrapper.utils.iam_utils import IamAuthUtils -from aws_advanced_python_wrapper.utils.dsql_token_utils import DSQLTokenUtils _GENERATED_TOKEN = "generated_token admin" _GENERATED_TOKEN_NON_ADMIN = "generated_token non-admin" From 0cc92c5188ee8cf15f271893b8cb39ead1dd9fec Mon Sep 17 00:00:00 2001 From: leszekamz Date: Mon, 21 Jul 2025 23:38:00 -0700 Subject: [PATCH 08/12] adding iam_dsql plugin documentation --- docs/README.md | 1 + .../UsingTheDSQLIamAuthenticationPlugin.md | 37 +++++++++++++++++++ 2 files changed, 38 insertions(+) create mode 100644 docs/using-the-python-driver/using-plugins/UsingTheDSQLIamAuthenticationPlugin.md diff --git a/docs/README.md b/docs/README.md index ef76cae1..ab7a88c9 100644 --- a/docs/README.md +++ b/docs/README.md @@ -17,6 +17,7 @@ - [Aurora Initial Connection Strategy Plugin](./using-the-python-driver/using-plugins/UsingTheAuroraInitialConnectionStrategyPlugin.md) - [Host Availability Strategy](./using-the-python-driver/HostAvailabilityStrategy.md) - [IAM Authentication Plugin](./using-the-python-driver/using-plugins/UsingTheIamAuthenticationPlugin.md) + - [DSQL IAM Authentication Plugin](./using-the-python-driver/using-plugins/UsingTheDSQLIamAuthenticationPlugin.md) - [AWS Secrets Manager Plugin](./using-the-python-driver/using-plugins/UsingTheAwsSecretsManagerPlugin.md) - [Federated Authentication Plugin](./using-the-python-driver/using-plugins/UsingTheFederatedAuthenticationPlugin.md) - [Read Write Splitting Plugin](./using-the-python-driver/using-plugins/UsingTheReadWriteSplittingPlugin.md) diff --git a/docs/using-the-python-driver/using-plugins/UsingTheDSQLIamAuthenticationPlugin.md b/docs/using-the-python-driver/using-plugins/UsingTheDSQLIamAuthenticationPlugin.md new file mode 100644 index 00000000..4defef79 --- /dev/null +++ b/docs/using-the-python-driver/using-plugins/UsingTheDSQLIamAuthenticationPlugin.md @@ -0,0 +1,37 @@ +# AWS Aurora DSQL IAM Authentication Plugin + +This plugin enables connecting to AWS Aurora DSQL databases through AWS Identity and Access Management (IAM). + +## What is IAM? +AWS Identity and Access Management (IAM) grants users access control across all Amazon Web Services. IAM supports granular permissions, giving you the ability to grant different permissions to different users. For more information on IAM and its use cases, please refer to the [IAM documentation](https://docs.aws.amazon.com/IAM/latest/UserGuide/introduction.html). + +## Prerequisites +> [!WARNING]\ +> To preserve compatibility with customers using the community driver, IAM Authentication requires the AWS SDK for Python; [Boto3](https://pypi.org/project/boto3/). Boto3 is a runtime dependency and must be resolved. It can be installed via pip like so: `pip install boto3`. + +The IAM Authentication plugin requires authentication via AWS Credentials. These credentials can be defined in `~/.aws/credentials` or set as environment variables. All users must set `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY`. Users who are using temporary security credentials will also need to additionally set `AWS_SESSION_TOKEN`. + +To enable the AWS Aurora DSQL IAM Authentication Plugin, add the plugin code `iam_dsql` to the [`plugins`](../UsingThePythonDriver.md#connection-plugin-manager-parameters) parameter. + +> [!WARNING]\ +> The `iam` plugin must NOT be specified when using the `iam_dsql` plugin. + +## AWS IAM Database Authentication +The AWS Python Driver supports Amazon AWS Identity and Access Management (IAM) authentication. When using AWS IAM database authentication, the host URL must be a valid AWS Aurora DSQL endpoint, and not a custom domain or an IP address. +
i.e. `cluster-identifier.dsql.us-east-1.on.aws` + + +## How do I use IAM with the AWS Python Driver? +1. Configure IAM roles for the cluster according to [Using database roles and IAM authentication](https://docs.aws.amazon.com/aurora-dsql/latest/userguide/using-database-and-iam-roles.html). +2. Add the plugin code `iam_dsql` to the [`plugins`](../UsingThePythonDriver.md#connection-plugin-manager-parameters) parameter value. + +| Parameter | Value | Required | Description | Example Value | +|--------------------|:-------:|:--------:|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------| +| `iam_host` | String | No | This property will override the default hostname that is used to generate the IAM token. The default hostname is derived from the connection string. This parameter is required when users are connecting with custom endpoints. | `cluster-identifier.dsql.us-east-1.on.aws` | +| `iam_region` | String | No | This property will override the default region that is used to generate the IAM token. The default region is parsed from the connection string. | `us-east-2` | +| `iam_expiration` | Integer | No | This property determines how long an IAM token is kept in the driver cache before a new one is generated. The default expiration time is set to 14 minutes and 30 seconds. Note that IAM database authentication tokens have a lifetime of 15 minutes. | `600` | + +## Sample code + +[DSQLIamAuthentication.py](../../examples/DSQLIamAuthentication.py) + From b1f2180920df2efd2c058de96c44a7748abdbfb3 Mon Sep 17 00:00:00 2001 From: leszekamz Date: Tue, 22 Jul 2025 08:42:05 -0700 Subject: [PATCH 09/12] updated readme.md --- README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/README.md b/README.md index 2f7122f7..e7123796 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,12 @@ Since a database failover is usually identified by reaching a network or a conne Enhanced Failure Monitoring (EFM) is a feature available from the [Host Monitoring Connection Plugin](./docs/using-the-python-driver/using-plugins/UsingTheHostMonitoringPlugin.md#enhanced-failure-monitoring) that periodically checks the connected database host's health and availability. If a database host is determined to be unhealthy, the connection is aborted (and potentially routed to another healthy host in the cluster). +### Using the AWS Advanced Python Driver with AWS Aurora DSQL +The AWS Advanced Python Driver is able to handle IAM authentication when working with AWS Aurora DSQL clusters. + +Please visit [this page](./docs/using-the-python-driver/using-plugins/UsingTheDSQLIamAuthenticationPlugin.md) for more information. + + ### Using the AWS Advanced Python Driver with plain RDS databases The AWS Advanced Python Driver also works with RDS provided databases that are not Aurora. From cf805a8225546436d04fb3e4680ad500af9f28d2 Mon Sep 17 00:00:00 2001 From: leszekamz Date: Wed, 23 Jul 2025 14:56:05 -0700 Subject: [PATCH 10/12] addressing code review comments --- aws_advanced_python_wrapper/federated_plugin.py | 3 ++- aws_advanced_python_wrapper/okta_plugin.py | 3 ++- .../aws_advanced_python_wrapper_messages.properties | 3 ++- aws_advanced_python_wrapper/utils/dsql_token_utils.py | 5 +++-- aws_advanced_python_wrapper/utils/rds_token_utils.py | 2 +- .../using-plugins/UsingTheDSQLIamAuthenticationPlugin.md | 6 ++++-- .../using-plugins/UsingTheIamAuthenticationPlugin.md | 3 +++ 7 files changed, 17 insertions(+), 8 deletions(-) diff --git a/aws_advanced_python_wrapper/federated_plugin.py b/aws_advanced_python_wrapper/federated_plugin.py index e6f1038e..186f6d16 100644 --- a/aws_advanced_python_wrapper/federated_plugin.py +++ b/aws_advanced_python_wrapper/federated_plugin.py @@ -57,7 +57,8 @@ class FederatedAuthPlugin(Plugin): _rds_utils: RdsUtils = RdsUtils() _token_cache: Dict[str, TokenInfo] = {} - def __init__(self, plugin_service: PluginService, + def __init__(self, + plugin_service: PluginService, credentials_provider_factory: CredentialsProviderFactory, token_utils: TokenUtils, session: Optional[Session] = None): diff --git a/aws_advanced_python_wrapper/okta_plugin.py b/aws_advanced_python_wrapper/okta_plugin.py index bd46e2e6..d1e9e19e 100644 --- a/aws_advanced_python_wrapper/okta_plugin.py +++ b/aws_advanced_python_wrapper/okta_plugin.py @@ -53,7 +53,8 @@ class OktaAuthPlugin(Plugin): _rds_utils: RdsUtils = RdsUtils() _token_cache: Dict[str, TokenInfo] = {} - def __init__(self, plugin_service: PluginService, + def __init__(self, + plugin_service: PluginService, credentials_provider_factory: CredentialsProviderFactory, token_utils: TokenUtils, session: Optional[Session] = None): diff --git a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties index a7c37c48..56d13449 100644 --- a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties +++ b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties @@ -151,7 +151,6 @@ IamAuthPlugin.UnhandledException=[IamAuthPlugin] Unhandled exception: {} IamAuthPlugin.UseCachedIamToken=[IamAuthPlugin] Used cached IAM token = {} IamAuthPlugin.InvalidHost=[IamAuthPlugin] Invalid IAM host {}. The IAM host must be a valid RDS or Aurora endpoint. IamAuthPlugin.IsNoneOrEmpty=[IamAuthPlugin] Property "{}" is None or empty. -IamAuthUtils.GeneratedNewAuthToken=Generated new authentication token = {} LimitlessPlugin.FailedToConnectToHost=[LimitlessPlugin] Failed to connect to host {}. LimitlessPlugin.UnsupportedDialectOrDatabase=[LimitlessPlugin] Unsupported dialect '{}' encountered. Please ensure the connection parameters are correct, and refer to the documentation to ensure that the connecting database is compatible with the Limitless Connection Plugin. @@ -316,6 +315,8 @@ RoundRobinHostSelector.ClusterInfoNone=[RoundRobinHostSelector] The round robin RoundRobinHostSelector.RoundRobinInvalidDefaultWeight=[RoundRobinHostSelector] The provided default weight value is not valid. Weight values must be an integer greater than or equal to 1. RoundRobinHostSelector.RoundRobinInvalidHostWeightPairs= [RoundRobinHostSelector] The provided host weight pairs have not been configured correctly. Please ensure the provided host weight pairs is a comma separated list of pairs, each pair in the format of :. Weight values must be an integer greater than or equal to the default weight value of 1. Weight pair: '{}' +TokenUtils.GeneratedNewAuthTokenLength=Generated new authentication token length = {} + WeightedRandomHostSelector.WeightedRandomInvalidHostWeightPairs= [WeightedRandomHostSelector] The provided host weight pairs have not been configured correctly. Please ensure the provided host weight pairs is a comma separated list of pairs, each pair in the format of :. Weight values must be an integer greater than or equal to the default weight value of 1. Weight pair: '{}' WeightedRandomHostSelector.WeightedRandomInvalidDefaultWeight=[WeightedRandomHostSelector] The provided default weight value is not valid. Weight values must be an integer greater than or equal to 1. diff --git a/aws_advanced_python_wrapper/utils/dsql_token_utils.py b/aws_advanced_python_wrapper/utils/dsql_token_utils.py index 47f0f7a7..aa61e02e 100644 --- a/aws_advanced_python_wrapper/utils/dsql_token_utils.py +++ b/aws_advanced_python_wrapper/utils/dsql_token_utils.py @@ -41,7 +41,7 @@ def generate_authentication_token( credentials: Optional[Dict[str, str]] = None, client_session: Optional[Session] = None) -> str: telemetry_factory = plugin_service.get_telemetry_factory() - context = telemetry_factory.open_telemetry_context("fetch authentication token", TelemetryTraceLevel.NESTED) + context = telemetry_factory.open_telemetry_context("fetch DSQL authentication token", TelemetryTraceLevel.NESTED) try: session = client_session if client_session else boto3.Session() @@ -64,7 +64,8 @@ def generate_authentication_token( else: token = client.generate_db_connect_auth_token(host_name, region) - logger.debug("IamAuthUtils.GeneratedNewAuthToken", token) + logger.debug("TokenUtils.GeneratedNewAuthTokenLength", len(token) if token else 0) + client.close() return token except Exception as ex: context.set_success(False) diff --git a/aws_advanced_python_wrapper/utils/rds_token_utils.py b/aws_advanced_python_wrapper/utils/rds_token_utils.py index 3d6c0dea..3497cc65 100644 --- a/aws_advanced_python_wrapper/utils/rds_token_utils.py +++ b/aws_advanced_python_wrapper/utils/rds_token_utils.py @@ -69,7 +69,7 @@ def generate_authentication_token( client.close() - logger.debug("IamAuthUtils.GeneratedNewAuthToken", token) + logger.debug("TokenUtils.GeneratedNewAuthTokenLength", len(token) if token else 0) return token except Exception as ex: context.set_success(False) diff --git a/docs/using-the-python-driver/using-plugins/UsingTheDSQLIamAuthenticationPlugin.md b/docs/using-the-python-driver/using-plugins/UsingTheDSQLIamAuthenticationPlugin.md index 4defef79..6b790b1a 100644 --- a/docs/using-the-python-driver/using-plugins/UsingTheDSQLIamAuthenticationPlugin.md +++ b/docs/using-the-python-driver/using-plugins/UsingTheDSQLIamAuthenticationPlugin.md @@ -9,7 +9,7 @@ AWS Identity and Access Management (IAM) grants users access control across all > [!WARNING]\ > To preserve compatibility with customers using the community driver, IAM Authentication requires the AWS SDK for Python; [Boto3](https://pypi.org/project/boto3/). Boto3 is a runtime dependency and must be resolved. It can be installed via pip like so: `pip install boto3`. -The IAM Authentication plugin requires authentication via AWS Credentials. These credentials can be defined in `~/.aws/credentials` or set as environment variables. All users must set `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY`. Users who are using temporary security credentials will also need to additionally set `AWS_SESSION_TOKEN`. +The DSQL IAM Authentication plugin requires authentication via AWS Credentials. These credentials can be defined in `~/.aws/credentials` or set as environment variables. All users must set `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY`. Users who are using temporary security credentials will also need to additionally set `AWS_SESSION_TOKEN`. To enable the AWS Aurora DSQL IAM Authentication Plugin, add the plugin code `iam_dsql` to the [`plugins`](../UsingThePythonDriver.md#connection-plugin-manager-parameters) parameter. @@ -20,6 +20,8 @@ To enable the AWS Aurora DSQL IAM Authentication Plugin, add the plugin code `ia The AWS Python Driver supports Amazon AWS Identity and Access Management (IAM) authentication. When using AWS IAM database authentication, the host URL must be a valid AWS Aurora DSQL endpoint, and not a custom domain or an IP address.
i.e. `cluster-identifier.dsql.us-east-1.on.aws` +Connections established by the `iamDsql` plugin are beholden to the [Cluster quotas and database limits in Amazon Aurora DSQL](https://docs.aws.amazon.com/aurora-dsql/latest/userguide/CHAP_quotas.html). In particular, applications need to consider the maximum transaction duration, and maximum connection duration limits. Ensure connections are returned to the pool regularly, and not retained for long periods. + ## How do I use IAM with the AWS Python Driver? 1. Configure IAM roles for the cluster according to [Using database roles and IAM authentication](https://docs.aws.amazon.com/aurora-dsql/latest/userguide/using-database-and-iam-roles.html). @@ -28,7 +30,7 @@ The AWS Python Driver supports Amazon AWS Identity and Access Management (IAM) a | Parameter | Value | Required | Description | Example Value | |--------------------|:-------:|:--------:|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------| | `iam_host` | String | No | This property will override the default hostname that is used to generate the IAM token. The default hostname is derived from the connection string. This parameter is required when users are connecting with custom endpoints. | `cluster-identifier.dsql.us-east-1.on.aws` | -| `iam_region` | String | No | This property will override the default region that is used to generate the IAM token. The default region is parsed from the connection string. | `us-east-2` | +| `iam_region` | String | No | This property will override the default region that is used to generate the IAM token. The default region is parsed from the connection string where possible. Some connection string formats may not be supported, and the `iam_region` must be provided in these cases. | `us-east-2` | | `iam_expiration` | Integer | No | This property determines how long an IAM token is kept in the driver cache before a new one is generated. The default expiration time is set to 14 minutes and 30 seconds. Note that IAM database authentication tokens have a lifetime of 15 minutes. | `600` | ## Sample code diff --git a/docs/using-the-python-driver/using-plugins/UsingTheIamAuthenticationPlugin.md b/docs/using-the-python-driver/using-plugins/UsingTheIamAuthenticationPlugin.md index e879e2a6..a5fadfcc 100644 --- a/docs/using-the-python-driver/using-plugins/UsingTheIamAuthenticationPlugin.md +++ b/docs/using-the-python-driver/using-plugins/UsingTheIamAuthenticationPlugin.md @@ -11,6 +11,9 @@ The IAM Authentication plugin requires authentication via AWS Credentials. These To enable the IAM Authentication Connection Plugin, add the plugin code `iam` to the [`plugins`](../UsingThePythonDriver.md#connection-plugin-manager-parameters) parameter. +> [!WARNING]\ +> The `iam` plugin must NOT be specified when using the `iam_dsql` plugin. + ## AWS IAM Database Authentication The AWS Python Driver supports Amazon AWS Identity and Access Management (IAM) authentication. When using AWS IAM database authentication, the host URL must be a valid Amazon endpoint, and not a custom domain or an IP address.
i.e. `db-identifier.cluster-XYZ.us-east-2.rds.amazonaws.com` From e1503c9a5c9beaae7c0397fabf6978454825cc75 Mon Sep 17 00:00:00 2001 From: leszekamz Date: Wed, 23 Jul 2025 17:44:56 -0700 Subject: [PATCH 11/12] addressing more code review comments --- tests/unit/test_iam_dsql_plugin.py | 52 ++++++++---------------------- 1 file changed, 13 insertions(+), 39 deletions(-) diff --git a/tests/unit/test_iam_dsql_plugin.py b/tests/unit/test_iam_dsql_plugin.py index c20b7379..edd81d48 100644 --- a/tests/unit/test_iam_dsql_plugin.py +++ b/tests/unit/test_iam_dsql_plugin.py @@ -21,7 +21,6 @@ import pytest -from aws_advanced_python_wrapper.errors import AwsWrapperError from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.iam_plugin import IamAuthPlugin, TokenInfo from aws_advanced_python_wrapper.utils.dsql_token_utils import DSQLTokenUtils @@ -34,7 +33,7 @@ _TEST_TOKEN = "test_token" _DEFAULT_PG_PORT = 5432 -_PG_HOST_INFO = HostInfo("dsqltestclusternamefoobar1.dsql-gamma.us-east-2.on.aws") +_PG_HOST_INFO = HostInfo("dsqltestclusternamefoobar1.dsql.us-east-2.on.aws") _PG_HOST_INFO_WITH_PORT = HostInfo(_PG_HOST_INFO.url, port=1234) _PG_REGION = "us-east-2" @@ -137,7 +136,7 @@ def test_pg_connect_valid_token_in_cache(user, mocker, mock_plugin_service, mock assert _GENERATED_TOKEN_NON_ADMIN != actual_token.token assert _TEST_TOKEN == actual_token.token - assert actual_token.is_expired() is False + assert not actual_token.is_expired() @patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) @@ -167,7 +166,7 @@ def test_pg_connect_with_invalid_port_fall_backs_to_host_port( actual_token = _token_cache.get(f"{_PG_REGION}:{_PG_HOST_INFO.url}:1234:admin") assert _GENERATED_TOKEN == actual_token.token - assert actual_token.is_expired() is False + assert not actual_token.is_expired() # Assert password has been updated to the value in token cache expected_props = {"user": "admin", "iam_default_port": "0"} @@ -203,7 +202,7 @@ def test_pg_connect_with_invalid_port_and_no_host_port_fall_backs_to_host_port( actual_token = _token_cache.get( f"{_PG_REGION}:{_PG_HOST_INFO.url}:{expected_default_pg_port}:admin") assert _GENERATED_TOKEN == actual_token.token - assert actual_token.is_expired() is False + assert not actual_token.is_expired() # Assert password has been updated to the value in token cache expected_props = {"user": "admin", "iam_default_port": "0"} @@ -217,7 +216,7 @@ def test_pg_connect_with_invalid_port_and_no_host_port_fall_backs_to_host_port( @patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) def test_connect_expired_token_in_cache(user, mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): test_props: Properties = Properties({"user": user}) - cache_key, initial_token = set_token_cache(user, _PG_HOST_INFO.url, _DEFAULT_PG_PORT, _PG_REGION) + cache_key, initial_token = set_token_cache(user, _PG_HOST_INFO.url, _DEFAULT_PG_PORT, _PG_REGION, True) mock_func.side_effect = Exception("generic exception") target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, @@ -234,7 +233,7 @@ def test_connect_expired_token_in_cache(user, mocker, mock_plugin_service, mock_ actual_token = _token_cache.get(cache_key) assert initial_token != actual_token - assert actual_token.is_expired() is False + assert not actual_token.is_expired() if user == "admin": mock_client.generate_db_connect_admin_auth_token.assert_called_with( @@ -280,7 +279,7 @@ def test_connect_empty_cache(user, mocker, mock_plugin_service, mock_connection, assert _GENERATED_TOKEN_NON_ADMIN == actual_token.token assert mock_connection == actual_connection - assert actual_token.is_expired() is False + assert not actual_token.is_expired() @patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) @@ -310,7 +309,7 @@ def test_connect_with_specified_port(mocker, mock_plugin_service, mock_session, assert _token_cache.get(_PG_CACHE_KEY) is None assert _GENERATED_TOKEN != actual_token.token assert f"{_TEST_TOKEN}:1234" == actual_token.token - assert actual_token.is_expired() is False + assert not actual_token.is_expired() # Assert password has been updated to the value in token cache expected_props = {"user": "admin"} @@ -346,7 +345,7 @@ def test_connect_with_specified_iam_default_port(mocker, mock_plugin_service, mo assert _token_cache.get(_PG_CACHE_KEY) is None assert _GENERATED_TOKEN != actual_token.token assert f"{_TEST_TOKEN}:{iam_default_port}" == actual_token.token - assert actual_token.is_expired() is False + assert not actual_token.is_expired() # Assert password has been updated to the value in token cache expected_props = {"user": "admin", "iam_default_port": "9999"} @@ -389,7 +388,7 @@ def test_connect_with_specified_region(user, mocker, mock_plugin_service, mock_s expected_props = {"iam_region": "us-east-1", "user": user} actual_token = _token_cache.get(IamAuthUtils.get_cache_key(user, _PG_HOST_INFO.url, _DEFAULT_PG_PORT, iam_region)) - assert actual_token.is_expired() is False + assert not actual_token.is_expired() if user == "admin": mock_client.generate_db_connect_admin_auth_token.assert_called_with( @@ -405,8 +404,8 @@ def test_connect_with_specified_region(user, mocker, mock_plugin_service, mock_s @pytest.mark.parametrize("iam_host", [ - pytest.param("dsqltestclusternamefoobar1.dsql-gamma.us-east-2.on.aws"), - pytest.param("dsqltestclusternamefoobar2.dsql-gamma.us-east-2.on.aws"), + pytest.param("dsqltestclusternamefoobar1.dsql.us-east-2.on.aws"), + pytest.param("dsqltestclusternamefoobar2.dsql.us-east-2.on.aws"), ]) @patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) def test_connect_with_specified_host(iam_host: str, mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): @@ -437,34 +436,9 @@ def test_connect_with_specified_host(iam_host: str, mocker, mock_plugin_service, assert actual_token is not None assert _GENERATED_TOKEN != actual_token.token assert f"{_TEST_TOKEN}:{iam_host}" == actual_token.token - assert actual_token.is_expired() is False + assert not actual_token.is_expired() def test_aws_supported_regions_url_exists(): url = "https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/Concepts.RegionsAndAvailabilityZones.html" assert 200 == urllib.request.urlopen(url).getcode() - - -@pytest.mark.parametrize("host", [ - pytest.param("<>"), - pytest.param("#"), - pytest.param("'"), - pytest.param("\""), - pytest.param("%"), - pytest.param("^"), - pytest.param("https://foo.com/abc.html"), - pytest.param("foo.boo//"), - pytest.param("8.8.8.8"), - pytest.param("a.b"), -]) -def test_invalid_iam_host(host, mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): - test_props: Properties = Properties({"user": "admin"}) - with pytest.raises(AwsWrapperError): - target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, DSQLTokenUtils(), mock_session) - target_plugin.connect( - target_driver_func=mocker.MagicMock(), - driver_dialect=mock_dialect, - host_info=HostInfo(host), - props=test_props, - is_initial_connection=False, - connect_func=mock_func) From ab49fbc8e514b055996674875ba745354d6e3e9d Mon Sep 17 00:00:00 2001 From: leszekamz Date: Fri, 25 Jul 2025 14:17:58 -0700 Subject: [PATCH 12/12] Added integration tests for dsql and updated relevant files --- .github/workflows/integration_tests.yml | 2 +- .../UsingTheDSQLIamAuthenticationPlugin.md | 2 +- .../container/test_aurora_failover.py | 4 +- .../integration/container/test_autoscaling.py | 3 +- .../container/test_basic_connectivity.py | 4 +- .../container/test_basic_functionality.py | 4 +- .../container/test_custom_endpoint.py | 4 +- .../container/test_iam_authentication.py | 4 +- .../container/test_iam_dsql_authentication.py | 118 ++++++++++++++ .../container/test_read_write_splitting.py | 4 +- .../utils/database_engine_deployment.py | 1 + .../utils/test_environment_features.py | 1 + tests/integration/host/build.gradle.kts | 96 +++++++++++ .../integration/DatabaseEngineDeployment.java | 3 +- .../integration/TestEnvironmentFeatures.java | 1 + .../integration/host/TestEnvironment.java | 151 ++++++++++++++++-- .../host/TestEnvironmentConfiguration.java | 2 + .../host/TestEnvironmentProvider.java | 11 +- .../integration/util/AuroraTestUtility.java | 110 ++++++++++++- 19 files changed, 495 insertions(+), 30 deletions(-) create mode 100644 tests/integration/container/test_iam_dsql_authentication.py diff --git a/.github/workflows/integration_tests.yml b/.github/workflows/integration_tests.yml index 937545b8..3bbc2572 100644 --- a/.github/workflows/integration_tests.yml +++ b/.github/workflows/integration_tests.yml @@ -19,7 +19,7 @@ jobs: matrix: python-version: [ "3.8", "3.11" ] engine-version: [ "lts", "latest"] - environment: ["mysql", "pg"] + environment: ["mysql", "pg", "dsql"] steps: - name: 'Clone repository' diff --git a/docs/using-the-python-driver/using-plugins/UsingTheDSQLIamAuthenticationPlugin.md b/docs/using-the-python-driver/using-plugins/UsingTheDSQLIamAuthenticationPlugin.md index 6b790b1a..87e36553 100644 --- a/docs/using-the-python-driver/using-plugins/UsingTheDSQLIamAuthenticationPlugin.md +++ b/docs/using-the-python-driver/using-plugins/UsingTheDSQLIamAuthenticationPlugin.md @@ -20,7 +20,7 @@ To enable the AWS Aurora DSQL IAM Authentication Plugin, add the plugin code `ia The AWS Python Driver supports Amazon AWS Identity and Access Management (IAM) authentication. When using AWS IAM database authentication, the host URL must be a valid AWS Aurora DSQL endpoint, and not a custom domain or an IP address.
i.e. `cluster-identifier.dsql.us-east-1.on.aws` -Connections established by the `iamDsql` plugin are beholden to the [Cluster quotas and database limits in Amazon Aurora DSQL](https://docs.aws.amazon.com/aurora-dsql/latest/userguide/CHAP_quotas.html). In particular, applications need to consider the maximum transaction duration, and maximum connection duration limits. Ensure connections are returned to the pool regularly, and not retained for long periods. +Connections established by the `iam_dsql` plugin are beholden to the [Cluster quotas and database limits in Amazon Aurora DSQL](https://docs.aws.amazon.com/aurora-dsql/latest/userguide/CHAP_quotas.html). In particular, applications need to consider the maximum transaction duration, and maximum connection duration limits. Ensure connections are returned to the pool regularly, and not retained for long periods. ## How do I use IAM with the AWS Python Driver? diff --git a/tests/integration/container/test_aurora_failover.py b/tests/integration/container/test_aurora_failover.py index 1a303e33..f82dcaa3 100644 --- a/tests/integration/container/test_aurora_failover.py +++ b/tests/integration/container/test_aurora_failover.py @@ -42,7 +42,9 @@ @enable_on_num_instances(min_instances=2) @enable_on_deployments([DatabaseEngineDeployment.AURORA, DatabaseEngineDeployment.RDS_MULTI_AZ]) -@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, TestEnvironmentFeatures.PERFORMANCE]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.PERFORMANCE, + TestEnvironmentFeatures.RUN_DSQL_TESTS_ONLY]) class TestAuroraFailover: IDLE_CONNECTIONS_NUM: int = 5 logger = Logger(__name__) diff --git a/tests/integration/container/test_autoscaling.py b/tests/integration/container/test_autoscaling.py index 61c9e7d2..ba328e82 100644 --- a/tests/integration/container/test_autoscaling.py +++ b/tests/integration/container/test_autoscaling.py @@ -42,7 +42,8 @@ @enable_on_num_instances(min_instances=5) -@enable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY]) +@enable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.RUN_DSQL_TESTS_ONLY]) class TestAutoScaling: @pytest.fixture def rds_utils(self): diff --git a/tests/integration/container/test_basic_connectivity.py b/tests/integration/container/test_basic_connectivity.py index ef03614e..bf56f14b 100644 --- a/tests/integration/container/test_basic_connectivity.py +++ b/tests/integration/container/test_basic_connectivity.py @@ -36,7 +36,9 @@ from .utils.test_environment_features import TestEnvironmentFeatures -@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, TestEnvironmentFeatures.PERFORMANCE]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.PERFORMANCE, + TestEnvironmentFeatures.RUN_DSQL_TESTS_ONLY]) class TestBasicConnectivity: @pytest.fixture(scope='class') diff --git a/tests/integration/container/test_basic_functionality.py b/tests/integration/container/test_basic_functionality.py index 34f66c62..193f2d39 100644 --- a/tests/integration/container/test_basic_functionality.py +++ b/tests/integration/container/test_basic_functionality.py @@ -46,7 +46,9 @@ from .utils.test_environment_features import TestEnvironmentFeatures -@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, TestEnvironmentFeatures.PERFORMANCE]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.PERFORMANCE, + TestEnvironmentFeatures.RUN_DSQL_TESTS_ONLY]) class TestBasicFunctionality: @pytest.fixture(scope='class') diff --git a/tests/integration/container/test_custom_endpoint.py b/tests/integration/container/test_custom_endpoint.py index ee33bcfa..dbda019d 100644 --- a/tests/integration/container/test_custom_endpoint.py +++ b/tests/integration/container/test_custom_endpoint.py @@ -45,7 +45,9 @@ @enable_on_num_instances(min_instances=3) @enable_on_deployments([DatabaseEngineDeployment.AURORA]) -@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, TestEnvironmentFeatures.PERFORMANCE]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.PERFORMANCE, + TestEnvironmentFeatures.RUN_DSQL_TESTS_ONLY]) class TestCustomEndpoint: logger: ClassVar[Logger] = Logger(__name__) endpoint_id: ClassVar[str] = f"test-endpoint-1-{uuid4()}" diff --git a/tests/integration/container/test_iam_authentication.py b/tests/integration/container/test_iam_authentication.py index 0e4e2e01..9ef98f21 100644 --- a/tests/integration/container/test_iam_authentication.py +++ b/tests/integration/container/test_iam_authentication.py @@ -39,7 +39,9 @@ @enable_on_features([TestEnvironmentFeatures.IAM]) -@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, TestEnvironmentFeatures.PERFORMANCE]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.PERFORMANCE, + TestEnvironmentFeatures.RUN_DSQL_TESTS_ONLY]) class TestAwsIamAuthentication: @pytest.fixture(scope='class') diff --git a/tests/integration/container/test_iam_dsql_authentication.py b/tests/integration/container/test_iam_dsql_authentication.py new file mode 100644 index 00000000..3c4ce39c --- /dev/null +++ b/tests/integration/container/test_iam_dsql_authentication.py @@ -0,0 +1,118 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from aws_advanced_python_wrapper.utils.properties import (Properties, + WrapperProperties) +from tests.integration.container.utils.test_environment_features import \ + TestEnvironmentFeatures + +if TYPE_CHECKING: + from tests.integration.container.utils.test_driver import TestDriver + +from socket import gethostbyname +from typing import Callable + +import pytest + +from aws_advanced_python_wrapper import AwsWrapperConnection +from aws_advanced_python_wrapper.errors import AwsWrapperError +from tests.integration.container.utils.conditions import enable_on_features +from tests.integration.container.utils.driver_helper import DriverHelper +from tests.integration.container.utils.test_environment import TestEnvironment + + +@enable_on_features([TestEnvironmentFeatures.RUN_DSQL_TESTS_ONLY]) +class TestAwsIamDSQLAuthentication: + + @pytest.fixture(scope='class') + def props(self): + p: Properties = Properties() + + if TestEnvironmentFeatures.TELEMETRY_TRACES_ENABLED in TestEnvironment.get_current().get_features() \ + or TestEnvironmentFeatures.TELEMETRY_METRICS_ENABLED in TestEnvironment.get_current().get_features(): + WrapperProperties.ENABLE_TELEMETRY.set(p, "True") + WrapperProperties.TELEMETRY_SUBMIT_TOPLEVEL.set(p, "True") + + if TestEnvironmentFeatures.TELEMETRY_TRACES_ENABLED in TestEnvironment.get_current().get_features(): + WrapperProperties.TELEMETRY_TRACES_BACKEND.set(p, "XRAY") + + if TestEnvironmentFeatures.TELEMETRY_METRICS_ENABLED in TestEnvironment.get_current().get_features(): + WrapperProperties.TELEMETRY_METRICS_BACKEND.set(p, "OTLP") + + return p + + def test_iam_wrong_database_username(self, test_environment: TestEnvironment, + test_driver: TestDriver, conn_utils, props): + target_driver_connect = DriverHelper.get_connect_func(test_driver) + user = f"WRONG_{conn_utils.iam_user}_USER" + params = conn_utils.get_connect_params(user=user) + params.pop("use_pure", None) # AWS tokens are truncated when using the pure Python MySQL driver + + with pytest.raises(AwsWrapperError): + AwsWrapperConnection.connect( + target_driver_connect, + **params, + plugins="iam_dsql", + **props) + + def test_iam_no_database_username(self, test_driver: TestDriver, conn_utils, props): + target_driver_connect = DriverHelper.get_connect_func(test_driver) + params = conn_utils.get_connect_params() + params.pop("use_pure", None) # AWS tokens are truncated when using the pure Python MySQL driver + params.pop("user", None) + + with pytest.raises(AwsWrapperError): + AwsWrapperConnection.connect(target_driver_connect, **params, plugins="iam_dsql", **props) + + def test_iam_invalid_host(self, test_driver: TestDriver, conn_utils, props): + target_driver_connect = DriverHelper.get_connect_func(test_driver) + params = conn_utils.get_connect_params() + params.pop("use_pure", None) # AWS tokens are truncated when using the pure Python MySQL driver + params.update({"iam_host": "<>", "plugins": "iam_dsql"}) + + with pytest.raises(AwsWrapperError): + AwsWrapperConnection.connect(target_driver_connect, **params, **props) + + def test_iam_valid_connection_properties( + self, test_environment: TestEnvironment, test_driver: TestDriver, conn_utils, props): + target_driver_connect = DriverHelper.get_connect_func(test_driver) + params = conn_utils.get_connect_params(user=conn_utils.iam_user, password="") + params.pop("use_pure", None) # AWS tokens are truncated when using the pure Python MySQL driver + params["plugins"] = "iam_dsql" + + self.validate_connection(target_driver_connect, **params, **props) + + def test_iam_valid_connection_properties_no_password( + self, test_environment: TestEnvironment, test_driver: TestDriver, conn_utils, props): + target_driver_connect = DriverHelper.get_connect_func(test_driver) + params = conn_utils.get_connect_params(user=conn_utils.iam_user) + params.pop("use_pure", None) # AWS tokens are truncated when using the pure Python MySQL driver + params.pop("password", None) + params["plugins"] = "iam_dsql" + + self.validate_connection(target_driver_connect, **params, **props) + + def get_ip_address(self, hostname: str): + return gethostbyname(hostname) + + def validate_connection(self, target_driver_connect: Callable, **connect_params): + with AwsWrapperConnection.connect(target_driver_connect, **connect_params) as conn, \ + conn.cursor() as cursor: + cursor.execute("SELECT now()") + records = cursor.fetchall() + assert len(records) == 1 diff --git a/tests/integration/container/test_read_write_splitting.py b/tests/integration/container/test_read_write_splitting.py index 86ad049a..6e352419 100644 --- a/tests/integration/container/test_read_write_splitting.py +++ b/tests/integration/container/test_read_write_splitting.py @@ -43,7 +43,9 @@ @enable_on_num_instances(min_instances=2) @enable_on_deployments([DatabaseEngineDeployment.AURORA, DatabaseEngineDeployment.RDS_MULTI_AZ]) -@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, TestEnvironmentFeatures.PERFORMANCE]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.PERFORMANCE, + TestEnvironmentFeatures.RUN_DSQL_TESTS_ONLY]) class TestReadWriteSplitting: @pytest.fixture(scope='class') def rds_utils(self): diff --git a/tests/integration/container/utils/database_engine_deployment.py b/tests/integration/container/utils/database_engine_deployment.py index 3f542745..8bf401a3 100644 --- a/tests/integration/container/utils/database_engine_deployment.py +++ b/tests/integration/container/utils/database_engine_deployment.py @@ -20,3 +20,4 @@ class DatabaseEngineDeployment(str, Enum): RDS = "RDS" RDS_MULTI_AZ = "RDS_MULTI_AZ" AURORA = "AURORA" + DSQL = "DSQL" diff --git a/tests/integration/container/utils/test_environment_features.py b/tests/integration/container/utils/test_environment_features.py index dfbb7fd9..62c5b4a4 100644 --- a/tests/integration/container/utils/test_environment_features.py +++ b/tests/integration/container/utils/test_environment_features.py @@ -26,6 +26,7 @@ class TestEnvironmentFeatures(Enum): AWS_CREDENTIALS_ENABLED = "AWS_CREDENTIALS_ENABLED" PERFORMANCE = "PERFORMANCE" RUN_AUTOSCALING_TESTS_ONLY = "RUN_AUTOSCALING_TESTS_ONLY" + RUN_DSQL_TESTS_ONLY = "RUN_DSQL_TESTS_ONLY" SKIP_MYSQL_DRIVER_TESTS = "SKIP_MYSQL_DRIVER_TESTS" SKIP_PG_DRIVER_TESTS = "SKIP_PG_DRIVER_TESTS" TELEMETRY_TRACES_ENABLED = "TELEMETRY_TRACES_ENABLED" diff --git a/tests/integration/host/build.gradle.kts b/tests/integration/host/build.gradle.kts index 7351d814..112cde47 100644 --- a/tests/integration/host/build.gradle.kts +++ b/tests/integration/host/build.gradle.kts @@ -30,6 +30,7 @@ dependencies { testImplementation("software.amazon.awssdk:rds:2.20.49") testImplementation("software.amazon.awssdk:ec2:2.20.61") testImplementation("software.amazon.awssdk:secretsmanager:2.20.49") + testImplementation("software.amazon.awssdk:dsql:2.29.34") // Note: all org.testcontainers dependencies should have the same version testImplementation("org.testcontainers:testcontainers:1.21.2") testImplementation("org.testcontainers:mysql:1.21.2") @@ -72,6 +73,7 @@ tasks.register("test-python-3.11-mysql") { systemProperty("exclude-multi-az", "true") systemProperty("exclude-pg-driver", "true") systemProperty("exclude-pg-engine", "true") + systemProperty("exclude-dsql", "true") } } @@ -84,6 +86,7 @@ tasks.register("test-python-3.8-mysql") { systemProperty("exclude-multi-az", "true") systemProperty("exclude-pg-driver", "true") systemProperty("exclude-pg-engine", "true") + systemProperty("exclude-dsql", "true") } } @@ -98,6 +101,7 @@ tasks.register("test-python-3.11-pg") { systemProperty("exclude-mysql-engine", "true") systemProperty("exclude-mariadb-driver", "true") systemProperty("exclude-mariadb-engine", "true") + systemProperty("exclude-dsql", "true") } } @@ -112,6 +116,43 @@ tasks.register("test-python-3.8-pg") { systemProperty("exclude-mysql-engine", "true") systemProperty("exclude-mariadb-driver", "true") systemProperty("exclude-mariadb-engine", "true") + systemProperty("exclude-dsql", "true") + } +} + +tasks.register("test-python-3.11-dsql") { + group = "verification" + filter.includeTestsMatching("integration.host.TestRunner.runTests") + doFirst { + systemProperty("exclude-aurora", "true") + systemProperty("exclude-python-38", "true") + systemProperty("exclude-autoscaling", "true") + systemProperty("exclude-docker", "true") + systemProperty("exclude-multi-az", "true") + systemProperty("exclude-mysql-driver", "true") + systemProperty("exclude-mysql-engine", "true") + systemProperty("exclude-mariadb-driver", "true") + systemProperty("exclude-mariadb-engine", "true") + systemProperty("exclude-performance", "true") + systemProperty("exclude-secrets-manager", "true") + } +} + +tasks.register("test-python-3.8-dsql") { + group = "verification" + filter.includeTestsMatching("integration.host.TestRunner.runTests") + doFirst { + systemProperty("exclude-aurora", "true") + systemProperty("exclude-python-311", "true") + systemProperty("exclude-autoscaling", "true") + systemProperty("exclude-docker", "true") + systemProperty("exclude-multi-az", "true") + systemProperty("exclude-mysql-driver", "true") + systemProperty("exclude-mysql-engine", "true") + systemProperty("exclude-mariadb-driver", "true") + systemProperty("exclude-mariadb-engine", "true") + systemProperty("exclude-performance", "true") + systemProperty("exclude-secrets-manager", "true") } } @@ -123,6 +164,7 @@ tasks.register("test-docker") { systemProperty("exclude-multi-az", "true") systemProperty("exclude-performance", "true") systemProperty("exclude-python-38", "true") + systemProperty("exclude-dsql", "true") } } @@ -134,6 +176,7 @@ tasks.register("test-aurora") { systemProperty("exclude-multi-az", "true") systemProperty("exclude-performance", "true") systemProperty("exclude-python-38", "true") + systemProperty("exclude-dsql", "true") } } @@ -148,6 +191,7 @@ tasks.register("test-pg-aurora") { systemProperty("exclude-mysql-engine", "true") systemProperty("exclude-mariadb-driver", "true") systemProperty("exclude-mariadb-engine", "true") + systemProperty("exclude-dsql", "true") } } @@ -160,6 +204,7 @@ tasks.register("test-mysql-aurora") { systemProperty("exclude-performance", "true") systemProperty("exclude-pg-driver", "true") systemProperty("exclude-pg-engine", "true") + systemProperty("exclude-dsql", "true") } } @@ -171,6 +216,7 @@ tasks.register("test-multi-az") { systemProperty("exclude-performance", "true") systemProperty("exclude-aurora", "true") systemProperty("exclude-python-38", "true") + systemProperty("exclude-dsql", "true") } } @@ -185,6 +231,7 @@ tasks.register("test-pg-multi-az") { systemProperty("exclude-mysql-engine", "true") systemProperty("exclude-mariadb-driver", "true") systemProperty("exclude-mariadb-engine", "true") + systemProperty("exclude-dsql", "true") } } @@ -197,6 +244,7 @@ tasks.register("test-mysql-multi-az") { systemProperty("exclude-aurora", "true") systemProperty("exclude-pg-driver", "true") systemProperty("exclude-pg-engine", "true") + systemProperty("exclude-dsql", "true") } } @@ -209,6 +257,7 @@ tasks.register("test-autoscaling") { systemProperty("exclude-performance", "true") systemProperty("exclude-mysql-driver", "true") systemProperty("exclude-mysql-engine", "true") + systemProperty("exclude-dsql", "true") } } @@ -224,6 +273,7 @@ tasks.register("test-pg-aurora-performance") { systemProperty("exclude-mysql-engine", "true") systemProperty("exclude-mariadb-driver", "true") systemProperty("exclude-mariadb-engine", "true") + systemProperty("exclude-dsql", "true") } } @@ -237,6 +287,24 @@ tasks.register("test-mysql-aurora-performance") { systemProperty("exclude-secrets-manager", "true") systemProperty("exclude-pg-driver", "true") systemProperty("exclude-pg-engine", "true") + systemProperty("exclude-dsql", "true") + } +} + +tasks.register("test-all-dsql") { + group = "verification" + filter.includeTestsMatching("integration.host.TestRunner.runTests") + doFirst { + systemProperty("exclude-aurora", "true") + systemProperty("exclude-autoscaling", "true") + systemProperty("exclude-docker", "true") + systemProperty("exclude-multi-az", "true") + systemProperty("exclude-mysql-driver", "true") + systemProperty("exclude-mysql-engine", "true") + systemProperty("exclude-mariadb-driver", "true") + systemProperty("exclude-mariadb-engine", "true") + systemProperty("exclude-performance", "true") + systemProperty("exclude-secrets-manager", "true") } } @@ -248,6 +316,7 @@ tasks.register("debug-all-environments") { doFirst { systemProperty("exclude-performance", "true") systemProperty("exclude-python-38", "true") + systemProperty("exclude-dsql", "true") } } @@ -259,6 +328,7 @@ tasks.register("debug-docker") { systemProperty("exclude-multi-az", "true") systemProperty("exclude-performance", "true") systemProperty("exclude-python-38", "true") + systemProperty("exclude-dsql", "true") } } @@ -270,6 +340,7 @@ tasks.register("debug-aurora") { systemProperty("exclude-multi-az", "true") systemProperty("exclude-performance", "true") systemProperty("exclude-python-38", "true") + systemProperty("exclude-dsql", "true") } } @@ -282,6 +353,7 @@ tasks.register("debug-pg-aurora") { systemProperty("exclude-performance", "true") systemProperty("exclude-mysql-driver", "true") systemProperty("exclude-mysql-engine", "true") + systemProperty("exclude-dsql", "true") } } @@ -294,6 +366,7 @@ tasks.register("debug-mysql-aurora") { systemProperty("exclude-performance", "true") systemProperty("exclude-pg-driver", "true") systemProperty("exclude-pg-engine", "true") + systemProperty("exclude-dsql", "true") } } @@ -307,6 +380,7 @@ tasks.register("debug-autoscaling") { systemProperty("exclude-performance", "true") systemProperty("exclude-mysql-driver", "true") systemProperty("exclude-mysql-engine", "true") + systemProperty("exclude-dsql", "true") } } @@ -322,6 +396,7 @@ tasks.register("debug-pg-aurora-performance") { systemProperty("exclude-mysql-engine", "true") systemProperty("exclude-mariadb-driver", "true") systemProperty("exclude-mariadb-engine", "true") + systemProperty("exclude-dsql", "true") } } @@ -335,6 +410,7 @@ tasks.register("debug-mysql-aurora-performance") { systemProperty("exclude-secrets-manager", "true") systemProperty("exclude-pg-driver", "true") systemProperty("exclude-pg-engine", "true") + systemProperty("exclude-dsql", "true") } } @@ -346,6 +422,7 @@ tasks.register("debug-multi-az") { systemProperty("exclude-aurora", "true") systemProperty("exclude-performance", "true") systemProperty("exclude-python-38", "true") + systemProperty("exclude-dsql", "true") } } @@ -358,6 +435,7 @@ tasks.register("debug-pg-multi-az") { systemProperty("exclude-performance", "true") systemProperty("exclude-mysql-driver", "true") systemProperty("exclude-mysql-engine", "true") + systemProperty("exclude-dsql", "true") } } @@ -370,5 +448,23 @@ tasks.register("debug-mysql-multi-az") { systemProperty("exclude-performance", "true") systemProperty("exclude-pg-driver", "true") systemProperty("exclude-pg-engine", "true") + systemProperty("exclude-dsql", "true") + } +} + +tasks.register("debug-all-dsql") { + group = "verification" + filter.includeTestsMatching("integration.host.TestRunner.debugTests") + doFirst { + systemProperty("exclude-aurora", "true") + systemProperty("exclude-autoscaling", "true") + systemProperty("exclude-docker", "true") + systemProperty("exclude-multi-az", "true") + systemProperty("exclude-mysql-driver", "true") + systemProperty("exclude-mysql-engine", "true") + systemProperty("exclude-mariadb-driver", "true") + systemProperty("exclude-mariadb-engine", "true") + systemProperty("exclude-performance", "true") + systemProperty("exclude-secrets-manager", "true") } } diff --git a/tests/integration/host/src/test/java/integration/DatabaseEngineDeployment.java b/tests/integration/host/src/test/java/integration/DatabaseEngineDeployment.java index 0126e0f2..3a16b5f0 100644 --- a/tests/integration/host/src/test/java/integration/DatabaseEngineDeployment.java +++ b/tests/integration/host/src/test/java/integration/DatabaseEngineDeployment.java @@ -20,5 +20,6 @@ public enum DatabaseEngineDeployment { DOCKER, RDS, RDS_MULTI_AZ, - AURORA + AURORA, + DSQL } diff --git a/tests/integration/host/src/test/java/integration/TestEnvironmentFeatures.java b/tests/integration/host/src/test/java/integration/TestEnvironmentFeatures.java index 6cf8514a..800f332a 100644 --- a/tests/integration/host/src/test/java/integration/TestEnvironmentFeatures.java +++ b/tests/integration/host/src/test/java/integration/TestEnvironmentFeatures.java @@ -25,6 +25,7 @@ public enum TestEnvironmentFeatures { AWS_CREDENTIALS_ENABLED, PERFORMANCE, RUN_AUTOSCALING_TESTS_ONLY, + RUN_DSQL_TESTS_ONLY, SKIP_MYSQL_DRIVER_TESTS, SKIP_PG_DRIVER_TESTS, TELEMETRY_TRACES_ENABLED, diff --git a/tests/integration/host/src/test/java/integration/host/TestEnvironment.java b/tests/integration/host/src/test/java/integration/host/TestEnvironment.java index 7cabd4b2..f0a96de5 100644 --- a/tests/integration/host/src/test/java/integration/host/TestEnvironment.java +++ b/tests/integration/host/src/test/java/integration/host/TestEnvironment.java @@ -41,6 +41,8 @@ import java.sql.SQLException; import java.sql.Statement; import java.util.ArrayList; +import java.util.LinkedList; +import java.util.List; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -127,6 +129,7 @@ public static TestEnvironment build(TestEnvironmentRequest request) throws IOExc break; case AURORA: case RDS_MULTI_AZ: + case DSQL: env = createAuroraOrMultiAzEnvironment(request); authorizeIP(env); @@ -200,7 +203,12 @@ private static TestEnvironment createAuroraOrMultiAzEnvironment(TestEnvironmentR } else { TestEnvironment env = new TestEnvironment(request); initDatabaseParams(env); - createDbCluster(env); + if (request.getDatabaseEngineDeployment() == DatabaseEngineDeployment.DSQL) { + createDsqlCluster(env); + } + else { + createDbCluster(env); + } if (request.getFeatures().contains(TestEnvironmentFeatures.IAM)) { if (request.getDatabaseEngineDeployment() == DatabaseEngineDeployment.RDS_MULTI_AZ) { @@ -334,8 +342,11 @@ private static void createDbCluster(TestEnvironment env, int numOfInstances) thr ArrayList instances = new ArrayList<>(); if (env.reuseAuroraDbCluster) { + if (StringUtils.isNullOrEmpty(env.auroraClusterName)) { + throw new RuntimeException("Environment variable RDS_CLUSTER_NAME is required."); + } if (StringUtils.isNullOrEmpty(env.auroraClusterDomain)) { - throw new RuntimeException("Environment variable AURORA_CLUSTER_DOMAIN is required."); + throw new RuntimeException("Environment variable RDS_CLUSTER_DOMAIN is required."); } if (!env.auroraUtil.doesClusterExist(env.auroraClusterName)) { @@ -439,6 +450,88 @@ private static void createDbCluster(TestEnvironment env, int numOfInstances) thr } } + + + private static void createDsqlCluster(TestEnvironment env) throws URISyntaxException { + + initAwsCredentials(env); + + env.info.setRegion( + !StringUtils.isNullOrEmpty(config.rdsDbRegion) + ? config.rdsDbRegion + : "us-east-2"); + + env.reuseAuroraDbCluster = config.reuseRdsCluster; + env.auroraClusterName = config.rdsClusterName; // "cluster-mysql" + env.auroraClusterDomain = config.rdsClusterDomain; // "XYZ.us-west-2.rds.amazonaws.com" + env.rdsEndpoint = config.rdsEndpoint; // "https://rds-int.amazon.com" + env.info.setRdsEndpoint(env.rdsEndpoint); + + env.auroraUtil = + new AuroraTestUtility( + env.info.getRegion(), + env.rdsEndpoint, + env.awsAccessKeyId, + env.awsSecretAccessKey, + env.awsSessionToken); + + + final String endpoint; + if (env.reuseAuroraDbCluster) { + if (StringUtils.isNullOrEmpty(env.auroraClusterName)) { + throw new RuntimeException("Environment variable RDS_CLUSTER_NAME is required."); + } + if (StringUtils.isNullOrEmpty(env.auroraClusterDomain)) { + throw new RuntimeException("Environment variable RDS_CLUSTER_DOMAIN is required."); + } + + endpoint = env.auroraClusterName + "." + env.auroraClusterDomain; + + final String identifier = env.auroraUtil.getDsqlInstanceId(endpoint); + if (!env.auroraUtil.doesDsqlClusterExist(identifier)) { + throw new RuntimeException( + String.format("It's requested to reuse existing DSQL cluster '%s' but it doesn't exist in region %s ", + endpoint, + env.info.getRegion())); + } + + LOGGER.finer( + "Reuse existing cluster " + endpoint); + + } else { + final String name = getRandomName(env.info.getRequest()); + try { + final String identifier = env.auroraUtil.createDsqlCluster(name); + env.auroraClusterName = identifier; + endpoint = String.format("%s.dsql.%s.on.aws", identifier, env.info.getRegion()); + } catch (Exception e) { + LOGGER.finer("Error creating a cluster " + name + ". " + e.getMessage()); + throw new RuntimeException(e); + } + } + + env.info.setClusterName(env.auroraClusterName); + + int port = getPort(env.info.getRequest()); + + env.info + .getDatabaseInfo() + .setClusterEndpoint(endpoint, port); + env.info + .getDatabaseInfo() + .setClusterReadOnlyEndpoint(endpoint, port); + + List instances = new LinkedList<>(); + instances.add(new TestInstanceInfo(env.auroraClusterName, endpoint, port)); + + env.info.getDatabaseInfo().getInstances().clear(); + env.info.getDatabaseInfo().getInstances().addAll(instances); + + authorizeIP(env); + + } + + private static void authorizeIP(TestEnvironment env) { try { env.runnerIP = env.auroraUtil.getPublicIPAddress(); @@ -578,14 +671,21 @@ private static int getPort(TestEnvironmentRequest request) { } private static void initDatabaseParams(TestEnvironment env) { - final String dbName = - !StringUtils.isNullOrEmpty(config.dbName) - ? config.dbName - : "test_database"; - final String dbUsername = - !StringUtils.isNullOrEmpty(config.dbUsername) - ? config.dbUsername - : "test_user"; + + final TestEnvironmentRequest request = env.info.getRequest(); + final boolean isDsql = (request.getDatabaseEngineDeployment() == DatabaseEngineDeployment.DSQL); + + final String dbName = isDsql + ? "postgres" + : !StringUtils.isNullOrEmpty(config.dbName) + ? config.dbName + : "test_database"; + final String dbUsername = isDsql + ? "admin" + : !StringUtils.isNullOrEmpty(config.dbUsername) + ? config.dbUsername + : "test_user"; + final String dbPassword = !StringUtils.isNullOrEmpty(config.dbPassword) ? config.dbPassword @@ -805,17 +905,24 @@ private static String getContainerBaseImageName(TestEnvironmentRequest request) private static void configureIamAccess(TestEnvironment env) { - if (env.info.getRequest().getDatabaseEngineDeployment() != DatabaseEngineDeployment.AURORA) { + if (env.info.getRequest().getDatabaseEngineDeployment() != DatabaseEngineDeployment.AURORA && + env.info.getRequest().getDatabaseEngineDeployment() != DatabaseEngineDeployment.DSQL) + { throw new UnsupportedOperationException( env.info.getRequest().getDatabaseEngineDeployment().toString()); } + final TestEnvironmentRequest request = env.info.getRequest(); + final boolean isDsql = (request.getDatabaseEngineDeployment() == DatabaseEngineDeployment.DSQL); + env.info.setIamUsername( - !StringUtils.isNullOrEmpty(config.iamUser) - ? config.iamUser - : "jane_doe"); + isDsql + ? "admin" + : !StringUtils.isNullOrEmpty(config.iamUser) + ? config.iamUser + : "jane_doe"); - if (!env.reuseAuroraDbCluster) { + if (!env.reuseAuroraDbCluster && !isDsql) { try { Class.forName(DriverHelper.getDriverClassname(env.info.getRequest().getDatabaseEngine())); } catch (ClassNotFoundException e) { @@ -918,6 +1025,7 @@ public void close() throws Exception { switch (this.info.getRequest().getDatabaseEngineDeployment()) { case AURORA: case RDS_MULTI_AZ: + case DSQL: deleteDbCluster(); break; case RDS: @@ -932,10 +1040,19 @@ private void deleteDbCluster() { auroraUtil.ec2DeauthorizesIP(runnerIP); } + final DatabaseEngineDeployment deployment = this.info.getRequest().getDatabaseEngineDeployment(); + + final String identifier; + if (deployment == DatabaseEngineDeployment.DSQL) { + identifier = this.auroraClusterName; + } else { + identifier = this.auroraClusterName + ".cluster-" + this.auroraClusterDomain; + } + if (!this.reuseAuroraDbCluster) { - LOGGER.finest("Deleting cluster " + this.auroraClusterName + ".cluster-" + this.auroraClusterDomain); + LOGGER.finest("Deleting cluster " + identifier); auroraUtil.deleteCluster(this.auroraClusterName); - LOGGER.finest("Deleted cluster " + this.auroraClusterName + ".cluster-" + this.auroraClusterDomain); + LOGGER.finest("Deleted cluster " + identifier); } } diff --git a/tests/integration/host/src/test/java/integration/host/TestEnvironmentConfiguration.java b/tests/integration/host/src/test/java/integration/host/TestEnvironmentConfiguration.java index 6789df0b..f98c0d13 100644 --- a/tests/integration/host/src/test/java/integration/host/TestEnvironmentConfiguration.java +++ b/tests/integration/host/src/test/java/integration/host/TestEnvironmentConfiguration.java @@ -44,6 +44,8 @@ public class TestEnvironmentConfiguration { Boolean.parseBoolean(System.getProperty("exclude-secrets-manager", "false")); public boolean testAutoscalingOnly = Boolean.parseBoolean(System.getProperty("test-autoscaling", "false")); + public boolean excludeDsql = + Boolean.parseBoolean(System.getProperty("exclude-dsql", "false")); public boolean excludeInstances1 = Boolean.parseBoolean(System.getProperty("exclude-instances-1", "false")); diff --git a/tests/integration/host/src/test/java/integration/host/TestEnvironmentProvider.java b/tests/integration/host/src/test/java/integration/host/TestEnvironmentProvider.java index d5cd972d..86929324 100644 --- a/tests/integration/host/src/test/java/integration/host/TestEnvironmentProvider.java +++ b/tests/integration/host/src/test/java/integration/host/TestEnvironmentProvider.java @@ -68,6 +68,9 @@ public Stream provideTestTemplateInvocationContex if (deployment == DatabaseEngineDeployment.RDS_MULTI_AZ && config.excludeMultiAz) { continue; } + if (deployment == DatabaseEngineDeployment.DSQL && config.excludeDsql) { + continue; + } for (DatabaseEngine engine : DatabaseEngine.values()) { if (engine == DatabaseEngine.PG && config.excludePgEngine) { @@ -76,9 +79,12 @@ public Stream provideTestTemplateInvocationContex if (engine == DatabaseEngine.MYSQL && config.excludeMysqlEngine) { continue; } + if (engine != DatabaseEngine.PG && DatabaseEngineDeployment.DSQL == deployment) { + continue; + } for (DatabaseInstances instances : DatabaseInstances.values()) { - if (deployment == DatabaseEngineDeployment.DOCKER + if ((deployment == DatabaseEngineDeployment.DOCKER || deployment == DatabaseEngineDeployment.DSQL) && instances != DatabaseInstances.SINGLE_INSTANCE) { continue; } @@ -141,6 +147,9 @@ public Stream provideTestTemplateInvocationContex || config.excludeIam ? null : TestEnvironmentFeatures.IAM, + deployment == DatabaseEngineDeployment.DSQL + ? TestEnvironmentFeatures.RUN_DSQL_TESTS_ONLY + : null, config.excludeSecretsManager ? null : TestEnvironmentFeatures.SECRETS_MANAGER, config.excludePerformance ? null : TestEnvironmentFeatures.PERFORMANCE, config.excludeMysqlDriver ? TestEnvironmentFeatures.SKIP_MYSQL_DRIVER_TESTS : null, diff --git a/tests/integration/host/src/test/java/integration/util/AuroraTestUtility.java b/tests/integration/host/src/test/java/integration/util/AuroraTestUtility.java index 919737b8..a58c3a6d 100644 --- a/tests/integration/host/src/test/java/integration/util/AuroraTestUtility.java +++ b/tests/integration/host/src/test/java/integration/util/AuroraTestUtility.java @@ -34,13 +34,19 @@ import java.time.Duration; import java.util.ArrayList; import java.util.Comparator; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Random; import java.util.concurrent.TimeUnit; import java.util.logging.Logger; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import java.util.stream.Collectors; +import org.apache.logging.log4j.CloseableThreadContext.Instance; + import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.auth.credentials.AwsSessionCredentials; @@ -48,6 +54,12 @@ import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; import software.amazon.awssdk.core.waiters.WaiterResponse; import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.retries.api.BackoffStrategy; +import software.amazon.awssdk.services.dsql.DsqlClient; +import software.amazon.awssdk.services.dsql.model.CreateClusterRequest; +import software.amazon.awssdk.services.dsql.model.CreateClusterResponse; +import software.amazon.awssdk.services.dsql.model.GetClusterResponse; +import software.amazon.awssdk.services.dsql.model.ResourceNotFoundException; import software.amazon.awssdk.services.ec2.Ec2Client; import software.amazon.awssdk.services.ec2.model.DescribeSecurityGroupsResponse; import software.amazon.awssdk.services.ec2.model.Ec2Exception; @@ -98,10 +110,19 @@ public class AuroraTestUtility { private final RdsClient rdsClient; private final Ec2Client ec2Client; + private final DsqlClient dsqlClient; private static final Random rand = new Random(); private static final String DUPLICATE_IP_ERROR_CODE = "InvalidPermission.Duplicate"; + private static final Pattern AURORA_DSQL_CLUSTER_PATTERN = + Pattern.compile( + "^(?[^.]+)\\." + + "(?dsql(?:-[^.]+)?)\\." + + "(?(?[a-zA-Z0-9\\-]+)" + + "\\.on\\.aws\\.?)$", + Pattern.CASE_INSENSITIVE); + public AuroraTestUtility( String region, String rdsEndpoint, String awsAccessKeyId, String awsSecretAccessKey, String awsSessionToken) throws URISyntaxException { @@ -139,7 +160,11 @@ public AuroraTestUtility(Region region, String rdsEndpoint, AwsCredentialsProvid .region(dbRegion) .credentialsProvider(credentialsProvider) .build(); - } + dsqlClient = DsqlClient.builder() + .region(dbRegion) + .credentialsProvider(credentialsProvider) + .build(); + } protected static Region getRegionInternal(String rdsRegion) { Optional regionOptional = @@ -339,6 +364,41 @@ public String createMultiAzCluster() throws InterruptedException { return clusterDomainPrefix; } + /** + * Create a DSQL cluster. + * + * @param name A human-readable name to tag the cluster with. + * @return The unique identifier of the created cluster. + */ + public String createDsqlCluster(final String name) throws InterruptedException { + final Map tagMap = new HashMap<>(); + tagMap.put("Name", name); + + final CreateClusterRequest request = CreateClusterRequest.builder() + .deletionProtectionEnabled(false) + .tags(tagMap) + .build(); + final CreateClusterResponse cluster = dsqlClient.createCluster(request); + + this.dbEngineDeployment = DatabaseEngineDeployment.DSQL; + this.dbIdentifier = cluster.identifier(); + + final WaiterResponse waiterResponse = dsqlClient.waiter().waitUntilClusterActive( + getCluster -> getCluster.identifier(cluster.identifier()), + config -> config.backoffStrategyV2( + BackoffStrategy.fixedDelayWithoutJitter(Duration.ofSeconds(10)) + ).waitTimeout(Duration.ofMinutes(30)) + ); + + if (waiterResponse.matched().exception().isPresent()) { + deleteCluster(); + throw new InterruptedException( + "Unable to create DSQL cluster after waiting for 30 minutes"); + } + + return cluster.identifier(); + } + /** * Gets public IP. * @@ -435,14 +495,21 @@ public void deleteCluster(String identifier) { * Destroys all instances and clusters. Removes IP from EC2 whitelist. */ public void deleteCluster() { + final DatabaseEngineDeployment deployment = this.dbEngineDeployment; + if (deployment == null) { + throw new UnsupportedOperationException("DB engine deployment must be non-null"); + } - switch (this.dbEngineDeployment) { + switch (deployment) { case AURORA: this.deleteAuroraCluster(); break; case RDS_MULTI_AZ: this.deleteMultiAzCluster(); break; + case DSQL: + this.deleteDsqlCluster(); + break; default: throw new UnsupportedOperationException(this.dbEngineDeployment.toString()); } @@ -509,6 +576,23 @@ public void deleteMultiAzCluster() { } } + public void deleteDsqlCluster() { + dsqlClient.deleteCluster(r -> r.identifier(dbIdentifier)); + + WaiterResponse waiterResponse = dsqlClient.waiter().waitUntilClusterNotExists( + getCluster -> getCluster.identifier(dbIdentifier), + config -> config.backoffStrategyV2( + BackoffStrategy.fixedDelayWithoutJitter(Duration.ofSeconds(10)) + ).waitTimeout(Duration.ofMinutes(30)) + ); + + if (waiterResponse.matched().exception().isPresent() + && !(waiterResponse.matched().exception().get() instanceof ResourceNotFoundException)) { + throw new RuntimeException( + "Unable to delete DSQL cluster after waiting for 30 minutes"); + } + } + public boolean doesClusterExist(final String clusterId) { final DescribeDbClustersRequest request = DescribeDbClustersRequest.builder().dbClusterIdentifier(clusterId).build(); @@ -520,6 +604,28 @@ public boolean doesClusterExist(final String clusterId) { return true; } + public boolean doesDsqlClusterExist(final String identifier) { + try { + final GetClusterResponse response = dsqlClient.getCluster(r -> r.identifier(identifier)); + return response.sdkHttpResponse().isSuccessful(); + } catch (ResourceNotFoundException ex) { + return false; + } + } + + public String getDsqlInstanceId(final String host) { + + if (StringUtils.isNullOrEmpty(host)) { + return null; + } + + final Matcher matcher = AURORA_DSQL_CLUSTER_PATTERN.matcher(host); + if (!matcher.matches()) { + return null; + } + return matcher.group("instance"); + } + public DBCluster getClusterInfo(final String clusterId) { final DescribeDbClustersRequest request = DescribeDbClustersRequest.builder().dbClusterIdentifier(clusterId).build();