Skip to content

Commit 5d69dc7

Browse files
committed
add McpTransportContext capability to HttpServletSseServerTransportProvider
1 parent b6d01a3 commit 5d69dc7

File tree

3 files changed

+82
-11
lines changed

3 files changed

+82
-11
lines changed

mcp/src/main/java/io/modelcontextprotocol/server/DefaultMcpTransportContext.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package io.modelcontextprotocol.server;
22

33
import java.util.Map;
4+
import java.util.StringJoiner;
45
import java.util.concurrent.ConcurrentHashMap;
56

67
/**
@@ -42,4 +43,13 @@ public McpTransportContext copy() {
4243
return new DefaultMcpTransportContext(new ConcurrentHashMap<>(this.storage));
4344
}
4445

46+
// TODO for debugging
47+
48+
@Override
49+
public String toString() {
50+
return new StringJoiner(", ", DefaultMcpTransportContext.class.getSimpleName() + "[", "]")
51+
.add("storage=" + storage)
52+
.toString();
53+
}
54+
4555
}

mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515
import com.fasterxml.jackson.core.type.TypeReference;
1616
import com.fasterxml.jackson.databind.ObjectMapper;
17+
import io.modelcontextprotocol.server.DefaultMcpTransportContext;
18+
import io.modelcontextprotocol.server.McpTransportContext;
19+
import io.modelcontextprotocol.server.McpTransportContextExtractor;
1720
import io.modelcontextprotocol.spec.McpError;
1821
import io.modelcontextprotocol.spec.McpSchema;
1922
import io.modelcontextprotocol.spec.McpServerSession;
@@ -99,6 +102,8 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement
99102
/** Map of active client sessions, keyed by session ID */
100103
private final Map<String, McpServerSession> sessions = new ConcurrentHashMap<>();
101104

105+
private McpTransportContextExtractor<HttpServletRequest> contextExtractor;
106+
102107
/** Flag indicating if the transport is in the process of shutting down */
103108
private final AtomicBoolean isClosing = new AtomicBoolean(false);
104109

@@ -141,7 +146,7 @@ public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String m
141146
@Deprecated
142147
public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
143148
String sseEndpoint) {
144-
this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null);
149+
this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null, null);
145150
}
146151

147152
/**
@@ -160,11 +165,33 @@ public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String b
160165
@Deprecated
161166
public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
162167
String sseEndpoint, Duration keepAliveInterval) {
168+
this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval, null);
169+
}
170+
171+
/**
172+
* Creates a new HttpServletSseServerTransportProvider instance with a custom SSE
173+
* endpoint.
174+
* @param objectMapper The JSON object mapper to use for message
175+
* serialization/deserialization
176+
* @param baseUrl The base URL for the server transport
177+
* @param messageEndpoint The endpoint path where clients will send their messages
178+
* @param sseEndpoint The endpoint path where clients will establish SSE connections
179+
* @param keepAliveInterval The interval for keep-alive pings, or null to disable
180+
* @param contextExtractor The extractor for transport context from the request.
181+
* keep-alive functionality
182+
* @deprecated Use the builder {@link #builder()} instead for better configuration
183+
* options.
184+
*/
185+
@Deprecated
186+
public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
187+
String sseEndpoint, Duration keepAliveInterval,
188+
McpTransportContextExtractor<HttpServletRequest> contextExtractor) {
163189

164190
this.objectMapper = objectMapper;
165191
this.baseUrl = baseUrl;
166192
this.messageEndpoint = messageEndpoint;
167193
this.sseEndpoint = sseEndpoint;
194+
this.contextExtractor = contextExtractor;
168195

169196
if (keepAliveInterval != null) {
170197

@@ -336,10 +363,15 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
336363
body.append(line);
337364
}
338365

366+
final McpTransportContext transportContext = contextExtractor.extract(request,
367+
new DefaultMcpTransportContext());
339368
McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body.toString());
340369

341370
// Process the message through the session's handle method
342-
session.handle(message).block(); // Block for Servlet compatibility
371+
session.handle(message).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)).block(); // Block
372+
// for
373+
// Servlet
374+
// compatibility
343375

344376
response.setStatus(HttpServletResponse.SC_OK);
345377
}
@@ -531,6 +563,8 @@ public static class Builder {
531563

532564
private String sseEndpoint = DEFAULT_SSE_ENDPOINT;
533565

566+
private McpTransportContextExtractor<HttpServletRequest> contextExtractor = (serverRequest, context) -> context;
567+
534568
private Duration keepAliveInterval;
535569

536570
/**
@@ -580,6 +614,19 @@ public Builder sseEndpoint(String sseEndpoint) {
580614
return this;
581615
}
582616

617+
/**
618+
* Sets the context extractor for extracting transport context from the request.
619+
* @param contextExtractor The context extractor to use. Must not be null.
620+
* @return this builder instance
621+
* @throws IllegalArgumentException if contextExtractor is null
622+
*/
623+
public HttpServletSseServerTransportProvider.Builder contextExtractor(
624+
McpTransportContextExtractor<HttpServletRequest> contextExtractor) {
625+
Assert.notNull(contextExtractor, "Context extractor must not be null");
626+
this.contextExtractor = contextExtractor;
627+
return this;
628+
}
629+
583630
/**
584631
* Sets the interval for keep-alive pings.
585632
* <p>
@@ -606,7 +653,7 @@ public HttpServletSseServerTransportProvider build() {
606653
throw new IllegalStateException("MessageEndpoint must be set");
607654
}
608655
return new HttpServletSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint,
609-
keepAliveInterval);
656+
keepAliveInterval, contextExtractor);
610657
}
611658

612659
}

mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,9 @@ public Mono<Void> sendNotification(String method, Object params) {
194194
* @return a Mono that completes when the message is processed
195195
*/
196196
public Mono<Void> handle(McpSchema.JSONRPCMessage message) {
197-
return Mono.defer(() -> {
197+
return Mono.deferContextual(ctx -> {
198+
McpTransportContext transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY);
199+
198200
// TODO handle errors for communication to without initialization happening
199201
// first
200202
if (message instanceof McpSchema.JSONRPCResponse response) {
@@ -210,7 +212,7 @@ public Mono<Void> handle(McpSchema.JSONRPCMessage message) {
210212
}
211213
else if (message instanceof McpSchema.JSONRPCRequest request) {
212214
logger.debug("Received request: {}", request);
213-
return handleIncomingRequest(request).onErrorResume(error -> {
215+
return handleIncomingRequest(request, transportContext).onErrorResume(error -> {
214216
var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null,
215217
new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR,
216218
error.getMessage(), null));
@@ -223,7 +225,7 @@ else if (message instanceof McpSchema.JSONRPCNotification notification) {
223225
// happening first
224226
logger.debug("Received notification: {}", notification);
225227
// TODO: in case of error, should the POST request be signalled?
226-
return handleIncomingNotification(notification)
228+
return handleIncomingNotification(notification, transportContext)
227229
.doOnError(error -> logger.error("Error handling notification: {}", error.getMessage()));
228230
}
229231
else {
@@ -236,9 +238,11 @@ else if (message instanceof McpSchema.JSONRPCNotification notification) {
236238
/**
237239
* Handles an incoming JSON-RPC request by routing it to the appropriate handler.
238240
* @param request The incoming JSON-RPC request
241+
* @param transportContext
239242
* @return A Mono containing the JSON-RPC response
240243
*/
241-
private Mono<McpSchema.JSONRPCResponse> handleIncomingRequest(McpSchema.JSONRPCRequest request) {
244+
private Mono<McpSchema.JSONRPCResponse> handleIncomingRequest(McpSchema.JSONRPCRequest request,
245+
McpTransportContext transportContext) {
242246
return Mono.defer(() -> {
243247
Mono<?> resultMono;
244248
if (McpSchema.METHOD_INITIALIZE.equals(request.method())) {
@@ -262,7 +266,11 @@ private Mono<McpSchema.JSONRPCResponse> handleIncomingRequest(McpSchema.JSONRPCR
262266
error.message(), error.data())));
263267
}
264268

265-
resultMono = this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, request.params()));
269+
resultMono = this.exchangeSink.asMono().flatMap(exchange -> {
270+
McpAsyncServerExchange newExchange = new McpAsyncServerExchange(exchange.sessionId(), this,
271+
exchange.getClientCapabilities(), exchange.getClientInfo(), transportContext);
272+
return handler.handle(newExchange, request.params());
273+
});
266274
}
267275
return resultMono
268276
.map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null))
@@ -276,24 +284,30 @@ private Mono<McpSchema.JSONRPCResponse> handleIncomingRequest(McpSchema.JSONRPCR
276284
/**
277285
* Handles an incoming JSON-RPC notification by routing it to the appropriate handler.
278286
* @param notification The incoming JSON-RPC notification
287+
* @param transportContext
279288
* @return A Mono that completes when the notification is processed
280289
*/
281-
private Mono<Void> handleIncomingNotification(McpSchema.JSONRPCNotification notification) {
290+
private Mono<Void> handleIncomingNotification(McpSchema.JSONRPCNotification notification,
291+
McpTransportContext transportContext) {
282292
return Mono.defer(() -> {
283293
if (McpSchema.METHOD_NOTIFICATION_INITIALIZED.equals(notification.method())) {
284294
this.state.lazySet(STATE_INITIALIZED);
285295
// FIXME: The session ID passed here is not the same as the one in the
286296
// legacy SSE transport.
287297
exchangeSink.tryEmitValue(new McpAsyncServerExchange(this.id, this, clientCapabilities.get(),
288-
clientInfo.get(), McpTransportContext.EMPTY));
298+
clientInfo.get(), transportContext));
289299
}
290300

291301
var handler = notificationHandlers.get(notification.method());
292302
if (handler == null) {
293303
logger.error("No handler registered for notification method: {}", notification.method());
294304
return Mono.empty();
295305
}
296-
return this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, notification.params()));
306+
return this.exchangeSink.asMono().flatMap(exchange -> {
307+
McpAsyncServerExchange newExchange = new McpAsyncServerExchange(exchange.sessionId(), this,
308+
exchange.getClientCapabilities(), exchange.getClientInfo(), transportContext);
309+
return handler.handle(newExchange, notification.params());
310+
});
297311
});
298312
}
299313

0 commit comments

Comments
 (0)