Skip to content

Commit b4b33d6

Browse files
committed
conflict resolution
Signed-off-by: Christian Tzolov <[email protected]>
1 parent 83f931c commit b4b33d6

File tree

1 file changed

+31
-12
lines changed

1 file changed

+31
-12
lines changed

mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,14 @@
1919
import com.fasterxml.jackson.core.type.TypeReference;
2020
import com.fasterxml.jackson.databind.ObjectMapper;
2121

22-
import io.modelcontextprotocol.spec.DefaultMcpTransportContext;
22+
import io.modelcontextprotocol.server.DefaultMcpTransportContext;
23+
import io.modelcontextprotocol.server.McpTransportContext;
24+
import io.modelcontextprotocol.server.McpTransportContextExtractor;
2325
import io.modelcontextprotocol.spec.McpError;
2426
import io.modelcontextprotocol.spec.McpSchema;
2527
import io.modelcontextprotocol.spec.McpStreamableServerSession;
2628
import io.modelcontextprotocol.spec.McpStreamableServerTransport;
2729
import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider;
28-
import io.modelcontextprotocol.spec.McpTransportContext;
2930
import io.modelcontextprotocol.util.Assert;
3031
import jakarta.servlet.AsyncContext;
3132
import jakarta.servlet.ServletException;
@@ -117,8 +118,7 @@ public class HttpServletStreamableServerTransportProvider extends HttpServlet
117118
*/
118119
private final ConcurrentHashMap<String, McpStreamableServerSession> sessions = new ConcurrentHashMap<>();
119120

120-
// TODO: add means to specify this
121-
private Function<HttpServletRequest, McpTransportContext> contextExtractor = req -> new DefaultMcpTransportContext();
121+
private McpTransportContextExtractor<HttpServletRequest> contextExtractor;
122122

123123
/**
124124
* Flag indicating if the transport is shutting down.
@@ -132,16 +132,19 @@ public class HttpServletStreamableServerTransportProvider extends HttpServlet
132132
* @param mcpEndpoint The endpoint URI where clients should send their JSON-RPC
133133
* messages via HTTP. This endpoint will handle GET, POST, and DELETE requests.
134134
* @param disallowDelete Whether to disallow DELETE requests on the endpoint.
135+
* @param contextExtractor The extractor for transport context from the request.
135136
* @throws IllegalArgumentException if any parameter is null
136137
*/
137-
public HttpServletStreamableServerTransportProvider(ObjectMapper objectMapper, String mcpEndpoint,
138-
boolean disallowDelete) {
138+
private HttpServletStreamableServerTransportProvider(ObjectMapper objectMapper, String mcpEndpoint,
139+
boolean disallowDelete, McpTransportContextExtractor<HttpServletRequest> contextExtractor) {
139140
Assert.notNull(objectMapper, "ObjectMapper must not be null");
140141
Assert.notNull(mcpEndpoint, "MCP endpoint must not be null");
142+
Assert.notNull(contextExtractor, "Context extractor must not be null");
141143

142144
this.objectMapper = objectMapper;
143145
this.mcpEndpoint = mcpEndpoint;
144146
this.disallowDelete = disallowDelete;
147+
this.contextExtractor = contextExtractor;
145148
}
146149

147150
@Override
@@ -224,8 +227,6 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
224227
return;
225228
}
226229

227-
McpTransportContext transportContext = this.contextExtractor.apply(request);
228-
229230
List<String> badRequestErrors = new ArrayList<>();
230231

231232
String accept = request.getHeader(ACCEPT);
@@ -254,6 +255,8 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
254255

255256
logger.debug("Handling GET request for session: {}", sessionId);
256257

258+
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
259+
257260
try {
258261
response.setContentType(TEXT_EVENT_STREAM);
259262
response.setCharacterEncoding(UTF_8);
@@ -277,7 +280,9 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
277280
.toIterable()
278281
.forEach(message -> {
279282
try {
280-
sessionTransport.sendMessage(message).block();
283+
sessionTransport.sendMessage(message)
284+
.contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext))
285+
.block();
281286
}
282287
catch (Exception e) {
283288
logger.error("Failed to replay message: {}", e.getMessage());
@@ -359,7 +364,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
359364
badRequestErrors.add("application/json required in Accept header");
360365
}
361366

362-
McpTransportContext transportContext = this.contextExtractor.apply(request);
367+
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
363368

364369
try {
365370
BufferedReader reader = request.getReader();
@@ -517,7 +522,7 @@ protected void doDelete(HttpServletRequest request, HttpServletResponse response
517522
return;
518523
}
519524

520-
McpTransportContext transportContext = this.contextExtractor.apply(request);
525+
McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
521526

522527
if (request.getHeader(MCP_SESSION_ID) == null) {
523528
this.responseError(response, HttpServletResponse.SC_BAD_REQUEST,
@@ -745,6 +750,8 @@ public static class Builder {
745750

746751
private boolean disallowDelete = false;
747752

753+
private McpTransportContextExtractor<HttpServletRequest> contextExtractor = (serverRequest, context) -> context;
754+
748755
/**
749756
* Sets the ObjectMapper to use for JSON serialization/deserialization of MCP
750757
* messages.
@@ -780,6 +787,18 @@ public Builder disallowDelete(boolean disallowDelete) {
780787
return this;
781788
}
782789

790+
/**
791+
* Sets the context extractor for extracting transport context from the request.
792+
* @param contextExtractor The context extractor to use. Must not be null.
793+
* @return this builder instance
794+
* @throws IllegalArgumentException if contextExtractor is null
795+
*/
796+
public Builder contextExtractor(McpTransportContextExtractor<HttpServletRequest> contextExtractor) {
797+
Assert.notNull(contextExtractor, "Context extractor must not be null");
798+
this.contextExtractor = contextExtractor;
799+
return this;
800+
}
801+
783802
/**
784803
* Builds a new instance of {@link HttpServletStreamableServerTransportProvider}
785804
* with the configured settings.
@@ -791,7 +810,7 @@ public HttpServletStreamableServerTransportProvider build() {
791810
Assert.notNull(this.mcpEndpoint, "MCP endpoint must be set");
792811

793812
return new HttpServletStreamableServerTransportProvider(this.objectMapper, this.mcpEndpoint,
794-
this.disallowDelete);
813+
this.disallowDelete, this.contextExtractor);
795814
}
796815

797816
}

0 commit comments

Comments
 (0)