diff --git a/ci/start-broker.sh b/ci/start-broker.sh index 38c60cc1b9..f954f9aaa8 100755 --- a/ci/start-broker.sh +++ b/ci/start-broker.sh @@ -20,7 +20,7 @@ cp -R "${PWD}"/tls-gen/basic/result/* rabbitmq-configuration/tls chmod o+r rabbitmq-configuration/tls/* chmod g+r rabbitmq-configuration/tls/* -echo "[rabbitmq_stream,rabbitmq_mqtt,rabbitmq_stomp,rabbitmq_amqp1_0,rabbitmq_auth_mechanism_ssl]." >> rabbitmq-configuration/enabled_plugins +echo "[rabbitmq_stream,rabbitmq_mqtt,rabbitmq_stomp,rabbitmq_amqp1_0,rabbitmq_auth_mechanism_ssl,rabbitmq_auth_backend_oauth2]." >> rabbitmq-configuration/enabled_plugins echo "loopback_users = none @@ -37,8 +37,25 @@ auth_mechanisms.1 = PLAIN auth_mechanisms.2 = ANONYMOUS auth_mechanisms.3 = EXTERNAL +auth_backends.1 = internal +auth_backends.2 = rabbit_auth_backend_oauth2 + stream.listeners.ssl.1 = 5551" >> rabbitmq-configuration/rabbitmq.conf +echo "[ + {rabbitmq_auth_backend_oauth2, [{key_config, + [{signing_keys, + #{<<\"token-key\">> => + {map, + #{<<\"alg\">> => <<\"HS256\">>, + <<\"k\">> => <<\"abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGH\">>, + <<\"kid\">> => <<\"token-key\">>, + <<\"kty\">> => <<\"oct\">>, + <<\"use\">> => <<\"sig\">>, + <<\"value\">> => <<\"token-key\">>}}}}]}, + {resource_server_id,<<\"rabbitmq\">>}]} +]." >> rabbitmq-configuration/advanced.config + echo "Running RabbitMQ ${RABBITMQ_IMAGE}" docker rm -f rabbitmq 2>/dev/null || echo "rabbitmq was not running" diff --git a/pom.xml b/pom.xml index 5bc6656ac7..930cfd580d 100644 --- a/pom.xml +++ b/pom.xml @@ -54,6 +54,7 @@ 0.34.1 4.2.33 1.15.3 + 2.13.1 13.1.2 4.7.5 1.28.0 @@ -66,10 +67,10 @@ 5.26.0 3.18.0 1.19.0 - 2.13.1 0.10.7 1.2.5 1.5.3 + 1.81 1.0.4 2.0.72.Final 3.14.0 @@ -105,6 +106,7 @@ 0.8.0 true true + 0.9.6 @@ -185,6 +187,13 @@ true + + com.google.code.gson + gson + ${gson.version} + true + + org.junit.jupiter junit-jupiter-engine @@ -285,13 +294,6 @@ test - - com.google.code.gson - gson - ${gson.version} - test - - io.vavr vavr @@ -313,6 +315,20 @@ test + + org.bouncycastle + bcpkix-jdk18on + ${bouncycastle.version} + test + + + + org.bitbucket.b_c + jose4j + ${jose4j.version} + test + + org.openjdk.jmh jmh-core diff --git a/src/docs/asciidoc/advanced-topics.adoc b/src/docs/asciidoc/advanced-topics.adoc index 0c338657b9..ba2254dc83 100644 --- a/src/docs/asciidoc/advanced-topics.adoc +++ b/src/docs/asciidoc/advanced-topics.adoc @@ -86,6 +86,31 @@ A defined set of values shared across the messages is a good candidate: geograph Cardinality of filter values can be from a few to a few thousands. Extreme cardinality (a couple or dozens of thousands) can make filtering less efficient. +=== OAuth 2 Support + +The client can authenticate against an OAuth 2 server like https://github.com/cloudfoundry/uaa[UAA]. +It uses the https://tools.ietf.org/html/rfc6749#section-4.4[OAuth 2 Client Credentials flow]. +The https://www.rabbitmq.com/docs/oauth2[OAuth 2 plugin] must be enabled on the server side and configured to use the same OAuth 2 server as the client. + +How to retrieve the OAuth 2 token is configured at the environment level: + +.Configuring OAuth 2 token retrieval +[source,java,indent=0] +-------- +include::{test-examples}/EnvironmentUsage.java[tag=oauth2] +-------- +<1> Access the OAuth 2 configuration +<2> Set the token endpoint URI +<3> Authenticate the client application +<4> Set the grant type +<5> Set optional parameters (depends on the OAuth 2 server) +<6> Set the SSL context (e.g. to verify and trust the identity of the OAuth 2 server) + +The environment retrieves tokens and uses them to create stream connections. +It also takes care of refreshing the tokens before they expire and of re-authenticating existing connections so the broker does not close them when their token expires. + +The environment uses the same token for all the connections it maintains. + === Using Native `epoll` The stream Java client uses the https://netty.io/[Netty] network framework and its Java NIO transport implementation by default. diff --git a/src/main/java/com/rabbitmq/stream/EnvironmentBuilder.java b/src/main/java/com/rabbitmq/stream/EnvironmentBuilder.java index 4c8a5239f5..6eb1506f8b 100644 --- a/src/main/java/com/rabbitmq/stream/EnvironmentBuilder.java +++ b/src/main/java/com/rabbitmq/stream/EnvironmentBuilder.java @@ -31,6 +31,7 @@ import java.util.Map; import java.util.concurrent.ScheduledExecutorService; import java.util.function.Consumer; +import javax.net.ssl.SSLContext; /** * API to configure and create an {@link Environment}. @@ -517,4 +518,90 @@ interface NettyConfiguration { */ EnvironmentBuilder environmentBuilder(); } + + /** + * OAuth 2 settings. + * + * @return OAuth 2 settings + * @see OAuth2Configuration + * @since 1.3.0 + */ + OAuth2Configuration oauth2(); + + /** + * Configuration to retrieve a token using the OAuth 2 Client Credentials flow. + * + * @since 1.3.0 + */ + interface OAuth2Configuration { + + /** + * Set the URI to access to get the token. + * + *

TLS is supported by providing a HTTPS URI and setting a {@link + * javax.net.ssl.SSLContext}. See {@link #tls()} for more information. Applications in + * production should always use HTTPS to retrieve tokens. + * + * @param uri access URI + * @return OAuth 2 configuration + * @see #sslContext(javax.net.ssl.SSLContext) + */ + OAuth2Configuration tokenEndpointUri(String uri); + + /** + * Set the OAuth 2 client ID + * + *

The client ID usually identifies the application that requests a token. + * + * @param clientId client ID + * @return OAuth 2 configuration + */ + OAuth2Configuration clientId(String clientId); + + /** + * Set the secret (password) to use to get a token. + * + * @param clientSecret client secret + * @return OAuth 2 configuration + */ + OAuth2Configuration clientSecret(String clientSecret); + + /** + * Set the grant type to use when requesting the token. + * + *

The default is client_credentials, but some OAuth 2 servers can use + * non-standard grant types to request tokens with extra-information. + * + * @param grantType grant type + * @return OAuth 2 configuration + */ + OAuth2Configuration grantType(String grantType); + + /** + * Set a parameter to pass in the request. + * + *

The OAuth 2 server may require extra parameters to narrow down the identity of the user. + * + * @param name name of the parameter + * @param value value of the parameter + * @return OAuth 2 configuration + */ + OAuth2Configuration parameter(String name, String value); + + /** + * {@link javax.net.ssl.SSLContext} for HTTPS requests. + * + * @param sslContext the SSL context + * @return OAuth 2 configuration + */ + OAuth2Configuration sslContext(SSLContext sslContext); + + /** + * Go back to the environment builder + * + * @return the environment builder + */ + EnvironmentBuilder environmentBuilder(); + } } diff --git a/src/main/java/com/rabbitmq/stream/impl/Client.java b/src/main/java/com/rabbitmq/stream/impl/Client.java index 892e7553fb..56ac077aaf 100644 --- a/src/main/java/com/rabbitmq/stream/impl/Client.java +++ b/src/main/java/com/rabbitmq/stream/impl/Client.java @@ -48,6 +48,8 @@ import com.rabbitmq.stream.impl.ServerFrameHandler.FrameHandlerInfo; import com.rabbitmq.stream.metrics.MetricsCollector; import com.rabbitmq.stream.metrics.NoOpMetricsCollector; +import com.rabbitmq.stream.oauth2.CredentialsManager; +import com.rabbitmq.stream.oauth2.CredentialsManager.Registration; import com.rabbitmq.stream.sasl.CredentialsProvider; import com.rabbitmq.stream.sasl.DefaultSaslConfiguration; import com.rabbitmq.stream.sasl.DefaultUsernamePasswordCredentialsProvider; @@ -115,6 +117,7 @@ */ public class Client implements AutoCloseable { + private static final AtomicLong ID_SEQUENCE = new AtomicLong(0); private static final Charset CHARSET = StandardCharsets.UTF_8; public static final int DEFAULT_PORT = 5552; public static final int DEFAULT_TLS_PORT = 5551; @@ -170,7 +173,6 @@ public long applyAsLong(Object value) { }; private final AtomicInteger correlationSequence = new AtomicInteger(0); private final SaslConfiguration saslConfiguration; - private final CredentialsProvider credentialsProvider; private final Runnable nettyClosing; private final int maxFrameSize; private final boolean frameSizeCapped; @@ -190,6 +192,7 @@ public long applyAsLong(Object value) { private final Runnable streamStatsCommandVersionsCheck; private final boolean filteringSupported; private final Runnable superStreamManagementCommandVersionsCheck; + private final Registration credentialsRegistration; @SuppressFBWarnings("CT_CONSTRUCTOR_THROW") public Client() { @@ -206,7 +209,6 @@ public Client(ClientParameters parameters) { this.creditNotification = parameters.creditNotification; this.codec = parameters.codec == null ? Codecs.DEFAULT : parameters.codec; this.saslConfiguration = parameters.saslConfiguration; - this.credentialsProvider = parameters.credentialsProvider; this.chunkChecksum = parameters.chunkChecksum; this.metricsCollector = parameters.metricsCollector; this.metadataListener = parameters.metadataListener; @@ -381,8 +383,36 @@ public void initChannel(SocketChannel ch) { debug(() -> "starting SASL handshake"); this.saslMechanisms = getSaslMechanisms(); debug(() -> "SASL mechanisms supported by server ({})", this.saslMechanisms); + + CredentialsProvider credentialsProvider = parameters.credentialsProvider; + CredentialsManager credentialsManager = parameters.credentialsManager; + CredentialsManager.AuthenticationCallback authCallback, renewCallback; + String regName = + clientConnectionName.isBlank() + ? String.valueOf(ID_SEQUENCE.getAndIncrement()) + : clientConnectionName + "-" + ID_SEQUENCE.getAndIncrement(); + if (credentialsManager == null) { + this.credentialsRegistration = CredentialsManagerFactory.get(); + authCallback = (u, p) -> this.authenticate(credentialsProvider); + } else { + renewCallback = + authCallback = + (u, p) -> { + if (u == null && p == null) { + // no username/password provided by the credentials manager + // using the credentials manager + this.authenticate(credentialsProvider); + } else { + // the credentials manager provides username/password (e.g. after token + // retrieval) + // we use them with a one-time credentials provider + this.authenticate(new DefaultUsernamePasswordCredentialsProvider(u, p)); + } + }; + this.credentialsRegistration = credentialsManager.register(regName, renewCallback); + } debug(() -> "starting authentication"); - authenticate(this.credentialsProvider); + this.credentialsRegistration.connect(authCallback); debug(() -> "authenticated"); this.tuneState.await(Duration.ofSeconds(10)); this.maxFrameSize = this.tuneState.getMaxFrameSize(); @@ -1454,6 +1484,9 @@ void closingSequence(ShutdownContext.ShutdownReason reason) { if (reason != null) { this.shutdownListenerCallback.accept(reason); } + if (this.credentialsRegistration != null) { + this.credentialsRegistration.close(); + } this.nettyClosing.run(); this.failOutstandingRequests(); if (this.closeDispatchingExecutorService != null) { @@ -2378,6 +2411,10 @@ String label() { } } + static ClientParameters cp() { + return new ClientParameters(); + } + public static class ClientParameters { private final Map clientProperties = new ConcurrentHashMap<>(); @@ -2410,6 +2447,7 @@ public static class ClientParameters { private SaslConfiguration saslConfiguration = DefaultSaslConfiguration.PLAIN; private CredentialsProvider credentialsProvider = new DefaultUsernamePasswordCredentialsProvider(DEFAULT_USERNAME, "guest"); + private CredentialsManager credentialsManager; private ChunkChecksum chunkChecksum = JdkChunkChecksum.CRC32_SINGLETON; private MetricsCollector metricsCollector = NoOpMetricsCollector.SINGLETON; private SslContext sslContext; @@ -2492,6 +2530,11 @@ public ClientParameters credentialsProvider(CredentialsProvider credentialsProvi return this; } + public ClientParameters credentialsManager(CredentialsManager credentialsManager) { + this.credentialsManager = credentialsManager; + return this; + } + public ClientParameters username(String username) { if (this.credentialsProvider instanceof UsernamePasswordCredentialsProvider) { this.credentialsProvider = diff --git a/src/main/java/com/rabbitmq/stream/impl/ConsumersCoordinator.java b/src/main/java/com/rabbitmq/stream/impl/ConsumersCoordinator.java index 05a9ae00c1..fcff1116f3 100644 --- a/src/main/java/com/rabbitmq/stream/impl/ConsumersCoordinator.java +++ b/src/main/java/com/rabbitmq/stream/impl/ConsumersCoordinator.java @@ -506,6 +506,18 @@ SubscriptionState state() { return this.state.get(); } + private void markConsuming() { + if (this.consumer != null) { + this.consumer.consuming(); + } + } + + private void markNotConsuming() { + if (this.consumer != null) { + this.consumer.notConsuming(); + } + } + String label() { return String.format( "[id=%d, stream=%s, name=%s, consumer=%d]", @@ -700,6 +712,7 @@ private ClientSubscriptionsManager( "Subscription connection has {} consumer(s) over {} stream(s) to recover", this.subscriptionTrackers.stream().filter(Objects::nonNull).count(), this.streamToStreamSubscriptions.size()); + iterate(this.subscriptionTrackers, SubscriptionTracker::markNotConsuming); environment .scheduledExecutorService() .execute( @@ -774,6 +787,7 @@ private ClientSubscriptionsManager( } if (affectedSubscriptions != null && !affectedSubscriptions.isEmpty()) { + iterate(affectedSubscriptions, SubscriptionTracker::markNotConsuming); environment .scheduledExecutorService() .execute( @@ -1132,6 +1146,7 @@ void add( throw e; } subscriptionTracker.state(SubscriptionState.ACTIVE); + subscriptionTracker.markConsuming(); LOGGER.debug("Subscribed to '{}'", subscriptionTracker.stream); } finally { this.subscriptionManagerLock.unlock(); @@ -1397,4 +1412,13 @@ static Broker pickBroker( Function, Broker> picker, Collection candidates) { return picker.apply(keepReplicasIfPossible(candidates)); } + + private static void iterate( + Collection l, java.util.function.Consumer c) { + for (SubscriptionTracker tracker : l) { + if (tracker != null) { + c.accept(tracker); + } + } + } } diff --git a/src/main/java/com/rabbitmq/stream/impl/CredentialsManagerFactory.java b/src/main/java/com/rabbitmq/stream/impl/CredentialsManagerFactory.java new file mode 100644 index 0000000000..4ffce36df7 --- /dev/null +++ b/src/main/java/com/rabbitmq/stream/impl/CredentialsManagerFactory.java @@ -0,0 +1,73 @@ +// Copyright (c) 2025 Broadcom. All Rights Reserved. +// The term "Broadcom" refers to Broadcom Inc. and/or its subsidiaries. +// +// This software, the RabbitMQ Stream Java client library, is dual-licensed under the +// Mozilla Public License 2.0 ("MPL"), and the Apache License version 2 ("ASL"). +// For the MPL, please see LICENSE-MPL-RabbitMQ. For the ASL, +// please see LICENSE-APACHE2. +// +// This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY KIND, +// either express or implied. See the LICENSE file for specific language governing +// rights and limitations of this software. +// +// If you have any questions regarding licensing, please contact us at +// info@rabbitmq.com. +package com.rabbitmq.stream.impl; + +import com.rabbitmq.stream.impl.StreamEnvironmentBuilder.DefaultOAuth2Configuration; +import com.rabbitmq.stream.oauth2.CredentialsManager; +import com.rabbitmq.stream.oauth2.GsonTokenParser; +import com.rabbitmq.stream.oauth2.HttpTokenRequester; +import com.rabbitmq.stream.oauth2.TokenCredentialsManager; +import java.net.http.HttpClient; +import java.util.concurrent.ScheduledExecutorService; +import java.util.function.Consumer; +import javax.net.ssl.SSLContext; + +final class CredentialsManagerFactory { + + private static final CredentialsManager.Registration CALLBACK_DELEGATING_REGISTRATION = + new CredentialsManager.Registration() { + @Override + public void connect(CredentialsManager.AuthenticationCallback callback) { + callback.authenticate(null, null); + } + + @Override + public void close() {} + }; + + private static final CredentialsManager CREDENTIALS_MANAGER = + (name, updateCallback) -> CALLBACK_DELEGATING_REGISTRATION; + + static CredentialsManager get( + DefaultOAuth2Configuration oauth2, ScheduledExecutorService scheduledExecutorService) { + if (oauth2 != null && oauth2.enabled()) { + Consumer clientBuilderConsumer; + if (oauth2.tlsEnabled()) { + SSLContext sslContext = oauth2.sslContext(); + clientBuilderConsumer = b -> b.sslContext(sslContext); + } else { + clientBuilderConsumer = ignored -> {}; + } + HttpTokenRequester tokenRequester = + new HttpTokenRequester( + oauth2.tokenEndpointUri(), + oauth2.clientId(), + oauth2.clientSecret(), + oauth2.grantType(), + oauth2.parameters(), + clientBuilderConsumer, + null, + new GsonTokenParser()); + return new TokenCredentialsManager( + tokenRequester, scheduledExecutorService, oauth2.refreshDelayStrategy()); + } else { + return CREDENTIALS_MANAGER; + } + } + + static CredentialsManager.Registration get() { + return CALLBACK_DELEGATING_REGISTRATION; + } +} diff --git a/src/main/java/com/rabbitmq/stream/impl/StreamConsumer.java b/src/main/java/com/rabbitmq/stream/impl/StreamConsumer.java index 11c57e4517..1c868002e6 100644 --- a/src/main/java/com/rabbitmq/stream/impl/StreamConsumer.java +++ b/src/main/java/com/rabbitmq/stream/impl/StreamConsumer.java @@ -64,6 +64,7 @@ class StreamConsumer implements Consumer { private final boolean sac; private final OffsetSpecification initialOffsetSpecification; private final Lock lock = new ReentrantLock(); + private volatile boolean consuming; @SuppressFBWarnings("CT_CONSTRUCTOR_THROW") StreamConsumer( @@ -249,6 +250,7 @@ class StreamConsumer implements Consumer { this.closed.set(true); throw e; } + this.consuming = true; } static OffsetSpecification getStoredOffset( @@ -605,4 +607,16 @@ String subscriptionConnectionName() { return client.clientConnectionName(); } } + + void notConsuming() { + this.consuming = false; + } + + void consuming() { + this.consuming = true; + } + + boolean isConsuming() { + return this.consuming; + } } diff --git a/src/main/java/com/rabbitmq/stream/impl/StreamEnvironment.java b/src/main/java/com/rabbitmq/stream/impl/StreamEnvironment.java index ad41eccbaf..292dd8ff02 100644 --- a/src/main/java/com/rabbitmq/stream/impl/StreamEnvironment.java +++ b/src/main/java/com/rabbitmq/stream/impl/StreamEnvironment.java @@ -31,8 +31,10 @@ import com.rabbitmq.stream.impl.Client.StreamStatsResponse; import com.rabbitmq.stream.impl.OffsetTrackingCoordinator.Registration; import com.rabbitmq.stream.impl.StreamConsumerBuilder.TrackingConfiguration; +import com.rabbitmq.stream.impl.StreamEnvironmentBuilder.DefaultOAuth2Configuration; import com.rabbitmq.stream.impl.StreamEnvironmentBuilder.DefaultTlsConfiguration; import com.rabbitmq.stream.impl.Utils.ClientConnectionType; +import com.rabbitmq.stream.oauth2.CredentialsManager; import com.rabbitmq.stream.sasl.CredentialsProvider; import com.rabbitmq.stream.sasl.UsernamePasswordCredentialsProvider; import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; @@ -103,6 +105,7 @@ class StreamEnvironment implements Environment { int maxTrackingConsumersByConnection, int maxConsumersByConnection, DefaultTlsConfiguration tlsConfiguration, + DefaultOAuth2Configuration oauth, ByteBufAllocator byteBufAllocator, boolean lazyInit, Function connectionNamingStrategy, @@ -158,6 +161,7 @@ class StreamEnvironment implements Environment { } AddressResolver addressResolverToUse = addressResolver; + // trying to detect development environment if (this.addresses.size() == 1 && "localhost".equals(this.addresses.get(0).host()) && addressResolver == DEFAULT_ADDRESS_RESOLVER) { @@ -211,18 +215,6 @@ class StreamEnvironment implements Environment { this.addresses.size(), 1, "rabbitmq-stream-locator-connection-"); shutdownService.wrap(this.executorServiceFactory::close); - if (clientParametersPrototype.eventLoopGroup == null) { - this.eventLoopGroup = Utils.eventLoopGroup(); - shutdownService.wrap(() -> closeEventLoopGroup(this.eventLoopGroup)); - this.clientParametersPrototype = - clientParametersPrototype.duplicate().eventLoopGroup(this.eventLoopGroup); - } else { - this.eventLoopGroup = null; - this.clientParametersPrototype = - clientParametersPrototype - .duplicate() - .eventLoopGroup(clientParametersPrototype.eventLoopGroup); - } ScheduledExecutorService executorService; if (scheduledExecutorService == null) { int threads = AVAILABLE_PROCESSORS; @@ -237,6 +229,25 @@ class StreamEnvironment implements Environment { } this.scheduledExecutorService = executorService; + CredentialsManager credentialsManager = + CredentialsManagerFactory.get(oauth, this.scheduledExecutorService); + + clientParametersPrototype = + clientParametersPrototype.duplicate().credentialsManager(credentialsManager); + + if (clientParametersPrototype.eventLoopGroup == null) { + this.eventLoopGroup = Utils.eventLoopGroup(); + shutdownService.wrap(() -> closeEventLoopGroup(this.eventLoopGroup)); + this.clientParametersPrototype = + clientParametersPrototype.duplicate().eventLoopGroup(this.eventLoopGroup); + } else { + this.eventLoopGroup = null; + this.clientParametersPrototype = + clientParametersPrototype + .duplicate() + .eventLoopGroup(clientParametersPrototype.eventLoopGroup); + } + this.producersCoordinator = new ProducersCoordinator( this, diff --git a/src/main/java/com/rabbitmq/stream/impl/StreamEnvironmentBuilder.java b/src/main/java/com/rabbitmq/stream/impl/StreamEnvironmentBuilder.java index 1be66aa36b..026df49a49 100644 --- a/src/main/java/com/rabbitmq/stream/impl/StreamEnvironmentBuilder.java +++ b/src/main/java/com/rabbitmq/stream/impl/StreamEnvironmentBuilder.java @@ -21,7 +21,9 @@ import com.rabbitmq.stream.compression.CompressionCodecFactory; import com.rabbitmq.stream.impl.Utils.ClientConnectionType; import com.rabbitmq.stream.metrics.MetricsCollector; +import com.rabbitmq.stream.oauth2.TokenCredentialsManager; import com.rabbitmq.stream.sasl.CredentialsProvider; +import com.rabbitmq.stream.sasl.DefaultSaslConfiguration; import com.rabbitmq.stream.sasl.SaslConfiguration; import io.netty.bootstrap.Bootstrap; import io.netty.buffer.ByteBufAllocator; @@ -32,13 +34,16 @@ import java.net.URI; import java.net.URISyntaxException; import java.time.Duration; +import java.time.Instant; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ScheduledExecutorService; import java.util.function.Consumer; import java.util.function.Function; import java.util.stream.Collectors; +import javax.net.ssl.SSLContext; import javax.net.ssl.SSLException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -51,6 +56,7 @@ public class StreamEnvironmentBuilder implements EnvironmentBuilder { private final Client.ClientParameters clientParameters = new Client.ClientParameters(); private final DefaultTlsConfiguration tls = new DefaultTlsConfiguration(this); private final DefaultNettyConfiguration netty = new DefaultNettyConfiguration(this); + private final DefaultOAuth2Configuration oauth2 = new DefaultOAuth2Configuration(this); private ScheduledExecutorService scheduledExecutorService; private List uris = Collections.emptyList(); private BackOffDelayPolicy recoveryBackOffDelayPolicy = @@ -295,6 +301,11 @@ public NettyConfiguration netty() { return this.netty; } + @Override + public OAuth2Configuration oauth2() { + return this.oauth2; + } + StreamEnvironmentBuilder clientFactory(Function clientFactory) { this.clientFactory = clientFactory; return this; @@ -348,6 +359,7 @@ public Environment build() { maxTrackingConsumersByConnection, maxConsumersByConnection, tls, + oauth2, netty.byteBufAllocator, lazyInit, connectionNamingStrategy, @@ -459,4 +471,110 @@ public EnvironmentBuilder environmentBuilder() { return this.environmentBuilder; } } + + static class DefaultOAuth2Configuration implements OAuth2Configuration { + + private final EnvironmentBuilder builder; + private final Map parameters = new HashMap<>(); + private String tokenEndpointUri; + private String clientId; + private String clientSecret; + private String grantType = "client_credentials"; + private Function refreshDelayStrategy = + TokenCredentialsManager.DEFAULT_REFRESH_DELAY_STRATEGY; + private SSLContext sslContext; + + DefaultOAuth2Configuration(StreamEnvironmentBuilder builder) { + this.builder = builder; + } + + @Override + public OAuth2Configuration tokenEndpointUri(String uri) { + this.builder.saslConfiguration(DefaultSaslConfiguration.PLAIN); + this.builder.credentialsProvider(null); + this.tokenEndpointUri = uri; + return this; + } + + @Override + public OAuth2Configuration clientId(String clientId) { + this.clientId = clientId; + return this; + } + + @Override + public OAuth2Configuration clientSecret(String clientSecret) { + this.clientSecret = clientSecret; + return this; + } + + @Override + public OAuth2Configuration grantType(String grantType) { + this.grantType = grantType; + return this; + } + + @Override + public OAuth2Configuration parameter(String name, String value) { + if (value == null) { + this.parameters.remove(name); + } else { + this.parameters.put(name, value); + } + return this; + } + + @Override + public OAuth2Configuration sslContext(SSLContext sslContext) { + this.sslContext = sslContext; + return this; + } + + @Override + public EnvironmentBuilder environmentBuilder() { + return this.builder; + } + + DefaultOAuth2Configuration refreshDelayStrategy( + Function refreshDelayStrategy) { + this.refreshDelayStrategy = refreshDelayStrategy; + return this; + } + + Function refreshDelayStrategy() { + return this.refreshDelayStrategy; + } + + String tokenEndpointUri() { + return this.tokenEndpointUri; + } + + String clientId() { + return this.clientId; + } + + String clientSecret() { + return this.clientSecret; + } + + String grantType() { + return this.grantType; + } + + Map parameters() { + return Map.copyOf(this.parameters); + } + + SSLContext sslContext() { + return this.sslContext; + } + + boolean enabled() { + return this.tokenEndpointUri != null; + } + + boolean tlsEnabled() { + return this.sslContext != null; + } + } } diff --git a/src/main/java/com/rabbitmq/stream/oauth2/CredentialsManager.java b/src/main/java/com/rabbitmq/stream/oauth2/CredentialsManager.java new file mode 100644 index 0000000000..bdc340f037 --- /dev/null +++ b/src/main/java/com/rabbitmq/stream/oauth2/CredentialsManager.java @@ -0,0 +1,87 @@ +// Copyright (c) 2024-2025 Broadcom. All Rights Reserved. +// The term "Broadcom" refers to Broadcom Inc. and/or its subsidiaries. +// +// This software, the RabbitMQ Stream Java client library, is dual-licensed under the +// Mozilla Public License 2.0 ("MPL"), and the Apache License version 2 ("ASL"). +// For the MPL, please see LICENSE-MPL-RabbitMQ. For the ASL, +// please see LICENSE-APACHE2. +// +// This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY KIND, +// either express or implied. See the LICENSE file for specific language governing +// rights and limitations of this software. +// +// If you have any questions regarding licensing, please contact us at +// info@rabbitmq.com. +package com.rabbitmq.stream.oauth2; + +/** + * Contract to authenticate and possibly re-authenticate application components. + * + *

A typical "application component" is a connection. + */ +public interface CredentialsManager { + + /** No-op credentials manager. */ + CredentialsManager NO_OP = new NoOpCredentialsManager(); + + /** + * Register a component for authentication. + * + * @param name component name (must be unique) + * @param updateCallback callback to update the component authentication + * @return the registration (must be closed when no longer necessary) + */ + Registration register(String name, AuthenticationCallback updateCallback); + + /** A component registration. */ + interface Registration extends AutoCloseable { + + /** + * Connection request from the component. + * + *

The component calls this method when it needs to authenticate. The underlying credentials + * manager implementation must take care of providing the component with the appropriate + * credentials in the callback. + * + * @param callback client code to authenticate the component + */ + void connect(AuthenticationCallback callback); + + /** Close the registration. */ + void close(); + } + + /** + * Component authentication callback. + * + *

The component provides the logic and the manager implementation calls it with the + * appropriate credentials. + */ + interface AuthenticationCallback { + + /** + * Authentication logic. + * + * @param username username + * @param password password + */ + void authenticate(String username, String password); + } + + class NoOpCredentialsManager implements CredentialsManager { + + @Override + public Registration register(String name, AuthenticationCallback updateCallback) { + return new NoOpRegistration(); + } + } + + class NoOpRegistration implements Registration { + + @Override + public void connect(AuthenticationCallback callback) {} + + @Override + public void close() {} + } +} diff --git a/src/main/java/com/rabbitmq/stream/oauth2/GsonTokenParser.java b/src/main/java/com/rabbitmq/stream/oauth2/GsonTokenParser.java new file mode 100644 index 0000000000..47d060131d --- /dev/null +++ b/src/main/java/com/rabbitmq/stream/oauth2/GsonTokenParser.java @@ -0,0 +1,65 @@ +// Copyright (c) 2024-2025 Broadcom. All Rights Reserved. +// The term "Broadcom" refers to Broadcom Inc. and/or its subsidiaries. +// +// This software, the RabbitMQ Stream Java client library, is dual-licensed under the +// Mozilla Public License 2.0 ("MPL"), and the Apache License version 2 ("ASL"). +// For the MPL, please see LICENSE-MPL-RabbitMQ. For the ASL, +// please see LICENSE-APACHE2. +// +// This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY KIND, +// either express or implied. See the LICENSE file for specific language governing +// rights and limitations of this software. +// +// If you have any questions regarding licensing, please contact us at +// info@rabbitmq.com. +package com.rabbitmq.stream.oauth2; + +import com.google.gson.Gson; +import com.google.gson.reflect.TypeToken; +import java.time.Duration; +import java.time.Instant; +import java.util.Map; + +/** + * Token parser for JSON OAuth 2 Access + * tokens. + * + *

Uses GSON for the JSON parsing. + */ +public class GsonTokenParser implements TokenParser { + + private static final Gson GSON = new Gson(); + private static final TypeToken> MAP_TYPE = new TypeToken<>() {}; + + @Override + public Token parse(String tokenAsString) { + Map tokenAsMap = GSON.fromJson(tokenAsString, MAP_TYPE); + String accessToken = (String) tokenAsMap.get("access_token"); + // in seconds, see https://www.rfc-editor.org/rfc/rfc6749#section-5.1 + Duration expiresIn = Duration.ofSeconds(((Number) tokenAsMap.get("expires_in")).longValue()); + Instant expirationTime = + Instant.ofEpochMilli(System.currentTimeMillis() + expiresIn.toMillis()); + return new DefaultTokenInfo(accessToken, expirationTime); + } + + private static final class DefaultTokenInfo implements Token { + + private final String value; + private final Instant expirationTime; + + private DefaultTokenInfo(String value, Instant expirationTime) { + this.value = value; + this.expirationTime = expirationTime; + } + + @Override + public String value() { + return this.value; + } + + @Override + public Instant expirationTime() { + return this.expirationTime; + } + } +} diff --git a/src/main/java/com/rabbitmq/stream/oauth2/HttpTokenRequester.java b/src/main/java/com/rabbitmq/stream/oauth2/HttpTokenRequester.java new file mode 100644 index 0000000000..defe24ed44 --- /dev/null +++ b/src/main/java/com/rabbitmq/stream/oauth2/HttpTokenRequester.java @@ -0,0 +1,159 @@ +// Copyright (c) 2024-2025 Broadcom. All Rights Reserved. +// The term "Broadcom" refers to Broadcom Inc. and/or its subsidiaries. +// +// This software, the RabbitMQ Stream Java client library, is dual-licensed under the +// Mozilla Public License 2.0 ("MPL"), and the Apache License version 2 ("ASL"). +// For the MPL, please see LICENSE-MPL-RabbitMQ. For the ASL, +// please see LICENSE-APACHE2. +// +// This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY KIND, +// either express or implied. See the LICENSE file for specific language governing +// rights and limitations of this software. +// +// If you have any questions regarding licensing, please contact us at +// info@rabbitmq.com. +package com.rabbitmq.stream.oauth2; + +import static java.nio.charset.StandardCharsets.UTF_8; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.URLEncoder; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.time.Duration; +import java.util.Base64; +import java.util.Map; +import java.util.function.Consumer; + +/** + * Token requester using HTTP(S) to request an OAuth2 Access token. + * + *

Uses {@link HttpClient} for the HTTP operations. + */ +public final class HttpTokenRequester implements TokenRequester { + + private static final Duration REQUEST_TIMEOUT = Duration.ofSeconds(60); + private static final Duration CONNECT_TIMEOUT = Duration.ofSeconds(30); + + private final URI tokenEndpointUri; + private final String clientId; + private final String clientSecret; + private final String grantType; + + private final Map parameters; + + private final HttpClient client; + private final Consumer requestBuilderConsumer; + + private final TokenParser parser; + + public HttpTokenRequester( + String tokenEndpointUri, + String clientId, + String clientSecret, + String grantType, + Map parameters, + Consumer clientBuilderConsumer, + Consumer requestBuilderConsumer, + TokenParser parser) { + try { + this.tokenEndpointUri = new URI(tokenEndpointUri); + } catch (URISyntaxException e) { + throw new IllegalArgumentException("Error in URI: " + tokenEndpointUri); + } + this.clientId = clientId; + this.clientSecret = clientSecret; + this.grantType = grantType; + this.parameters = Map.copyOf(parameters); + this.parser = parser; + if (requestBuilderConsumer == null) { + this.requestBuilderConsumer = + requestBuilder -> + requestBuilder + .timeout(REQUEST_TIMEOUT) + .setHeader("authorization", authorization(this.clientId, this.clientSecret)); + } else { + this.requestBuilderConsumer = requestBuilderConsumer; + } + + HttpClient.Builder builder = + HttpClient.newBuilder() + .version(HttpClient.Version.HTTP_1_1) + .followRedirects(HttpClient.Redirect.NORMAL) + .connectTimeout(CONNECT_TIMEOUT); + if (clientBuilderConsumer != null) { + clientBuilderConsumer.accept(builder); + } + this.client = builder.build(); + // TODO handle HTTPS configuration + } + + @Override + public Token request() { + StringBuilder urlParameters = new StringBuilder(); + encode(urlParameters, "grant_type", grantType); + for (Map.Entry parameter : parameters.entrySet()) { + encode(urlParameters, parameter.getKey(), parameter.getValue()); + } + byte[] postData = urlParameters.toString().getBytes(UTF_8); + + HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(this.tokenEndpointUri); + requestBuilder.header("content-type", "application/x-www-form-urlencoded"); + requestBuilder.header("charset", UTF_8.name()); + requestBuilder.header("accept", "application/json"); + + requestBuilderConsumer.accept(requestBuilder); + HttpRequest request = + requestBuilder.POST(HttpRequest.BodyPublishers.ofByteArray(postData)).build(); + + try { + HttpResponse response = + this.client.send(request, HttpResponse.BodyHandlers.ofString(UTF_8)); + checkStatusCode(response.statusCode()); + checkContentType(response.headers().firstValue("content-type").orElse(null)); + return this.parser.parse(response.body()); + } catch (IOException e) { + throw new OAuth2Exception("Error while retrieving OAuth 2 token", e); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new OAuth2Exception("Error while retrieving OAuth 2 token", e); + } + } + + private static String authorization(String username, String password) { + return "Basic " + base64(username + ":" + password); + } + + private static String base64(String in) { + return Base64.getEncoder().encodeToString(in.getBytes(UTF_8)); + } + + private static void encode(StringBuilder builder, String name, String value) { + if (value != null) { + if (builder.length() > 0) { + builder.append("&"); + } + builder.append(encode(name)).append("=").append(encode(value)); + } + } + + private static String encode(String value) { + return URLEncoder.encode(value, UTF_8); + } + + private static void checkContentType(String contentType) { + if (contentType == null || !contentType.toLowerCase().contains("json")) { + throw new OAuth2Exception("HTTP request for token retrieval is not JSON: " + contentType); + } + } + + private static void checkStatusCode(int statusCode) { + if (statusCode != 200) { + throw new OAuth2Exception( + "HTTP request for token retrieval did not " + "return 200 status code: " + statusCode); + } + } +} diff --git a/src/main/java/com/rabbitmq/stream/oauth2/OAuth2Exception.java b/src/main/java/com/rabbitmq/stream/oauth2/OAuth2Exception.java new file mode 100644 index 0000000000..2abdf895b0 --- /dev/null +++ b/src/main/java/com/rabbitmq/stream/oauth2/OAuth2Exception.java @@ -0,0 +1,27 @@ +// Copyright (c) 2024-2025 Broadcom. All Rights Reserved. +// The term "Broadcom" refers to Broadcom Inc. and/or its subsidiaries. +// +// This software, the RabbitMQ Stream Java client library, is dual-licensed under the +// Mozilla Public License 2.0 ("MPL"), and the Apache License version 2 ("ASL"). +// For the MPL, please see LICENSE-MPL-RabbitMQ. For the ASL, +// please see LICENSE-APACHE2. +// +// This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY KIND, +// either express or implied. See the LICENSE file for specific language governing +// rights and limitations of this software. +// +// If you have any questions regarding licensing, please contact us at +// info@rabbitmq.com. +package com.rabbitmq.stream.oauth2; + +/** OAuth 2-related exception. */ +public class OAuth2Exception extends RuntimeException { + + public OAuth2Exception(String message) { + super(message); + } + + public OAuth2Exception(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/src/main/java/com/rabbitmq/stream/oauth2/Token.java b/src/main/java/com/rabbitmq/stream/oauth2/Token.java new file mode 100644 index 0000000000..7f0f8568a4 --- /dev/null +++ b/src/main/java/com/rabbitmq/stream/oauth2/Token.java @@ -0,0 +1,35 @@ +// Copyright (c) 2024-2025 Broadcom. All Rights Reserved. +// The term "Broadcom" refers to Broadcom Inc. and/or its subsidiaries. +// +// This software, the RabbitMQ Stream Java client library, is dual-licensed under the +// Mozilla Public License 2.0 ("MPL"), and the Apache License version 2 ("ASL"). +// For the MPL, please see LICENSE-MPL-RabbitMQ. For the ASL, +// please see LICENSE-APACHE2. +// +// This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY KIND, +// either express or implied. See the LICENSE file for specific language governing +// rights and limitations of this software. +// +// If you have any questions regarding licensing, please contact us at +// info@rabbitmq.com. +package com.rabbitmq.stream.oauth2; + +import java.time.Instant; + +/** A token. */ +public interface Token { + + /** + * The value of the token. + * + * @return the token value + */ + String value(); + + /** + * The expiration time of the token. + * + * @return the expiration time + */ + Instant expirationTime(); +} diff --git a/src/main/java/com/rabbitmq/stream/oauth2/TokenCredentialsManager.java b/src/main/java/com/rabbitmq/stream/oauth2/TokenCredentialsManager.java new file mode 100644 index 0000000000..cb57badc31 --- /dev/null +++ b/src/main/java/com/rabbitmq/stream/oauth2/TokenCredentialsManager.java @@ -0,0 +1,351 @@ +// Copyright (c) 2024-2025 Broadcom. All Rights Reserved. +// The term "Broadcom" refers to Broadcom Inc. and/or its subsidiaries. +// +// This software, the RabbitMQ Stream Java client library, is dual-licensed under the +// Mozilla Public License 2.0 ("MPL"), and the Apache License version 2 ("ASL"). +// For the MPL, please see LICENSE-MPL-RabbitMQ. For the ASL, +// please see LICENSE-APACHE2. +// +// This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY KIND, +// either express or implied. See the LICENSE file for specific language governing +// rights and limitations of this software. +// +// If you have any questions regarding licensing, please contact us at +// info@rabbitmq.com. +package com.rabbitmq.stream.oauth2; + +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; +import java.time.Duration; +import java.time.Instant; +import java.time.format.DateTimeFormatter; +import java.util.Collection; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Function; +import java.util.stream.Collectors; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Credentials manager implementation that requests and refreshes tokens. + * + *

It also keeps track of registration and update them with refreshed tokens when appropriate. + */ +public final class TokenCredentialsManager implements CredentialsManager { + + public static final Function DEFAULT_REFRESH_DELAY_STRATEGY = + ratioRefreshDelayStrategy(0.8f); + private static final Logger LOGGER = LoggerFactory.getLogger(TokenCredentialsManager.class); + + private final TokenRequester requester; + private final ScheduledExecutorService scheduledExecutorService; + private volatile Token token; + private final Lock lock = new ReentrantLock(); + private final Map registrations = new ConcurrentHashMap<>(); + private final AtomicLong registrationSequence = new AtomicLong(0); + private final AtomicBoolean schedulingRefresh = new AtomicBoolean(false); + private final Function refreshDelayStrategy; + private volatile ScheduledFuture refreshTask; + + public TokenCredentialsManager( + TokenRequester requester, + ScheduledExecutorService scheduledExecutorService, + Function refreshDelayStrategy) { + this.requester = requester; + this.scheduledExecutorService = scheduledExecutorService; + this.refreshDelayStrategy = refreshDelayStrategy; + } + + private void lock() { + this.lock.lock(); + } + + private void unlock() { + this.lock.unlock(); + } + + private boolean expiresSoon(Token ignores) { + return false; + } + + private Token getToken() { + if (debug()) { + LOGGER.debug( + "Requesting new token ({})...", registrationSummary(this.registrations.values())); + } + long start = 0L; + if (debug()) { + start = System.nanoTime(); + } + Token token = requester.request(); + if (debug()) { + LOGGER.debug( + "Got new token in {} ms, token expires on {} ({})", + Duration.ofNanos(System.nanoTime() - start), + format(token.expirationTime()), + registrationSummary(this.registrations.values())); + } + return token; + } + + @Override + public Registration register(String name, AuthenticationCallback updateCallback) { + Long id = this.registrationSequence.getAndIncrement(); + name = name == null ? id.toString() : name; + RegistrationImpl registration = new RegistrationImpl(id, name, updateCallback); + this.registrations.put(id, registration); + return registration; + } + + private void updateRegistrations(Token t) { + this.scheduledExecutorService.execute( + () -> { + LOGGER.debug("Updating {} registration(s)", this.registrations.size()); + int refreshedCount = 0; + for (RegistrationImpl registration : this.registrations.values()) { + if (t.equals(this.token)) { + if (!registration.isClosed() && !registration.hasSameToken(t)) { + // the registration does not have the new token yet + try { + registration.updateCallback().authenticate("", this.token.value()); + } catch (Exception e) { + LOGGER.warn( + "Error while updating token for registration '{}': {}", + registration.name(), + e.getMessage()); + } + registration.registrationToken = this.token; + refreshedCount++; + } else { + if (debug()) { + LOGGER.debug( + "Not updating registration {} (closed or already has the new token)", + registration.name()); + } + } + } else { + if (debug()) { + LOGGER.debug( + "Not updating registration {} (the token has changed)", registration.name()); + } + } + } + LOGGER.debug("Updated {} registration(s)", refreshedCount); + }); + } + + private void token(Token t) { + lock(); + try { + if (!t.equals(this.token)) { + this.token = t; + scheduleTokenRefresh(t); + } + } finally { + unlock(); + } + } + + private void scheduleTokenRefresh(Token t) { + if (this.schedulingRefresh.compareAndSet(false, true)) { + if (this.refreshTask != null) { + if (debug()) { + LOGGER.debug("Cancelling refresh task (scheduling a new one)"); + } + this.refreshTask.cancel(false); + } + Duration delay = this.refreshDelayStrategy.apply(t.expirationTime()); + if (!this.registrations.isEmpty()) { + if (debug()) { + LOGGER.debug( + "Scheduling token update in {} ({})", + delay, + registrationSummary(this.registrations.values())); + } + this.refreshTask = + this.scheduledExecutorService.schedule( + () -> { + if (debug()) { + LOGGER.debug("Starting token update task"); + } + Token previousToken = this.token; + this.lock(); + try { + if (this.token.equals(previousToken)) { + Token newToken = getToken(); + token(newToken); + updateRegistrations(newToken); + } else { + if (debug()) { + LOGGER.debug("Token has already been updated"); + } + } + } finally { + unlock(); + } + }, + delay.toMillis(), + TimeUnit.MILLISECONDS); + if (debug()) { + LOGGER.debug("Task scheduled"); + } + } else { + this.refreshTask = null; + } + this.schedulingRefresh.set(false); + } + } + + private static String format(Instant instant) { + return DateTimeFormatter.ISO_INSTANT.format(instant); + } + + private final class RegistrationImpl implements Registration { + + private final Long id; + private final String name; + private final AuthenticationCallback updateCallback; + private volatile Token registrationToken; + private final AtomicBoolean closed = new AtomicBoolean(false); + + private RegistrationImpl(Long id, String name, AuthenticationCallback updateCallback) { + this.id = id; + this.name = name; + this.updateCallback = updateCallback; + } + + @Override + public void connect(AuthenticationCallback callback) { + if (debug()) { + LOGGER.debug("Connecting registration {}", this.name); + } + boolean shouldRefresh = false; + Token tokenToUse; + lock(); + try { + Token globalToken = token; + if (globalToken == null) { + token(getToken()); + } else if (expiresSoon(globalToken)) { + shouldRefresh = true; + token(getToken()); + } + if (!token.equals(this.registrationToken)) { + this.registrationToken = token; + } + tokenToUse = this.registrationToken; + if (refreshTask == null) { + scheduleTokenRefresh(tokenToUse); + } + } finally { + unlock(); + } + if (debug()) { + if (debug()) { + LOGGER.debug("Authenticating registration {}", this.name); + } + } + callback.authenticate("", tokenToUse.value()); + if (shouldRefresh) { + updateRegistrations(tokenToUse); + } + } + + @Override + public void close() { + if (this.closed.compareAndSet(false, true)) { + LOGGER.debug("Closing credentials registration {}", this.name); + registrations.remove(this.id); + ScheduledFuture task = refreshTask; + if (registrations.isEmpty() && task != null) { + lock(); + try { + if (refreshTask != null) { + refreshTask.cancel(false); + } + } finally { + unlock(); + } + } + } + } + + private AuthenticationCallback updateCallback() { + return this.updateCallback; + } + + private String name() { + return this.name; + } + + private boolean hasSameToken(Token t) { + return t.equals(this.registrationToken); + } + + private boolean isClosed() { + return this.closed.get(); + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + RegistrationImpl that = (RegistrationImpl) o; + return Objects.equals(id, that.id); + } + + @Override + public int hashCode() { + return Objects.hashCode(id); + } + + @Override + public String toString() { + return this.name(); + } + } + + public static Function ratioRefreshDelayStrategy(float ratio) { + return new RatioRefreshDelayStrategy(ratio); + } + + private static class RatioRefreshDelayStrategy implements Function { + + private final float ratio; + + @SuppressFBWarnings("CT_CONSTRUCTOR_THROW") + private RatioRefreshDelayStrategy(float ratio) { + if (ratio < 0 || ratio > 1) { + throw new IllegalArgumentException("Ratio should be > 0 and <= 1: " + ratio); + } + this.ratio = ratio; + } + + @Override + public Duration apply(Instant expirationTime) { + Duration expiresIn = Duration.between(Instant.now(), expirationTime); + Duration delay; + if (expiresIn.isZero() || expiresIn.isNegative()) { + delay = Duration.ofSeconds(1); + } else { + delay = Duration.ofMillis((long) (expiresIn.toMillis() * ratio)); + } + return delay; + } + } + + private static String registrationSummary(Collection registrations) { + return registrations.stream().map(Registration::toString).collect(Collectors.joining(", ")); + } + + private static boolean debug() { + return LOGGER.isDebugEnabled(); + } +} diff --git a/src/main/java/com/rabbitmq/stream/oauth2/TokenParser.java b/src/main/java/com/rabbitmq/stream/oauth2/TokenParser.java new file mode 100644 index 0000000000..c8c17bc10c --- /dev/null +++ b/src/main/java/com/rabbitmq/stream/oauth2/TokenParser.java @@ -0,0 +1,27 @@ +// Copyright (c) 2024-2025 Broadcom. All Rights Reserved. +// The term "Broadcom" refers to Broadcom Inc. and/or its subsidiaries. +// +// This software, the RabbitMQ Stream Java client library, is dual-licensed under the +// Mozilla Public License 2.0 ("MPL"), and the Apache License version 2 ("ASL"). +// For the MPL, please see LICENSE-MPL-RabbitMQ. For the ASL, +// please see LICENSE-APACHE2. +// +// This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY KIND, +// either express or implied. See the LICENSE file for specific language governing +// rights and limitations of this software. +// +// If you have any questions regarding licensing, please contact us at +// info@rabbitmq.com. +package com.rabbitmq.stream.oauth2; + +/** Contract to parse a token from a string. */ +public interface TokenParser { + + /** + * Parse the token. + * + * @param tokenAsString token as a string + * @return the token + */ + Token parse(String tokenAsString); +} diff --git a/src/main/java/com/rabbitmq/stream/oauth2/TokenRequester.java b/src/main/java/com/rabbitmq/stream/oauth2/TokenRequester.java new file mode 100644 index 0000000000..63e6e4ec5d --- /dev/null +++ b/src/main/java/com/rabbitmq/stream/oauth2/TokenRequester.java @@ -0,0 +1,26 @@ +// Copyright (c) 2024-2025 Broadcom. All Rights Reserved. +// The term "Broadcom" refers to Broadcom Inc. and/or its subsidiaries. +// +// This software, the RabbitMQ Stream Java client library, is dual-licensed under the +// Mozilla Public License 2.0 ("MPL"), and the Apache License version 2 ("ASL"). +// For the MPL, please see LICENSE-MPL-RabbitMQ. For the ASL, +// please see LICENSE-APACHE2. +// +// This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY KIND, +// either express or implied. See the LICENSE file for specific language governing +// rights and limitations of this software. +// +// If you have any questions regarding licensing, please contact us at +// info@rabbitmq.com. +package com.rabbitmq.stream.oauth2; + +/** Contract to request a token (usually on HTTP). */ +public interface TokenRequester { + + /** + * Request a token. + * + * @return the token + */ + Token request(); +} diff --git a/src/test/java/com/rabbitmq/stream/docs/EnvironmentUsage.java b/src/test/java/com/rabbitmq/stream/docs/EnvironmentUsage.java index bd7a9c6bc2..ae92d1f3a5 100644 --- a/src/test/java/com/rabbitmq/stream/docs/EnvironmentUsage.java +++ b/src/test/java/com/rabbitmq/stream/docs/EnvironmentUsage.java @@ -20,11 +20,12 @@ import io.micrometer.observation.ObservationRegistry; import io.netty.channel.EventLoopGroup; import io.netty.channel.MultiThreadIoEventLoopGroup; -import io.netty.channel.epoll.EpollEventLoopGroup; import io.netty.channel.epoll.EpollIoHandler; import io.netty.channel.epoll.EpollSocketChannel; import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContextBuilder; + +import javax.net.ssl.SSLContext; import java.io.FileInputStream; import java.security.cert.CertificateFactory; import java.security.cert.X509Certificate; @@ -140,6 +141,22 @@ void deleteStream() { // end::stream-deletion[] } + void oauth2() { + SSLContext sslContext = null; + // tag::oauth2[] + Environment env = Environment.builder() + .oauth2() // <1> + .tokenEndpointUri("https://localhost:8443/uaa/oauth/token/") // <2> + .clientId("rabbitmq").clientSecret("rabbitmq") // <3> + .grantType("password") // <4> + .parameter("username", "rabbit_super") // <5> + .parameter("password", "rabbit_super") // <5> + .sslContext(sslContext) // <6> + .environmentBuilder() + .build(); + // end::oauth2[] + } + void nativeEpoll() { // tag::native-epoll[] EventLoopGroup epollEventLoopGroup = new MultiThreadIoEventLoopGroup( // <1> diff --git a/src/test/java/com/rabbitmq/stream/impl/Assertions.java b/src/test/java/com/rabbitmq/stream/impl/Assertions.java index 10b7093774..ec214b108e 100644 --- a/src/test/java/com/rabbitmq/stream/impl/Assertions.java +++ b/src/test/java/com/rabbitmq/stream/impl/Assertions.java @@ -20,7 +20,7 @@ import java.time.Duration; import org.assertj.core.api.AbstractObjectAssert; -final class Assertions { +public final class Assertions { private Assertions() {} @@ -28,7 +28,7 @@ static ResponseAssert assertThat(Client.Response response) { return new ResponseAssert(response); } - static SyncAssert assertThat(TestUtils.Sync sync) { + public static SyncAssert assertThat(TestUtils.Sync sync) { return new SyncAssert(sync); } @@ -68,13 +68,13 @@ ResponseAssert hasCodeNoOffset() { } } - static class SyncAssert extends AbstractObjectAssert { + public static class SyncAssert extends AbstractObjectAssert { private SyncAssert(TestUtils.Sync sync) { super(sync, SyncAssert.class); } - SyncAssert completes() { + public SyncAssert completes() { return this.completes(TestUtils.DEFAULT_CONDITION_TIMEOUT); } diff --git a/src/test/java/com/rabbitmq/stream/impl/AuthenticationTest.java b/src/test/java/com/rabbitmq/stream/impl/AuthenticationTest.java index 779d2a8bce..f67d0d1dc5 100644 --- a/src/test/java/com/rabbitmq/stream/impl/AuthenticationTest.java +++ b/src/test/java/com/rabbitmq/stream/impl/AuthenticationTest.java @@ -134,7 +134,7 @@ void updateSecretShouldSucceedWithNewCorrectPassword() { try { addUser(username, password); setPermissions(username, "/", "^stream.*$"); - Client client = cf.get(new Client.ClientParameters().username(username).password(password)); + Client client = cf.get(new Client.ClientParameters().username("stream").password(username)); changePassword(username, newPassword); // OK client.authenticate(credentialsProvider(username, newPassword)); diff --git a/src/test/java/com/rabbitmq/stream/impl/HttpTestUtils.java b/src/test/java/com/rabbitmq/stream/impl/HttpTestUtils.java new file mode 100644 index 0000000000..5802441087 --- /dev/null +++ b/src/test/java/com/rabbitmq/stream/impl/HttpTestUtils.java @@ -0,0 +1,136 @@ +// Copyright (c) 2025 Broadcom. All Rights Reserved. +// The term "Broadcom" refers to Broadcom Inc. and/or its subsidiaries. +// +// This software, the RabbitMQ Stream Java client library, is dual-licensed under the +// Mozilla Public License 2.0 ("MPL"), and the Apache License version 2 ("ASL"). +// For the MPL, please see LICENSE-MPL-RabbitMQ. For the ASL, +// please see LICENSE-APACHE2. +// +// This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY KIND, +// either express or implied. See the LICENSE file for specific language governing +// rights and limitations of this software. +// +// If you have any questions regarding licensing, please contact us at +// info@rabbitmq.com. +package com.rabbitmq.stream.impl; + +import static java.lang.System.currentTimeMillis; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.rabbitmq.stream.oauth2.OAuth2TestUtils; +import com.sun.net.httpserver.Headers; +import com.sun.net.httpserver.HttpHandler; +import com.sun.net.httpserver.HttpServer; +import com.sun.net.httpserver.HttpsConfigurator; +import com.sun.net.httpserver.HttpsServer; +import java.io.OutputStream; +import java.math.BigInteger; +import java.net.InetSocketAddress; +import java.security.*; +import java.security.cert.Certificate; +import java.security.cert.X509Certificate; +import java.security.spec.ECGenParameterSpec; +import java.time.Duration; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Date; +import java.util.function.LongSupplier; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import org.bouncycastle.asn1.x500.X500NameBuilder; +import org.bouncycastle.asn1.x500.style.BCStyle; +import org.bouncycastle.cert.X509CertificateHolder; +import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter; +import org.bouncycastle.cert.jcajce.JcaX509v3CertificateBuilder; +import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder; + +public final class HttpTestUtils { + + private static final char[] KEY_STORE_PASSWORD = "password".toCharArray(); + + private HttpTestUtils() {} + + public static HttpServer startServer(int port, String path, HttpHandler handler) { + return startServer(port, path, null, handler); + } + + public static HttpServer startServer( + int port, String path, KeyStore keyStore, HttpHandler handler) { + HttpServer server; + try { + if (keyStore != null) { + KeyManagerFactory keyManagerFactory = + KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + keyManagerFactory.init(keyStore, KEY_STORE_PASSWORD); + SSLContext sslContext = SSLContext.getInstance("TLS"); + sslContext.init(keyManagerFactory.getKeyManagers(), null, null); + server = HttpsServer.create(new InetSocketAddress(port), 0); + ((HttpsServer) server).setHttpsConfigurator(new HttpsConfigurator(sslContext)); + } else { + server = HttpServer.create(new InetSocketAddress(port), 0); + } + server.createContext(path, handler); + server.start(); + return server; + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + public static KeyStore generateKeyPair() { + try { + KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType()); + keyStore.load(null, KEY_STORE_PASSWORD); + + KeyPairGenerator kpg = KeyPairGenerator.getInstance("EC"); + ECGenParameterSpec spec = new ECGenParameterSpec("secp521r1"); + kpg.initialize(spec); + + KeyPair kp = kpg.generateKeyPair(); + + JcaX509v3CertificateBuilder certificateBuilder = + new JcaX509v3CertificateBuilder( + new X500NameBuilder().addRDN(BCStyle.CN, "localhost").build(), + BigInteger.valueOf(new SecureRandom().nextInt()), + Date.from(Instant.now().minus(10, ChronoUnit.DAYS)), + Date.from(Instant.now().plus(10, ChronoUnit.DAYS)), + new X500NameBuilder().addRDN(BCStyle.CN, "localhost").build(), + kp.getPublic()); + + X509CertificateHolder certificateHolder = + certificateBuilder.build( + new JcaContentSignerBuilder("SHA512withECDSA").build(kp.getPrivate())); + + X509Certificate certificate = + new JcaX509CertificateConverter().getCertificate(certificateHolder); + + keyStore.setKeyEntry( + "default", kp.getPrivate(), KEY_STORE_PASSWORD, new Certificate[] {certificate}); + return keyStore; + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + static HttpHandler oAuth2TokenHttpHandler(LongSupplier expirationTimeSupplier) { + return oAuth2TokenHttpHandler(expirationTimeSupplier, () -> {}); + } + + static HttpHandler oAuth2TokenHttpHandler( + LongSupplier expirationTimeSupplier, Runnable requestCallback) { + return exchange -> { + long expirationTime = expirationTimeSupplier.getAsLong(); + String jwtToken = JwtTestUtils.token(expirationTime); + Duration expiresIn = Duration.ofMillis(expirationTime - currentTimeMillis()); + String oauthToken = OAuth2TestUtils.sampleJsonToken(jwtToken, expiresIn); + byte[] data = oauthToken.getBytes(UTF_8); + Headers responseHeaders = exchange.getResponseHeaders(); + responseHeaders.set("content-type", "application/json"); + exchange.sendResponseHeaders(200, data.length); + OutputStream responseBody = exchange.getResponseBody(); + responseBody.write(data); + responseBody.close(); + requestCallback.run(); + }; + } +} diff --git a/src/test/java/com/rabbitmq/stream/impl/JwtTestUtils.java b/src/test/java/com/rabbitmq/stream/impl/JwtTestUtils.java new file mode 100644 index 0000000000..42f2c5362b --- /dev/null +++ b/src/test/java/com/rabbitmq/stream/impl/JwtTestUtils.java @@ -0,0 +1,97 @@ +// Copyright (c) 2024-2025 Broadcom. All Rights Reserved. +// The term "Broadcom" refers to Broadcom Inc. and/or its subsidiaries. +// +// This software, the RabbitMQ Stream Java client library, is dual-licensed under the +// Mozilla Public License 2.0 ("MPL"), and the Apache License version 2 ("ASL"). +// For the MPL, please see LICENSE-MPL-RabbitMQ. For the ASL, +// please see LICENSE-APACHE2. +// +// This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY KIND, +// either express or implied. See the LICENSE file for specific language governing +// rights and limitations of this software. +// +// If you have any questions regarding licensing, please contact us at +// info@rabbitmq.com. +package com.rabbitmq.stream.impl; + +import com.google.gson.Gson; +import com.google.gson.reflect.TypeToken; +import com.rabbitmq.stream.oauth2.Token; +import java.time.Instant; +import java.util.Base64; +import java.util.List; +import java.util.Map; +import org.apache.commons.lang3.RandomStringUtils; +import org.jose4j.jws.AlgorithmIdentifiers; +import org.jose4j.jws.JsonWebSignature; +import org.jose4j.jwt.JwtClaims; +import org.jose4j.jwt.NumericDate; +import org.jose4j.jwt.consumer.JwtConsumer; +import org.jose4j.jwt.consumer.JwtConsumerBuilder; +import org.jose4j.keys.HmacKey; + +final class JwtTestUtils { + + private static final String BASE64_KEY = "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGH"; + private static final HmacKey KEY = new HmacKey(Base64.getDecoder().decode(BASE64_KEY)); + private static final String AUDIENCE = "rabbitmq"; + private static final Gson GSON = new Gson(); + private static final TypeToken> MAP_TYPE = new TypeToken<>() {}; + + private JwtTestUtils() {} + + static String token(long expirationTime) { + try { + JwtClaims claims = new JwtClaims(); + claims.setIssuer("unit_test"); + claims.setAudience(AUDIENCE); + claims.setExpirationTime(NumericDate.fromMilliseconds(expirationTime)); + claims.setStringListClaim( + "scope", List.of("rabbitmq.configure:*/*", "rabbitmq.write:*/*", "rabbitmq.read:*/*")); + claims.setStringClaim("random", RandomStringUtils.insecure().nextAscii(6)); + + JsonWebSignature signature = new JsonWebSignature(); + + signature.setKeyIdHeaderValue("token-key"); + signature.setAlgorithmHeaderValue(AlgorithmIdentifiers.HMAC_SHA256); + signature.setKey(KEY); + signature.setPayload(claims.toJson()); + return signature.getCompactSerialization(); + } catch (Exception e) { + System.out.println("ERROR " + e.getMessage()); + throw new RuntimeException(e); + } + } + + static Token parseToken(String tokenAsString) { + long expirationTime; + try { + JwtConsumer consumer = + new JwtConsumerBuilder() + .setExpectedAudience(AUDIENCE) + // we do not validate the expiration time + .setEvaluationTime(NumericDate.fromMilliseconds(0)) + .setVerificationKey(KEY) + .build(); + JwtClaims claims = consumer.processToClaims(tokenAsString); + expirationTime = claims.getExpirationTime().getValueInMillis(); + } catch (Exception e) { + throw new RuntimeException(e); + } + return new Token() { + @Override + public String value() { + return tokenAsString; + } + + @Override + public Instant expirationTime() { + return Instant.ofEpochMilli(expirationTime); + } + }; + } + + static Map parse(String json) { + return GSON.fromJson(json, MAP_TYPE); + } +} diff --git a/src/test/java/com/rabbitmq/stream/impl/OAuth2ClientTest.java b/src/test/java/com/rabbitmq/stream/impl/OAuth2ClientTest.java new file mode 100644 index 0000000000..ac69fa2522 --- /dev/null +++ b/src/test/java/com/rabbitmq/stream/impl/OAuth2ClientTest.java @@ -0,0 +1,177 @@ +// Copyright (c) 2025 Broadcom. All Rights Reserved. +// The term "Broadcom" refers to Broadcom Inc. and/or its subsidiaries. +// +// This software, the RabbitMQ Stream Java client library, is dual-licensed under the +// Mozilla Public License 2.0 ("MPL"), and the Apache License version 2 ("ASL"). +// For the MPL, please see LICENSE-MPL-RabbitMQ. For the ASL, +// please see LICENSE-APACHE2. +// +// This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY KIND, +// either express or implied. See the LICENSE file for specific language governing +// rights and limitations of this software. +// +// If you have any questions regarding licensing, please contact us at +// info@rabbitmq.com. +package com.rabbitmq.stream.impl; + +import static com.rabbitmq.stream.impl.Assertions.*; +import static com.rabbitmq.stream.impl.Client.cp; +import static com.rabbitmq.stream.impl.HttpTestUtils.generateKeyPair; +import static com.rabbitmq.stream.impl.HttpTestUtils.oAuth2TokenHttpHandler; +import static com.rabbitmq.stream.impl.TestUtils.sync; +import static java.lang.System.currentTimeMillis; +import static java.time.Duration.ofMinutes; +import static java.time.Duration.ofSeconds; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import com.rabbitmq.stream.AuthenticationFailureException; +import com.rabbitmq.stream.impl.TestUtils.ClientFactory; +import com.rabbitmq.stream.impl.TestUtils.DisabledIfOauth2AuthBackendNotEnabled; +import com.rabbitmq.stream.impl.TestUtils.Sync; +import com.rabbitmq.stream.oauth2.GsonTokenParser; +import com.rabbitmq.stream.oauth2.HttpTokenRequester; +import com.rabbitmq.stream.oauth2.TokenCredentialsManager; +import com.sun.net.httpserver.HttpHandler; +import com.sun.net.httpserver.HttpServer; +import java.security.KeyStore; +import java.time.Duration; +import java.time.Instant; +import java.util.Collections; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.function.Function; +import javax.net.ssl.SSLContext; +import javax.net.ssl.TrustManagerFactory; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +@StreamTestInfrastructure +@DisabledIfOauth2AuthBackendNotEnabled +public class OAuth2ClientTest { + + ClientFactory cf; + HttpServer server; + String contextPath = "/uaa/oauth/token"; + int port; + ScheduledExecutorService scheduledExecutorService; + + @BeforeEach + void init() throws Exception { + this.scheduledExecutorService = Executors.newSingleThreadScheduledExecutor(); + this.port = TestUtils.randomNetworkPort(); + } + + @AfterEach + void tearDown() { + this.scheduledExecutorService.shutdown(); + if (this.server != null) { + server.stop(0); + } + } + + @Test + void openingConnectionWithValidTokenShouldSucceed() { + this.server = + start(oAuth2TokenHttpHandler(() -> currentTimeMillis() + ofMinutes(60).toMillis())); + + TokenCredentialsManager tokenCredentialsManager = createTokenCredentialsManager(); + try (Client ignored = cf.get(cp().credentialsManager(tokenCredentialsManager))) {} + } + + @Test + void openingConnectionWithExpiredTokenShouldFail() { + this.server = + start(oAuth2TokenHttpHandler(() -> currentTimeMillis() - ofMinutes(60).toMillis())); + + TokenCredentialsManager tokenCredentialsManager = createTokenCredentialsManager(); + + assertThatThrownBy(() -> cf.get(cp().credentialsManager(tokenCredentialsManager))) + .isInstanceOf(AuthenticationFailureException.class); + } + + @Test + void connectionShouldBeClosedWhenRefreshedTokenExpires() { + Duration tokenDuration = Duration.ofSeconds(2); + long expiry = currentTimeMillis() + tokenDuration.toMillis(); + String token = JwtTestUtils.token(expiry); + + Sync sync = sync(); + cf.get(cp().username("").password(token).shutdownListener(shutdownContext -> sync.down())); + assertThat(sync).completes(tokenDuration.multipliedBy(4)); + } + + @Test + void tokenWithHttpShouldBeRefreshedWhenItExpires() throws Exception { + this.tokenShouldBeRefreshedWhenItExpires(null); + } + + @Test + void tokenWithHttpsShouldBeRefreshedWhenItExpires() throws Exception { + this.tokenShouldBeRefreshedWhenItExpires(generateKeyPair()); + } + + private void tokenShouldBeRefreshedWhenItExpires(KeyStore ks) throws Exception { + int tokenRefreshCount = 3; + Sync tokenRefreshedSync = sync(tokenRefreshCount); + Duration tokenLifetime = ofSeconds(3); + HttpHandler httpHandler = + oAuth2TokenHttpHandler( + () -> currentTimeMillis() + tokenLifetime.toMillis(), tokenRefreshedSync::down); + + this.server = start(httpHandler, ks); + + SSLContext sslContext = null; + if (ks != null) { + sslContext = SSLContext.getInstance("TLS"); + TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509"); + tmf.init(ks); + sslContext.init(null, tmf.getTrustManagers(), null); + } + + TokenCredentialsManager tokenCredentialsManager = createTokenCredentialsManager(sslContext); + + try (Client ignored = cf.get(cp().credentialsManager(tokenCredentialsManager))) { + assertThat(tokenRefreshedSync).completes(tokenLifetime.multipliedBy(tokenRefreshCount + 1)); + } + } + + private HttpServer start(HttpHandler handler) { + return start(handler, null); + } + + private HttpServer start(HttpHandler handler, KeyStore ks) { + return HttpTestUtils.startServer(port, contextPath, ks, handler); + } + + private TokenCredentialsManager createTokenCredentialsManager() { + return this.createTokenCredentialsManager(null); + } + + private TokenCredentialsManager createTokenCredentialsManager(SSLContext sslContext) { + String uri = + String.format( + "%s://localhost:%d%s", + sslContext == null ? "http" : "https", this.port, this.contextPath); + HttpTokenRequester tokenRequester = + new HttpTokenRequester( + uri, + "rabbitmq", + "rabbitmq", + "client_credentials", + Collections.emptyMap(), + c -> { + if (sslContext != null) { + c.sslContext(sslContext); + } + }, + null, + new GsonTokenParser()); + // the broker works at the second level for expiration + // we have to make sure to renew fast enough for short-lived tokens + Function refreshDelayStrategy = + TokenCredentialsManager.ratioRefreshDelayStrategy(0.4f); + return new TokenCredentialsManager( + tokenRequester, this.scheduledExecutorService, refreshDelayStrategy); + } +} diff --git a/src/test/java/com/rabbitmq/stream/impl/StreamEnvironmentOAuth2Test.java b/src/test/java/com/rabbitmq/stream/impl/StreamEnvironmentOAuth2Test.java new file mode 100644 index 0000000000..b16797169a --- /dev/null +++ b/src/test/java/com/rabbitmq/stream/impl/StreamEnvironmentOAuth2Test.java @@ -0,0 +1,220 @@ +// Copyright (c) 2025 Broadcom. All Rights Reserved. +// The term "Broadcom" refers to Broadcom Inc. and/or its subsidiaries. +// +// This software, the RabbitMQ Stream Java client library, is dual-licensed under the +// Mozilla Public License 2.0 ("MPL"), and the Apache License version 2 ("ASL"). +// For the MPL, please see LICENSE-MPL-RabbitMQ. For the ASL, +// please see LICENSE-APACHE2. +// +// This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY KIND, +// either express or implied. See the LICENSE file for specific language governing +// rights and limitations of this software. +// +// If you have any questions regarding licensing, please contact us at +// info@rabbitmq.com. +package com.rabbitmq.stream.impl; + +import static com.rabbitmq.stream.Constants.CODE_PRODUCER_NOT_AVAILABLE; +import static com.rabbitmq.stream.impl.Assertions.assertThat; +import static com.rabbitmq.stream.impl.HttpTestUtils.generateKeyPair; +import static com.rabbitmq.stream.impl.HttpTestUtils.oAuth2TokenHttpHandler; +import static com.rabbitmq.stream.impl.TestUtils.localhost; +import static com.rabbitmq.stream.impl.TestUtils.localhostTls; +import static com.rabbitmq.stream.impl.TestUtils.sync; +import static com.rabbitmq.stream.impl.TestUtils.waitAtMost; +import static java.lang.System.currentTimeMillis; +import static java.time.Duration.ofSeconds; +import static org.assertj.core.api.Assertions.*; + +import com.rabbitmq.stream.Environment; +import com.rabbitmq.stream.EnvironmentBuilder; +import com.rabbitmq.stream.Producer; +import com.rabbitmq.stream.impl.StreamEnvironmentBuilder.DefaultOAuth2Configuration; +import com.rabbitmq.stream.impl.TestUtils.DisabledIfOauth2AuthBackendNotEnabled; +import com.rabbitmq.stream.impl.TestUtils.Sync; +import com.rabbitmq.stream.oauth2.TokenCredentialsManager; +import com.sun.net.httpserver.HttpHandler; +import com.sun.net.httpserver.HttpServer; +import io.netty.channel.EventLoopGroup; +import java.security.KeyStore; +import java.time.Duration; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import javax.net.ssl.SSLContext; +import javax.net.ssl.TrustManagerFactory; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +@StreamTestInfrastructure +@DisabledIfOauth2AuthBackendNotEnabled +public class StreamEnvironmentOAuth2Test { + + EnvironmentBuilder environmentBuilder; + + String stream; + TestUtils.ClientFactory cf; + EventLoopGroup eventLoopGroup; + + HttpServer server; + String contextPath = "/uaa/oauth/token"; + int port; + + @BeforeEach + void init() throws Exception { + environmentBuilder = Environment.builder(); + environmentBuilder.addressResolver( + add -> add.port() == Client.DEFAULT_PORT ? localhost() : localhostTls()); + environmentBuilder.netty().eventLoopGroup(eventLoopGroup); + this.port = TestUtils.randomNetworkPort(); + } + + @AfterEach + void tearDown() { + if (this.server != null) { + server.stop(0); + } + } + + @Test + void tokenWithHttpShouldBeRefreshedWhenItExpires() throws Exception { + tokenShouldBeRefreshedWhenItExpires(null); + } + + @Test + void tokenWithHttpsShouldBeRefreshedWhenItExpires() throws Exception { + tokenShouldBeRefreshedWhenItExpires(generateKeyPair()); + } + + @Test + void environmentShouldNotWorkAfterTokenExpires() throws Exception { + Duration tokenLifetime = ofSeconds(3); + AtomicInteger serverCallCount = new AtomicInteger(0); + Sync tokenRefreshedSync = sync(2); + HttpHandler httpHandler = + oAuth2TokenHttpHandler( + () -> { + // if (serverCallCount.getAndIncrement() == 0) { + return currentTimeMillis() + tokenLifetime.toMillis(); + // } else { + // return currentTimeMillis() - 100; + // } + }, + tokenRefreshedSync::down); + this.server = start(httpHandler, null); + + try (Environment env = + environmentBuilder + .oauth2() + .tokenEndpointUri(uri()) + .clientId("rabbitmq") + .clientSecret("rabbitmq") + .grantType("client_credentials") + .environmentBuilder() + .build()) { + Producer producer = env.producerBuilder().stream(stream).build(); + Sync consumeSync = sync(); + StreamConsumer consumer = + (StreamConsumer) + env.consumerBuilder().stream(stream) + .messageHandler((ctx, msg) -> consumeSync.down()) + .build(); + producer.send(producer.messageBuilder().build(), ctx -> {}); + assertThat(consumeSync).completes(); + assertThat(tokenRefreshedSync).completes(); + org.assertj.core.api.Assertions.assertThat(consumer.isConsuming()); + + // stopping the token server, there won't be attempts to re-authenticate + this.server.stop(0); + + waitAtMost( + () -> { + try { + env.streamExists(stream); + return false; + } catch (Exception e) { + return true; + } + }); + + AtomicInteger lastResponseCode = new AtomicInteger(-1); + waitAtMost( + () -> { + CountDownLatch latch = new CountDownLatch(1); + AtomicBoolean confirmed = new AtomicBoolean(false); + producer.send( + producer.messageBuilder().build(), + ctx -> { + lastResponseCode.set(ctx.getCode()); + confirmed.set(ctx.isConfirmed()); + latch.countDown(); + }); + boolean receivedCallback = latch.await(10, TimeUnit.SECONDS); + return receivedCallback && !confirmed.get(); + }); + org.assertj.core.api.Assertions.assertThat(lastResponseCode) + .hasValue(CODE_PRODUCER_NOT_AVAILABLE); + + waitAtMost(() -> !consumer.isConsuming()); + } + } + + private void tokenShouldBeRefreshedWhenItExpires(KeyStore ks) throws Exception { + int tokenRefreshCount = 3; + Sync tokenRefreshedSync = sync(tokenRefreshCount); + Duration tokenLifetime = ofSeconds(3); + HttpHandler httpHandler = + oAuth2TokenHttpHandler( + () -> currentTimeMillis() + tokenLifetime.toMillis(), tokenRefreshedSync::down); + + this.server = start(httpHandler, ks); + + SSLContext sslContext = null; + if (ks != null) { + sslContext = SSLContext.getInstance("TLS"); + TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509"); + tmf.init(ks); + sslContext.init(null, tmf.getTrustManagers(), null); + } + + DefaultOAuth2Configuration oauth = (DefaultOAuth2Configuration) environmentBuilder.oauth2(); + // the broker works at the second level for expiration + // we have to make sure to renew fast enough for short-lived tokens + oauth.refreshDelayStrategy(TokenCredentialsManager.ratioRefreshDelayStrategy(0.4f)); + try (Environment env = + environmentBuilder + .oauth2() + .tokenEndpointUri(uri(sslContext)) + .clientId("rabbitmq") + .clientSecret("rabbitmq") + .grantType("client_credentials") + .sslContext(sslContext) + .environmentBuilder() + .build()) { + + Producer producer = env.producerBuilder().stream(stream).build(); + + Sync consumeSync = sync(); + env.consumerBuilder().stream(stream).messageHandler((ctx, msg) -> consumeSync.down()).build(); + assertThat(tokenRefreshedSync).completes(tokenLifetime.multipliedBy(tokenRefreshCount + 1)); + + producer.send(producer.messageBuilder().build(), ctx -> {}); + assertThat(consumeSync).completes(); + } + } + + private String uri() { + return this.uri(null); + } + + private String uri(SSLContext sslContext) { + return String.format( + "%s://localhost:%d%s", sslContext == null ? "http" : "https", this.port, this.contextPath); + } + + private HttpServer start(HttpHandler handler, KeyStore ks) { + return HttpTestUtils.startServer(port, contextPath, ks, handler); + } +} diff --git a/src/test/java/com/rabbitmq/stream/impl/StreamEnvironmentUnitTest.java b/src/test/java/com/rabbitmq/stream/impl/StreamEnvironmentUnitTest.java index 8dafb175ad..655c50c265 100644 --- a/src/test/java/com/rabbitmq/stream/impl/StreamEnvironmentUnitTest.java +++ b/src/test/java/com/rabbitmq/stream/impl/StreamEnvironmentUnitTest.java @@ -93,6 +93,7 @@ Client.ClientParameters duplicate() { ProducersCoordinator.MAX_TRACKING_CONSUMERS_PER_CLIENT, ConsumersCoordinator.MAX_SUBSCRIPTIONS_PER_CLIENT, null, + null, Utils.byteBufAllocator(), false, type -> "locator-connection", @@ -162,6 +163,7 @@ void shouldTryUrisOnInitializationFailure() throws Exception { ProducersCoordinator.MAX_TRACKING_CONSUMERS_PER_CLIENT, ConsumersCoordinator.MAX_SUBSCRIPTIONS_PER_CLIENT, null, + null, Utils.byteBufAllocator(), false, type -> "locator-connection", @@ -194,6 +196,7 @@ void shouldNotOpenConnectionWhenLazyInitIsEnabled( ProducersCoordinator.MAX_TRACKING_CONSUMERS_PER_CLIENT, ConsumersCoordinator.MAX_SUBSCRIPTIONS_PER_CLIENT, null, + null, Utils.byteBufAllocator(), lazyInit, type -> "locator-connection", diff --git a/src/test/java/com/rabbitmq/stream/impl/TestUtils.java b/src/test/java/com/rabbitmq/stream/impl/TestUtils.java index 2c286571a4..78db8f4cac 100644 --- a/src/test/java/com/rabbitmq/stream/impl/TestUtils.java +++ b/src/test/java/com/rabbitmq/stream/impl/TestUtils.java @@ -51,6 +51,7 @@ import java.lang.annotation.Target; import java.lang.reflect.Field; import java.lang.reflect.Method; +import java.net.ServerSocket; import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.*; @@ -121,19 +122,31 @@ public static Duration waitAtMost( public static Duration waitAtMost( Duration timeout, CallableBooleanSupplier condition, Supplier message) throws Exception { + return waitAtMost(timeout, Duration.ofMillis(100), condition, message); + } + + public static Duration waitAtMost( + Duration timeout, Duration waitTime, CallableBooleanSupplier condition) throws Exception { + return waitAtMost(timeout, waitTime, condition, null); + } + + public static Duration waitAtMost( + Duration timeout, + Duration waitTime, + CallableBooleanSupplier condition, + Supplier message) + throws Exception { if (condition.getAsBoolean()) { return Duration.ZERO; } - int waitTime = 100; - int waitedTime = 0; - int timeoutInMs = (int) timeout.toMillis(); + Duration waitedTime = Duration.ZERO; Exception exception = null; - while (waitedTime <= timeoutInMs) { - Thread.sleep(waitTime); - waitedTime += waitTime; + while (waitedTime.compareTo(timeout) <= 0) { + Thread.sleep(waitTime.toMillis()); + waitedTime = waitedTime.plus(waitTime); try { if (condition.getAsBoolean()) { - return Duration.ofMillis(waitedTime); + return waitedTime; } exception = null; } catch (Exception e) { @@ -151,7 +164,7 @@ public static Duration waitAtMost( } else { fail(msg, exception); } - return Duration.ofMillis(waitedTime); + return waitedTime; } public static Address localhost() { @@ -518,6 +531,12 @@ static boolean atLeastVersion(String expectedVersion, String currentVersion) { @ExtendWith(DisabledIfAuthMechanismSslNotEnabledCondition.class) @interface DisabledIfAuthMechanismSslNotEnabled {} + @Target({ElementType.TYPE, ElementType.METHOD}) + @Retention(RetentionPolicy.RUNTIME) + @Documented + @ExtendWith(DisabledIfOauth2AuthBackendNotEnabledCondition.class) + @interface DisabledIfOauth2AuthBackendNotEnabled {} + @Target({ElementType.TYPE, ElementType.METHOD}) @Retention(RetentionPolicy.RUNTIME) @Documented @@ -879,6 +898,16 @@ static class DisabledIfAmqp10NotEnabledCondition extends DisabledIfPluginNotEnab } } + static class DisabledIfOauth2AuthBackendNotEnabledCondition + extends DisabledIfPluginNotEnabledCondition { + + DisabledIfOauth2AuthBackendNotEnabledCondition() { + super( + "OAuth2 authentication backend", + output -> output.contains("rabbitmq_auth_backend_oauth2")); + } + } + static class DisabledIfTlsNotEnabledCondition implements ExecutionCondition { @Override @@ -1139,15 +1168,15 @@ private static Connection connection() throws IOException, TimeoutException { return AMQP_CF.newConnection(); } - static Sync sync() { + public static Sync sync() { return sync(1); } - static Sync sync(int count) { + public static Sync sync(int count) { return new Sync(count); } - static class Sync { + public static class Sync { private final AtomicReference latch = new AtomicReference<>(); @@ -1155,7 +1184,7 @@ private Sync(int count) { this.latch.set(new CountDownLatch(count)); } - void down() { + public void down() { this.latch.get().countDown(); } @@ -1192,4 +1221,12 @@ boolean hasCompleted() { static Collection threads() { return Thread.getAllStackTraces().keySet(); } + + public static int randomNetworkPort() throws IOException { + ServerSocket socket = new ServerSocket(); + socket.bind(null); + int port = socket.getLocalPort(); + socket.close(); + return port; + } } diff --git a/src/test/java/com/rabbitmq/stream/oauth2/GsonTokenParserTest.java b/src/test/java/com/rabbitmq/stream/oauth2/GsonTokenParserTest.java new file mode 100644 index 0000000000..03ded35b6d --- /dev/null +++ b/src/test/java/com/rabbitmq/stream/oauth2/GsonTokenParserTest.java @@ -0,0 +1,42 @@ +// Copyright (c) 2024-2025 Broadcom. All Rights Reserved. +// The term "Broadcom" refers to Broadcom Inc. and/or its subsidiaries. +// +// This software, the RabbitMQ Stream Java client library, is dual-licensed under the +// Mozilla Public License 2.0 ("MPL"), and the Apache License version 2 ("ASL"). +// For the MPL, please see LICENSE-MPL-RabbitMQ. For the ASL, +// please see LICENSE-APACHE2. +// +// This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY KIND, +// either express or implied. See the LICENSE file for specific language governing +// rights and limitations of this software. +// +// If you have any questions regarding licensing, please contact us at +// info@rabbitmq.com. +package com.rabbitmq.stream.oauth2; + +import static com.rabbitmq.stream.oauth2.OAuth2TestUtils.sampleJsonToken; +import static java.time.Duration.ofSeconds; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.within; + +import java.time.Duration; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.UUID; +import org.junit.jupiter.api.Test; + +public class GsonTokenParserTest { + + TokenParser parser = new GsonTokenParser(); + + @Test + void parse() { + String accessToken = UUID.randomUUID().toString(); + Duration expireIn = ofSeconds(60); + String jsonToken = sampleJsonToken(accessToken, expireIn); + Token token = parser.parse(jsonToken); + assertThat(token.value()).isEqualTo(accessToken); + assertThat(token.expirationTime()) + .isCloseTo(Instant.now().plus(expireIn), within(1, ChronoUnit.SECONDS)); + } +} diff --git a/src/test/java/com/rabbitmq/stream/oauth2/HttpTokenRequesterTest.java b/src/test/java/com/rabbitmq/stream/oauth2/HttpTokenRequesterTest.java new file mode 100644 index 0000000000..a5338b9992 --- /dev/null +++ b/src/test/java/com/rabbitmq/stream/oauth2/HttpTokenRequesterTest.java @@ -0,0 +1,170 @@ +// Copyright (c) 2024-2025 Broadcom. All Rights Reserved. +// The term "Broadcom" refers to Broadcom Inc. and/or its subsidiaries. +// +// This software, the RabbitMQ Stream Java client library, is dual-licensed under the +// Mozilla Public License 2.0 ("MPL"), and the Apache License version 2 ("ASL"). +// For the MPL, please see LICENSE-MPL-RabbitMQ. For the ASL, +// please see LICENSE-APACHE2. +// +// This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY KIND, +// either express or implied. See the LICENSE file for specific language governing +// rights and limitations of this software. +// +// If you have any questions regarding licensing, please contact us at +// info@rabbitmq.com. +package com.rabbitmq.stream.oauth2; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.assertj.core.api.Assertions.assertThat; + +import com.google.gson.Gson; +import com.google.gson.reflect.TypeToken; +import com.sun.net.httpserver.Headers; +import com.sun.net.httpserver.HttpServer; +import java.io.IOException; +import java.io.OutputStream; +import java.net.http.HttpClient; +import java.security.KeyStore; +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.stream.Collectors; +import javax.net.ssl.SSLContext; +import javax.net.ssl.TrustManagerFactory; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +public class HttpTokenRequesterTest { + + HttpServer server; + int port; + String contextPath = "/uaa/oauth/token"; + + @BeforeEach + void init() throws IOException { + this.port = OAuth2TestUtils.randomNetworkPort(); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void requestToken(boolean tls) throws Exception { + String protocol; + KeyStore keyStore; + Consumer clientBuilderConsumer; + if (tls) { + protocol = "https"; + keyStore = OAuth2TestUtils.generateKeyPair(); + SSLContext sslContext = SSLContext.getInstance("TLS"); + TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509"); + tmf.init(keyStore); + sslContext.init(null, tmf.getTrustManagers(), null); + clientBuilderConsumer = b -> b.sslContext(sslContext); + } else { + protocol = "http"; + keyStore = null; + clientBuilderConsumer = b -> {}; + } + String uri = String.format("%s://localhost:%d%s", protocol, port, contextPath); + AtomicReference httpMethod = new AtomicReference<>(); + AtomicReference contentType = new AtomicReference<>(); + AtomicReference authorization = new AtomicReference<>(); + AtomicReference accept = new AtomicReference<>(); + AtomicReference> httpParameters = new AtomicReference<>(); + + String accessToken = UUID.randomUUID().toString(); + + Duration expiresIn = Duration.ofSeconds(60); + server = + OAuth2TestUtils.startServer( + port, + contextPath, + keyStore, + exchange -> { + Headers headers = exchange.getRequestHeaders(); + httpMethod.set(exchange.getRequestMethod()); + contentType.set(headers.getFirst("content-type")); + authorization.set(headers.getFirst("authorization")); + accept.set(headers.getFirst("accept")); + + String requestBody = new String(exchange.getRequestBody().readAllBytes(), UTF_8); + Map parameters = + Arrays.stream(requestBody.split("&")) + .map(p -> p.split("=")) + .collect(Collectors.toMap(p -> p[0], p -> p[1])); + httpParameters.set(parameters); + + byte[] data = OAuth2TestUtils.sampleJsonToken(accessToken, expiresIn).getBytes(UTF_8); + + Headers responseHeaders = exchange.getResponseHeaders(); + responseHeaders.set("content-type", "application/json"); + exchange.sendResponseHeaders(200, data.length); + OutputStream responseBody = exchange.getResponseBody(); + responseBody.write(data); + responseBody.close(); + }); + + TokenRequester requester = + new HttpTokenRequester( + uri, + "rabbit_client", + "rabbit_secret", + "password", + Map.of("username", "rabbit_username", "password", "rabbit_password"), + clientBuilderConsumer, + null, + StringToken::new); + + String token = requester.request().value(); + assertThat(token).contains(accessToken); + Gson gson = new Gson(); + TypeToken> mapType = new TypeToken<>() {}; + Map tokenMap = gson.fromJson(token, mapType); + assertThat(tokenMap) + .containsEntry("access_token", accessToken) + .containsEntry("expires_in", (double) expiresIn.toSeconds()); + + assertThat(httpMethod).hasValue("POST"); + assertThat(contentType).hasValue("application/x-www-form-urlencoded"); + assertThat(authorization).hasValue("Basic cmFiYml0X2NsaWVudDpyYWJiaXRfc2VjcmV0"); + assertThat(accept).hasValue("application/json"); + Map parameters = httpParameters.get(); + assertThat(parameters) + .isNotNull() + .hasSize(3) + .containsEntry("grant_type", "password") + .containsEntry("username", "rabbit_username") + .containsEntry("password", "rabbit_password"); + } + + @AfterEach + public void tearDown() { + if (server != null) { + server.stop(0); + } + } + + private static class StringToken implements Token { + + private final String value; + + private StringToken(String value) { + this.value = value; + } + + @Override + public String value() { + return this.value; + } + + @Override + public Instant expirationTime() { + return Instant.EPOCH; + } + } +} diff --git a/src/test/java/com/rabbitmq/stream/oauth2/OAuth2TestUtils.java b/src/test/java/com/rabbitmq/stream/oauth2/OAuth2TestUtils.java new file mode 100644 index 0000000000..62f74d7ec4 --- /dev/null +++ b/src/test/java/com/rabbitmq/stream/oauth2/OAuth2TestUtils.java @@ -0,0 +1,205 @@ +// Copyright (c) 2024-2025 Broadcom. All Rights Reserved. +// The term "Broadcom" refers to Broadcom Inc. and/or its subsidiaries. +// +// This software, the RabbitMQ Stream Java client library, is dual-licensed under the +// Mozilla Public License 2.0 ("MPL"), and the Apache License version 2 ("ASL"). +// For the MPL, please see LICENSE-MPL-RabbitMQ. For the ASL, +// please see LICENSE-APACHE2. +// +// This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY KIND, +// either express or implied. See the LICENSE file for specific language governing +// rights and limitations of this software. +// +// If you have any questions regarding licensing, please contact us at +// info@rabbitmq.com. +package com.rabbitmq.stream.oauth2; + +import static org.junit.jupiter.api.Assertions.fail; + +import com.sun.net.httpserver.HttpHandler; +import com.sun.net.httpserver.HttpServer; +import com.sun.net.httpserver.HttpsConfigurator; +import com.sun.net.httpserver.HttpsServer; +import java.io.IOException; +import java.math.BigInteger; +import java.net.InetSocketAddress; +import java.net.ServerSocket; +import java.security.KeyPair; +import java.security.KeyPairGenerator; +import java.security.KeyStore; +import java.security.SecureRandom; +import java.security.cert.X509Certificate; +import java.security.spec.ECGenParameterSpec; +import java.time.Duration; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Date; +import java.util.function.Supplier; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import org.bouncycastle.asn1.x500.X500NameBuilder; +import org.bouncycastle.asn1.x500.style.BCStyle; +import org.bouncycastle.cert.X509CertificateHolder; +import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter; +import org.bouncycastle.cert.jcajce.JcaX509v3CertificateBuilder; +import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder; + +public final class OAuth2TestUtils { + + private static final char[] KEY_STORE_PASSWORD = "password".toCharArray(); + + private OAuth2TestUtils() {} + + public static String sampleJsonToken(String accessToken, Duration expiresIn) { + String json = + "{\n" + + " \"access_token\" : \"{accessToken}\",\n" + + " \"token_type\" : \"bearer\",\n" + + " \"expires_in\" : {expiresIn},\n" + + " \"scope\" : \"clients.read emails.write scim.userids password.write idps.write notifications.write oauth.login scim.write critical_notifications.write\",\n" + + " \"jti\" : \"18c1b1dfdda04382a8bcc14d077b71dd\"\n" + + "}"; + return json.replace("{accessToken}", accessToken) + .replace("{expiresIn}", expiresIn.toSeconds() + ""); + } + + public static int randomNetworkPort() throws IOException { + ServerSocket socket = new ServerSocket(); + socket.bind(null); + int port = socket.getLocalPort(); + socket.close(); + return port; + } + + public static Duration waitAtMost( + Duration timeout, + Duration waitTime, + CallableBooleanSupplier condition, + Supplier message) + throws Exception { + if (condition.getAsBoolean()) { + return Duration.ZERO; + } + Duration waitedTime = Duration.ZERO; + Exception exception = null; + while (waitedTime.compareTo(timeout) <= 0) { + Thread.sleep(waitTime.toMillis()); + waitedTime = waitedTime.plus(waitTime); + try { + if (condition.getAsBoolean()) { + return waitedTime; + } + exception = null; + } catch (Exception e) { + exception = e; + } + } + String msg; + if (message == null) { + msg = "Waited " + timeout.getSeconds() + " second(s), condition never got true"; + } else { + msg = "Waited " + timeout.getSeconds() + " second(s), " + message.get(); + } + if (exception == null) { + fail(msg); + } else { + fail(msg, exception); + } + return waitedTime; + } + + public static Duration waitAtMost( + Duration timeout, Duration waitTime, CallableBooleanSupplier condition) throws Exception { + return waitAtMost(timeout, waitTime, condition, null); + } + + public static HttpServer startServer(int port, String path, HttpHandler handler) { + return startServer(port, path, null, handler); + } + + public static HttpServer startServer( + int port, String path, KeyStore keyStore, HttpHandler handler) { + HttpServer server; + try { + if (keyStore != null) { + KeyManagerFactory keyManagerFactory = + KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + keyManagerFactory.init(keyStore, KEY_STORE_PASSWORD); + SSLContext sslContext = SSLContext.getInstance("TLS"); + sslContext.init(keyManagerFactory.getKeyManagers(), null, null); + server = HttpsServer.create(new InetSocketAddress(port), 0); + ((HttpsServer) server).setHttpsConfigurator(new HttpsConfigurator(sslContext)); + } else { + server = HttpServer.create(new InetSocketAddress(port), 0); + } + server.createContext(path, handler); + server.start(); + return server; + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + public static KeyStore generateKeyPair() { + try { + KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType()); + keyStore.load(null, KEY_STORE_PASSWORD); + + KeyPairGenerator kpg = KeyPairGenerator.getInstance("EC"); + ECGenParameterSpec spec = new ECGenParameterSpec("secp521r1"); + kpg.initialize(spec); + + KeyPair kp = kpg.generateKeyPair(); + + JcaX509v3CertificateBuilder certificateBuilder = + new JcaX509v3CertificateBuilder( + new X500NameBuilder().addRDN(BCStyle.CN, "localhost").build(), + BigInteger.valueOf(new SecureRandom().nextInt()), + Date.from(Instant.now().minus(10, ChronoUnit.DAYS)), + Date.from(Instant.now().plus(10, ChronoUnit.DAYS)), + new X500NameBuilder().addRDN(BCStyle.CN, "localhost").build(), + kp.getPublic()); + + X509CertificateHolder certificateHolder = + certificateBuilder.build( + new JcaContentSignerBuilder("SHA256withECDSA").build(kp.getPrivate())); + + X509Certificate certificate = + new JcaX509CertificateConverter().getCertificate(certificateHolder); + + keyStore.setKeyEntry( + "localhost", kp.getPrivate(), KEY_STORE_PASSWORD, new X509Certificate[] {certificate}); + + return keyStore; + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + public static Pair pair(A v1, B v2) { + return new Pair<>(v1, v2); + } + + public interface CallableBooleanSupplier { + boolean getAsBoolean() throws Exception; + } + + public static class Pair { + + private final A v1; + private final B v2; + + private Pair(A v1, B v2) { + this.v1 = v1; + this.v2 = v2; + } + + public A v1() { + return this.v1; + } + + public B v2() { + return this.v2; + } + } +} diff --git a/src/test/java/com/rabbitmq/stream/oauth2/TokenCredentialsManagerTest.java b/src/test/java/com/rabbitmq/stream/oauth2/TokenCredentialsManagerTest.java new file mode 100644 index 0000000000..a1924e98e8 --- /dev/null +++ b/src/test/java/com/rabbitmq/stream/oauth2/TokenCredentialsManagerTest.java @@ -0,0 +1,177 @@ +// Copyright (c) 2024-2025 Broadcom. All Rights Reserved. +// The term "Broadcom" refers to Broadcom Inc. and/or its subsidiaries. +// +// This software, the RabbitMQ Stream Java client library, is dual-licensed under the +// Mozilla Public License 2.0 ("MPL"), and the Apache License version 2 ("ASL"). +// For the MPL, please see LICENSE-MPL-RabbitMQ. For the ASL, +// please see LICENSE-APACHE2. +// +// This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY KIND, +// either express or implied. See the LICENSE file for specific language governing +// rights and limitations of this software. +// +// If you have any questions regarding licensing, please contact us at +// info@rabbitmq.com. +package com.rabbitmq.stream.oauth2; + +import static com.rabbitmq.stream.oauth2.OAuth2TestUtils.pair; +import static com.rabbitmq.stream.oauth2.OAuth2TestUtils.waitAtMost; +import static com.rabbitmq.stream.oauth2.TokenCredentialsManager.DEFAULT_REFRESH_DELAY_STRATEGY; +import static java.time.Duration.ofMillis; +import static java.time.Duration.ofSeconds; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.stream.Collectors.toList; +import static java.util.stream.IntStream.range; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.when; + +import com.rabbitmq.stream.oauth2.CredentialsManager.Registration; +import com.rabbitmq.stream.oauth2.OAuth2TestUtils.Pair; +import java.time.Duration; +import java.time.Instant; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +public class TokenCredentialsManagerTest { + + ScheduledExecutorService scheduledExecutorService; + AutoCloseable mocks; + @Mock TokenRequester requester; + + @BeforeEach + void init() { + this.scheduledExecutorService = Executors.newSingleThreadScheduledExecutor(); + this.mocks = MockitoAnnotations.openMocks(this); + } + + @AfterEach + void tearDown() throws Exception { + this.scheduledExecutorService.shutdownNow(); + this.mocks.close(); + } + + @Test + void refreshShouldStopOnceUnregistered() throws InterruptedException { + Duration tokenExpiry = ofMillis(50); + AtomicInteger requestCount = new AtomicInteger(0); + when(this.requester.request()) + .thenAnswer( + ignored -> { + requestCount.incrementAndGet(); + return token("ok", Instant.now().plus(tokenExpiry)); + }); + TokenCredentialsManager credentials = + new TokenCredentialsManager( + this.requester, this.scheduledExecutorService, DEFAULT_REFRESH_DELAY_STRATEGY); + int expectedRefreshCount = 3; + AtomicInteger refreshCount = new AtomicInteger(); + CountDownLatch refreshLatch = new CountDownLatch(expectedRefreshCount); + Registration registration = + credentials.register( + "", + (u, p) -> { + refreshCount.incrementAndGet(); + refreshLatch.countDown(); + }); + registration.connect(connectionCallback(() -> {})); + assertThat(requestCount).hasValue(1); + assertThat(refreshLatch.await(ofSeconds(10).toMillis(), MILLISECONDS)).isTrue(); + assertThat(requestCount).hasValue(expectedRefreshCount + 1); + registration.close(); + assertThat(refreshCount).hasValue(expectedRefreshCount); + assertThat(requestCount).hasValue(expectedRefreshCount + 1); + Thread.sleep(tokenExpiry.multipliedBy(2).toMillis()); + assertThat(refreshCount).hasValue(expectedRefreshCount); + assertThat(requestCount).hasValue(expectedRefreshCount + 1); + } + + @Test + void severalRegistrationsShouldBeRefreshed() throws Exception { + Duration tokenExpiry = ofMillis(50); + Duration waitTime = tokenExpiry.dividedBy(4); + Duration timeout = tokenExpiry.multipliedBy(20); + when(this.requester.request()) + .thenAnswer(ignored -> token("ok", Instant.now().plus(tokenExpiry))); + TokenCredentialsManager credentials = + new TokenCredentialsManager( + this.requester, this.scheduledExecutorService, DEFAULT_REFRESH_DELAY_STRATEGY); + int expectedRefreshCountPerConnection = 3; + int connectionCount = 10; + AtomicInteger totalRefreshCount = new AtomicInteger(); + List> registrations = + range(0, connectionCount) + .mapToObj( + ignored -> { + CountDownLatch sync = new CountDownLatch(expectedRefreshCountPerConnection); + Registration r = + credentials.register( + "", + (username, password) -> { + totalRefreshCount.incrementAndGet(); + sync.countDown(); + }); + return pair(r, sync); + }) + .collect(toList()); + + registrations.forEach(r -> r.v1().connect(connectionCallback(() -> {}))); + for (Pair registrationPair : registrations) { + assertThat(registrationPair.v2().await(ofSeconds(10).toMillis(), MILLISECONDS)).isTrue(); + } + // all connections have been refreshed once + int refreshCountSnapshot = totalRefreshCount.get(); + assertThat(refreshCountSnapshot).isEqualTo(connectionCount * expectedRefreshCountPerConnection); + + // unregister half of the connections + int splitCount = connectionCount / 2; + registrations.subList(0, splitCount).forEach(r -> r.v1().close()); + // only the remaining connections should get refreshed again + waitAtMost( + timeout, waitTime, () -> totalRefreshCount.get() == refreshCountSnapshot + splitCount); + // waiting another round of refresh + waitAtMost( + timeout, waitTime, () -> totalRefreshCount.get() == refreshCountSnapshot + splitCount * 2); + // unregister all connections + registrations.forEach(r -> r.v1().close()); + // wait 2 expiry times + Thread.sleep(tokenExpiry.multipliedBy(2).toMillis()); + // no new refresh + assertThat(totalRefreshCount).hasValue(refreshCountSnapshot + splitCount * 2); + } + + @Test + void refreshDelayStrategy() { + Duration diff = ofMillis(100); + Function strategy = TokenCredentialsManager.ratioRefreshDelayStrategy(0.8f); + assertThat(strategy.apply(Instant.now().plusSeconds(10))).isCloseTo(ofSeconds(8), diff); + assertThat(strategy.apply(Instant.now().minusSeconds(10))).isEqualTo(ofSeconds(1)); + } + + private static Token token(String value, Instant expirationTime) { + return new Token() { + @Override + public String value() { + return value; + } + + @Override + public Instant expirationTime() { + return expirationTime; + } + }; + } + + private static CredentialsManager.AuthenticationCallback connectionCallback( + Runnable passwordCallback) { + return (username, password) -> passwordCallback.run(); + } +} diff --git a/src/test/resources/logback-test.xml b/src/test/resources/logback-test.xml index 0cf733c381..cc391f4fd9 100644 --- a/src/test/resources/logback-test.xml +++ b/src/test/resources/logback-test.xml @@ -11,6 +11,7 @@ +