-
Notifications
You must be signed in to change notification settings - Fork 653
Add support for DNS rebinding protections #284
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
ddworken
wants to merge
4
commits into
modelcontextprotocol:main
Choose a base branch
from
ddworken:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -106,6 +106,11 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi | |
*/ | ||
private volatile boolean isClosing = false; | ||
|
||
/** | ||
* DNS rebinding protection configuration. | ||
*/ | ||
private final DnsRebindingProtection dnsRebindingProtection; | ||
|
||
/** | ||
* Constructs a new WebMvcSseServerTransportProvider instance with the default SSE | ||
* endpoint. | ||
|
@@ -114,8 +119,10 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi | |
* @param messageEndpoint The endpoint URI where clients should send their JSON-RPC | ||
* messages via HTTP POST. This endpoint will be communicated to clients through the | ||
* SSE connection's initial endpoint event. | ||
* @deprecated Use {@link #builder()} instead. | ||
* @throws IllegalArgumentException if either objectMapper or messageEndpoint is null | ||
*/ | ||
@Deprecated | ||
public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { | ||
this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); | ||
} | ||
|
@@ -128,10 +135,12 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag | |
* messages via HTTP POST. This endpoint will be communicated to clients through the | ||
* SSE connection's initial endpoint event. | ||
* @param sseEndpoint The endpoint URI where clients establish their SSE connections. | ||
* @deprecated Use {@link #builder()} instead. | ||
* @throws IllegalArgumentException if any parameter is null | ||
*/ | ||
@Deprecated | ||
public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { | ||
this(objectMapper, "", messageEndpoint, sseEndpoint); | ||
this(objectMapper, "", messageEndpoint, sseEndpoint, null); | ||
} | ||
|
||
/** | ||
|
@@ -144,10 +153,32 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag | |
* messages via HTTP POST. This endpoint will be communicated to clients through the | ||
* SSE connection's initial endpoint event. | ||
* @param sseEndpoint The endpoint URI where clients establish their SSE connections. | ||
* @deprecated Use {@link #builder()} instead. | ||
* @throws IllegalArgumentException if any parameter is null | ||
*/ | ||
@Deprecated | ||
public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, | ||
String sseEndpoint) { | ||
this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null); | ||
} | ||
|
||
/** | ||
* Constructs a new WebMvcSseServerTransportProvider instance with DNS rebinding | ||
* protection. | ||
* @param objectMapper The ObjectMapper to use for JSON serialization/deserialization | ||
* of messages. | ||
* @param baseUrl The base URL for the message endpoint, used to construct the full | ||
* endpoint URL for clients. | ||
* @param messageEndpoint The endpoint URI where clients should send their JSON-RPC | ||
* messages via HTTP POST. This endpoint will be communicated to clients through the | ||
* SSE connection's initial endpoint event. | ||
* @param sseEndpoint The endpoint URI where clients establish their SSE connections. | ||
* @param dnsRebindingProtection The DNS rebinding protection configuration (may be | ||
* null). | ||
* @throws IllegalArgumentException if any required parameter is null | ||
*/ | ||
private WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, | ||
String sseEndpoint, DnsRebindingProtection dnsRebindingProtection) { | ||
Assert.notNull(objectMapper, "ObjectMapper must not be null"); | ||
Assert.notNull(baseUrl, "Message base URL must not be null"); | ||
Assert.notNull(messageEndpoint, "Message endpoint must not be null"); | ||
|
@@ -157,6 +188,7 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUr | |
this.baseUrl = baseUrl; | ||
this.messageEndpoint = messageEndpoint; | ||
this.sseEndpoint = sseEndpoint; | ||
this.dnsRebindingProtection = dnsRebindingProtection; | ||
this.routerFunction = RouterFunctions.route() | ||
.GET(this.sseEndpoint, this::handleSseConnection) | ||
.POST(this.messageEndpoint, this::handleMessage) | ||
|
@@ -246,6 +278,12 @@ private ServerResponse handleSseConnection(ServerRequest request) { | |
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); | ||
} | ||
|
||
// Validate headers | ||
ServerResponse validationError = validateDnsRebindingProtection(request); | ||
if (validationError != null) { | ||
return validationError; | ||
} | ||
|
||
String sessionId = UUID.randomUUID().toString(); | ||
logger.debug("Creating new SSE connection for session: {}", sessionId); | ||
|
||
|
@@ -299,6 +337,19 @@ private ServerResponse handleMessage(ServerRequest request) { | |
return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); | ||
} | ||
|
||
// Always validate Content-Type for POST requests | ||
String contentType = request.headers().asHttpHeaders().getFirst("Content-Type"); | ||
if (contentType == null || !contentType.toLowerCase().startsWith("application/json")) { | ||
logger.warn("Invalid Content-Type header: '{}'", contentType); | ||
return ServerResponse.badRequest().body(new McpError("Content-Type must be application/json")); | ||
} | ||
|
||
// Validate headers for POST requests if DNS rebinding protection is configured | ||
ServerResponse validationError = validateDnsRebindingProtection(request); | ||
if (validationError != null) { | ||
return validationError; | ||
} | ||
|
||
if (request.param("sessionId").isEmpty()) { | ||
return ServerResponse.badRequest().body(new McpError("Session ID missing in message endpoint")); | ||
} | ||
|
@@ -416,4 +467,23 @@ public void close() { | |
|
||
} | ||
|
||
/** | ||
* Validates DNS rebinding protection for the given request. | ||
* @param request The incoming server request | ||
* @return A ServerResponse with forbidden status if validation fails, or null if | ||
* validation passes | ||
*/ | ||
private ServerResponse validateDnsRebindingProtection(ServerRequest request) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe a similar strategy as in the |
||
if (dnsRebindingProtection != null) { | ||
String hostHeader = request.headers().asHttpHeaders().getFirst("Host"); | ||
String originHeader = request.headers().asHttpHeaders().getFirst("Origin"); | ||
if (!dnsRebindingProtection.isValid(hostHeader, originHeader)) { | ||
logger.warn("DNS rebinding protection validation failed - Host: '{}', Origin: '{}'", hostHeader, | ||
originHeader); | ||
return ServerResponse.status(HttpStatus.FORBIDDEN).body("DNS rebinding protection validation failed"); | ||
} | ||
} | ||
return null; | ||
} | ||
|
||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not a valid style to return a
null
Mono
reference. If there's a synchronous action that does not block, simply return aboolean
to drive the behaviour.boolean passesDnsRebindingProtectionCheck(ServerRequest request)
Mono<ServerResponse failDnsRebindingProtection()
I think the style you're using in this PR is inspired by errors from Go. The style present in the codebase of this project should be consistent.