Skip to content

Commit a1c6593

Browse files
add plugin for DSQL auth tokens
1 parent 3b32ac2 commit a1c6593

File tree

8 files changed

+474
-1
lines changed

8 files changed

+474
-1
lines changed

wrapper/build.gradle.kts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ dependencies {
3737
compileOnly("software.amazon.awssdk:rds:2.31.78")
3838
compileOnly("software.amazon.awssdk:auth:2.31.45") // Required for IAM (light implementation)
3939
compileOnly("software.amazon.awssdk:http-client-spi:2.31.60") // Required for IAM (light implementation)
40+
compileOnly("software.amazon.awssdk:dsql:2.31.78")
4041
compileOnly("software.amazon.awssdk:sts:2.31.78")
4142
compileOnly("com.zaxxer:HikariCP:4.0.3") // Version 4.+ is compatible with Java 8
4243
compileOnly("com.mchange:c3p0:0.11.0")
@@ -73,6 +74,7 @@ dependencies {
7374
testImplementation("software.amazon.awssdk:rds:2.31.78")
7475
testImplementation("software.amazon.awssdk:auth:2.31.45") // Required for IAM (light implementation)
7576
testImplementation("software.amazon.awssdk:http-client-spi:2.31.60") // Required for IAM (light implementation)
77+
testImplementation("software.amazon.awssdk:dsql:2.31.78")
7678
testImplementation("software.amazon.awssdk:ec2:2.31.78")
7779
testImplementation("software.amazon.awssdk:secretsmanager:2.31.12")
7880
testImplementation("software.amazon.awssdk:sts:2.31.78")

wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
import software.amazon.jdbc.plugin.failover.FailoverConnectionPluginFactory;
4646
import software.amazon.jdbc.plugin.federatedauth.FederatedAuthPluginFactory;
4747
import software.amazon.jdbc.plugin.federatedauth.OktaAuthPluginFactory;
48+
import software.amazon.jdbc.plugin.iam.DsqlIamConnectionPluginFactory;
4849
import software.amazon.jdbc.plugin.iam.IamAuthConnectionPluginFactory;
4950
import software.amazon.jdbc.plugin.limitless.LimitlessConnectionPluginFactory;
5051
import software.amazon.jdbc.plugin.readwritesplitting.ReadWriteSplittingPluginFactory;
@@ -75,6 +76,7 @@ public class ConnectionPluginChainBuilder {
7576
put("failover", new FailoverConnectionPluginFactory());
7677
put("failover2", new software.amazon.jdbc.plugin.failover2.FailoverConnectionPluginFactory());
7778
put("iam", new IamAuthConnectionPluginFactory());
79+
put("iamDsql", new DsqlIamConnectionPluginFactory());
7880
put("awsSecretsManager", new AwsSecretsManagerConnectionPluginFactory());
7981
put("federatedAuth", new FederatedAuthPluginFactory());
8082
put("okta", new OktaAuthPluginFactory());
@@ -114,6 +116,7 @@ public class ConnectionPluginChainBuilder {
114116
put(FastestResponseStrategyPluginFactory.class, 900);
115117
put(LimitlessConnectionPluginFactory.class, 950);
116118
put(IamAuthConnectionPluginFactory.class, 1000);
119+
put(DsqlIamConnectionPluginFactory.class, 1001);
117120
put(AwsSecretsManagerConnectionPluginFactory.class, 1100);
118121
put(FederatedAuthPluginFactory.class, 1200);
119122
put(LogQueryConnectionPluginFactory.class, 1300);
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License").
5+
* You may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package software.amazon.jdbc.plugin.iam;
18+
19+
import java.util.Properties;
20+
import org.checkerframework.checker.nullness.qual.NonNull;
21+
import software.amazon.jdbc.ConnectionPlugin;
22+
import software.amazon.jdbc.ConnectionPluginFactory;
23+
import software.amazon.jdbc.PluginService;
24+
25+
/**
26+
* Provides {@link ConnectionPlugin} instances which can be used to connect to Amazon Aurora DSQL.
27+
*/
28+
public class DsqlIamConnectionPluginFactory implements ConnectionPluginFactory {
29+
@Override
30+
public ConnectionPlugin getInstance(@NonNull final PluginService pluginService, @NonNull final Properties props) {
31+
return new IamAuthConnectionPlugin(pluginService, new DsqlTokenUtility());
32+
}
33+
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License").
5+
* You may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package software.amazon.jdbc.plugin.iam;
18+
19+
import org.checkerframework.checker.nullness.qual.NonNull;
20+
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
21+
import software.amazon.awssdk.regions.Region;
22+
import software.amazon.awssdk.services.dsql.DsqlUtilities;
23+
24+
/**
25+
* Represents an {@link IamTokenUtility} which provides auth tokens for connecting to Amazon Aurora DSQL.
26+
*/
27+
public class DsqlTokenUtility implements IamTokenUtility {
28+
29+
private DsqlUtilities utilities = null;
30+
31+
public DsqlTokenUtility() { }
32+
33+
@Override
34+
public String generateAuthenticationToken(
35+
@NonNull final AwsCredentialsProvider credentialsProvider,
36+
@NonNull final Region region,
37+
@NonNull final String hostname,
38+
final int port,
39+
@NonNull final String username) {
40+
if (this.utilities == null) {
41+
this.utilities = DsqlUtilities.builder()
42+
.credentialsProvider(credentialsProvider)
43+
.region(region)
44+
.build();
45+
}
46+
if (username.equals("admin")) {
47+
return this.utilities.generateDbConnectAdminAuthToken((builder) ->
48+
builder.hostname(hostname).region(region)
49+
);
50+
} else {
51+
return this.utilities.generateDbConnectAuthToken((builder) ->
52+
builder.hostname(hostname).region(region)
53+
);
54+
}
55+
}
56+
}

wrapper/src/main/java/software/amazon/jdbc/util/RdsUtils.java

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,14 @@ public class RdsUtils {
144144
+ "\\.(amazonaws\\.com\\.?|c2s\\.ic\\.gov\\.?|sc2s\\.sgov\\.gov\\.?))$",
145145
Pattern.CASE_INSENSITIVE);
146146

147+
private static final Pattern AURORA_DSQL_CLUSTER_PATTERN =
148+
Pattern.compile(
149+
"^(?<instance>[^.]+)\\."
150+
+ "(?<dns>dsql(?:-[^.]+)?)\\."
151+
+ "(?<domain>(?<region>[a-zA-Z0-9\\-]+)"
152+
+ "\\.on\\.aws\\.?)$",
153+
Pattern.CASE_INSENSITIVE);
154+
147155
private static final Pattern ELB_PATTERN =
148156
Pattern.compile(
149157
"^(?<instance>.+)\\.elb\\."
@@ -259,14 +267,25 @@ public String getRdsInstanceHostPattern(final String host) {
259267
return group == null ? "?" : "?." + group;
260268
}
261269

270+
public String getDsqlInstanceId(final String host) {
271+
final String preparedHost = getPreparedHost(host);
272+
if (StringUtils.isNullOrEmpty(preparedHost)) {
273+
return null;
274+
}
275+
276+
final Matcher matcher = cacheMatcher(preparedHost, AURORA_DSQL_CLUSTER_PATTERN);
277+
return getRegexGroup(matcher, INSTANCE_GROUP);
278+
}
279+
262280
public String getRdsRegion(final String host) {
263281
final String preparedHost = getPreparedHost(host);
264282
if (StringUtils.isNullOrEmpty(preparedHost)) {
265283
return null;
266284
}
267285

268286
final Matcher matcher = cacheMatcher(preparedHost,
269-
AURORA_DNS_PATTERN, AURORA_CHINA_DNS_PATTERN, AURORA_OLD_CHINA_DNS_PATTERN, AURORA_GOV_DNS_PATTERN);
287+
AURORA_DNS_PATTERN, AURORA_CHINA_DNS_PATTERN, AURORA_OLD_CHINA_DNS_PATTERN, AURORA_GOV_DNS_PATTERN,
288+
AURORA_DSQL_CLUSTER_PATTERN);
270289
final String group = getRegexGroup(matcher, REGION_GROUP);
271290
if (group != null) {
272291
return group;
@@ -294,6 +313,11 @@ public boolean isLimitlessDbShardGroupDns(final String host) {
294313
return dnsGroup != null && dnsGroup.equalsIgnoreCase("shardgrp-");
295314
}
296315

316+
public boolean isDsqlCluster(final String host) {
317+
final String instanceId = getDsqlInstanceId(host);
318+
return instanceId != null;
319+
}
320+
297321
public String getRdsClusterHostUrl(final String host) {
298322
final String preparedHost = getPreparedHost(host);
299323
if (StringUtils.isNullOrEmpty(preparedHost)) {
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License").
5+
* You may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package software.amazon.jdbc.plugin.iam;
18+
19+
import static org.junit.jupiter.api.Assertions.assertEquals;
20+
import static org.junit.jupiter.api.Assertions.assertFalse;
21+
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
22+
import static org.junit.jupiter.api.Assertions.assertNotNull;
23+
import static org.junit.jupiter.api.Assertions.assertTrue;
24+
import static org.mockito.ArgumentMatchers.anyString;
25+
import static org.mockito.ArgumentMatchers.eq;
26+
import static org.mockito.Mockito.doReturn;
27+
import static software.amazon.jdbc.plugin.iam.DsqlTokenUtilityTest.ADMIN_USER;
28+
import static software.amazon.jdbc.plugin.iam.DsqlTokenUtilityTest.REGULAR_USER;
29+
30+
import java.sql.Connection;
31+
import java.sql.SQLException;
32+
import java.util.List;
33+
import java.util.Properties;
34+
import org.junit.jupiter.api.AfterEach;
35+
import org.junit.jupiter.api.BeforeEach;
36+
import org.junit.jupiter.api.Test;
37+
import org.junit.jupiter.params.ParameterizedTest;
38+
import org.junit.jupiter.params.provider.ValueSource;
39+
import org.mockito.Mock;
40+
import org.mockito.Mockito;
41+
import org.mockito.MockitoAnnotations;
42+
import software.amazon.awssdk.regions.Region;
43+
import software.amazon.jdbc.ConnectionPlugin;
44+
import software.amazon.jdbc.ConnectionPluginChainBuilder;
45+
import software.amazon.jdbc.ConnectionProvider;
46+
import software.amazon.jdbc.HostSpec;
47+
import software.amazon.jdbc.HostSpecBuilder;
48+
import software.amazon.jdbc.JdbcCallable;
49+
import software.amazon.jdbc.PluginManagerService;
50+
import software.amazon.jdbc.PluginService;
51+
import software.amazon.jdbc.PropertyDefinition;
52+
import software.amazon.jdbc.dialect.Dialect;
53+
import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy;
54+
import software.amazon.jdbc.plugin.TokenInfo;
55+
import software.amazon.jdbc.util.FullServicesContainer;
56+
import software.amazon.jdbc.util.IamAuthUtils;
57+
import software.amazon.jdbc.util.telemetry.TelemetryContext;
58+
import software.amazon.jdbc.util.telemetry.TelemetryFactory;
59+
import software.amazon.jdbc.util.telemetry.TelemetryTraceLevel;
60+
61+
class DsqlIamConnectionPluginTest {
62+
63+
private static final Region TEST_REGION = Region.US_EAST_1;
64+
private static final String TEST_HOSTNAME = String.format("foo0bar1baz2quux3quuux4.dsql.%s.on.aws", TEST_REGION);
65+
private static final int TEST_PORT = 5432;
66+
67+
private static final String DRIVER_PROTOCOL = "jdbc:postgresql:";
68+
private static final HostSpec HOST_SPEC = new HostSpecBuilder(new SimpleHostAvailabilityStrategy())
69+
.host(TEST_HOSTNAME).port(TEST_PORT).build();
70+
71+
private static final String DEFAULT_USERNAME = "admin";
72+
73+
private final Properties props = new Properties();
74+
75+
static void assertTokenContainsProperties(final String token, final String hostname, final String username) {
76+
assertNotNull(token);
77+
assertFalse(token.isEmpty());
78+
79+
assertTrue(token.contains(hostname));
80+
81+
final String expectedAction;
82+
if (username.equals("admin")) {
83+
expectedAction = "DbConnectAdmin";
84+
} else {
85+
expectedAction = "DbConnect";
86+
}
87+
88+
// Include the ampersand to ensure the complete action is compared.
89+
assertTrue(token.contains("Action=" + expectedAction + "&"));
90+
}
91+
92+
private AutoCloseable cleanMocksCallback;
93+
@Mock private Connection mockConnection;
94+
@Mock private PluginService mockPluginService;
95+
@Mock private FullServicesContainer mockServicesContainer;
96+
@Mock private Dialect mockDialect;
97+
@Mock private TelemetryFactory mockTelemetryFactory;
98+
@Mock private TelemetryContext mockTelemetryContext;
99+
@Mock private JdbcCallable<Connection, SQLException> mockLambda;
100+
@Mock private ConnectionProvider mockConnectionProvider;
101+
@Mock private PluginManagerService mockPluginManagerService;
102+
103+
@BeforeEach
104+
public void init() {
105+
cleanMocksCallback = MockitoAnnotations.openMocks(this);
106+
107+
IamAuthConnectionPlugin.clearCache();
108+
109+
props.setProperty(PropertyDefinition.USER.name, DEFAULT_USERNAME);
110+
props.setProperty("iamRegion", Region.US_EAST_1.toString());
111+
props.setProperty(PropertyDefinition.PLUGINS.name, "iamDsql");
112+
113+
doReturn(mockPluginService).when(mockServicesContainer).getPluginService();
114+
doReturn(mockDialect).when(mockPluginService).getDialect();
115+
doReturn(TEST_PORT).when(mockDialect).getDefaultPort();
116+
doReturn(mockTelemetryFactory).when(mockPluginService).getTelemetryFactory();
117+
doReturn(mockTelemetryContext).when(mockTelemetryFactory)
118+
.openTelemetryContext(anyString(), eq(TelemetryTraceLevel.NESTED));
119+
}
120+
121+
@AfterEach
122+
public void cleanup() throws Exception {
123+
cleanMocksCallback.close();
124+
}
125+
126+
@SuppressWarnings("resource") // Prevent Mockito warning when mocking closeable return type.
127+
private void assertPluginProvidesDsqlTokens(final ConnectionPlugin plugin, final String username)
128+
throws SQLException {
129+
Mockito.doReturn(mockConnection).when(mockLambda).call();
130+
131+
plugin
132+
.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda)
133+
.close();
134+
135+
final String cacheKey = IamAuthUtils.getCacheKey(
136+
username,
137+
TEST_HOSTNAME,
138+
TEST_PORT,
139+
TEST_REGION);
140+
141+
final TokenInfo info = IamAuthCacheHolder.tokenCache.get(cacheKey);
142+
final String token = info.getToken();
143+
144+
assertTokenContainsProperties(token, TEST_HOSTNAME, username);
145+
}
146+
147+
@Test
148+
public void testDsqlPluginRegistration() throws SQLException {
149+
ConnectionPluginChainBuilder builder = new ConnectionPluginChainBuilder();
150+
151+
final List<ConnectionPlugin> result = builder.getPlugins(
152+
mockServicesContainer,
153+
mockConnectionProvider,
154+
null,
155+
mockPluginManagerService,
156+
props,
157+
null);
158+
159+
// 2 because default plugin is always included.
160+
assertEquals(2, result.size());
161+
final ConnectionPlugin plugin = result.get(0);
162+
163+
assertInstanceOf(IamAuthConnectionPlugin.class, plugin);
164+
assertPluginProvidesDsqlTokens(plugin, DEFAULT_USERNAME);
165+
}
166+
167+
@ParameterizedTest
168+
@ValueSource(strings = {REGULAR_USER, ADMIN_USER})
169+
public void testDsqlTokenGeneratedBasedOnUser(final String username) throws SQLException {
170+
props.setProperty(PropertyDefinition.USER.name, username);
171+
172+
final DsqlIamConnectionPluginFactory factory = new DsqlIamConnectionPluginFactory();
173+
final ConnectionPlugin plugin = factory.getInstance(mockPluginService, props);
174+
assertPluginProvidesDsqlTokens(plugin, username);
175+
}
176+
}

0 commit comments

Comments
 (0)