Skip to content

Commit 5667caf

Browse files
committed
fix the baseUrl is configured with a trailing slash
Signed-off-by: lance <[email protected]>
1 parent 5c46626 commit 5667caf

File tree

2 files changed

+133
-5
lines changed

2 files changed

+133
-5
lines changed

mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import org.springframework.web.servlet.function.ServerRequest;
3737
import org.springframework.web.servlet.function.ServerResponse;
3838
import org.springframework.web.servlet.function.ServerResponse.SseBuilder;
39+
import org.springframework.web.util.UriComponentsBuilder;
3940

4041
/**
4142
* Server-side implementation of the Model Context Protocol (MCP) transport layer using
@@ -87,6 +88,8 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi
8788
*/
8889
public static final String ENDPOINT_EVENT_TYPE = "endpoint";
8990

91+
public static final String SESSION_ID = "sessionId";
92+
9093
/**
9194
* Default SSE endpoint path as specified by the MCP transport specification.
9295
*/
@@ -275,9 +278,7 @@ private ServerResponse handleSseConnection(ServerRequest request) {
275278
this.sessions.put(sessionId, session);
276279

277280
try {
278-
sseBuilder.id(sessionId)
279-
.event(ENDPOINT_EVENT_TYPE)
280-
.data(this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId);
281+
sseBuilder.id(sessionId).event(ENDPOINT_EVENT_TYPE).data(buildEndpointUrl(sessionId));
281282
}
282283
catch (Exception e) {
283284
logger.error("Failed to send initial endpoint event: {}", e.getMessage());
@@ -292,6 +293,14 @@ private ServerResponse handleSseConnection(ServerRequest request) {
292293
}
293294
}
294295

296+
private String buildEndpointUrl(String sessionId) {
297+
return UriComponentsBuilder.fromUriString(baseUrl)
298+
.path(messageEndpoint)
299+
.queryParam(SESSION_ID, sessionId)
300+
.build()
301+
.toUriString();
302+
}
303+
295304
/**
296305
* Handles incoming JSON-RPC messages from clients. This method:
297306
* <ul>
@@ -308,11 +317,11 @@ private ServerResponse handleMessage(ServerRequest request) {
308317
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down");
309318
}
310319

311-
if (request.param("sessionId").isEmpty()) {
320+
if (request.param(SESSION_ID).isEmpty()) {
312321
return ServerResponse.badRequest().body(new McpError("Session ID missing in message endpoint"));
313322
}
314323

315-
String sessionId = request.param("sessionId").get();
324+
String sessionId = request.param(SESSION_ID).get();
316325
McpServerSession session = sessions.get(sessionId);
317326

318327
if (session == null) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
/*
2+
* Copyright 2024 - 2024 the original author or authors.
3+
*/
4+
5+
package io.modelcontextprotocol.server.transport;
6+
7+
import io.modelcontextprotocol.client.McpClient;
8+
import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
9+
import io.modelcontextprotocol.common.McpTransportContext;
10+
import io.modelcontextprotocol.json.McpJsonMapper;
11+
import io.modelcontextprotocol.server.McpServer;
12+
import io.modelcontextprotocol.server.TestUtil;
13+
import io.modelcontextprotocol.server.TomcatTestUtil;
14+
import io.modelcontextprotocol.spec.McpSchema;
15+
import org.apache.catalina.LifecycleException;
16+
import org.apache.catalina.LifecycleState;
17+
import org.junit.jupiter.api.AfterEach;
18+
import org.junit.jupiter.api.BeforeEach;
19+
import org.junit.jupiter.api.Test;
20+
21+
import org.springframework.context.annotation.Bean;
22+
import org.springframework.context.annotation.Configuration;
23+
import org.springframework.web.servlet.config.annotation.EnableWebMvc;
24+
import org.springframework.web.servlet.function.RouterFunction;
25+
import org.springframework.web.servlet.function.ServerResponse;
26+
27+
import static org.assertj.core.api.Assertions.assertThat;
28+
29+
/**
30+
* Integration tests for WebMvcSseServerTransportProvider
31+
*
32+
* @author lance
33+
*/
34+
class WebMvcSseServerTransportProviderTests {
35+
36+
private static final int PORT = TestUtil.findAvailablePort();
37+
38+
private static final String CUSTOM_CONTEXT_PATH = "/";
39+
40+
private static final String MESSAGE_ENDPOINT = "/mcp/message";
41+
42+
private WebMvcSseServerTransportProvider mcpServerTransportProvider;
43+
44+
McpClient.SyncSpec clientBuilder;
45+
46+
private TomcatTestUtil.TomcatServer tomcatServer;
47+
48+
@BeforeEach
49+
public void before() {
50+
tomcatServer = TomcatTestUtil.createTomcatServer(CUSTOM_CONTEXT_PATH, PORT, TestConfig.class);
51+
52+
try {
53+
tomcatServer.tomcat().start();
54+
assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED);
55+
}
56+
catch (Exception e) {
57+
throw new RuntimeException("Failed to start Tomcat", e);
58+
}
59+
60+
HttpClientSseClientTransport transport = HttpClientSseClientTransport.builder("http://localhost:" + PORT)
61+
.sseEndpoint(WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT)
62+
.build();
63+
64+
clientBuilder = McpClient.sync(transport);
65+
mcpServerTransportProvider = tomcatServer.appContext().getBean(WebMvcSseServerTransportProvider.class);
66+
}
67+
68+
@Test
69+
void validBaseUrl() {
70+
McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").build();
71+
try (var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0"))
72+
.build()) {
73+
assertThat(client.initialize()).isNotNull();
74+
}
75+
}
76+
77+
@AfterEach
78+
public void after() {
79+
if (mcpServerTransportProvider != null) {
80+
mcpServerTransportProvider.closeGracefully().block();
81+
}
82+
if (tomcatServer.appContext() != null) {
83+
tomcatServer.appContext().close();
84+
}
85+
if (tomcatServer.tomcat() != null) {
86+
try {
87+
tomcatServer.tomcat().stop();
88+
tomcatServer.tomcat().destroy();
89+
}
90+
catch (LifecycleException e) {
91+
throw new RuntimeException("Failed to stop Tomcat", e);
92+
}
93+
}
94+
}
95+
96+
@Configuration
97+
@EnableWebMvc
98+
static class TestConfig {
99+
100+
@Bean
101+
public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() {
102+
103+
return WebMvcSseServerTransportProvider.builder()
104+
.baseUrl("http://localhost:" + PORT + "/")
105+
.messageEndpoint(MESSAGE_ENDPOINT)
106+
.sseEndpoint(WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT)
107+
.jsonMapper(McpJsonMapper.getDefault())
108+
.contextExtractor(req -> McpTransportContext.EMPTY)
109+
.build();
110+
}
111+
112+
@Bean
113+
public RouterFunction<ServerResponse> routerFunction(WebMvcSseServerTransportProvider transportProvider) {
114+
return transportProvider.getRouterFunction();
115+
}
116+
117+
}
118+
119+
}

0 commit comments

Comments
 (0)