19
19
import com .fasterxml .jackson .core .type .TypeReference ;
20
20
import com .fasterxml .jackson .databind .ObjectMapper ;
21
21
22
- import io .modelcontextprotocol .spec .DefaultMcpTransportContext ;
22
+ import io .modelcontextprotocol .server .DefaultMcpTransportContext ;
23
+ import io .modelcontextprotocol .server .McpTransportContext ;
24
+ import io .modelcontextprotocol .server .McpTransportContextExtractor ;
23
25
import io .modelcontextprotocol .spec .McpError ;
24
26
import io .modelcontextprotocol .spec .McpSchema ;
25
27
import io .modelcontextprotocol .spec .McpStreamableServerSession ;
26
28
import io .modelcontextprotocol .spec .McpStreamableServerTransport ;
27
29
import io .modelcontextprotocol .spec .McpStreamableServerTransportProvider ;
28
- import io .modelcontextprotocol .spec .McpTransportContext ;
29
30
import io .modelcontextprotocol .util .Assert ;
30
31
import jakarta .servlet .AsyncContext ;
31
32
import jakarta .servlet .ServletException ;
@@ -117,8 +118,7 @@ public class HttpServletStreamableServerTransportProvider extends HttpServlet
117
118
*/
118
119
private final ConcurrentHashMap <String , McpStreamableServerSession > sessions = new ConcurrentHashMap <>();
119
120
120
- // TODO: add means to specify this
121
- private Function <HttpServletRequest , McpTransportContext > contextExtractor = req -> new DefaultMcpTransportContext ();
121
+ private McpTransportContextExtractor <HttpServletRequest > contextExtractor ;
122
122
123
123
/**
124
124
* Flag indicating if the transport is shutting down.
@@ -132,16 +132,19 @@ public class HttpServletStreamableServerTransportProvider extends HttpServlet
132
132
* @param mcpEndpoint The endpoint URI where clients should send their JSON-RPC
133
133
* messages via HTTP. This endpoint will handle GET, POST, and DELETE requests.
134
134
* @param disallowDelete Whether to disallow DELETE requests on the endpoint.
135
+ * @param contextExtractor The extractor for transport context from the request.
135
136
* @throws IllegalArgumentException if any parameter is null
136
137
*/
137
- public HttpServletStreamableServerTransportProvider (ObjectMapper objectMapper , String mcpEndpoint ,
138
- boolean disallowDelete ) {
138
+ private HttpServletStreamableServerTransportProvider (ObjectMapper objectMapper , String mcpEndpoint ,
139
+ boolean disallowDelete , McpTransportContextExtractor < HttpServletRequest > contextExtractor ) {
139
140
Assert .notNull (objectMapper , "ObjectMapper must not be null" );
140
141
Assert .notNull (mcpEndpoint , "MCP endpoint must not be null" );
142
+ Assert .notNull (contextExtractor , "Context extractor must not be null" );
141
143
142
144
this .objectMapper = objectMapper ;
143
145
this .mcpEndpoint = mcpEndpoint ;
144
146
this .disallowDelete = disallowDelete ;
147
+ this .contextExtractor = contextExtractor ;
145
148
}
146
149
147
150
@ Override
@@ -224,8 +227,6 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
224
227
return ;
225
228
}
226
229
227
- McpTransportContext transportContext = this .contextExtractor .apply (request );
228
-
229
230
List <String > badRequestErrors = new ArrayList <>();
230
231
231
232
String accept = request .getHeader (ACCEPT );
@@ -254,6 +255,8 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
254
255
255
256
logger .debug ("Handling GET request for session: {}" , sessionId );
256
257
258
+ McpTransportContext transportContext = this .contextExtractor .extract (request , new DefaultMcpTransportContext ());
259
+
257
260
try {
258
261
response .setContentType (TEXT_EVENT_STREAM );
259
262
response .setCharacterEncoding (UTF_8 );
@@ -277,7 +280,9 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
277
280
.toIterable ()
278
281
.forEach (message -> {
279
282
try {
280
- sessionTransport .sendMessage (message ).block ();
283
+ sessionTransport .sendMessage (message )
284
+ .contextWrite (ctx -> ctx .put (McpTransportContext .KEY , transportContext ))
285
+ .block ();
281
286
}
282
287
catch (Exception e ) {
283
288
logger .error ("Failed to replay message: {}" , e .getMessage ());
@@ -359,7 +364,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
359
364
badRequestErrors .add ("application/json required in Accept header" );
360
365
}
361
366
362
- McpTransportContext transportContext = this .contextExtractor .apply (request );
367
+ McpTransportContext transportContext = this .contextExtractor .extract (request , new DefaultMcpTransportContext () );
363
368
364
369
try {
365
370
BufferedReader reader = request .getReader ();
@@ -517,7 +522,7 @@ protected void doDelete(HttpServletRequest request, HttpServletResponse response
517
522
return ;
518
523
}
519
524
520
- McpTransportContext transportContext = this .contextExtractor .apply (request );
525
+ McpTransportContext transportContext = this .contextExtractor .extract (request , new DefaultMcpTransportContext () );
521
526
522
527
if (request .getHeader (MCP_SESSION_ID ) == null ) {
523
528
this .responseError (response , HttpServletResponse .SC_BAD_REQUEST ,
@@ -745,6 +750,8 @@ public static class Builder {
745
750
746
751
private boolean disallowDelete = false ;
747
752
753
+ private McpTransportContextExtractor <HttpServletRequest > contextExtractor = (serverRequest , context ) -> context ;
754
+
748
755
/**
749
756
* Sets the ObjectMapper to use for JSON serialization/deserialization of MCP
750
757
* messages.
@@ -780,6 +787,18 @@ public Builder disallowDelete(boolean disallowDelete) {
780
787
return this ;
781
788
}
782
789
790
+ /**
791
+ * Sets the context extractor for extracting transport context from the request.
792
+ * @param contextExtractor The context extractor to use. Must not be null.
793
+ * @return this builder instance
794
+ * @throws IllegalArgumentException if contextExtractor is null
795
+ */
796
+ public Builder contextExtractor (McpTransportContextExtractor <HttpServletRequest > contextExtractor ) {
797
+ Assert .notNull (contextExtractor , "Context extractor must not be null" );
798
+ this .contextExtractor = contextExtractor ;
799
+ return this ;
800
+ }
801
+
783
802
/**
784
803
* Builds a new instance of {@link HttpServletStreamableServerTransportProvider}
785
804
* with the configured settings.
@@ -791,7 +810,7 @@ public HttpServletStreamableServerTransportProvider build() {
791
810
Assert .notNull (this .mcpEndpoint , "MCP endpoint must be set" );
792
811
793
812
return new HttpServletStreamableServerTransportProvider (this .objectMapper , this .mcpEndpoint ,
794
- this .disallowDelete );
813
+ this .disallowDelete , this . contextExtractor );
795
814
}
796
815
797
816
}
0 commit comments