From ad65312e777b4826d9b27e7952cd188d4a59cefe Mon Sep 17 00:00:00 2001 From: YunKui Lu Date: Wed, 7 May 2025 14:10:35 +0800 Subject: [PATCH] fix: Add an exception handling function to WebFluxSseClientTransport - `sseErrorHandler` is used to handle errors when processing SSE events and situations where an unrecognized event type is received. - Add unit tests to verify. Signed-off-by: YunKui Lu --- .../transport/WebFluxSseClientTransport.java | 27 ++++++++++++++++- .../WebFluxSseClientTransportTests.java | 30 +++++++++++++++++++ 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java index 37abe295b..9aec47b27 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java @@ -5,6 +5,7 @@ import java.io.IOException; import java.util.function.BiConsumer; +import java.util.function.BiFunction; import java.util.function.Function; import com.fasterxml.jackson.core.type.TypeReference; @@ -123,6 +124,16 @@ public class WebFluxSseClientTransport implements McpClientTransport { */ private String sseEndpoint; + /** + * Handle exceptions that occur during the processing of SSE events to avoid + * connection interruption. + */ + private BiFunction, Mono> sseErrorHandler = (error, + event) -> { + logger.warn("Failed to handle SSE event {}", event, error); + return Mono.empty(); + }; + /** * Constructs a new SseClientTransport with the specified WebClient builder. Uses a * default ObjectMapper instance for JSON processing. @@ -215,12 +226,26 @@ else if (MESSAGE_EVENT_TYPE.equals(event.event())) { else { s.error(new McpError("Received unrecognized SSE event type: " + event.event())); } - }).transform(handler)).subscribe(); + }).onErrorResume(e -> sseErrorHandler.apply(e, event)).transform(handler)).subscribe(); // The connection is established once the server sends the endpoint event return messageEndpointSink.asMono().then(); } + /** + * Sets the handler for processing transport-level errors. + * + *

+ * The provided handler will be called when errors occur during transport operations, + * such as connection failures or protocol violations. + *

+ * @param errorHandler a consumer that processes error messages + */ + public void setSseErrorHandler( + BiFunction, Mono> errorHandler) { + this.sseErrorHandler = errorHandler; + } + /** * Sends a JSON-RPC message to the server using the endpoint provided during * connection. diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java index c757d3da9..ade024261 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java @@ -6,7 +6,10 @@ import java.time.Duration; import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import com.fasterxml.jackson.databind.ObjectMapper; @@ -86,6 +89,11 @@ public void simulateMessageEvent(String jsonMessage) { inboundMessageCount.incrementAndGet(); } + public void simulateComment(String comment) { + events.tryEmitNext(ServerSentEvent.builder().comment(comment).build()); + inboundMessageCount.incrementAndGet(); + } + } void startContainer() { @@ -338,4 +346,26 @@ void testMessageOrderPreservation() { assertThat(transport.getInboundMessageCount()).isEqualTo(3); } + @Test + void customErrorHandlerShouldProcessErrors() throws InterruptedException { + AtomicReference> receivedErrorEvent = new AtomicReference<>(); + AtomicReference handledError = new AtomicReference<>(); + + transport.setSseErrorHandler((error, event) -> { + receivedErrorEvent.set(event); + handledError.set(error); + return Mono.empty(); + }); + + // Mock receive a common message `: This is a comment.\n\n` + transport.simulateComment("This is a comment."); + + assertThat(receivedErrorEvent.get().comment()).isNotNull().isEqualTo("This is a comment."); + + // Mock receive a common message `:ping - 2025-05-06 08:42:06.508759+00:00\n\n` + transport.simulateComment("ping - 2025-05-06 08:42:06.508759+00:00"); + + assertThat(receivedErrorEvent.get().comment()).isNotNull().isEqualTo("ping - 2025-05-06 08:42:06.508759+00:00"); + } + }