diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java index 53b59cb30..d5ac8e95c 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java @@ -30,6 +30,7 @@ import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; import io.modelcontextprotocol.spec.McpTransportStream; import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.Utils; import reactor.core.Disposable; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -244,7 +245,7 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { Disposable connection = webClient.post() .uri(this.endpoint) - .accept(MediaType.TEXT_EVENT_STREAM, MediaType.APPLICATION_JSON) + .accept(MediaType.APPLICATION_JSON, MediaType.TEXT_EVENT_STREAM) .headers(httpHeaders -> { transportSession.sessionId().ifPresent(id -> httpHeaders.add("mcp-session-id", id)); }) @@ -287,7 +288,7 @@ else if (mediaType.isCompatibleWith(MediaType.APPLICATION_JSON)) { logger.trace("Received response to POST for session {}", sessionRepresentation); // communicate to caller the message was delivered sink.success(); - return responseFlux(response); + return directResponseFlux(message, response); } else { logger.warn("Unknown media type {} returned for POST in session {}", contentType, @@ -384,14 +385,22 @@ private static String sessionIdOrPlaceholder(McpTransportSession transportSes return transportSession.sessionId().orElse("[missing_session_id]"); } - private Flux responseFlux(ClientResponse response) { + private Flux directResponseFlux(McpSchema.JSONRPCMessage sentMessage, + ClientResponse response) { return response.bodyToMono(String.class).>handle((responseMessage, s) -> { try { - McpSchema.JSONRPCMessage jsonRpcResponse = McpSchema.deserializeJsonRpcMessage(objectMapper, - responseMessage); - s.next(List.of(jsonRpcResponse)); + if (sentMessage instanceof McpSchema.JSONRPCNotification && Utils.hasText(responseMessage)) { + logger.warn("Notification: {} received non-compliant response: {}", sentMessage, responseMessage); + s.complete(); + } + else { + McpSchema.JSONRPCMessage jsonRpcResponse = McpSchema.deserializeJsonRpcMessage(objectMapper, + responseMessage); + s.next(List.of(jsonRpcResponse)); + } } catch (IOException e) { + // TODO: this should be a McpTransportError s.error(e); } }).flatMapIterable(Function.identity()); diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java index 4cf1690ff..12baa1706 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java @@ -342,9 +342,9 @@ public String toString(McpSchema.JSONRPCMessage message) { } } - public Mono sendMessage(McpSchema.JSONRPCMessage sendMessage) { + public Mono sendMessage(McpSchema.JSONRPCMessage sentMessage) { return Mono.create(messageSink -> { - logger.debug("Sending message {}", sendMessage); + logger.debug("Sending message {}", sentMessage); final AtomicReference disposableRef = new AtomicReference<>(); final McpTransportSession transportSession = this.activeSession.get(); @@ -355,10 +355,10 @@ public Mono sendMessage(McpSchema.JSONRPCMessage sendMessage) { requestBuilder = requestBuilder.header("mcp-session-id", transportSession.sessionId().get()); } - String jsonBody = this.toString(sendMessage); + String jsonBody = this.toString(sentMessage); HttpRequest request = requestBuilder.uri(Utils.resolveUri(this.baseUri, this.endpoint)) - .header("Accept", TEXT_EVENT_STREAM + ", " + APPLICATION_JSON) + .header("Accept", APPLICATION_JSON + ", " + TEXT_EVENT_STREAM) .header("Content-Type", APPLICATION_JSON) .header("Cache-Control", "no-cache") .POST(HttpRequest.BodyPublishers.ofString(jsonBody)) @@ -436,10 +436,16 @@ else if (contentType.contains(TEXT_EVENT_STREAM)) { else if (contentType.contains(APPLICATION_JSON)) { messageSink.success(); String data = ((ResponseSubscribers.AggregateResponseEvent) responseEvent).data(); + if (sentMessage instanceof McpSchema.JSONRPCNotification && Utils.hasText(data)) { + logger.warn("Notification: {} received non-compliant response: {}", sentMessage, data); + return Mono.empty(); + } + try { return Mono.just(McpSchema.deserializeJsonRpcMessage(objectMapper, data)); } catch (IOException e) { + // TODO: this should be a McpTransportError return Mono.error(e); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java index 26b0d13bd..eb9d3c65c 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java @@ -12,6 +12,7 @@ import org.reactivestreams.FlowAdapters; import org.reactivestreams.Subscription; +import io.modelcontextprotocol.spec.McpError; import reactor.core.publisher.BaseSubscriber; import reactor.core.publisher.FluxSink; @@ -135,6 +136,7 @@ protected void hookOnSubscribe(Subscription subscription) { @Override protected void hookOnNext(String line) { + if (line.isEmpty()) { // Empty line means end of event if (this.eventBuilder.length() > 0) { @@ -164,6 +166,13 @@ else if (line.startsWith("event:")) { this.currentEventType.set(matcher.group(1).trim()); } } + else { + // If the response is not successful, emit an error + // TODO: This should be a McpTransportError + this.sink.error(new McpError( + "Invalid SSE response. Status code: " + this.responseInfo.statusCode() + " Line: " + line)); + + } } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java b/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java index 8e654e596..039b0d68e 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java @@ -69,6 +69,9 @@ public static boolean isEmpty(@Nullable Map map) { * base URL or URI is malformed */ public static URI resolveUri(URI baseUrl, String endpointUrl) { + if (!Utils.hasText(endpointUrl)) { + return baseUrl; + } URI endpointUri = URI.create(endpointUrl); if (endpointUri.isAbsolute() && !isUnderBaseUri(baseUrl, endpointUri)) { throw new IllegalArgumentException("Absolute endpoint URL does not match the base URL.");