diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java index 128cda4c3..03a69c77a 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java @@ -5,6 +5,7 @@ import java.io.IOException; import java.util.function.BiConsumer; +import java.util.function.BiFunction; import java.util.function.Function; import com.fasterxml.jackson.core.type.TypeReference; @@ -123,6 +124,16 @@ public class WebFluxSseClientTransport implements McpClientTransport { */ private String sseEndpoint; + /** + * Handle exceptions that occur during the processing of SSE events to avoid + * connection interruption. + */ + private BiFunction, Mono> sseErrorHandler = (error, + event) -> { + logger.warn("Failed to handle SSE event {}", event, error); + return Mono.empty(); + }; + /** * Constructs a new SseClientTransport with the specified WebClient builder. Uses a * default ObjectMapper instance for JSON processing. @@ -218,12 +229,26 @@ else if (MESSAGE_EVENT_TYPE.equals(event.event())) { else { s.error(new McpError("Received unrecognized SSE event type: " + event.event())); } - }).transform(handler)).subscribe(); + }).onErrorResume(e -> sseErrorHandler.apply(e, event)).transform(handler)).subscribe(); // The connection is established once the server sends the endpoint event return messageEndpointSink.asMono().then(); } + /** + * Sets the handler for processing transport-level errors. + * + *

+ * The provided handler will be called when errors occur during transport operations, + * such as connection failures or protocol violations. + *

+ * @param errorHandler a consumer that processes error messages + */ + public void setSseErrorHandler( + BiFunction, Mono> errorHandler) { + this.sseErrorHandler = errorHandler; + } + /** * Sends a JSON-RPC message to the server using the endpoint provided during * connection. diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java index 42b91d14e..2ecea5cba 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java @@ -6,7 +6,10 @@ import java.time.Duration; import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import com.fasterxml.jackson.databind.ObjectMapper; @@ -87,6 +90,11 @@ public void simulateMessageEvent(String jsonMessage) { inboundMessageCount.incrementAndGet(); } + public void simulateComment(String comment) { + events.tryEmitNext(ServerSentEvent.builder().comment(comment).build()); + inboundMessageCount.incrementAndGet(); + } + } void startContainer() { @@ -339,4 +347,26 @@ void testMessageOrderPreservation() { assertThat(transport.getInboundMessageCount()).isEqualTo(3); } + @Test + void customErrorHandlerShouldProcessErrors() throws InterruptedException { + AtomicReference> receivedErrorEvent = new AtomicReference<>(); + AtomicReference handledError = new AtomicReference<>(); + + transport.setSseErrorHandler((error, event) -> { + receivedErrorEvent.set(event); + handledError.set(error); + return Mono.empty(); + }); + + // Mock receive a common message `: This is a comment.\n\n` + transport.simulateComment("This is a comment."); + + assertThat(receivedErrorEvent.get().comment()).isNotNull().isEqualTo("This is a comment."); + + // Mock receive a common message `:ping - 2025-05-06 08:42:06.508759+00:00\n\n` + transport.simulateComment("ping - 2025-05-06 08:42:06.508759+00:00"); + + assertThat(receivedErrorEvent.get().comment()).isNotNull().isEqualTo("ping - 2025-05-06 08:42:06.508759+00:00"); + } + }