diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/AsyncHttpRequestCustomizer.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/AsyncHttpRequestCustomizer.java new file mode 100644 index 000000000..dee026d96 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/AsyncHttpRequestCustomizer.java @@ -0,0 +1,52 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import java.net.URI; +import java.net.http.HttpRequest; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; +import reactor.util.annotation.Nullable; + +/** + * Customize {@link HttpRequest.Builder} before executing the request, in either SSE or + * Streamable HTTP transport. + *

+ * When used in a non-blocking context, implementations MUST be non-blocking. + * + * @author Daniel Garnier-Moiroux + */ +public interface AsyncHttpRequestCustomizer { + + Publisher customize(HttpRequest.Builder builder, String method, URI endpoint, + @Nullable String body); + + AsyncHttpRequestCustomizer NOOP = new Noop(); + + /** + * Wrap a sync implementation in an async wrapper. + *

+ * Do NOT wrap a blocking implementation for use in a non-blocking context. For a + * blocking implementation, consider using {@link Schedulers#boundedElastic()}. + */ + static AsyncHttpRequestCustomizer fromSync(SyncHttpRequestCustomizer customizer) { + return (builder, method, uri, body) -> Mono.fromSupplier(() -> { + customizer.customize(builder, method, uri, body); + return builder; + }); + } + + class Noop implements AsyncHttpRequestCustomizer { + + @Override + public Publisher customize(HttpRequest.Builder builder, String method, URI endpoint, + String body) { + return Mono.just(builder); + } + + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index b610ad93a..39fb0d461 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -1,5 +1,5 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2024 - 2025 the original author or authors. */ package io.modelcontextprotocol.client.transport; @@ -102,6 +102,11 @@ public class HttpClientSseClientTransport implements McpClientTransport { */ protected final Sinks.One messageEndpointSink = Sinks.one(); + /** + * Customizer to modify requests before they are executed. + */ + private final AsyncHttpRequestCustomizer httpRequestCustomizer; + /** * Creates a new transport instance with default HTTP client and object mapper. * @param baseUri the base URI of the MCP server @@ -172,18 +177,38 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpReques * @param objectMapper the object mapper for JSON serialization/deserialization * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null */ + @Deprecated(forRemoval = true) HttpClientSseClientTransport(HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri, String sseEndpoint, ObjectMapper objectMapper) { + this(httpClient, requestBuilder, baseUri, sseEndpoint, objectMapper, AsyncHttpRequestCustomizer.NOOP); + } + + /** + * Creates a new transport instance with custom HTTP client builder, object mapper, + * and headers. + * @param httpClient the HTTP client to use + * @param requestBuilder the HTTP request builder to use + * @param baseUri the base URI of the MCP server + * @param sseEndpoint the SSE endpoint path + * @param objectMapper the object mapper for JSON serialization/deserialization + * @param httpRequestCustomizer customizer for the requestBuilder before executing + * requests + * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null + */ + HttpClientSseClientTransport(HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri, + String sseEndpoint, ObjectMapper objectMapper, AsyncHttpRequestCustomizer httpRequestCustomizer) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); Assert.hasText(baseUri, "baseUri must not be empty"); Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); Assert.notNull(httpClient, "httpClient must not be null"); Assert.notNull(requestBuilder, "requestBuilder must not be null"); + Assert.notNull(httpRequestCustomizer, "httpRequestCustomizer must not be null"); this.baseUri = URI.create(baseUri); this.sseEndpoint = sseEndpoint; this.objectMapper = objectMapper; this.httpClient = httpClient; this.requestBuilder = requestBuilder; + this.httpRequestCustomizer = httpRequestCustomizer; } /** @@ -213,6 +238,8 @@ public static class Builder { private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder() .header("Content-Type", "application/json"); + private AsyncHttpRequestCustomizer httpRequestCustomizer = AsyncHttpRequestCustomizer.NOOP; + /** * Creates a new builder instance. */ @@ -310,31 +337,66 @@ public Builder objectMapper(ObjectMapper objectMapper) { return this; } + /** + * Sets the customizer for {@link HttpRequest.Builder}, to modify requests before + * executing them. + *

+ * This overrides the customizer from + * {@link #asyncHttpRequestCustomizer(AsyncHttpRequestCustomizer)}. + *

+ * Do NOT use a blocking {@link SyncHttpRequestCustomizer} in a non-blocking + * context. Use {@link #asyncHttpRequestCustomizer(AsyncHttpRequestCustomizer)} + * instead. + * @param syncHttpRequestCustomizer the request customizer + * @return this builder + */ + public Builder httpRequestCustomizer(SyncHttpRequestCustomizer syncHttpRequestCustomizer) { + this.httpRequestCustomizer = AsyncHttpRequestCustomizer.fromSync(syncHttpRequestCustomizer); + return this; + } + + /** + * Sets the customizer for {@link HttpRequest.Builder}, to modify requests before + * executing them. + *

+ * This overrides the customizer from + * {@link #httpRequestCustomizer(SyncHttpRequestCustomizer)}. + *

+ * Do NOT use a blocking implementation in a non-blocking context. + * @param asyncHttpRequestCustomizer the request customizer + * @return this builder + */ + public Builder asyncHttpRequestCustomizer(AsyncHttpRequestCustomizer asyncHttpRequestCustomizer) { + this.httpRequestCustomizer = asyncHttpRequestCustomizer; + return this; + } + /** * Builds a new {@link HttpClientSseClientTransport} instance. * @return a new transport instance */ public HttpClientSseClientTransport build() { return new HttpClientSseClientTransport(clientBuilder.build(), requestBuilder, baseUri, sseEndpoint, - objectMapper); + objectMapper, httpRequestCustomizer); } } @Override public Mono connect(Function, Mono> handler) { + var uri = Utils.resolveUri(this.baseUri, this.sseEndpoint); - return Mono.create(sink -> { - - HttpRequest request = requestBuilder.copy() - .uri(Utils.resolveUri(this.baseUri, this.sseEndpoint)) + return Mono.defer(() -> { + var builder = requestBuilder.copy() + .uri(uri) .header("Accept", "text/event-stream") .header("Cache-Control", "no-cache") - .GET() - .build(); - + .GET(); + return Mono.from(this.httpRequestCustomizer.customize(builder, "GET", uri, null)); + }).flatMap(requestBuilder -> Mono.create(sink -> { Disposable connection = Flux.create(sseSink -> this.httpClient - .sendAsync(request, responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, sseSink)) + .sendAsync(requestBuilder.build(), + responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, sseSink)) .exceptionallyCompose(e -> { sseSink.error(e); return CompletableFuture.failedFuture(e); @@ -397,7 +459,7 @@ else if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { .subscribe(); this.sseSubscription.set(connection); - }); + })); } /** @@ -453,13 +515,13 @@ private Mono serializeMessage(final JSONRPCMessage message) { private Mono> sendHttpPost(final String endpoint, final String body) { final URI requestUri = Utils.resolveUri(baseUri, endpoint); - final HttpRequest request = this.requestBuilder.copy() - .uri(requestUri) - .POST(HttpRequest.BodyPublishers.ofString(body)) - .build(); - - // TODO: why discard the body? - return Mono.fromFuture(httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())); + return Mono.defer(() -> { + var builder = this.requestBuilder.copy().uri(requestUri).POST(HttpRequest.BodyPublishers.ofString(body)); + return Mono.from(this.httpRequestCustomizer.customize(builder, "POST", requestUri, body)); + }).flatMap(customizedBuilder -> { + var request = customizedBuilder.build(); + return Mono.fromFuture(httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())); + }); } /** diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java index d8dd97f1e..799716584 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2024-2025 the original author or authors. */ package io.modelcontextprotocol.client.transport; @@ -109,6 +109,8 @@ public class HttpClientStreamableHttpTransport implements McpClientTransport { private final boolean resumableStreams; + private final AsyncHttpRequestCustomizer httpRequestCustomizer; + private final AtomicReference activeSession = new AtomicReference<>(); private final AtomicReference, Mono>> handler = new AtomicReference<>(); @@ -117,7 +119,7 @@ public class HttpClientStreamableHttpTransport implements McpClientTransport { private HttpClientStreamableHttpTransport(ObjectMapper objectMapper, HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri, String endpoint, boolean resumableStreams, - boolean openConnectionOnStartup) { + boolean openConnectionOnStartup, AsyncHttpRequestCustomizer httpRequestCustomizer) { this.objectMapper = objectMapper; this.httpClient = httpClient; this.requestBuilder = requestBuilder; @@ -126,6 +128,7 @@ private HttpClientStreamableHttpTransport(ObjectMapper objectMapper, HttpClient this.resumableStreams = resumableStreams; this.openConnectionOnStartup = openConnectionOnStartup; this.activeSession.set(createTransportSession()); + this.httpRequestCustomizer = httpRequestCustomizer; } public static Builder builder(String baseUri) { @@ -154,14 +157,18 @@ private DefaultMcpTransportSession createTransportSession() { } private Publisher createDelete(String sessionId) { - HttpRequest request = this.requestBuilder.copy() - .uri(Utils.resolveUri(this.baseUri, this.endpoint)) - .header("Cache-Control", "no-cache") - .header("mcp-session-id", sessionId) - .DELETE() - .build(); - - return Mono.fromFuture(() -> this.httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())).then(); + var uri = Utils.resolveUri(this.baseUri, this.endpoint); + return Mono.defer(() -> { + var builder = this.requestBuilder.copy() + .uri(uri) + .header("Cache-Control", "no-cache") + .header("mcp-session-id", sessionId) + .DELETE(); + return Mono.from(this.httpRequestCustomizer.customize(builder, "DELETE", uri, null)); + }).flatMap(requestBuilder -> { + var request = requestBuilder.build(); + return Mono.fromFuture(() -> this.httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())); + }).then(); } @Override @@ -208,100 +215,110 @@ private Mono reconnect(McpTransportStream stream) { final AtomicReference disposableRef = new AtomicReference<>(); final McpTransportSession transportSession = this.activeSession.get(); + var uri = Utils.resolveUri(this.baseUri, this.endpoint); - HttpRequest.Builder requestBuilder = this.requestBuilder.copy(); - - if (transportSession != null && transportSession.sessionId().isPresent()) { - requestBuilder = requestBuilder.header("mcp-session-id", transportSession.sessionId().get()); - } - - if (stream != null && stream.lastId().isPresent()) { - requestBuilder = requestBuilder.header("last-event-id", stream.lastId().get()); - } - - HttpRequest request = requestBuilder.uri(Utils.resolveUri(this.baseUri, this.endpoint)) - .header("Accept", TEXT_EVENT_STREAM) - .header("Cache-Control", "no-cache") - .GET() - .build(); - - Disposable connection = Flux.create(sseSink -> this.httpClient - .sendAsync(request, responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, sseSink)) - .whenComplete((response, throwable) -> { - if (throwable != null) { - sseSink.error(throwable); - } - else { - logger.debug("SSE connection established successfully"); - } - })) - .map(responseEvent -> (ResponseSubscribers.SseResponseEvent) responseEvent) - .flatMap(responseEvent -> { - int statusCode = responseEvent.responseInfo().statusCode(); - - if (statusCode >= 200 && statusCode < 300) { + Disposable connection = Mono.defer(() -> { + HttpRequest.Builder requestBuilder = this.requestBuilder.copy(); - if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { - try { - // We don't support batching ATM and probably won't since - // the - // next version considers removing it. - McpSchema.JSONRPCMessage message = McpSchema - .deserializeJsonRpcMessage(this.objectMapper, responseEvent.sseEvent().data()); - - Tuple2, Iterable> idWithMessages = Tuples - .of(Optional.ofNullable(responseEvent.sseEvent().id()), List.of(message)); + if (transportSession != null && transportSession.sessionId().isPresent()) { + requestBuilder = requestBuilder.header("mcp-session-id", transportSession.sessionId().get()); + } - McpTransportStream sessionStream = stream != null ? stream - : new DefaultMcpTransportStream<>(this.resumableStreams, this::reconnect); - logger.debug("Connected stream {}", sessionStream.streamId()); + if (stream != null && stream.lastId().isPresent()) { + requestBuilder = requestBuilder.header("last-event-id", stream.lastId().get()); + } - return Flux.from(sessionStream.consumeSseStream(Flux.just(idWithMessages))); + var builder = requestBuilder.uri(uri) + .header("Accept", TEXT_EVENT_STREAM) + .header("Cache-Control", "no-cache") + .GET(); + return Mono.from(this.httpRequestCustomizer.customize(builder, "GET", uri, null)); + }) + .flatMapMany( + requestBuilder -> Flux.create( + sseSink -> this.httpClient + .sendAsync(requestBuilder.build(), + responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, + sseSink)) + .whenComplete((response, throwable) -> { + if (throwable != null) { + sseSink.error(throwable); + } + else { + logger.debug("SSE connection established successfully"); + } + })) + .map(responseEvent -> (ResponseSubscribers.SseResponseEvent) responseEvent) + .flatMap(responseEvent -> { + int statusCode = responseEvent.responseInfo().statusCode(); + + if (statusCode >= 200 && statusCode < 300) { + + if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { + try { + // We don't support batching ATM and probably + // won't since the next version considers + // removing it. + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage( + this.objectMapper, responseEvent.sseEvent().data()); + + Tuple2, Iterable> idWithMessages = Tuples + .of(Optional.ofNullable(responseEvent.sseEvent().id()), + List.of(message)); + + McpTransportStream sessionStream = stream != null ? stream + : new DefaultMcpTransportStream<>(this.resumableStreams, + this::reconnect); + logger.debug("Connected stream {}", sessionStream.streamId()); + + return Flux.from(sessionStream.consumeSseStream(Flux.just(idWithMessages))); + + } + catch (IOException ioException) { + return Flux.error( + new McpError("Error parsing JSON-RPC message: " + + responseEvent.sseEvent().data())); + } + } + else { + logger.debug("Received SSE event with type: {}", responseEvent.sseEvent()); + return Flux.empty(); + } + } + else if (statusCode == METHOD_NOT_ALLOWED) { // NotAllowed + logger + .debug("The server does not support SSE streams, using request-response mode."); + return Flux.empty(); + } + else if (statusCode == NOT_FOUND) { + String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); + McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( + "Session not found for session ID: " + sessionIdRepresentation); + return Flux.error(exception); + } + else if (statusCode == BAD_REQUEST) { + String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); + McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( + "Session not found for session ID: " + sessionIdRepresentation); + return Flux.error(exception); + } - } - catch (IOException ioException) { return Flux.error(new McpError( - "Error parsing JSON-RPC message: " + responseEvent.sseEvent().data())); - } - } - else { - logger.debug("Received SSE event with type: {}", responseEvent.sseEvent()); - return Flux.empty(); - } - } - else if (statusCode == METHOD_NOT_ALLOWED) { // NotAllowed - logger.debug("The server does not support SSE streams, using request-response mode."); - return Flux.empty(); - } - else if (statusCode == NOT_FOUND) { - String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); - McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( - "Session not found for session ID: " + sessionIdRepresentation); - return Flux.error(exception); - } - else if (statusCode == BAD_REQUEST) { - String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); - McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( - "Session not found for session ID: " + sessionIdRepresentation); - return Flux.error(exception); - } - - return Flux.error( - new McpError("Received unrecognized SSE event type: " + responseEvent.sseEvent().event())); - - }).flatMap(jsonrpcMessage -> this.handler.get().apply(Mono.just(jsonrpcMessage))) - .onErrorMap(CompletionException.class, t -> t.getCause()) - .onErrorComplete(t -> { - this.handleException(t); - return true; - }) - .doFinally(s -> { - Disposable ref = disposableRef.getAndSet(null); - if (ref != null) { - transportSession.removeConnection(ref); - } - }) + "Received unrecognized SSE event type: " + responseEvent.sseEvent().event())); + }).flatMap( + jsonrpcMessage -> this.handler.get().apply(Mono.just(jsonrpcMessage))) + .onErrorMap(CompletionException.class, t -> t.getCause()) + .onErrorComplete(t -> { + this.handleException(t); + return true; + }) + .doFinally(s -> { + Disposable ref = disposableRef.getAndSet(null); + if (ref != null) { + transportSession.removeConnection(ref); + } + })) .contextWrite(ctx) .subscribe(); @@ -347,31 +364,33 @@ public String toString(McpSchema.JSONRPCMessage message) { } public Mono sendMessage(McpSchema.JSONRPCMessage sentMessage) { - return Mono.create(messageSink -> { + return Mono.create(deliveredSink -> { logger.debug("Sending message {}", sentMessage); final AtomicReference disposableRef = new AtomicReference<>(); final McpTransportSession transportSession = this.activeSession.get(); - HttpRequest.Builder requestBuilder = this.requestBuilder.copy(); - - if (transportSession != null && transportSession.sessionId().isPresent()) { - requestBuilder = requestBuilder.header("mcp-session-id", transportSession.sessionId().get()); - } - + var uri = Utils.resolveUri(this.baseUri, this.endpoint); String jsonBody = this.toString(sentMessage); - HttpRequest request = requestBuilder.uri(Utils.resolveUri(this.baseUri, this.endpoint)) - .header("Accept", APPLICATION_JSON + ", " + TEXT_EVENT_STREAM) - .header("Content-Type", APPLICATION_JSON) - .header("Cache-Control", "no-cache") - .POST(HttpRequest.BodyPublishers.ofString(jsonBody)) - .build(); + Disposable connection = Mono.defer(() -> { + HttpRequest.Builder requestBuilder = this.requestBuilder.copy(); - Disposable connection = Flux.create(responseEventSink -> { + if (transportSession != null && transportSession.sessionId().isPresent()) { + requestBuilder = requestBuilder.header("mcp-session-id", transportSession.sessionId().get()); + } + + var builder = requestBuilder.uri(uri) + .header("Accept", APPLICATION_JSON + ", " + TEXT_EVENT_STREAM) + .header("Content-Type", APPLICATION_JSON) + .header("Cache-Control", "no-cache") + .POST(HttpRequest.BodyPublishers.ofString(jsonBody)); + return Mono.from(this.httpRequestCustomizer.customize(builder, "GET", uri, jsonBody)); + }).flatMapMany(requestBuilder -> Flux.create(responseEventSink -> { // Create the async request with proper body subscriber selection - Mono.fromFuture(this.httpClient.sendAsync(request, this.toSendMessageBodySubscriber(responseEventSink)) + Mono.fromFuture(this.httpClient + .sendAsync(requestBuilder.build(), this.toSendMessageBodySubscriber(responseEventSink)) .whenComplete((response, throwable) -> { if (throwable != null) { responseEventSink.error(throwable); @@ -381,13 +400,13 @@ public Mono sendMessage(McpSchema.JSONRPCMessage sentMessage) { } })).onErrorMap(CompletionException.class, t -> t.getCause()).onErrorComplete().subscribe(); - }).flatMap(responseEvent -> { + })).flatMap(responseEvent -> { if (transportSession.markInitialized( responseEvent.responseInfo().headers().firstValue("mcp-session-id").orElseGet(() -> null))) { // Once we have a session, we try to open an async stream for // the server to send notifications and requests out-of-band. - reconnect(null).contextWrite(messageSink.contextView()).subscribe(); + reconnect(null).contextWrite(deliveredSink.contextView()).subscribe(); } String sessionRepresentation = sessionIdOrPlaceholder(transportSession); @@ -404,16 +423,18 @@ public Mono sendMessage(McpSchema.JSONRPCMessage sentMessage) { if (contentType.isBlank()) { logger.debug("No content type returned for POST in session {}", sessionRepresentation); - // No content type means no response body, so we can just return + // No content type means no response body, so we can just + // return // an empty stream - messageSink.success(); + deliveredSink.success(); return Flux.empty(); } else if (contentType.contains(TEXT_EVENT_STREAM)) { return Flux.just(((ResponseSubscribers.SseResponseEvent) responseEvent).sseEvent()) .flatMap(sseEvent -> { try { - // We don't support batching ATM and probably won't + // We don't support batching ATM and probably + // won't // since the // next version considers removing it. McpSchema.JSONRPCMessage message = McpSchema @@ -427,7 +448,7 @@ else if (contentType.contains(TEXT_EVENT_STREAM)) { logger.debug("Connected stream {}", sessionStream.streamId()); - messageSink.success(); + deliveredSink.success(); return Flux.from(sessionStream.consumeSseStream(Flux.just(idWithMessages))); } @@ -438,7 +459,7 @@ else if (contentType.contains(TEXT_EVENT_STREAM)) { }); } else if (contentType.contains(APPLICATION_JSON)) { - messageSink.success(); + deliveredSink.success(); String data = ((ResponseSubscribers.AggregateResponseEvent) responseEvent).data(); if (sentMessage instanceof McpSchema.JSONRPCNotification && Utils.hasText(data)) { logger.warn("Notification: {} received non-compliant response: {}", sentMessage, data); @@ -483,7 +504,7 @@ else if (statusCode == BAD_REQUEST) { // handle the error first this.handleException(t); // inform the caller of sendMessage - messageSink.error(t); + deliveredSink.error(t); return true; }) .doFinally(s -> { @@ -493,7 +514,7 @@ else if (statusCode == BAD_REQUEST) { transportSession.removeConnection(ref); } }) - .contextWrite(messageSink.contextView()) + .contextWrite(deliveredSink.contextView()) .subscribe(); disposableRef.set(connection); @@ -531,6 +552,8 @@ public static class Builder { private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(); + private AsyncHttpRequestCustomizer httpRequestCustomizer = AsyncHttpRequestCustomizer.NOOP; + /** * Creates a new builder with the specified base URI. * @param baseUri the base URI of the MCP server @@ -633,6 +656,40 @@ public Builder openConnectionOnStartup(boolean openConnectionOnStartup) { return this; } + /** + * Sets the customizer for {@link HttpRequest.Builder}, to modify requests before + * executing them. + *

+ * This overrides the customizer from + * {@link #asyncHttpRequestCustomizer(AsyncHttpRequestCustomizer)}. + *

+ * Do NOT use a blocking {@link SyncHttpRequestCustomizer} in a non-blocking + * context. Use {@link #asyncHttpRequestCustomizer(AsyncHttpRequestCustomizer)} + * instead. + * @param syncHttpRequestCustomizer the request customizer + * @return this builder + */ + public Builder httpRequestCustomizer(SyncHttpRequestCustomizer syncHttpRequestCustomizer) { + this.httpRequestCustomizer = AsyncHttpRequestCustomizer.fromSync(syncHttpRequestCustomizer); + return this; + } + + /** + * Sets the customizer for {@link HttpRequest.Builder}, to modify requests before + * executing them. + *

+ * This overrides the customizer from + * {@link #httpRequestCustomizer(SyncHttpRequestCustomizer)}. + *

+ * Do NOT use a blocking implementation in a non-blocking context. + * @param asyncHttpRequestCustomizer the request customizer + * @return this builder + */ + public Builder asyncHttpRequestCustomizer(AsyncHttpRequestCustomizer asyncHttpRequestCustomizer) { + this.httpRequestCustomizer = asyncHttpRequestCustomizer; + return this; + } + /** * Construct a fresh instance of {@link HttpClientStreamableHttpTransport} using * the current builder configuration. @@ -642,7 +699,7 @@ public HttpClientStreamableHttpTransport build() { ObjectMapper objectMapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); return new HttpClientStreamableHttpTransport(objectMapper, clientBuilder.build(), requestBuilder, baseUri, - endpoint, resumableStreams, openConnectionOnStartup); + endpoint, resumableStreams, openConnectionOnStartup, httpRequestCustomizer); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/SyncHttpRequestCustomizer.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/SyncHttpRequestCustomizer.java new file mode 100644 index 000000000..72b6e6c1b --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/SyncHttpRequestCustomizer.java @@ -0,0 +1,21 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import java.net.URI; +import java.net.http.HttpRequest; +import reactor.util.annotation.Nullable; + +/** + * Customize {@link HttpRequest.Builder} before executing the request, either in SSE or + * Streamable HTTP transport. + * + * @author Daniel Garnier-Moiroux + */ +public interface SyncHttpRequestCustomizer { + + void customize(HttpRequest.Builder builder, String method, URI endpoint, @Nullable String body); + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java index 31430543a..46b9207f6 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java @@ -17,10 +17,13 @@ import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; +import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; +import org.mockito.ArgumentCaptor; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; import reactor.core.publisher.Mono; @@ -28,9 +31,17 @@ import reactor.test.StepVerifier; import org.springframework.http.codec.ServerSentEvent; +import org.springframework.web.util.UriComponentsBuilder; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isNull; +import static org.mockito.Mockito.clearInvocations; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; /** * Tests for the {@link HttpClientSseClientTransport} class. @@ -43,7 +54,7 @@ class HttpClientSseClientTransportTests { static String host = "http://localhost:3001"; @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") .withCommand("node dist/index.js sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) @@ -61,7 +72,7 @@ static class TestHttpClientSseClientTransport extends HttpClientSseClientTranspo public TestHttpClientSseClientTransport(final String baseUri) { super(HttpClient.newBuilder().version(HttpClient.Version.HTTP_1_1).build(), HttpRequest.newBuilder().header("Content-Type", "application/json"), baseUri, "/sse", - new ObjectMapper()); + new ObjectMapper(), AsyncHttpRequestCustomizer.NOOP); } public int getInboundMessageCount() { @@ -80,15 +91,21 @@ public void simulateMessageEvent(String jsonMessage) { } - void startContainer() { + @BeforeAll + static void startContainer() { container.start(); int port = container.getMappedPort(3001); host = "http://" + container.getHost() + ":" + port; + + } + + @AfterAll + static void stopContainer() { + container.stop(); } @BeforeEach void setUp() { - startContainer(); transport = new TestHttpClientSseClientTransport(host); transport.connect(Function.identity()).block(); } @@ -98,11 +115,6 @@ void afterEach() { if (transport != null) { assertThatCode(() -> transport.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); } - cleanup(); - } - - void cleanup() { - container.stop(); } @Test @@ -375,4 +387,74 @@ void testChainedCustomizations() { customizedTransport.closeGracefully().block(); } + @Test + void testRequestCustomizer() { + var mockCustomizer = mock(SyncHttpRequestCustomizer.class); + + // Create a transport with the customizer + var customizedTransport = HttpClientSseClientTransport.builder(host) + .httpRequestCustomizer(mockCustomizer) + .build(); + + // Connect + StepVerifier.create(customizedTransport.connect(Function.identity())).verifyComplete(); + + // Verify the customizer was called + verify(mockCustomizer).customize(any(), eq("GET"), + eq(UriComponentsBuilder.fromUriString(host).path("/sse").build().toUri()), isNull()); + clearInvocations(mockCustomizer); + + // Send test message + var testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + Map.of("key", "value")); + + // Subscribe to messages and verify + StepVerifier.create(customizedTransport.sendMessage(testMessage)).verifyComplete(); + + // Verify the customizer was called + var uriArgumentCaptor = ArgumentCaptor.forClass(URI.class); + verify(mockCustomizer).customize(any(), eq("POST"), uriArgumentCaptor.capture(), eq( + "{\"jsonrpc\":\"2.0\",\"method\":\"test-method\",\"id\":\"test-id\",\"params\":{\"key\":\"value\"}}")); + assertThat(uriArgumentCaptor.getValue().toString()).startsWith(host + "/message?sessionId="); + + // Clean up + customizedTransport.closeGracefully().block(); + } + + @Test + void testAsyncRequestCustomizer() { + var mockCustomizer = mock(AsyncHttpRequestCustomizer.class); + when(mockCustomizer.customize(any(), any(), any(), any())) + .thenAnswer(invocation -> Mono.just(invocation.getArguments()[0])); + + // Create a transport with the customizer + var customizedTransport = HttpClientSseClientTransport.builder(host) + .asyncHttpRequestCustomizer(mockCustomizer) + .build(); + + // Connect + StepVerifier.create(customizedTransport.connect(Function.identity())).verifyComplete(); + + // Verify the customizer was called + verify(mockCustomizer).customize(any(), eq("GET"), + eq(UriComponentsBuilder.fromUriString(host).path("/sse").build().toUri()), isNull()); + clearInvocations(mockCustomizer); + + // Send test message + var testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + Map.of("key", "value")); + + // Subscribe to messages and verify + StepVerifier.create(customizedTransport.sendMessage(testMessage)).verifyComplete(); + + // Verify the customizer was called + var uriArgumentCaptor = ArgumentCaptor.forClass(URI.class); + verify(mockCustomizer).customize(any(), eq("POST"), uriArgumentCaptor.capture(), eq( + "{\"jsonrpc\":\"2.0\",\"method\":\"test-method\",\"id\":\"test-id\",\"params\":{\"key\":\"value\"}}")); + assertThat(uriArgumentCaptor.getValue().toString()).startsWith(host + "/message?sessionId="); + + // Clean up + customizedTransport.closeGracefully().block(); + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportTest.java b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportTest.java new file mode 100644 index 000000000..479468f63 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportTest.java @@ -0,0 +1,115 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import io.modelcontextprotocol.spec.McpSchema; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.function.Consumer; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for the {@link HttpClientStreamableHttpTransport} class. + * + * @author Daniel Garnier-Moiroux + */ +class HttpClientStreamableHttpTransportTest { + + static String host = "http://localhost:3001"; + + @SuppressWarnings("resource") + static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + @BeforeAll + static void startContainer() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @AfterAll + static void stopContainer() { + container.stop(); + } + + void withTransport(HttpClientStreamableHttpTransport transport, Consumer c) { + try { + c.accept(transport); + } + finally { + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + } + } + + @Test + void testRequestCustomizer() throws URISyntaxException { + var uri = new URI(host + "/mcp"); + var mockRequestCustomizer = mock(SyncHttpRequestCustomizer.class); + + var transport = HttpClientStreamableHttpTransport.builder(host) + .httpRequestCustomizer(mockRequestCustomizer) + .build(); + + withTransport(transport, (t) -> { + // Send test message + var initializeRequest = new McpSchema.InitializeRequest(McpSchema.LATEST_PROTOCOL_VERSION, + McpSchema.ClientCapabilities.builder().roots(true).build(), + new McpSchema.Implementation("Spring AI MCP Client", "0.3.1")); + var testMessage = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, + "test-id", initializeRequest); + + StepVerifier.create(t.sendMessage(testMessage)).verifyComplete(); + + // Verify the customizer was called + verify(mockRequestCustomizer, atLeastOnce()).customize(any(), eq("GET"), eq(uri), eq( + "{\"jsonrpc\":\"2.0\",\"method\":\"initialize\",\"id\":\"test-id\",\"params\":{\"protocolVersion\":\"2025-03-26\",\"capabilities\":{\"roots\":{\"listChanged\":true}},\"clientInfo\":{\"name\":\"Spring AI MCP Client\",\"version\":\"0.3.1\"}}}")); + }); + } + + @Test + void testAsyncRequestCustomizer() throws URISyntaxException { + var uri = new URI(host + "/mcp"); + var mockRequestCustomizer = mock(AsyncHttpRequestCustomizer.class); + when(mockRequestCustomizer.customize(any(), any(), any(), any())) + .thenAnswer(invocation -> Mono.just(invocation.getArguments()[0])); + + var transport = HttpClientStreamableHttpTransport.builder(host) + .asyncHttpRequestCustomizer(mockRequestCustomizer) + .build(); + + withTransport(transport, (t) -> { + // Send test message + var initializeRequest = new McpSchema.InitializeRequest(McpSchema.LATEST_PROTOCOL_VERSION, + McpSchema.ClientCapabilities.builder().roots(true).build(), + new McpSchema.Implementation("Spring AI MCP Client", "0.3.1")); + var testMessage = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, + "test-id", initializeRequest); + + StepVerifier.create(t.sendMessage(testMessage)).verifyComplete(); + + // Verify the customizer was called + verify(mockRequestCustomizer, atLeastOnce()).customize(any(), eq("GET"), eq(uri), eq( + "{\"jsonrpc\":\"2.0\",\"method\":\"initialize\",\"id\":\"test-id\",\"params\":{\"protocolVersion\":\"2025-03-26\",\"capabilities\":{\"roots\":{\"listChanged\":true}},\"clientInfo\":{\"name\":\"Spring AI MCP Client\",\"version\":\"0.3.1\"}}}")); + }); + } + +}