2
2
3
3
import com .fasterxml .jackson .core .type .TypeReference ;
4
4
import com .fasterxml .jackson .databind .ObjectMapper ;
5
+ import io .modelcontextprotocol .spec .DefaultMcpTransportContext ;
5
6
import io .modelcontextprotocol .spec .McpError ;
6
7
import io .modelcontextprotocol .spec .McpSchema ;
7
8
import io .modelcontextprotocol .spec .McpServerTransport ;
8
9
import io .modelcontextprotocol .spec .McpStreamableServerSession ;
9
10
import io .modelcontextprotocol .spec .McpStreamableServerTransportProvider ;
11
+ import io .modelcontextprotocol .spec .McpTransportContext ;
10
12
import io .modelcontextprotocol .util .Assert ;
11
13
import org .slf4j .Logger ;
12
14
import org .slf4j .LoggerFactory ;
25
27
26
28
import java .io .IOException ;
27
29
import java .util .concurrent .ConcurrentHashMap ;
30
+ import java .util .function .Function ;
28
31
29
32
public class WebFluxStreamableServerTransportProvider implements McpStreamableServerTransportProvider {
30
33
@@ -48,6 +51,9 @@ public class WebFluxStreamableServerTransportProvider implements McpStreamableSe
48
51
49
52
private final ConcurrentHashMap <String , McpStreamableServerSession > sessions = new ConcurrentHashMap <>();
50
53
54
+ // TODO: add means to specify this
55
+ private Function <ServerRequest , McpTransportContext > contextExtractor = req -> new DefaultMcpTransportContext ();
56
+
51
57
/**
52
58
* Flag indicating if the transport is shutting down.
53
59
*/
@@ -183,6 +189,8 @@ private Mono<ServerResponse> handleGet(ServerRequest request) {
183
189
return ServerResponse .status (HttpStatus .SERVICE_UNAVAILABLE ).bodyValue ("Server is shutting down" );
184
190
}
185
191
192
+ McpTransportContext transportContext = this .contextExtractor .apply (request );
193
+
186
194
return Mono .defer (() -> {
187
195
if (!request .headers ().asHttpHeaders ().containsKey ("mcp-session-id" )) {
188
196
return ServerResponse .badRequest ().build (); // TODO: say we need a session id
@@ -204,11 +212,11 @@ private Mono<ServerResponse> handleGet(ServerRequest request) {
204
212
return ServerResponse .ok ().contentType (MediaType .TEXT_EVENT_STREAM )
205
213
.body (Flux .<ServerSentEvent <?>>create (sink -> {
206
214
WebFluxStreamableMcpSessionTransport sessionTransport = new WebFluxStreamableMcpSessionTransport (sink );
207
- McpStreamableServerSession .McpStreamableServerSessionStream genericStream = session .newStream (sessionTransport );
208
- sink .onDispose (genericStream ::close );
215
+ McpStreamableServerSession .McpStreamableServerSessionStream listeningStream = session .listeningStream (sessionTransport );
216
+ sink .onDispose (listeningStream ::close );
209
217
}), ServerSentEvent .class );
210
218
211
- });
219
+ }). contextWrite ( ctx -> ctx . put ( McpTransportContext . KEY , transportContext )) ;
212
220
}
213
221
214
222
/**
@@ -231,6 +239,8 @@ private Mono<ServerResponse> handlePost(ServerRequest request) {
231
239
return ServerResponse .status (HttpStatus .SERVICE_UNAVAILABLE ).bodyValue ("Server is shutting down" );
232
240
}
233
241
242
+ McpTransportContext transportContext = this .contextExtractor .apply (request );
243
+
234
244
return request .bodyToMono (String .class ).<ServerResponse >flatMap (body -> {
235
245
try {
236
246
McpSchema .JSONRPCMessage message = McpSchema .deserializeJsonRpcMessage (objectMapper , body );
@@ -261,7 +271,7 @@ private Mono<ServerResponse> handlePost(ServerRequest request) {
261
271
return ServerResponse .ok ().contentType (MediaType .TEXT_EVENT_STREAM )
262
272
.body (Flux .<ServerSentEvent <?>>create (sink -> {
263
273
WebFluxStreamableMcpSessionTransport st = new WebFluxStreamableMcpSessionTransport (sink );
264
- Mono <Void > stream = session .handleStream (jsonrpcRequest , st );
274
+ Mono <Void > stream = session .responseStream (jsonrpcRequest , st );
265
275
Disposable streamSubscription = stream
266
276
.doOnError (err -> sink .error (err ))
267
277
.contextWrite (sink .contextView ())
@@ -276,7 +286,7 @@ private Mono<ServerResponse> handlePost(ServerRequest request) {
276
286
logger .error ("Failed to deserialize message: {}" , e .getMessage ());
277
287
return ServerResponse .badRequest ().bodyValue (new McpError ("Invalid message format" ));
278
288
}
279
- });
289
+ }). contextWrite ( ctx -> ctx . put ( McpTransportContext . KEY , transportContext )) ;
280
290
}
281
291
282
292
private class WebFluxStreamableMcpSessionTransport implements McpServerTransport {
0 commit comments