From 9758c2396c7866e2f3ade50edd1cfff3e44b6b3f Mon Sep 17 00:00:00 2001 From: David Dworken Date: Thu, 29 May 2025 18:15:17 -0700 Subject: [PATCH 1/3] Add support for DNS rebinding protections --- .../WebFluxSseServerTransportProvider.java | 78 +++++- .../WebMvcSseServerTransportProvider.java | 54 +++- .../DnsRebindingProtectionConfig.java | 108 ++++++++ ...HttpServletSseServerTransportProvider.java | 70 ++++- .../DnsRebindingProtectionConfigTests.java | 157 ++++++++++++ .../HttpServletSseHeaderValidationTests.java | 240 ++++++++++++++++++ 6 files changed, 702 insertions(+), 5 deletions(-) create mode 100644 mcp/src/main/java/io/modelcontextprotocol/server/transport/DnsRebindingProtectionConfig.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/transport/DnsRebindingProtectionConfigTests.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseHeaderValidationTests.java diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index 62264d9aa..f43d93adb 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -1,6 +1,7 @@ package io.modelcontextprotocol.server.transport; import java.io.IOException; +import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -110,6 +111,11 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv */ private volatile boolean isClosing = false; + /** + * DNS rebinding protection configuration. + */ + private final DnsRebindingProtectionConfig dnsRebindingProtectionConfig; + /** * Constructs a new WebFlux SSE server transport provider instance with the default * SSE endpoint. @@ -134,7 +140,7 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa * @throws IllegalArgumentException if either parameter is null */ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { - this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint); + this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint, null); } /** @@ -149,6 +155,24 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa */ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint) { + this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null); + } + + /** + * Constructs a new WebFlux SSE server transport provider instance with optional DNS + * rebinding protection. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of MCP messages. Must not be null. + * @param baseUrl webflux message base path + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages. This endpoint will be communicated to clients during SSE connection + * setup. Must not be null. + * @param sseEndpoint The endpoint URI where clients establish their SSE connections. + * @param dnsRebindingProtectionConfig The DNS rebinding protection configuration (may be null). + * @throws IllegalArgumentException if required parameters are null + */ + public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, + String sseEndpoint, DnsRebindingProtectionConfig dnsRebindingProtectionConfig) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); Assert.notNull(baseUrl, "Message base path must not be null"); Assert.notNull(messageEndpoint, "Message endpoint must not be null"); @@ -158,6 +182,7 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseU this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; + this.dnsRebindingProtectionConfig = dnsRebindingProtectionConfig; this.routerFunction = RouterFunctions.route() .GET(this.sseEndpoint, this::handleSseConnection) .POST(this.messageEndpoint, this::handleMessage) @@ -256,6 +281,16 @@ private Mono handleSseConnection(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); } + // Validate headers + if (dnsRebindingProtectionConfig != null) { + String hostHeader = request.headers().asHttpHeaders().getFirst("Host"); + String originHeader = request.headers().asHttpHeaders().getFirst("Origin"); + if (!dnsRebindingProtectionConfig.validate(hostHeader, originHeader)) { + logger.warn("DNS rebinding protection validation failed - Host: '{}', Origin: '{}'", hostHeader, originHeader); + return ServerResponse.status(HttpStatus.FORBIDDEN).bodyValue("DNS rebinding protection validation failed"); + } + } + return ServerResponse.ok() .contentType(MediaType.TEXT_EVENT_STREAM) .body(Flux.>create(sink -> { @@ -300,6 +335,25 @@ private Mono handleMessage(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); } + // Always validate Content-Type for POST requests + String contentType = request.headers().contentType() + .map(MediaType::toString) + .orElse(null); + if (contentType == null || !contentType.toLowerCase().startsWith("application/json")) { + logger.warn("Invalid Content-Type header: '{}'", contentType); + return ServerResponse.badRequest().bodyValue(new McpError("Content-Type must be application/json")); + } + + // Validate headers for POST requests if DNS rebinding protection is configured + if (dnsRebindingProtectionConfig != null) { + String hostHeader = request.headers().asHttpHeaders().getFirst("Host"); + String originHeader = request.headers().asHttpHeaders().getFirst("Origin"); + if (!dnsRebindingProtectionConfig.validate(hostHeader, originHeader)) { + logger.warn("DNS rebinding protection validation failed - Host: '{}', Origin: '{}'", hostHeader, originHeader); + return ServerResponse.status(HttpStatus.FORBIDDEN).bodyValue("DNS rebinding protection validation failed"); + } + } + if (request.queryParam("sessionId").isEmpty()) { return ServerResponse.badRequest().bodyValue(new McpError("Session ID missing in message endpoint")); } @@ -397,6 +451,8 @@ public static class Builder { private String sseEndpoint = DEFAULT_SSE_ENDPOINT; + private DnsRebindingProtectionConfig dnsRebindingProtectionConfig; + /** * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP * messages. @@ -447,6 +503,23 @@ public Builder sseEndpoint(String sseEndpoint) { return this; } + + /** + * Sets the DNS rebinding protection configuration. + *

+ * When set, this configuration will be used to create a header validator that + * enforces DNS rebinding protection rules. This will override any previously set + * header validator. + * @param config The DNS rebinding protection configuration + * @return this builder instance + * @throws IllegalArgumentException if config is null + */ + public Builder dnsRebindingProtectionConfig(DnsRebindingProtectionConfig config) { + Assert.notNull(config, "DNS rebinding protection config must not be null"); + this.dnsRebindingProtectionConfig = config; + return this; + } + /** * Builds a new instance of {@link WebFluxSseServerTransportProvider} with the * configured settings. @@ -457,7 +530,8 @@ public WebFluxSseServerTransportProvider build() { Assert.notNull(objectMapper, "ObjectMapper must be set"); Assert.notNull(messageEndpoint, "Message endpoint must be set"); - return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint); + return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint, + dnsRebindingProtectionConfig); } } diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index fc86cfaa0..d350d9ab4 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -6,6 +6,7 @@ import java.io.IOException; import java.time.Duration; +import java.util.List; import java.util.Map; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; @@ -107,6 +108,11 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi */ private volatile boolean isClosing = false; + /** + * DNS rebinding protection configuration. + */ + private final DnsRebindingProtectionConfig dnsRebindingProtectionConfig; + /** * Constructs a new WebMvcSseServerTransportProvider instance with the default SSE * endpoint. @@ -132,7 +138,7 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag * @throws IllegalArgumentException if any parameter is null */ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { - this(objectMapper, "", messageEndpoint, sseEndpoint); + this(objectMapper, "", messageEndpoint, sseEndpoint, null); } /** @@ -149,6 +155,24 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag */ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint) { + this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null); + } + + /** + * Constructs a new WebMvcSseServerTransportProvider instance with DNS rebinding protection. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of messages. + * @param baseUrl The base URL for the message endpoint, used to construct the full + * endpoint URL for clients. + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages via HTTP POST. This endpoint will be communicated to clients through the + * SSE connection's initial endpoint event. + * @param sseEndpoint The endpoint URI where clients establish their SSE connections. + * @param dnsRebindingProtectionConfig The DNS rebinding protection configuration (may be null). + * @throws IllegalArgumentException if any required parameter is null + */ + public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, + String sseEndpoint, DnsRebindingProtectionConfig dnsRebindingProtectionConfig) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); Assert.notNull(baseUrl, "Message base URL must not be null"); Assert.notNull(messageEndpoint, "Message endpoint must not be null"); @@ -158,6 +182,7 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUr this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; + this.dnsRebindingProtectionConfig = dnsRebindingProtectionConfig; this.routerFunction = RouterFunctions.route() .GET(this.sseEndpoint, this::handleSseConnection) .POST(this.messageEndpoint, this::handleMessage) @@ -247,6 +272,16 @@ private ServerResponse handleSseConnection(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); } + // Validate headers + if (dnsRebindingProtectionConfig != null) { + String hostHeader = request.headers().asHttpHeaders().getFirst("Host"); + String originHeader = request.headers().asHttpHeaders().getFirst("Origin"); + if (!dnsRebindingProtectionConfig.validate(hostHeader, originHeader)) { + logger.warn("DNS rebinding protection validation failed - Host: '{}', Origin: '{}'", hostHeader, originHeader); + return ServerResponse.status(HttpStatus.FORBIDDEN).body("DNS rebinding protection validation failed"); + } + } + String sessionId = UUID.randomUUID().toString(); logger.debug("Creating new SSE connection for session: {}", sessionId); @@ -300,6 +335,23 @@ private ServerResponse handleMessage(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); } + // Always validate Content-Type for POST requests + String contentType = request.headers().asHttpHeaders().getFirst("Content-Type"); + if (contentType == null || !contentType.toLowerCase().startsWith("application/json")) { + logger.warn("Invalid Content-Type header: '{}'", contentType); + return ServerResponse.badRequest().body(new McpError("Content-Type must be application/json")); + } + + // Validate headers for POST requests if DNS rebinding protection is configured + if (dnsRebindingProtectionConfig != null) { + String hostHeader = request.headers().asHttpHeaders().getFirst("Host"); + String originHeader = request.headers().asHttpHeaders().getFirst("Origin"); + if (!dnsRebindingProtectionConfig.validate(hostHeader, originHeader)) { + logger.warn("DNS rebinding protection validation failed - Host: '{}', Origin: '{}'", hostHeader, originHeader); + return ServerResponse.status(HttpStatus.FORBIDDEN).body("DNS rebinding protection validation failed"); + } + } + if (request.param("sessionId").isEmpty()) { return ServerResponse.badRequest().body(new McpError("Session ID missing in message endpoint")); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/DnsRebindingProtectionConfig.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/DnsRebindingProtectionConfig.java new file mode 100644 index 000000000..e03052dbb --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/DnsRebindingProtectionConfig.java @@ -0,0 +1,108 @@ +package io.modelcontextprotocol.server.transport; + +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; + +/** + * Configuration for DNS rebinding protection in SSE server transports. Provides + * validation for Host and Origin headers to prevent DNS rebinding attacks. + */ +public class DnsRebindingProtectionConfig { + + private final Set allowedHosts; + + private final Set allowedOrigins; + + private final boolean enableDnsRebindingProtection; + + private DnsRebindingProtectionConfig(Builder builder) { + this.allowedHosts = Collections.unmodifiableSet(new HashSet<>(builder.allowedHosts)); + this.allowedOrigins = Collections.unmodifiableSet(new HashSet<>(builder.allowedOrigins)); + this.enableDnsRebindingProtection = builder.enableDnsRebindingProtection; + } + + /** + * Validates Host and Origin headers for DNS rebinding protection. Returns true if the + * headers are valid, false otherwise. + * @param hostHeader The value of the Host header (may be null) + * @param originHeader The value of the Origin header (may be null) + * @return true if the headers are valid, false otherwise + */ + public boolean validate(String hostHeader, String originHeader) { + // Skip validation if protection is not enabled + if (!enableDnsRebindingProtection) { + return true; + } + + // Validate Host header + if (hostHeader != null) { + String lowerHost = hostHeader.toLowerCase(); + if (!allowedHosts.contains(lowerHost)) { + return false; + } + } + + // Validate Origin header + if (originHeader != null) { + String lowerOrigin = originHeader.toLowerCase(); + if (!allowedOrigins.contains(lowerOrigin)) { + return false; + } + } + + return true; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private final Set allowedHosts = new HashSet<>(); + + private final Set allowedOrigins = new HashSet<>(); + + private boolean enableDnsRebindingProtection = true; + + public Builder allowedHost(String host) { + if (host != null) { + this.allowedHosts.add(host.toLowerCase()); + } + return this; + } + + public Builder allowedHosts(Set hosts) { + if (hosts != null) { + hosts.forEach(this::allowedHost); + } + return this; + } + + public Builder allowedOrigin(String origin) { + if (origin != null) { + this.allowedOrigins.add(origin.toLowerCase()); + } + return this; + } + + public Builder allowedOrigins(Set origins) { + if (origins != null) { + origins.forEach(this::allowedOrigin); + } + return this; + } + + public Builder enableDnsRebindingProtection(boolean enable) { + this.enableDnsRebindingProtection = enable; + return this; + } + + public DnsRebindingProtectionConfig build() { + return new DnsRebindingProtectionConfig(this); + } + + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java index afdbff472..1ce7ee0c5 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -103,6 +103,9 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement /** Session factory for creating new sessions */ private McpServerSession.Factory sessionFactory; + /** DNS rebinding protection configuration */ + private final DnsRebindingProtectionConfig dnsRebindingProtectionConfig; + /** * Creates a new HttpServletSseServerTransportProvider instance with a custom SSE * endpoint. @@ -113,7 +116,7 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement */ public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { - this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint); + this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint, null); } /** @@ -127,10 +130,27 @@ public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String m */ public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint) { + this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null); + } + + /** + * Creates a new HttpServletSseServerTransportProvider instance with optional DNS + * rebinding protection. + * @param objectMapper The JSON object mapper to use for message + * serialization/deserialization + * @param baseUrl The base URL for the server transport + * @param messageEndpoint The endpoint path where clients will send their messages + * @param sseEndpoint The endpoint path where clients will establish SSE connections + * @param dnsRebindingProtectionConfig The DNS rebinding protection configuration (may + * be null) + */ + public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, + String sseEndpoint, DnsRebindingProtectionConfig dnsRebindingProtectionConfig) { this.objectMapper = objectMapper; this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; + this.dnsRebindingProtectionConfig = dnsRebindingProtectionConfig; } /** @@ -202,6 +222,18 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) return; } + // Validate headers if DNS rebinding protection is configured + if (dnsRebindingProtectionConfig != null) { + String hostHeader = request.getHeader("Host"); + String originHeader = request.getHeader("Origin"); + if (!dnsRebindingProtectionConfig.validate(hostHeader, originHeader)) { + logger.warn("DNS rebinding protection validation failed - Host: '{}', Origin: '{}'", hostHeader, + originHeader); + response.sendError(HttpServletResponse.SC_FORBIDDEN, "DNS rebinding protection validation failed"); + return; + } + } + response.setContentType("text/event-stream"); response.setCharacterEncoding(UTF_8); response.setHeader("Cache-Control", "no-cache"); @@ -252,6 +284,26 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) return; } + // Always validate Content-Type for POST requests + String contentType = request.getContentType(); + if (contentType == null || !contentType.toLowerCase().startsWith("application/json")) { + logger.warn("Invalid Content-Type header: '{}'", contentType); + response.sendError(HttpServletResponse.SC_BAD_REQUEST, "Content-Type must be application/json"); + return; + } + + // Validate headers for POST requests if DNS rebinding protection is configured + if (dnsRebindingProtectionConfig != null) { + String hostHeader = request.getHeader("Host"); + String originHeader = request.getHeader("Origin"); + if (!dnsRebindingProtectionConfig.validate(hostHeader, originHeader)) { + logger.warn("DNS rebinding protection validation failed - Host: '{}', Origin: '{}'", hostHeader, + originHeader); + response.sendError(HttpServletResponse.SC_FORBIDDEN, "DNS rebinding protection validation failed"); + return; + } + } + // Get the session ID from the request parameter String sessionId = request.getParameter("sessionId"); if (sessionId == null) { @@ -475,6 +527,8 @@ public static class Builder { private String sseEndpoint = DEFAULT_SSE_ENDPOINT; + private DnsRebindingProtectionConfig dnsRebindingProtectionConfig; + /** * Sets the JSON object mapper to use for message serialization/deserialization. * @param objectMapper The object mapper to use @@ -522,6 +576,17 @@ public Builder sseEndpoint(String sseEndpoint) { return this; } + /** + * Sets the DNS rebinding protection configuration. + * @param config The DNS rebinding protection configuration + * @return This builder instance for method chaining + */ + public Builder dnsRebindingProtectionConfig(DnsRebindingProtectionConfig config) { + Assert.notNull(config, "DNS rebinding protection config must not be null"); + this.dnsRebindingProtectionConfig = config; + return this; + } + /** * Builds a new instance of HttpServletSseServerTransportProvider with the * configured settings. @@ -535,7 +600,8 @@ public HttpServletSseServerTransportProvider build() { if (messageEndpoint == null) { throw new IllegalStateException("MessageEndpoint must be set"); } - return new HttpServletSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint); + return new HttpServletSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint, + dnsRebindingProtectionConfig); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/DnsRebindingProtectionConfigTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/DnsRebindingProtectionConfigTests.java new file mode 100644 index 000000000..388a48cfa --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/DnsRebindingProtectionConfigTests.java @@ -0,0 +1,157 @@ +package io.modelcontextprotocol.server.transport; + +import org.junit.jupiter.api.Test; + +import java.util.Set; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for DNS rebinding protection configuration. + */ +public class DnsRebindingProtectionConfigTests { + + @Test + void testDefaultConfiguration() { + DnsRebindingProtectionConfig config = DnsRebindingProtectionConfig.builder().build(); + + // Test default behavior - when allowed lists are empty and headers are provided, + // validation fails because the headers are not in the (empty) allowed lists + assertThat(config.validate("any.host.com", "http://any.origin.com")).isFalse(); + assertThat(config.validate("localhost", null)).isFalse(); + assertThat(config.validate(null, "http://example.com")).isFalse(); + // Null values are allowed when lists are empty + assertThat(config.validate(null, null)).isTrue(); + } + + @Test + void testDisableDnsRebindingProtection() { + DnsRebindingProtectionConfig config = DnsRebindingProtectionConfig.builder() + .enableDnsRebindingProtection(false) + .allowedHost("localhost") // Should be ignored when protection is disabled + .allowedOrigin("http://localhost") // Should be ignored when protection is + // disabled + .build(); + + // When protection is disabled, all hosts and origins should be allowed + assertThat(config.validate("evil.com", "http://evil.com")).isTrue(); + assertThat(config.validate("any.host", "http://any.origin")).isTrue(); + assertThat(config.validate(null, null)).isTrue(); + } + + @Test + void testHostValidation() { + DnsRebindingProtectionConfig config = DnsRebindingProtectionConfig.builder() + .allowedHost("localhost") + .allowedHost("127.0.0.1") + .build(); + + // Valid hosts + assertThat(config.validate("localhost", null)).isTrue(); + assertThat(config.validate("127.0.0.1", null)).isTrue(); + + // Invalid hosts + assertThat(config.validate("evil.com", null)).isFalse(); + + // Null host is allowed when no specific hosts are being checked + assertThat(config.validate(null, null)).isTrue(); + } + + @Test + void testOriginValidation() { + DnsRebindingProtectionConfig config = DnsRebindingProtectionConfig.builder() + .allowedOrigin("http://localhost:8080") + .allowedOrigin("https://app.example.com") + .build(); + + // Valid origins + assertThat(config.validate(null, "http://localhost:8080")).isTrue(); + assertThat(config.validate(null, "https://app.example.com")).isTrue(); + + // Invalid origins + assertThat(config.validate(null, "http://evil.com")).isFalse(); + + // Null origin is allowed when no specific origins are being checked + assertThat(config.validate(null, null)).isTrue(); + } + + @Test + void testCombinedHostAndOriginValidation() { + DnsRebindingProtectionConfig config = DnsRebindingProtectionConfig.builder() + .allowedHost("localhost") + .allowedOrigin("http://localhost:8080") + .build(); + + // Both valid + assertThat(config.validate("localhost", "http://localhost:8080")).isTrue(); + + // Host valid, origin invalid + assertThat(config.validate("localhost", "http://evil.com")).isFalse(); + + // Host invalid, origin valid + assertThat(config.validate("evil.com", "http://localhost:8080")).isFalse(); + + // Both invalid + assertThat(config.validate("evil.com", "http://evil.com")).isFalse(); + } + + @Test + void testCaseInsensitiveHostAndOrigin() { + DnsRebindingProtectionConfig config = DnsRebindingProtectionConfig.builder() + .allowedHost("LOCALHOST") + .allowedOrigin("HTTP://LOCALHOST:8080") + .build(); + + // Case insensitive matching + assertThat(config.validate("localhost", null)).isTrue(); + assertThat(config.validate("LOCALHOST", null)).isTrue(); + assertThat(config.validate("LoCaLhOsT", null)).isTrue(); + + assertThat(config.validate(null, "http://localhost:8080")).isTrue(); + assertThat(config.validate(null, "HTTP://LOCALHOST:8080")).isTrue(); + } + + @Test + void testEmptyAllowedListsDenyNonNull() { + DnsRebindingProtectionConfig config = DnsRebindingProtectionConfig.builder().build(); + + // When allowed lists are empty and headers are provided, validation fails + assertThat(config.validate("any.host.com", "http://any.origin.com")).isFalse(); + assertThat(config.validate("random.host", "http://random.origin")).isFalse(); + // But null values are allowed + assertThat(config.validate(null, null)).isTrue(); + } + + @Test + void testBuilderWithSets() { + Set hosts = Set.of("host1.com", "host2.com"); + Set origins = Set.of("http://origin1.com", "http://origin2.com"); + + DnsRebindingProtectionConfig config = DnsRebindingProtectionConfig.builder() + .allowedHosts(hosts) + .allowedOrigins(origins) + .build(); + + assertThat(config.validate("host1.com", null)).isTrue(); + assertThat(config.validate("host2.com", null)).isTrue(); + assertThat(config.validate("host3.com", null)).isFalse(); + + assertThat(config.validate(null, "http://origin1.com")).isTrue(); + assertThat(config.validate(null, "http://origin2.com")).isTrue(); + assertThat(config.validate(null, "http://origin3.com")).isFalse(); + } + + @Test + void testNullValuesWithConfiguredLists() { + DnsRebindingProtectionConfig config = DnsRebindingProtectionConfig.builder() + .allowedHost("localhost") + .allowedOrigin("http://localhost") + .build(); + + // Null values should be allowed when no check is needed for that header + assertThat(config.validate(null, "http://localhost")).isTrue(); + assertThat(config.validate("localhost", null)).isTrue(); + assertThat(config.validate(null, null)).isTrue(); + } + +} \ No newline at end of file diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseHeaderValidationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseHeaderValidationTests.java new file mode 100644 index 000000000..68ec9ec0d --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseHeaderValidationTests.java @@ -0,0 +1,240 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ +package io.modelcontextprotocol.server.transport; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpSyncServer; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.Map; +import java.util.concurrent.CompletionException; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Integration tests for header validation in + * {@link HttpServletSseServerTransportProvider}. + */ +class HttpServletSseHeaderValidationTests { + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private static final String SSE_ENDPOINT = "/sse"; + + private Tomcat tomcat; + + private HttpServletSseServerTransportProvider transportProvider; + + private McpSyncServer server; + + @AfterEach + void tearDown() { + if (server != null) { + server.close(); + } + if (transportProvider != null) { + transportProvider.closeGracefully().block(); + } + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + @Test + void testConnectionSucceedsWithValidHeaders() { + // Create DNS rebinding protection config that validates API key + DnsRebindingProtectionConfig dnsRebindingProtectionConfig = DnsRebindingProtectionConfig.builder() + .enableDnsRebindingProtection(false) // Disable Host/Origin validation for + // this test + .build(); + + // For this test, we'll need to use a custom transport provider implementation + // since DnsRebindingProtectionConfig doesn't support custom header validation + transportProvider = HttpServletSseServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .messageEndpoint(MESSAGE_ENDPOINT) + .sseEndpoint(SSE_ENDPOINT) + .dnsRebindingProtectionConfig(dnsRebindingProtectionConfig) + .build(); + + startServer(); + + // Create client - should succeed since DNS rebinding protection is disabled + try (var client = McpClient + .sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT).sseEndpoint(SSE_ENDPOINT).build()) + .build()) { + + // Connection should succeed + McpSchema.InitializeResult result = client.initialize(); + assertThat(result).isNotNull(); + assertThat(result.serverInfo().name()).isEqualTo("test-server"); + } + } + + @Test + void testConnectionFailsWithInvalidHeaders() { + // Create DNS rebinding protection config with restricted hosts + DnsRebindingProtectionConfig dnsRebindingProtectionConfig = DnsRebindingProtectionConfig.builder() + .allowedHost("valid-host.com") + .build(); + + // Create server with header validation + transportProvider = HttpServletSseServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .messageEndpoint(MESSAGE_ENDPOINT) + .sseEndpoint(SSE_ENDPOINT) + .dnsRebindingProtectionConfig(dnsRebindingProtectionConfig) + .build(); + + startServer(); + + // Create client with localhost which won't match the allowed host + // The Host header will be "localhost:PORT" which won't match "valid-host.com" + var clientTransport = HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .sseEndpoint(SSE_ENDPOINT) + .build(); + + // Connection should fail during initialization + assertThatThrownBy(() -> { + try (var client = McpClient.sync(clientTransport).build()) { + client.initialize(); + } + }).isInstanceOf(RuntimeException.class); + } + + @Test + void testConnectionFailsWithEmptyAllowedHostsButProvidedHost() { + // Create DNS rebinding protection config with specific allowed origin but no + // allowed hosts + // This means any non-null host will be rejected + DnsRebindingProtectionConfig dnsRebindingProtectionConfig = DnsRebindingProtectionConfig.builder() + .allowedOrigin("http://allowed-origin.com") + .build(); + + // Create server with header validation + transportProvider = HttpServletSseServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .messageEndpoint(MESSAGE_ENDPOINT) + .sseEndpoint(SSE_ENDPOINT) + .dnsRebindingProtectionConfig(dnsRebindingProtectionConfig) + .build(); + + startServer(); + + // Create client - the client will send a Host header like "localhost:PORT" + // Since allowedHosts is empty, any non-null host will be rejected + var clientTransport = HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .sseEndpoint(SSE_ENDPOINT) + .build(); + + // With the new behavior, a non-null Host header is rejected when allowedHosts is + // empty + assertThatThrownBy(() -> { + try (var client = McpClient.sync(clientTransport).build()) { + client.initialize(); + } + }).isInstanceOf(RuntimeException.class); + } + + @Test + void testComplexHeaderValidation() { + // Create DNS rebinding protection config with specific allowed hosts and origins + // Note: The Host header will include the port, so we need to allow + // "localhost:PORT" + DnsRebindingProtectionConfig dnsRebindingProtectionConfig = DnsRebindingProtectionConfig.builder() + .allowedHost("localhost:" + PORT) + .allowedOrigin("http://localhost:" + PORT) + .build(); + + // Create server with DNS rebinding protection + transportProvider = HttpServletSseServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .messageEndpoint(MESSAGE_ENDPOINT) + .sseEndpoint(SSE_ENDPOINT) + .dnsRebindingProtectionConfig(dnsRebindingProtectionConfig) + .build(); + + startServer(); + + // Test with valid headers (localhost is allowed) + try (var client = McpClient + .sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT).sseEndpoint(SSE_ENDPOINT).build()) + .build()) { + + McpSchema.InitializeResult result = client.initialize(); + assertThat(result).isNotNull(); + } + + // Test with different host (should fail) + var invalidHostTransport = HttpClientSseClientTransport.builder("http://127.0.0.1:" + PORT) + .sseEndpoint(SSE_ENDPOINT) + .build(); + + assertThatThrownBy(() -> { + try (var client = McpClient.sync(invalidHostTransport).build()) { + client.initialize(); + } + }).isInstanceOf(RuntimeException.class); + } + + @Test + void testDefaultValidatorAllowsAllHeaders() { + // Create server without specifying a DNS rebinding protection config (no + // validation) + transportProvider = HttpServletSseServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .messageEndpoint(MESSAGE_ENDPOINT) + .sseEndpoint(SSE_ENDPOINT) + .build(); + + startServer(); + + // Create client with arbitrary headers + try (var client = McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .sseEndpoint(SSE_ENDPOINT) + .customizeRequest(requestBuilder -> { + requestBuilder.header("X-Random-Header", "random-value"); + requestBuilder.header("X-Another-Header", "another-value"); + }) + .build()).build()) { + + // Connection should succeed with any headers + McpSchema.InitializeResult result = client.initialize(); + assertThat(result).isNotNull(); + } + } + + private void startServer() { + tomcat = TomcatTestUtil.createTomcatServer("", PORT, transportProvider); + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + server = McpServer.sync(transportProvider).serverInfo("test-server", "1.0.0").build(); + } + +} From 7d44740cb967c6e1a2db4623e20afba219fa46b9 Mon Sep 17 00:00:00 2001 From: David Dworken Date: Thu, 26 Jun 2025 17:03:18 -0700 Subject: [PATCH 2/3] Extract DNS rebinding validation logic to separate methods in Spring transport providers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add validateDnsRebindingProtection() method to WebFluxSseServerTransportProvider - Add validateDnsRebindingProtection() method to WebMvcSseServerTransportProvider - Remove duplicated validation logic from GET and POST endpoints - Maintain framework-specific return types (Mono vs ServerResponse) - Apply Spring Java formatting to all modules 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- .../WebFluxSseServerTransportProvider.java | 70 +++-- .../WebMvcSseServerTransportProvider.java | 59 ++-- .../transport/DnsRebindingProtection.java | 266 ++++++++++++++++++ .../DnsRebindingProtectionConfig.java | 108 ------- ...HttpServletSseServerTransportProvider.java | 101 ++++--- .../DnsRebindingProtectionConfigTests.java | 157 ----------- .../DnsRebindingProtectionTests.java | 157 +++++++++++ .../HttpServletSseHeaderValidationTests.java | 18 +- 8 files changed, 576 insertions(+), 360 deletions(-) create mode 100644 mcp/src/main/java/io/modelcontextprotocol/server/transport/DnsRebindingProtection.java delete mode 100644 mcp/src/main/java/io/modelcontextprotocol/server/transport/DnsRebindingProtectionConfig.java delete mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/transport/DnsRebindingProtectionConfigTests.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/transport/DnsRebindingProtectionTests.java diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index f43d93adb..00cca331d 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -114,7 +114,7 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv /** * DNS rebinding protection configuration. */ - private final DnsRebindingProtectionConfig dnsRebindingProtectionConfig; + private final DnsRebindingProtection dnsRebindingProtection; /** * Constructs a new WebFlux SSE server transport provider instance with the default @@ -124,8 +124,10 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC * messages. This endpoint will be communicated to clients during SSE connection * setup. Must not be null. + * @deprecated Use {@link #builder()} instead. * @throws IllegalArgumentException if either parameter is null */ + @Deprecated public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); } @@ -137,8 +139,10 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC * messages. This endpoint will be communicated to clients during SSE connection * setup. Must not be null. + * @deprecated Use {@link #builder()} instead. * @throws IllegalArgumentException if either parameter is null */ + @Deprecated public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint, null); } @@ -151,8 +155,10 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC * messages. This endpoint will be communicated to clients during SSE connection * setup. Must not be null. + * @deprecated Use {@link #builder()} instead. * @throws IllegalArgumentException if either parameter is null */ + @Deprecated public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint) { this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null); @@ -168,11 +174,12 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseU * messages. This endpoint will be communicated to clients during SSE connection * setup. Must not be null. * @param sseEndpoint The endpoint URI where clients establish their SSE connections. - * @param dnsRebindingProtectionConfig The DNS rebinding protection configuration (may be null). + * @param dnsRebindingProtection The DNS rebinding protection configuration (may be + * null). * @throws IllegalArgumentException if required parameters are null */ - public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, - String sseEndpoint, DnsRebindingProtectionConfig dnsRebindingProtectionConfig) { + private WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, + String sseEndpoint, DnsRebindingProtection dnsRebindingProtection) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); Assert.notNull(baseUrl, "Message base path must not be null"); Assert.notNull(messageEndpoint, "Message endpoint must not be null"); @@ -182,7 +189,7 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseU this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; - this.dnsRebindingProtectionConfig = dnsRebindingProtectionConfig; + this.dnsRebindingProtection = dnsRebindingProtection; this.routerFunction = RouterFunctions.route() .GET(this.sseEndpoint, this::handleSseConnection) .POST(this.messageEndpoint, this::handleMessage) @@ -282,13 +289,9 @@ private Mono handleSseConnection(ServerRequest request) { } // Validate headers - if (dnsRebindingProtectionConfig != null) { - String hostHeader = request.headers().asHttpHeaders().getFirst("Host"); - String originHeader = request.headers().asHttpHeaders().getFirst("Origin"); - if (!dnsRebindingProtectionConfig.validate(hostHeader, originHeader)) { - logger.warn("DNS rebinding protection validation failed - Host: '{}', Origin: '{}'", hostHeader, originHeader); - return ServerResponse.status(HttpStatus.FORBIDDEN).bodyValue("DNS rebinding protection validation failed"); - } + Mono validationError = validateDnsRebindingProtection(request); + if (validationError != null) { + return validationError; } return ServerResponse.ok() @@ -336,22 +339,16 @@ private Mono handleMessage(ServerRequest request) { } // Always validate Content-Type for POST requests - String contentType = request.headers().contentType() - .map(MediaType::toString) - .orElse(null); + String contentType = request.headers().contentType().map(MediaType::toString).orElse(null); if (contentType == null || !contentType.toLowerCase().startsWith("application/json")) { logger.warn("Invalid Content-Type header: '{}'", contentType); return ServerResponse.badRequest().bodyValue(new McpError("Content-Type must be application/json")); } // Validate headers for POST requests if DNS rebinding protection is configured - if (dnsRebindingProtectionConfig != null) { - String hostHeader = request.headers().asHttpHeaders().getFirst("Host"); - String originHeader = request.headers().asHttpHeaders().getFirst("Origin"); - if (!dnsRebindingProtectionConfig.validate(hostHeader, originHeader)) { - logger.warn("DNS rebinding protection validation failed - Host: '{}', Origin: '{}'", hostHeader, originHeader); - return ServerResponse.status(HttpStatus.FORBIDDEN).bodyValue("DNS rebinding protection validation failed"); - } + Mono validationError = validateDnsRebindingProtection(request); + if (validationError != null) { + return validationError; } if (request.queryParam("sessionId").isEmpty()) { @@ -451,7 +448,7 @@ public static class Builder { private String sseEndpoint = DEFAULT_SSE_ENDPOINT; - private DnsRebindingProtectionConfig dnsRebindingProtectionConfig; + private DnsRebindingProtection dnsRebindingProtection; /** * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP @@ -503,7 +500,6 @@ public Builder sseEndpoint(String sseEndpoint) { return this; } - /** * Sets the DNS rebinding protection configuration. *

@@ -514,9 +510,9 @@ public Builder sseEndpoint(String sseEndpoint) { * @return this builder instance * @throws IllegalArgumentException if config is null */ - public Builder dnsRebindingProtectionConfig(DnsRebindingProtectionConfig config) { + public Builder dnsRebindingProtection(DnsRebindingProtection config) { Assert.notNull(config, "DNS rebinding protection config must not be null"); - this.dnsRebindingProtectionConfig = config; + this.dnsRebindingProtection = config; return this; } @@ -531,9 +527,29 @@ public WebFluxSseServerTransportProvider build() { Assert.notNull(messageEndpoint, "Message endpoint must be set"); return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint, - dnsRebindingProtectionConfig); + dnsRebindingProtection); } } + /** + * Validates DNS rebinding protection for the given request. + * @param request The incoming server request + * @return A ServerResponse with forbidden status if validation fails, or null if + * validation passes + */ + private Mono validateDnsRebindingProtection(ServerRequest request) { + if (dnsRebindingProtection != null) { + String hostHeader = request.headers().asHttpHeaders().getFirst("Host"); + String originHeader = request.headers().asHttpHeaders().getFirst("Origin"); + if (!dnsRebindingProtection.isValid(hostHeader, originHeader)) { + logger.warn("DNS rebinding protection validation failed - Host: '{}', Origin: '{}'", hostHeader, + originHeader); + return ServerResponse.status(HttpStatus.FORBIDDEN) + .bodyValue("DNS rebinding protection validation failed"); + } + } + return null; + } + } diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index d350d9ab4..ced54f62c 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -111,7 +111,7 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi /** * DNS rebinding protection configuration. */ - private final DnsRebindingProtectionConfig dnsRebindingProtectionConfig; + private final DnsRebindingProtection dnsRebindingProtection; /** * Constructs a new WebMvcSseServerTransportProvider instance with the default SSE @@ -121,8 +121,10 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC * messages via HTTP POST. This endpoint will be communicated to clients through the * SSE connection's initial endpoint event. + * @deprecated Use {@link #builder()} instead. * @throws IllegalArgumentException if either objectMapper or messageEndpoint is null */ + @Deprecated public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); } @@ -135,8 +137,10 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag * messages via HTTP POST. This endpoint will be communicated to clients through the * SSE connection's initial endpoint event. * @param sseEndpoint The endpoint URI where clients establish their SSE connections. + * @deprecated Use {@link #builder()} instead. * @throws IllegalArgumentException if any parameter is null */ + @Deprecated public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { this(objectMapper, "", messageEndpoint, sseEndpoint, null); } @@ -151,15 +155,18 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag * messages via HTTP POST. This endpoint will be communicated to clients through the * SSE connection's initial endpoint event. * @param sseEndpoint The endpoint URI where clients establish their SSE connections. + * @deprecated Use {@link #builder()} instead. * @throws IllegalArgumentException if any parameter is null */ + @Deprecated public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint) { this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null); } /** - * Constructs a new WebMvcSseServerTransportProvider instance with DNS rebinding protection. + * Constructs a new WebMvcSseServerTransportProvider instance with DNS rebinding + * protection. * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization * of messages. * @param baseUrl The base URL for the message endpoint, used to construct the full @@ -168,11 +175,12 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUr * messages via HTTP POST. This endpoint will be communicated to clients through the * SSE connection's initial endpoint event. * @param sseEndpoint The endpoint URI where clients establish their SSE connections. - * @param dnsRebindingProtectionConfig The DNS rebinding protection configuration (may be null). + * @param dnsRebindingProtection The DNS rebinding protection configuration (may be + * null). * @throws IllegalArgumentException if any required parameter is null */ - public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, - String sseEndpoint, DnsRebindingProtectionConfig dnsRebindingProtectionConfig) { + private WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, + String sseEndpoint, DnsRebindingProtection dnsRebindingProtection) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); Assert.notNull(baseUrl, "Message base URL must not be null"); Assert.notNull(messageEndpoint, "Message endpoint must not be null"); @@ -182,7 +190,7 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUr this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; - this.dnsRebindingProtectionConfig = dnsRebindingProtectionConfig; + this.dnsRebindingProtection = dnsRebindingProtection; this.routerFunction = RouterFunctions.route() .GET(this.sseEndpoint, this::handleSseConnection) .POST(this.messageEndpoint, this::handleMessage) @@ -273,13 +281,9 @@ private ServerResponse handleSseConnection(ServerRequest request) { } // Validate headers - if (dnsRebindingProtectionConfig != null) { - String hostHeader = request.headers().asHttpHeaders().getFirst("Host"); - String originHeader = request.headers().asHttpHeaders().getFirst("Origin"); - if (!dnsRebindingProtectionConfig.validate(hostHeader, originHeader)) { - logger.warn("DNS rebinding protection validation failed - Host: '{}', Origin: '{}'", hostHeader, originHeader); - return ServerResponse.status(HttpStatus.FORBIDDEN).body("DNS rebinding protection validation failed"); - } + ServerResponse validationError = validateDnsRebindingProtection(request); + if (validationError != null) { + return validationError; } String sessionId = UUID.randomUUID().toString(); @@ -343,13 +347,9 @@ private ServerResponse handleMessage(ServerRequest request) { } // Validate headers for POST requests if DNS rebinding protection is configured - if (dnsRebindingProtectionConfig != null) { - String hostHeader = request.headers().asHttpHeaders().getFirst("Host"); - String originHeader = request.headers().asHttpHeaders().getFirst("Origin"); - if (!dnsRebindingProtectionConfig.validate(hostHeader, originHeader)) { - logger.warn("DNS rebinding protection validation failed - Host: '{}', Origin: '{}'", hostHeader, originHeader); - return ServerResponse.status(HttpStatus.FORBIDDEN).body("DNS rebinding protection validation failed"); - } + ServerResponse validationError = validateDnsRebindingProtection(request); + if (validationError != null) { + return validationError; } if (request.param("sessionId").isEmpty()) { @@ -469,4 +469,23 @@ public void close() { } + /** + * Validates DNS rebinding protection for the given request. + * @param request The incoming server request + * @return A ServerResponse with forbidden status if validation fails, or null if + * validation passes + */ + private ServerResponse validateDnsRebindingProtection(ServerRequest request) { + if (dnsRebindingProtection != null) { + String hostHeader = request.headers().asHttpHeaders().getFirst("Host"); + String originHeader = request.headers().asHttpHeaders().getFirst("Origin"); + if (!dnsRebindingProtection.isValid(hostHeader, originHeader)) { + logger.warn("DNS rebinding protection validation failed - Host: '{}', Origin: '{}'", hostHeader, + originHeader); + return ServerResponse.status(HttpStatus.FORBIDDEN).body("DNS rebinding protection validation failed"); + } + } + return null; + } + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/DnsRebindingProtection.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/DnsRebindingProtection.java new file mode 100644 index 000000000..d1d2530db --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/DnsRebindingProtection.java @@ -0,0 +1,266 @@ +package io.modelcontextprotocol.server.transport; + +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; + +/** + * Configuration for DNS rebinding protection in SSE server transports. + *

+ * DNS Rebinding Attacks and Protection + *

+ * DNS rebinding attacks allow malicious websites to bypass same-origin policy and + * interact with local services by rebinding a domain name to localhost (127.0.0.1) or + * other local IP addresses. This protection is critical when running applications + * locally that are not meant to be accessible to the broader network. + * + *

+ * When to Use DNS Rebinding Protection + *

    + *
  • Local development servers: MCP servers running on localhost during + * development
  • + *
  • Enterprise environments: Internal services that should not be + * accessible from external networks
  • + *
+ * + *

+ * Allowlist Configuration + *

+ * This protection requires configuring an allowlist of accepted Host and Origin values: + *

    + *
  • Host values: The expected domain/IP and port where your server is + * hosted.
    + * These values are constructed from {@code HttpServletRequest.getServerName()} and + * {@code getServerPort()}.
    + * Examples: "localhost:8080", "127.0.0.1:3000", "app.mycompany.com" (standard ports + * 80/443 omitted)
  • + *
  • Origin headers: The domains allowed to make requests to your + * server.
    + * These are taken directly from the Origin HTTP header.
    + * Examples: "http://localhost:3000", "https://app.mycompany.com"
  • + *
+ * + *

+ * Validation Behavior + *

+ * When protection is enabled, validation succeeds when the Host and Origin are either + * null, or match values in the respective allowlists (case-insensitive). + * + *

+ * Example Usage

{@code
+ * // For a local development server:
+ * DnsRebindingProtection config = DnsRebindingProtection.builder()
+ *     .allowedHost("localhost:8080")
+ *     .allowedHost("127.0.0.1:8080")
+ *     .allowedOrigin("http://localhost:3000")
+ *     .build();
+ *
+ * // To disable protection entirely:
+ * DnsRebindingProtection config = DnsRebindingProtection.builder()
+ *     .enableDnsRebindingProtection(false)
+ *     .build();
+ * }
+ * + * @see DNS Rebinding Attack + */ +public class DnsRebindingProtection { + + private final Set allowedHosts; + + private final Set allowedOrigins; + + private final boolean enable; + + /** + * Constructs a new DNS rebinding protection configuration. + * @param allowedHosts The set of allowed host header values (case-insensitive) + * @param allowedOrigins The set of allowed origin header values (case-insensitive) + * @param enable Whether DNS rebinding protection is enabled + */ + private DnsRebindingProtection(Set allowedHosts, Set allowedOrigins, boolean enable) { + this.allowedHosts = Collections.unmodifiableSet(new HashSet<>(allowedHosts)); + this.allowedOrigins = Collections.unmodifiableSet(new HashSet<>(allowedOrigins)); + this.enable = enable; + } + + /** + * Validates Host and Origin headers for DNS rebinding protection. + *

+ * Validation Logic: + *

    + *
  • If protection is disabled ({@code enable=false}): always returns + * true
  • + *
  • If both headers are null: returns true (no validation + * needed)
  • + *
  • If allowlists are empty and headers are provided: returns + * false (reject all)
  • + *
  • If headers are provided and match allowlist values: returns + * true (case-insensitive)
  • + *
  • If any provided header doesn't match its allowlist: returns + * false
  • + *
+ * + *

+ * Important Behavior: An empty allowlist will reject + * any non-null header value. This is intentional - you must + * explicitly configure allowed values for protection to be meaningful. + * @param hostHeader The value of the Host header from the HTTP request (may be null) + * @param originHeader The value of the Origin header from the HTTP request (may be + * null) + * @return {@code true} if the headers are valid according to the protection rules, + * {@code false} if validation failed + */ + public boolean isValid(String hostHeader, String originHeader) { + // Skip validation if protection is not enabled + if (!enable) { + return true; + } + + // Validate Host header + if (hostHeader != null) { + String lowerHost = hostHeader.toLowerCase(); + if (!allowedHosts.contains(lowerHost)) { + return false; + } + } + + // Validate Origin header + if (originHeader != null) { + String lowerOrigin = originHeader.toLowerCase(); + if (!allowedOrigins.contains(lowerOrigin)) { + return false; + } + } + + return true; + } + + /** + * Creates a new builder for constructing DNS rebinding protection configurations. + * @return A new builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for constructing {@link DnsRebindingProtection} instances with a fluent + * API. + *

+ * This builder provides a convenient way to configure DNS rebinding protection with + * explicit allowlists for host and origin values. The builder enforces safe + * construction patterns and provides sensible defaults. + * + *

+ * Default Behavior: + *

    + *
  • Protection is enabled by default ({@code enable = true})
  • + *
  • Both allowlists start empty, which means all non-null + * headers will be rejected
  • + *
  • This secure default forces explicit allowlist configuration
  • + *
+ * + *

+ * Usage Pattern: + *

    + *
  1. Create builder via {@link DnsRebindingProtection#builder()}
  2. + *
  3. Configure allowlists using {@link #allowedHost(String)} and + * {@link #allowedOrigin(String)}
  4. + *
  5. Optionally enable/disable protection via + * {@link #enableDnsRebindingProtection(boolean)}
  6. + *
  7. Build final instance via {@link #build()}
  8. + *
+ * + *

+ * Thread Safety: This builder is not thread-safe. + * Each thread should use its own builder instance. + * + * @see DnsRebindingProtection#builder() + */ + public static class Builder { + + private final Set allowedHosts = new HashSet<>(); + + private final Set allowedOrigins = new HashSet<>(); + + private boolean enable = true; + + /** + * Private constructor to restrict instantiation to builder() method. + */ + private Builder() { + } + + /** + * Adds an allowed host header value. + * @param host The host header value to allow (case-insensitive, may be null) + * @return This builder instance for method chaining + */ + public Builder allowedHost(String host) { + if (host != null) { + this.allowedHosts.add(host.toLowerCase()); + } + return this; + } + + /** + * Adds multiple allowed host header values. + * @param hosts The set of host header values to allow (case-insensitive, may be + * null) + * @return This builder instance for method chaining + */ + public Builder allowedHosts(Set hosts) { + if (hosts != null) { + hosts.forEach(this::allowedHost); + } + return this; + } + + /** + * Adds an allowed origin header value. + * @param origin The origin header value to allow (case-insensitive, may be null) + * @return This builder instance for method chaining + */ + public Builder allowedOrigin(String origin) { + if (origin != null) { + this.allowedOrigins.add(origin.toLowerCase()); + } + return this; + } + + /** + * Adds multiple allowed origin header values. + * @param origins The set of origin header values to allow (case-insensitive, may + * be null) + * @return This builder instance for method chaining + */ + public Builder allowedOrigins(Set origins) { + if (origins != null) { + origins.forEach(this::allowedOrigin); + } + return this; + } + + /** + * Sets whether DNS rebinding protection is enabled. + * @param enable True to enable protection, false to + * disable it + * @return This builder instance for method chaining + */ + public Builder enableDnsRebindingProtection(boolean enable) { + this.enable = enable; + return this; + } + + /** + * Builds a new DNS rebinding protection configuration with the configured + * settings. + * @return A new DnsRebindingProtection instance + */ + public DnsRebindingProtection build() { + return new DnsRebindingProtection(allowedHosts, allowedOrigins, enable); + } + + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/DnsRebindingProtectionConfig.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/DnsRebindingProtectionConfig.java deleted file mode 100644 index e03052dbb..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/DnsRebindingProtectionConfig.java +++ /dev/null @@ -1,108 +0,0 @@ -package io.modelcontextprotocol.server.transport; - -import java.util.Collections; -import java.util.HashSet; -import java.util.Set; - -/** - * Configuration for DNS rebinding protection in SSE server transports. Provides - * validation for Host and Origin headers to prevent DNS rebinding attacks. - */ -public class DnsRebindingProtectionConfig { - - private final Set allowedHosts; - - private final Set allowedOrigins; - - private final boolean enableDnsRebindingProtection; - - private DnsRebindingProtectionConfig(Builder builder) { - this.allowedHosts = Collections.unmodifiableSet(new HashSet<>(builder.allowedHosts)); - this.allowedOrigins = Collections.unmodifiableSet(new HashSet<>(builder.allowedOrigins)); - this.enableDnsRebindingProtection = builder.enableDnsRebindingProtection; - } - - /** - * Validates Host and Origin headers for DNS rebinding protection. Returns true if the - * headers are valid, false otherwise. - * @param hostHeader The value of the Host header (may be null) - * @param originHeader The value of the Origin header (may be null) - * @return true if the headers are valid, false otherwise - */ - public boolean validate(String hostHeader, String originHeader) { - // Skip validation if protection is not enabled - if (!enableDnsRebindingProtection) { - return true; - } - - // Validate Host header - if (hostHeader != null) { - String lowerHost = hostHeader.toLowerCase(); - if (!allowedHosts.contains(lowerHost)) { - return false; - } - } - - // Validate Origin header - if (originHeader != null) { - String lowerOrigin = originHeader.toLowerCase(); - if (!allowedOrigins.contains(lowerOrigin)) { - return false; - } - } - - return true; - } - - public static Builder builder() { - return new Builder(); - } - - public static class Builder { - - private final Set allowedHosts = new HashSet<>(); - - private final Set allowedOrigins = new HashSet<>(); - - private boolean enableDnsRebindingProtection = true; - - public Builder allowedHost(String host) { - if (host != null) { - this.allowedHosts.add(host.toLowerCase()); - } - return this; - } - - public Builder allowedHosts(Set hosts) { - if (hosts != null) { - hosts.forEach(this::allowedHost); - } - return this; - } - - public Builder allowedOrigin(String origin) { - if (origin != null) { - this.allowedOrigins.add(origin.toLowerCase()); - } - return this; - } - - public Builder allowedOrigins(Set origins) { - if (origins != null) { - origins.forEach(this::allowedOrigin); - } - return this; - } - - public Builder enableDnsRebindingProtection(boolean enable) { - this.enableDnsRebindingProtection = enable; - return this; - } - - public DnsRebindingProtectionConfig build() { - return new DnsRebindingProtectionConfig(this); - } - - } - -} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java index 1ce7ee0c5..12fd48c03 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -104,7 +104,7 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement private McpServerSession.Factory sessionFactory; /** DNS rebinding protection configuration */ - private final DnsRebindingProtectionConfig dnsRebindingProtectionConfig; + private final DnsRebindingProtection dnsRebindingProtection; /** * Creates a new HttpServletSseServerTransportProvider instance with a custom SSE @@ -113,7 +113,9 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement * serialization/deserialization * @param messageEndpoint The endpoint path where clients will send their messages * @param sseEndpoint The endpoint path where clients will establish SSE connections + * @deprecated Use {@link #builder()} instead. */ + @Deprecated public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint, null); @@ -127,12 +129,27 @@ public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String m * @param baseUrl The base URL for the server transport * @param messageEndpoint The endpoint path where clients will send their messages * @param sseEndpoint The endpoint path where clients will establish SSE connections + * @deprecated Use {@link #builder()} instead. */ + @Deprecated public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint) { this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null); } + /** + * Creates a new HttpServletSseServerTransportProvider instance with the default SSE + * endpoint. + * @param objectMapper The JSON object mapper to use for message + * serialization/deserialization + * @param messageEndpoint The endpoint path where clients will send their messages + * @deprecated Use {@link #builder()} instead. + */ + @Deprecated + public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { + this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); + } + /** * Creates a new HttpServletSseServerTransportProvider instance with optional DNS * rebinding protection. @@ -141,27 +158,16 @@ public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String b * @param baseUrl The base URL for the server transport * @param messageEndpoint The endpoint path where clients will send their messages * @param sseEndpoint The endpoint path where clients will establish SSE connections - * @param dnsRebindingProtectionConfig The DNS rebinding protection configuration (may - * be null) + * @param dnsRebindingProtection The DNS rebinding protection configuration (may be + * null) */ - public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, - String sseEndpoint, DnsRebindingProtectionConfig dnsRebindingProtectionConfig) { + private HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, + String sseEndpoint, DnsRebindingProtection dnsRebindingProtection) { this.objectMapper = objectMapper; this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; - this.dnsRebindingProtectionConfig = dnsRebindingProtectionConfig; - } - - /** - * Creates a new HttpServletSseServerTransportProvider instance with the default SSE - * endpoint. - * @param objectMapper The JSON object mapper to use for message - * serialization/deserialization - * @param messageEndpoint The endpoint path where clients will send their messages - */ - public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { - this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); + this.dnsRebindingProtection = dnsRebindingProtection; } /** @@ -223,15 +229,8 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) } // Validate headers if DNS rebinding protection is configured - if (dnsRebindingProtectionConfig != null) { - String hostHeader = request.getHeader("Host"); - String originHeader = request.getHeader("Origin"); - if (!dnsRebindingProtectionConfig.validate(hostHeader, originHeader)) { - logger.warn("DNS rebinding protection validation failed - Host: '{}', Origin: '{}'", hostHeader, - originHeader); - response.sendError(HttpServletResponse.SC_FORBIDDEN, "DNS rebinding protection validation failed"); - return; - } + if (!validateDnsRebindingProtection(request, response)) { + return; } response.setContentType("text/event-stream"); @@ -293,15 +292,8 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) } // Validate headers for POST requests if DNS rebinding protection is configured - if (dnsRebindingProtectionConfig != null) { - String hostHeader = request.getHeader("Host"); - String originHeader = request.getHeader("Origin"); - if (!dnsRebindingProtectionConfig.validate(hostHeader, originHeader)) { - logger.warn("DNS rebinding protection validation failed - Host: '{}', Origin: '{}'", hostHeader, - originHeader); - response.sendError(HttpServletResponse.SC_FORBIDDEN, "DNS rebinding protection validation failed"); - return; - } + if (!validateDnsRebindingProtection(request, response)) { + return; } // Get the session ID from the request parameter @@ -379,6 +371,37 @@ public Mono closeGracefully() { return Flux.fromIterable(sessions.values()).flatMap(McpServerSession::closeGracefully).then(); } + /** + * Validates DNS rebinding protection for the given request and sends an error + * response if validation fails. + *

+ * Uses {@link HttpServletRequest#getServerName()} and + * {@link HttpServletRequest#getServerPort()} instead of the Host header for HTTP/2 + * compatibility. + * @param request The HTTP servlet request + * @param response The HTTP servlet response + * @return true if validation passed (or no validation needed), false if validation + * failed and error was sent + * @throws IOException If an I/O error occurs while sending error response + */ + private boolean validateDnsRebindingProtection(HttpServletRequest request, HttpServletResponse response) + throws IOException { + if (dnsRebindingProtection != null) { + String serverName = request.getServerName(); + int serverPort = request.getServerPort(); + String hostValue = serverPort == 80 || serverPort == 443 ? serverName : serverName + ":" + serverPort; + + String originHeader = request.getHeader("Origin"); + if (!dnsRebindingProtection.isValid(hostValue, originHeader)) { + logger.warn("DNS rebinding protection validation failed - Host: '{}', Origin: '{}'", hostValue, + originHeader); + response.sendError(HttpServletResponse.SC_FORBIDDEN, "DNS rebinding protection validation failed"); + return false; + } + } + return true; + } + /** * Sends an SSE event to a client. * @param writer The writer to send the event through @@ -527,7 +550,7 @@ public static class Builder { private String sseEndpoint = DEFAULT_SSE_ENDPOINT; - private DnsRebindingProtectionConfig dnsRebindingProtectionConfig; + private DnsRebindingProtection dnsRebindingProtection; /** * Sets the JSON object mapper to use for message serialization/deserialization. @@ -581,9 +604,9 @@ public Builder sseEndpoint(String sseEndpoint) { * @param config The DNS rebinding protection configuration * @return This builder instance for method chaining */ - public Builder dnsRebindingProtectionConfig(DnsRebindingProtectionConfig config) { + public Builder dnsRebindingProtection(DnsRebindingProtection config) { Assert.notNull(config, "DNS rebinding protection config must not be null"); - this.dnsRebindingProtectionConfig = config; + this.dnsRebindingProtection = config; return this; } @@ -601,7 +624,7 @@ public HttpServletSseServerTransportProvider build() { throw new IllegalStateException("MessageEndpoint must be set"); } return new HttpServletSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint, - dnsRebindingProtectionConfig); + dnsRebindingProtection); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/DnsRebindingProtectionConfigTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/DnsRebindingProtectionConfigTests.java deleted file mode 100644 index 388a48cfa..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/DnsRebindingProtectionConfigTests.java +++ /dev/null @@ -1,157 +0,0 @@ -package io.modelcontextprotocol.server.transport; - -import org.junit.jupiter.api.Test; - -import java.util.Set; - -import static org.assertj.core.api.Assertions.assertThat; - -/** - * Unit tests for DNS rebinding protection configuration. - */ -public class DnsRebindingProtectionConfigTests { - - @Test - void testDefaultConfiguration() { - DnsRebindingProtectionConfig config = DnsRebindingProtectionConfig.builder().build(); - - // Test default behavior - when allowed lists are empty and headers are provided, - // validation fails because the headers are not in the (empty) allowed lists - assertThat(config.validate("any.host.com", "http://any.origin.com")).isFalse(); - assertThat(config.validate("localhost", null)).isFalse(); - assertThat(config.validate(null, "http://example.com")).isFalse(); - // Null values are allowed when lists are empty - assertThat(config.validate(null, null)).isTrue(); - } - - @Test - void testDisableDnsRebindingProtection() { - DnsRebindingProtectionConfig config = DnsRebindingProtectionConfig.builder() - .enableDnsRebindingProtection(false) - .allowedHost("localhost") // Should be ignored when protection is disabled - .allowedOrigin("http://localhost") // Should be ignored when protection is - // disabled - .build(); - - // When protection is disabled, all hosts and origins should be allowed - assertThat(config.validate("evil.com", "http://evil.com")).isTrue(); - assertThat(config.validate("any.host", "http://any.origin")).isTrue(); - assertThat(config.validate(null, null)).isTrue(); - } - - @Test - void testHostValidation() { - DnsRebindingProtectionConfig config = DnsRebindingProtectionConfig.builder() - .allowedHost("localhost") - .allowedHost("127.0.0.1") - .build(); - - // Valid hosts - assertThat(config.validate("localhost", null)).isTrue(); - assertThat(config.validate("127.0.0.1", null)).isTrue(); - - // Invalid hosts - assertThat(config.validate("evil.com", null)).isFalse(); - - // Null host is allowed when no specific hosts are being checked - assertThat(config.validate(null, null)).isTrue(); - } - - @Test - void testOriginValidation() { - DnsRebindingProtectionConfig config = DnsRebindingProtectionConfig.builder() - .allowedOrigin("http://localhost:8080") - .allowedOrigin("https://app.example.com") - .build(); - - // Valid origins - assertThat(config.validate(null, "http://localhost:8080")).isTrue(); - assertThat(config.validate(null, "https://app.example.com")).isTrue(); - - // Invalid origins - assertThat(config.validate(null, "http://evil.com")).isFalse(); - - // Null origin is allowed when no specific origins are being checked - assertThat(config.validate(null, null)).isTrue(); - } - - @Test - void testCombinedHostAndOriginValidation() { - DnsRebindingProtectionConfig config = DnsRebindingProtectionConfig.builder() - .allowedHost("localhost") - .allowedOrigin("http://localhost:8080") - .build(); - - // Both valid - assertThat(config.validate("localhost", "http://localhost:8080")).isTrue(); - - // Host valid, origin invalid - assertThat(config.validate("localhost", "http://evil.com")).isFalse(); - - // Host invalid, origin valid - assertThat(config.validate("evil.com", "http://localhost:8080")).isFalse(); - - // Both invalid - assertThat(config.validate("evil.com", "http://evil.com")).isFalse(); - } - - @Test - void testCaseInsensitiveHostAndOrigin() { - DnsRebindingProtectionConfig config = DnsRebindingProtectionConfig.builder() - .allowedHost("LOCALHOST") - .allowedOrigin("HTTP://LOCALHOST:8080") - .build(); - - // Case insensitive matching - assertThat(config.validate("localhost", null)).isTrue(); - assertThat(config.validate("LOCALHOST", null)).isTrue(); - assertThat(config.validate("LoCaLhOsT", null)).isTrue(); - - assertThat(config.validate(null, "http://localhost:8080")).isTrue(); - assertThat(config.validate(null, "HTTP://LOCALHOST:8080")).isTrue(); - } - - @Test - void testEmptyAllowedListsDenyNonNull() { - DnsRebindingProtectionConfig config = DnsRebindingProtectionConfig.builder().build(); - - // When allowed lists are empty and headers are provided, validation fails - assertThat(config.validate("any.host.com", "http://any.origin.com")).isFalse(); - assertThat(config.validate("random.host", "http://random.origin")).isFalse(); - // But null values are allowed - assertThat(config.validate(null, null)).isTrue(); - } - - @Test - void testBuilderWithSets() { - Set hosts = Set.of("host1.com", "host2.com"); - Set origins = Set.of("http://origin1.com", "http://origin2.com"); - - DnsRebindingProtectionConfig config = DnsRebindingProtectionConfig.builder() - .allowedHosts(hosts) - .allowedOrigins(origins) - .build(); - - assertThat(config.validate("host1.com", null)).isTrue(); - assertThat(config.validate("host2.com", null)).isTrue(); - assertThat(config.validate("host3.com", null)).isFalse(); - - assertThat(config.validate(null, "http://origin1.com")).isTrue(); - assertThat(config.validate(null, "http://origin2.com")).isTrue(); - assertThat(config.validate(null, "http://origin3.com")).isFalse(); - } - - @Test - void testNullValuesWithConfiguredLists() { - DnsRebindingProtectionConfig config = DnsRebindingProtectionConfig.builder() - .allowedHost("localhost") - .allowedOrigin("http://localhost") - .build(); - - // Null values should be allowed when no check is needed for that header - assertThat(config.validate(null, "http://localhost")).isTrue(); - assertThat(config.validate("localhost", null)).isTrue(); - assertThat(config.validate(null, null)).isTrue(); - } - -} \ No newline at end of file diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/DnsRebindingProtectionTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/DnsRebindingProtectionTests.java new file mode 100644 index 000000000..87ae4b404 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/DnsRebindingProtectionTests.java @@ -0,0 +1,157 @@ +package io.modelcontextprotocol.server.transport; + +import org.junit.jupiter.api.Test; + +import java.util.Set; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit tests for DNS rebinding protection configuration. + */ +public class DnsRebindingProtectionTests { + + @Test + void testDefaultConfiguration() { + DnsRebindingProtection config = DnsRebindingProtection.builder().build(); + + // Test default behavior - when allowed lists are empty and headers are provided, + // validation fails because the headers are not in the (empty) allowed lists + assertThat(config.isValid("any.host.com", "http://any.origin.com")).isFalse(); + assertThat(config.isValid("localhost", null)).isFalse(); + assertThat(config.isValid(null, "http://example.com")).isFalse(); + // Null values are allowed when lists are empty + assertThat(config.isValid(null, null)).isTrue(); + } + + @Test + void testDisableDnsRebindingProtection() { + DnsRebindingProtection config = DnsRebindingProtection.builder() + .enableDnsRebindingProtection(false) + .allowedHost("localhost") // Should be ignored when protection is disabled + .allowedOrigin("http://localhost") // Should be ignored when protection is + // disabled + .build(); + + // When protection is disabled, all hosts and origins should be allowed + assertThat(config.isValid("evil.com", "http://evil.com")).isTrue(); + assertThat(config.isValid("any.host", "http://any.origin")).isTrue(); + assertThat(config.isValid(null, null)).isTrue(); + } + + @Test + void testHostValidation() { + DnsRebindingProtection config = DnsRebindingProtection.builder() + .allowedHost("localhost") + .allowedHost("127.0.0.1") + .build(); + + // Valid hosts + assertThat(config.isValid("localhost", null)).isTrue(); + assertThat(config.isValid("127.0.0.1", null)).isTrue(); + + // Invalid hosts + assertThat(config.isValid("evil.com", null)).isFalse(); + + // Null host is allowed when no specific hosts are being checked + assertThat(config.isValid(null, null)).isTrue(); + } + + @Test + void testOriginValidation() { + DnsRebindingProtection config = DnsRebindingProtection.builder() + .allowedOrigin("http://localhost:8080") + .allowedOrigin("https://app.example.com") + .build(); + + // Valid origins + assertThat(config.isValid(null, "http://localhost:8080")).isTrue(); + assertThat(config.isValid(null, "https://app.example.com")).isTrue(); + + // Invalid origins + assertThat(config.isValid(null, "http://evil.com")).isFalse(); + + // Null origin is allowed when no specific origins are being checked + assertThat(config.isValid(null, null)).isTrue(); + } + + @Test + void testCombinedHostAndOriginValidation() { + DnsRebindingProtection config = DnsRebindingProtection.builder() + .allowedHost("localhost") + .allowedOrigin("http://localhost:8080") + .build(); + + // Both valid + assertThat(config.isValid("localhost", "http://localhost:8080")).isTrue(); + + // Host valid, origin invalid + assertThat(config.isValid("localhost", "http://evil.com")).isFalse(); + + // Host invalid, origin valid + assertThat(config.isValid("evil.com", "http://localhost:8080")).isFalse(); + + // Both invalid + assertThat(config.isValid("evil.com", "http://evil.com")).isFalse(); + } + + @Test + void testCaseInsensitiveHostAndOrigin() { + DnsRebindingProtection config = DnsRebindingProtection.builder() + .allowedHost("LOCALHOST") + .allowedOrigin("HTTP://LOCALHOST:8080") + .build(); + + // Case insensitive matching + assertThat(config.isValid("localhost", null)).isTrue(); + assertThat(config.isValid("LOCALHOST", null)).isTrue(); + assertThat(config.isValid("LoCaLhOsT", null)).isTrue(); + + assertThat(config.isValid(null, "http://localhost:8080")).isTrue(); + assertThat(config.isValid(null, "HTTP://LOCALHOST:8080")).isTrue(); + } + + @Test + void testEmptyAllowedListsDenyNonNull() { + DnsRebindingProtection config = DnsRebindingProtection.builder().build(); + + // When allowed lists are empty and headers are provided, validation fails + assertThat(config.isValid("any.host.com", "http://any.origin.com")).isFalse(); + assertThat(config.isValid("random.host", "http://random.origin")).isFalse(); + // But null values are allowed + assertThat(config.isValid(null, null)).isTrue(); + } + + @Test + void testBuilderWithSets() { + Set hosts = Set.of("host1.com", "host2.com"); + Set origins = Set.of("http://origin1.com", "http://origin2.com"); + + DnsRebindingProtection config = DnsRebindingProtection.builder() + .allowedHosts(hosts) + .allowedOrigins(origins) + .build(); + + assertThat(config.isValid("host1.com", null)).isTrue(); + assertThat(config.isValid("host2.com", null)).isTrue(); + assertThat(config.isValid("host3.com", null)).isFalse(); + + assertThat(config.isValid(null, "http://origin1.com")).isTrue(); + assertThat(config.isValid(null, "http://origin2.com")).isTrue(); + assertThat(config.isValid(null, "http://origin3.com")).isFalse(); + } + + @Test + void testNullValuesWithConfiguredLists() { + DnsRebindingProtection config = DnsRebindingProtection.builder() + .allowedHost("localhost") + .allowedOrigin("http://localhost") + .build(); + + // Null values should be allowed when no check is needed for that header + assertThat(config.isValid(null, "http://localhost")).isTrue(); + assertThat(config.isValid("localhost", null)).isTrue(); + assertThat(config.isValid(null, null)).isTrue(); + } + +} \ No newline at end of file diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseHeaderValidationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseHeaderValidationTests.java index 68ec9ec0d..c10d83b0f 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseHeaderValidationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseHeaderValidationTests.java @@ -63,18 +63,18 @@ void tearDown() { @Test void testConnectionSucceedsWithValidHeaders() { // Create DNS rebinding protection config that validates API key - DnsRebindingProtectionConfig dnsRebindingProtectionConfig = DnsRebindingProtectionConfig.builder() + DnsRebindingProtection dnsRebindingProtection = DnsRebindingProtection.builder() .enableDnsRebindingProtection(false) // Disable Host/Origin validation for // this test .build(); // For this test, we'll need to use a custom transport provider implementation - // since DnsRebindingProtectionConfig doesn't support custom header validation + // since DnsRebindingProtection doesn't support custom header validation transportProvider = HttpServletSseServerTransportProvider.builder() .objectMapper(new ObjectMapper()) .messageEndpoint(MESSAGE_ENDPOINT) .sseEndpoint(SSE_ENDPOINT) - .dnsRebindingProtectionConfig(dnsRebindingProtectionConfig) + .dnsRebindingProtection(dnsRebindingProtection) .build(); startServer(); @@ -94,7 +94,7 @@ void testConnectionSucceedsWithValidHeaders() { @Test void testConnectionFailsWithInvalidHeaders() { // Create DNS rebinding protection config with restricted hosts - DnsRebindingProtectionConfig dnsRebindingProtectionConfig = DnsRebindingProtectionConfig.builder() + DnsRebindingProtection dnsRebindingProtection = DnsRebindingProtection.builder() .allowedHost("valid-host.com") .build(); @@ -103,7 +103,7 @@ void testConnectionFailsWithInvalidHeaders() { .objectMapper(new ObjectMapper()) .messageEndpoint(MESSAGE_ENDPOINT) .sseEndpoint(SSE_ENDPOINT) - .dnsRebindingProtectionConfig(dnsRebindingProtectionConfig) + .dnsRebindingProtection(dnsRebindingProtection) .build(); startServer(); @@ -127,7 +127,7 @@ void testConnectionFailsWithEmptyAllowedHostsButProvidedHost() { // Create DNS rebinding protection config with specific allowed origin but no // allowed hosts // This means any non-null host will be rejected - DnsRebindingProtectionConfig dnsRebindingProtectionConfig = DnsRebindingProtectionConfig.builder() + DnsRebindingProtection dnsRebindingProtection = DnsRebindingProtection.builder() .allowedOrigin("http://allowed-origin.com") .build(); @@ -136,7 +136,7 @@ void testConnectionFailsWithEmptyAllowedHostsButProvidedHost() { .objectMapper(new ObjectMapper()) .messageEndpoint(MESSAGE_ENDPOINT) .sseEndpoint(SSE_ENDPOINT) - .dnsRebindingProtectionConfig(dnsRebindingProtectionConfig) + .dnsRebindingProtection(dnsRebindingProtection) .build(); startServer(); @@ -161,7 +161,7 @@ void testComplexHeaderValidation() { // Create DNS rebinding protection config with specific allowed hosts and origins // Note: The Host header will include the port, so we need to allow // "localhost:PORT" - DnsRebindingProtectionConfig dnsRebindingProtectionConfig = DnsRebindingProtectionConfig.builder() + DnsRebindingProtection dnsRebindingProtection = DnsRebindingProtection.builder() .allowedHost("localhost:" + PORT) .allowedOrigin("http://localhost:" + PORT) .build(); @@ -171,7 +171,7 @@ void testComplexHeaderValidation() { .objectMapper(new ObjectMapper()) .messageEndpoint(MESSAGE_ENDPOINT) .sseEndpoint(SSE_ENDPOINT) - .dnsRebindingProtectionConfig(dnsRebindingProtectionConfig) + .dnsRebindingProtection(dnsRebindingProtection) .build(); startServer(); From 145b592ec1196f97bb81a59c99dd9619dcac06a1 Mon Sep 17 00:00:00 2001 From: David Dworken Date: Thu, 26 Jun 2025 17:07:34 -0700 Subject: [PATCH 3/3] Add javadocs --- .../server/transport/DnsRebindingProtection.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/DnsRebindingProtection.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/DnsRebindingProtection.java index d1d2530db..2d1fd22d3 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/DnsRebindingProtection.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/DnsRebindingProtection.java @@ -167,7 +167,8 @@ public static Builder builder() { *

  • Configure allowlists using {@link #allowedHost(String)} and * {@link #allowedOrigin(String)}
  • *
  • Optionally enable/disable protection via - * {@link #enableDnsRebindingProtection(boolean)}
  • + * {@link #enableDnsRebindingProtection(boolean)}. Protection is default-enabled when + * creating a builder. *
  • Build final instance via {@link #build()}
  • * *