Skip to content

Feature/mcp transport context for http servlet sse server transport provider #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.server.DefaultMcpTransportContext;
import io.modelcontextprotocol.server.McpTransportContext;
import io.modelcontextprotocol.server.McpTransportContextExtractor;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpServerSession;
Expand Down Expand Up @@ -102,6 +105,8 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement
/** Map of active client sessions, keyed by session ID */
private final Map<String, McpServerSession> sessions = new ConcurrentHashMap<>();

private McpTransportContextExtractor<HttpServletRequest> contextExtractor;

/** Flag indicating if the transport is in the process of shutting down */
private final AtomicBoolean isClosing = new AtomicBoolean(false);

Expand Down Expand Up @@ -144,7 +149,7 @@ public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String m
@Deprecated
public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
String sseEndpoint) {
this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null);
this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null, null);
}

/**
Expand All @@ -163,11 +168,33 @@ public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String b
@Deprecated
public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
String sseEndpoint, Duration keepAliveInterval) {
this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval, null);
}

/**
* Creates a new HttpServletSseServerTransportProvider instance with a custom SSE
* endpoint.
* @param objectMapper The JSON object mapper to use for message
* serialization/deserialization
* @param baseUrl The base URL for the server transport
* @param messageEndpoint The endpoint path where clients will send their messages
* @param sseEndpoint The endpoint path where clients will establish SSE connections
* @param keepAliveInterval The interval for keep-alive pings, or null to disable
* keep-alive functionality
* @param contextExtractor The extractor for transport context from the request.
* @deprecated Use the builder {@link #builder()} instead for better configuration
* options.
*/
@Deprecated
public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
String sseEndpoint, Duration keepAliveInterval,
McpTransportContextExtractor<HttpServletRequest> contextExtractor) {

this.objectMapper = objectMapper;
this.baseUrl = baseUrl;
this.messageEndpoint = messageEndpoint;
this.sseEndpoint = sseEndpoint;
this.contextExtractor = contextExtractor;

if (keepAliveInterval != null) {

Expand Down Expand Up @@ -339,10 +366,13 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
body.append(line);
}

final McpTransportContext transportContext = contextExtractor.extract(request,
new DefaultMcpTransportContext());
McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body.toString());

// Process the message through the session's handle method
session.handle(message).block(); // Block for Servlet compatibility
// Block for Servlet compatibility
session.handle(message).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)).block();

response.setStatus(HttpServletResponse.SC_OK);
}
Expand Down Expand Up @@ -534,6 +564,8 @@ public static class Builder {

private String sseEndpoint = DEFAULT_SSE_ENDPOINT;

private McpTransportContextExtractor<HttpServletRequest> contextExtractor = (serverRequest, context) -> context;

private Duration keepAliveInterval;

/**
Expand Down Expand Up @@ -583,6 +615,19 @@ public Builder sseEndpoint(String sseEndpoint) {
return this;
}

/**
* Sets the context extractor for extracting transport context from the request.
* @param contextExtractor The context extractor to use. Must not be null.
* @return this builder instance
* @throws IllegalArgumentException if contextExtractor is null
*/
public HttpServletSseServerTransportProvider.Builder contextExtractor(
McpTransportContextExtractor<HttpServletRequest> contextExtractor) {
Assert.notNull(contextExtractor, "Context extractor must not be null");
this.contextExtractor = contextExtractor;
return this;
}

/**
* Sets the interval for keep-alive pings.
* <p>
Expand All @@ -609,7 +654,7 @@ public HttpServletSseServerTransportProvider build() {
throw new IllegalStateException("MessageEndpoint must be set");
}
return new HttpServletSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint,
keepAliveInterval);
keepAliveInterval, contextExtractor);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,9 @@ public Mono<Void> sendNotification(String method, Object params) {
* @return a Mono that completes when the message is processed
*/
public Mono<Void> handle(McpSchema.JSONRPCMessage message) {
return Mono.defer(() -> {
return Mono.deferContextual(ctx -> {
McpTransportContext transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY);

// TODO handle errors for communication to without initialization happening
// first
if (message instanceof McpSchema.JSONRPCResponse response) {
Expand All @@ -214,7 +216,7 @@ public Mono<Void> handle(McpSchema.JSONRPCMessage message) {
}
else if (message instanceof McpSchema.JSONRPCRequest request) {
logger.debug("Received request: {}", request);
return handleIncomingRequest(request).onErrorResume(error -> {
return handleIncomingRequest(request, transportContext).onErrorResume(error -> {
var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null,
new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR,
error.getMessage(), null));
Expand All @@ -227,7 +229,7 @@ else if (message instanceof McpSchema.JSONRPCNotification notification) {
// happening first
logger.debug("Received notification: {}", notification);
// TODO: in case of error, should the POST request be signalled?
return handleIncomingNotification(notification)
return handleIncomingNotification(notification, transportContext)
.doOnError(error -> logger.error("Error handling notification: {}", error.getMessage()));
}
else {
Expand All @@ -240,9 +242,11 @@ else if (message instanceof McpSchema.JSONRPCNotification notification) {
/**
* Handles an incoming JSON-RPC request by routing it to the appropriate handler.
* @param request The incoming JSON-RPC request
* @param transportContext
* @return A Mono containing the JSON-RPC response
*/
private Mono<McpSchema.JSONRPCResponse> handleIncomingRequest(McpSchema.JSONRPCRequest request) {
private Mono<McpSchema.JSONRPCResponse> handleIncomingRequest(McpSchema.JSONRPCRequest request,
McpTransportContext transportContext) {
return Mono.defer(() -> {
Mono<?> resultMono;
if (McpSchema.METHOD_INITIALIZE.equals(request.method())) {
Expand All @@ -266,7 +270,11 @@ private Mono<McpSchema.JSONRPCResponse> handleIncomingRequest(McpSchema.JSONRPCR
error.message(), error.data())));
}

resultMono = this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, request.params()));
resultMono = this.exchangeSink.asMono().flatMap(exchange -> {
McpAsyncServerExchange newExchange = new McpAsyncServerExchange(exchange.sessionId(), this,
exchange.getClientCapabilities(), exchange.getClientInfo(), transportContext);
return handler.handle(newExchange, request.params());
});
}
return resultMono
.map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null))
Expand All @@ -280,24 +288,30 @@ private Mono<McpSchema.JSONRPCResponse> handleIncomingRequest(McpSchema.JSONRPCR
/**
* Handles an incoming JSON-RPC notification by routing it to the appropriate handler.
* @param notification The incoming JSON-RPC notification
* @param transportContext
* @return A Mono that completes when the notification is processed
*/
private Mono<Void> handleIncomingNotification(McpSchema.JSONRPCNotification notification) {
private Mono<Void> handleIncomingNotification(McpSchema.JSONRPCNotification notification,
McpTransportContext transportContext) {
return Mono.defer(() -> {
if (McpSchema.METHOD_NOTIFICATION_INITIALIZED.equals(notification.method())) {
this.state.lazySet(STATE_INITIALIZED);
// FIXME: The session ID passed here is not the same as the one in the
// legacy SSE transport.
exchangeSink.tryEmitValue(new McpAsyncServerExchange(this.id, this, clientCapabilities.get(),
clientInfo.get(), McpTransportContext.EMPTY));
clientInfo.get(), transportContext));
}

var handler = notificationHandlers.get(notification.method());
if (handler == null) {
logger.warn("No handler registered for notification method: {}", notification);
return Mono.empty();
}
return this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, notification.params()));
return this.exchangeSink.asMono().flatMap(exchange -> {
McpAsyncServerExchange newExchange = new McpAsyncServerExchange(exchange.sessionId(), this,
exchange.getClientCapabilities(), exchange.getClientInfo(), transportContext);
return handler.handle(newExchange, notification.params());
});
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertWith;
import static org.awaitility.Awaitility.await;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;

import java.net.URI;
Expand All @@ -28,6 +29,7 @@
import java.util.function.Function;
import java.util.stream.Collectors;

import jakarta.servlet.http.HttpServletRequest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;

Expand Down Expand Up @@ -825,6 +827,61 @@ void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) {
mcpServer.close();
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
void testToolCallSuccessWithTranportContextExtraction(String clientType) {

var clientBuilder = clientBuilders.get(clientType);

var expectedCallResponse = new McpSchema.CallToolResult(
List.of(new McpSchema.TextContent("CALL RESPONSE; ctx=value")), null);
McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder()
.tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build())
.callHandler((exchange, request) -> {

McpTransportContext transportContext = exchange.transportContext();
assertTrue(transportContext != null, "transportContext should not be null");
assertTrue(!transportContext.equals(McpTransportContext.EMPTY), "transportContext should not be empty");
String ctxValue = (String) transportContext.get("important");

try {
HttpResponse<String> response = HttpClient.newHttpClient()
.send(HttpRequest.newBuilder()
.uri(URI.create(
"https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md"))
.GET()
.build(), HttpResponse.BodyHandlers.ofString());
String responseBody = response.body();
assertThat(responseBody).isNotBlank();
}
catch (Exception e) {
e.printStackTrace();
}

return new McpSchema.CallToolResult(
List.of(new McpSchema.TextContent("CALL RESPONSE; ctx=" + ctxValue)), null);
})
.build();

var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build())
.tools(tool1)
.build();

try (var mcpClient = clientBuilder.build()) {

InitializeResult initResult = mcpClient.initialize();
assertThat(initResult).isNotNull();

assertThat(mcpClient.listTools().tools()).contains(tool1.tool());

CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));

assertThat(response).isNotNull().isEqualTo(expectedCallResponse);
}

mcpServer.close();
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "httpclient" })
void testToolListChangeHandlingSuccess(String clientType) {
Expand Down Expand Up @@ -1531,4 +1588,9 @@ private double evaluateExpression(String expression) {
};
}

protected static McpTransportContextExtractor<HttpServletRequest> extractor = (r, tc) -> {
tc.put("important", "value");
return tc;
};

}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public void before() {
// Create and configure the transport provider
mcpServerTransportProvider = HttpServletSseServerTransportProvider.builder()
.objectMapper(new ObjectMapper())
.contextExtractor(extractor)
.messageEndpoint(CUSTOM_MESSAGE_ENDPOINT)
.sseEndpoint(CUSTOM_SSE_ENDPOINT)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ public void before() {
// Create and configure the transport provider
mcpServerTransportProvider = HttpServletStreamableServerTransportProvider.builder()
.objectMapper(new ObjectMapper())
.contextExtractor(extractor)
.mcpEndpoint(MESSAGE_ENDPOINT)
.keepAliveInterval(Duration.ofSeconds(1))
.build();
Expand Down