Skip to content

Commit 0a21054

Browse files
authored
chore: validate iam_host (#471)
1 parent a4a55b2 commit 0a21054

File tree

4 files changed

+68
-14
lines changed

4 files changed

+68
-14
lines changed

aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ IamAuthPlugin.InvalidPort=[IamAuthPlugin] Port number: {} is not valid. Port num
125125
IamAuthPlugin.NoValidPort=[IamAuthPlugin] Unable to determine a valid port.
126126
IamAuthPlugin.UnhandledException=[IamAuthPlugin] Unhandled exception: {}
127127
IamAuthPlugin.UseCachedIamToken=[IamAuthPlugin] Used cached IAM token = {}
128-
128+
IAMAuthPlugin.InvalidHost=[IamAuthPlugin] Invalid IAM host {}. The IAM host must be a valid RDS or Aurora endpoint.
129129
IamPlugin.IsNoneOrEmpty=[IamPlugin] Property "{}" is None or empty.
130130

131131
LogUtils.Topology=[LogUtils] Topology {}

aws_advanced_python_wrapper/utils/iamutils.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@
1717
from datetime import datetime
1818
from typing import TYPE_CHECKING
1919

20+
from aws_advanced_python_wrapper.errors import AwsWrapperError
21+
from aws_advanced_python_wrapper.utils.messages import Messages
22+
from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType
23+
from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils
24+
2025
if TYPE_CHECKING:
2126
from aws_advanced_python_wrapper.hostinfo import HostInfo
2227

@@ -25,10 +30,21 @@
2530

2631

2732
class IamAuthUtils:
28-
2933
@staticmethod
3034
def get_iam_host(props: Properties, host_info: HostInfo):
31-
return WrapperProperties.IAM_HOST.get(props) if WrapperProperties.IAM_HOST.get(props) else host_info.host
35+
host = WrapperProperties.IAM_HOST.get(props) if WrapperProperties.IAM_HOST.get(props) else host_info.host
36+
IamAuthUtils.validate_iam_host(host)
37+
return host
38+
39+
@staticmethod
40+
def validate_iam_host(host: str | None):
41+
if host is None:
42+
raise AwsWrapperError(Messages.get_formatted("IAMAuthPlugin.InvalidHost", "[No host provided]"))
43+
44+
utils = RdsUtils()
45+
rds_type = utils.identify_rds_type(host)
46+
if rds_type == RdsUrlType.OTHER or rds_type == RdsUrlType.IP_ADDRESS:
47+
raise AwsWrapperError(Messages.get_formatted("IAMAuthPlugin.InvalidHost", host))
3248

3349
@staticmethod
3450
def get_port(props: Properties, host_info: HostInfo, dialect_default_port: int) -> int:

tests/integration/container/test_iam_authentication.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,7 @@ def test_iam_wrong_database_username(self, test_environment: TestEnvironment,
7373
plugins="iam",
7474
**props)
7575

76-
def test_iam_no_database_username(self, test_environment: TestEnvironment,
77-
test_driver: TestDriver, conn_utils, props):
76+
def test_iam_no_database_username(self, test_driver: TestDriver, conn_utils, props):
7877
target_driver_connect = DriverHelper.get_connect_func(test_driver)
7978
params = conn_utils.get_connect_params()
8079
params.pop("use_pure", None) # AWS tokens are truncated when using the pure Python MySQL driver
@@ -83,6 +82,15 @@ def test_iam_no_database_username(self, test_environment: TestEnvironment,
8382
with pytest.raises(AwsWrapperError):
8483
AwsWrapperConnection.connect(target_driver_connect, **params, plugins="iam", **props)
8584

85+
def test_iam_invalid_host(self, test_driver: TestDriver, conn_utils, props):
86+
target_driver_connect = DriverHelper.get_connect_func(test_driver)
87+
params = conn_utils.get_connect_params()
88+
params.pop("use_pure", None) # AWS tokens are truncated when using the pure Python MySQL driver
89+
params.update({"iam_host": "<>", "plugins": "iam"})
90+
91+
with pytest.raises(AwsWrapperError):
92+
AwsWrapperConnection.connect(target_driver_connect, **params, **props)
93+
8694
def test_iam_using_ip_address(self, test_environment: TestEnvironment,
8795
test_driver: TestDriver, conn_utils, props):
8896
target_driver_connect = DriverHelper.get_connect_func(test_driver)

tests/unit/test_iam_plugin.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import pytest
2323

24+
from aws_advanced_python_wrapper.errors import AwsWrapperError
2425
from aws_advanced_python_wrapper.hostinfo import HostInfo
2526
from aws_advanced_python_wrapper.iam_plugin import IamAuthPlugin, TokenInfo
2627
from aws_advanced_python_wrapper.utils.properties import (Properties,
@@ -350,10 +351,17 @@ def test_connect_with_specified_region(mocker, mock_plugin_service, mock_session
350351
mock_dialect.set_password.assert_called_with(expected_props, f"{_TEST_TOKEN}:{iam_region}")
351352

352353

354+
@pytest.mark.parametrize("iam_host", [
355+
pytest.param("foo.testdb.us-east-2.rds.amazonaws.com"),
356+
pytest.param("test.cluster-123456789012.us-east-2.rds.amazonaws.com"),
357+
pytest.param("test-.cluster-ro-123456789012.us-east-2.rds.amazonaws.com"),
358+
pytest.param("test.cluster-custom-123456789012.us-east-2.rds.amazonaws.com"),
359+
pytest.param("test-.proxy-123456789012.us-east-2.rds.amazonaws.com.cn"),
360+
pytest.param("test-.proxy-123456789012.us-east-2.rds.amazonaws.com.proxy"),
361+
])
353362
@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache)
354-
def test_connect_with_specified_host(mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect):
363+
def test_connect_with_specified_host(iam_host: str, mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect):
355364
test_props: Properties = Properties({"user": "postgresqlUser"})
356-
iam_host: str = "foo.testdb.us-east-2.rds.amazonaws.com"
357365

358366
test_props[WrapperProperties.IAM_HOST.name] = iam_host
359367

@@ -365,7 +373,7 @@ def test_connect_with_specified_host(mocker, mock_plugin_service, mock_session,
365373
target_plugin.connect(
366374
target_driver_func=mocker.MagicMock(),
367375
driver_dialect=mock_dialect,
368-
host_info=HostInfo("pg.testdb.us-east-2.rds.amazonaws.com"),
376+
host_info=HostInfo("bar.foo.com"),
369377
props=test_props,
370378
is_initial_connection=False,
371379
connect_func=mock_func)
@@ -376,16 +384,38 @@ def test_connect_with_specified_host(mocker, mock_plugin_service, mock_session,
376384
DBUsername="postgresqlUser"
377385
)
378386

379-
actual_token = _token_cache.get("us-east-2:foo.testdb.us-east-2.rds.amazonaws.com:5432:postgresqlUser")
387+
actual_token = _token_cache.get(f"us-east-2:{iam_host}:5432:postgresqlUser")
388+
assert actual_token is not None
380389
assert _GENERATED_TOKEN != actual_token.token
381-
assert f"{_TEST_TOKEN}:foo.testdb.us-east-2.rds.amazonaws.com" == actual_token.token
390+
assert f"{_TEST_TOKEN}:{iam_host}" == actual_token.token
382391
assert actual_token.is_expired() is False
383392

384-
# Assert password has been updated to the value in token cache
385-
expected_props = {"iam_host": "foo.testdb.us-east-2.rds.amazonaws.com", "user": "postgresqlUser"}
386-
mock_dialect.set_password.assert_called_with(expected_props, f"{_TEST_TOKEN}:foo.testdb.us-east-2.rds.amazonaws.com")
387-
388393

389394
def test_aws_supported_regions_url_exists():
390395
url = "https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/Concepts.RegionsAndAvailabilityZones.html"
391396
assert 200 == urllib.request.urlopen(url).getcode()
397+
398+
399+
@pytest.mark.parametrize("host", [
400+
pytest.param("<>"),
401+
pytest.param("#"),
402+
pytest.param("'"),
403+
pytest.param("\""),
404+
pytest.param("%"),
405+
pytest.param("^"),
406+
pytest.param("https://foo.com/abc.html"),
407+
pytest.param("foo.boo//"),
408+
pytest.param("8.8.8.8"),
409+
pytest.param("a.b"),
410+
])
411+
def test_invalid_iam_host(host, mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect):
412+
test_props: Properties = Properties({"user": "postgresqlUser"})
413+
with pytest.raises(AwsWrapperError):
414+
target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, mock_session)
415+
target_plugin.connect(
416+
target_driver_func=mocker.MagicMock(),
417+
driver_dialect=mock_dialect,
418+
host_info=HostInfo(host),
419+
props=test_props,
420+
is_initial_connection=False,
421+
connect_func=mock_func)

0 commit comments

Comments
 (0)