Skip to content

Add support for DSQL iam authentication #919

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions aws_advanced_python_wrapper/dsql_iam_auth_plugin_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING

from aws_advanced_python_wrapper.utils.dsql_token_utils import DSQLTokenUtils
from aws_advanced_python_wrapper.iam_plugin import IamAuthPlugin
from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory
from aws_advanced_python_wrapper.utils.properties import (Properties)

if TYPE_CHECKING:
from aws_advanced_python_wrapper.plugin_service import PluginService

class DsqlIamAuthPluginFactory(PluginFactory):
def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin:
return IamAuthPlugin(plugin_service, DSQLTokenUtils())
11 changes: 7 additions & 4 deletions aws_advanced_python_wrapper/iam_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

from aws_advanced_python_wrapper.utils.iam_utils import IamAuthUtils, TokenInfo
from aws_advanced_python_wrapper.utils.region_utils import RegionUtils
from aws_advanced_python_wrapper.utils.token_utils import TokenUtils
from aws_advanced_python_wrapper.utils.rds_token_utils import RDSTokenUtils

if TYPE_CHECKING:
from boto3 import Session
Expand Down Expand Up @@ -48,11 +50,12 @@ class IamAuthPlugin(Plugin):
_rds_utils: RdsUtils = RdsUtils()
_token_cache: Dict[str, TokenInfo] = {}

def __init__(self, plugin_service: PluginService, session: Optional[Session] = None):
def __init__(self, plugin_service: PluginService, token_utils: TokenUtils, session: Optional[Session] = None):
self._plugin_service = plugin_service
self._session = session

self._region_utils = RegionUtils()
self._token_utils = token_utils
telemetry_factory = self._plugin_service.get_telemetry_factory()
self._fetch_token_counter = telemetry_factory.create_counter("iam.fetch_token.count")
self._cache_size_gauge = telemetry_factory.create_gauge(
Expand Down Expand Up @@ -102,7 +105,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
else:
token_expiry = datetime.now() + timedelta(seconds=token_expiration_sec)
self._fetch_token_counter.inc()
token: str = IamAuthUtils.generate_authentication_token(self._plugin_service, user, host, port, region, client_session=self._session)
token: str = self._token_utils.generate_authentication_token(self._plugin_service, user, host, port, region, client_session=self._session)
self._plugin_service.driver_dialect.set_password(props, token)
IamAuthPlugin._token_cache[cache_key] = TokenInfo(token, token_expiry)

Expand All @@ -120,7 +123,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
# Try to generate a new token and try to connect again
token_expiry = datetime.now() + timedelta(seconds=token_expiration_sec)
self._fetch_token_counter.inc()
token = IamAuthUtils.generate_authentication_token(self._plugin_service, user, host, port, region, client_session=self._session)
token = self._token_utils.generate_authentication_token(self._plugin_service, user, host, port, region, client_session=self._session)
self._plugin_service.driver_dialect.set_password(props, token)
IamAuthPlugin._token_cache[cache_key] = TokenInfo(token, token_expiry)

Expand All @@ -142,4 +145,4 @@ def force_connect(

class IamAuthPluginFactory(PluginFactory):
def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin:
return IamAuthPlugin(plugin_service)
return IamAuthPlugin(plugin_service, RDSTokenUtils())
3 changes: 3 additions & 0 deletions aws_advanced_python_wrapper/plugin_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
HostMonitoringPluginFactory
from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole
from aws_advanced_python_wrapper.iam_plugin import IamAuthPluginFactory
from aws_advanced_python_wrapper.dsql_iam_auth_plugin_factory import DsqlIamAuthPluginFactory
from aws_advanced_python_wrapper.plugin import CanReleaseResources
from aws_advanced_python_wrapper.read_write_splitting_plugin import \
ReadWriteSplittingPluginFactory
Expand Down Expand Up @@ -716,6 +717,7 @@ class PluginManager(CanReleaseResources):

PLUGIN_FACTORIES: Dict[str, Type[PluginFactory]] = {
"iam": IamAuthPluginFactory,
"iam_dsql": DsqlIamAuthPluginFactory,
"aws_secrets_manager": AwsSecretsManagerPluginFactory,
"aurora_connection_tracker": AuroraConnectionTrackerPluginFactory,
"host_monitoring": HostMonitoringPluginFactory,
Expand Down Expand Up @@ -748,6 +750,7 @@ class PluginManager(CanReleaseResources):
HostMonitoringPluginFactory: 500,
FastestResponseStrategyPluginFactory: 600,
IamAuthPluginFactory: 700,
DsqlIamAuthPluginFactory: 710,
AwsSecretsManagerPluginFactory: 800,
FederatedAuthPluginFactory: 900,
LimitlessPluginFactory: 950,
Expand Down
60 changes: 60 additions & 0 deletions aws_advanced_python_wrapper/utils/dsql_token_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Dict, Optional
from aws_advanced_python_wrapper.utils.log import Logger
from aws_advanced_python_wrapper.utils.telemetry.telemetry import \
TelemetryTraceLevel
from aws_advanced_python_wrapper.utils.token_utils import TokenUtils

if TYPE_CHECKING:
from aws_advanced_python_wrapper.plugin_service import PluginService
from boto3 import Session

import boto3

logger = Logger(__name__)

class DSQLTokenUtils(TokenUtils):
def generate_authentication_token(
self,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: move the self to row above similar to other methods

plugin_service: PluginService,
user: Optional[str],
host_name: Optional[str],
port: Optional[int],
region: Optional[str],
credentials: Optional[Dict[str, str]] = None,
client_session: Optional[Session] = None) -> str:
telemetry_factory = plugin_service.get_telemetry_factory()
context = telemetry_factory.open_telemetry_context("fetch authentication token", TelemetryTraceLevel.NESTED)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to change this to "fetch DSQL authentication token" to distinguish it from regular IAM auth? Or should we leave it with the same name?


try:
client = boto3.client("dsql", region_name=region)

if user == "admin":
token = client.generate_db_connect_admin_auth_token(host_name, region)
else:
token = client.generate_db_connect_auth_token(host_name, region)

logger.debug("IamAuthUtils.GeneratedNewAuthToken", token)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In RdsTokenUtils we close the client before returning the token, should we do that here as well or no?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is token credentials right? This is a security risk imo, I don't think it's a good idea even if it's debug logs.

return token
except Exception as ex:
context.set_success(False)
context.set_exception(ex)
raise ex
finally:
context.close_context()
47 changes: 0 additions & 47 deletions aws_advanced_python_wrapper/utils/iam_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,53 +70,6 @@ def get_port(props: Properties, host_info: HostInfo, dialect_default_port: int)
def get_cache_key(user: Optional[str], hostname: Optional[str], port: int, region: Optional[str]) -> str:
return f"{region}:{hostname}:{port}:{user}"

@staticmethod
def generate_authentication_token(
plugin_service: PluginService,
user: Optional[str],
host_name: Optional[str],
port: Optional[int],
region: Optional[str],
credentials: Optional[Dict[str, str]] = None,
client_session: Optional[Session] = None) -> str:
telemetry_factory = plugin_service.get_telemetry_factory()
context = telemetry_factory.open_telemetry_context("fetch authentication token", TelemetryTraceLevel.NESTED)

try:
session = client_session if client_session else boto3.Session()

if credentials is not None:
client = session.client(
'rds',
region_name=region,
aws_access_key_id=credentials.get('AccessKeyId'),
aws_secret_access_key=credentials.get('SecretAccessKey'),
aws_session_token=credentials.get('SessionToken')
)
else:
client = session.client(
'rds',
region_name=region
)

token = client.generate_db_auth_token(
DBHostname=host_name,
Port=port,
DBUsername=user
)

client.close()

logger.debug("IamAuthUtils.GeneratedNewAuthToken", token)
return token
except Exception as ex:
context.set_success(False)
context.set_exception(ex)
raise ex
finally:
context.close_context()


class TokenInfo:
@property
def token(self):
Expand Down
78 changes: 78 additions & 0 deletions aws_advanced_python_wrapper/utils/rds_token_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Dict, Optional
from aws_advanced_python_wrapper.utils.log import Logger
from aws_advanced_python_wrapper.utils.telemetry.telemetry import \
TelemetryTraceLevel
from aws_advanced_python_wrapper.utils.token_utils import TokenUtils

if TYPE_CHECKING:
from aws_advanced_python_wrapper.plugin_service import PluginService
from boto3 import Session

import boto3

logger = Logger(__name__)

class RDSTokenUtils(TokenUtils):
def generate_authentication_token(
self,
plugin_service: PluginService,
user: Optional[str],
host_name: Optional[str],
port: Optional[int],
region: Optional[str],
credentials: Optional[Dict[str, str]] = None,
client_session: Optional[Session] = None) -> str:

telemetry_factory = plugin_service.get_telemetry_factory()
context = telemetry_factory.open_telemetry_context("fetch authentication token", TelemetryTraceLevel.NESTED)

try:
session = client_session if client_session else boto3.Session()

if credentials is not None:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can any of this common client acquisition logic be factored out into the base class?

client = session.client(
'rds',
region_name=region,
aws_access_key_id=credentials.get('AccessKeyId'),
aws_secret_access_key=credentials.get('SecretAccessKey'),
aws_session_token=credentials.get('SessionToken')
)
else:
client = session.client(
'rds',
region_name=region
)

token = client.generate_db_auth_token(
DBHostname=host_name,
Port=port,
DBUsername=user
)

client.close()

logger.debug("IamAuthUtils.GeneratedNewAuthToken", token)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same with this, we should remove it. Ik it was here before but not a good idea to display it as it's part of the credentials.

return token
except Exception as ex:
context.set_success(False)
context.set_exception(ex)
raise ex
finally:
context.close_context()
1 change: 1 addition & 0 deletions aws_advanced_python_wrapper/utils/rds_url_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,5 @@ def __init__(self, is_rds: bool, is_rds_cluster: bool):
RDS_PROXY = True, False,
RDS_INSTANCE = True, False,
RDS_AURORA_LIMITLESS_DB_SHARD_GROUP = True, False,
DSQL_CLUSTER = False, False,
OTHER = False, False
14 changes: 14 additions & 0 deletions aws_advanced_python_wrapper/utils/rdsutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ class RdsUtils:
r"(?P<dns>cluster-|cluster-ro-)+" \
r"(?P<domain>[a-zA-Z0-9]+\.rds\.(?P<region>[a-zA-Z0-9\-]+)" \
r"\.(amazonaws\.com|c2s\.ic\.gov|sc2s\.sgov\.gov))$"
AURORA_DSQL_CLUSTER_PATTERN = r"^(?P<instance>[^.]+)\." \
r"(?P<dns>dsql(?:-[^.]+)?)\." \
r"(?P<domain>(?P<region>[a-zA-Z0-9\-]+)" \
r"\.on\.aws\.?)$"
ELB_PATTERN = r"^(?<instance>.+)\.elb\.((?<region>[a-zA-Z0-9\-]+)\.amazonaws\.com)$"

IP_V4 = r"^(([1-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){1}" \
Expand Down Expand Up @@ -148,6 +152,14 @@ def is_rds_dns(self, host: str) -> bool:

def is_rds_instance(self, host: str) -> bool:
return self._get_dns_group(host) is None and self.is_rds_dns(host)

def is_dsql_cluster(self, host: str) -> bool:
if not host or not host.strip():
return False

pattern = self._find(host, [RdsUtils.AURORA_DSQL_CLUSTER_PATTERN])

return pattern is not None

def is_rds_proxy_dns(self, host: str) -> bool:
dns_group = self._get_dns_group(host)
Expand Down Expand Up @@ -257,6 +269,8 @@ def identify_rds_type(self, host: Optional[str]) -> RdsUrlType:
return RdsUrlType.RDS_PROXY
elif self.is_rds_instance(host):
return RdsUrlType.RDS_INSTANCE
elif self.is_dsql_cluster(host):
return RdsUrlType.DSQL_CLUSTER

return RdsUrlType.OTHER

Expand Down
35 changes: 35 additions & 0 deletions aws_advanced_python_wrapper/utils/token_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Dict, Optional

if TYPE_CHECKING:
from aws_advanced_python_wrapper.plugin_service import PluginService
from boto3 import Session

class TokenUtils(ABC):
@abstractmethod
def generate_authentication_token(
self,
plugin_service: PluginService,
user: Optional[str],
host_name: Optional[str],
port: Optional[int],
region: Optional[str],
credentials: Optional[Dict[str, str]] = None,
client_session: Optional[Session] = None) -> str:
pass
Loading
Loading