Skip to content

Commit 3964d51

Browse files
ezynda3opencode
andauthored
feat: implement protocol version negotiation (#502)
* feat: implement protocol version negotiation Implement protocol version negotiation following the TypeScript SDK approach: - Update LATEST_PROTOCOL_VERSION to 2025-06-18 - Add client-side validation of server protocol version - Return UnsupportedProtocolVersionError for incompatible versions - Add Mcp-Protocol-Version header support for HTTP transports - Implement SetProtocolVersion method on HTTP connections - Add comprehensive tests for protocol negotiation This ensures both client and server agree on a mutually supported protocol version, preventing compatibility issues. 🤖 Generated with [opencode](https://opencode.ai) Co-Authored-By: opencode <[email protected]> * fmt * refactor: improve protocol negotiation implementation - Move HTTP header constants to common location to avoid duplication - Add errors.Is interface to UnsupportedProtocolVersionError for better Go error handling - Add comprehensive edge case tests for empty and malformed protocol versions - Ensure consistent header constant usage across client and server packages 🤖 Generated with [opencode](https://opencode.ai) Co-Authored-By: opencode <[email protected]> * fmt * fix: include protocol version header in DELETE request for session termination As per MCP specification, the MCP-Protocol-Version header must be included on all subsequent requests to the MCP server, including DELETE requests for terminating sessions. 🤖 Generated with [opencode](https://opencode.ai) Co-Authored-By: opencode <[email protected]> * fix: maintain backward compatibility for protocol version negotiation When the server does not receive an MCP-Protocol-Version header, it should assume protocol version 2025-03-26 for backward compatibility as per the MCP specification. 🤖 Generated with [opencode](https://opencode.ai) Co-Authored-By: opencode <[email protected]> * test: update tests to reflect backward compatibility behavior Tests now expect protocol version 2025-03-26 when no protocol version is provided during initialization, as per MCP specification for backward compatibility. 🤖 Generated with [opencode](https://opencode.ai) Co-Authored-By: opencode <[email protected]> --------- Co-authored-by: opencode <[email protected]>
1 parent 38ac77c commit 3964d51

19 files changed

+433
-108
lines changed

client/client.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"encoding/json"
66
"errors"
77
"fmt"
8+
"slices"
89
"sync"
910
"sync/atomic"
1011

@@ -22,6 +23,7 @@ type Client struct {
2223
requestID atomic.Int64
2324
clientCapabilities mcp.ClientCapabilities
2425
serverCapabilities mcp.ServerCapabilities
26+
protocolVersion string
2527
samplingHandler SamplingHandler
2628
}
2729

@@ -176,8 +178,19 @@ func (c *Client) Initialize(
176178
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
177179
}
178180

179-
// Store serverCapabilities
181+
// Validate protocol version
182+
if !slices.Contains(mcp.ValidProtocolVersions, result.ProtocolVersion) {
183+
return nil, mcp.UnsupportedProtocolVersionError{Version: result.ProtocolVersion}
184+
}
185+
186+
// Store serverCapabilities and protocol version
180187
c.serverCapabilities = result.Capabilities
188+
c.protocolVersion = result.ProtocolVersion
189+
190+
// Set protocol version on HTTP transports
191+
if httpConn, ok := c.transport.(transport.HTTPConnection); ok {
192+
httpConn.SetProtocolVersion(result.ProtocolVersion)
193+
}
181194

182195
// Send initialized notification
183196
notification := mcp.JSONRPCNotification{
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
package client
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"strings"
8+
"testing"
9+
10+
"github.com/mark3labs/mcp-go/client/transport"
11+
"github.com/mark3labs/mcp-go/mcp"
12+
)
13+
14+
// mockProtocolTransport implements transport.Interface for testing protocol negotiation
15+
type mockProtocolTransport struct {
16+
responses map[string]string
17+
notificationHandler func(mcp.JSONRPCNotification)
18+
started bool
19+
closed bool
20+
}
21+
22+
func (m *mockProtocolTransport) Start(ctx context.Context) error {
23+
m.started = true
24+
return nil
25+
}
26+
27+
func (m *mockProtocolTransport) SendRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
28+
responseStr, ok := m.responses[request.Method]
29+
if !ok {
30+
return nil, fmt.Errorf("no mock response for method %s", request.Method)
31+
}
32+
33+
return &transport.JSONRPCResponse{
34+
JSONRPC: "2.0",
35+
ID: request.ID,
36+
Result: json.RawMessage(responseStr),
37+
}, nil
38+
}
39+
40+
func (m *mockProtocolTransport) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error {
41+
return nil
42+
}
43+
44+
func (m *mockProtocolTransport) SetNotificationHandler(handler func(notification mcp.JSONRPCNotification)) {
45+
m.notificationHandler = handler
46+
}
47+
48+
func (m *mockProtocolTransport) Close() error {
49+
m.closed = true
50+
return nil
51+
}
52+
53+
func (m *mockProtocolTransport) GetSessionId() string {
54+
return "mock-session"
55+
}
56+
57+
func TestProtocolVersionNegotiation(t *testing.T) {
58+
tests := []struct {
59+
name string
60+
serverVersion string
61+
expectError bool
62+
errorContains string
63+
}{
64+
{
65+
name: "supported latest version",
66+
serverVersion: mcp.LATEST_PROTOCOL_VERSION,
67+
expectError: false,
68+
},
69+
{
70+
name: "supported older version 2025-03-26",
71+
serverVersion: "2025-03-26",
72+
expectError: false,
73+
},
74+
{
75+
name: "supported older version 2024-11-05",
76+
serverVersion: "2024-11-05",
77+
expectError: false,
78+
},
79+
{
80+
name: "unsupported version",
81+
serverVersion: "2023-01-01",
82+
expectError: true,
83+
errorContains: "unsupported protocol version",
84+
},
85+
{
86+
name: "unsupported future version",
87+
serverVersion: "2030-01-01",
88+
expectError: true,
89+
errorContains: "unsupported protocol version",
90+
},
91+
{
92+
name: "empty protocol version",
93+
serverVersion: "",
94+
expectError: true,
95+
errorContains: "unsupported protocol version",
96+
},
97+
{
98+
name: "malformed protocol version - invalid format",
99+
serverVersion: "not-a-date",
100+
expectError: true,
101+
errorContains: "unsupported protocol version",
102+
},
103+
{
104+
name: "malformed protocol version - partial date",
105+
serverVersion: "2025-06",
106+
expectError: true,
107+
errorContains: "unsupported protocol version",
108+
},
109+
{
110+
name: "malformed protocol version - just numbers",
111+
serverVersion: "20250618",
112+
expectError: true,
113+
errorContains: "unsupported protocol version",
114+
},
115+
}
116+
117+
for _, tt := range tests {
118+
t.Run(tt.name, func(t *testing.T) {
119+
// Create mock transport that returns specific version
120+
mockTransport := &mockProtocolTransport{
121+
responses: map[string]string{
122+
"initialize": fmt.Sprintf(`{
123+
"protocolVersion": "%s",
124+
"capabilities": {},
125+
"serverInfo": {"name": "test", "version": "1.0"}
126+
}`, tt.serverVersion),
127+
},
128+
}
129+
130+
client := NewClient(mockTransport)
131+
132+
_, err := client.Initialize(context.Background(), mcp.InitializeRequest{
133+
Params: mcp.InitializeParams{
134+
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
135+
ClientInfo: mcp.Implementation{Name: "test-client", Version: "1.0"},
136+
Capabilities: mcp.ClientCapabilities{},
137+
},
138+
})
139+
140+
if tt.expectError {
141+
if err == nil {
142+
t.Errorf("expected error but got none")
143+
} else if !strings.Contains(err.Error(), tt.errorContains) {
144+
t.Errorf("expected error containing %q, got %q", tt.errorContains, err.Error())
145+
}
146+
// Verify it's the correct error type
147+
if !mcp.IsUnsupportedProtocolVersion(err) {
148+
t.Errorf("expected UnsupportedProtocolVersionError, got %T", err)
149+
}
150+
} else {
151+
if err != nil {
152+
t.Errorf("unexpected error: %v", err)
153+
}
154+
// Verify the protocol version was stored
155+
if client.protocolVersion != tt.serverVersion {
156+
t.Errorf("expected protocol version %q, got %q", tt.serverVersion, client.protocolVersion)
157+
}
158+
}
159+
})
160+
}
161+
}
162+
163+
// mockHTTPTransport implements both transport.Interface and transport.HTTPConnection
164+
type mockHTTPTransport struct {
165+
mockProtocolTransport
166+
protocolVersion string
167+
}
168+
169+
func (m *mockHTTPTransport) SetProtocolVersion(version string) {
170+
m.protocolVersion = version
171+
}
172+
173+
func TestProtocolVersionHeaderSetting(t *testing.T) {
174+
// Create mock HTTP transport
175+
mockTransport := &mockHTTPTransport{
176+
mockProtocolTransport: mockProtocolTransport{
177+
responses: map[string]string{
178+
"initialize": fmt.Sprintf(`{
179+
"protocolVersion": "%s",
180+
"capabilities": {},
181+
"serverInfo": {"name": "test", "version": "1.0"}
182+
}`, mcp.LATEST_PROTOCOL_VERSION),
183+
},
184+
},
185+
}
186+
187+
client := NewClient(mockTransport)
188+
189+
_, err := client.Initialize(context.Background(), mcp.InitializeRequest{
190+
Params: mcp.InitializeParams{
191+
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
192+
ClientInfo: mcp.Implementation{Name: "test-client", Version: "1.0"},
193+
Capabilities: mcp.ClientCapabilities{},
194+
},
195+
})
196+
197+
if err != nil {
198+
t.Fatalf("unexpected error: %v", err)
199+
}
200+
201+
// Verify SetProtocolVersion was called on HTTP transport
202+
if mockTransport.protocolVersion != mcp.LATEST_PROTOCOL_VERSION {
203+
t.Errorf("expected SetProtocolVersion to be called with %q, got %q",
204+
mcp.LATEST_PROTOCOL_VERSION, mockTransport.protocolVersion)
205+
}
206+
}
207+
208+
func TestUnsupportedProtocolVersionError_Is(t *testing.T) {
209+
// Test that errors.Is works correctly with UnsupportedProtocolVersionError
210+
err1 := mcp.UnsupportedProtocolVersionError{Version: "2023-01-01"}
211+
err2 := mcp.UnsupportedProtocolVersionError{Version: "2024-01-01"}
212+
213+
// Test Is method
214+
if !err1.Is(err2) {
215+
t.Error("expected UnsupportedProtocolVersionError.Is to return true for same error type")
216+
}
217+
218+
// Test with different error type
219+
otherErr := fmt.Errorf("some other error")
220+
if err1.Is(otherErr) {
221+
t.Error("expected UnsupportedProtocolVersionError.Is to return false for different error type")
222+
}
223+
224+
// Test IsUnsupportedProtocolVersion helper
225+
if !mcp.IsUnsupportedProtocolVersion(err1) {
226+
t.Error("expected IsUnsupportedProtocolVersion to return true")
227+
}
228+
if mcp.IsUnsupportedProtocolVersion(otherErr) {
229+
t.Error("expected IsUnsupportedProtocolVersion to return false for different error type")
230+
}
231+
}

client/stdio_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ func TestStdioMCPClient(t *testing.T) {
9393
defer cancel()
9494

9595
request := mcp.InitializeRequest{}
96-
request.Params.ProtocolVersion = "1.0"
96+
request.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
9797
request.Params.ClientInfo = mcp.Implementation{
9898
Name: "test-client",
9999
Version: "1.0.0",

client/transport/constants.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package transport
2+
3+
// Common HTTP header constants used across transports
4+
const (
5+
HeaderKeySessionID = "Mcp-Session-Id"
6+
HeaderKeyProtocolVersion = "Mcp-Protocol-Version"
7+
)

client/transport/interface.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,13 @@ type BidirectionalInterface interface {
4747
SetRequestHandler(handler RequestHandler)
4848
}
4949

50+
// HTTPConnection is a Transport that runs over HTTP and supports
51+
// protocol version headers.
52+
type HTTPConnection interface {
53+
Interface
54+
SetProtocolVersion(version string)
55+
}
56+
5057
type JSONRPCRequest struct {
5158
JSONRPC string `json:"jsonrpc"`
5259
ID mcp.RequestId `json:"id"`

client/transport/sse.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ type SSE struct {
3737
started atomic.Bool
3838
closed atomic.Bool
3939
cancelSSEStream context.CancelFunc
40+
protocolVersion atomic.Value // string
4041

4142
// OAuth support
4243
oauthHandler *OAuthHandler
@@ -324,6 +325,12 @@ func (c *SSE) SendRequest(
324325

325326
// Set headers
326327
req.Header.Set("Content-Type", "application/json")
328+
// Set protocol version header if negotiated
329+
if v := c.protocolVersion.Load(); v != nil {
330+
if version, ok := v.(string); ok && version != "" {
331+
req.Header.Set(HeaderKeyProtocolVersion, version)
332+
}
333+
}
327334
for k, v := range c.headers {
328335
req.Header.Set(k, v)
329336
}
@@ -434,6 +441,11 @@ func (c *SSE) GetSessionId() string {
434441
return ""
435442
}
436443

444+
// SetProtocolVersion sets the negotiated protocol version for this connection.
445+
func (c *SSE) SetProtocolVersion(version string) {
446+
c.protocolVersion.Store(version)
447+
}
448+
437449
// SendNotification sends a JSON-RPC notification to the server without expecting a response.
438450
func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error {
439451
if c.endpoint == nil {
@@ -456,6 +468,12 @@ func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNoti
456468
}
457469

458470
req.Header.Set("Content-Type", "application/json")
471+
// Set protocol version header if negotiated
472+
if v := c.protocolVersion.Load(); v != nil {
473+
if version, ok := v.(string); ok && version != "" {
474+
req.Header.Set(HeaderKeyProtocolVersion, version)
475+
}
476+
}
459477
// Set custom HTTP headers
460478
for k, v := range c.headers {
461479
req.Header.Set(k, v)

client/transport/sse_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -408,9 +408,9 @@ func TestSSE(t *testing.T) {
408408
t.Run("SSEEventWithoutEventField", func(t *testing.T) {
409409
// Test that SSE events with only data field (no event field) are processed correctly
410410
// This tests the fix for issue #369
411-
411+
412412
var messageReceived chan struct{}
413-
413+
414414
// Create a custom mock server that sends SSE events without event field
415415
sseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
416416
w.Header().Set("Content-Type", "text/event-stream")
@@ -449,7 +449,7 @@ func TestSSE(t *testing.T) {
449449
messageHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
450450
w.Header().Set("Content-Type", "application/json")
451451
w.WriteHeader(http.StatusAccepted)
452-
452+
453453
// Signal that message was received
454454
close(messageReceived)
455455
})

0 commit comments

Comments
 (0)