Skip to content

Commit 9d7f9f7

Browse files
author
Zachary German
committed
Added per-request exchange instances to route server-sent JSONRPCMessage to their related transport streams
1 parent 91cacc1 commit 9d7f9f7

File tree

5 files changed

+238
-137
lines changed

5 files changed

+238
-137
lines changed

mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,7 @@ private McpServerSession.RequestHandler<Object> setLoggerRequestHandler() {
778778
new TypeReference<SetLevelRequest>() {
779779
});
780780

781+
// This will update both the exchange and session logging levels
781782
exchange.setMinLoggingLevel(newMinLoggingLevel.level());
782783

783784
// FIXME: this field is deprecated and should be removed together

mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,22 @@ public class McpAsyncServerExchange {
4545
public static final TypeReference<Object> OBJECT_TYPE_REF = new TypeReference<>() {
4646
};
4747

48+
private final String transportId;
49+
4850
/**
4951
* Create a new asynchronous exchange with the client.
5052
* @param session The server session representing a 1-1 interaction.
5153
* @param clientCapabilities The client capabilities that define the supported
5254
* features and functionality.
5355
* @param clientInfo The client implementation information.
56+
* @param transportId The transport ID to use for outgoing messages
5457
*/
5558
public McpAsyncServerExchange(McpServerSession session, McpSchema.ClientCapabilities clientCapabilities,
56-
McpSchema.Implementation clientInfo) {
59+
McpSchema.Implementation clientInfo, String transportId) {
5760
this.session = session;
5861
this.clientCapabilities = clientCapabilities;
5962
this.clientInfo = clientInfo;
63+
this.transportId = transportId;
6064
}
6165

6266
/**
@@ -99,7 +103,7 @@ public Mono<McpSchema.CreateMessageResult> createMessage(McpSchema.CreateMessage
99103
return Mono.error(new McpError("Client must be configured with sampling capabilities"));
100104
}
101105
return this.session.sendRequest(McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, createMessageRequest,
102-
CREATE_MESSAGE_RESULT_TYPE_REF);
106+
CREATE_MESSAGE_RESULT_TYPE_REF, transportId);
103107
}
104108

105109
/**
@@ -123,8 +127,8 @@ public Mono<McpSchema.ElicitResult> createElicitation(McpSchema.ElicitRequest el
123127
if (this.clientCapabilities.elicitation() == null) {
124128
return Mono.error(new McpError("Client must be configured with elicitation capabilities"));
125129
}
126-
return this.session.sendRequest(McpSchema.METHOD_ELICITATION_CREATE, elicitRequest,
127-
ELICITATION_RESULT_TYPE_REF);
130+
return this.session.sendRequest(McpSchema.METHOD_ELICITATION_CREATE, elicitRequest, ELICITATION_RESULT_TYPE_REF,
131+
transportId);
128132
}
129133

130134
/**
@@ -154,7 +158,7 @@ public Mono<McpSchema.ListRootsResult> listRoots() {
154158
*/
155159
public Mono<McpSchema.ListRootsResult> listRoots(String cursor) {
156160
return this.session.sendRequest(McpSchema.METHOD_ROOTS_LIST, new McpSchema.PaginatedRequest(cursor),
157-
LIST_ROOTS_RESULT_TYPE_REF);
161+
LIST_ROOTS_RESULT_TYPE_REF, transportId);
158162
}
159163

160164
/**
@@ -171,7 +175,8 @@ public Mono<Void> loggingNotification(LoggingMessageNotification loggingMessageN
171175

172176
return Mono.defer(() -> {
173177
if (this.isNotificationForLevelAllowed(loggingMessageNotification.level())) {
174-
return this.session.sendNotification(McpSchema.METHOD_NOTIFICATION_MESSAGE, loggingMessageNotification);
178+
return this.session.sendNotification(McpSchema.METHOD_NOTIFICATION_MESSAGE, loggingMessageNotification,
179+
transportId);
175180
}
176181
return Mono.empty();
177182
});
@@ -182,17 +187,19 @@ public Mono<Void> loggingNotification(LoggingMessageNotification loggingMessageN
182187
* @return A Mono that completes with clients's ping response
183188
*/
184189
public Mono<Object> ping() {
185-
return this.session.sendRequest(McpSchema.METHOD_PING, null, OBJECT_TYPE_REF);
190+
return this.session.sendRequest(McpSchema.METHOD_PING, null, OBJECT_TYPE_REF, transportId);
186191
}
187192

188193
/**
189194
* Set the minimum logging level for the client. Messages below this level will be
190195
* filtered out.
191196
* @param minLoggingLevel The minimum logging level
192197
*/
193-
void setMinLoggingLevel(LoggingLevel minLoggingLevel) {
198+
public void setMinLoggingLevel(LoggingLevel minLoggingLevel) {
194199
Assert.notNull(minLoggingLevel, "minLoggingLevel must not be null");
195200
this.minLoggingLevel = minLoggingLevel;
201+
// Also update the session level for future exchanges
202+
this.session.setMinLoggingLevel(minLoggingLevel);
196203
}
197204

198205
private boolean isNotificationForLevelAllowed(LoggingLevel loggingLevel) {

mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java

Lines changed: 80 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import io.modelcontextprotocol.server.McpAsyncServerExchange;
1515
import io.modelcontextprotocol.spec.SseEvent;
1616
import io.modelcontextprotocol.spec.McpSchema.McpId;
17+
import io.modelcontextprotocol.spec.McpError;
1718

1819
import org.slf4j.Logger;
1920
import org.slf4j.LoggerFactory;
@@ -53,8 +54,6 @@ public class McpServerSession implements McpSession {
5354

5455
private final Map<String, NotificationHandler> notificationHandlers;
5556

56-
private final Sinks.One<McpAsyncServerExchange> exchangeSink = Sinks.one();
57-
5857
private final AtomicReference<McpSchema.ClientCapabilities> clientCapabilities = new AtomicReference<>();
5958

6059
private final AtomicReference<McpSchema.Implementation> clientInfo = new AtomicReference<>();
@@ -73,6 +72,8 @@ public class McpServerSession implements McpSession {
7372

7473
private final Map<String, Map<String, SseEvent>> transportEventHistories = new ConcurrentHashMap<>();
7574

75+
private volatile McpSchema.LoggingLevel minLoggingLevel = McpSchema.LoggingLevel.INFO;
76+
7677
/**
7778
* Creates a new server session with the given parameters and the transport to use.
7879
* @param id session id
@@ -104,6 +105,16 @@ public McpServerSession(String id, Duration requestTimeout, InitRequestHandler i
104105
this(id, requestTimeout, null, initHandler, initNotificationHandler, requestHandlers, notificationHandlers);
105106
}
106107

108+
/**
109+
* Updates the session's minimum logging level for all future exchanges.
110+
*/
111+
public void setMinLoggingLevel(McpSchema.LoggingLevel level) {
112+
if (level != null) {
113+
this.minLoggingLevel = level;
114+
logger.debug("Updated session {} minimum logging level to {}", id, level);
115+
}
116+
}
117+
107118
/**
108119
* Retrieve the session initialization state
109120
* @return session initialization state
@@ -240,14 +251,26 @@ public RequestHandler<?> getRequestHandler(String method) {
240251

241252
@Override
242253
public <T> Mono<T> sendRequest(String method, Object requestParams, TypeReference<T> typeRef) {
243-
McpId requestId = this.generateRequestId();
254+
return sendRequest(method, requestParams, typeRef, LISTENING_TRANSPORT);
255+
}
256+
257+
public <T> Mono<T> sendRequest(String method, Object requestParams, TypeReference<T> typeRef, String transportId) {
258+
McpServerTransport transport = getTransport(transportId);
259+
if (transport == null) {
260+
// Fallback to listening transport if specific transport not found
261+
transport = getTransport(LISTENING_TRANSPORT);
262+
if (transport == null) {
263+
return Mono.error(new RuntimeException("Transport not found: " + transportId));
264+
}
265+
}
244266

267+
final McpServerTransport finalTransport = transport;
268+
McpId requestId = this.generateRequestId();
245269
return Mono.<McpSchema.JSONRPCResponse>create(sink -> {
246270
this.pendingResponses.put(requestId, sink);
247271
McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, method,
248272
requestId, requestParams);
249-
250-
Flux.from(listeningTransport.sendMessage(jsonrpcRequest)).subscribe(v -> {
273+
Flux.from(finalTransport.sendMessage(jsonrpcRequest)).subscribe(v -> {
251274
}, error -> {
252275
this.pendingResponses.remove(requestId);
253276
sink.error(error);
@@ -260,17 +283,29 @@ else if (typeRef.getType().equals(Void.class)) {
260283
sink.complete();
261284
}
262285
else {
263-
T result = listeningTransport.unmarshalFrom(jsonRpcResponse.result(), typeRef);
286+
T result = finalTransport.unmarshalFrom(jsonRpcResponse.result(), typeRef);
264287
sink.next(result);
265288
}
266289
});
267290
}
268291

269292
@Override
270293
public Mono<Void> sendNotification(String method, Object params) {
294+
return sendNotification(method, params, LISTENING_TRANSPORT);
295+
}
296+
297+
public Mono<Void> sendNotification(String method, Object params, String transportId) {
298+
McpServerTransport transport = getTransport(transportId);
299+
if (transport == null) {
300+
// Fallback to listening transport if specific transport not found
301+
transport = getTransport(LISTENING_TRANSPORT);
302+
if (transport == null) {
303+
return Mono.error(new RuntimeException("Transport not found: " + transportId));
304+
}
305+
}
271306
McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION,
272307
method, params);
273-
return this.listeningTransport.sendMessage(jsonrpcNotification);
308+
return transport.sendMessage(jsonrpcNotification);
274309
}
275310

276311
/**
@@ -300,26 +335,20 @@ public Mono<Void> handle(McpSchema.JSONRPCMessage message) {
300335
}
301336
else if (message instanceof McpSchema.JSONRPCRequest request) {
302337
logger.debug("Received request: {}", request);
303-
final String transportId;
304-
if (transports.isEmpty()) {
305-
transportId = LISTENING_TRANSPORT;
306-
}
307-
else {
308-
transportId = request.id().toString();
309-
}
338+
final String transportId = determineTransportId(request);
310339
return handleIncomingRequest(request).onErrorResume(error -> {
311340
var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null,
312341
new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR,
313342
error.getMessage(), null));
314-
McpServerTransport transport = getTransport(transportId);
343+
McpServerTransport transport = getTransportWithFallback(transportId);
315344
return transport != null ? transport.sendMessage(errorResponse).then(Mono.empty()) : Mono.empty();
316345
}).flatMap(response -> {
317-
McpServerTransport transport = getTransport(transportId);
346+
McpServerTransport transport = getTransportWithFallback(transportId);
318347
if (transport != null) {
319348
return transport.sendMessage(response);
320349
}
321350
else {
322-
return Mono.error(new RuntimeException("Transport not found: " + transportId));
351+
return Mono.error(new RuntimeException("No transport available"));
323352
}
324353
});
325354
}
@@ -369,10 +398,10 @@ private Mono<McpSchema.JSONRPCResponse> handleIncomingRequest(McpSchema.JSONRPCR
369398
error.message(), error.data())));
370399
}
371400

372-
// We would need to add request.id() as a parameter to handler.handle() if
373-
// we want client-request-driven requests/notifications to go to the
374-
// related stream
375-
resultMono = this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, request.params()));
401+
McpAsyncServerExchange requestExchange = new McpAsyncServerExchange(this, clientCapabilities.get(),
402+
clientInfo.get(), determineTransportId(request));
403+
requestExchange.setMinLoggingLevel(minLoggingLevel);
404+
resultMono = handler.handle(requestExchange, request.params());
376405
}
377406
return resultMono
378407
.map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null))
@@ -392,7 +421,6 @@ private Mono<Void> handleIncomingNotification(McpSchema.JSONRPCNotification noti
392421
return Mono.defer(() -> {
393422
if (McpSchema.METHOD_NOTIFICATION_INITIALIZED.equals(notification.method())) {
394423
this.state.lazySet(STATE_INITIALIZED);
395-
exchangeSink.tryEmitValue(new McpAsyncServerExchange(this, clientCapabilities.get(), clientInfo.get()));
396424
return this.initNotificationHandler.handle();
397425
}
398426

@@ -401,7 +429,10 @@ private Mono<Void> handleIncomingNotification(McpSchema.JSONRPCNotification noti
401429
logger.error("No handler registered for notification method: {}", notification.method());
402430
return Mono.empty();
403431
}
404-
return this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, notification.params()));
432+
McpAsyncServerExchange notificationExchange = new McpAsyncServerExchange(this, clientCapabilities.get(),
433+
clientInfo.get(), LISTENING_TRANSPORT);
434+
notificationExchange.setMinLoggingLevel(minLoggingLevel);
435+
return handler.handle(notificationExchange, notification.params());
405436
});
406437
}
407438

@@ -412,6 +443,32 @@ private MethodNotFoundError getMethodNotFoundError(String method) {
412443
return new MethodNotFoundError(method, "Method not found: " + method, null);
413444
}
414445

446+
/**
447+
* Determines the appropriate transport ID for a request. Uses request ID for
448+
* per-request routing only if a transport with that ID exists, otherwise falls back
449+
* to listening transport.
450+
*/
451+
private String determineTransportId(McpSchema.JSONRPCRequest request) {
452+
String requestTransportId = request.id().toString();
453+
// Check if a transport exists for this specific request ID
454+
if (getTransport(requestTransportId) != null) {
455+
return requestTransportId;
456+
}
457+
// Fallback to listening transport
458+
return LISTENING_TRANSPORT;
459+
}
460+
461+
/**
462+
* Gets a transport with fallback to listening transport.
463+
*/
464+
private McpServerTransport getTransportWithFallback(String transportId) {
465+
McpServerTransport transport = getTransport(transportId);
466+
if (transport == null) {
467+
transport = getTransport(LISTENING_TRANSPORT);
468+
}
469+
return transport;
470+
}
471+
415472
@Override
416473
public Mono<Void> closeGracefully() {
417474
return Mono.defer(() -> {

0 commit comments

Comments
 (0)