diff --git a/.github/workflows/integration_tests.yml b/.github/workflows/integration_tests.yml index 937545b8..3bbc2572 100644 --- a/.github/workflows/integration_tests.yml +++ b/.github/workflows/integration_tests.yml @@ -19,7 +19,7 @@ jobs: matrix: python-version: [ "3.8", "3.11" ] engine-version: [ "lts", "latest"] - environment: ["mysql", "pg"] + environment: ["mysql", "pg", "dsql"] steps: - name: 'Clone repository' diff --git a/README.md b/README.md index 61c5a81d..cb7d7d08 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,12 @@ Since a database failover is usually identified by reaching a network or a conne Enhanced Failure Monitoring (EFM) is a feature available from the [Host Monitoring Connection Plugin](./docs/using-the-python-driver/using-plugins/UsingTheHostMonitoringPlugin.md#enhanced-failure-monitoring) that periodically checks the connected database host's health and availability. If a database host is determined to be unhealthy, the connection is aborted (and potentially routed to another healthy host in the cluster). +### Using the AWS Advanced Python Driver with AWS Aurora DSQL +The AWS Advanced Python Driver is able to handle IAM authentication when working with AWS Aurora DSQL clusters. + +Please visit [this page](./docs/using-the-python-driver/using-plugins/UsingTheDSQLIamAuthenticationPlugin.md) for more information. + + ### Using the AWS Advanced Python Driver with plain RDS databases The AWS Advanced Python Driver also works with RDS provided databases that are not Aurora. diff --git a/aws_advanced_python_wrapper/dsql_iam_auth_plugin_factory.py b/aws_advanced_python_wrapper/dsql_iam_auth_plugin_factory.py new file mode 100644 index 00000000..62068870 --- /dev/null +++ b/aws_advanced_python_wrapper/dsql_iam_auth_plugin_factory.py @@ -0,0 +1,30 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from aws_advanced_python_wrapper.iam_plugin import IamAuthPlugin +from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory +from aws_advanced_python_wrapper.utils.dsql_token_utils import DSQLTokenUtils + +if TYPE_CHECKING: + from aws_advanced_python_wrapper.plugin_service import PluginService + from aws_advanced_python_wrapper.utils.properties import Properties + + +class DsqlIamAuthPluginFactory(PluginFactory): + def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: + return IamAuthPlugin(plugin_service, DSQLTokenUtils()) diff --git a/aws_advanced_python_wrapper/federated_plugin.py b/aws_advanced_python_wrapper/federated_plugin.py index 0eb8258f..186f6d16 100644 --- a/aws_advanced_python_wrapper/federated_plugin.py +++ b/aws_advanced_python_wrapper/federated_plugin.py @@ -31,6 +31,7 @@ from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.pep249 import Connection from aws_advanced_python_wrapper.plugin_service import PluginService + from aws_advanced_python_wrapper.utils.token_utils import TokenUtils from datetime import datetime, timedelta from typing import Callable, Dict, Optional, Set @@ -43,6 +44,7 @@ from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) +from aws_advanced_python_wrapper.utils.rds_token_utils import RDSTokenUtils from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils logger = Logger(__name__) @@ -55,12 +57,17 @@ class FederatedAuthPlugin(Plugin): _rds_utils: RdsUtils = RdsUtils() _token_cache: Dict[str, TokenInfo] = {} - def __init__(self, plugin_service: PluginService, credentials_provider_factory: CredentialsProviderFactory, session: Optional[Session] = None): + def __init__(self, + plugin_service: PluginService, + credentials_provider_factory: CredentialsProviderFactory, + token_utils: TokenUtils, + session: Optional[Session] = None): self._plugin_service = plugin_service self._credentials_provider_factory = credentials_provider_factory self._session = session self._region_utils = RegionUtils() + self._token_utils = token_utils telemetry_factory = self._plugin_service.get_telemetry_factory() self._fetch_token_counter = telemetry_factory.create_counter("federated.fetch_token.count") self._cache_size_gauge = telemetry_factory.create_gauge("federated.token_cache.size", lambda: len(FederatedAuthPlugin._token_cache)) @@ -145,7 +152,7 @@ def _update_authentication_token(self, credentials: Optional[Dict[str, str]] = self._credentials_provider_factory.get_aws_credentials(region, props) self._fetch_token_counter.inc() - token: str = IamAuthUtils.generate_authentication_token( + token: str = self._token_utils.generate_authentication_token( self._plugin_service, user, host_info.host, @@ -159,7 +166,7 @@ def _update_authentication_token(self, class FederatedAuthPluginFactory(PluginFactory): def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: - return FederatedAuthPlugin(plugin_service, self.get_credentials_provider_factory(plugin_service, props)) + return FederatedAuthPlugin(plugin_service, self.get_credentials_provider_factory(plugin_service, props), RDSTokenUtils()) def get_credentials_provider_factory(self, plugin_service: PluginService, props: Properties) -> AdfsCredentialsProviderFactory: idp_name = WrapperProperties.IDP_NAME.get(props) diff --git a/aws_advanced_python_wrapper/iam_plugin.py b/aws_advanced_python_wrapper/iam_plugin.py index 1a26c58a..06fb3b07 100644 --- a/aws_advanced_python_wrapper/iam_plugin.py +++ b/aws_advanced_python_wrapper/iam_plugin.py @@ -17,6 +17,8 @@ from typing import TYPE_CHECKING from aws_advanced_python_wrapper.utils.iam_utils import IamAuthUtils, TokenInfo +from aws_advanced_python_wrapper.utils.rds_token_utils import RDSTokenUtils +from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils from aws_advanced_python_wrapper.utils.region_utils import RegionUtils if TYPE_CHECKING: @@ -25,6 +27,7 @@ from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.pep249 import Connection from aws_advanced_python_wrapper.plugin_service import PluginService + from aws_advanced_python_wrapper.utils.token_utils import TokenUtils from datetime import datetime, timedelta from typing import Callable, Dict, Optional, Set @@ -35,7 +38,6 @@ from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) -from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils logger = Logger(__name__) @@ -48,11 +50,12 @@ class IamAuthPlugin(Plugin): _rds_utils: RdsUtils = RdsUtils() _token_cache: Dict[str, TokenInfo] = {} - def __init__(self, plugin_service: PluginService, session: Optional[Session] = None): + def __init__(self, plugin_service: PluginService, token_utils: TokenUtils, session: Optional[Session] = None): self._plugin_service = plugin_service self._session = session self._region_utils = RegionUtils() + self._token_utils = token_utils telemetry_factory = self._plugin_service.get_telemetry_factory() self._fetch_token_counter = telemetry_factory.create_counter("iam.fetch_token.count") self._cache_size_gauge = telemetry_factory.create_gauge( @@ -102,7 +105,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl else: token_expiry = datetime.now() + timedelta(seconds=token_expiration_sec) self._fetch_token_counter.inc() - token: str = IamAuthUtils.generate_authentication_token(self._plugin_service, user, host, port, region, client_session=self._session) + token: str = self._token_utils.generate_authentication_token(self._plugin_service, user, host, port, region, client_session=self._session) self._plugin_service.driver_dialect.set_password(props, token) IamAuthPlugin._token_cache[cache_key] = TokenInfo(token, token_expiry) @@ -120,7 +123,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl # Try to generate a new token and try to connect again token_expiry = datetime.now() + timedelta(seconds=token_expiration_sec) self._fetch_token_counter.inc() - token = IamAuthUtils.generate_authentication_token(self._plugin_service, user, host, port, region, client_session=self._session) + token = self._token_utils.generate_authentication_token(self._plugin_service, user, host, port, region, client_session=self._session) self._plugin_service.driver_dialect.set_password(props, token) IamAuthPlugin._token_cache[cache_key] = TokenInfo(token, token_expiry) @@ -142,4 +145,4 @@ def force_connect( class IamAuthPluginFactory(PluginFactory): def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: - return IamAuthPlugin(plugin_service) + return IamAuthPlugin(plugin_service, RDSTokenUtils()) diff --git a/aws_advanced_python_wrapper/okta_plugin.py b/aws_advanced_python_wrapper/okta_plugin.py index 55bd9980..d1e9e19e 100644 --- a/aws_advanced_python_wrapper/okta_plugin.py +++ b/aws_advanced_python_wrapper/okta_plugin.py @@ -31,6 +31,7 @@ from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.pep249 import Connection from aws_advanced_python_wrapper.plugin_service import PluginService + from aws_advanced_python_wrapper.utils.token_utils import TokenUtils import requests @@ -40,6 +41,7 @@ from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) +from aws_advanced_python_wrapper.utils.rds_token_utils import RDSTokenUtils from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils logger = Logger(__name__) @@ -51,12 +53,17 @@ class OktaAuthPlugin(Plugin): _rds_utils: RdsUtils = RdsUtils() _token_cache: Dict[str, TokenInfo] = {} - def __init__(self, plugin_service: PluginService, credentials_provider_factory: CredentialsProviderFactory, session: Optional[Session] = None): + def __init__(self, + plugin_service: PluginService, + credentials_provider_factory: CredentialsProviderFactory, + token_utils: TokenUtils, + session: Optional[Session] = None): self._plugin_service = plugin_service self._credentials_provider_factory = credentials_provider_factory self._session = session self._region_utils = RegionUtils() + self._token_utils = token_utils telemetry_factory = self._plugin_service.get_telemetry_factory() self._fetch_token_counter = telemetry_factory.create_counter("okta.fetch_token.count") self._cache_size_gauge = telemetry_factory.create_gauge("okta.token_cache.size", lambda: len(OktaAuthPlugin._token_cache)) @@ -140,7 +147,7 @@ def _update_authentication_token(self, port: int = IamAuthUtils.get_port(props, host_info, self._plugin_service.database_dialect.default_port) credentials: Optional[Dict[str, str]] = self._credentials_provider_factory.get_aws_credentials(region, props) - token: str = IamAuthUtils.generate_authentication_token( + token: str = self._token_utils.generate_authentication_token( self._plugin_service, user, host_info.host, @@ -228,7 +235,7 @@ def get_saml_assertion(self, props: Properties): class OktaAuthPluginFactory(PluginFactory): def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: - return OktaAuthPlugin(plugin_service, self.get_credentials_provider_factory(plugin_service, props)) + return OktaAuthPlugin(plugin_service, self.get_credentials_provider_factory(plugin_service, props), RDSTokenUtils()) def get_credentials_provider_factory(self, plugin_service: PluginService, props: Properties) -> OktaCredentialsProviderFactory: return OktaCredentialsProviderFactory(plugin_service, props) diff --git a/aws_advanced_python_wrapper/plugin_service.py b/aws_advanced_python_wrapper/plugin_service.py index b58f8ea3..95769ef2 100644 --- a/aws_advanced_python_wrapper/plugin_service.py +++ b/aws_advanced_python_wrapper/plugin_service.py @@ -61,6 +61,8 @@ from aws_advanced_python_wrapper.developer_plugin import DeveloperPluginFactory from aws_advanced_python_wrapper.driver_configuration_profiles import \ DriverConfigurationProfiles +from aws_advanced_python_wrapper.dsql_iam_auth_plugin_factory import \ + DsqlIamAuthPluginFactory from aws_advanced_python_wrapper.errors import (AwsWrapperError, QueryTimeoutError, UnsupportedOperationError) @@ -752,6 +754,7 @@ class PluginManager(CanReleaseResources): PLUGIN_FACTORIES: Dict[str, Type[PluginFactory]] = { "iam": IamAuthPluginFactory, + "iam_dsql": DsqlIamAuthPluginFactory, "aws_secrets_manager": AwsSecretsManagerPluginFactory, "aurora_connection_tracker": AuroraConnectionTrackerPluginFactory, "host_monitoring": HostMonitoringPluginFactory, @@ -786,6 +789,7 @@ class PluginManager(CanReleaseResources): BlueGreenPluginFactory: 550, FastestResponseStrategyPluginFactory: 600, IamAuthPluginFactory: 700, + DsqlIamAuthPluginFactory: 710, AwsSecretsManagerPluginFactory: 800, FederatedAuthPluginFactory: 900, LimitlessPluginFactory: 950, diff --git a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties index 9d690871..7c15ca43 100644 --- a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties +++ b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties @@ -186,7 +186,6 @@ IamAuthPlugin.UnhandledException=[IamAuthPlugin] Unhandled exception: {} IamAuthPlugin.UseCachedIamToken=[IamAuthPlugin] Used cached IAM token = {} IamAuthPlugin.InvalidHost=[IamAuthPlugin] Invalid IAM host {}. The IAM host must be a valid RDS or Aurora endpoint. IamAuthPlugin.IsNoneOrEmpty=[IamAuthPlugin] Property "{}" is None or empty. -IamAuthUtils.GeneratedNewAuthToken=Generated new authentication token = {} LimitlessPlugin.FailedToConnectToHost=[LimitlessPlugin] Failed to connect to host {}. LimitlessPlugin.UnsupportedDialectOrDatabase=[LimitlessPlugin] Unsupported dialect '{}' encountered. Please ensure the connection parameters are correct, and refer to the documentation to ensure that the connecting database is compatible with the Limitless Connection Plugin. @@ -357,6 +356,8 @@ RoundRobinHostSelector.ClusterInfoNone=[RoundRobinHostSelector] The round robin RoundRobinHostSelector.RoundRobinInvalidDefaultWeight=[RoundRobinHostSelector] The provided default weight value is not valid. Weight values must be an integer greater than or equal to 1. RoundRobinHostSelector.RoundRobinInvalidHostWeightPairs= [RoundRobinHostSelector] The provided host weight pairs have not been configured correctly. Please ensure the provided host weight pairs is a comma separated list of pairs, each pair in the format of :. Weight values must be an integer greater than or equal to the default weight value of 1. Weight pair: '{}' +TokenUtils.GeneratedNewAuthTokenLength=Generated new authentication token length = {} + WeightedRandomHostSelector.WeightedRandomInvalidHostWeightPairs= [WeightedRandomHostSelector] The provided host weight pairs have not been configured correctly. Please ensure the provided host weight pairs is a comma separated list of pairs, each pair in the format of :. Weight values must be an integer greater than or equal to the default weight value of 1. Weight pair: '{}' WeightedRandomHostSelector.WeightedRandomInvalidDefaultWeight=[WeightedRandomHostSelector] The provided default weight value is not valid. Weight values must be an integer greater than or equal to 1. diff --git a/aws_advanced_python_wrapper/utils/dsql_token_utils.py b/aws_advanced_python_wrapper/utils/dsql_token_utils.py new file mode 100644 index 00000000..aa61e02e --- /dev/null +++ b/aws_advanced_python_wrapper/utils/dsql_token_utils.py @@ -0,0 +1,75 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Dict, Optional + +from aws_advanced_python_wrapper.utils.log import Logger +from aws_advanced_python_wrapper.utils.telemetry.telemetry import \ + TelemetryTraceLevel +from aws_advanced_python_wrapper.utils.token_utils import TokenUtils + +if TYPE_CHECKING: + from aws_advanced_python_wrapper.plugin_service import PluginService + from boto3 import Session + +import boto3 + +logger = Logger(__name__) + + +class DSQLTokenUtils(TokenUtils): + def generate_authentication_token( + self, + plugin_service: PluginService, + user: Optional[str], + host_name: Optional[str], + port: Optional[int], + region: Optional[str], + credentials: Optional[Dict[str, str]] = None, + client_session: Optional[Session] = None) -> str: + telemetry_factory = plugin_service.get_telemetry_factory() + context = telemetry_factory.open_telemetry_context("fetch DSQL authentication token", TelemetryTraceLevel.NESTED) + + try: + session = client_session if client_session else boto3.Session() + if credentials is not None: + client = session.client( + 'dsql', + region_name=region, + aws_access_key_id=credentials.get('AccessKeyId'), + aws_secret_access_key=credentials.get('SecretAccessKey'), + aws_session_token=credentials.get('SessionToken') + ) + else: + client = session.client( + 'dsql', + region_name=region + ) + + if user == "admin": + token = client.generate_db_connect_admin_auth_token(host_name, region) + else: + token = client.generate_db_connect_auth_token(host_name, region) + + logger.debug("TokenUtils.GeneratedNewAuthTokenLength", len(token) if token else 0) + client.close() + return token + except Exception as ex: + context.set_success(False) + context.set_exception(ex) + raise ex + finally: + context.close_context() diff --git a/aws_advanced_python_wrapper/utils/iam_utils.py b/aws_advanced_python_wrapper/utils/iam_utils.py index ecb5868f..412499d0 100644 --- a/aws_advanced_python_wrapper/utils/iam_utils.py +++ b/aws_advanced_python_wrapper/utils/iam_utils.py @@ -15,22 +15,16 @@ from __future__ import annotations from datetime import datetime -from typing import TYPE_CHECKING, Dict, Optional - -import boto3 +from typing import TYPE_CHECKING, Optional from aws_advanced_python_wrapper.errors import AwsWrapperError from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils -from aws_advanced_python_wrapper.utils.telemetry.telemetry import \ - TelemetryTraceLevel if TYPE_CHECKING: from aws_advanced_python_wrapper.hostinfo import HostInfo - from aws_advanced_python_wrapper.plugin_service import PluginService - from boto3 import Session from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) @@ -70,52 +64,6 @@ def get_port(props: Properties, host_info: HostInfo, dialect_default_port: int) def get_cache_key(user: Optional[str], hostname: Optional[str], port: int, region: Optional[str]) -> str: return f"{region}:{hostname}:{port}:{user}" - @staticmethod - def generate_authentication_token( - plugin_service: PluginService, - user: Optional[str], - host_name: Optional[str], - port: Optional[int], - region: Optional[str], - credentials: Optional[Dict[str, str]] = None, - client_session: Optional[Session] = None) -> str: - telemetry_factory = plugin_service.get_telemetry_factory() - context = telemetry_factory.open_telemetry_context("fetch authentication token", TelemetryTraceLevel.NESTED) - - try: - session = client_session if client_session else boto3.Session() - - if credentials is not None: - client = session.client( - 'rds', - region_name=region, - aws_access_key_id=credentials.get('AccessKeyId'), - aws_secret_access_key=credentials.get('SecretAccessKey'), - aws_session_token=credentials.get('SessionToken') - ) - else: - client = session.client( - 'rds', - region_name=region - ) - - token = client.generate_db_auth_token( - DBHostname=host_name, - Port=port, - DBUsername=user - ) - - client.close() - - logger.debug("IamAuthUtils.GeneratedNewAuthToken", token) - return token - except Exception as ex: - context.set_success(False) - context.set_exception(ex) - raise ex - finally: - context.close_context() - class TokenInfo: @property diff --git a/aws_advanced_python_wrapper/utils/rds_token_utils.py b/aws_advanced_python_wrapper/utils/rds_token_utils.py new file mode 100644 index 00000000..3497cc65 --- /dev/null +++ b/aws_advanced_python_wrapper/utils/rds_token_utils.py @@ -0,0 +1,79 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Dict, Optional + +from aws_advanced_python_wrapper.utils.log import Logger +from aws_advanced_python_wrapper.utils.telemetry.telemetry import \ + TelemetryTraceLevel +from aws_advanced_python_wrapper.utils.token_utils import TokenUtils + +if TYPE_CHECKING: + from aws_advanced_python_wrapper.plugin_service import PluginService + from boto3 import Session + +import boto3 + +logger = Logger(__name__) + + +class RDSTokenUtils(TokenUtils): + def generate_authentication_token( + self, + plugin_service: PluginService, + user: Optional[str], + host_name: Optional[str], + port: Optional[int], + region: Optional[str], + credentials: Optional[Dict[str, str]] = None, + client_session: Optional[Session] = None) -> str: + + telemetry_factory = plugin_service.get_telemetry_factory() + context = telemetry_factory.open_telemetry_context("fetch authentication token", TelemetryTraceLevel.NESTED) + + try: + session = client_session if client_session else boto3.Session() + + if credentials is not None: + client = session.client( + 'rds', + region_name=region, + aws_access_key_id=credentials.get('AccessKeyId'), + aws_secret_access_key=credentials.get('SecretAccessKey'), + aws_session_token=credentials.get('SessionToken') + ) + else: + client = session.client( + 'rds', + region_name=region + ) + + token = client.generate_db_auth_token( + DBHostname=host_name, + Port=port, + DBUsername=user + ) + + client.close() + + logger.debug("TokenUtils.GeneratedNewAuthTokenLength", len(token) if token else 0) + return token + except Exception as ex: + context.set_success(False) + context.set_exception(ex) + raise ex + finally: + context.close_context() diff --git a/aws_advanced_python_wrapper/utils/rds_url_type.py b/aws_advanced_python_wrapper/utils/rds_url_type.py index 7226c33c..911c25c6 100644 --- a/aws_advanced_python_wrapper/utils/rds_url_type.py +++ b/aws_advanced_python_wrapper/utils/rds_url_type.py @@ -34,4 +34,5 @@ def __init__(self, is_rds: bool, is_rds_cluster: bool): RDS_PROXY = True, False, RDS_INSTANCE = True, False, RDS_AURORA_LIMITLESS_DB_SHARD_GROUP = True, False, + DSQL_CLUSTER = False, False, OTHER = False, False diff --git a/aws_advanced_python_wrapper/utils/rdsutils.py b/aws_advanced_python_wrapper/utils/rdsutils.py index ab8f1b1a..508a7cf1 100644 --- a/aws_advanced_python_wrapper/utils/rdsutils.py +++ b/aws_advanced_python_wrapper/utils/rdsutils.py @@ -108,6 +108,10 @@ class RdsUtils: r"(?Pcluster-|cluster-ro-)+" \ r"(?P[a-zA-Z0-9]+\.rds\.(?P[a-zA-Z0-9\-]+)" \ r"\.(amazonaws\.com|c2s\.ic\.gov|sc2s\.sgov\.gov))$" + AURORA_DSQL_CLUSTER_PATTERN = r"^(?P[^.]+)\." \ + r"(?Pdsql(?:-[^.]+)?)\." \ + r"(?P(?P[a-zA-Z0-9\-]+)" \ + r"\.on\.aws\.?)$" ELB_PATTERN = r"^(?.+)\.elb\.((?[a-zA-Z0-9\-]+)\.amazonaws\.com)$" IP_V4 = r"^(([1-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){1}" \ @@ -153,6 +157,14 @@ def is_rds_dns(self, host: str) -> bool: def is_rds_instance(self, host: str) -> bool: return self._get_dns_group(host) is None and self.is_rds_dns(host) + def is_dsql_cluster(self, host: str) -> bool: + if not host or not host.strip(): + return False + + pattern = self._find(host, [RdsUtils.AURORA_DSQL_CLUSTER_PATTERN]) + + return pattern is not None + def is_rds_proxy_dns(self, host: str) -> bool: dns_group = self._get_dns_group(host) return dns_group is not None and dns_group.casefold() == "proxy-" @@ -261,6 +273,8 @@ def identify_rds_type(self, host: Optional[str]) -> RdsUrlType: return RdsUrlType.RDS_PROXY elif self.is_rds_instance(host): return RdsUrlType.RDS_INSTANCE + elif self.is_dsql_cluster(host): + return RdsUrlType.DSQL_CLUSTER return RdsUrlType.OTHER diff --git a/aws_advanced_python_wrapper/utils/token_utils.py b/aws_advanced_python_wrapper/utils/token_utils.py new file mode 100644 index 00000000..f23c61e7 --- /dev/null +++ b/aws_advanced_python_wrapper/utils/token_utils.py @@ -0,0 +1,37 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Dict, Optional + +if TYPE_CHECKING: + from boto3 import Session + + from aws_advanced_python_wrapper.plugin_service import PluginService + + +class TokenUtils(ABC): + @abstractmethod + def generate_authentication_token( + self, + plugin_service: PluginService, + user: Optional[str], + host_name: Optional[str], + port: Optional[int], + region: Optional[str], + credentials: Optional[Dict[str, str]] = None, + client_session: Optional[Session] = None) -> str: + pass diff --git a/docs/README.md b/docs/README.md index ef76cae1..ab7a88c9 100644 --- a/docs/README.md +++ b/docs/README.md @@ -17,6 +17,7 @@ - [Aurora Initial Connection Strategy Plugin](./using-the-python-driver/using-plugins/UsingTheAuroraInitialConnectionStrategyPlugin.md) - [Host Availability Strategy](./using-the-python-driver/HostAvailabilityStrategy.md) - [IAM Authentication Plugin](./using-the-python-driver/using-plugins/UsingTheIamAuthenticationPlugin.md) + - [DSQL IAM Authentication Plugin](./using-the-python-driver/using-plugins/UsingTheDSQLIamAuthenticationPlugin.md) - [AWS Secrets Manager Plugin](./using-the-python-driver/using-plugins/UsingTheAwsSecretsManagerPlugin.md) - [Federated Authentication Plugin](./using-the-python-driver/using-plugins/UsingTheFederatedAuthenticationPlugin.md) - [Read Write Splitting Plugin](./using-the-python-driver/using-plugins/UsingTheReadWriteSplittingPlugin.md) diff --git a/docs/examples/DSQLIamAuthentication.py b/docs/examples/DSQLIamAuthentication.py new file mode 100644 index 00000000..4e9b2870 --- /dev/null +++ b/docs/examples/DSQLIamAuthentication.py @@ -0,0 +1,38 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import psycopg + +from aws_advanced_python_wrapper import AwsWrapperConnection + +if __name__ == "__main__": + with AwsWrapperConnection.connect( + psycopg.Connection.connect, + host="abcd.dsql.us-east-1.on.aws", + dbname="postgres", + user="admin", + plugins="iam_dsql", + iam_region="us-east-1", + wrapper_dialect="pg", + autocommit=True + ) as awsconn, awsconn.cursor() as awscursor: + awscursor.execute("CREATE TABLE IF NOT EXISTS bank_test (id int primary key, name varchar(40), account_balance int)") + awscursor.execute("INSERT INTO bank_test VALUES (%s, %s, %s)", (0, "Jane Doe", 200)) + awscursor.execute("INSERT INTO bank_test VALUES (%s, %s, %s)", (1, "John Smith", 200)) + awscursor.execute("SELECT * FROM bank_test") + + res = awscursor.fetchall() + for record in res: + print(record) + awscursor.execute("DROP TABLE bank_test") diff --git a/docs/using-the-python-driver/using-plugins/UsingTheDSQLIamAuthenticationPlugin.md b/docs/using-the-python-driver/using-plugins/UsingTheDSQLIamAuthenticationPlugin.md new file mode 100644 index 00000000..87e36553 --- /dev/null +++ b/docs/using-the-python-driver/using-plugins/UsingTheDSQLIamAuthenticationPlugin.md @@ -0,0 +1,39 @@ +# AWS Aurora DSQL IAM Authentication Plugin + +This plugin enables connecting to AWS Aurora DSQL databases through AWS Identity and Access Management (IAM). + +## What is IAM? +AWS Identity and Access Management (IAM) grants users access control across all Amazon Web Services. IAM supports granular permissions, giving you the ability to grant different permissions to different users. For more information on IAM and its use cases, please refer to the [IAM documentation](https://docs.aws.amazon.com/IAM/latest/UserGuide/introduction.html). + +## Prerequisites +> [!WARNING]\ +> To preserve compatibility with customers using the community driver, IAM Authentication requires the AWS SDK for Python; [Boto3](https://pypi.org/project/boto3/). Boto3 is a runtime dependency and must be resolved. It can be installed via pip like so: `pip install boto3`. + +The DSQL IAM Authentication plugin requires authentication via AWS Credentials. These credentials can be defined in `~/.aws/credentials` or set as environment variables. All users must set `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY`. Users who are using temporary security credentials will also need to additionally set `AWS_SESSION_TOKEN`. + +To enable the AWS Aurora DSQL IAM Authentication Plugin, add the plugin code `iam_dsql` to the [`plugins`](../UsingThePythonDriver.md#connection-plugin-manager-parameters) parameter. + +> [!WARNING]\ +> The `iam` plugin must NOT be specified when using the `iam_dsql` plugin. + +## AWS IAM Database Authentication +The AWS Python Driver supports Amazon AWS Identity and Access Management (IAM) authentication. When using AWS IAM database authentication, the host URL must be a valid AWS Aurora DSQL endpoint, and not a custom domain or an IP address. +
i.e. `cluster-identifier.dsql.us-east-1.on.aws` + +Connections established by the `iam_dsql` plugin are beholden to the [Cluster quotas and database limits in Amazon Aurora DSQL](https://docs.aws.amazon.com/aurora-dsql/latest/userguide/CHAP_quotas.html). In particular, applications need to consider the maximum transaction duration, and maximum connection duration limits. Ensure connections are returned to the pool regularly, and not retained for long periods. + + +## How do I use IAM with the AWS Python Driver? +1. Configure IAM roles for the cluster according to [Using database roles and IAM authentication](https://docs.aws.amazon.com/aurora-dsql/latest/userguide/using-database-and-iam-roles.html). +2. Add the plugin code `iam_dsql` to the [`plugins`](../UsingThePythonDriver.md#connection-plugin-manager-parameters) parameter value. + +| Parameter | Value | Required | Description | Example Value | +|--------------------|:-------:|:--------:|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------| +| `iam_host` | String | No | This property will override the default hostname that is used to generate the IAM token. The default hostname is derived from the connection string. This parameter is required when users are connecting with custom endpoints. | `cluster-identifier.dsql.us-east-1.on.aws` | +| `iam_region` | String | No | This property will override the default region that is used to generate the IAM token. The default region is parsed from the connection string where possible. Some connection string formats may not be supported, and the `iam_region` must be provided in these cases. | `us-east-2` | +| `iam_expiration` | Integer | No | This property determines how long an IAM token is kept in the driver cache before a new one is generated. The default expiration time is set to 14 minutes and 30 seconds. Note that IAM database authentication tokens have a lifetime of 15 minutes. | `600` | + +## Sample code + +[DSQLIamAuthentication.py](../../examples/DSQLIamAuthentication.py) + diff --git a/docs/using-the-python-driver/using-plugins/UsingTheIamAuthenticationPlugin.md b/docs/using-the-python-driver/using-plugins/UsingTheIamAuthenticationPlugin.md index e879e2a6..a5fadfcc 100644 --- a/docs/using-the-python-driver/using-plugins/UsingTheIamAuthenticationPlugin.md +++ b/docs/using-the-python-driver/using-plugins/UsingTheIamAuthenticationPlugin.md @@ -11,6 +11,9 @@ The IAM Authentication plugin requires authentication via AWS Credentials. These To enable the IAM Authentication Connection Plugin, add the plugin code `iam` to the [`plugins`](../UsingThePythonDriver.md#connection-plugin-manager-parameters) parameter. +> [!WARNING]\ +> The `iam` plugin must NOT be specified when using the `iam_dsql` plugin. + ## AWS IAM Database Authentication The AWS Python Driver supports Amazon AWS Identity and Access Management (IAM) authentication. When using AWS IAM database authentication, the host URL must be a valid Amazon endpoint, and not a custom domain or an IP address.
i.e. `db-identifier.cluster-XYZ.us-east-2.rds.amazonaws.com` diff --git a/tests/integration/container/test_aurora_failover.py b/tests/integration/container/test_aurora_failover.py index 71524bee..a00573a4 100644 --- a/tests/integration/container/test_aurora_failover.py +++ b/tests/integration/container/test_aurora_failover.py @@ -44,7 +44,8 @@ @enable_on_deployments([DatabaseEngineDeployment.AURORA, DatabaseEngineDeployment.RDS_MULTI_AZ_CLUSTER]) @disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, - TestEnvironmentFeatures.PERFORMANCE]) + TestEnvironmentFeatures.PERFORMANCE, + TestEnvironmentFeatures.RUN_DSQL_TESTS_ONLY]) class TestAuroraFailover: IDLE_CONNECTIONS_NUM: int = 5 logger = Logger(__name__) diff --git a/tests/integration/container/test_autoscaling.py b/tests/integration/container/test_autoscaling.py index 61c9e7d2..ba328e82 100644 --- a/tests/integration/container/test_autoscaling.py +++ b/tests/integration/container/test_autoscaling.py @@ -42,7 +42,8 @@ @enable_on_num_instances(min_instances=5) -@enable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY]) +@enable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.RUN_DSQL_TESTS_ONLY]) class TestAutoScaling: @pytest.fixture def rds_utils(self): diff --git a/tests/integration/container/test_basic_connectivity.py b/tests/integration/container/test_basic_connectivity.py index 6745eac6..947a0ec0 100644 --- a/tests/integration/container/test_basic_connectivity.py +++ b/tests/integration/container/test_basic_connectivity.py @@ -38,7 +38,8 @@ @disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, - TestEnvironmentFeatures.PERFORMANCE]) + TestEnvironmentFeatures.PERFORMANCE, + TestEnvironmentFeatures.RUN_DSQL_TESTS_ONLY]) class TestBasicConnectivity: @pytest.fixture(scope='class') diff --git a/tests/integration/container/test_basic_functionality.py b/tests/integration/container/test_basic_functionality.py index 005fcb6c..954f1e92 100644 --- a/tests/integration/container/test_basic_functionality.py +++ b/tests/integration/container/test_basic_functionality.py @@ -48,7 +48,8 @@ @disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, - TestEnvironmentFeatures.PERFORMANCE]) + TestEnvironmentFeatures.PERFORMANCE, + TestEnvironmentFeatures.RUN_DSQL_TESTS_ONLY]) class TestBasicFunctionality: @pytest.fixture(scope='class') diff --git a/tests/integration/container/test_custom_endpoint.py b/tests/integration/container/test_custom_endpoint.py index 57162437..3265f613 100644 --- a/tests/integration/container/test_custom_endpoint.py +++ b/tests/integration/container/test_custom_endpoint.py @@ -47,7 +47,8 @@ @enable_on_deployments([DatabaseEngineDeployment.AURORA]) @disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, - TestEnvironmentFeatures.PERFORMANCE]) + TestEnvironmentFeatures.PERFORMANCE, + TestEnvironmentFeatures.RUN_DSQL_TESTS_ONLY]) class TestCustomEndpoint: logger: ClassVar[Logger] = Logger(__name__) endpoint_id: ClassVar[str] = f"test-endpoint-1-{uuid4()}" diff --git a/tests/integration/container/test_iam_authentication.py b/tests/integration/container/test_iam_authentication.py index 4bc6ee3d..d2549cee 100644 --- a/tests/integration/container/test_iam_authentication.py +++ b/tests/integration/container/test_iam_authentication.py @@ -41,7 +41,8 @@ @enable_on_features([TestEnvironmentFeatures.IAM]) @disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, - TestEnvironmentFeatures.PERFORMANCE]) + TestEnvironmentFeatures.PERFORMANCE, + TestEnvironmentFeatures.RUN_DSQL_TESTS_ONLY]) class TestAwsIamAuthentication: @pytest.fixture(scope='class') diff --git a/tests/integration/container/test_iam_dsql_authentication.py b/tests/integration/container/test_iam_dsql_authentication.py new file mode 100644 index 00000000..1de40d8b --- /dev/null +++ b/tests/integration/container/test_iam_dsql_authentication.py @@ -0,0 +1,114 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from aws_advanced_python_wrapper.utils.properties import (Properties, + WrapperProperties) +from tests.integration.container.utils.test_environment_features import \ + TestEnvironmentFeatures + +if TYPE_CHECKING: + from tests.integration.container.utils.test_driver import TestDriver + +from typing import Callable + +import pytest + +from aws_advanced_python_wrapper import AwsWrapperConnection +from aws_advanced_python_wrapper.errors import AwsWrapperError +from tests.integration.container.utils.conditions import enable_on_features +from tests.integration.container.utils.driver_helper import DriverHelper +from tests.integration.container.utils.test_environment import TestEnvironment + + +@enable_on_features([TestEnvironmentFeatures.RUN_DSQL_TESTS_ONLY]) +class TestAwsIamDSQLAuthentication: + + @pytest.fixture(scope='class') + def props(self): + p: Properties = Properties() + + if TestEnvironmentFeatures.TELEMETRY_TRACES_ENABLED in TestEnvironment.get_current().get_features() \ + or TestEnvironmentFeatures.TELEMETRY_METRICS_ENABLED in TestEnvironment.get_current().get_features(): + WrapperProperties.ENABLE_TELEMETRY.set(p, "True") + WrapperProperties.TELEMETRY_SUBMIT_TOPLEVEL.set(p, "True") + + if TestEnvironmentFeatures.TELEMETRY_TRACES_ENABLED in TestEnvironment.get_current().get_features(): + WrapperProperties.TELEMETRY_TRACES_BACKEND.set(p, "XRAY") + + if TestEnvironmentFeatures.TELEMETRY_METRICS_ENABLED in TestEnvironment.get_current().get_features(): + WrapperProperties.TELEMETRY_METRICS_BACKEND.set(p, "OTLP") + + return p + + def test_iam_wrong_database_username(self, test_environment: TestEnvironment, + test_driver: TestDriver, conn_utils, props): + target_driver_connect = DriverHelper.get_connect_func(test_driver) + user = f"WRONG_{conn_utils.iam_user}_USER" + params = conn_utils.get_connect_params(user=user) + params.pop("use_pure", None) # AWS tokens are truncated when using the pure Python MySQL driver + + with pytest.raises(AwsWrapperError): + AwsWrapperConnection.connect( + target_driver_connect, + **params, + plugins="iam_dsql", + **props) + + def test_iam_no_database_username(self, test_driver: TestDriver, conn_utils, props): + target_driver_connect = DriverHelper.get_connect_func(test_driver) + params = conn_utils.get_connect_params() + params.pop("use_pure", None) # AWS tokens are truncated when using the pure Python MySQL driver + params.pop("user", None) + + with pytest.raises(AwsWrapperError): + AwsWrapperConnection.connect(target_driver_connect, **params, plugins="iam_dsql", **props) + + def test_iam_invalid_host(self, test_driver: TestDriver, conn_utils, props): + target_driver_connect = DriverHelper.get_connect_func(test_driver) + params = conn_utils.get_connect_params() + params.pop("use_pure", None) # AWS tokens are truncated when using the pure Python MySQL driver + params.update({"iam_host": "<>", "plugins": "iam_dsql"}) + + with pytest.raises(AwsWrapperError): + AwsWrapperConnection.connect(target_driver_connect, **params, **props) + + def test_iam_valid_connection_properties( + self, test_environment: TestEnvironment, test_driver: TestDriver, conn_utils, props): + target_driver_connect = DriverHelper.get_connect_func(test_driver) + params = conn_utils.get_connect_params(user=conn_utils.iam_user, password="") + params.pop("use_pure", None) # AWS tokens are truncated when using the pure Python MySQL driver + params["plugins"] = "iam_dsql" + + self.validate_connection(target_driver_connect, **params, **props) + + def test_iam_valid_connection_properties_no_password( + self, test_environment: TestEnvironment, test_driver: TestDriver, conn_utils, props): + target_driver_connect = DriverHelper.get_connect_func(test_driver) + params = conn_utils.get_connect_params(user=conn_utils.iam_user) + params.pop("use_pure", None) # AWS tokens are truncated when using the pure Python MySQL driver + params.pop("password", None) + params["plugins"] = "iam_dsql" + + self.validate_connection(target_driver_connect, **params, **props) + + def validate_connection(self, target_driver_connect: Callable, **connect_params): + with AwsWrapperConnection.connect(target_driver_connect, **connect_params) as conn, \ + conn.cursor() as cursor: + cursor.execute("SELECT now()") + records = cursor.fetchall() + assert len(records) == 1 diff --git a/tests/integration/container/test_read_write_splitting.py b/tests/integration/container/test_read_write_splitting.py index de75badb..5cf559df 100644 --- a/tests/integration/container/test_read_write_splitting.py +++ b/tests/integration/container/test_read_write_splitting.py @@ -47,7 +47,8 @@ DatabaseEngineDeployment.RDS_MULTI_AZ_INSTANCE]) @disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, - TestEnvironmentFeatures.PERFORMANCE]) + TestEnvironmentFeatures.PERFORMANCE, + TestEnvironmentFeatures.RUN_DSQL_TESTS_ONLY]) class TestReadWriteSplitting: @pytest.fixture(scope='class') def rds_utils(self): diff --git a/tests/integration/container/utils/database_engine_deployment.py b/tests/integration/container/utils/database_engine_deployment.py index c58817a3..941aa21a 100644 --- a/tests/integration/container/utils/database_engine_deployment.py +++ b/tests/integration/container/utils/database_engine_deployment.py @@ -21,3 +21,4 @@ class DatabaseEngineDeployment(str, Enum): RDS_MULTI_AZ_CLUSTER = "RDS_MULTI_AZ_CLUSTER" RDS_MULTI_AZ_INSTANCE = "RDS_MULTI_AZ_INSTANCE" AURORA = "AURORA" + DSQL = "DSQL" diff --git a/tests/integration/container/utils/test_environment_features.py b/tests/integration/container/utils/test_environment_features.py index ec42d197..a2914329 100644 --- a/tests/integration/container/utils/test_environment_features.py +++ b/tests/integration/container/utils/test_environment_features.py @@ -26,6 +26,7 @@ class TestEnvironmentFeatures(Enum): AWS_CREDENTIALS_ENABLED = "AWS_CREDENTIALS_ENABLED" PERFORMANCE = "PERFORMANCE" RUN_AUTOSCALING_TESTS_ONLY = "RUN_AUTOSCALING_TESTS_ONLY" + RUN_DSQL_TESTS_ONLY = "RUN_DSQL_TESTS_ONLY" BLUE_GREEN_DEPLOYMENT = "BLUE_GREEN_DEPLOYMENT" SKIP_MYSQL_DRIVER_TESTS = "SKIP_MYSQL_DRIVER_TESTS" SKIP_PG_DRIVER_TESTS = "SKIP_PG_DRIVER_TESTS" diff --git a/tests/integration/host/build.gradle.kts b/tests/integration/host/build.gradle.kts index b746eb64..5032ad59 100644 --- a/tests/integration/host/build.gradle.kts +++ b/tests/integration/host/build.gradle.kts @@ -30,6 +30,7 @@ dependencies { testImplementation("software.amazon.awssdk:rds:2.20.49") testImplementation("software.amazon.awssdk:ec2:2.20.61") testImplementation("software.amazon.awssdk:secretsmanager:2.20.49") + testImplementation("software.amazon.awssdk:dsql:2.29.34") // Note: all org.testcontainers dependencies should have the same version testImplementation("org.testcontainers:testcontainers:1.21.2") testImplementation("org.testcontainers:mysql:1.21.2") @@ -123,6 +124,48 @@ tasks.register("test-python-3.8-pg") { } } +tasks.register("test-python-3.11-dsql") { + group = "verification" + filter.includeTestsMatching("integration.host.TestRunner.runTests") + doFirst { + systemProperty("exclude-dsql", "false") + systemProperty("exclude-aurora", "true") + systemProperty("exclude-python-38", "true") + systemProperty("exclude-autoscaling", "true") + systemProperty("exclude-docker", "true") + systemProperty("exclude-multi-az-cluster", "true") + systemProperty("exclude-multi-az-instance", "true") + systemProperty("exclude-mysql-driver", "true") + systemProperty("exclude-mysql-engine", "true") + systemProperty("exclude-mariadb-driver", "true") + systemProperty("exclude-mariadb-engine", "true") + systemProperty("exclude-performance", "true") + systemProperty("exclude-secrets-manager", "true") + systemProperty("exclude-bg", "true") + } +} + +tasks.register("test-python-3.8-dsql") { + group = "verification" + filter.includeTestsMatching("integration.host.TestRunner.runTests") + doFirst { + systemProperty("exclude-dsql", "false") + systemProperty("exclude-aurora", "true") + systemProperty("exclude-python-311", "true") + systemProperty("exclude-autoscaling", "true") + systemProperty("exclude-docker", "true") + systemProperty("exclude-multi-az-cluster", "true") + systemProperty("exclude-multi-az-instance", "true") + systemProperty("exclude-mysql-driver", "true") + systemProperty("exclude-mysql-engine", "true") + systemProperty("exclude-mariadb-driver", "true") + systemProperty("exclude-mariadb-engine", "true") + systemProperty("exclude-performance", "true") + systemProperty("exclude-secrets-manager", "true") + systemProperty("exclude-bg", "true") + } +} + tasks.register("test-docker") { group = "verification" filter.includeTestsMatching("integration.host.TestRunner.runTests") @@ -264,6 +307,26 @@ tasks.register("test-mysql-aurora-performance") { } } +tasks.register("test-all-dsql") { + group = "verification" + filter.includeTestsMatching("integration.host.TestRunner.runTests") + doFirst { + systemProperty("exclude-dsql", "false") + systemProperty("exclude-aurora", "true") + systemProperty("exclude-autoscaling", "true") + systemProperty("exclude-docker", "true") + systemProperty("exclude-multi-az-cluster", "true") + systemProperty("exclude-multi-az-instance", "true") + systemProperty("exclude-mysql-driver", "true") + systemProperty("exclude-mysql-engine", "true") + systemProperty("exclude-mariadb-driver", "true") + systemProperty("exclude-mariadb-engine", "true") + systemProperty("exclude-performance", "true") + systemProperty("exclude-secrets-manager", "true") + systemProperty("exclude-bg", "true") + } +} + tasks.register("test-bgd-mysql-instance") { group = "verification" filter.includeTestsMatching("integration.host.TestRunner.runTests") @@ -301,7 +364,6 @@ tasks.register("test-bgd-mysql-aurora") { systemProperty("exclude-instances-5", "true") systemProperty("exclude-multi-az-cluster", "true") systemProperty("test-bg-only", "true") - } } @@ -342,7 +404,6 @@ tasks.register("test-bgd-pg-aurora") { systemProperty("exclude-instances-5", "true") systemProperty("exclude-multi-az-cluster", "true") systemProperty("test-bg-only", "true") - } } @@ -497,6 +558,25 @@ tasks.register("debug-mysql-multi-az") { } } +tasks.register("debug-all-dsql") { + group = "verification" + filter.includeTestsMatching("integration.host.TestRunner.debugTests") + doFirst { + systemProperty("exclude-aurora", "true") + systemProperty("exclude-autoscaling", "true") + systemProperty("exclude-docker", "true") + systemProperty("exclude-multi-az-cluster", "true") + systemProperty("exclude-multi-az-instance", "true") + systemProperty("exclude-mysql-driver", "true") + systemProperty("exclude-mysql-engine", "true") + systemProperty("exclude-mariadb-driver", "true") + systemProperty("exclude-mariadb-engine", "true") + systemProperty("exclude-performance", "true") + systemProperty("exclude-secrets-manager", "true") + systemProperty("exclude-bg", "true") + } +} + tasks.register("debug-bgd-pg-aurora") { group = "verification" filter.includeTestsMatching("integration.host.TestRunner.debugTests") diff --git a/tests/integration/host/src/test/java/integration/DatabaseEngineDeployment.java b/tests/integration/host/src/test/java/integration/DatabaseEngineDeployment.java index d7273e30..20d4c638 100644 --- a/tests/integration/host/src/test/java/integration/DatabaseEngineDeployment.java +++ b/tests/integration/host/src/test/java/integration/DatabaseEngineDeployment.java @@ -21,5 +21,6 @@ public enum DatabaseEngineDeployment { RDS, RDS_MULTI_AZ_CLUSTER, RDS_MULTI_AZ_INSTANCE, - AURORA + AURORA, + DSQL } diff --git a/tests/integration/host/src/test/java/integration/TestEnvironmentFeatures.java b/tests/integration/host/src/test/java/integration/TestEnvironmentFeatures.java index a80defb9..1286dd50 100644 --- a/tests/integration/host/src/test/java/integration/TestEnvironmentFeatures.java +++ b/tests/integration/host/src/test/java/integration/TestEnvironmentFeatures.java @@ -29,5 +29,6 @@ public enum TestEnvironmentFeatures { RUN_AUTOSCALING_TESTS_ONLY, TELEMETRY_TRACES_ENABLED, TELEMETRY_METRICS_ENABLED, - BLUE_GREEN_DEPLOYMENT + BLUE_GREEN_DEPLOYMENT, + RUN_DSQL_TESTS_ONLY } diff --git a/tests/integration/host/src/test/java/integration/host/TestEnvironment.java b/tests/integration/host/src/test/java/integration/host/TestEnvironment.java index 06a4fa30..24877f9e 100644 --- a/tests/integration/host/src/test/java/integration/host/TestEnvironment.java +++ b/tests/integration/host/src/test/java/integration/host/TestEnvironment.java @@ -39,6 +39,7 @@ import java.sql.SQLException; import java.sql.Statement; import java.util.ArrayList; +import java.util.LinkedList; import java.util.List; import java.util.Random; import java.util.concurrent.ExecutionException; @@ -139,6 +140,7 @@ public static TestEnvironment build(TestEnvironmentRequest request) throws IOExc case AURORA: case RDS_MULTI_AZ_CLUSTER: case RDS_MULTI_AZ_INSTANCE: + case DSQL: env = createAuroraOrMultiAzEnvironment(request); @@ -177,7 +179,8 @@ private static void authorizeRunnerIpAddress(TestEnvironment env) { if (deployment == DatabaseEngineDeployment.AURORA || deployment == DatabaseEngineDeployment.RDS || deployment == DatabaseEngineDeployment.RDS_MULTI_AZ_INSTANCE - || deployment == DatabaseEngineDeployment.RDS_MULTI_AZ_CLUSTER) { + || deployment == DatabaseEngineDeployment.RDS_MULTI_AZ_CLUSTER + || deployment == DatabaseEngineDeployment.DSQL) { // These environment require creating external database cluster that should be publicly available. // Corresponding AWS Security Groups should be configured and the test task runner IP address // should be whitelisted. @@ -271,6 +274,12 @@ private static TestEnvironment createAuroraOrMultiAzEnvironment(TestEnvironmentR createDbCluster(env); configureIamAccess(env); break; + case DSQL: + initEnv(env); + authorizeRunnerIpAddress(env); + createDsqlCluster(env); + configureIamAccess(env); + break; default: throw new NotImplementedException(request.getDatabaseEngineDeployment().toString()); } @@ -775,6 +784,60 @@ private static void createMultiAzInstance(TestEnvironment env) { } } + private static void createDsqlCluster(TestEnvironment env) { + final String endpoint; + + if (env.reuseDb) { + if (StringUtils.isNullOrEmpty(env.rdsDbName)) { + throw new RuntimeException("Environment variable RDS_CLUSTER_NAME is required."); + } + if (StringUtils.isNullOrEmpty(env.rdsDbDomain)) { + throw new RuntimeException("Environment variable RDS_CLUSTER_DOMAIN is required."); + } + + endpoint = env.rdsDbName + "." + env.rdsDbDomain; + + final String identifier = env.auroraUtil.getDsqlInstanceId(endpoint); + if (!env.auroraUtil.doesDsqlClusterExist(identifier)) { + throw new RuntimeException( + String.format("It's requested to reuse existing DSQL cluster '%s' but it doesn't exist in region %s ", + endpoint, + env.info.getRegion())); + } + + LOGGER.finer( + "Reuse existing cluster " + endpoint); + + } else { + final String name = getRandomName(env); + try { + final String identifier = env.auroraUtil.createDsqlCluster(name); + env.rdsDbName = identifier; + endpoint = String.format("%s.dsql.%s.on.aws", identifier, env.info.getRegion()); + } catch (Exception e) { + LOGGER.finer("Error creating a cluster " + name + ". " + e.getMessage()); + throw new RuntimeException(e); + } + } + + int port = getPort(env.info.getRequest()); + + env.info.setRdsDbName(env.rdsDbName); + env.info + .getDatabaseInfo() + .setClusterEndpoint(endpoint, port); + env.info + .getDatabaseInfo() + .setClusterReadOnlyEndpoint(endpoint, port); + + List instances = new LinkedList<>(); + instances.add(new TestInstanceInfo(env.rdsDbName, endpoint, port)); + + env.info.getDatabaseInfo().getInstances().clear(); + env.info.getDatabaseInfo().getInstances().addAll(instances); + } + + private static void authorizeIP(TestEnvironment env) { try { env.runnerIP = env.auroraUtil.getPublicIPAddress(); @@ -924,15 +987,22 @@ private static int getPort(TestEnvironmentRequest request) { } private static void initDatabaseParams(TestEnvironment env) { - final String dbName = - config.dbName == null - ? "test_database" - : config.dbName.trim(); - - final String dbUsername = - !StringUtils.isNullOrEmpty(config.dbUsername) - ? config.dbUsername - : "test_user"; + + final TestEnvironmentRequest request = env.info.getRequest(); + final boolean isDsql = (request.getDatabaseEngineDeployment() == DatabaseEngineDeployment.DSQL); + + final String dbName = !StringUtils.isNullOrEmpty(config.dbName) + ? config.dbName.trim() + : isDsql + ? "postgres" + : "test_database"; + + final String dbUsername = !StringUtils.isNullOrEmpty(config.dbUsername) + ? config.dbUsername + : isDsql + ? "admin" + : "test_user"; + final String dbPassword = !StringUtils.isNullOrEmpty(config.dbPassword) ? config.dbPassword @@ -1157,13 +1227,16 @@ private static void configureIamAccess(TestEnvironment env) { } final DatabaseEngineDeployment deployment = env.info.getRequest().getDatabaseEngineDeployment(); + final boolean isDsql = deployment == DatabaseEngineDeployment.DSQL; env.info.setIamUsername( - !StringUtils.isNullOrEmpty(config.iamUser) - ? config.iamUser - : "jane_doe"); + isDsql + ? "admin" + : !StringUtils.isNullOrEmpty(config.iamUser) + ? config.iamUser + : "jane_doe"); - if (!env.reuseDb) { + if (!env.reuseDb && !isDsql ) { try { Class.forName(DriverHelper.getDriverClassname(env.info.getRequest().getDatabaseEngine())); } catch (ClassNotFoundException e) { @@ -1286,15 +1359,15 @@ public void close() throws Exception { if (this.info.getRequest().getFeatures().contains(TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT) && !StringUtils.isNullOrEmpty(this.info.getBlueGreenDeploymentId())) { deleteBlueGreenDeployment(); - deleteDbCluster(true); + deleteDbCluster(DatabaseEngineDeployment.AURORA, true); deleteCustomClusterParameterGroup(this.info.getClusterParameterGroupName()); } else { - deleteDbCluster(false); + deleteDbCluster(DatabaseEngineDeployment.AURORA, false); } deAuthorizeIP(this); break; case RDS_MULTI_AZ_CLUSTER: - deleteDbCluster(false); + deleteDbCluster(DatabaseEngineDeployment.RDS_MULTI_AZ_CLUSTER, false); deAuthorizeIP(this); break; case RDS_MULTI_AZ_INSTANCE: @@ -1312,20 +1385,31 @@ public void close() throws Exception { // no external resources to dispose // do nothing break; + case DSQL: + deleteDbCluster(DatabaseEngineDeployment.DSQL, false); + deAuthorizeIP(this); + break; default: throw new NotImplementedException(this.info.getRequest().getDatabaseEngineDeployment().toString()); } } - private void deleteDbCluster(boolean waitForCompletion) { + private void deleteDbCluster(DatabaseEngineDeployment deployment, boolean waitForCompletion) { if (!this.reuseDb) { - LOGGER.finest("Deleting cluster " + this.rdsDbName + ".cluster-" + this.rdsDbDomain); + final String identifier; + if (deployment == DatabaseEngineDeployment.DSQL) { + identifier = this.rdsDbName; + } else { + identifier = this.rdsDbName + ".cluster-" + this.rdsDbDomain; + } + + LOGGER.finest("Deleting cluster " + identifier); auroraUtil.deleteCluster( - this.rdsDbName, this.info.getRequest().getDatabaseEngineDeployment(), waitForCompletion); - LOGGER.finest("Deleted cluster " + this.rdsDbName + ".cluster-" + this.rdsDbDomain); + this.rdsDbName, deployment, waitForCompletion); + LOGGER.finest("Deleted cluster " + identifier); } } - + private void deleteMultiAzInstance() { if (!this.reuseDb) { LOGGER.finest("Deleting MultiAz Instance " + this.rdsDbName + "." + this.rdsDbDomain); diff --git a/tests/integration/host/src/test/java/integration/host/TestEnvironmentConfiguration.java b/tests/integration/host/src/test/java/integration/host/TestEnvironmentConfiguration.java index f36f8af8..558e048f 100644 --- a/tests/integration/host/src/test/java/integration/host/TestEnvironmentConfiguration.java +++ b/tests/integration/host/src/test/java/integration/host/TestEnvironmentConfiguration.java @@ -46,6 +46,8 @@ public class TestEnvironmentConfiguration { Boolean.parseBoolean(System.getProperty("exclude-secrets-manager", "false")); public boolean testAutoscalingOnly = Boolean.parseBoolean(System.getProperty("test-autoscaling", "false")); + public boolean excludeDsql = + Boolean.parseBoolean(System.getProperty("exclude-dsql", "true")); public boolean excludeInstances1 = Boolean.parseBoolean(System.getProperty("exclude-instances-1", "false")); diff --git a/tests/integration/host/src/test/java/integration/host/TestEnvironmentProvider.java b/tests/integration/host/src/test/java/integration/host/TestEnvironmentProvider.java index 15011003..d95403f7 100644 --- a/tests/integration/host/src/test/java/integration/host/TestEnvironmentProvider.java +++ b/tests/integration/host/src/test/java/integration/host/TestEnvironmentProvider.java @@ -71,6 +71,9 @@ public Stream provideTestTemplateInvocationContex if (deployment == DatabaseEngineDeployment.RDS_MULTI_AZ_INSTANCE && config.excludeMultiAzInstance) { continue; } + if (deployment == DatabaseEngineDeployment.DSQL && config.excludeDsql) { + continue; + } for (DatabaseEngine engine : DatabaseEngine.values()) { if (engine == DatabaseEngine.PG && config.excludePgEngine) { @@ -79,9 +82,12 @@ public Stream provideTestTemplateInvocationContex if (engine == DatabaseEngine.MYSQL && config.excludeMysqlEngine) { continue; } + if (engine != DatabaseEngine.PG && DatabaseEngineDeployment.DSQL == deployment) { + continue; + } for (DatabaseInstances instances : DatabaseInstances.values()) { - if (deployment == DatabaseEngineDeployment.DOCKER + if ((deployment == DatabaseEngineDeployment.DOCKER || deployment == DatabaseEngineDeployment.DSQL) && instances != DatabaseInstances.SINGLE_INSTANCE) { continue; } @@ -168,6 +174,9 @@ public Stream provideTestTemplateInvocationContex || config.excludeIam ? null : TestEnvironmentFeatures.IAM, + deployment == DatabaseEngineDeployment.DSQL + ? TestEnvironmentFeatures.RUN_DSQL_TESTS_ONLY + : null, config.excludeSecretsManager ? null : TestEnvironmentFeatures.SECRETS_MANAGER, config.excludePerformance ? null : TestEnvironmentFeatures.PERFORMANCE, config.excludeMysqlDriver ? TestEnvironmentFeatures.SKIP_MYSQL_DRIVER_TESTS : null, diff --git a/tests/integration/host/src/test/java/integration/util/AuroraTestUtility.java b/tests/integration/host/src/test/java/integration/util/AuroraTestUtility.java index a04d44a9..d96c2e38 100644 --- a/tests/integration/host/src/test/java/integration/util/AuroraTestUtility.java +++ b/tests/integration/host/src/test/java/integration/util/AuroraTestUtility.java @@ -42,12 +42,17 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.logging.Logger; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import java.util.stream.Collectors; + import org.checkerframework.checker.nullness.qual.Nullable; import org.testcontainers.shaded.org.apache.commons.lang3.NotImplementedException; import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; @@ -57,6 +62,12 @@ import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.waiters.WaiterResponse; import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.retries.api.BackoffStrategy; +import software.amazon.awssdk.services.dsql.DsqlClient; +import software.amazon.awssdk.services.dsql.model.CreateClusterRequest; +import software.amazon.awssdk.services.dsql.model.CreateClusterResponse; +import software.amazon.awssdk.services.dsql.model.GetClusterResponse; +import software.amazon.awssdk.services.dsql.model.ResourceNotFoundException; import software.amazon.awssdk.services.ec2.Ec2Client; import software.amazon.awssdk.services.ec2.model.DescribeSecurityGroupsResponse; import software.amazon.awssdk.services.ec2.model.Ec2Exception; @@ -122,6 +133,15 @@ public class AuroraTestUtility { private final RdsClient rdsClient; private final Ec2Client ec2Client; + private final DsqlClient dsqlClient; + + private static final Pattern AURORA_DSQL_CLUSTER_PATTERN = + Pattern.compile( + "^(?[^.]+)\\." + + "(?dsql(?:-[^.]+)?)\\." + + "(?(?[a-zA-Z0-9\\-]+)" + + "\\.on\\.aws\\.?)$", + Pattern.CASE_INSENSITIVE); public AuroraTestUtility( String region, String rdsEndpoint, String awsAccessKeyId, String awsSecretAccessKey, String awsSessionToken) { @@ -162,6 +182,10 @@ public AuroraTestUtility(Region region, String rdsEndpoint, AwsCredentialsProvid .region(region) .credentialsProvider(credentialsProvider) .build(); + dsqlClient = DsqlClient.builder() + .region(region) + .credentialsProvider(credentialsProvider) + .build(); } protected static Region getRegionInternal(String rdsRegion) { @@ -499,6 +523,39 @@ public void deleteCustomClusterParameterGroup(String groupName) { ); } + /** + * Create a DSQL cluster. + * + * @param name A human-readable name to tag the cluster with. + * @return The unique identifier of the created cluster. + */ + public String createDsqlCluster(final String name) throws InterruptedException { + final Map tagMap = new HashMap<>(); + tagMap.put("Name", name); + + final CreateClusterRequest request = CreateClusterRequest.builder() + .deletionProtectionEnabled(false) + .tags(tagMap) + .build(); + final CreateClusterResponse cluster = dsqlClient.createCluster(request); + String identifier = cluster.identifier(); + + final WaiterResponse waiterResponse = dsqlClient.waiter().waitUntilClusterActive( + getCluster -> getCluster.identifier(cluster.identifier()), + config -> config.backoffStrategyV2( + BackoffStrategy.fixedDelayWithoutJitter(Duration.ofSeconds(10)) + ).waitTimeout(Duration.ofMinutes(30)) + ); + + if (waiterResponse.matched().exception().isPresent()) { + deleteCluster(identifier, DatabaseEngineDeployment.DSQL, false); + throw new InterruptedException( + "Unable to create DSQL cluster after waiting for 30 minutes"); + } + + return identifier; + } + /** * Gets the public IP address for the current machine. * @@ -604,6 +661,9 @@ public void deleteCluster(String identifier, DatabaseEngineDeployment deployment case RDS_MULTI_AZ_CLUSTER: this.deleteMultiAzCluster(identifier, waitForCompletion); break; + case DSQL: + this.deleteDsqlCluster(identifier, waitForCompletion); + break; default: throw new UnsupportedOperationException(deployment.toString()); } @@ -803,6 +863,23 @@ public void promoteInstanceToStandalone(String instanceArn) { } } + public void deleteDsqlCluster(String identifier, boolean waitForCompletion) { + dsqlClient.deleteCluster(r -> r.identifier(identifier)); + + WaiterResponse waiterResponse = dsqlClient.waiter().waitUntilClusterNotExists( + getCluster -> getCluster.identifier(identifier), + config -> config.backoffStrategyV2( + BackoffStrategy.fixedDelayWithoutJitter(Duration.ofSeconds(10)) + ).waitTimeout(Duration.ofMinutes(30)) + ); + + if (waiterResponse.matched().exception().isPresent() + && !(waiterResponse.matched().exception().get() instanceof ResourceNotFoundException)) { + throw new RuntimeException( + "Unable to delete DSQL cluster after waiting for 30 minutes"); + } + } + public boolean doesClusterExist(final String clusterId) { final DescribeDbClustersRequest request = DescribeDbClustersRequest.builder().dbClusterIdentifier(clusterId).build(); @@ -825,6 +902,28 @@ public boolean doesInstanceExist(final String instanceId) { } } + public boolean doesDsqlClusterExist(final String identifier) { + try { + final GetClusterResponse response = dsqlClient.getCluster(r -> r.identifier(identifier)); + return response.sdkHttpResponse().isSuccessful(); + } catch (ResourceNotFoundException ex) { + return false; + } + } + + public String getDsqlInstanceId(final String host) { + + if (StringUtils.isNullOrEmpty(host)) { + return null; + } + + final Matcher matcher = AURORA_DSQL_CLUSTER_PATTERN.matcher(host); + if (!matcher.matches()) { + return null; + } + return matcher.group("instance"); + } + public DBCluster getClusterInfo(final String clusterId) { final DescribeDbClustersRequest request = DescribeDbClustersRequest.builder().dbClusterIdentifier(clusterId).build(); diff --git a/tests/unit/test_federated_auth_plugin.py b/tests/unit/test_federated_auth_plugin.py index 1c3a77e3..3f982a70 100644 --- a/tests/unit/test_federated_auth_plugin.py +++ b/tests/unit/test_federated_auth_plugin.py @@ -25,6 +25,7 @@ from aws_advanced_python_wrapper.iam_plugin import TokenInfo from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) +from aws_advanced_python_wrapper.utils.rds_token_utils import RDSTokenUtils _GENERATED_TOKEN = "generated_token" _TEST_TOKEN = "test_token" @@ -101,6 +102,7 @@ def test_pg_connect_valid_token_in_cache(mocker, mock_plugin_service, mock_sessi _token_cache[_PG_CACHE_KEY] = initial_token target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, + RDSTokenUtils(), mock_session) key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + str(_DEFAULT_PG_PORT) + ":postgesqlUser" _token_cache[key] = initial_token @@ -129,7 +131,10 @@ def test_expired_cached_token(mocker, mock_plugin_service, mock_session, mock_fu initial_token = TokenInfo(_TEST_TOKEN, datetime.now() - timedelta(minutes=5)) _token_cache[_PG_CACHE_KEY] = initial_token - target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, mock_session) + target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, + mock_credentials_provider_factory, + RDSTokenUtils(), + mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), @@ -154,7 +159,10 @@ def test_no_cached_token(mocker, mock_plugin_service, mock_session, mock_func, m test_props: Properties = Properties({"plugins": "federated_auth", "user": "postgresqlUser", "idp_username": "user", "idp_password": "password"}) WrapperProperties.DB_USER.set(test_props, _DB_USER) - target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, mock_session) + target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, + mock_credentials_provider_factory, + RDSTokenUtils(), + mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), @@ -183,7 +191,9 @@ def test_no_cached_token_raises_exception(mocker, mock_plugin_service, mock_sess exception_message = "generic exception" mock_func.side_effect = Exception(exception_message) - target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, + target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, + mock_credentials_provider_factory, + RDSTokenUtils(), mock_session) with pytest.raises(Exception) as e_info: target_plugin.connect( @@ -229,7 +239,10 @@ def test_connect_with_specified_iam_host_port_region(mocker, mock_client.generate_db_auth_token.return_value = f"{_TEST_TOKEN}:{expected_region}" - target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, mock_session) + target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, + mock_credentials_provider_factory, + RDSTokenUtils(), + mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), driver_dialect=mock_dialect, diff --git a/tests/unit/test_iam_dsql_plugin.py b/tests/unit/test_iam_dsql_plugin.py new file mode 100644 index 00000000..edd81d48 --- /dev/null +++ b/tests/unit/test_iam_dsql_plugin.py @@ -0,0 +1,444 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import urllib.request +from datetime import datetime, timedelta +from typing import Dict +from unittest.mock import patch + +import pytest + +from aws_advanced_python_wrapper.hostinfo import HostInfo +from aws_advanced_python_wrapper.iam_plugin import IamAuthPlugin, TokenInfo +from aws_advanced_python_wrapper.utils.dsql_token_utils import DSQLTokenUtils +from aws_advanced_python_wrapper.utils.iam_utils import IamAuthUtils +from aws_advanced_python_wrapper.utils.properties import (Properties, + WrapperProperties) + +_GENERATED_TOKEN = "generated_token admin" +_GENERATED_TOKEN_NON_ADMIN = "generated_token non-admin" +_TEST_TOKEN = "test_token" +_DEFAULT_PG_PORT = 5432 + +_PG_HOST_INFO = HostInfo("dsqltestclusternamefoobar1.dsql.us-east-2.on.aws") +_PG_HOST_INFO_WITH_PORT = HostInfo(_PG_HOST_INFO.url, port=1234) +_PG_REGION = "us-east-2" + +_PG_CACHE_KEY = f"{_PG_REGION}:{_PG_HOST_INFO.url}:{_DEFAULT_PG_PORT}:admin" + + +_token_cache: Dict[str, TokenInfo] = {} + + +@pytest.fixture(autouse=True) +def clear_caches(): + _token_cache.clear() + + +@pytest.fixture +def mock_session(mocker): + return mocker.MagicMock() + + +@pytest.fixture +def mock_client(mocker): + return mocker.MagicMock() + + +@pytest.fixture +def mock_connection(mocker): + return mocker.MagicMock() + + +@pytest.fixture +def mock_func(mocker): + return mocker.MagicMock() + + +@pytest.fixture +def mock_plugin_service(mocker): + return mocker.MagicMock() + + +@pytest.fixture +def mock_dialect(mocker): + return mocker.MagicMock() + + +@pytest.fixture(autouse=True) +def mock_default_behavior(mock_session, mock_client, mock_func, mock_connection, mock_plugin_service, mock_dialect): + mock_session.client.return_value = mock_client + mock_client.generate_db_connect_admin_auth_token.return_value = _GENERATED_TOKEN + mock_client.generate_db_connect_auth_token.return_value = _GENERATED_TOKEN_NON_ADMIN + mock_session.get_available_regions.return_value = ['us-east-1', 'us-east-2', 'us-west-1', 'us-west-2'] + mock_func.return_value = mock_connection + mock_plugin_service.driver_dialect = mock_dialect + mock_plugin_service.database_dialect = mock_dialect + mock_dialect.default_port = _DEFAULT_PG_PORT + + +@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) +def set_token_cache(user, host, port, region, expired=False): + if not expired: + initial_token = TokenInfo(_TEST_TOKEN, datetime.now() + timedelta(minutes=5)) + else: + initial_token = TokenInfo(_TEST_TOKEN, datetime.now() - timedelta(minutes=5)) + cache_key: str = IamAuthUtils.get_cache_key( + user, + host, + port, + region + ) + _token_cache[cache_key] = initial_token + + return cache_key, initial_token + + +@pytest.mark.parametrize("user", [ + pytest.param("admin"), + pytest.param("non-admin"), +]) +@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) +def test_pg_connect_valid_token_in_cache(user, mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): + test_props: Properties = Properties({"user": user}) + cache_key, _ = set_token_cache(user, _PG_HOST_INFO.url, _DEFAULT_PG_PORT, _PG_REGION) + + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + DSQLTokenUtils(), + mock_session) + target_plugin.connect( + target_driver_func=mocker.MagicMock(), + driver_dialect=mock_dialect, + host_info=_PG_HOST_INFO, + props=test_props, + is_initial_connection=False, + connect_func=mock_func) + + actual_token = _token_cache.get(cache_key) + if user == "admin": + mock_client.generate_db_connect_admin_auth_token.assert_not_called() + assert _GENERATED_TOKEN != actual_token.token + else: + mock_client.generate_db_connect_auth_token.assert_not_called() + assert _GENERATED_TOKEN_NON_ADMIN != actual_token.token + + assert _TEST_TOKEN == actual_token.token + assert not actual_token.is_expired() + + +@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) +def test_pg_connect_with_invalid_port_fall_backs_to_host_port( + mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): + test_props: Properties = Properties({"user": "admin"}) + invalid_port = "0" + test_props[WrapperProperties.IAM_DEFAULT_PORT.name] = invalid_port + + # Assert no password has been set + assert test_props.get("password") is None + + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + DSQLTokenUtils(), + mock_session) + target_plugin.connect( + target_driver_func=mocker.MagicMock(), + driver_dialect=mock_dialect, + host_info=_PG_HOST_INFO_WITH_PORT, + props=test_props, + is_initial_connection=False, + connect_func=mock_func) + + mock_client.generate_db_connect_admin_auth_token.assert_called_with( + _PG_HOST_INFO.url, _PG_REGION + ) + + actual_token = _token_cache.get(f"{_PG_REGION}:{_PG_HOST_INFO.url}:1234:admin") + assert _GENERATED_TOKEN == actual_token.token + assert not actual_token.is_expired() + + # Assert password has been updated to the value in token cache + expected_props = {"user": "admin", "iam_default_port": "0"} + mock_dialect.set_password.assert_called_with(expected_props, _GENERATED_TOKEN) + + +@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) +def test_pg_connect_with_invalid_port_and_no_host_port_fall_backs_to_host_port( + mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): + test_props: Properties = Properties({"user": "admin"}) + expected_default_pg_port = 5432 + invalid_port = "0" + test_props[WrapperProperties.IAM_DEFAULT_PORT.name] = invalid_port + + # Assert no password has been set + assert test_props.get("password") is None + + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + DSQLTokenUtils(), + mock_session) + target_plugin.connect( + target_driver_func=mocker.MagicMock(), + driver_dialect=mock_dialect, + host_info=_PG_HOST_INFO, + props=test_props, + is_initial_connection=False, + connect_func=mock_func) + + mock_client.generate_db_connect_admin_auth_token.assert_called_with( + _PG_HOST_INFO.url, _PG_REGION + ) + + actual_token = _token_cache.get( + f"{_PG_REGION}:{_PG_HOST_INFO.url}:{expected_default_pg_port}:admin") + assert _GENERATED_TOKEN == actual_token.token + assert not actual_token.is_expired() + + # Assert password has been updated to the value in token cache + expected_props = {"user": "admin", "iam_default_port": "0"} + mock_dialect.set_password.assert_called_with(expected_props, _GENERATED_TOKEN) + + +@pytest.mark.parametrize("user", [ + pytest.param("admin"), + pytest.param("non-admin"), +]) +@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) +def test_connect_expired_token_in_cache(user, mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): + test_props: Properties = Properties({"user": user}) + cache_key, initial_token = set_token_cache(user, _PG_HOST_INFO.url, _DEFAULT_PG_PORT, _PG_REGION, True) + + mock_func.side_effect = Exception("generic exception") + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + DSQLTokenUtils(), + mock_session) + with pytest.raises(Exception): + target_plugin.connect( + target_driver_func=mocker.MagicMock(), + driver_dialect=mock_dialect, + host_info=_PG_HOST_INFO, + props=test_props, + is_initial_connection=False, + connect_func=mock_func) + + actual_token = _token_cache.get(cache_key) + assert initial_token != actual_token + assert not actual_token.is_expired() + + if user == "admin": + mock_client.generate_db_connect_admin_auth_token.assert_called_with( + _PG_HOST_INFO.url, _PG_REGION) + assert _GENERATED_TOKEN == actual_token.token + else: + mock_client.generate_db_connect_auth_token.assert_called_with( + _PG_HOST_INFO.url, _PG_REGION) + assert _GENERATED_TOKEN_NON_ADMIN == actual_token.token + + +@pytest.mark.parametrize("user", [ + pytest.param("admin"), + pytest.param("non-admin"), +]) +@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) +def test_connect_empty_cache(user, mocker, mock_plugin_service, mock_connection, mock_session, mock_func, mock_client, mock_dialect): + test_props: Properties = Properties({"user": user}) + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + DSQLTokenUtils(), + mock_session) + actual_connection = target_plugin.connect( + target_driver_func=mocker.MagicMock(), + driver_dialect=mock_dialect, + host_info=_PG_HOST_INFO, + props=test_props, + is_initial_connection=False, + connect_func=mock_func) + + cache_key: str = IamAuthUtils.get_cache_key( + user, _PG_HOST_INFO.url, _DEFAULT_PG_PORT, _PG_REGION + ) + actual_token = _token_cache.get(cache_key) + + if user == "admin": + mock_client.generate_db_connect_admin_auth_token.assert_called_with( + _PG_HOST_INFO.url, _PG_REGION + ) + assert _GENERATED_TOKEN == actual_token.token + else: + mock_client.generate_db_connect_auth_token.assert_called_with( + _PG_HOST_INFO.url, _PG_REGION) + assert _GENERATED_TOKEN_NON_ADMIN == actual_token.token + + assert mock_connection == actual_connection + assert not actual_token.is_expired() + + +@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) +def test_connect_with_specified_port(mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): + test_props: Properties = Properties({"user": "admin"}) + cache_key_with_new_port: str = f"{_PG_REGION}:{_PG_HOST_INFO.url}:1234:admin" + initial_token = TokenInfo(f"{_TEST_TOKEN}:1234", datetime.now() + timedelta(minutes=5)) + _token_cache[cache_key_with_new_port] = initial_token + + # Assert no password has been set + assert test_props.get("password") is None + + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + DSQLTokenUtils(), + mock_session) + target_plugin.connect( + target_driver_func=mocker.MagicMock(), + driver_dialect=mock_dialect, + host_info=_PG_HOST_INFO_WITH_PORT, + props=test_props, + is_initial_connection=False, + connect_func=mock_func) + + mock_client.generate_db_connect_admin_auth_token.assert_not_called() + + actual_token = _token_cache.get(cache_key_with_new_port) + assert _token_cache.get(_PG_CACHE_KEY) is None + assert _GENERATED_TOKEN != actual_token.token + assert f"{_TEST_TOKEN}:1234" == actual_token.token + assert not actual_token.is_expired() + + # Assert password has been updated to the value in token cache + expected_props = {"user": "admin"} + mock_dialect.set_password.assert_called_with(expected_props, f"{_TEST_TOKEN}:1234") + + +@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) +def test_connect_with_specified_iam_default_port(mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): + test_props: Properties = Properties({"user": "admin"}) + iam_default_port: str = "9999" + test_props[WrapperProperties.IAM_DEFAULT_PORT.name] = iam_default_port + cache_key_with_new_port = f"{_PG_REGION}:{_PG_HOST_INFO.url}:{iam_default_port}:admin" + initial_token = TokenInfo(f"{_TEST_TOKEN}:{iam_default_port}", datetime.now() + timedelta(minutes=5)) + _token_cache[cache_key_with_new_port] = initial_token + + # Assert no password has been set + assert test_props.get("password") is None + + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + DSQLTokenUtils(), + mock_session) + target_plugin.connect( + target_driver_func=mocker.MagicMock(), + driver_dialect=mock_dialect, + host_info=_PG_HOST_INFO_WITH_PORT, + props=test_props, + is_initial_connection=False, + connect_func=mock_func) + + mock_client.generate_db_connect_admin_auth_token.assert_not_called() + + actual_token = _token_cache.get(cache_key_with_new_port) + assert _token_cache.get(_PG_CACHE_KEY) is None + assert _GENERATED_TOKEN != actual_token.token + assert f"{_TEST_TOKEN}:{iam_default_port}" == actual_token.token + assert not actual_token.is_expired() + + # Assert password has been updated to the value in token cache + expected_props = {"user": "admin", "iam_default_port": "9999"} + mock_dialect.set_password.assert_called_with(expected_props, f"{_TEST_TOKEN}:{iam_default_port}") + + +@pytest.mark.parametrize("user", [ + pytest.param("admin"), + pytest.param("non-admin"), +]) +@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) +def test_connect_with_specified_region(user, mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): + test_props: Properties = Properties({"user": user}) + iam_region: str = "us-east-1" + + # Cache a token with a different region + set_token_cache(user, _PG_HOST_INFO.url, _DEFAULT_PG_PORT, _PG_REGION) + test_props[WrapperProperties.IAM_REGION.name] = iam_region + + # Assert no password has been set + assert test_props.get("password") is None + + mock_client.generate_db_connect_admin_auth_token.return_value = f"{_TEST_TOKEN}:{iam_region}" + mock_client.generate_db_connect_auth_token.return_value = f"{_GENERATED_TOKEN_NON_ADMIN}:{iam_region}" + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + DSQLTokenUtils(), + mock_session) + target_plugin.connect( + target_driver_func=mocker.MagicMock(), + driver_dialect=mock_dialect, + host_info=HostInfo(_PG_HOST_INFO.url), + props=test_props, + is_initial_connection=False, + connect_func=mock_func) + + mock_session.client.assert_called_with( + "dsql", + region_name=iam_region + ) + + expected_props = {"iam_region": "us-east-1", "user": user} + actual_token = _token_cache.get(IamAuthUtils.get_cache_key(user, _PG_HOST_INFO.url, _DEFAULT_PG_PORT, iam_region)) + assert not actual_token.is_expired() + + if user == "admin": + mock_client.generate_db_connect_admin_auth_token.assert_called_with( + _PG_HOST_INFO.url, iam_region + ) + assert f"{_TEST_TOKEN}:{iam_region}" == actual_token.token + mock_dialect.set_password.assert_called_with(expected_props, f"{_TEST_TOKEN}:{iam_region}") + else: + mock_client.generate_db_connect_auth_token.assert_called_with( + _PG_HOST_INFO.url, iam_region) + assert f"{_GENERATED_TOKEN_NON_ADMIN}:{iam_region}" == actual_token.token + mock_dialect.set_password.assert_called_with(expected_props, f"{_GENERATED_TOKEN_NON_ADMIN}:{iam_region}") + + +@pytest.mark.parametrize("iam_host", [ + pytest.param("dsqltestclusternamefoobar1.dsql.us-east-2.on.aws"), + pytest.param("dsqltestclusternamefoobar2.dsql.us-east-2.on.aws"), +]) +@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) +def test_connect_with_specified_host(iam_host: str, mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): + test_props: Properties = Properties({"user": "admin"}) + + test_props[WrapperProperties.IAM_HOST.name] = iam_host + + # Assert no password has been set + assert test_props.get("password") is None + + mock_client.generate_db_connect_admin_auth_token.return_value = f"{_TEST_TOKEN}:{iam_host}" + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + DSQLTokenUtils(), + mock_session) + target_plugin.connect( + target_driver_func=mocker.MagicMock(), + driver_dialect=mock_dialect, + host_info=HostInfo("bar.foo.com"), + props=test_props, + is_initial_connection=False, + connect_func=mock_func) + + mock_client.generate_db_connect_admin_auth_token.assert_called_with( + iam_host, _PG_REGION + ) + + actual_token = _token_cache.get(f"{_PG_REGION}:{iam_host}:5432:admin") + assert actual_token is not None + assert _GENERATED_TOKEN != actual_token.token + assert f"{_TEST_TOKEN}:{iam_host}" == actual_token.token + assert not actual_token.is_expired() + + +def test_aws_supported_regions_url_exists(): + url = "https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/Concepts.RegionsAndAvailabilityZones.html" + assert 200 == urllib.request.urlopen(url).getcode() diff --git a/tests/unit/test_iam_plugin.py b/tests/unit/test_iam_plugin.py index 04273698..10a3a3cd 100644 --- a/tests/unit/test_iam_plugin.py +++ b/tests/unit/test_iam_plugin.py @@ -26,6 +26,7 @@ from aws_advanced_python_wrapper.iam_plugin import IamAuthPlugin, TokenInfo from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) +from aws_advanced_python_wrapper.utils.rds_token_utils import RDSTokenUtils _GENERATED_TOKEN = "generated_token" _TEST_TOKEN = "test_token" @@ -99,6 +100,7 @@ def test_pg_connect_valid_token_in_cache(mocker, mock_plugin_service, mock_sessi _token_cache[_PG_CACHE_KEY] = initial_token target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + RDSTokenUtils(), mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), @@ -127,6 +129,7 @@ def test_pg_connect_with_invalid_port_fall_backs_to_host_port( assert test_props.get("password") is None target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + RDSTokenUtils(), mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), @@ -163,6 +166,7 @@ def test_pg_connect_with_invalid_port_and_no_host_port_fall_backs_to_host_port( assert test_props.get("password") is None target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + RDSTokenUtils(), mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), @@ -195,7 +199,9 @@ def test_connect_expired_token_in_cache(mocker, mock_plugin_service, mock_sessio _token_cache[_PG_CACHE_KEY] = initial_token mock_func.side_effect = Exception("generic exception") - target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, mock_session) + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + RDSTokenUtils(), + mock_session) with pytest.raises(Exception): target_plugin.connect( target_driver_func=mocker.MagicMock(), @@ -220,7 +226,9 @@ def test_connect_expired_token_in_cache(mocker, mock_plugin_service, mock_sessio @patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) def test_connect_empty_cache(mocker, mock_plugin_service, mock_connection, mock_session, mock_func, mock_client, mock_dialect): test_props: Properties = Properties({"user": "postgresqlUser"}) - target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, mock_session) + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + RDSTokenUtils(), + mock_session) actual_connection = target_plugin.connect( target_driver_func=mocker.MagicMock(), driver_dialect=mock_dialect, @@ -251,7 +259,9 @@ def test_connect_with_specified_port(mocker, mock_plugin_service, mock_session, # Assert no password has been set assert test_props.get("password") is None - target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, mock_session) + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + RDSTokenUtils(), + mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), driver_dialect=mock_dialect, @@ -285,7 +295,9 @@ def test_connect_with_specified_iam_default_port(mocker, mock_plugin_service, mo # Assert no password has been set assert test_props.get("password") is None - target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, mock_session) + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + RDSTokenUtils(), + mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), driver_dialect=mock_dialect, @@ -323,7 +335,9 @@ def test_connect_with_specified_region(mocker, mock_plugin_service, mock_session assert test_props.get("password") is None mock_client.generate_db_auth_token.return_value = f"{_TEST_TOKEN}:{iam_region}" - target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, mock_session) + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + RDSTokenUtils(), + mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), driver_dialect=mock_dialect, @@ -369,7 +383,9 @@ def test_connect_with_specified_host(iam_host: str, mocker, mock_plugin_service, assert test_props.get("password") is None mock_client.generate_db_auth_token.return_value = f"{_TEST_TOKEN}:{iam_host}" - target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, mock_session) + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + RDSTokenUtils(), + mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), driver_dialect=mock_dialect, @@ -411,7 +427,7 @@ def test_aws_supported_regions_url_exists(): def test_invalid_iam_host(host, mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): test_props: Properties = Properties({"user": "postgresqlUser"}) with pytest.raises(AwsWrapperError): - target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, mock_session) + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, RDSTokenUtils(), mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), driver_dialect=mock_dialect, diff --git a/tests/unit/test_okta_plugin.py b/tests/unit/test_okta_plugin.py index 72f9727a..e2823568 100644 --- a/tests/unit/test_okta_plugin.py +++ b/tests/unit/test_okta_plugin.py @@ -25,6 +25,7 @@ from aws_advanced_python_wrapper.okta_plugin import OktaAuthPlugin from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) +from aws_advanced_python_wrapper.utils.rds_token_utils import RDSTokenUtils _GENERATED_TOKEN = "generated_token" _TEST_TOKEN = "test_token" @@ -100,7 +101,7 @@ def test_pg_connect_valid_token_in_cache(mocker, mock_plugin_service, mock_sessi initial_token = TokenInfo(_TEST_TOKEN, datetime.now() + timedelta(minutes=5)) _token_cache[_PG_CACHE_KEY] = initial_token - target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, mock_session) + target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, RDSTokenUtils(), mock_session) key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + str(_DEFAULT_PG_PORT) + ":postgesqlUser" _token_cache[key] = initial_token @@ -127,7 +128,10 @@ def test_expired_cached_token(mocker, mock_plugin_service, mock_session, mock_fu initial_token = TokenInfo(_TEST_TOKEN, datetime.now() - timedelta(minutes=5)) _token_cache[_PG_CACHE_KEY] = initial_token - target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, mock_session) + target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, + mock_credentials_provider_factory, + RDSTokenUtils(), + mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), @@ -151,7 +155,10 @@ def test_no_cached_token(mocker, mock_plugin_service, mock_session, mock_func, m test_props: Properties = Properties({"plugins": "okta", "user": "postgresqlUser", "idp_username": "user", "idp_password": "password"}) WrapperProperties.DB_USER.set(test_props, _DB_USER) - target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, mock_session) + target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, + mock_credentials_provider_factory, + RDSTokenUtils(), + mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), @@ -179,7 +186,10 @@ def test_no_cached_token_raises_exception(mocker, mock_plugin_service, mock_sess exception_message = "generic exception" mock_func.side_effect = Exception(exception_message) - target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, mock_session) + target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, + mock_credentials_provider_factory, + RDSTokenUtils(), + mock_session) with pytest.raises(Exception) as e_info: target_plugin.connect( @@ -225,7 +235,10 @@ def test_connect_with_specified_iam_host_port_region(mocker, mock_client.generate_db_auth_token.return_value = f"{_TEST_TOKEN}:{expected_region}" - target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, mock_session) + target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, + mock_credentials_provider_factory, + RDSTokenUtils(), + mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), driver_dialect=mock_dialect,