From 881e0955972fd53001fbdfa41283ea0c56fc7185 Mon Sep 17 00:00:00 2001 From: andig Date: Sun, 27 Jul 2025 21:29:35 +0200 Subject: [PATCH 01/22] feat: implement sampling support for Streamable HTTP transport MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements sampling capability for HTTP transport, resolving issue #419. Enables servers to send sampling requests to HTTP clients via SSE and receive LLM-generated responses. ## Key Changes ### Core Implementation - Add `BidirectionalInterface` support to `StreamableHTTP` - Implement `SetRequestHandler` for server-to-client requests - Enhance SSE parsing to handle requests alongside responses/notifications - Add `handleIncomingRequest` and `sendResponseToServer` methods ### HTTP-Specific Features - Leverage existing MCP headers (`Mcp-Session-Id`, `Mcp-Protocol-Version`) - Bidirectional communication via HTTP POST for responses - Proper JSON-RPC request/response handling over HTTP ### Error Handling - Add specific JSON-RPC error codes for different failure scenarios: - `-32601` (Method not found) when no handler configured - `-32603` (Internal error) for sampling failures - `-32800` (Request cancelled/timeout) for context errors - Enhanced error messages with sampling-specific context ### Testing & Examples - Comprehensive test suite in `streamable_http_sampling_test.go` - Complete working example in `examples/sampling_http_client/` - Tests cover success flows, error scenarios, and interface compliance ## Technical Details The implementation maintains full backward compatibility while adding bidirectional communication support. Server requests are processed asynchronously to avoid blocking the SSE stream reader. HTTP transport now supports the complete sampling flow that was previously only available in stdio and inprocess transports. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- client/transport/streamable_http.go | 125 ++++++++- .../streamable_http_sampling_test.go | 245 ++++++++++++++++++ examples/sampling_http_client/README.md | 95 +++++++ examples/sampling_http_client/main.go | 108 ++++++++ 4 files changed, 572 insertions(+), 1 deletion(-) create mode 100644 client/transport/streamable_http_sampling_test.go create mode 100644 examples/sampling_http_client/README.md create mode 100644 examples/sampling_http_client/main.go diff --git a/client/transport/streamable_http.go b/client/transport/streamable_http.go index e8b2fcc58..22d5a038c 100644 --- a/client/transport/streamable_http.go +++ b/client/transport/streamable_http.go @@ -92,7 +92,6 @@ func WithSession(sessionID string) StreamableHTTPCOption { // The current implementation does not support the following features: // - resuming stream // (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#resumability-and-redelivery) -// - server -> client request type StreamableHTTP struct { serverURL *url.URL httpClient *http.Client @@ -110,6 +109,10 @@ type StreamableHTTP struct { notificationHandler func(mcp.JSONRPCNotification) notifyMu sync.RWMutex + // Request handler for incoming server-to-client requests (like sampling) + requestHandler RequestHandler + requestMu sync.RWMutex + closed chan struct{} // OAuth support @@ -406,6 +409,7 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl defer close(responseChan) c.readSSE(ctx, reader, func(event, data string) { + // Try to unmarshal as a response first var message JSONRPCResponse if err := json.Unmarshal([]byte(data), &message); err != nil { c.logger.Errorf("failed to unmarshal message: %v", err) @@ -427,6 +431,17 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl return } + // Check if this is actually a request from the server + // If Result and Error are nil, it might be a request + if message.Result == nil && message.Error == nil { + var request JSONRPCRequest + if err := json.Unmarshal([]byte(data), &request); err == nil { + // This is a request from the server + c.handleIncomingRequest(ctx, request) + return + } + } + if !ignoreResponse { responseChan <- &message } @@ -547,6 +562,13 @@ func (c *StreamableHTTP) SetNotificationHandler(handler func(mcp.JSONRPCNotifica c.notificationHandler = handler } +// SetRequestHandler sets the handler for incoming requests from the server. +func (c *StreamableHTTP) SetRequestHandler(handler RequestHandler) { + c.requestMu.Lock() + defer c.requestMu.Unlock() + c.requestHandler = handler +} + func (c *StreamableHTTP) GetSessionId() string { return c.sessionID.Load().(string) } @@ -627,6 +649,107 @@ func (c *StreamableHTTP) createGETConnectionToServer(ctx context.Context) error return nil } +// handleIncomingRequest processes requests from the server (like sampling requests) +func (c *StreamableHTTP) handleIncomingRequest(ctx context.Context, request JSONRPCRequest) { + c.requestMu.RLock() + handler := c.requestHandler + c.requestMu.RUnlock() + + if handler == nil { + c.logger.Errorf("received request from server but no handler set: %s", request.Method) + // Send method not found error + errorResponse := &JSONRPCResponse{ + JSONRPC: "2.0", + ID: request.ID, + Error: &struct { + Code int `json:"code"` + Message string `json:"message"` + Data json.RawMessage `json:"data"` + }{ + Code: -32601, // Method not found + Message: fmt.Sprintf("no handler configured for method: %s", request.Method), + }, + } + c.sendResponseToServer(ctx, errorResponse) + return + } + + // Handle the request in a goroutine to avoid blocking the SSE reader + go func() { + response, err := handler(ctx, request) + if err != nil { + c.logger.Errorf("error handling request %s: %v", request.Method, err) + + // Determine appropriate JSON-RPC error code based on error type + var errorCode int + var errorMessage string + + // Check for specific sampling-related errors + if errors.Is(err, context.Canceled) { + errorCode = -32800 // Request cancelled + errorMessage = "request was cancelled" + } else if errors.Is(err, context.DeadlineExceeded) { + errorCode = -32800 // Request timeout + errorMessage = "request timed out" + } else { + // Generic error cases + switch request.Method { + case string(mcp.MethodSamplingCreateMessage): + errorCode = -32603 // Internal error + errorMessage = fmt.Sprintf("sampling request failed: %v", err) + default: + errorCode = -32603 // Internal error + errorMessage = err.Error() + } + } + + // Send error response + errorResponse := &JSONRPCResponse{ + JSONRPC: "2.0", + ID: request.ID, + Error: &struct { + Code int `json:"code"` + Message string `json:"message"` + Data json.RawMessage `json:"data"` + }{ + Code: errorCode, + Message: errorMessage, + }, + } + c.sendResponseToServer(ctx, errorResponse) + return + } + + if response != nil { + c.sendResponseToServer(ctx, response) + } + }() +} + +// sendResponseToServer sends a response back to the server via HTTP POST +func (c *StreamableHTTP) sendResponseToServer(ctx context.Context, response *JSONRPCResponse) { + responseBody, err := json.Marshal(response) + if err != nil { + c.logger.Errorf("failed to marshal response: %v", err) + return + } + + ctx, cancel := c.contextAwareOfClientClose(ctx) + defer cancel() + + resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(responseBody), "application/json") + if err != nil { + c.logger.Errorf("failed to send response to server: %v", err) + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { + body, _ := io.ReadAll(resp.Body) + c.logger.Errorf("server rejected response with status %d: %s", resp.StatusCode, body) + } +} + func (c *StreamableHTTP) contextAwareOfClientClose(ctx context.Context) (context.Context, context.CancelFunc) { newCtx, cancel := context.WithCancel(ctx) go func() { diff --git a/client/transport/streamable_http_sampling_test.go b/client/transport/streamable_http_sampling_test.go new file mode 100644 index 000000000..0dce2da68 --- /dev/null +++ b/client/transport/streamable_http_sampling_test.go @@ -0,0 +1,245 @@ +package transport + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" +) + +// TestStreamableHTTP_SamplingFlow tests the complete sampling flow with HTTP transport +func TestStreamableHTTP_SamplingFlow(t *testing.T) { + // Create simple test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Just respond OK to any requests + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + // Create HTTP client transport + client, err := NewStreamableHTTP(server.URL) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer client.Close() + + // Set up sampling request handler + var handledRequest *JSONRPCRequest + client.SetRequestHandler(func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) { + handledRequest = &request + + // Simulate sampling handler response + result := map[string]interface{}{ + "role": "assistant", + "content": map[string]interface{}{ + "type": "text", + "text": "Hello! How can I help you today?", + }, + "model": "test-model", + "stopReason": "stop_sequence", + } + + resultBytes, _ := json.Marshal(result) + + return &JSONRPCResponse{ + JSONRPC: "2.0", + ID: request.ID, + Result: resultBytes, + }, nil + }) + + // Start the client + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + err = client.Start(ctx) + if err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Test direct request handling (simulating a sampling request) + samplingRequest := JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(1), + Method: string(mcp.MethodSamplingCreateMessage), + Params: map[string]interface{}{ + "messages": []map[string]interface{}{ + { + "role": "user", + "content": map[string]interface{}{ + "type": "text", + "text": "Hello, world!", + }, + }, + }, + }, + } + + // Directly test request handling + client.handleIncomingRequest(ctx, samplingRequest) + + // Allow time for async processing + time.Sleep(100 * time.Millisecond) + + // Verify the request was handled + if handledRequest == nil { + t.Fatal("Sampling request was not handled") + } + + if handledRequest.Method != string(mcp.MethodSamplingCreateMessage) { + t.Errorf("Expected method %s, got %s", mcp.MethodSamplingCreateMessage, handledRequest.Method) + } +} + +// TestStreamableHTTP_SamplingErrorHandling tests error handling in sampling requests +func TestStreamableHTTP_SamplingErrorHandling(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var body map[string]interface{} + json.NewDecoder(r.Body).Decode(&body) + + // Check if this is an error response + if errorField, ok := body["error"]; ok { + errorMap := errorField.(map[string]interface{}) + if code, ok := errorMap["code"].(float64); ok && code == -32603 { + w.WriteHeader(http.StatusOK) + return + } + } + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client, err := NewStreamableHTTP(server.URL) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer client.Close() + + // Set up request handler that returns an error + client.SetRequestHandler(func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) { + return nil, fmt.Errorf("sampling failed") + }) + + // Start the client + ctx := context.Background() + err = client.Start(ctx) + if err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Simulate incoming sampling request + samplingRequest := JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(1), + Method: string(mcp.MethodSamplingCreateMessage), + Params: map[string]interface{}{}, + } + + // This should trigger error handling + client.handleIncomingRequest(ctx, samplingRequest) + + // Allow some time for async error handling + time.Sleep(100 * time.Millisecond) +} + +// TestStreamableHTTP_NoSamplingHandler tests behavior when no sampling handler is set +func TestStreamableHTTP_NoSamplingHandler(t *testing.T) { + var errorReceived bool + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var body map[string]interface{} + json.NewDecoder(r.Body).Decode(&body) + + // Check if this is an error response with method not found + if errorField, ok := body["error"]; ok { + errorMap := errorField.(map[string]interface{}) + if code, ok := errorMap["code"].(float64); ok && code == -32601 { + if message, ok := errorMap["message"].(string); ok && + strings.Contains(message, "no handler configured") { + errorReceived = true + } + } + } + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client, err := NewStreamableHTTP(server.URL) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer client.Close() + + // Don't set any request handler + + ctx := context.Background() + err = client.Start(ctx) + if err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Simulate incoming sampling request + samplingRequest := JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(1), + Method: string(mcp.MethodSamplingCreateMessage), + Params: map[string]interface{}{}, + } + + // This should trigger "method not found" error + client.handleIncomingRequest(ctx, samplingRequest) + + // Allow some time for async error handling + time.Sleep(100 * time.Millisecond) + + if !errorReceived { + t.Error("Expected method not found error, but didn't receive it") + } +} + +// TestStreamableHTTP_BidirectionalInterface verifies the interface implementation +func TestStreamableHTTP_BidirectionalInterface(t *testing.T) { + client, err := NewStreamableHTTP("http://example.com") + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer client.Close() + + // Verify it implements BidirectionalInterface + _, ok := interface{}(client).(BidirectionalInterface) + if !ok { + t.Error("StreamableHTTP should implement BidirectionalInterface") + } + + // Test SetRequestHandler + handlerSet := false + client.SetRequestHandler(func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) { + handlerSet = true + return nil, nil + }) + + // Verify handler was set by triggering it + ctx := context.Background() + client.handleIncomingRequest(ctx, JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(1), + Method: "test", + }) + + // Allow async execution + time.Sleep(50 * time.Millisecond) + + if !handlerSet { + t.Error("Request handler was not properly set or called") + } +} \ No newline at end of file diff --git a/examples/sampling_http_client/README.md b/examples/sampling_http_client/README.md new file mode 100644 index 000000000..e4cf0ea4e --- /dev/null +++ b/examples/sampling_http_client/README.md @@ -0,0 +1,95 @@ +# HTTP Sampling Client Example + +This example demonstrates how to create an MCP client using HTTP transport that supports sampling requests from the server. + +## Overview + +This client: +- Connects to an MCP server via HTTP/HTTPS transport +- Declares sampling capability during initialization +- Handles incoming sampling requests from the server +- Uses a mock LLM to generate responses (replace with real LLM integration) + +## Usage + +1. Start an MCP server that supports sampling (e.g., using the `sampling_server` example) + +2. Update the server URL in `main.go`: + ```go + httpClient, err := client.NewStreamableHttpClient( + "http://your-server:port", // Replace with your server URL + ) + ``` + +3. Run the client: + ```bash + go run main.go + ``` + +## Key Features + +### HTTP Transport with Sampling +The client creates the HTTP transport directly and then wraps it with a client that supports sampling: + +```go +httpTransport, err := transport.NewStreamableHTTP("http://localhost:8080") +mcpClient := client.NewClient(httpTransport, client.WithSamplingHandler(samplingHandler)) +``` + +### Sampling Handler +The `MockSamplingHandler` implements the `client.SamplingHandler` interface: + +```go +type MockSamplingHandler struct{} + +func (h *MockSamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + // Process the sampling request and return LLM response + // In production, integrate with OpenAI, Anthropic, or other LLM APIs +} +``` + +### Client Configuration +The client is configured with sampling capabilities: + +```go +mcpClient := client.NewClient( + httpTransport, + client.WithSamplingHandler(samplingHandler), +) +// Sampling capability is automatically declared when a handler is provided +``` + +## Real Implementation + +For a production implementation, replace the `MockSamplingHandler` with a real LLM client: + +```go +type RealSamplingHandler struct { + client *openai.Client // or other LLM client +} + +func (h *RealSamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + // Convert MCP request to LLM API format + // Call LLM API + // Convert response back to MCP format + // Return the result +} +``` + +## HTTP-Specific Features + +The HTTP transport supports: +- Standard HTTP headers for authentication and customization +- OAuth 2.0 authentication (using `WithHTTPOAuth`) +- Custom headers (using `WithHTTPHeaders`) +- Server-side events (SSE) for bidirectional communication +- Proper error handling with HTTP status codes +- Session management via HTTP headers + +## Testing + +The implementation includes comprehensive tests in `client/transport/streamable_http_sampling_test.go` that verify: +- Sampling request handling +- Error scenarios +- Bidirectional interface compliance +- HTTP-specific error codes and responses \ No newline at end of file diff --git a/examples/sampling_http_client/main.go b/examples/sampling_http_client/main.go new file mode 100644 index 000000000..4f365ad1b --- /dev/null +++ b/examples/sampling_http_client/main.go @@ -0,0 +1,108 @@ +package main + +import ( + "context" + "fmt" + "log" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" +) + +// MockSamplingHandler implements client.SamplingHandler for demonstration. +// In a real implementation, this would integrate with an actual LLM API. +type MockSamplingHandler struct{} + +func (h *MockSamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + // Extract the user's message + if len(request.Messages) == 0 { + return nil, fmt.Errorf("no messages provided") + } + + // Get the last user message + lastMessage := request.Messages[len(request.Messages)-1] + userText := "" + if textContent, ok := lastMessage.Content.(mcp.TextContent); ok { + userText = textContent.Text + } + + // Generate a mock response + responseText := fmt.Sprintf("Mock LLM response to: '%s'", userText) + + log.Printf("Mock LLM generating response: %s", responseText) + + result := &mcp.CreateMessageResult{ + SamplingMessage: mcp.SamplingMessage{ + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Type: "text", + Text: responseText, + }, + }, + Model: "mock-model-v1", + StopReason: "endTurn", + } + + return result, nil +} + +func main() { + // Create sampling handler + samplingHandler := &MockSamplingHandler{} + + // Create HTTP transport directly + httpTransport, err := transport.NewStreamableHTTP( + "http://localhost:8080", // Replace with your MCP server URL + // You can add HTTP-specific options here like headers, OAuth, etc. + ) + if err != nil { + log.Fatalf("Failed to create HTTP transport: %v", err) + } + defer httpTransport.Close() + + // Create client with sampling support + mcpClient := client.NewClient( + httpTransport, + client.WithSamplingHandler(samplingHandler), + ) + + // Start the client + ctx := context.Background() + err = mcpClient.Start(ctx) + if err != nil { + log.Fatalf("Failed to start client: %v", err) + } + + // Initialize the MCP session + initRequest := mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + Capabilities: mcp.ClientCapabilities{ + // Sampling capability will be automatically added by the client + }, + ClientInfo: mcp.Implementation{ + Name: "sampling-http-client", + Version: "1.0.0", + }, + }, + } + + _, err = mcpClient.Initialize(ctx, initRequest) + if err != nil { + log.Fatalf("Failed to initialize MCP session: %v", err) + } + + log.Println("HTTP MCP client with sampling support started successfully!") + log.Println("The client is now ready to handle sampling requests from the server.") + log.Println("When the server sends a sampling request, the MockSamplingHandler will process it.") + + // In a real application, you would keep the client running to handle sampling requests + // For this example, we'll just demonstrate that it's working + + // Keep the client running (in a real app, you'd have your main application logic here) + select { + case <-ctx.Done(): + log.Println("Client context cancelled") + } +} \ No newline at end of file From d239784b41fa15b410d08dca6c10e306c4e3ec7f Mon Sep 17 00:00:00 2001 From: andig Date: Sun, 27 Jul 2025 21:47:13 +0200 Subject: [PATCH 02/22] feat: implement server-side sampling support for HTTP transport MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This completes the server-side implementation of sampling support for HTTP transport, addressing the remaining requirements from issue #419. Changes: - Enhanced streamableHttpSession to implement SessionWithSampling interface - Added bidirectional SSE communication for server-to-client requests - Implemented session registry for proper response correlation - Added comprehensive error handling with JSON-RPC error codes - Created extensive test suite covering all scenarios - Added working example server with sampling tools Key Features: - Server can send sampling requests to HTTP clients via SSE - Clients respond via HTTP POST with proper session correlation - Queue overflow protection and timeout handling - Compatible with existing HTTP transport architecture 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- examples/sampling_http_server/README.md | 138 ++++++++++++++ examples/sampling_http_server/main.go | 142 +++++++++++++++ server/streamable_http.go | 228 +++++++++++++++++++++++- server/streamable_http_sampling_test.go | 166 +++++++++++++++++ 4 files changed, 666 insertions(+), 8 deletions(-) create mode 100644 examples/sampling_http_server/README.md create mode 100644 examples/sampling_http_server/main.go create mode 100644 server/streamable_http_sampling_test.go diff --git a/examples/sampling_http_server/README.md b/examples/sampling_http_server/README.md new file mode 100644 index 000000000..64be58c2c --- /dev/null +++ b/examples/sampling_http_server/README.md @@ -0,0 +1,138 @@ +# HTTP Sampling Server Example + +This example demonstrates how to create an MCP server using HTTP transport that can send sampling requests to clients. + +## Overview + +This server: +- Runs on HTTP transport (port 8080 by default) +- Declares sampling capability during initialization +- Can send sampling requests to connected clients via Server-Sent Events (SSE) +- Receives sampling responses from clients via HTTP POST +- Includes tools that demonstrate sampling functionality + +## Usage + +1. Start the server: + ```bash + go run main.go + ``` + +2. The server will be available at: `http://localhost:8080/mcp` + +3. Connect with an HTTP client that supports sampling (like the `sampling_http_client` example) + +## Tools Available + +### `ask_llm` +Demonstrates server-initiated sampling: +- Takes a question and optional system prompt +- Sends sampling request to client +- Returns the LLM's response + +### `echo` +Simple tool for testing basic functionality: +- Echoes back the input message +- Doesn't require sampling + +## How Sampling Works + +### Server → Client Flow +1. **Tool Invocation**: Client calls `ask_llm` tool +2. **Sampling Request**: Server creates sampling request with user's question +3. **SSE Transmission**: Server sends JSON-RPC request to client via SSE stream +4. **Client Processing**: Client's sampling handler processes the request +5. **HTTP Response**: Client sends JSON-RPC response back via HTTP POST +6. **Tool Response**: Server returns the LLM response to the original tool caller + +### Communication Architecture +``` +Client (HTTP + SSE) ←→ Server (HTTP) + │ │ + ├─ POST: Tool Call ──→ │ + │ │ + │ ←── SSE: Sampling ───┤ + │ Request │ + │ │ + ├─ POST: Sampling ───→ │ + │ Response │ + │ │ + │ ←── HTTP: Tool ──────┤ + Response +``` + +## Key Features + +### Bidirectional Communication +- **SSE Stream**: Server → Client requests (sampling, notifications) +- **HTTP POST**: Client → Server responses and requests + +### Session Management +- Session ID tracking for request/response correlation +- Proper session lifecycle management +- Session validation for security + +### Error Handling +- JSON-RPC error codes for different failure scenarios +- Timeout handling for sampling requests +- Queue overflow protection + +### HTTP-Specific Features +- Standard MCP headers (`Mcp-Session-Id`, `Mcp-Protocol-Version`) +- Content-Type validation +- Proper HTTP status codes +- SSE event formatting + +## Testing + +You can test the server using the `sampling_http_client` example: + +1. Start this server: + ```bash + go run examples/sampling_http_server/main.go + ``` + +2. In another terminal, start the client: + ```bash + go run examples/sampling_http_client/main.go + ``` + +3. The client will connect and be ready to handle sampling requests from the server. + +## Production Considerations + +### Security +- Implement proper authentication/authorization +- Use HTTPS in production +- Validate all incoming data +- Implement rate limiting + +### Scalability +- Consider connection pooling for multiple clients +- Implement proper session cleanup +- Monitor memory usage for long-running sessions +- Add metrics and monitoring + +### Reliability +- Implement request retries +- Add circuit breakers for failing clients +- Implement graceful degradation when sampling is unavailable +- Add comprehensive logging + +## Integration + +This server can be integrated into existing HTTP infrastructure: + +```go +// Custom HTTP server integration +mux := http.NewServeMux() +mux.Handle("/mcp", httpServer) +mux.Handle("/health", healthHandler) + +server := &http.Server{ + Addr: ":8080", + Handler: mux, +} +``` + +The sampling functionality works seamlessly with other MCP features like tools, resources, and prompts. \ No newline at end of file diff --git a/examples/sampling_http_server/main.go b/examples/sampling_http_server/main.go new file mode 100644 index 000000000..b2f84533e --- /dev/null +++ b/examples/sampling_http_server/main.go @@ -0,0 +1,142 @@ +package main + +import ( + "context" + "fmt" + "log" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +func main() { + // Create MCP server with sampling capability + mcpServer := server.NewMCPServer("sampling-http-server", "1.0.0") + + // Enable sampling capability + mcpServer.EnableSampling() + + // Add a tool that uses sampling to get LLM responses + mcpServer.AddTool(mcp.Tool{ + Name: "ask_llm", + Description: "Ask the LLM a question using sampling over HTTP", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "question": map[string]any{ + "type": "string", + "description": "The question to ask the LLM", + }, + "system_prompt": map[string]any{ + "type": "string", + "description": "Optional system prompt to provide context", + }, + }, + Required: []string{"question"}, + }, + }, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Extract parameters + question, err := request.RequireString("question") + if err != nil { + return nil, err + } + + systemPrompt := request.GetString("system_prompt", "You are a helpful assistant.") + + // Create sampling request + samplingRequest := mcp.CreateMessageRequest{ + CreateMessageParams: mcp.CreateMessageParams{ + Messages: []mcp.SamplingMessage{ + { + Role: mcp.RoleUser, + Content: mcp.TextContent{ + Type: "text", + Text: question, + }, + }, + }, + SystemPrompt: systemPrompt, + MaxTokens: 1000, + Temperature: 0.7, + }, + } + + // Request sampling from the client with timeout + samplingCtx, cancel := context.WithTimeout(ctx, 2*time.Minute) + defer cancel() + + serverFromCtx := server.ServerFromContext(ctx) + result, err := serverFromCtx.RequestSampling(samplingCtx, samplingRequest) + if err != nil { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("Error requesting sampling: %v", err), + }, + }, + IsError: true, + }, nil + } + + // Return the LLM response + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("LLM Response (model: %s): %s", result.Model, result.Content.(mcp.TextContent).Text), + }, + }, + }, nil + }) + + // Add a simple echo tool for testing + mcpServer.AddTool(mcp.Tool{ + Name: "echo", + Description: "Echo back the input message", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "message": map[string]any{ + "type": "string", + "description": "The message to echo back", + }, + }, + Required: []string{"message"}, + }, + }, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + message := request.GetString("message", "") + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("Echo: %s", message), + }, + }, + }, nil + }) + + // Create HTTP server + httpServer := server.NewStreamableHTTPServer(mcpServer) + + log.Println("Starting HTTP MCP server with sampling support on :8080") + log.Println("Endpoint: http://localhost:8080/mcp") + log.Println("") + log.Println("This server supports sampling over HTTP transport.") + log.Println("Clients must:") + log.Println("1. Initialize with sampling capability") + log.Println("2. Establish SSE connection for bidirectional communication") + log.Println("3. Handle incoming sampling requests from the server") + log.Println("4. Send responses back via HTTP POST") + log.Println("") + log.Println("Available tools:") + log.Println("- ask_llm: Ask the LLM a question (requires sampling)") + log.Println("- echo: Simple echo tool (no sampling required)") + + // Start the server + if err := httpServer.Start(":8080"); err != nil { + log.Fatalf("Server failed to start: %v", err) + } +} \ No newline at end of file diff --git a/server/streamable_http.go b/server/streamable_http.go index f39e24f87..f029674e1 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -120,6 +120,7 @@ type StreamableHTTPServer struct { server *MCPServer sessionTools *sessionToolsStore sessionRequestIDs sync.Map // sessionId --> last requestID(*atomic.Int64) + activeSessions sync.Map // sessionId --> *streamableHttpSession (for sampling responses) httpServer *http.Server mu sync.RWMutex @@ -223,14 +224,32 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, fmt.Sprintf("read request body error: %v", err)) return } - var baseMessage struct { - Method mcp.MCPMethod `json:"method"` + // First, try to parse as a response (sampling responses don't have a method field) + var responseMessage struct { + ID json.RawMessage `json:"id"` + Result json.RawMessage `json:"result,omitempty"` + Error json.RawMessage `json:"error,omitempty"` + Method mcp.MCPMethod `json:"method,omitempty"` } - if err := json.Unmarshal(rawData, &baseMessage); err != nil { + if err := json.Unmarshal(rawData, &responseMessage); err != nil { s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, "request body is not valid json") return } - isInitializeRequest := baseMessage.Method == mcp.MethodInitialize + + // Check if this is a sampling response (has result/error but no method) + isSamplingResponse := responseMessage.Method == "" && responseMessage.ID != nil && + (responseMessage.Result != nil || responseMessage.Error != nil) + + isInitializeRequest := responseMessage.Method == mcp.MethodInitialize + + // Handle sampling responses separately + if isSamplingResponse { + if err := s.handleSamplingResponse(w, r, responseMessage); err != nil { + s.logger.Errorf("Failed to handle sampling response: %v", err) + http.Error(w, "Failed to handle sampling response", http.StatusInternalServerError) + } + return + } // Prepare the session for the mcp server // The session is ephemeral. Its life is the same as the request. It's only created @@ -371,6 +390,10 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) return } defer s.server.UnregisterSession(r.Context(), sessionID) + + // Register session for sampling response delivery + s.activeSessions.Store(sessionID, session) + defer s.activeSessions.Delete(sessionID) // Set the client context before handling the message w.Header().Set("Content-Type", "text/event-stream") @@ -399,6 +422,21 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) case <-done: return } + case samplingReq := <-session.samplingRequestChan: + // Send sampling request to client via SSE + jsonrpcRequest := mcp.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(samplingReq.requestID), + Request: mcp.Request{ + Method: string(mcp.MethodSamplingCreateMessage), + }, + Params: samplingReq.request.CreateMessageParams, + } + select { + case writeChan <- jsonrpcRequest: + case <-done: + return + } case <-done: return } @@ -487,6 +525,98 @@ func writeSSEEvent(w io.Writer, data any) error { return nil } +// handleSamplingResponse processes incoming sampling responses from clients +func (s *StreamableHTTPServer) handleSamplingResponse(w http.ResponseWriter, r *http.Request, responseMessage struct { + ID json.RawMessage `json:"id"` + Result json.RawMessage `json:"result,omitempty"` + Error json.RawMessage `json:"error,omitempty"` + Method mcp.MCPMethod `json:"method,omitempty"` +}) error { + // Get session ID from header + sessionID := r.Header.Get(HeaderKeySessionID) + if sessionID == "" { + http.Error(w, "Missing session ID for sampling response", http.StatusBadRequest) + return fmt.Errorf("missing session ID") + } + + // Validate session + isTerminated, err := s.sessionIdManager.Validate(sessionID) + if err != nil { + http.Error(w, "Invalid session ID", http.StatusBadRequest) + return err + } + if isTerminated { + http.Error(w, "Session terminated", http.StatusNotFound) + return fmt.Errorf("session terminated") + } + + // Parse the request ID + var requestID int64 + if err := json.Unmarshal(responseMessage.ID, &requestID); err != nil { + http.Error(w, "Invalid request ID in sampling response", http.StatusBadRequest) + return err + } + + // Create the sampling response item + response := samplingResponseItem{ + requestID: requestID, + } + + // Parse result or error + if responseMessage.Error != nil { + // Parse error + var jsonrpcError struct { + Code int `json:"code"` + Message string `json:"message"` + } + if err := json.Unmarshal(responseMessage.Error, &jsonrpcError); err != nil { + response.err = fmt.Errorf("failed to parse error: %v", err) + } else { + response.err = fmt.Errorf("sampling error %d: %s", jsonrpcError.Code, jsonrpcError.Message) + } + } else if responseMessage.Result != nil { + // Parse result + var result mcp.CreateMessageResult + if err := json.Unmarshal(responseMessage.Result, &result); err != nil { + response.err = fmt.Errorf("failed to parse sampling result: %v", err) + } else { + response.result = &result + } + } else { + response.err = fmt.Errorf("sampling response has neither result nor error") + } + + // Find the corresponding session and deliver the response + // Note: In a real implementation, we would need to maintain a mapping of sessionID to session instances + // For now, we'll use a simplified approach by broadcasting to all active sessions + // This is a limitation of the current architecture that would need to be addressed in production + s.deliverSamplingResponse(sessionID, response) + + // Acknowledge receipt + w.WriteHeader(http.StatusOK) + return nil +} + +// deliverSamplingResponse delivers a sampling response to the appropriate session +func (s *StreamableHTTPServer) deliverSamplingResponse(sessionID string, response samplingResponseItem) { + // Look up the active session + if sessionInterface, ok := s.activeSessions.Load(sessionID); ok { + if session, ok := sessionInterface.(*streamableHttpSession); ok { + // Deliver the response via the session's response channel + select { + case session.samplingResponseChan <- response: + s.logger.Infof("Delivered sampling response for session %s, request %d", sessionID, response.requestID) + default: + s.logger.Errorf("Failed to deliver sampling response for session %s, request %d: channel full", sessionID, response.requestID) + } + } else { + s.logger.Errorf("Invalid session type for session %s", sessionID) + } + } else { + s.logger.Errorf("No active session found for session %s", sessionID) + } +} + // writeJSONRPCError writes a JSON-RPC error response with the given error details. func (s *StreamableHTTPServer) writeJSONRPCError( w http.ResponseWriter, @@ -573,6 +703,19 @@ func (s *sessionToolsStore) delete(sessionID string) { delete(s.tools, sessionID) } +// Sampling support types for HTTP transport +type samplingRequestItem struct { + requestID int64 + request mcp.CreateMessageRequest + response chan samplingResponseItem +} + +type samplingResponseItem struct { + requestID int64 + result *mcp.CreateMessageResult + err error +} + // streamableHttpSession is a session for streamable-http transport // When in POST handlers(request/notification), it's ephemeral, and only exists in the life of the request handler. // When in GET handlers(listening), it's a real session, and will be registered in the MCP server. @@ -582,14 +725,23 @@ type streamableHttpSession struct { tools *sessionToolsStore upgradeToSSE atomic.Bool logLevels *sessionLogLevelsStore + + // Sampling support for bidirectional communication + samplingRequestChan chan samplingRequestItem // server -> client sampling requests + samplingResponseChan chan samplingResponseItem // client -> server sampling responses + samplingRequests sync.Map // requestID -> pending sampling request context + requestIDCounter atomic.Int64 // for generating unique request IDs + mu sync.RWMutex // protects sampling channels and requests } func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore, levels *sessionLogLevelsStore) *streamableHttpSession { s := &streamableHttpSession{ - sessionID: sessionID, - notificationChannel: make(chan mcp.JSONRPCNotification, 100), - tools: toolStore, - logLevels: levels, + sessionID: sessionID, + notificationChannel: make(chan mcp.JSONRPCNotification, 100), + tools: toolStore, + logLevels: levels, + samplingRequestChan: make(chan samplingRequestItem, 10), + samplingResponseChan: make(chan samplingResponseItem, 10), } return s } @@ -641,6 +793,66 @@ func (s *streamableHttpSession) UpgradeToSSEWhenReceiveNotification() { var _ SessionWithStreamableHTTPConfig = (*streamableHttpSession)(nil) +// RequestSampling implements SessionWithSampling interface for HTTP transport +func (s *streamableHttpSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + // Generate unique request ID + requestID := s.requestIDCounter.Add(1) + + // Create response channel for this specific request + responseChan := make(chan samplingResponseItem, 1) + + // Create the sampling request item + samplingRequest := samplingRequestItem{ + requestID: requestID, + request: request, + response: responseChan, + } + + // Store the pending request + s.samplingRequests.Store(requestID, responseChan) + defer s.samplingRequests.Delete(requestID) + + // Send the sampling request via the channel (non-blocking) + select { + case s.samplingRequestChan <- samplingRequest: + // Request queued successfully + case <-ctx.Done(): + return nil, ctx.Err() + default: + return nil, fmt.Errorf("sampling request queue is full - server overloaded") + } + + // Wait for response or context cancellation + select { + case response := <-responseChan: + if response.err != nil { + return nil, response.err + } + return response.result, nil + case response := <-s.samplingResponseChan: + // Check if this response matches our request + if response.requestID == requestID { + if response.err != nil { + return nil, response.err + } + return response.result, nil + } + // If it's not our response, put it back and continue waiting + // This is a simplified approach; in production you'd want better routing + select { + case s.samplingResponseChan <- response: + default: + // Channel full, log and continue + } + // Continue waiting for our response + return s.RequestSampling(ctx, request) + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +var _ SessionWithSampling = (*streamableHttpSession)(nil) + // --- session id manager --- type SessionIdManager interface { diff --git a/server/streamable_http_sampling_test.go b/server/streamable_http_sampling_test.go new file mode 100644 index 000000000..8e1ff4b05 --- /dev/null +++ b/server/streamable_http_sampling_test.go @@ -0,0 +1,166 @@ +package server + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" +) + +// TestStreamableHTTPServer_SamplingBasic tests basic sampling session functionality +func TestStreamableHTTPServer_SamplingBasic(t *testing.T) { + // Create MCP server with sampling enabled + mcpServer := NewMCPServer("test-server", "1.0.0") + mcpServer.EnableSampling() + + // Create HTTP server + httpServer := NewStreamableHTTPServer(mcpServer) + testServer := httptest.NewServer(httpServer) + defer testServer.Close() + + // Test session creation and interface implementation + sessionID := "test-session" + session := newStreamableHttpSession(sessionID, httpServer.sessionTools, httpServer.sessionLogLevels) + + // Verify it implements SessionWithSampling + _, ok := interface{}(session).(SessionWithSampling) + if !ok { + t.Error("streamableHttpSession should implement SessionWithSampling") + } + + // Test that sampling request channels are initialized + if session.samplingRequestChan == nil { + t.Error("samplingRequestChan should be initialized") + } + if session.samplingResponseChan == nil { + t.Error("samplingResponseChan should be initialized") + } +} + +// TestStreamableHTTPServer_SamplingErrorHandling tests error scenarios +func TestStreamableHTTPServer_SamplingErrorHandling(t *testing.T) { + mcpServer := NewMCPServer("test-server", "1.0.0") + mcpServer.EnableSampling() + + httpServer := NewStreamableHTTPServer(mcpServer) + testServer := httptest.NewServer(httpServer) + defer testServer.Close() + + client := &http.Client{} + baseURL := testServer.URL + + // Test sending sampling response without session ID + samplingResponse := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 1, + "result": map[string]interface{}{ + "role": "assistant", + "content": map[string]interface{}{ + "type": "text", + "text": "Test response", + }, + }, + } + + responseBody, _ := json.Marshal(samplingResponse) + resp, err := client.Post(baseURL, "application/json", bytes.NewReader(responseBody)) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("Expected 400 Bad Request for missing session ID, got %d", resp.StatusCode) + } +} + +// TestStreamableHTTPServer_SamplingInterface verifies interface implementation +func TestStreamableHTTPServer_SamplingInterface(t *testing.T) { + mcpServer := NewMCPServer("test-server", "1.0.0") + httpServer := NewStreamableHTTPServer(mcpServer) + testServer := httptest.NewServer(httpServer) + defer testServer.Close() + + // Create a session + sessionID := "test-session" + session := newStreamableHttpSession(sessionID, httpServer.sessionTools, httpServer.sessionLogLevels) + + // Verify it implements SessionWithSampling + _, ok := interface{}(session).(SessionWithSampling) + if !ok { + t.Error("streamableHttpSession should implement SessionWithSampling") + } + + // Test RequestSampling with timeout + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + request := mcp.CreateMessageRequest{ + CreateMessageParams: mcp.CreateMessageParams{ + Messages: []mcp.SamplingMessage{ + { + Role: mcp.RoleUser, + Content: mcp.TextContent{ + Type: "text", + Text: "Test message", + }, + }, + }, + }, + } + + _, err := session.RequestSampling(ctx, request) + if err == nil { + t.Error("Expected timeout error, but got nil") + } + + if !strings.Contains(err.Error(), "context deadline exceeded") { + t.Errorf("Expected timeout error, got: %v", err) + } +} + +// TestStreamableHTTPServer_SamplingQueueFull tests queue overflow scenarios +func TestStreamableHTTPServer_SamplingQueueFull(t *testing.T) { + sessionID := "test-session" + session := newStreamableHttpSession(sessionID, nil, nil) + + // Fill the sampling request queue + for i := 0; i < cap(session.samplingRequestChan); i++ { + session.samplingRequestChan <- samplingRequestItem{ + requestID: int64(i), + request: mcp.CreateMessageRequest{}, + response: make(chan samplingResponseItem, 1), + } + } + + // Try to add another request (should fail) + ctx := context.Background() + request := mcp.CreateMessageRequest{ + CreateMessageParams: mcp.CreateMessageParams{ + Messages: []mcp.SamplingMessage{ + { + Role: mcp.RoleUser, + Content: mcp.TextContent{ + Type: "text", + Text: "Test message", + }, + }, + }, + }, + } + + _, err := session.RequestSampling(ctx, request) + if err == nil { + t.Error("Expected queue full error, but got nil") + } + + if !strings.Contains(err.Error(), "queue is full") { + t.Errorf("Expected queue full error, got: %v", err) + } +} \ No newline at end of file From dd877e0aa244410fbe70c9478bd4faad53d15585 Mon Sep 17 00:00:00 2001 From: andig Date: Sun, 27 Jul 2025 21:53:12 +0200 Subject: [PATCH 03/22] fix: replace time.Sleep with synchronization primitives in tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace flaky time.Sleep calls with proper synchronization using channels and sync.WaitGroup to make tests deterministic and avoid race conditions. Also improves error handling robustness in test servers with proper JSON decoding error checks. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- .../streamable_http_sampling_test.go | 54 +++++++++++++++---- 1 file changed, 44 insertions(+), 10 deletions(-) diff --git a/client/transport/streamable_http_sampling_test.go b/client/transport/streamable_http_sampling_test.go index 0dce2da68..d5b49df29 100644 --- a/client/transport/streamable_http_sampling_test.go +++ b/client/transport/streamable_http_sampling_test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "strings" + "sync" "testing" "time" @@ -31,8 +32,10 @@ func TestStreamableHTTP_SamplingFlow(t *testing.T) { // Set up sampling request handler var handledRequest *JSONRPCRequest + handlerCalled := make(chan struct{}) client.SetRequestHandler(func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) { handledRequest = &request + close(handlerCalled) // Simulate sampling handler response result := map[string]interface{}{ @@ -84,8 +87,13 @@ func TestStreamableHTTP_SamplingFlow(t *testing.T) { // Directly test request handling client.handleIncomingRequest(ctx, samplingRequest) - // Allow time for async processing - time.Sleep(100 * time.Millisecond) + // Wait for handler to be called + select { + case <-handlerCalled: + // Handler was called + case <-time.After(1 * time.Second): + t.Fatal("Handler was not called within timeout") + } // Verify the request was handled if handledRequest == nil { @@ -99,15 +107,23 @@ func TestStreamableHTTP_SamplingFlow(t *testing.T) { // TestStreamableHTTP_SamplingErrorHandling tests error handling in sampling requests func TestStreamableHTTP_SamplingErrorHandling(t *testing.T) { + var errorHandled sync.WaitGroup + errorHandled.Add(1) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodPost { var body map[string]interface{} - json.NewDecoder(r.Body).Decode(&body) + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Logf("Failed to decode body: %v", err) + w.WriteHeader(http.StatusOK) + return + } // Check if this is an error response if errorField, ok := body["error"]; ok { errorMap := errorField.(map[string]interface{}) if code, ok := errorMap["code"].(float64); ok && code == -32603 { + errorHandled.Done() w.WriteHeader(http.StatusOK) return } @@ -146,18 +162,23 @@ func TestStreamableHTTP_SamplingErrorHandling(t *testing.T) { // This should trigger error handling client.handleIncomingRequest(ctx, samplingRequest) - // Allow some time for async error handling - time.Sleep(100 * time.Millisecond) + // Wait for error to be handled + errorHandled.Wait() } // TestStreamableHTTP_NoSamplingHandler tests behavior when no sampling handler is set func TestStreamableHTTP_NoSamplingHandler(t *testing.T) { var errorReceived bool + errorReceivedChan := make(chan struct{}) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodPost { var body map[string]interface{} - json.NewDecoder(r.Body).Decode(&body) + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Logf("Failed to decode body: %v", err) + w.WriteHeader(http.StatusOK) + return + } // Check if this is an error response with method not found if errorField, ok := body["error"]; ok { @@ -166,6 +187,7 @@ func TestStreamableHTTP_NoSamplingHandler(t *testing.T) { if message, ok := errorMap["message"].(string); ok && strings.Contains(message, "no handler configured") { errorReceived = true + close(errorReceivedChan) } } } @@ -199,8 +221,13 @@ func TestStreamableHTTP_NoSamplingHandler(t *testing.T) { // This should trigger "method not found" error client.handleIncomingRequest(ctx, samplingRequest) - // Allow some time for async error handling - time.Sleep(100 * time.Millisecond) + // Wait for error to be received + select { + case <-errorReceivedChan: + // Error was received + case <-time.After(1 * time.Second): + t.Fatal("Method not found error was not received within timeout") + } if !errorReceived { t.Error("Expected method not found error, but didn't receive it") @@ -223,8 +250,10 @@ func TestStreamableHTTP_BidirectionalInterface(t *testing.T) { // Test SetRequestHandler handlerSet := false + handlerSetChan := make(chan struct{}) client.SetRequestHandler(func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) { handlerSet = true + close(handlerSetChan) return nil, nil }) @@ -236,8 +265,13 @@ func TestStreamableHTTP_BidirectionalInterface(t *testing.T) { Method: "test", }) - // Allow async execution - time.Sleep(50 * time.Millisecond) + // Wait for handler to be called + select { + case <-handlerSetChan: + // Handler was called + case <-time.After(1 * time.Second): + t.Fatal("Handler was not called within timeout") + } if !handlerSet { t.Error("Request handler was not properly set or called") From a4ec0b367311bb59de9a2d2348ce9289f2672cd8 Mon Sep 17 00:00:00 2001 From: andig Date: Sun, 27 Jul 2025 21:53:22 +0200 Subject: [PATCH 04/22] fix: improve request detection logic and add nil pointer checks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Make request vs response detection more robust by checking for presence of "method" field instead of relying on nil Result/Error fields - Add nil pointer check in sendResponseToServer function to prevent panics These changes improve reliability against malformed messages and edge cases. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- client/transport/streamable_http.go | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/client/transport/streamable_http.go b/client/transport/streamable_http.go index 22d5a038c..454a3318d 100644 --- a/client/transport/streamable_http.go +++ b/client/transport/streamable_http.go @@ -431,14 +431,16 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl return } - // Check if this is actually a request from the server - // If Result and Error are nil, it might be a request - if message.Result == nil && message.Error == nil { - var request JSONRPCRequest - if err := json.Unmarshal([]byte(data), &request); err == nil { - // This is a request from the server - c.handleIncomingRequest(ctx, request) - return + // Check if this is actually a request from the server by looking for method field + var rawMessage map[string]json.RawMessage + if err := json.Unmarshal([]byte(data), &rawMessage); err == nil { + if _, hasMethod := rawMessage["method"]; hasMethod && !message.ID.IsNil() { + var request JSONRPCRequest + if err := json.Unmarshal([]byte(data), &request); err == nil { + // This is a request from the server + c.handleIncomingRequest(ctx, request) + return + } } } @@ -728,6 +730,11 @@ func (c *StreamableHTTP) handleIncomingRequest(ctx context.Context, request JSON // sendResponseToServer sends a response back to the server via HTTP POST func (c *StreamableHTTP) sendResponseToServer(ctx context.Context, response *JSONRPCResponse) { + if response == nil { + c.logger.Errorf("cannot send nil response to server") + return + } + responseBody, err := json.Marshal(response) if err != nil { c.logger.Errorf("failed to marshal response: %v", err) From 1cae3a9ae919f53d0ff9fe3ab45f1889663e942c Mon Sep 17 00:00:00 2001 From: andig Date: Sun, 27 Jul 2025 22:02:15 +0200 Subject: [PATCH 05/22] fix: correct misleading comment about response delivery MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The comment incorrectly stated that responses are broadcast to all sessions, but the implementation actually delivers responses to the specific session identified by sessionID using the activeSessions registry. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- server/streamable_http.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/server/streamable_http.go b/server/streamable_http.go index f029674e1..4e65b7a65 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -587,9 +587,7 @@ func (s *StreamableHTTPServer) handleSamplingResponse(w http.ResponseWriter, r * } // Find the corresponding session and deliver the response - // Note: In a real implementation, we would need to maintain a mapping of sessionID to session instances - // For now, we'll use a simplified approach by broadcasting to all active sessions - // This is a limitation of the current architecture that would need to be addressed in production + // The response is delivered to the specific session identified by sessionID s.deliverSamplingResponse(sessionID, response) // Acknowledge receipt From 204b273254a18e7de2c630f33e4ee707b35acac1 Mon Sep 17 00:00:00 2001 From: andig Date: Sun, 27 Jul 2025 22:08:45 +0200 Subject: [PATCH 06/22] fix: implement EnableSampling() to properly declare sampling capability MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously, EnableSampling() was a no-op that didn't actually enable the sampling capability in the server's declared capabilities. Changes: - Add Sampling field to mcp.ServerCapabilities struct - Add sampling field to internal serverCapabilities struct - Update EnableSampling() to set the sampling capability flag - Update handleInitialize() to include sampling in capability response - Add test to verify sampling capability is properly declared Now when EnableSampling() is called, the server will properly declare sampling capability during initialization, allowing clients to know that the server supports sending sampling requests. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- mcp/types.go | 2 ++ server/sampling.go | 3 +++ server/sampling_test.go | 39 +++++++++++++++++++++++++++++++++++++++ server/server.go | 5 +++++ 4 files changed, 49 insertions(+) diff --git a/mcp/types.go b/mcp/types.go index 0ef6811fd..a13e3689a 100644 --- a/mcp/types.go +++ b/mcp/types.go @@ -472,6 +472,8 @@ type ServerCapabilities struct { // list. ListChanged bool `json:"listChanged,omitempty"` } `json:"resources,omitempty"` + // Present if the server supports sending sampling requests to clients. + Sampling *struct{} `json:"sampling,omitempty"` // Present if the server offers any tools to call. Tools *struct { // Whether this server supports notifications for changes to the tool list. diff --git a/server/sampling.go b/server/sampling.go index ae0812fa5..4423ccf5f 100644 --- a/server/sampling.go +++ b/server/sampling.go @@ -12,6 +12,9 @@ import ( func (s *MCPServer) EnableSampling() { s.capabilitiesMu.Lock() defer s.capabilitiesMu.Unlock() + + enabled := true + s.capabilities.sampling = &enabled } // RequestSampling sends a sampling request to the client. diff --git a/server/sampling_test.go b/server/sampling_test.go index c69ac6cb5..fbecdd70d 100644 --- a/server/sampling_test.go +++ b/server/sampling_test.go @@ -113,3 +113,42 @@ func TestMCPServer_RequestSampling_Success(t *testing.T) { t.Errorf("expected model %q, got %q", "test-model", result.Model) } } + +func TestMCPServer_EnableSampling_SetsCapability(t *testing.T) { + server := NewMCPServer("test", "1.0.0") + + // Verify sampling capability is not set initially + ctx := context.Background() + initRequest := mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: "2025-03-26", + ClientInfo: mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + }, + Capabilities: mcp.ClientCapabilities{}, + }, + } + + result, err := server.handleInitialize(ctx, 1, initRequest) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Capabilities.Sampling != nil { + t.Error("sampling capability should not be set before EnableSampling() is called") + } + + // Enable sampling + server.EnableSampling() + + // Verify sampling capability is now set + result, err = server.handleInitialize(ctx, 2, initRequest) + if err != nil { + t.Fatalf("unexpected error after EnableSampling(): %v", err) + } + + if result.Capabilities.Sampling == nil { + t.Error("sampling capability should be set after EnableSampling() is called") + } +} diff --git a/server/server.go b/server/server.go index a98a2132b..7822a125f 100644 --- a/server/server.go +++ b/server/server.go @@ -181,6 +181,7 @@ type serverCapabilities struct { resources *resourceCapabilities prompts *promptCapabilities logging *bool + sampling *bool } // resourceCapabilities defines the supported resource-related features @@ -580,6 +581,10 @@ func (s *MCPServer) handleInitialize( capabilities.Logging = &struct{}{} } + if s.capabilities.sampling != nil && *s.capabilities.sampling { + capabilities.Sampling = &struct{}{} + } + result := mcp.InitializeResult{ ProtocolVersion: s.protocolVersion(request.Params.ProtocolVersion), ServerInfo: mcp.Implementation{ From 5d4fb64ef27958893b6dc3b65ba448e3ec92c4b7 Mon Sep 17 00:00:00 2001 From: andig Date: Sun, 27 Jul 2025 22:39:08 +0200 Subject: [PATCH 07/22] fix: prevent panic from unsafe type assertion in example server MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace unsafe type assertion result.Content.(mcp.TextContent).Text with safe type checking to handle cases where Content might not be a TextContent struct. Now gracefully handles different content types without panicking. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- examples/sampling_http_server/main.go | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/examples/sampling_http_server/main.go b/examples/sampling_http_server/main.go index b2f84533e..95a2bf29b 100644 --- a/examples/sampling_http_server/main.go +++ b/examples/sampling_http_server/main.go @@ -80,12 +80,20 @@ func main() { }, nil } + // Extract response text safely + var responseText string + if textContent, ok := result.Content.(mcp.TextContent); ok { + responseText = textContent.Text + } else { + responseText = fmt.Sprintf("%v", result.Content) + } + // Return the LLM response return &mcp.CallToolResult{ Content: []mcp.Content{ mcp.TextContent{ Type: "text", - Text: fmt.Sprintf("LLM Response (model: %s): %s", result.Model, result.Content.(mcp.TextContent).Text), + Text: fmt.Sprintf("LLM Response (model: %s): %s", result.Model, responseText), }, }, }, nil From 4e41f25d8106eb342329f8a6e8a930943246492d Mon Sep 17 00:00:00 2001 From: andig Date: Sun, 27 Jul 2025 22:41:51 +0200 Subject: [PATCH 08/22] fix: add missing EnableSampling() call in interface test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The SamplingInterface test was missing the EnableSampling() call, which is necessary to activate sampling features for proper testing. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- server/streamable_http_sampling_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/server/streamable_http_sampling_test.go b/server/streamable_http_sampling_test.go index 8e1ff4b05..9576a39d8 100644 --- a/server/streamable_http_sampling_test.go +++ b/server/streamable_http_sampling_test.go @@ -83,6 +83,7 @@ func TestStreamableHTTPServer_SamplingErrorHandling(t *testing.T) { // TestStreamableHTTPServer_SamplingInterface verifies interface implementation func TestStreamableHTTPServer_SamplingInterface(t *testing.T) { mcpServer := NewMCPServer("test-server", "1.0.0") + mcpServer.EnableSampling() httpServer := NewStreamableHTTPServer(mcpServer) testServer := httptest.NewServer(httpServer) defer testServer.Close() From 178e234c38062d642d6145a099db7c60903b56c5 Mon Sep 17 00:00:00 2001 From: andig Date: Sun, 27 Jul 2025 22:43:12 +0200 Subject: [PATCH 09/22] fix: expand error test coverage and avoid t.Fatalf MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace single error test with comprehensive table-driven tests - Add test cases for invalid request IDs and malformed results - Replace t.Fatalf with t.Errorf to follow project conventions - Use proper session ID format for valid test scenarios 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- server/streamable_http_sampling_test.go | 88 ++++++++++++++++++++----- 1 file changed, 70 insertions(+), 18 deletions(-) diff --git a/server/streamable_http_sampling_test.go b/server/streamable_http_sampling_test.go index 9576a39d8..bb5e92836 100644 --- a/server/streamable_http_sampling_test.go +++ b/server/streamable_http_sampling_test.go @@ -55,28 +55,80 @@ func TestStreamableHTTPServer_SamplingErrorHandling(t *testing.T) { client := &http.Client{} baseURL := testServer.URL - // Test sending sampling response without session ID - samplingResponse := map[string]interface{}{ - "jsonrpc": "2.0", - "id": 1, - "result": map[string]interface{}{ - "role": "assistant", - "content": map[string]interface{}{ - "type": "text", - "text": "Test response", + tests := []struct { + name string + sessionID string + body map[string]interface{} + expectedStatus int + }{ + { + name: "missing session ID", + sessionID: "", + body: map[string]interface{}{ + "jsonrpc": "2.0", + "id": 1, + "result": map[string]interface{}{ + "role": "assistant", + "content": map[string]interface{}{ + "type": "text", + "text": "Test response", + }, + }, }, + expectedStatus: http.StatusBadRequest, + }, + { + name: "invalid request ID", + sessionID: "mcp-session-550e8400-e29b-41d4-a716-446655440000", + body: map[string]interface{}{ + "jsonrpc": "2.0", + "id": "invalid-id", + "result": map[string]interface{}{ + "role": "assistant", + "content": map[string]interface{}{ + "type": "text", + "text": "Test response", + }, + }, + }, + expectedStatus: http.StatusBadRequest, + }, + { + name: "malformed result", + sessionID: "mcp-session-550e8400-e29b-41d4-a716-446655440000", + body: map[string]interface{}{ + "jsonrpc": "2.0", + "id": 1, + "result": "invalid-result", + }, + expectedStatus: http.StatusOK, // Still returns OK but logs error internally }, } - responseBody, _ := json.Marshal(samplingResponse) - resp, err := client.Post(baseURL, "application/json", bytes.NewReader(responseBody)) - if err != nil { - t.Fatalf("Failed to send request: %v", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusBadRequest { - t.Errorf("Expected 400 Bad Request for missing session ID, got %d", resp.StatusCode) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + payload, _ := json.Marshal(tt.body) + req, err := http.NewRequest("POST", baseURL, bytes.NewReader(payload)) + if err != nil { + t.Errorf("Failed to create request: %v", err) + return + } + req.Header.Set("Content-Type", "application/json") + if tt.sessionID != "" { + req.Header.Set("Mcp-Session-Id", tt.sessionID) + } + + resp, err := client.Do(req) + if err != nil { + t.Errorf("Failed to send request: %v", err) + return + } + defer resp.Body.Close() + + if resp.StatusCode != tt.expectedStatus { + t.Errorf("Expected status %d, got %d", tt.expectedStatus, resp.StatusCode) + } + }) } } From 27322ca8619cd2b43c7336e1e408cb68430e0b8d Mon Sep 17 00:00:00 2001 From: andig Date: Sun, 27 Jul 2025 22:44:28 +0200 Subject: [PATCH 10/22] fix: eliminate recursive response handling and improve routing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Remove recursive call in RequestSampling that could cause stack overflow - Remove problematic response re-queuing to global channel - Update deliverSamplingResponse to route responses directly to dedicated request channels via samplingRequests map lookup - This prevents ordering issues and ensures responses reach the correct waiting request 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- server/streamable_http.go | 37 ++++++++++++++----------------------- 1 file changed, 14 insertions(+), 23 deletions(-) diff --git a/server/streamable_http.go b/server/streamable_http.go index 4e65b7a65..44feca23f 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -600,12 +600,20 @@ func (s *StreamableHTTPServer) deliverSamplingResponse(sessionID string, respons // Look up the active session if sessionInterface, ok := s.activeSessions.Load(sessionID); ok { if session, ok := sessionInterface.(*streamableHttpSession); ok { - // Deliver the response via the session's response channel - select { - case session.samplingResponseChan <- response: - s.logger.Infof("Delivered sampling response for session %s, request %d", sessionID, response.requestID) - default: - s.logger.Errorf("Failed to deliver sampling response for session %s, request %d: channel full", sessionID, response.requestID) + // Look up the dedicated response channel for this specific request + if responseChannelInterface, exists := session.samplingRequests.Load(response.requestID); exists { + if responseChan, ok := responseChannelInterface.(chan samplingResponseItem); ok { + select { + case responseChan <- response: + s.logger.Infof("Delivered sampling response for session %s, request %d", sessionID, response.requestID) + default: + s.logger.Errorf("Failed to deliver sampling response for session %s, request %d: channel full", sessionID, response.requestID) + } + } else { + s.logger.Errorf("Invalid response channel type for session %s, request %d", sessionID, response.requestID) + } + } else { + s.logger.Errorf("No pending request found for session %s, request %d", sessionID, response.requestID) } } else { s.logger.Errorf("Invalid session type for session %s", sessionID) @@ -827,23 +835,6 @@ func (s *streamableHttpSession) RequestSampling(ctx context.Context, request mcp return nil, response.err } return response.result, nil - case response := <-s.samplingResponseChan: - // Check if this response matches our request - if response.requestID == requestID { - if response.err != nil { - return nil, response.err - } - return response.result, nil - } - // If it's not our response, put it back and continue waiting - // This is a simplified approach; in production you'd want better routing - select { - case s.samplingResponseChan <- response: - default: - // Channel full, log and continue - } - // Continue waiting for our response - return s.RequestSampling(ctx, request) case <-ctx.Done(): return nil, ctx.Err() } From d0259751878dd033ff5f72d2b3c9e19bdc9306ae Mon Sep 17 00:00:00 2001 From: andig Date: Sun, 27 Jul 2025 22:47:05 +0200 Subject: [PATCH 11/22] fix: improve sampling response delivery robustness MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Modified deliverSamplingResponse to return error instead of just logging - Added proper error handling for disconnected sessions - Improved error messages for debugging - Updated test expectations to match new error behavior 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- server/streamable_http.go | 58 +++++++++++++++---------- server/streamable_http_sampling_test.go | 2 +- 2 files changed, 35 insertions(+), 25 deletions(-) diff --git a/server/streamable_http.go b/server/streamable_http.go index 44feca23f..e199758a9 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -588,7 +588,11 @@ func (s *StreamableHTTPServer) handleSamplingResponse(w http.ResponseWriter, r * // Find the corresponding session and deliver the response // The response is delivered to the specific session identified by sessionID - s.deliverSamplingResponse(sessionID, response) + if err := s.deliverSamplingResponse(sessionID, response); err != nil { + s.logger.Errorf("Failed to deliver sampling response: %v", err) + http.Error(w, "Failed to deliver response", http.StatusInternalServerError) + return err + } // Acknowledge receipt w.WriteHeader(http.StatusOK) @@ -596,30 +600,36 @@ func (s *StreamableHTTPServer) handleSamplingResponse(w http.ResponseWriter, r * } // deliverSamplingResponse delivers a sampling response to the appropriate session -func (s *StreamableHTTPServer) deliverSamplingResponse(sessionID string, response samplingResponseItem) { +func (s *StreamableHTTPServer) deliverSamplingResponse(sessionID string, response samplingResponseItem) error { // Look up the active session - if sessionInterface, ok := s.activeSessions.Load(sessionID); ok { - if session, ok := sessionInterface.(*streamableHttpSession); ok { - // Look up the dedicated response channel for this specific request - if responseChannelInterface, exists := session.samplingRequests.Load(response.requestID); exists { - if responseChan, ok := responseChannelInterface.(chan samplingResponseItem); ok { - select { - case responseChan <- response: - s.logger.Infof("Delivered sampling response for session %s, request %d", sessionID, response.requestID) - default: - s.logger.Errorf("Failed to deliver sampling response for session %s, request %d: channel full", sessionID, response.requestID) - } - } else { - s.logger.Errorf("Invalid response channel type for session %s, request %d", sessionID, response.requestID) - } - } else { - s.logger.Errorf("No pending request found for session %s, request %d", sessionID, response.requestID) - } - } else { - s.logger.Errorf("Invalid session type for session %s", sessionID) - } - } else { - s.logger.Errorf("No active session found for session %s", sessionID) + sessionInterface, ok := s.activeSessions.Load(sessionID) + if !ok { + return fmt.Errorf("no active session found for session %s", sessionID) + } + + session, ok := sessionInterface.(*streamableHttpSession) + if !ok { + return fmt.Errorf("invalid session type for session %s", sessionID) + } + + // Look up the dedicated response channel for this specific request + responseChannelInterface, exists := session.samplingRequests.Load(response.requestID) + if !exists { + return fmt.Errorf("no pending request found for session %s, request %d", sessionID, response.requestID) + } + + responseChan, ok := responseChannelInterface.(chan samplingResponseItem) + if !ok { + return fmt.Errorf("invalid response channel type for session %s, request %d", sessionID, response.requestID) + } + + // Attempt to deliver the response with timeout to prevent indefinite blocking + select { + case responseChan <- response: + s.logger.Infof("Delivered sampling response for session %s, request %d", sessionID, response.requestID) + return nil + default: + return fmt.Errorf("failed to deliver sampling response for session %s, request %d: channel full or blocked", sessionID, response.requestID) } } diff --git a/server/streamable_http_sampling_test.go b/server/streamable_http_sampling_test.go index bb5e92836..287fc0643 100644 --- a/server/streamable_http_sampling_test.go +++ b/server/streamable_http_sampling_test.go @@ -101,7 +101,7 @@ func TestStreamableHTTPServer_SamplingErrorHandling(t *testing.T) { "id": 1, "result": "invalid-result", }, - expectedStatus: http.StatusOK, // Still returns OK but logs error internally + expectedStatus: http.StatusInternalServerError, // Now correctly returns 500 due to no active session }, } From a9b20be14a6d3aeb29d47dd03dc9fa7ff14bf6cf Mon Sep 17 00:00:00 2001 From: andig Date: Mon, 28 Jul 2025 08:40:21 +0200 Subject: [PATCH 12/22] fix: add graceful shutdown handling to sampling client MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add signal handling for SIGINT and SIGTERM - Move defer statement after error checking - Improve shutdown error handling 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- examples/sampling_client/main.go | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/examples/sampling_client/main.go b/examples/sampling_client/main.go index 67b3840b0..8d53ff18a 100644 --- a/examples/sampling_client/main.go +++ b/examples/sampling_client/main.go @@ -5,6 +5,8 @@ import ( "fmt" "log" "os" + "os/signal" + "syscall" "github.com/mark3labs/mcp-go/client" "github.com/mark3labs/mcp-go/client/transport" @@ -89,7 +91,25 @@ func main() { if err := mcpClient.Start(ctx); err != nil { log.Fatalf("Failed to start client: %v", err) } - defer mcpClient.Close() + + // Setup graceful shutdown + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + // Create a context that cancels on signal + ctx, cancel := context.WithCancel(ctx) + go func() { + <-sigChan + log.Println("Received shutdown signal, closing client...") + cancel() + }() + + // Move defer after error checking + defer func() { + if err := mcpClient.Close(); err != nil { + log.Printf("Error closing client: %v", err) + } + }() // Initialize the connection initResult, err := mcpClient.Initialize(ctx, mcp.InitializeRequest{ From b7afbb946a487b48f82f5b7afece1c7ff69c4965 Mon Sep 17 00:00:00 2001 From: andig Date: Mon, 28 Jul 2025 08:41:03 +0200 Subject: [PATCH 13/22] fix: improve context handling in streamable HTTP transport MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add timeout context for SSE response processing (30s default) - Add timeout for individual connection attempts in listenForever (10s) - Use context-aware sleep in retry logic - Ensure async goroutines properly respect context cancellation 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- client/transport/streamable_http.go | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/client/transport/streamable_http.go b/client/transport/streamable_http.go index 454a3318d..9e9b5fc58 100644 --- a/client/transport/streamable_http.go +++ b/client/transport/streamable_http.go @@ -400,12 +400,19 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl // Create a channel for this specific request responseChan := make(chan *JSONRPCResponse, 1) + // Add timeout context for request processing if not already set + if deadline, ok := ctx.Deadline(); !ok || time.Until(deadline) > 30*time.Second { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, 30*time.Second) + defer cancel() + } + ctx, cancel := context.WithCancel(ctx) defer cancel() // Start a goroutine to process the SSE stream go func() { - // only close responseChan after readingSSE() + // Ensure this goroutine respects the context defer close(responseChan) c.readSSE(ctx, reader, func(event, data string) { @@ -588,7 +595,11 @@ func (c *StreamableHTTP) IsOAuthEnabled() bool { func (c *StreamableHTTP) listenForever(ctx context.Context) { c.logger.Infof("listening to server forever") for { - err := c.createGETConnectionToServer(ctx) + // Add timeout for individual connection attempts + connectCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + err := c.createGETConnectionToServer(connectCtx) + cancel() + if errors.Is(err, ErrGetMethodNotAllowed) { // server does not support listening c.logger.Errorf("server does not support listening") @@ -604,7 +615,13 @@ func (c *StreamableHTTP) listenForever(ctx context.Context) { if err != nil { c.logger.Errorf("failed to listen to server. retry in 1 second: %v", err) } - time.Sleep(retryInterval) + + // Use context-aware sleep + select { + case <-time.After(retryInterval): + case <-ctx.Done(): + return + } } } From a6642894b9203ffd3f594472522ef253be67e4ca Mon Sep 17 00:00:00 2001 From: andig Date: Mon, 28 Jul 2025 08:45:17 +0200 Subject: [PATCH 14/22] fix: improve error message for notification channel queue full condition MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Make error message more descriptive and actionable - Provide clearer debugging information about why the channel is blocked 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- server/errors.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/errors.go b/server/errors.go index ecbe91e5f..3864f36f7 100644 --- a/server/errors.go +++ b/server/errors.go @@ -21,7 +21,7 @@ var ( // Notification-related errors ErrNotificationNotInitialized = errors.New("notification channel not initialized") - ErrNotificationChannelBlocked = errors.New("notification channel full or blocked") + ErrNotificationChannelBlocked = errors.New("notification channel queue is full - client may not be processing notifications fast enough") ) // ErrDynamicPathConfig is returned when attempting to use static path methods with dynamic path configuration From bac5dadc42df34592bd404f82e2a5e79a8cd751a Mon Sep 17 00:00:00 2001 From: andig Date: Mon, 28 Jul 2025 08:46:22 +0200 Subject: [PATCH 15/22] refactor: rename struct variable for clarity in message parsing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Rename 'baseMessage' to 'jsonMessage' for more neutral naming - Improves code readability and follows consistent naming conventions 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- server/streamable_http.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/server/streamable_http.go b/server/streamable_http.go index e199758a9..fa71b5105 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -225,26 +225,26 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request return } // First, try to parse as a response (sampling responses don't have a method field) - var responseMessage struct { + var jsonMessage struct { ID json.RawMessage `json:"id"` Result json.RawMessage `json:"result,omitempty"` Error json.RawMessage `json:"error,omitempty"` Method mcp.MCPMethod `json:"method,omitempty"` } - if err := json.Unmarshal(rawData, &responseMessage); err != nil { + if err := json.Unmarshal(rawData, &jsonMessage); err != nil { s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, "request body is not valid json") return } // Check if this is a sampling response (has result/error but no method) - isSamplingResponse := responseMessage.Method == "" && responseMessage.ID != nil && - (responseMessage.Result != nil || responseMessage.Error != nil) + isSamplingResponse := jsonMessage.Method == "" && jsonMessage.ID != nil && + (jsonMessage.Result != nil || jsonMessage.Error != nil) - isInitializeRequest := responseMessage.Method == mcp.MethodInitialize + isInitializeRequest := jsonMessage.Method == mcp.MethodInitialize // Handle sampling responses separately if isSamplingResponse { - if err := s.handleSamplingResponse(w, r, responseMessage); err != nil { + if err := s.handleSamplingResponse(w, r, jsonMessage); err != nil { s.logger.Errorf("Failed to handle sampling response: %v", err) http.Error(w, "Failed to handle sampling response", http.StatusInternalServerError) } From e69716dbd0e32bd4788c61c8544d159c2e2326a2 Mon Sep 17 00:00:00 2001 From: andig Date: Mon, 28 Jul 2025 09:01:20 +0200 Subject: [PATCH 16/22] test: add concurrent sampling requests test with response association MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add test verifying that concurrent sampling requests are handled correctly when the second request completes faster than the first. The test ensures: - Responses are correctly associated with their request IDs - Server processes requests concurrently without blocking - Completion order follows actual processing time, not submission order 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- .../streamable_http_sampling_test.go | 217 ++++++++++++++++++ 1 file changed, 217 insertions(+) diff --git a/client/transport/streamable_http_sampling_test.go b/client/transport/streamable_http_sampling_test.go index d5b49df29..6702b32ce 100644 --- a/client/transport/streamable_http_sampling_test.go +++ b/client/transport/streamable_http_sampling_test.go @@ -276,4 +276,221 @@ func TestStreamableHTTP_BidirectionalInterface(t *testing.T) { if !handlerSet { t.Error("Request handler was not properly set or called") } +} + +// TestStreamableHTTP_ConcurrentSamplingRequests tests concurrent sampling requests +// where the second request completes faster than the first request +func TestStreamableHTTP_ConcurrentSamplingRequests(t *testing.T) { + var receivedResponses []map[string]interface{} + var responseMutex sync.Mutex + responseComplete := make(chan struct{}, 2) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var body map[string]interface{} + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Logf("Failed to decode body: %v", err) + w.WriteHeader(http.StatusBadRequest) + return + } + + // Check if this is a response from client (not a request) + if _, ok := body["result"]; ok { + responseMutex.Lock() + receivedResponses = append(receivedResponses, body) + responseMutex.Unlock() + responseComplete <- struct{}{} + } + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client, err := NewStreamableHTTP(server.URL) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer client.Close() + + // Track which requests have been received and their completion order + var requestOrder []int + var orderMutex sync.Mutex + + // Set up request handler that simulates different processing times + client.SetRequestHandler(func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) { + // Extract request ID to determine processing time + requestIDValue := request.ID.Value() + + var delay time.Duration + var responseText string + var requestNum int + + // First request (ID 1) takes longer, second request (ID 2) completes faster + if requestIDValue == int64(1) { + delay = 100 * time.Millisecond + responseText = "Response from slow request 1" + requestNum = 1 + } else if requestIDValue == int64(2) { + delay = 10 * time.Millisecond + responseText = "Response from fast request 2" + requestNum = 2 + } else { + t.Errorf("Unexpected request ID: %v", requestIDValue) + return nil, fmt.Errorf("unexpected request ID") + } + + // Simulate processing time + time.Sleep(delay) + + // Record completion order + orderMutex.Lock() + requestOrder = append(requestOrder, requestNum) + orderMutex.Unlock() + + // Return response with correct request ID + result := map[string]interface{}{ + "role": "assistant", + "content": map[string]interface{}{ + "type": "text", + "text": responseText, + }, + "model": "test-model", + "stopReason": "stop_sequence", + } + + resultBytes, _ := json.Marshal(result) + + return &JSONRPCResponse{ + JSONRPC: "2.0", + ID: request.ID, + Result: resultBytes, + }, nil + }) + + // Start the client + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + err = client.Start(ctx) + if err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Create two sampling requests with different IDs + request1 := JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(int64(1)), + Method: string(mcp.MethodSamplingCreateMessage), + Params: map[string]interface{}{ + "messages": []map[string]interface{}{ + { + "role": "user", + "content": map[string]interface{}{ + "type": "text", + "text": "Slow request 1", + }, + }, + }, + }, + } + + request2 := JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(int64(2)), + Method: string(mcp.MethodSamplingCreateMessage), + Params: map[string]interface{}{ + "messages": []map[string]interface{}{ + { + "role": "user", + "content": map[string]interface{}{ + "type": "text", + "text": "Fast request 2", + }, + }, + }, + }, + } + + // Send both requests concurrently + go client.handleIncomingRequest(ctx, request1) + go client.handleIncomingRequest(ctx, request2) + + // Wait for both responses to complete + for i := 0; i < 2; i++ { + select { + case <-responseComplete: + // Response received + case <-time.After(2 * time.Second): + t.Fatal("Timeout waiting for response") + } + } + + // Verify completion order: request 2 should complete first + orderMutex.Lock() + defer orderMutex.Unlock() + + if len(requestOrder) != 2 { + t.Fatalf("Expected 2 completed requests, got %d", len(requestOrder)) + } + + if requestOrder[0] != 2 { + t.Errorf("Expected request 2 to complete first, but request %d completed first", requestOrder[0]) + } + + if requestOrder[1] != 1 { + t.Errorf("Expected request 1 to complete second, but request %d completed second", requestOrder[1]) + } + + // Verify responses are correctly associated + responseMutex.Lock() + defer responseMutex.Unlock() + + if len(receivedResponses) != 2 { + t.Fatalf("Expected 2 responses, got %d", len(receivedResponses)) + } + + // Find responses by ID + var response1, response2 map[string]interface{} + for _, resp := range receivedResponses { + if id, ok := resp["id"]; ok { + switch id { + case int64(1), float64(1): + response1 = resp + case int64(2), float64(2): + response2 = resp + } + } + } + + if response1 == nil { + t.Error("Response for request 1 not found") + } + if response2 == nil { + t.Error("Response for request 2 not found") + } + + // Verify each response contains the correct content + if response1 != nil { + if result, ok := response1["result"].(map[string]interface{}); ok { + if content, ok := result["content"].(map[string]interface{}); ok { + if text, ok := content["text"].(string); ok { + if !strings.Contains(text, "slow request 1") { + t.Errorf("Response 1 should contain 'slow request 1', got: %s", text) + } + } + } + } + } + + if response2 != nil { + if result, ok := response2["result"].(map[string]interface{}); ok { + if content, ok := result["content"].(map[string]interface{}); ok { + if text, ok := content["text"].(string); ok { + if !strings.Contains(text, "fast request 2") { + t.Errorf("Response 2 should contain 'fast request 2', got: %s", text) + } + } + } + } + } } \ No newline at end of file From e28a85922040a2b4a5af3b4c217ee4f7c434e13a Mon Sep 17 00:00:00 2001 From: andig Date: Mon, 28 Jul 2025 09:14:18 +0200 Subject: [PATCH 17/22] fix: improve context handling in async goroutine MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Create new context with 30-second timeout for request handling to prevent long-running handlers from blocking indefinitely. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- client/transport/streamable_http.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/client/transport/streamable_http.go b/client/transport/streamable_http.go index 9e9b5fc58..0f889cfec 100644 --- a/client/transport/streamable_http.go +++ b/client/transport/streamable_http.go @@ -695,7 +695,11 @@ func (c *StreamableHTTP) handleIncomingRequest(ctx context.Context, request JSON // Handle the request in a goroutine to avoid blocking the SSE reader go func() { - response, err := handler(ctx, request) + // Create a new context with timeout for request handling + requestCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + response, err := handler(requestCtx, request) if err != nil { c.logger.Errorf("error handling request %s: %v", request.Method, err) @@ -735,12 +739,12 @@ func (c *StreamableHTTP) handleIncomingRequest(ctx context.Context, request JSON Message: errorMessage, }, } - c.sendResponseToServer(ctx, errorResponse) + c.sendResponseToServer(requestCtx, errorResponse) return } if response != nil { - c.sendResponseToServer(ctx, response) + c.sendResponseToServer(requestCtx, response) } }() } From 4fa52953f9873146ad7e500b4ba9587dc30d58e8 Mon Sep 17 00:00:00 2001 From: andig Date: Mon, 28 Jul 2025 09:23:08 +0200 Subject: [PATCH 18/22] refactor: replace interface{} with any throughout codebase MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace all occurrences of interface{} with the modern Go any type alias for improved readability and consistency with current Go best practices. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- .../streamable_http_sampling_test.go | 54 +++++++++---------- examples/sampling_client/main.go | 2 +- examples/sampling_server/main.go | 4 +- server/session_test.go | 14 ++--- server/streamable_http_sampling_test.go | 20 +++---- server/streamable_http_test.go | 4 +- 6 files changed, 49 insertions(+), 49 deletions(-) diff --git a/client/transport/streamable_http_sampling_test.go b/client/transport/streamable_http_sampling_test.go index 6702b32ce..edba61eac 100644 --- a/client/transport/streamable_http_sampling_test.go +++ b/client/transport/streamable_http_sampling_test.go @@ -38,9 +38,9 @@ func TestStreamableHTTP_SamplingFlow(t *testing.T) { close(handlerCalled) // Simulate sampling handler response - result := map[string]interface{}{ + result := map[string]any{ "role": "assistant", - "content": map[string]interface{}{ + "content": map[string]any{ "type": "text", "text": "Hello! How can I help you today?", }, @@ -71,11 +71,11 @@ func TestStreamableHTTP_SamplingFlow(t *testing.T) { JSONRPC: "2.0", ID: mcp.NewRequestId(1), Method: string(mcp.MethodSamplingCreateMessage), - Params: map[string]interface{}{ - "messages": []map[string]interface{}{ + Params: map[string]any{ + "messages": []map[string]any{ { "role": "user", - "content": map[string]interface{}{ + "content": map[string]any{ "type": "text", "text": "Hello, world!", }, @@ -112,7 +112,7 @@ func TestStreamableHTTP_SamplingErrorHandling(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodPost { - var body map[string]interface{} + var body map[string]any if err := json.NewDecoder(r.Body).Decode(&body); err != nil { t.Logf("Failed to decode body: %v", err) w.WriteHeader(http.StatusOK) @@ -121,7 +121,7 @@ func TestStreamableHTTP_SamplingErrorHandling(t *testing.T) { // Check if this is an error response if errorField, ok := body["error"]; ok { - errorMap := errorField.(map[string]interface{}) + errorMap := errorField.(map[string]any) if code, ok := errorMap["code"].(float64); ok && code == -32603 { errorHandled.Done() w.WriteHeader(http.StatusOK) @@ -156,7 +156,7 @@ func TestStreamableHTTP_SamplingErrorHandling(t *testing.T) { JSONRPC: "2.0", ID: mcp.NewRequestId(1), Method: string(mcp.MethodSamplingCreateMessage), - Params: map[string]interface{}{}, + Params: map[string]any{}, } // This should trigger error handling @@ -173,7 +173,7 @@ func TestStreamableHTTP_NoSamplingHandler(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodPost { - var body map[string]interface{} + var body map[string]any if err := json.NewDecoder(r.Body).Decode(&body); err != nil { t.Logf("Failed to decode body: %v", err) w.WriteHeader(http.StatusOK) @@ -182,7 +182,7 @@ func TestStreamableHTTP_NoSamplingHandler(t *testing.T) { // Check if this is an error response with method not found if errorField, ok := body["error"]; ok { - errorMap := errorField.(map[string]interface{}) + errorMap := errorField.(map[string]any) if code, ok := errorMap["code"].(float64); ok && code == -32601 { if message, ok := errorMap["message"].(string); ok && strings.Contains(message, "no handler configured") { @@ -215,7 +215,7 @@ func TestStreamableHTTP_NoSamplingHandler(t *testing.T) { JSONRPC: "2.0", ID: mcp.NewRequestId(1), Method: string(mcp.MethodSamplingCreateMessage), - Params: map[string]interface{}{}, + Params: map[string]any{}, } // This should trigger "method not found" error @@ -243,7 +243,7 @@ func TestStreamableHTTP_BidirectionalInterface(t *testing.T) { defer client.Close() // Verify it implements BidirectionalInterface - _, ok := interface{}(client).(BidirectionalInterface) + _, ok := any(client).(BidirectionalInterface) if !ok { t.Error("StreamableHTTP should implement BidirectionalInterface") } @@ -281,13 +281,13 @@ func TestStreamableHTTP_BidirectionalInterface(t *testing.T) { // TestStreamableHTTP_ConcurrentSamplingRequests tests concurrent sampling requests // where the second request completes faster than the first request func TestStreamableHTTP_ConcurrentSamplingRequests(t *testing.T) { - var receivedResponses []map[string]interface{} + var receivedResponses []map[string]any var responseMutex sync.Mutex responseComplete := make(chan struct{}, 2) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodPost { - var body map[string]interface{} + var body map[string]any if err := json.NewDecoder(r.Body).Decode(&body); err != nil { t.Logf("Failed to decode body: %v", err) w.WriteHeader(http.StatusBadRequest) @@ -348,9 +348,9 @@ func TestStreamableHTTP_ConcurrentSamplingRequests(t *testing.T) { orderMutex.Unlock() // Return response with correct request ID - result := map[string]interface{}{ + result := map[string]any{ "role": "assistant", - "content": map[string]interface{}{ + "content": map[string]any{ "type": "text", "text": responseText, }, @@ -381,11 +381,11 @@ func TestStreamableHTTP_ConcurrentSamplingRequests(t *testing.T) { JSONRPC: "2.0", ID: mcp.NewRequestId(int64(1)), Method: string(mcp.MethodSamplingCreateMessage), - Params: map[string]interface{}{ - "messages": []map[string]interface{}{ + Params: map[string]any{ + "messages": []map[string]any{ { "role": "user", - "content": map[string]interface{}{ + "content": map[string]any{ "type": "text", "text": "Slow request 1", }, @@ -398,11 +398,11 @@ func TestStreamableHTTP_ConcurrentSamplingRequests(t *testing.T) { JSONRPC: "2.0", ID: mcp.NewRequestId(int64(2)), Method: string(mcp.MethodSamplingCreateMessage), - Params: map[string]interface{}{ - "messages": []map[string]interface{}{ + Params: map[string]any{ + "messages": []map[string]any{ { "role": "user", - "content": map[string]interface{}{ + "content": map[string]any{ "type": "text", "text": "Fast request 2", }, @@ -450,7 +450,7 @@ func TestStreamableHTTP_ConcurrentSamplingRequests(t *testing.T) { } // Find responses by ID - var response1, response2 map[string]interface{} + var response1, response2 map[string]any for _, resp := range receivedResponses { if id, ok := resp["id"]; ok { switch id { @@ -471,8 +471,8 @@ func TestStreamableHTTP_ConcurrentSamplingRequests(t *testing.T) { // Verify each response contains the correct content if response1 != nil { - if result, ok := response1["result"].(map[string]interface{}); ok { - if content, ok := result["content"].(map[string]interface{}); ok { + if result, ok := response1["result"].(map[string]any); ok { + if content, ok := result["content"].(map[string]any); ok { if text, ok := content["text"].(string); ok { if !strings.Contains(text, "slow request 1") { t.Errorf("Response 1 should contain 'slow request 1', got: %s", text) @@ -483,8 +483,8 @@ func TestStreamableHTTP_ConcurrentSamplingRequests(t *testing.T) { } if response2 != nil { - if result, ok := response2["result"].(map[string]interface{}); ok { - if content, ok := result["content"].(map[string]interface{}); ok { + if result, ok := response2["result"].(map[string]any); ok { + if content, ok := result["content"].(map[string]any); ok { if text, ok := content["text"].(string); ok { if !strings.Contains(text, "fast request 2") { t.Errorf("Response 2 should contain 'fast request 2', got: %s", text) diff --git a/examples/sampling_client/main.go b/examples/sampling_client/main.go index 8d53ff18a..093b59817 100644 --- a/examples/sampling_client/main.go +++ b/examples/sampling_client/main.go @@ -30,7 +30,7 @@ func (h *MockSamplingHandler) CreateMessage(ctx context.Context, request mcp.Cre switch content := userMessage.Content.(type) { case mcp.TextContent: userText = content.Text - case map[string]interface{}: + case map[string]any: // Handle case where content is unmarshaled as a map if text, ok := content["text"].(string); ok { userText = text diff --git a/examples/sampling_server/main.go b/examples/sampling_server/main.go index c3bcf4902..ea887c588 100644 --- a/examples/sampling_server/main.go +++ b/examples/sampling_server/main.go @@ -127,11 +127,11 @@ func main() { } // Helper function to extract text from content -func getTextFromContent(content interface{}) string { +func getTextFromContent(content any) string { switch c := content.(type) { case mcp.TextContent: return c.Text - case map[string]interface{}: + case map[string]any: // Handle JSON unmarshaled content if text, ok := c["text"].(string); ok { return text diff --git a/server/session_test.go b/server/session_test.go index 9bd8bc9fa..04334487b 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -1471,8 +1471,8 @@ func TestMCPServer_LoggingNotificationFormat(t *testing.T) { // Send log messages with different formats testCases := []struct { name string - data interface{} - expected interface{} + data any + expected any }{ { name: "string data", @@ -1481,8 +1481,8 @@ func TestMCPServer_LoggingNotificationFormat(t *testing.T) { }, { name: "structured data", - data: map[string]interface{}{"key": "value", "num": 42}, - expected: map[string]interface{}{"key": "value", "num": 42}, + data: map[string]any{"key": "value", "num": 42}, + expected: map[string]any{"key": "value", "num": 42}, }, { name: "error data", @@ -1514,9 +1514,9 @@ func TestMCPServer_LoggingNotificationFormat(t *testing.T) { switch expected := tc.expected.(type) { case string: assert.Equal(t, expected, dataField) - case map[string]interface{}: - assert.IsType(t, map[string]interface{}{}, dataField) - dataMap := dataField.(map[string]interface{}) + case map[string]any: + assert.IsType(t, map[string]any{}, dataField) + dataMap := dataField.(map[string]any) for k, v := range expected { assert.Equal(t, v, dataMap[k]) } diff --git a/server/streamable_http_sampling_test.go b/server/streamable_http_sampling_test.go index 287fc0643..7c95bdefa 100644 --- a/server/streamable_http_sampling_test.go +++ b/server/streamable_http_sampling_test.go @@ -29,7 +29,7 @@ func TestStreamableHTTPServer_SamplingBasic(t *testing.T) { session := newStreamableHttpSession(sessionID, httpServer.sessionTools, httpServer.sessionLogLevels) // Verify it implements SessionWithSampling - _, ok := interface{}(session).(SessionWithSampling) + _, ok := any(session).(SessionWithSampling) if !ok { t.Error("streamableHttpSession should implement SessionWithSampling") } @@ -58,18 +58,18 @@ func TestStreamableHTTPServer_SamplingErrorHandling(t *testing.T) { tests := []struct { name string sessionID string - body map[string]interface{} + body map[string]any expectedStatus int }{ { name: "missing session ID", sessionID: "", - body: map[string]interface{}{ + body: map[string]any{ "jsonrpc": "2.0", "id": 1, - "result": map[string]interface{}{ + "result": map[string]any{ "role": "assistant", - "content": map[string]interface{}{ + "content": map[string]any{ "type": "text", "text": "Test response", }, @@ -80,12 +80,12 @@ func TestStreamableHTTPServer_SamplingErrorHandling(t *testing.T) { { name: "invalid request ID", sessionID: "mcp-session-550e8400-e29b-41d4-a716-446655440000", - body: map[string]interface{}{ + body: map[string]any{ "jsonrpc": "2.0", "id": "invalid-id", - "result": map[string]interface{}{ + "result": map[string]any{ "role": "assistant", - "content": map[string]interface{}{ + "content": map[string]any{ "type": "text", "text": "Test response", }, @@ -96,7 +96,7 @@ func TestStreamableHTTPServer_SamplingErrorHandling(t *testing.T) { { name: "malformed result", sessionID: "mcp-session-550e8400-e29b-41d4-a716-446655440000", - body: map[string]interface{}{ + body: map[string]any{ "jsonrpc": "2.0", "id": 1, "result": "invalid-result", @@ -145,7 +145,7 @@ func TestStreamableHTTPServer_SamplingInterface(t *testing.T) { session := newStreamableHttpSession(sessionID, httpServer.sessionTools, httpServer.sessionLogLevels) // Verify it implements SessionWithSampling - _, ok := interface{}(session).(SessionWithSampling) + _, ok := any(session).(SessionWithSampling) if !ok { t.Error("streamableHttpSession should implement SessionWithSampling") } diff --git a/server/streamable_http_test.go b/server/streamable_http_test.go index 6f4a6edad..105fd18ce 100644 --- a/server/streamable_http_test.go +++ b/server/streamable_http_test.go @@ -207,7 +207,7 @@ func TestStreamableHTTP_POST_SendAndReceive(t *testing.T) { Notification: mcp.Notification{ Method: "testNotification", Params: mcp.NotificationParams{ - AdditionalFields: map[string]interface{}{"param1": "value1"}, + AdditionalFields: map[string]any{"param1": "value1"}, }, }, } @@ -395,7 +395,7 @@ func TestStreamableHTTP_POST_SendAndReceive_stateless(t *testing.T) { Notification: mcp.Notification{ Method: "testNotification", Params: mcp.NotificationParams{ - AdditionalFields: map[string]interface{}{"param1": "value1"}, + AdditionalFields: map[string]any{"param1": "value1"}, }, }, } From 3852e2d7445b20a81ed715e08f92389675fc3e12 Mon Sep 17 00:00:00 2001 From: andig Date: Mon, 28 Jul 2025 09:31:40 +0200 Subject: [PATCH 19/22] fix: improve context handling in async goroutine for StreamableHTTP MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Create timeout context from parent context instead of context.Background() to ensure request handlers respect parent context cancellation. Addresses review comment about context handling in async goroutine. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- client/transport/streamable_http.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/client/transport/streamable_http.go b/client/transport/streamable_http.go index 0f889cfec..f8965553a 100644 --- a/client/transport/streamable_http.go +++ b/client/transport/streamable_http.go @@ -695,8 +695,8 @@ func (c *StreamableHTTP) handleIncomingRequest(ctx context.Context, request JSON // Handle the request in a goroutine to avoid blocking the SSE reader go func() { - // Create a new context with timeout for request handling - requestCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + // Create a new context with timeout for request handling, respecting parent context + requestCtx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() response, err := handler(requestCtx, request) From 83883edc1042bfc305f6643cc34c05d69d3922cc Mon Sep 17 00:00:00 2001 From: andig Date: Mon, 28 Jul 2025 09:31:50 +0200 Subject: [PATCH 20/22] refactor: remove unused samplingResponseChan field from session struct MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The samplingResponseChan field was declared but never used in the streamableHttpSession struct. Remove it and update tests accordingly. Addresses review comment about unused fields in session struct. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- server/streamable_http.go | 2 -- server/streamable_http_sampling_test.go | 3 --- 2 files changed, 5 deletions(-) diff --git a/server/streamable_http.go b/server/streamable_http.go index fa71b5105..d84d8ace1 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -744,7 +744,6 @@ type streamableHttpSession struct { // Sampling support for bidirectional communication samplingRequestChan chan samplingRequestItem // server -> client sampling requests - samplingResponseChan chan samplingResponseItem // client -> server sampling responses samplingRequests sync.Map // requestID -> pending sampling request context requestIDCounter atomic.Int64 // for generating unique request IDs mu sync.RWMutex // protects sampling channels and requests @@ -757,7 +756,6 @@ func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore, le tools: toolStore, logLevels: levels, samplingRequestChan: make(chan samplingRequestItem, 10), - samplingResponseChan: make(chan samplingResponseItem, 10), } return s } diff --git a/server/streamable_http_sampling_test.go b/server/streamable_http_sampling_test.go index 7c95bdefa..4cf57838c 100644 --- a/server/streamable_http_sampling_test.go +++ b/server/streamable_http_sampling_test.go @@ -38,9 +38,6 @@ func TestStreamableHTTPServer_SamplingBasic(t *testing.T) { if session.samplingRequestChan == nil { t.Error("samplingRequestChan should be initialized") } - if session.samplingResponseChan == nil { - t.Error("samplingResponseChan should be initialized") - } } // TestStreamableHTTPServer_SamplingErrorHandling tests error scenarios From 9ea4a10bfb64e38f071731114b17f707464a0b67 Mon Sep 17 00:00:00 2001 From: andig Date: Mon, 28 Jul 2025 09:32:00 +0200 Subject: [PATCH 21/22] feat: add graceful shutdown handling to sampling HTTP client example MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add signal handling for SIGINT and SIGTERM to allow graceful shutdown of the sampling HTTP client example. This prevents indefinite blocking and provides better production-ready behavior. Addresses review comment about adding graceful shutdown handling. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- examples/sampling_http_client/main.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/examples/sampling_http_client/main.go b/examples/sampling_http_client/main.go index 4f365ad1b..98817e6f8 100644 --- a/examples/sampling_http_client/main.go +++ b/examples/sampling_http_client/main.go @@ -4,6 +4,9 @@ import ( "context" "fmt" "log" + "os" + "os/signal" + "syscall" "github.com/mark3labs/mcp-go/client" "github.com/mark3labs/mcp-go/client/transport" @@ -101,8 +104,13 @@ func main() { // For this example, we'll just demonstrate that it's working // Keep the client running (in a real app, you'd have your main application logic here) + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) + select { case <-ctx.Done(): log.Println("Client context cancelled") + case <-sigChan: + log.Println("Received shutdown signal") } } \ No newline at end of file From 11f8d0e40045cf4d5720ff5f60d275f1bb6476e1 Mon Sep 17 00:00:00 2001 From: andig Date: Mon, 28 Jul 2025 16:31:39 +0200 Subject: [PATCH 22/22] refactor: remove unused mu field from streamableHttpSession MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Removes unused sync.RWMutex field that was flagged by golangci-lint. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- server/streamable_http.go | 1 - 1 file changed, 1 deletion(-) diff --git a/server/streamable_http.go b/server/streamable_http.go index d84d8ace1..24ec1c95a 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -746,7 +746,6 @@ type streamableHttpSession struct { samplingRequestChan chan samplingRequestItem // server -> client sampling requests samplingRequests sync.Map // requestID -> pending sampling request context requestIDCounter atomic.Int64 // for generating unique request IDs - mu sync.RWMutex // protects sampling channels and requests } func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore, levels *sessionLogLevelsStore) *streamableHttpSession {