Skip to content

Add support for DSQL iam authentication #919

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ Since a database failover is usually identified by reaching a network or a conne

Enhanced Failure Monitoring (EFM) is a feature available from the [Host Monitoring Connection Plugin](./docs/using-the-python-driver/using-plugins/UsingTheHostMonitoringPlugin.md#enhanced-failure-monitoring) that periodically checks the connected database host's health and availability. If a database host is determined to be unhealthy, the connection is aborted (and potentially routed to another healthy host in the cluster).

### Using the AWS Advanced Python Driver with AWS Aurora DSQL
The AWS Advanced Python Driver is able to handle IAM authentication when working with AWS Aurora DSQL clusters.

Please visit [this page](./docs/using-the-python-driver/using-plugins/UsingTheDSQLIamAuthenticationPlugin.md) for more information.


### Using the AWS Advanced Python Driver with plain RDS databases

The AWS Advanced Python Driver also works with RDS provided databases that are not Aurora.
Expand Down
30 changes: 30 additions & 0 deletions aws_advanced_python_wrapper/dsql_iam_auth_plugin_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING

from aws_advanced_python_wrapper.iam_plugin import IamAuthPlugin
from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory
from aws_advanced_python_wrapper.utils.dsql_token_utils import DSQLTokenUtils

if TYPE_CHECKING:
from aws_advanced_python_wrapper.plugin_service import PluginService
from aws_advanced_python_wrapper.utils.properties import Properties


class DsqlIamAuthPluginFactory(PluginFactory):
def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin:
return IamAuthPlugin(plugin_service, DSQLTokenUtils())
12 changes: 9 additions & 3 deletions aws_advanced_python_wrapper/federated_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from aws_advanced_python_wrapper.hostinfo import HostInfo
from aws_advanced_python_wrapper.pep249 import Connection
from aws_advanced_python_wrapper.plugin_service import PluginService
from aws_advanced_python_wrapper.utils.token_utils import TokenUtils

from datetime import datetime, timedelta
from typing import Callable, Dict, Optional, Set
Expand All @@ -43,6 +44,7 @@
from aws_advanced_python_wrapper.utils.messages import Messages
from aws_advanced_python_wrapper.utils.properties import (Properties,
WrapperProperties)
from aws_advanced_python_wrapper.utils.rds_token_utils import RDSTokenUtils
from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils

logger = Logger(__name__)
Expand All @@ -55,12 +57,16 @@ class FederatedAuthPlugin(Plugin):
_rds_utils: RdsUtils = RdsUtils()
_token_cache: Dict[str, TokenInfo] = {}

def __init__(self, plugin_service: PluginService, credentials_provider_factory: CredentialsProviderFactory, session: Optional[Session] = None):
def __init__(self, plugin_service: PluginService,
credentials_provider_factory: CredentialsProviderFactory,
Comment on lines +60 to +61
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
def __init__(self, plugin_service: PluginService,
credentials_provider_factory: CredentialsProviderFactory,
def __init__(self,
plugin_service: PluginService,
credentials_provider_factory: CredentialsProviderFactory,

token_utils: TokenUtils,
session: Optional[Session] = None):
self._plugin_service = plugin_service
self._credentials_provider_factory = credentials_provider_factory
self._session = session

self._region_utils = RegionUtils()
self._token_utils = token_utils
telemetry_factory = self._plugin_service.get_telemetry_factory()
self._fetch_token_counter = telemetry_factory.create_counter("federated.fetch_token.count")
self._cache_size_gauge = telemetry_factory.create_gauge("federated.token_cache.size", lambda: len(FederatedAuthPlugin._token_cache))
Expand Down Expand Up @@ -145,7 +151,7 @@ def _update_authentication_token(self,
credentials: Optional[Dict[str, str]] = self._credentials_provider_factory.get_aws_credentials(region, props)

self._fetch_token_counter.inc()
token: str = IamAuthUtils.generate_authentication_token(
token: str = self._token_utils.generate_authentication_token(
self._plugin_service,
user,
host_info.host,
Expand All @@ -159,7 +165,7 @@ def _update_authentication_token(self,

class FederatedAuthPluginFactory(PluginFactory):
def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin:
return FederatedAuthPlugin(plugin_service, self.get_credentials_provider_factory(plugin_service, props))
return FederatedAuthPlugin(plugin_service, self.get_credentials_provider_factory(plugin_service, props), RDSTokenUtils())

def get_credentials_provider_factory(self, plugin_service: PluginService, props: Properties) -> AdfsCredentialsProviderFactory:
idp_name = WrapperProperties.IDP_NAME.get(props)
Expand Down
13 changes: 8 additions & 5 deletions aws_advanced_python_wrapper/iam_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from typing import TYPE_CHECKING

from aws_advanced_python_wrapper.utils.iam_utils import IamAuthUtils, TokenInfo
from aws_advanced_python_wrapper.utils.rds_token_utils import RDSTokenUtils
from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils
from aws_advanced_python_wrapper.utils.region_utils import RegionUtils

if TYPE_CHECKING:
Expand All @@ -25,6 +27,7 @@
from aws_advanced_python_wrapper.hostinfo import HostInfo
from aws_advanced_python_wrapper.pep249 import Connection
from aws_advanced_python_wrapper.plugin_service import PluginService
from aws_advanced_python_wrapper.utils.token_utils import TokenUtils

from datetime import datetime, timedelta
from typing import Callable, Dict, Optional, Set
Expand All @@ -35,7 +38,6 @@
from aws_advanced_python_wrapper.utils.messages import Messages
from aws_advanced_python_wrapper.utils.properties import (Properties,
WrapperProperties)
from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils

logger = Logger(__name__)

Expand All @@ -48,11 +50,12 @@ class IamAuthPlugin(Plugin):
_rds_utils: RdsUtils = RdsUtils()
_token_cache: Dict[str, TokenInfo] = {}

def __init__(self, plugin_service: PluginService, session: Optional[Session] = None):
def __init__(self, plugin_service: PluginService, token_utils: TokenUtils, session: Optional[Session] = None):
self._plugin_service = plugin_service
self._session = session

self._region_utils = RegionUtils()
self._token_utils = token_utils
telemetry_factory = self._plugin_service.get_telemetry_factory()
self._fetch_token_counter = telemetry_factory.create_counter("iam.fetch_token.count")
self._cache_size_gauge = telemetry_factory.create_gauge(
Expand Down Expand Up @@ -102,7 +105,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
else:
token_expiry = datetime.now() + timedelta(seconds=token_expiration_sec)
self._fetch_token_counter.inc()
token: str = IamAuthUtils.generate_authentication_token(self._plugin_service, user, host, port, region, client_session=self._session)
token: str = self._token_utils.generate_authentication_token(self._plugin_service, user, host, port, region, client_session=self._session)
self._plugin_service.driver_dialect.set_password(props, token)
IamAuthPlugin._token_cache[cache_key] = TokenInfo(token, token_expiry)

Expand All @@ -120,7 +123,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
# Try to generate a new token and try to connect again
token_expiry = datetime.now() + timedelta(seconds=token_expiration_sec)
self._fetch_token_counter.inc()
token = IamAuthUtils.generate_authentication_token(self._plugin_service, user, host, port, region, client_session=self._session)
token = self._token_utils.generate_authentication_token(self._plugin_service, user, host, port, region, client_session=self._session)
self._plugin_service.driver_dialect.set_password(props, token)
IamAuthPlugin._token_cache[cache_key] = TokenInfo(token, token_expiry)

Expand All @@ -142,4 +145,4 @@ def force_connect(

class IamAuthPluginFactory(PluginFactory):
def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin:
return IamAuthPlugin(plugin_service)
return IamAuthPlugin(plugin_service, RDSTokenUtils())
12 changes: 9 additions & 3 deletions aws_advanced_python_wrapper/okta_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from aws_advanced_python_wrapper.hostinfo import HostInfo
from aws_advanced_python_wrapper.pep249 import Connection
from aws_advanced_python_wrapper.plugin_service import PluginService
from aws_advanced_python_wrapper.utils.token_utils import TokenUtils

import requests

Expand All @@ -40,6 +41,7 @@
from aws_advanced_python_wrapper.utils.messages import Messages
from aws_advanced_python_wrapper.utils.properties import (Properties,
WrapperProperties)
from aws_advanced_python_wrapper.utils.rds_token_utils import RDSTokenUtils
from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils

logger = Logger(__name__)
Expand All @@ -51,12 +53,16 @@ class OktaAuthPlugin(Plugin):
_rds_utils: RdsUtils = RdsUtils()
_token_cache: Dict[str, TokenInfo] = {}

def __init__(self, plugin_service: PluginService, credentials_provider_factory: CredentialsProviderFactory, session: Optional[Session] = None):
def __init__(self, plugin_service: PluginService,
credentials_provider_factory: CredentialsProviderFactory,
Comment on lines +56 to +57
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __init__(self, plugin_service: PluginService,
credentials_provider_factory: CredentialsProviderFactory,
def __init__(self,
plugin_service: PluginService,
credentials_provider_factory: CredentialsProviderFactory,

token_utils: TokenUtils,
session: Optional[Session] = None):
self._plugin_service = plugin_service
self._credentials_provider_factory = credentials_provider_factory
self._session = session

self._region_utils = RegionUtils()
self._token_utils = token_utils
telemetry_factory = self._plugin_service.get_telemetry_factory()
self._fetch_token_counter = telemetry_factory.create_counter("okta.fetch_token.count")
self._cache_size_gauge = telemetry_factory.create_gauge("okta.token_cache.size", lambda: len(OktaAuthPlugin._token_cache))
Expand Down Expand Up @@ -140,7 +146,7 @@ def _update_authentication_token(self,
port: int = IamAuthUtils.get_port(props, host_info, self._plugin_service.database_dialect.default_port)
credentials: Optional[Dict[str, str]] = self._credentials_provider_factory.get_aws_credentials(region, props)

token: str = IamAuthUtils.generate_authentication_token(
token: str = self._token_utils.generate_authentication_token(
self._plugin_service,
user,
host_info.host,
Expand Down Expand Up @@ -228,7 +234,7 @@ def get_saml_assertion(self, props: Properties):

class OktaAuthPluginFactory(PluginFactory):
def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin:
return OktaAuthPlugin(plugin_service, self.get_credentials_provider_factory(plugin_service, props))
return OktaAuthPlugin(plugin_service, self.get_credentials_provider_factory(plugin_service, props), RDSTokenUtils())

def get_credentials_provider_factory(self, plugin_service: PluginService, props: Properties) -> OktaCredentialsProviderFactory:
return OktaCredentialsProviderFactory(plugin_service, props)
4 changes: 4 additions & 0 deletions aws_advanced_python_wrapper/plugin_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
from aws_advanced_python_wrapper.developer_plugin import DeveloperPluginFactory
from aws_advanced_python_wrapper.driver_configuration_profiles import \
DriverConfigurationProfiles
from aws_advanced_python_wrapper.dsql_iam_auth_plugin_factory import \
DsqlIamAuthPluginFactory
from aws_advanced_python_wrapper.errors import (AwsWrapperError,
QueryTimeoutError,
UnsupportedOperationError)
Expand Down Expand Up @@ -716,6 +718,7 @@ class PluginManager(CanReleaseResources):

PLUGIN_FACTORIES: Dict[str, Type[PluginFactory]] = {
"iam": IamAuthPluginFactory,
"iam_dsql": DsqlIamAuthPluginFactory,
"aws_secrets_manager": AwsSecretsManagerPluginFactory,
"aurora_connection_tracker": AuroraConnectionTrackerPluginFactory,
"host_monitoring": HostMonitoringPluginFactory,
Expand Down Expand Up @@ -748,6 +751,7 @@ class PluginManager(CanReleaseResources):
HostMonitoringPluginFactory: 500,
FastestResponseStrategyPluginFactory: 600,
IamAuthPluginFactory: 700,
DsqlIamAuthPluginFactory: 710,
AwsSecretsManagerPluginFactory: 800,
FederatedAuthPluginFactory: 900,
LimitlessPluginFactory: 950,
Expand Down
74 changes: 74 additions & 0 deletions aws_advanced_python_wrapper/utils/dsql_token_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING, Dict, Optional

from aws_advanced_python_wrapper.utils.log import Logger
from aws_advanced_python_wrapper.utils.telemetry.telemetry import \
TelemetryTraceLevel
from aws_advanced_python_wrapper.utils.token_utils import TokenUtils

if TYPE_CHECKING:
from aws_advanced_python_wrapper.plugin_service import PluginService
from boto3 import Session

import boto3

logger = Logger(__name__)


class DSQLTokenUtils(TokenUtils):
def generate_authentication_token(
self,
plugin_service: PluginService,
user: Optional[str],
host_name: Optional[str],
port: Optional[int],
region: Optional[str],
credentials: Optional[Dict[str, str]] = None,
client_session: Optional[Session] = None) -> str:
telemetry_factory = plugin_service.get_telemetry_factory()
context = telemetry_factory.open_telemetry_context("fetch authentication token", TelemetryTraceLevel.NESTED)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to change this to "fetch DSQL authentication token" to distinguish it from regular IAM auth? Or should we leave it with the same name?


try:
session = client_session if client_session else boto3.Session()
if credentials is not None:
client = session.client(
'dsql',
region_name=region,
aws_access_key_id=credentials.get('AccessKeyId'),
aws_secret_access_key=credentials.get('SecretAccessKey'),
aws_session_token=credentials.get('SessionToken')
)
else:
client = session.client(
'dsql',
region_name=region
)

if user == "admin":
token = client.generate_db_connect_admin_auth_token(host_name, region)
else:
token = client.generate_db_connect_auth_token(host_name, region)

logger.debug("IamAuthUtils.GeneratedNewAuthToken", token)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In RdsTokenUtils we close the client before returning the token, should we do that here as well or no?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is token credentials right? This is a security risk imo, I don't think it's a good idea even if it's debug logs.

return token
except Exception as ex:
context.set_success(False)
context.set_exception(ex)
raise ex
finally:
context.close_context()
54 changes: 1 addition & 53 deletions aws_advanced_python_wrapper/utils/iam_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,16 @@
from __future__ import annotations

from datetime import datetime
from typing import TYPE_CHECKING, Dict, Optional

import boto3
from typing import TYPE_CHECKING, Optional

from aws_advanced_python_wrapper.errors import AwsWrapperError
from aws_advanced_python_wrapper.utils.log import Logger
from aws_advanced_python_wrapper.utils.messages import Messages
from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType
from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils
from aws_advanced_python_wrapper.utils.telemetry.telemetry import \
TelemetryTraceLevel

if TYPE_CHECKING:
from aws_advanced_python_wrapper.hostinfo import HostInfo
from aws_advanced_python_wrapper.plugin_service import PluginService
from boto3 import Session

from aws_advanced_python_wrapper.utils.properties import (Properties,
WrapperProperties)
Expand Down Expand Up @@ -70,52 +64,6 @@ def get_port(props: Properties, host_info: HostInfo, dialect_default_port: int)
def get_cache_key(user: Optional[str], hostname: Optional[str], port: int, region: Optional[str]) -> str:
return f"{region}:{hostname}:{port}:{user}"

@staticmethod
def generate_authentication_token(
plugin_service: PluginService,
user: Optional[str],
host_name: Optional[str],
port: Optional[int],
region: Optional[str],
credentials: Optional[Dict[str, str]] = None,
client_session: Optional[Session] = None) -> str:
telemetry_factory = plugin_service.get_telemetry_factory()
context = telemetry_factory.open_telemetry_context("fetch authentication token", TelemetryTraceLevel.NESTED)

try:
session = client_session if client_session else boto3.Session()

if credentials is not None:
client = session.client(
'rds',
region_name=region,
aws_access_key_id=credentials.get('AccessKeyId'),
aws_secret_access_key=credentials.get('SecretAccessKey'),
aws_session_token=credentials.get('SessionToken')
)
else:
client = session.client(
'rds',
region_name=region
)

token = client.generate_db_auth_token(
DBHostname=host_name,
Port=port,
DBUsername=user
)

client.close()

logger.debug("IamAuthUtils.GeneratedNewAuthToken", token)
return token
except Exception as ex:
context.set_success(False)
context.set_exception(ex)
raise ex
finally:
context.close_context()


class TokenInfo:
@property
Expand Down
Loading