diff --git a/server/session.go b/server/session.go index 165ecaedd..3544cabed 100644 --- a/server/session.go +++ b/server/session.go @@ -48,6 +48,13 @@ type SessionWithClientInfo interface { SetClientInfo(clientInfo mcp.Implementation) } +// SessionWithParams is an extension of ClientSession that can store session parameters +type SessionWithParams interface { + ClientSession + // Params returns the parameters associated with the session. + Params() map[string]string +} + // SessionWithStreamableHTTPConfig extends ClientSession to support streamable HTTP transport configurations type SessionWithStreamableHTTPConfig interface { ClientSession diff --git a/server/session_test.go b/server/session_test.go index 22da95714..aa8a61dbd 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "sync" "sync/atomic" "testing" @@ -172,12 +173,72 @@ func (f *sessionTestClientWithLogging) GetLogLevel() mcp.LoggingLevel { return level.(mcp.LoggingLevel) } +// sessionTestClientWithParams implements the SessionWithParams interface for testing +type sessionTestClientWithParams struct { + sessionID string + notificationChannel chan mcp.JSONRPCNotification + initialized bool + params map[string]string + mu sync.RWMutex // Mutex to protect concurrent access to params +} + +func (f *sessionTestClientWithParams) SessionID() string { + return f.sessionID +} + +func (f *sessionTestClientWithParams) NotificationChannel() chan<- mcp.JSONRPCNotification { + return f.notificationChannel +} + +func (f *sessionTestClientWithParams) Initialize() { + f.initialized = true +} + +func (f *sessionTestClientWithParams) Initialized() bool { + return f.initialized +} + +func (f *sessionTestClientWithParams) Params() map[string]string { + f.mu.RLock() + defer f.mu.RUnlock() + + // Return a copy of the map to prevent concurrent modification + if f.params == nil { + return nil + } + + paramsCopy := make(map[string]string, len(f.params)) + for k, v := range f.params { + paramsCopy[k] = v + } + return paramsCopy +} + +// SetParams sets the params (for testing purposes) +func (f *sessionTestClientWithParams) SetParams(params map[string]string) { + f.mu.Lock() + defer f.mu.Unlock() + + // Create a copy of the map to prevent concurrent modification + if params == nil { + f.params = nil + return + } + + paramsCopy := make(map[string]string, len(params)) + for k, v := range params { + paramsCopy[k] = v + } + f.params = paramsCopy +} + // Verify that all implementations satisfy their respective interfaces var ( _ ClientSession = (*sessionTestClient)(nil) _ SessionWithTools = (*sessionTestClientWithTools)(nil) _ SessionWithLogging = (*sessionTestClientWithLogging)(nil) _ SessionWithClientInfo = (*sessionTestClientWithClientInfo)(nil) + _ SessionWithParams = (*sessionTestClientWithParams)(nil) ) func TestSessionWithTools_Integration(t *testing.T) { @@ -1507,3 +1568,112 @@ func TestMCPServer_LoggingNotificationFormat(t *testing.T) { }) } } + +func TestSessionWithParams_Integration(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0") + + // Create a session with params + testParams := map[string]string{ + "tenant_id": "test-tenant-123", + "user_id": "user-456", + "environment": "development", + } + + session := &sessionTestClientWithParams{ + sessionID: "session-1", + notificationChannel: make(chan mcp.JSONRPCNotification, 10), + initialized: true, + params: testParams, + } + + // Register the session + err := server.RegisterSession(context.Background(), session) + require.NoError(t, err) + + // Test that we can access the session from context + sessionCtx := server.WithContext(context.Background(), session) + + // Check if the session was stored in the context correctly + s := ClientSessionFromContext(sessionCtx) + require.NotNil(t, s, "Session should be available from context") + assert.Equal(t, session.SessionID(), s.SessionID(), "Session ID should match") + + // Check if the session can be cast to SessionWithParams + swp, ok := s.(SessionWithParams) + require.True(t, ok, "Session should implement SessionWithParams") + + // Test accessing params + params := swp.Params() + require.NotNil(t, params, "Session params should be available") + require.Len(t, params, 3, "Should have 3 params") + assert.Equal(t, "test-tenant-123", params["tenant_id"], "tenant_id should match") + assert.Equal(t, "user-456", params["user_id"], "user_id should match") + assert.Equal(t, "development", params["environment"], "environment should match") + + // Test that params are returned as a copy (not the original map) + params["tenant_id"] = "modified-tenant" + originalParams := swp.Params() + assert.Equal(t, "test-tenant-123", originalParams["tenant_id"], "Original params should not be modified") + + t.Run("test concurrent access", func(t *testing.T) { + // Test concurrent access to params + var wg sync.WaitGroup + errors := make(chan error, 10) + + // Start multiple goroutines reading params + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 10; j++ { + params := swp.Params() + if params == nil { + errors <- fmt.Errorf("params should not be nil") + return + } + if len(params) != 3 { + errors <- fmt.Errorf("expected 3 params, got %d", len(params)) + return + } + } + }() + } + + // Start goroutines modifying params + for i := 0; i < 3; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + newParams := map[string]string{ + "tenant_id": fmt.Sprintf("tenant-%d", idx), + "user_id": fmt.Sprintf("user-%d", idx), + "environment": "test", + } + session.SetParams(newParams) + }(i) + } + + wg.Wait() + close(errors) + + // Check for any errors during concurrent access + for err := range errors { + t.Error(err) + } + }) + + t.Run("test nil params", func(t *testing.T) { + // Test with nil params + session.SetParams(nil) + params := swp.Params() + assert.Nil(t, params, "Params should be nil when set to nil") + }) + + t.Run("test empty params", func(t *testing.T) { + // Test with empty params + session.SetParams(map[string]string{}) + params := swp.Params() + require.NotNil(t, params, "Params should not be nil for empty map") + assert.Len(t, params, 0, "Params should be empty") + }) +} diff --git a/server/sse.go b/server/sse.go index c20e67820..f5f2092e9 100644 --- a/server/sse.go +++ b/server/sse.go @@ -30,6 +30,7 @@ type sseSession struct { loggingLevel atomic.Value tools sync.Map // stores session-specific tools clientInfo atomic.Value // stores session-specific client info + params map[string]string } // SSEContextFunc is a function that takes an existing context and the current @@ -48,6 +49,10 @@ func (s *sseSession) SessionID() string { return s.sessionID } +func (s *sseSession) Params() map[string]string { + return s.params +} + func (s *sseSession) NotificationChannel() chan<- mcp.JSONRPCNotification { return s.notificationChannel } @@ -347,12 +352,18 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { return } + params := make(map[string]string) + for k, v := range r.URL.Query() { + params[k] = v[0] + } + sessionID := uuid.New().String() session := &sseSession{ done: make(chan struct{}), eventQueue: make(chan string, 100), // Buffer for events sessionID: sessionID, notificationChannel: make(chan mcp.JSONRPCNotification, 100), + params: params, } s.sessions.Store(sessionID, session) diff --git a/server/sse_test.go b/server/sse_test.go index 2a2b03b08..08a1ed409 100644 --- a/server/sse_test.go +++ b/server/sse_test.go @@ -1606,6 +1606,182 @@ func TestSSEServer(t *testing.T) { t.Error("Headers check hook was not called within timeout") } }) + + t.Run("Parameters parsing from URL query", func(t *testing.T) { + mcpServer := NewMCPServer("test", "1.0.0", WithToolCapabilities(true)) + + // Add a tool that can access session params to verify they're passed correctly + var capturedParams map[string]string + mcpServer.AddTool( + mcp.NewTool("get-params"), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + session := ClientSessionFromContext(ctx) + if session == nil { + return nil, fmt.Errorf("no session in context") + } + + sessionWithParams, ok := session.(SessionWithParams) + if !ok { + return nil, fmt.Errorf("session does not implement SessionWithParams") + } + + capturedParams = sessionWithParams.Params() + return mcp.NewToolResultText("params captured"), nil + }, + ) + + testServer := NewTestServer(mcpServer) + defer testServer.Close() + + // Test with various query parameters + testCases := []struct { + name string + queryString string + expectedParams map[string]string + }{ + { + name: "single parameter", + queryString: "?tenant_id=test-tenant-123", + expectedParams: map[string]string{ + "tenant_id": "test-tenant-123", + }, + }, + { + name: "multiple parameters", + queryString: "?tenant_id=test-tenant-123&user_id=user-456&environment=development", + expectedParams: map[string]string{ + "tenant_id": "test-tenant-123", + "user_id": "user-456", + "environment": "development", + }, + }, + { + name: "no parameters", + queryString: "", + expectedParams: map[string]string{}, + }, + { + name: "parameters with special characters", + queryString: "?key1=value%20with%20spaces&key2=value%2Bwith%2Bplus", + expectedParams: map[string]string{ + "key1": "value with spaces", + "key2": "value+with+plus", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + capturedParams = nil // Reset captured params + + // Connect to SSE endpoint with query parameters + sseURL := fmt.Sprintf("%s/sse%s", testServer.URL, tc.queryString) + sseResp, err := http.Get(sseURL) + if err != nil { + t.Fatalf("Failed to connect to SSE endpoint: %v", err) + } + defer sseResp.Body.Close() + + // Read the endpoint event + endpointEvent, err := readSSEEvent(sseResp) + if err != nil { + t.Fatalf("Failed to read SSE response: %v", err) + } + if !strings.Contains(endpointEvent, "event: endpoint") { + t.Fatalf("Expected endpoint event, got: %s", endpointEvent) + } + + // Extract message endpoint URL + messageURL := strings.TrimSpace( + strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0], + ) + + // Send initialize request + initRequest := map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": map[string]any{ + "protocolVersion": "2024-11-05", + "clientInfo": map[string]any{ + "name": "test-client", + "version": "1.0.0", + }, + }, + } + + requestBody, err := json.Marshal(initRequest) + if err != nil { + t.Fatalf("Failed to marshal init request: %v", err) + } + + initResp, err := http.Post(messageURL, "application/json", bytes.NewReader(requestBody)) + if err != nil { + t.Fatalf("Failed to send init request: %v", err) + } + defer initResp.Body.Close() + + if initResp.StatusCode != http.StatusAccepted { + t.Fatalf("Expected status 202, got %d", initResp.StatusCode) + } + + // Send a request to get params + paramsRequest := map[string]any{ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": map[string]any{ + "name": "get-params", + "arguments": map[string]any{}, + }, + } + + requestBody, err = json.Marshal(paramsRequest) + if err != nil { + t.Fatalf("Failed to marshal params request: %v", err) + } + + paramsResp, err := http.Post(messageURL, "application/json", bytes.NewReader(requestBody)) + if err != nil { + t.Fatalf("Failed to send params request: %v", err) + } + defer paramsResp.Body.Close() + + if paramsResp.StatusCode != http.StatusAccepted { + t.Fatalf("Expected status 202, got %d", paramsResp.StatusCode) + } + + // Verify captured params match expected + if capturedParams == nil { + t.Fatal("No params were captured") + } + + // Check length + if len(capturedParams) != len(tc.expectedParams) { + t.Errorf("Expected %d params, got %d: %v", len(tc.expectedParams), len(capturedParams), capturedParams) + } + + // Check each expected param + for expectedKey, expectedValue := range tc.expectedParams { + actualValue, exists := capturedParams[expectedKey] + if !exists { + t.Errorf("Expected param %s not found in captured params: %v", expectedKey, capturedParams) + continue + } + if actualValue != expectedValue { + t.Errorf("Expected param %s to be %q, got %q", expectedKey, expectedValue, actualValue) + } + } + + // Check no unexpected params + for actualKey := range capturedParams { + if _, expected := tc.expectedParams[actualKey]; !expected { + t.Errorf("Unexpected param %s=%s found in captured params", actualKey, capturedParams[actualKey]) + } + } + }) + } + }) } func readSSEEvent(sseResp *http.Response) (string, error) { diff --git a/server/streamable_http.go b/server/streamable_http.go index b4d344abf..b5f4ca90e 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -237,6 +237,11 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request } isInitializeRequest := baseMessage.Method == mcp.MethodInitialize + params := make(map[string]string) + for k, v := range r.URL.Query() { + params[k] = v[0] + } + // Prepare the session for the mcp server // The session is ephemeral. Its life is the same as the request. It's only created // for interaction with the mcp server. @@ -259,7 +264,7 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request } } - session := newStreamableHttpSession(sessionID, s.sessionTools, s.sessionLogLevels) + session := newStreamableHttpSession(sessionID, s.sessionTools, s.sessionLogLevels, params) // Set the client context before handling the message ctx := s.server.WithContext(r.Context(), session) @@ -370,7 +375,12 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) sessionID = uuid.New().String() } - session := newStreamableHttpSession(sessionID, s.sessionTools, s.sessionLogLevels) + params := make(map[string]string) + for k, v := range r.URL.Query() { + params[k] = v[0] + } + + session := newStreamableHttpSession(sessionID, s.sessionTools, s.sessionLogLevels, params) if err := s.server.RegisterSession(r.Context(), session); err != nil { http.Error(w, fmt.Sprintf("Session registration failed: %v", err), http.StatusBadRequest) return @@ -587,14 +597,16 @@ type streamableHttpSession struct { tools *sessionToolsStore upgradeToSSE atomic.Bool logLevels *sessionLogLevelsStore + params map[string]string } -func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore, levels *sessionLogLevelsStore) *streamableHttpSession { +func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore, levels *sessionLogLevelsStore, params map[string]string) *streamableHttpSession { s := &streamableHttpSession{ sessionID: sessionID, notificationChannel: make(chan mcp.JSONRPCNotification, 100), tools: toolStore, logLevels: levels, + params: params, } return s } @@ -603,6 +615,10 @@ func (s *streamableHttpSession) SessionID() string { return s.sessionID } +func (s *streamableHttpSession) Params() map[string]string { + return s.params +} + func (s *streamableHttpSession) NotificationChannel() chan<- mcp.JSONRPCNotification { return s.notificationChannel } diff --git a/server/streamable_http_test.go b/server/streamable_http_test.go index 5be010a74..0f6742ce6 100644 --- a/server/streamable_http_test.go +++ b/server/streamable_http_test.go @@ -896,6 +896,171 @@ func TestStreamableHTTP_HeaderPassthrough(t *testing.T) { } } +func TestStreamableHTTP_ParameterParsing(t *testing.T) { + mcpServer := NewMCPServer("test", "1.0.0", WithToolCapabilities(true)) + + // Add a tool that can access session params to verify they're passed correctly + var capturedParams map[string]string + mcpServer.AddTool( + mcp.NewTool("get-params"), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + session := ClientSessionFromContext(ctx) + if session == nil { + return nil, fmt.Errorf("no session in context") + } + + sessionWithParams, ok := session.(SessionWithParams) + if !ok { + return nil, fmt.Errorf("session does not implement SessionWithParams") + } + + capturedParams = sessionWithParams.Params() + return mcp.NewToolResultText("params captured"), nil + }, + ) + + server := NewStreamableHTTPServer(mcpServer) + httpServer := httptest.NewServer(server) + defer httpServer.Close() + + // Test cases with different query parameters + testCases := []struct { + name string + queryString string + expectedParams map[string]string + }{ + { + name: "single parameter", + queryString: "?tenant_id=tenant-123", + expectedParams: map[string]string{ + "tenant_id": "tenant-123", + }, + }, + { + name: "multiple parameters", + queryString: "?tenant_id=tenant-123&user_id=user-456&environment=prod", + expectedParams: map[string]string{ + "tenant_id": "tenant-123", + "user_id": "user-456", + "environment": "prod", + }, + }, + { + name: "no parameters", + queryString: "", + expectedParams: map[string]string{}, + }, + { + name: "parameters with special characters", + queryString: "?key1=value%20with%20spaces&key2=value%2Bwith%2Bplus&key3=hello%26world", + expectedParams: map[string]string{ + "key1": "value with spaces", + "key2": "value+with+plus", + "key3": "hello&world", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Test POST endpoint (simpler and more reliable) + t.Run("POST endpoint", func(t *testing.T) { + capturedParams = nil + + testURL := httpServer.URL + tc.queryString + + // First initialize to get session ID + initRequest := map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": map[string]any{ + "protocolVersion": "2025-03-26", + "clientInfo": map[string]any{ + "name": "test-client", + "version": "1.0.0", + }, + }, + } + + initResp, err := postJSON(testURL, initRequest) + if err != nil { + t.Fatalf("Failed to send init request: %v", err) + } + defer initResp.Body.Close() + + if initResp.StatusCode != http.StatusOK { + t.Fatalf("Expected status 200 for init, got %d", initResp.StatusCode) + } + + // Get session ID from header + sessionID := initResp.Header.Get("Mcp-Session-Id") + if sessionID == "" { + t.Fatal("Expected session id in header") + } + + // Now call the tool with session ID + toolRequest := map[string]any{ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": map[string]any{ + "name": "get-params", + "arguments": map[string]any{}, + }, + } + + jsonBody, _ := json.Marshal(toolRequest) + req, _ := http.NewRequest("POST", testURL, bytes.NewReader(jsonBody)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Mcp-Session-Id", sessionID) + + toolResp, err := httpServer.Client().Do(req) + if err != nil { + t.Fatalf("Failed to send tool request: %v", err) + } + defer toolResp.Body.Close() + + if toolResp.StatusCode != http.StatusOK { + t.Fatalf("Expected status 200, got %d", toolResp.StatusCode) + } + + // Small delay to ensure async operations complete + time.Sleep(10 * time.Millisecond) + + // Verify captured params match expected + if capturedParams == nil { + t.Fatal("No params were captured") + } + + // Check length + if len(capturedParams) != len(tc.expectedParams) { + t.Errorf("Expected %d params, got %d: %v", len(tc.expectedParams), len(capturedParams), capturedParams) + } + + // Check each expected param + for expectedKey, expectedValue := range tc.expectedParams { + actualValue, exists := capturedParams[expectedKey] + if !exists { + t.Errorf("Expected param %s not found in captured params: %v", expectedKey, capturedParams) + continue + } + if actualValue != expectedValue { + t.Errorf("Expected param %s to be %q, got %q", expectedKey, expectedValue, actualValue) + } + } + + // Check no unexpected params + for actualKey := range capturedParams { + if _, expected := tc.expectedParams[actualKey]; !expected { + t.Errorf("Unexpected param %s=%s found in captured params", actualKey, capturedParams[actualKey]) + } + } + }) + }) + } +} + func postJSON(url string, bodyObject any) (*http.Response, error) { jsonBody, _ := json.Marshal(bodyObject) req, _ := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonBody))