diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index fde067f03..2b32cf07e 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -11,6 +11,7 @@ import io.modelcontextprotocol.spec.McpServerTransport; import io.modelcontextprotocol.spec.McpServerTransportProvider; import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.Utils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.Exceptions; @@ -81,6 +82,8 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv */ public static final String DEFAULT_SSE_ENDPOINT = "/sse"; + public static final String DEFAULT_CONTEXT_PATH = ""; + public static final String DEFAULT_BASE_URL = ""; private final ObjectMapper objectMapper; @@ -91,6 +94,8 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv */ private final String baseUrl; + private final String contextPath; + private final String messageEndpoint; private final String sseEndpoint; @@ -133,33 +138,38 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa * @throws IllegalArgumentException if either parameter is null */ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { - this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint); + this(objectMapper, DEFAULT_CONTEXT_PATH, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint); } /** * Constructs a new WebFlux SSE server transport provider instance. * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization * of MCP messages. Must not be null. + * @param contextPath The context path of the server. * @param baseUrl webflux message base path * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC * messages. This endpoint will be communicated to clients during SSE connection * setup. Must not be null. * @throws IllegalArgumentException if either parameter is null */ - public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, - String sseEndpoint) { + public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String contextPath, String baseUrl, + String messageEndpoint, String sseEndpoint) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); + Assert.notNull(contextPath, "Context path must not be null"); Assert.notNull(baseUrl, "Message base path must not be null"); Assert.notNull(messageEndpoint, "Message endpoint must not be null"); + Assert.hasText(messageEndpoint, "Message endpoint must not be empty"); Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); + Assert.hasText(sseEndpoint, "SSE endpoint must not be empty"); this.objectMapper = objectMapper; - this.baseUrl = baseUrl; + this.contextPath = Utils.removeTrailingSlash(contextPath); + this.baseUrl = Utils.removeTrailingSlash(baseUrl); this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; this.routerFunction = RouterFunctions.route() - .GET(this.sseEndpoint, this::handleSseConnection) - .POST(this.messageEndpoint, this::handleMessage) + .GET(this.baseUrl + this.sseEndpoint, this::handleSseConnection) + .POST(this.baseUrl + this.messageEndpoint, this::handleMessage) .build(); } @@ -270,7 +280,7 @@ private Mono handleSseConnection(ServerRequest request) { logger.debug("Sending initial endpoint event to session: {}", sessionId); sink.next(ServerSentEvent.builder() .event(ENDPOINT_EVENT_TYPE) - .data(this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId) + .data(this.contextPath + this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId) .build()); sink.onCancel(() -> { logger.debug("Session {} cancelled", sessionId); @@ -390,6 +400,8 @@ public static class Builder { private ObjectMapper objectMapper; + private String contextPath = DEFAULT_CONTEXT_PATH; + private String baseUrl = DEFAULT_BASE_URL; private String messageEndpoint; @@ -422,6 +434,18 @@ public Builder basePath(String baseUrl) { return this; } + /** + * Sets the context path under which the server is running. + * @param contextPath the context path. + * @return this builder instance. + * @throws IllegalArgumentException if contextPath is null + */ + public Builder contextPath(String contextPath) { + Assert.notNull(contextPath, "contextPath must not be null"); + this.contextPath = contextPath; + return this; + } + /** * Sets the endpoint URI where clients should send their JSON-RPC messages. * @param messageEndpoint The message endpoint URI. Must not be null. @@ -456,7 +480,8 @@ public WebFluxSseServerTransportProvider build() { Assert.notNull(objectMapper, "ObjectMapper must be set"); Assert.notNull(messageEndpoint, "Message endpoint must be set"); - return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint); + return new WebFluxSseServerTransportProvider(objectMapper, contextPath, baseUrl, messageEndpoint, + sseEndpoint); } } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseCustomPathIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseCustomPathIntegrationTests.java new file mode 100644 index 000000000..5cd90a5c7 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseCustomPathIntegrationTests.java @@ -0,0 +1,158 @@ +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.server.RouterFunction; +import org.springframework.web.reactive.function.server.RouterFunctions; +import org.springframework.web.reactive.function.server.ServerResponse; +import reactor.core.publisher.Mono; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.web.reactive.function.server.RequestPredicates.path; +import static org.springframework.web.reactive.function.server.RouterFunctions.nest; + +/** + * Tests the {@link WebFluxSseServerTransportProvider} with different values for the + * endpoint. + */ +public class WebFluxSseCustomPathIntegrationTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private DisposableServer httpServer; + + private WebFluxSseServerTransportProvider mcpServerTransportProvider; + + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + @ParameterizedTest( + name = "baseUrl = \"{0}\" messageEndpoint = \"{1}\" sseEndpoint = \"{2}\" contextPath = \"{3}\" : {displayName} ") + @MethodSource("provideCustomEndpoints") + public void testCustomizedEndpoints(String baseUrl, String messageEndpoint, String sseEndpoint, + String contextPath) { + + this.mcpServerTransportProvider = new WebFluxSseServerTransportProvider(new ObjectMapper(), contextPath, + baseUrl, messageEndpoint, sseEndpoint); + + RouterFunction router = this.mcpServerTransportProvider.getRouterFunction(); + // wrap the context path around the router function + RouterFunction nestedRouter = (RouterFunction) nest(path(contextPath), router); + HttpHandler httpHandler = RouterFunctions.toHttpHandler(nestedRouter); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + + this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + + var endpoint = buildSseEndpoint(contextPath, baseUrl, sseEndpoint); + + var clientBuilder = McpClient + .sync(WebFluxSseClientTransport.builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) + .sseEndpoint(endpoint) + .build()); + + McpSchema.CallToolResult callResponse = new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + + McpServerFeatures.AsyncToolSpecification tool1 = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), + (exchange, request) -> Mono.just(callResponse)); + + var server = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { + assertThat(client.initialize()).isNotNull(); + assertThat(client.listTools().tools()).contains(tool1.tool()); + + McpSchema.CallToolResult response = client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + assertThat(response).isNotNull().isEqualTo(callResponse); + } + + server.close(); + + } + + /** + * This is a helper function for the tests which builds the SSE endpoint to pass to + * the client transport. + * @param contextPath context path of the server. + * @param baseUrl base url of the sse endpoint. + * @param sseEndpoint the sse endpoint. + * @return the created sse endpoint. + */ + private String buildSseEndpoint(String contextPath, String baseUrl, String sseEndpoint) { + if (baseUrl.endsWith("/")) { + baseUrl = baseUrl.substring(0, baseUrl.length() - 1); + } + if (contextPath.endsWith("/")) { + contextPath = contextPath.substring(0, contextPath.length() - 1); + } + + return contextPath + baseUrl + sseEndpoint; + } + + @AfterEach + public void after() { + if (mcpServerTransportProvider != null) { + mcpServerTransportProvider.closeGracefully().block(); + } + if (httpServer != null) { + httpServer.disposeNow(); + } + } + + /** + * Provides a stream of custom endpoints. This generates all possible combinations for + * allowed endpoint values. + * + *

+ * Each combination is returned as an {@link Arguments} object containing four + * parameters in the following order: + *

+ *
    + *
  1. Base URL (String)
  2. + *
  3. Message endpoint (String)
  4. + *
  5. SSE endpoint (String)
  6. + *
  7. Context path (String)
  8. + *
+ * @return a {@link Stream} of {@link Arguments} objects, each containing four String + * parameters representing different endpoint combinations for parameterized testing + */ + private static Stream provideCustomEndpoints() { + String[] baseUrls = { "", "/", "/v1", "/v1/" }; + String[] messageEndpoints = { "/", "/message", "/message/" }; + String[] sseEndpoints = { "/", "/sse", "/sse/" }; + String[] contextPaths = { "", "/", "/mcp", "/mcp/" }; + + return Stream.of(baseUrls) + .flatMap(baseUrl -> Stream.of(messageEndpoints) + .flatMap(messageEndpoint -> Stream.of(sseEndpoints) + .flatMap(sseEndpoint -> Stream.of(contextPaths) + .map(contextPath -> Arguments.of(baseUrl, messageEndpoint, sseEndpoint, contextPath))))); + } + +} \ No newline at end of file diff --git a/mcp-spring/mcp-spring-webmvc/pom.xml b/mcp-spring/mcp-spring-webmvc/pom.xml index 4c6d37bf9..7df64b3b8 100644 --- a/mcp-spring/mcp-spring-webmvc/pom.xml +++ b/mcp-spring/mcp-spring-webmvc/pom.xml @@ -71,6 +71,12 @@ ${junit.version} test + + org.junit.jupiter + junit-jupiter-params + ${junit.version} + test + org.mockito mockito-core diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index 114eff607..1b6404292 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -17,6 +17,7 @@ import io.modelcontextprotocol.spec.McpServerTransportProvider; import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.Utils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; @@ -92,6 +93,8 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi private final String baseUrl; + private final String contextPath; + private final RouterFunction routerFunction; private McpServerSession.Factory sessionFactory; @@ -131,13 +134,14 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag * @throws IllegalArgumentException if any parameter is null */ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { - this(objectMapper, "", messageEndpoint, sseEndpoint); + this(objectMapper, "", "", messageEndpoint, sseEndpoint); } /** * Constructs a new WebMvcSseServerTransportProvider instance. * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization * of messages. + * @param contextPath The context path under which the server runs. * @param baseUrl The base URL for the message endpoint, used to construct the full * endpoint URL for clients. * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC @@ -146,20 +150,24 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag * @param sseEndpoint The endpoint URI where clients establish their SSE connections. * @throws IllegalArgumentException if any parameter is null */ - public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, - String sseEndpoint) { + public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String contextPath, String baseUrl, + String messageEndpoint, String sseEndpoint) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); + Assert.notNull(contextPath, "Context path must not be null"); Assert.notNull(baseUrl, "Message base URL must not be null"); Assert.notNull(messageEndpoint, "Message endpoint must not be null"); + Assert.hasText(messageEndpoint, "Message endpoint must not be empty"); Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); + Assert.hasText(sseEndpoint, "SSE endpoint must not be empty"); this.objectMapper = objectMapper; - this.baseUrl = baseUrl; + this.contextPath = Utils.removeTrailingSlash(contextPath); + this.baseUrl = Utils.removeTrailingSlash(baseUrl); this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; this.routerFunction = RouterFunctions.route() - .GET(this.sseEndpoint, this::handleSseConnection) - .POST(this.messageEndpoint, this::handleMessage) + .GET(this.baseUrl + this.sseEndpoint, this::handleSseConnection) + .POST(this.baseUrl + this.messageEndpoint, this::handleMessage) .build(); } @@ -268,7 +276,7 @@ private ServerResponse handleSseConnection(ServerRequest request) { try { sseBuilder.id(sessionId) .event(ENDPOINT_EVENT_TYPE) - .data(this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId); + .data(this.contextPath + this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId); } catch (Exception e) { logger.error("Failed to send initial endpoint event: {}", e.getMessage()); diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java index 1b5218cc5..06a01f797 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java @@ -91,7 +91,7 @@ static class TestConfig { @Bean public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() { - return new WebMvcSseServerTransportProvider(new ObjectMapper(), CUSTOM_CONTEXT_PATH, MESSAGE_ENDPOINT, + return new WebMvcSseServerTransportProvider(new ObjectMapper(), CUSTOM_CONTEXT_PATH, "", MESSAGE_ENDPOINT, WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT); } diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomPathIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomPathIntegrationTests.java new file mode 100644 index 000000000..7b39f1321 --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomPathIntegrationTests.java @@ -0,0 +1,192 @@ +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; + +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; + +import org.springframework.core.env.Environment; +import reactor.core.publisher.Mono; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerResponse; + +import com.fasterxml.jackson.databind.ObjectMapper; + +/** + * Tests the {@link WebMvcSseServerTransportProvider} with different values for the + * endpoint. + */ +public class WebMvcSseCustomPathIntegrationTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private WebMvcSseServerTransportProvider mcpServerTransportProvider; + + private TomcatTestUtil.TomcatServer tomcatServer; + + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + @Configuration + @EnableWebMvc + static class TestConfig { + + @Bean + public WebMvcSseServerTransportProvider transportProvider(Environment env) { + String baseUrl = env.getProperty("test.baseUrl"); + String messageEndpoint = env.getProperty("test.messageEndpoint"); + String sseEndpoint = env.getProperty("test.sseEndpoint"); + String contextPath = env.getProperty("test.contextPath"); + + return new WebMvcSseServerTransportProvider(new ObjectMapper(), contextPath, baseUrl, messageEndpoint, + sseEndpoint); + } + + @Bean + public RouterFunction routerFunction(WebMvcSseServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + + } + + @ParameterizedTest( + name = "baseUrl = \"{0}\" messageEndpoint = \"{1}\" sseEndpoint = \"{2}\" contextPath = \"{3}\" : {displayName} ") + @MethodSource("provideCustomEndpoints") + public void testCustomizedEndpoints(String baseUrl, String messageEndpoint, String sseEndpoint, + String contextPath) { + System.setProperty("test.baseUrl", baseUrl); + System.setProperty("test.messageEndpoint", messageEndpoint); + System.setProperty("test.sseEndpoint", sseEndpoint); + System.setProperty("test.contextPath", contextPath); + + tomcatServer = TomcatTestUtil.createTomcatServer(contextPath, PORT, TestConfig.class); + + try { + tomcatServer.tomcat().start(); + assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + var endpoint = buildSseEndpoint(contextPath, baseUrl, sseEndpoint); + + var clientBuilder = McpClient + .sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT).sseEndpoint(endpoint).build()); + + McpSchema.CallToolResult callResponse = new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + + McpServerFeatures.AsyncToolSpecification tool1 = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), + (exchange, request) -> Mono.just(callResponse)); + + mcpServerTransportProvider = tomcatServer.appContext().getBean(WebMvcSseServerTransportProvider.class); + + var server = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { + assertThat(client.initialize()).isNotNull(); + assertThat(client.listTools().tools()).contains(tool1.tool()); + + McpSchema.CallToolResult response = client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + assertThat(response).isNotNull().isEqualTo(callResponse); + } + + server.close(); + } + + /** + * This is a helper function for the tests which builds the SSE endpoint to pass to + * the client transport. + * @param contextPath context path of the server. + * @param baseUrl base url of the sse endpoint. + * @param sseEndpoint the sse endpoint. + * @return the created sse endpoint. + */ + private String buildSseEndpoint(String contextPath, String baseUrl, String sseEndpoint) { + if (baseUrl.endsWith("/")) { + baseUrl = baseUrl.substring(0, baseUrl.length() - 1); + } + if (contextPath.endsWith("/")) { + contextPath = contextPath.substring(0, contextPath.length() - 1); + } + + return contextPath + baseUrl + sseEndpoint; + } + + @AfterEach + public void after() { + if (mcpServerTransportProvider != null) { + mcpServerTransportProvider.closeGracefully().block(); + } + if (tomcatServer.appContext() != null) { + tomcatServer.appContext().close(); + } + if (tomcatServer.tomcat() != null) { + try { + tomcatServer.tomcat().stop(); + tomcatServer.tomcat().destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + /** + * Provides a stream of custom endpoints. This generates all possible combinations for + * allowed endpoint values. + * + *

+ * Each combination is returned as an {@link Arguments} object containing four + * parameters in the following order: + *

+ *
    + *
  1. Base URL (String)
  2. + *
  3. Message endpoint (String)
  4. + *
  5. SSE endpoint (String)
  6. + *
  7. Context path (String)
  8. + *
+ * @return a {@link Stream} of {@link Arguments} objects, each containing four String + * parameters representing different endpoint combinations for parameterized testing + */ + private static Stream provideCustomEndpoints() { + String[] baseUrls = { "", "/", "/v1", "/v1/" }; + String[] messageEndpoints = { "/", "/message", "/message/" }; + String[] sseEndpoints = { "/", "/sse", "/sse/" }; + String[] contextPaths = { "", "/", "/mcp", "/mcp/" }; + + return Stream.of(baseUrls) + .flatMap(baseUrl -> Stream.of(messageEndpoints) + .flatMap(messageEndpoint -> Stream.of(sseEndpoints) + .flatMap(sseEndpoint -> Stream.of(contextPaths) + .map(contextPath -> Arguments.of(baseUrl, messageEndpoint, sseEndpoint, contextPath))))); + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index 8598e3164..f9bd39344 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -175,9 +175,9 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpReques HttpClientSseClientTransport(HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri, String sseEndpoint, ObjectMapper objectMapper) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); - Assert.hasText(baseUri, "baseUri must not be empty"); - Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); - Assert.notNull(httpClient, "httpClient must not be null"); + Assert.notNull(baseUri, "baseUri must not be null"); + Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); + Assert.hasText(sseEndpoint, "SSE endpoint must not be empty"); Assert.notNull(requestBuilder, "requestBuilder must not be null"); this.baseUri = URI.create(baseUri); this.sseEndpoint = sseEndpoint; @@ -325,9 +325,11 @@ public HttpClientSseClientTransport build() { public Mono connect(Function, Mono> handler) { return Mono.create(sink -> { + URI clientUri = Utils.resolveUri(this.baseUri, this.sseEndpoint); + logger.debug("Subscribing to {}", clientUri); HttpRequest request = requestBuilder.copy() - .uri(Utils.resolveUri(this.baseUri, this.sseEndpoint)) + .uri(clientUri) .header("Accept", "text/event-stream") .header("Cache-Control", "no-cache") .GET() diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java b/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java index 8e654e596..c45e31500 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java @@ -53,6 +53,18 @@ public static boolean isEmpty(@Nullable Map map) { return (map == null || map.isEmpty()); } + /** + * Removes the trailing slash character of the given String. + * @param str the String to remove the trailing slash + * @return the modified String. + */ + public static String removeTrailingSlash(String str) { + if (str.endsWith("/")) { + str = str.substring(0, str.length() - 1); + } + return str; + } + /** * Resolves the given endpoint URL against the base URL. *