Skip to content

Commit 29a39f7

Browse files
author
Zachary German
committed
Removing streamTools in favor of exchange-driven transport upgrading. Added extensive integration testing.
1 parent 9d7f9f7 commit 29a39f7

File tree

8 files changed

+643
-1060
lines changed

8 files changed

+643
-1060
lines changed

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import io.modelcontextprotocol.spec.McpTransportSession;
3030
import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException;
3131
import io.modelcontextprotocol.spec.McpTransportStream;
32+
import io.modelcontextprotocol.spec.McpSchema.JSONRPCNotification;
3233
import io.modelcontextprotocol.util.Assert;
3334
import reactor.core.Disposable;
3435
import reactor.core.publisher.Flux;
@@ -117,10 +118,6 @@ public static Builder builder(WebClient.Builder webClientBuilder) {
117118
public Mono<Void> connect(Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
118119
return Mono.deferContextual(ctx -> {
119120
this.handler.set(handler);
120-
if (openConnectionOnStartup) {
121-
logger.debug("Eagerly opening connection on startup");
122-
return this.reconnect(null).then();
123-
}
124121
return Mono.empty();
125122
});
126123
}
@@ -250,11 +247,13 @@ public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
250247
})
251248
.bodyValue(message)
252249
.exchangeToFlux(response -> {
253-
if (transportSession
254-
.markInitialized(response.headers().asHttpHeaders().getFirst("mcp-session-id"))) {
255-
// Once we have a session, we try to open an async stream for
256-
// the server to send notifications and requests out-of-band.
257-
reconnect(null).contextWrite(sink.contextView()).subscribe();
250+
transportSession.markInitialized(response.headers().asHttpHeaders().getFirst("mcp-session-id"));
251+
if (response.statusCode().is2xxSuccessful()
252+
&& message instanceof JSONRPCNotification notification) {
253+
if (notification.method().equals("notifications/initialized")) {
254+
// Establish SSE stream after session is initialized
255+
reconnect(null).contextWrite(sink.contextView()).subscribe();
256+
}
258257
}
259258

260259
String sessionRepresentation = sessionIdOrPlaceholder(transportSession);

mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/StreamableHttpTransportIntegrationTest.java

Lines changed: 501 additions & 106 deletions
Large diffs are not rendered by default.

mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java

Lines changed: 13 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -88,16 +88,12 @@ public class McpAsyncServer {
8888

8989
private final McpSchema.ServerCapabilities serverCapabilities;
9090

91-
private final boolean isStreamableHttp;
92-
9391
private final McpSchema.Implementation serverInfo;
9492

9593
private final String instructions;
9694

9795
private final CopyOnWriteArrayList<McpServerFeatures.AsyncToolSpecification> tools = new CopyOnWriteArrayList<>();
9896

99-
private final CopyOnWriteArrayList<McpServerFeatures.AsyncStreamingToolSpecification> streamTools = new CopyOnWriteArrayList<>();
100-
10197
private final CopyOnWriteArrayList<McpSchema.ResourceTemplate> resourceTemplates = new CopyOnWriteArrayList<>();
10298

10399
private final ConcurrentHashMap<String, McpServerFeatures.AsyncResourceSpecification> resources = new ConcurrentHashMap<>();
@@ -123,7 +119,7 @@ public class McpAsyncServer {
123119
*/
124120
McpAsyncServer(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper,
125121
McpServerFeatures.Async features, Duration requestTimeout,
126-
McpUriTemplateManagerFactory uriTemplateManagerFactory, boolean isStreamableHttp) {
122+
McpUriTemplateManagerFactory uriTemplateManagerFactory) {
127123
this.mcpTransportProvider = mcpTransportProvider;
128124
this.objectMapper = objectMapper;
129125
this.serverInfo = features.serverInfo();
@@ -135,7 +131,6 @@ public class McpAsyncServer {
135131
this.prompts.putAll(features.prompts());
136132
this.completions.putAll(features.completions());
137133
this.uriTemplateManagerFactory = uriTemplateManagerFactory;
138-
this.isStreamableHttp = isStreamableHttp;
139134

140135
Map<String, McpServerSession.RequestHandler<?>> requestHandlers = new HashMap<>();
141136

@@ -193,13 +188,6 @@ public class McpAsyncServer {
193188
notificationHandlers));
194189
}
195190

196-
// Alternate constructor for HTTP+SSE servers (past spec)
197-
McpAsyncServer(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper,
198-
McpServerFeatures.Async features, Duration requestTimeout,
199-
McpUriTemplateManagerFactory uriTemplateManagerFactory) {
200-
this(mcpTransportProvider, objectMapper, features, requestTimeout, uriTemplateManagerFactory, false);
201-
}
202-
203191
// ---------------------------------------
204192
// Lifecycle Management
205193
// ---------------------------------------
@@ -341,69 +329,6 @@ public Mono<Void> removeTool(String toolName) {
341329
});
342330
}
343331

344-
/**
345-
* Add a new tool specification at runtime.
346-
* @param toolSpecification The tool specification to add
347-
* @return Mono that completes when clients have been notified of the change
348-
*/
349-
public Mono<Void> addStreamTool(McpServerFeatures.AsyncStreamingToolSpecification toolSpecification) {
350-
if (toolSpecification == null) {
351-
return Mono.error(new McpError("Tool specification must not be null"));
352-
}
353-
if (toolSpecification.tool() == null) {
354-
return Mono.error(new McpError("Tool must not be null"));
355-
}
356-
if (toolSpecification.call() == null) {
357-
return Mono.error(new McpError("Tool call handler must not be null"));
358-
}
359-
if (this.serverCapabilities.tools() == null) {
360-
return Mono.error(new McpError("Server must be configured with tool capabilities"));
361-
}
362-
363-
return Mono.defer(() -> {
364-
// Check for duplicate tool names
365-
if (this.streamTools.stream().anyMatch(th -> th.tool().name().equals(toolSpecification.tool().name()))) {
366-
return Mono
367-
.error(new McpError("Tool with name '" + toolSpecification.tool().name() + "' already exists"));
368-
}
369-
370-
this.streamTools.add(toolSpecification);
371-
logger.debug("Added tool handler: {}", toolSpecification.tool().name());
372-
373-
if (this.serverCapabilities.tools().listChanged()) {
374-
return notifyToolsListChanged();
375-
}
376-
return Mono.empty();
377-
});
378-
}
379-
380-
/**
381-
* Remove a tool handler at runtime.
382-
* @param toolName The name of the tool handler to remove
383-
* @return Mono that completes when clients have been notified of the change
384-
*/
385-
public Mono<Void> removeStreamTool(String toolName) {
386-
if (toolName == null) {
387-
return Mono.error(new McpError("Tool name must not be null"));
388-
}
389-
if (this.serverCapabilities.tools() == null) {
390-
return Mono.error(new McpError("Server must be configured with tool capabilities"));
391-
}
392-
393-
return Mono.defer(() -> {
394-
boolean removed = this.tools
395-
.removeIf(toolSpecification -> toolSpecification.tool().name().equals(toolName));
396-
if (removed) {
397-
logger.debug("Removed tool handler: {}", toolName);
398-
if (this.serverCapabilities.tools().listChanged()) {
399-
return notifyToolsListChanged();
400-
}
401-
return Mono.empty();
402-
}
403-
return Mono.error(new McpError("Tool with name '" + toolName + "' not found"));
404-
});
405-
}
406-
407332
/**
408333
* Notifies clients that the list of available tools has changed.
409334
* @return A Mono that completes when all clients have been notified
@@ -416,95 +341,27 @@ private McpServerSession.RequestHandler<McpSchema.ListToolsResult> toolsListRequ
416341
return (exchange, params) -> {
417342
List<Tool> tools = new ArrayList<>();
418343
tools.addAll(this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList());
419-
tools.addAll(
420-
this.streamTools.stream().map(McpServerFeatures.AsyncStreamingToolSpecification::tool).toList());
421344

422345
return Mono.just(new McpSchema.ListToolsResult(tools, null));
423346
};
424347
}
425348

426349
private McpServerSession.RequestHandler<CallToolResult> toolsCallRequestHandler() {
427-
if (isStreamableHttp) {
428-
return new McpServerSession.StreamingRequestHandler<CallToolResult>() {
429-
@Override
430-
public Mono<CallToolResult> handle(McpAsyncServerExchange exchange, Object params) {
431-
var callToolRequest = objectMapper.convertValue(params, McpSchema.CallToolRequest.class);
432-
433-
// Check regular tools first
434-
var regularTool = tools.stream()
435-
.filter(tool -> callToolRequest.name().equals(tool.tool().name()))
436-
.findFirst();
437-
438-
if (regularTool.isPresent()) {
439-
return regularTool.get().call().apply(exchange, callToolRequest.arguments());
440-
}
441-
442-
// Check streaming tools (take first result)
443-
var streamingTool = streamTools.stream()
444-
.filter(tool -> callToolRequest.name().equals(tool.tool().name()))
445-
.findFirst();
446-
447-
if (streamingTool.isPresent()) {
448-
return streamingTool.get().call().apply(exchange, callToolRequest.arguments()).next();
449-
}
450-
451-
return Mono.error(new McpError("Tool not found: " + callToolRequest.name()));
452-
}
453-
454-
@Override
455-
public Flux<CallToolResult> handleStreaming(McpAsyncServerExchange exchange, Object params) {
456-
var callToolRequest = objectMapper.convertValue(params, McpSchema.CallToolRequest.class);
457-
458-
// Check streaming tools first (preferred for streaming)
459-
var streamingTool = streamTools.stream()
460-
.filter(tool -> callToolRequest.name().equals(tool.tool().name()))
461-
.findFirst();
462-
463-
if (streamingTool.isPresent()) {
464-
return streamingTool.get().call().apply(exchange, callToolRequest.arguments());
465-
}
466-
467-
// Fallback to regular tools (convert Mono to Flux)
468-
var regularTool = tools.stream()
469-
.filter(tool -> callToolRequest.name().equals(tool.tool().name()))
470-
.findFirst();
471-
472-
if (regularTool.isPresent()) {
473-
return regularTool.get().call().apply(exchange, callToolRequest.arguments()).flux();
474-
}
475-
476-
return Flux.error(new McpError("Tool not found: " + callToolRequest.name()));
477-
}
478-
};
479-
}
480-
else {
481-
return (exchange, params) -> {
482-
McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(params,
483-
new TypeReference<McpSchema.CallToolRequest>() {
484-
});
485-
486-
// Check regular tools first
487-
Optional<McpServerFeatures.AsyncToolSpecification> toolSpecification = this.tools.stream()
488-
.filter(tr -> callToolRequest.name().equals(tr.tool().name()))
489-
.findAny();
490-
491-
if (toolSpecification.isPresent()) {
492-
return toolSpecification.get().call().apply(exchange, callToolRequest.arguments());
493-
}
350+
return (exchange, params) -> {
351+
McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(params,
352+
new TypeReference<McpSchema.CallToolRequest>() {
353+
});
494354

495-
// Check streaming tools (take first result)
496-
Optional<McpServerFeatures.AsyncStreamingToolSpecification> streamToolSpecification = this.streamTools
497-
.stream()
498-
.filter(tr -> callToolRequest.name().equals(tr.tool().name()))
499-
.findAny();
355+
Optional<McpServerFeatures.AsyncToolSpecification> toolSpecification = this.tools.stream()
356+
.filter(tr -> callToolRequest.name().equals(tr.tool().name()))
357+
.findAny();
500358

501-
if (streamToolSpecification.isPresent()) {
502-
return streamToolSpecification.get().call().apply(exchange, callToolRequest.arguments()).next();
503-
}
359+
if (toolSpecification.isPresent()) {
360+
return toolSpecification.get().call().apply(exchange, callToolRequest.arguments());
361+
}
504362

505-
return Mono.error(new McpError("Tool not found: " + callToolRequest.name()));
506-
};
507-
}
363+
return Mono.error(new McpError("Tool not found: " + callToolRequest.name()));
364+
};
508365
}
509366

510367
// ---------------------------------------

mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,14 @@
88
import java.util.Collections;
99

1010
import com.fasterxml.jackson.core.type.TypeReference;
11+
12+
import io.modelcontextprotocol.server.transport.StreamableHttpServerTransportProvider;
1113
import io.modelcontextprotocol.spec.McpError;
1214
import io.modelcontextprotocol.spec.McpSchema;
1315
import io.modelcontextprotocol.spec.McpSchema.LoggingLevel;
1416
import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification;
1517
import io.modelcontextprotocol.spec.McpServerSession;
18+
import io.modelcontextprotocol.spec.McpServerTransport;
1619
import io.modelcontextprotocol.util.Assert;
1720
import reactor.core.publisher.Mono;
1821

@@ -79,6 +82,28 @@ public McpSchema.Implementation getClientInfo() {
7982
return this.clientInfo;
8083
}
8184

85+
/**
86+
* If the exchange's session is using StreamableHttp: Upgrades the transport
87+
* referenced by this exchange's transportId to an SSE stream if it isn't already one.
88+
*/
89+
private void establishSseStream() {
90+
final McpServerTransport currentTransport = session.getTransport(transportId);
91+
if (session.isStreamableHttp()
92+
&& currentTransport instanceof StreamableHttpServerTransportProvider.HttpTransport transport) {
93+
session.registerTransport(transportId,
94+
new StreamableHttpServerTransportProvider.SseTransport(transport.getObjectMapper(),
95+
transport.getResponse(), transport.getAsyncContext(), null, transportId, session.getId()));
96+
}
97+
}
98+
99+
/**
100+
* This is for tool writers to use if they want to send their tool response over an
101+
* SSE stream without using any other McpAsyncServerExchange methods
102+
*/
103+
public void upgradeTransport() {
104+
establishSseStream();
105+
}
106+
82107
/**
83108
* Create a new message using the sampling capabilities of the client. The Model
84109
* Context Protocol (MCP) provides a standardized way for servers to request LLM
@@ -96,6 +121,9 @@ public McpSchema.Implementation getClientInfo() {
96121
* Specification</a>
97122
*/
98123
public Mono<McpSchema.CreateMessageResult> createMessage(McpSchema.CreateMessageRequest createMessageRequest) {
124+
125+
establishSseStream();
126+
99127
if (this.clientCapabilities == null) {
100128
return Mono.error(new McpError("Client must be initialized. Call the initialize method first!"));
101129
}
@@ -121,6 +149,9 @@ public Mono<McpSchema.CreateMessageResult> createMessage(McpSchema.CreateMessage
121149
* Specification</a>
122150
*/
123151
public Mono<McpSchema.ElicitResult> createElicitation(McpSchema.ElicitRequest elicitRequest) {
152+
153+
establishSseStream();
154+
124155
if (this.clientCapabilities == null) {
125156
return Mono.error(new McpError("Client must be initialized. Call the initialize method first!"));
126157
}
@@ -157,6 +188,9 @@ public Mono<McpSchema.ListRootsResult> listRoots() {
157188
* @return A Mono that emits the list of roots result containing
158189
*/
159190
public Mono<McpSchema.ListRootsResult> listRoots(String cursor) {
191+
192+
establishSseStream();
193+
160194
return this.session.sendRequest(McpSchema.METHOD_ROOTS_LIST, new McpSchema.PaginatedRequest(cursor),
161195
LIST_ROOTS_RESULT_TYPE_REF, transportId);
162196
}
@@ -169,6 +203,8 @@ public Mono<McpSchema.ListRootsResult> listRoots(String cursor) {
169203
*/
170204
public Mono<Void> loggingNotification(LoggingMessageNotification loggingMessageNotification) {
171205

206+
establishSseStream();
207+
172208
if (loggingMessageNotification == null) {
173209
return Mono.error(new McpError("Logging message must not be null"));
174210
}
@@ -187,6 +223,9 @@ public Mono<Void> loggingNotification(LoggingMessageNotification loggingMessageN
187223
* @return A Mono that completes with clients's ping response
188224
*/
189225
public Mono<Object> ping() {
226+
227+
establishSseStream();
228+
190229
return this.session.sendRequest(McpSchema.METHOD_PING, null, OBJECT_TYPE_REF, transportId);
191230
}
192231

0 commit comments

Comments
 (0)