Skip to content

Commit 64493be

Browse files
authored
mcp: reject requests with bad IDs (#202)
Add an isRequest field to methodInfo, used it to reject non-notification requests that lack a valid ID. Furthermore, lift this validation to the transport layer for HTTP server transports, so that we can preemptively reject bad HTTP requests. Fixes #194 Fixes #197
1 parent 3ac4ca9 commit 64493be

File tree

8 files changed

+178
-39
lines changed

8 files changed

+178
-39
lines changed

mcp/client.go

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -287,16 +287,16 @@ func (c *Client) AddReceivingMiddleware(middleware ...Middleware[*ClientSession]
287287

288288
// clientMethodInfos maps from the RPC method name to serverMethodInfos.
289289
var clientMethodInfos = map[string]methodInfo{
290-
methodComplete: newMethodInfo(sessionMethod((*ClientSession).Complete)),
291-
methodPing: newMethodInfo(sessionMethod((*ClientSession).ping)),
292-
methodListRoots: newMethodInfo(clientMethod((*Client).listRoots)),
293-
methodCreateMessage: newMethodInfo(clientMethod((*Client).createMessage)),
294-
notificationToolListChanged: newMethodInfo(clientMethod((*Client).callToolChangedHandler)),
295-
notificationPromptListChanged: newMethodInfo(clientMethod((*Client).callPromptChangedHandler)),
296-
notificationResourceListChanged: newMethodInfo(clientMethod((*Client).callResourceChangedHandler)),
297-
notificationResourceUpdated: newMethodInfo(clientMethod((*Client).callResourceUpdatedHandler)),
298-
notificationLoggingMessage: newMethodInfo(clientMethod((*Client).callLoggingHandler)),
299-
notificationProgress: newMethodInfo(sessionMethod((*ClientSession).callProgressNotificationHandler)),
290+
methodComplete: newMethodInfo(sessionMethod((*ClientSession).Complete), true),
291+
methodPing: newMethodInfo(sessionMethod((*ClientSession).ping), true),
292+
methodListRoots: newMethodInfo(clientMethod((*Client).listRoots), true),
293+
methodCreateMessage: newMethodInfo(clientMethod((*Client).createMessage), true),
294+
notificationToolListChanged: newMethodInfo(clientMethod((*Client).callToolChangedHandler), false),
295+
notificationPromptListChanged: newMethodInfo(clientMethod((*Client).callPromptChangedHandler), false),
296+
notificationResourceListChanged: newMethodInfo(clientMethod((*Client).callResourceChangedHandler), false),
297+
notificationResourceUpdated: newMethodInfo(clientMethod((*Client).callResourceUpdatedHandler), false),
298+
notificationLoggingMessage: newMethodInfo(clientMethod((*Client).callLoggingHandler), false),
299+
notificationProgress: newMethodInfo(sessionMethod((*ClientSession).callProgressNotificationHandler), false),
300300
}
301301

302302
func (cs *ClientSession) sendingMethodInfos() map[string]methodInfo {
@@ -323,7 +323,7 @@ func (cs *ClientSession) receivingMethodHandler() methodHandler {
323323
return cs.client.receivingMethodHandler_
324324
}
325325

326-
// getConn implements [session.getConn].
326+
// getConn implements [Session.getConn].
327327
func (cs *ClientSession) getConn() *jsonrpc2.Connection { return cs.conn }
328328

329329
func (*ClientSession) ping(context.Context, *PingParams) (*emptyResult, error) {

mcp/server.go

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -688,22 +688,22 @@ func (s *Server) AddReceivingMiddleware(middleware ...Middleware[*ServerSession]
688688

689689
// serverMethodInfos maps from the RPC method name to serverMethodInfos.
690690
var serverMethodInfos = map[string]methodInfo{
691-
methodComplete: newMethodInfo(serverMethod((*Server).complete)),
692-
methodInitialize: newMethodInfo(sessionMethod((*ServerSession).initialize)),
693-
methodPing: newMethodInfo(sessionMethod((*ServerSession).ping)),
694-
methodListPrompts: newMethodInfo(serverMethod((*Server).listPrompts)),
695-
methodGetPrompt: newMethodInfo(serverMethod((*Server).getPrompt)),
696-
methodListTools: newMethodInfo(serverMethod((*Server).listTools)),
697-
methodCallTool: newMethodInfo(serverMethod((*Server).callTool)),
698-
methodListResources: newMethodInfo(serverMethod((*Server).listResources)),
699-
methodListResourceTemplates: newMethodInfo(serverMethod((*Server).listResourceTemplates)),
700-
methodReadResource: newMethodInfo(serverMethod((*Server).readResource)),
701-
methodSetLevel: newMethodInfo(sessionMethod((*ServerSession).setLevel)),
702-
methodSubscribe: newMethodInfo(serverMethod((*Server).subscribe)),
703-
methodUnsubscribe: newMethodInfo(serverMethod((*Server).unsubscribe)),
704-
notificationInitialized: newMethodInfo(serverMethod((*Server).callInitializedHandler)),
705-
notificationRootsListChanged: newMethodInfo(serverMethod((*Server).callRootsListChangedHandler)),
706-
notificationProgress: newMethodInfo(sessionMethod((*ServerSession).callProgressNotificationHandler)),
691+
methodComplete: newMethodInfo(serverMethod((*Server).complete), true),
692+
methodInitialize: newMethodInfo(sessionMethod((*ServerSession).initialize), true),
693+
methodPing: newMethodInfo(sessionMethod((*ServerSession).ping), true),
694+
methodListPrompts: newMethodInfo(serverMethod((*Server).listPrompts), true),
695+
methodGetPrompt: newMethodInfo(serverMethod((*Server).getPrompt), true),
696+
methodListTools: newMethodInfo(serverMethod((*Server).listTools), true),
697+
methodCallTool: newMethodInfo(serverMethod((*Server).callTool), true),
698+
methodListResources: newMethodInfo(serverMethod((*Server).listResources), true),
699+
methodListResourceTemplates: newMethodInfo(serverMethod((*Server).listResourceTemplates), true),
700+
methodReadResource: newMethodInfo(serverMethod((*Server).readResource), true),
701+
methodSetLevel: newMethodInfo(sessionMethod((*ServerSession).setLevel), true),
702+
methodSubscribe: newMethodInfo(serverMethod((*Server).subscribe), true),
703+
methodUnsubscribe: newMethodInfo(serverMethod((*Server).unsubscribe), true),
704+
notificationInitialized: newMethodInfo(serverMethod((*Server).callInitializedHandler), false),
705+
notificationRootsListChanged: newMethodInfo(serverMethod((*Server).callRootsListChangedHandler), false),
706+
notificationProgress: newMethodInfo(sessionMethod((*ServerSession).callProgressNotificationHandler), false),
707707
}
708708

709709
func (ss *ServerSession) sendingMethodInfos() map[string]methodInfo { return clientMethodInfos }

mcp/shared.go

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,9 @@ func defaultReceivingMethodHandler[S Session](ctx context.Context, session S, me
123123
}
124124

125125
func handleReceive[S Session](ctx context.Context, session S, req *jsonrpc.Request) (Result, error) {
126-
info, ok := session.receivingMethodInfos()[req.Method]
127-
if !ok {
128-
return nil, jsonrpc2.ErrNotHandled
126+
info, err := checkRequest(req, session.receivingMethodInfos())
127+
if err != nil {
128+
return nil, err
129129
}
130130
params, err := info.unmarshalParams(req.Params)
131131
if err != nil {
@@ -141,8 +141,30 @@ func handleReceive[S Session](ctx context.Context, session S, req *jsonrpc.Reque
141141
return res, nil
142142
}
143143

144+
// checkRequest checks the given request against the provided method info, to
145+
// ensure it is a valid MCP request.
146+
//
147+
// If valid, the relevant method info is returned. Otherwise, a non-nil error
148+
// is returned describing why the request is invalid.
149+
//
150+
// This is extracted from request handling so that it can be called in the
151+
// transport layer to preemptively reject bad requests.
152+
func checkRequest(req *jsonrpc.Request, infos map[string]methodInfo) (methodInfo, error) {
153+
info, ok := infos[req.Method]
154+
if !ok {
155+
return methodInfo{}, fmt.Errorf("%w: %q unsupported", jsonrpc2.ErrNotHandled, req.Method)
156+
}
157+
if info.isRequest && !req.ID.IsValid() {
158+
return methodInfo{}, fmt.Errorf("%w: %q missing ID", jsonrpc2.ErrInvalidRequest, req.Method)
159+
}
160+
return info, nil
161+
}
162+
144163
// methodInfo is information about sending and receiving a method.
145164
type methodInfo struct {
165+
// isRequest reports whether the method is a JSON-RPC request.
166+
// Otherwise, the method is treated as a notification.
167+
isRequest bool
146168
// Unmarshal params from the wire into a Params struct.
147169
// Used on the receive side.
148170
unmarshalParams func(json.RawMessage) (Params, error)
@@ -169,8 +191,12 @@ type paramsPtr[T any] interface {
169191
}
170192

171193
// newMethodInfo creates a methodInfo from a typedMethodHandler.
172-
func newMethodInfo[S Session, P paramsPtr[T], R Result, T any](d typedMethodHandler[S, P, R]) methodInfo {
194+
//
195+
// If isRequest is set, the method is treated as a request rather than a
196+
// notification.
197+
func newMethodInfo[S Session, P paramsPtr[T], R Result, T any](d typedMethodHandler[S, P, R], isRequest bool) methodInfo {
173198
return methodInfo{
199+
isRequest: isRequest,
174200
unmarshalParams: func(m json.RawMessage) (Params, error) {
175201
var p P
176202
if m != nil {

mcp/sse.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,12 @@ func (t *SSEServerTransport) ServeHTTP(w http.ResponseWriter, req *http.Request)
129129
http.Error(w, "failed to parse body", http.StatusBadRequest)
130130
return
131131
}
132+
if req, ok := msg.(*jsonrpc.Request); ok {
133+
if _, err := checkRequest(req, serverMethodInfos); err != nil {
134+
http.Error(w, err.Error(), http.StatusBadRequest)
135+
return
136+
}
137+
}
132138
select {
133139
case t.incoming <- msg:
134140
w.WriteHeader(http.StatusAccepted)

mcp/sse_test.go

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
package mcp
66

77
import (
8+
"bytes"
89
"context"
910
"fmt"
11+
"io"
1012
"net/http"
1113
"net/http/httptest"
1214
"sync/atomic"
@@ -24,10 +26,10 @@ func TestSSEServer(t *testing.T) {
2426

2527
sseHandler := NewSSEHandler(func(*http.Request) *Server { return server })
2628

27-
conns := make(chan *ServerSession, 1)
28-
sseHandler.onConnection = func(cc *ServerSession) {
29+
serverSessions := make(chan *ServerSession, 1)
30+
sseHandler.onConnection = func(ss *ServerSession) {
2931
select {
30-
case conns <- cc:
32+
case serverSessions <- ss:
3133
default:
3234
}
3335
}
@@ -54,7 +56,7 @@ func TestSSEServer(t *testing.T) {
5456
if err := cs.Ping(ctx, nil); err != nil {
5557
t.Fatal(err)
5658
}
57-
ss := <-conns
59+
ss := <-serverSessions
5860
gotHi, err := cs.CallTool(ctx, &CallToolParams{
5961
Name: "greet",
6062
Arguments: map[string]any{"Name": "user"},
@@ -76,6 +78,39 @@ func TestSSEServer(t *testing.T) {
7678
t.Error("Expected custom HTTP client to be used, but it wasn't")
7779
}
7880

81+
t.Run("badrequests", func(t *testing.T) {
82+
msgEndpoint := cs.mcpConn.(*sseClientConn).msgEndpoint.String()
83+
84+
// Test some invalid data, and verify that we get 400s.
85+
badRequests := []struct {
86+
name string
87+
body string
88+
responseContains string
89+
}{
90+
{"not a method", `{"jsonrpc":"2.0", "method":"notamethod"}`, "not handled"},
91+
{"missing ID", `{"jsonrpc":"2.0", "method":"ping"}`, "missing ID"},
92+
}
93+
for _, r := range badRequests {
94+
t.Run(r.name, func(t *testing.T) {
95+
resp, err := http.Post(msgEndpoint, "application/json", bytes.NewReader([]byte(r.body)))
96+
if err != nil {
97+
t.Fatal(err)
98+
}
99+
defer resp.Body.Close()
100+
if got, want := resp.StatusCode, http.StatusBadRequest; got != want {
101+
t.Errorf("Sending bad request %q: got status %d, want %d", r.body, got, want)
102+
}
103+
result, err := io.ReadAll(resp.Body)
104+
if err != nil {
105+
t.Fatalf("Reading response: %v", err)
106+
}
107+
if !bytes.Contains(result, []byte(r.responseContains)) {
108+
t.Errorf("Response body does not contain %q:\n%s", r.responseContains, string(result))
109+
}
110+
})
111+
}
112+
})
113+
79114
// Test that closing either end of the connection terminates the other
80115
// end.
81116
if closeServerFirst {

mcp/streamable.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,8 +387,16 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R
387387
}
388388
requests := make(map[jsonrpc.ID]struct{})
389389
for _, msg := range incoming {
390-
if req, ok := msg.(*jsonrpc.Request); ok && req.ID.IsValid() {
391-
requests[req.ID] = struct{}{}
390+
if req, ok := msg.(*jsonrpc.Request); ok {
391+
// Preemptively check that this is a valid request, so that we can fail
392+
// the HTTP request. If we didn't do this, a request with a bad method or
393+
// missing ID could be silently swallowed.
394+
if _, err := checkRequest(req, serverMethodInfos); err != nil {
395+
return http.StatusBadRequest, err.Error()
396+
}
397+
if req.ID.IsValid() {
398+
requests[req.ID] = struct{}{}
399+
}
392400
}
393401
}
394402

mcp/streamable_test.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ func TestStreamableServerTransport(t *testing.T) {
280280
}
281281

282282
// Predefined steps, to avoid repetition below.
283-
initReq := req(1, "initialize", &InitializeParams{})
283+
initReq := req(1, methodInitialize, &InitializeParams{})
284284
initResp := resp(1, &InitializeResult{
285285
Capabilities: &serverCapabilities{
286286
Completions: &completionCapabilities{},
@@ -290,7 +290,7 @@ func TestStreamableServerTransport(t *testing.T) {
290290
ProtocolVersion: latestProtocolVersion,
291291
ServerInfo: &Implementation{Name: "testServer", Version: "v1.0.0"},
292292
}, nil)
293-
initializedMsg := req(0, "initialized", &InitializedParams{})
293+
initializedMsg := req(0, notificationInitialized, &InitializedParams{})
294294
initialize := step{
295295
Method: "POST",
296296
Send: []jsonrpc.Message{initReq},
@@ -438,6 +438,16 @@ func TestStreamableServerTransport(t *testing.T) {
438438
Method: "DELETE",
439439
StatusCode: http.StatusBadRequest,
440440
},
441+
{
442+
Method: "POST",
443+
Send: []jsonrpc.Message{req(1, "notamethod", nil)},
444+
StatusCode: http.StatusBadRequest, // notamethod is an invalid method
445+
},
446+
{
447+
Method: "POST",
448+
Send: []jsonrpc.Message{req(0, "tools/call", &CallToolParams{Name: "tool"})},
449+
StatusCode: http.StatusBadRequest, // tools/call must have an ID
450+
},
441451
{
442452
Method: "POST",
443453
Send: []jsonrpc.Message{req(2, "tools/call", &CallToolParams{Name: "tool"})},
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
Check robustness to missing fields: servers should reject and otherwise ignore
2+
bad requests.
3+
4+
Fixed bugs:
5+
- No id in 'initialize' should not panic (#197).
6+
- No id in 'ping' should not panic (#194).
7+
8+
TODO:
9+
- No params in 'initialize' should not panic (#195).
10+
11+
-- prompts --
12+
code_review
13+
14+
-- client --
15+
{
16+
"jsonrpc": "2.0",
17+
"method": "initialize",
18+
"params": {
19+
"protocolVersion": "2024-11-05",
20+
"capabilities": {},
21+
"clientInfo": { "name": "ExampleClient", "version": "1.0.0" }
22+
}
23+
}
24+
{
25+
"jsonrpc": "2.0",
26+
"id": 2,
27+
"method": "initialize",
28+
"params": {
29+
"protocolVersion": "2024-11-05",
30+
"capabilities": {},
31+
"clientInfo": { "name": "ExampleClient", "version": "1.0.0" }
32+
}
33+
}
34+
{"jsonrpc":"2.0", "method":"ping"}
35+
36+
-- server --
37+
{
38+
"jsonrpc": "2.0",
39+
"id": 2,
40+
"result": {
41+
"capabilities": {
42+
"completions": {},
43+
"logging": {},
44+
"prompts": {
45+
"listChanged": true
46+
}
47+
},
48+
"protocolVersion": "2024-11-05",
49+
"serverInfo": {
50+
"name": "testServer",
51+
"version": "v1.0.0"
52+
}
53+
}
54+
}

0 commit comments

Comments
 (0)