diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java index ceeea31b..d60e927f 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -16,6 +16,9 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.DefaultMcpTransportContext; +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.server.McpTransportContextExtractor; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpServerSession; @@ -102,6 +105,8 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement /** Map of active client sessions, keyed by session ID */ private final Map sessions = new ConcurrentHashMap<>(); + private McpTransportContextExtractor contextExtractor; + /** Flag indicating if the transport is in the process of shutting down */ private final AtomicBoolean isClosing = new AtomicBoolean(false); @@ -144,7 +149,7 @@ public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String m @Deprecated public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint) { - this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null); + this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null, null); } /** @@ -163,11 +168,33 @@ public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String b @Deprecated public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint, Duration keepAliveInterval) { + this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval, null); + } + + /** + * Creates a new HttpServletSseServerTransportProvider instance with a custom SSE + * endpoint. + * @param objectMapper The JSON object mapper to use for message + * serialization/deserialization + * @param baseUrl The base URL for the server transport + * @param messageEndpoint The endpoint path where clients will send their messages + * @param sseEndpoint The endpoint path where clients will establish SSE connections + * @param keepAliveInterval The interval for keep-alive pings, or null to disable + * keep-alive functionality + * @param contextExtractor The extractor for transport context from the request. + * @deprecated Use the builder {@link #builder()} instead for better configuration + * options. + */ + @Deprecated + public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, + String sseEndpoint, Duration keepAliveInterval, + McpTransportContextExtractor contextExtractor) { this.objectMapper = objectMapper; this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; + this.contextExtractor = contextExtractor; if (keepAliveInterval != null) { @@ -339,10 +366,13 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) body.append(line); } + final McpTransportContext transportContext = contextExtractor.extract(request, + new DefaultMcpTransportContext()); McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body.toString()); // Process the message through the session's handle method - session.handle(message).block(); // Block for Servlet compatibility + // Block for Servlet compatibility + session.handle(message).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)).block(); response.setStatus(HttpServletResponse.SC_OK); } @@ -534,6 +564,8 @@ public static class Builder { private String sseEndpoint = DEFAULT_SSE_ENDPOINT; + private McpTransportContextExtractor contextExtractor = (serverRequest, context) -> context; + private Duration keepAliveInterval; /** @@ -583,6 +615,19 @@ public Builder sseEndpoint(String sseEndpoint) { return this; } + /** + * Sets the context extractor for extracting transport context from the request. + * @param contextExtractor The context extractor to use. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if contextExtractor is null + */ + public HttpServletSseServerTransportProvider.Builder contextExtractor( + McpTransportContextExtractor contextExtractor) { + Assert.notNull(contextExtractor, "Context extractor must not be null"); + this.contextExtractor = contextExtractor; + return this; + } + /** * Sets the interval for keep-alive pings. *

@@ -609,7 +654,7 @@ public HttpServletSseServerTransportProvider build() { throw new IllegalStateException("MessageEndpoint must be set"); } return new HttpServletSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint, - keepAliveInterval); + keepAliveInterval, contextExtractor); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 62985dc1..669c10b8 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -198,7 +198,9 @@ public Mono sendNotification(String method, Object params) { * @return a Mono that completes when the message is processed */ public Mono handle(McpSchema.JSONRPCMessage message) { - return Mono.defer(() -> { + return Mono.deferContextual(ctx -> { + McpTransportContext transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + // TODO handle errors for communication to without initialization happening // first if (message instanceof McpSchema.JSONRPCResponse response) { @@ -214,7 +216,7 @@ public Mono handle(McpSchema.JSONRPCMessage message) { } else if (message instanceof McpSchema.JSONRPCRequest request) { logger.debug("Received request: {}", request); - return handleIncomingRequest(request).onErrorResume(error -> { + return handleIncomingRequest(request, transportContext).onErrorResume(error -> { var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, error.getMessage(), null)); @@ -227,7 +229,7 @@ else if (message instanceof McpSchema.JSONRPCNotification notification) { // happening first logger.debug("Received notification: {}", notification); // TODO: in case of error, should the POST request be signalled? - return handleIncomingNotification(notification) + return handleIncomingNotification(notification, transportContext) .doOnError(error -> logger.error("Error handling notification: {}", error.getMessage())); } else { @@ -240,9 +242,11 @@ else if (message instanceof McpSchema.JSONRPCNotification notification) { /** * Handles an incoming JSON-RPC request by routing it to the appropriate handler. * @param request The incoming JSON-RPC request + * @param transportContext * @return A Mono containing the JSON-RPC response */ - private Mono handleIncomingRequest(McpSchema.JSONRPCRequest request) { + private Mono handleIncomingRequest(McpSchema.JSONRPCRequest request, + McpTransportContext transportContext) { return Mono.defer(() -> { Mono resultMono; if (McpSchema.METHOD_INITIALIZE.equals(request.method())) { @@ -266,7 +270,11 @@ private Mono handleIncomingRequest(McpSchema.JSONRPCR error.message(), error.data()))); } - resultMono = this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, request.params())); + resultMono = this.exchangeSink.asMono().flatMap(exchange -> { + McpAsyncServerExchange newExchange = new McpAsyncServerExchange(exchange.sessionId(), this, + exchange.getClientCapabilities(), exchange.getClientInfo(), transportContext); + return handler.handle(newExchange, request.params()); + }); } return resultMono .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) @@ -280,16 +288,18 @@ private Mono handleIncomingRequest(McpSchema.JSONRPCR /** * Handles an incoming JSON-RPC notification by routing it to the appropriate handler. * @param notification The incoming JSON-RPC notification + * @param transportContext * @return A Mono that completes when the notification is processed */ - private Mono handleIncomingNotification(McpSchema.JSONRPCNotification notification) { + private Mono handleIncomingNotification(McpSchema.JSONRPCNotification notification, + McpTransportContext transportContext) { return Mono.defer(() -> { if (McpSchema.METHOD_NOTIFICATION_INITIALIZED.equals(notification.method())) { this.state.lazySet(STATE_INITIALIZED); // FIXME: The session ID passed here is not the same as the one in the // legacy SSE transport. exchangeSink.tryEmitValue(new McpAsyncServerExchange(this.id, this, clientCapabilities.get(), - clientInfo.get(), McpTransportContext.EMPTY)); + clientInfo.get(), transportContext)); } var handler = notificationHandlers.get(notification.method()); @@ -297,7 +307,11 @@ private Mono handleIncomingNotification(McpSchema.JSONRPCNotification noti logger.warn("No handler registered for notification method: {}", notification); return Mono.empty(); } - return this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, notification.params())); + return this.exchangeSink.asMono().flatMap(exchange -> { + McpAsyncServerExchange newExchange = new McpAsyncServerExchange(exchange.sessionId(), this, + exchange.getClientCapabilities(), exchange.getClientInfo(), transportContext); + return handler.handle(newExchange, notification.params()); + }); }); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java index e2adb340..cc580bdd 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java @@ -10,6 +10,7 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertWith; import static org.awaitility.Awaitility.await; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.mock; import java.net.URI; @@ -28,6 +29,7 @@ import java.util.function.Function; import java.util.stream.Collectors; +import jakarta.servlet.http.HttpServletRequest; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; @@ -825,6 +827,61 @@ void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) { mcpServer.close(); } + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testToolCallSuccessWithTranportContextExtraction(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var expectedCallResponse = new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent("CALL RESPONSE; ctx=value")), null); + McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((exchange, request) -> { + + McpTransportContext transportContext = exchange.transportContext(); + assertTrue(transportContext != null, "transportContext should not be null"); + assertTrue(!transportContext.equals(McpTransportContext.EMPTY), "transportContext should not be empty"); + String ctxValue = (String) transportContext.get("important"); + + try { + HttpResponse response = HttpClient.newHttpClient() + .send(HttpRequest.newBuilder() + .uri(URI.create( + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) + .GET() + .build(), HttpResponse.BodyHandlers.ofString()); + String responseBody = response.body(); + assertThat(responseBody).isNotBlank(); + } + catch (Exception e) { + e.printStackTrace(); + } + + return new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent("CALL RESPONSE; ctx=" + ctxValue)), null); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull().isEqualTo(expectedCallResponse); + } + + mcpServer.close(); + } + @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient" }) void testToolListChangeHandlingSuccess(String clientType) { @@ -1531,4 +1588,9 @@ private double evaluateExpression(String expression) { }; } + protected static McpTransportContextExtractor extractor = (r, tc) -> { + tc.put("important", "value"); + return tc; + }; + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java index 56e74218..4435b8b4 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java @@ -40,6 +40,7 @@ public void before() { // Create and configure the transport provider mcpServerTransportProvider = HttpServletSseServerTransportProvider.builder() .objectMapper(new ObjectMapper()) + .contextExtractor(extractor) .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) .sseEndpoint(CUSTOM_SSE_ENDPOINT) .build(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java index 6ac10014..0815556b 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java @@ -38,6 +38,7 @@ public void before() { // Create and configure the transport provider mcpServerTransportProvider = HttpServletStreamableServerTransportProvider.builder() .objectMapper(new ObjectMapper()) + .contextExtractor(extractor) .mcpEndpoint(MESSAGE_ENDPOINT) .keepAliveInterval(Duration.ofSeconds(1)) .build();