diff --git a/client/transport/streamable_http.go b/client/transport/streamable_http.go index e8b2fcc58..f8965553a 100644 --- a/client/transport/streamable_http.go +++ b/client/transport/streamable_http.go @@ -92,7 +92,6 @@ func WithSession(sessionID string) StreamableHTTPCOption { // The current implementation does not support the following features: // - resuming stream // (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#resumability-and-redelivery) -// - server -> client request type StreamableHTTP struct { serverURL *url.URL httpClient *http.Client @@ -110,6 +109,10 @@ type StreamableHTTP struct { notificationHandler func(mcp.JSONRPCNotification) notifyMu sync.RWMutex + // Request handler for incoming server-to-client requests (like sampling) + requestHandler RequestHandler + requestMu sync.RWMutex + closed chan struct{} // OAuth support @@ -397,15 +400,23 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl // Create a channel for this specific request responseChan := make(chan *JSONRPCResponse, 1) + // Add timeout context for request processing if not already set + if deadline, ok := ctx.Deadline(); !ok || time.Until(deadline) > 30*time.Second { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, 30*time.Second) + defer cancel() + } + ctx, cancel := context.WithCancel(ctx) defer cancel() // Start a goroutine to process the SSE stream go func() { - // only close responseChan after readingSSE() + // Ensure this goroutine respects the context defer close(responseChan) c.readSSE(ctx, reader, func(event, data string) { + // Try to unmarshal as a response first var message JSONRPCResponse if err := json.Unmarshal([]byte(data), &message); err != nil { c.logger.Errorf("failed to unmarshal message: %v", err) @@ -427,6 +438,19 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl return } + // Check if this is actually a request from the server by looking for method field + var rawMessage map[string]json.RawMessage + if err := json.Unmarshal([]byte(data), &rawMessage); err == nil { + if _, hasMethod := rawMessage["method"]; hasMethod && !message.ID.IsNil() { + var request JSONRPCRequest + if err := json.Unmarshal([]byte(data), &request); err == nil { + // This is a request from the server + c.handleIncomingRequest(ctx, request) + return + } + } + } + if !ignoreResponse { responseChan <- &message } @@ -547,6 +571,13 @@ func (c *StreamableHTTP) SetNotificationHandler(handler func(mcp.JSONRPCNotifica c.notificationHandler = handler } +// SetRequestHandler sets the handler for incoming requests from the server. +func (c *StreamableHTTP) SetRequestHandler(handler RequestHandler) { + c.requestMu.Lock() + defer c.requestMu.Unlock() + c.requestHandler = handler +} + func (c *StreamableHTTP) GetSessionId() string { return c.sessionID.Load().(string) } @@ -564,7 +595,11 @@ func (c *StreamableHTTP) IsOAuthEnabled() bool { func (c *StreamableHTTP) listenForever(ctx context.Context) { c.logger.Infof("listening to server forever") for { - err := c.createGETConnectionToServer(ctx) + // Add timeout for individual connection attempts + connectCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + err := c.createGETConnectionToServer(connectCtx) + cancel() + if errors.Is(err, ErrGetMethodNotAllowed) { // server does not support listening c.logger.Errorf("server does not support listening") @@ -580,7 +615,13 @@ func (c *StreamableHTTP) listenForever(ctx context.Context) { if err != nil { c.logger.Errorf("failed to listen to server. retry in 1 second: %v", err) } - time.Sleep(retryInterval) + + // Use context-aware sleep + select { + case <-time.After(retryInterval): + case <-ctx.Done(): + return + } } } @@ -627,6 +668,116 @@ func (c *StreamableHTTP) createGETConnectionToServer(ctx context.Context) error return nil } +// handleIncomingRequest processes requests from the server (like sampling requests) +func (c *StreamableHTTP) handleIncomingRequest(ctx context.Context, request JSONRPCRequest) { + c.requestMu.RLock() + handler := c.requestHandler + c.requestMu.RUnlock() + + if handler == nil { + c.logger.Errorf("received request from server but no handler set: %s", request.Method) + // Send method not found error + errorResponse := &JSONRPCResponse{ + JSONRPC: "2.0", + ID: request.ID, + Error: &struct { + Code int `json:"code"` + Message string `json:"message"` + Data json.RawMessage `json:"data"` + }{ + Code: -32601, // Method not found + Message: fmt.Sprintf("no handler configured for method: %s", request.Method), + }, + } + c.sendResponseToServer(ctx, errorResponse) + return + } + + // Handle the request in a goroutine to avoid blocking the SSE reader + go func() { + // Create a new context with timeout for request handling, respecting parent context + requestCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + response, err := handler(requestCtx, request) + if err != nil { + c.logger.Errorf("error handling request %s: %v", request.Method, err) + + // Determine appropriate JSON-RPC error code based on error type + var errorCode int + var errorMessage string + + // Check for specific sampling-related errors + if errors.Is(err, context.Canceled) { + errorCode = -32800 // Request cancelled + errorMessage = "request was cancelled" + } else if errors.Is(err, context.DeadlineExceeded) { + errorCode = -32800 // Request timeout + errorMessage = "request timed out" + } else { + // Generic error cases + switch request.Method { + case string(mcp.MethodSamplingCreateMessage): + errorCode = -32603 // Internal error + errorMessage = fmt.Sprintf("sampling request failed: %v", err) + default: + errorCode = -32603 // Internal error + errorMessage = err.Error() + } + } + + // Send error response + errorResponse := &JSONRPCResponse{ + JSONRPC: "2.0", + ID: request.ID, + Error: &struct { + Code int `json:"code"` + Message string `json:"message"` + Data json.RawMessage `json:"data"` + }{ + Code: errorCode, + Message: errorMessage, + }, + } + c.sendResponseToServer(requestCtx, errorResponse) + return + } + + if response != nil { + c.sendResponseToServer(requestCtx, response) + } + }() +} + +// sendResponseToServer sends a response back to the server via HTTP POST +func (c *StreamableHTTP) sendResponseToServer(ctx context.Context, response *JSONRPCResponse) { + if response == nil { + c.logger.Errorf("cannot send nil response to server") + return + } + + responseBody, err := json.Marshal(response) + if err != nil { + c.logger.Errorf("failed to marshal response: %v", err) + return + } + + ctx, cancel := c.contextAwareOfClientClose(ctx) + defer cancel() + + resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(responseBody), "application/json") + if err != nil { + c.logger.Errorf("failed to send response to server: %v", err) + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { + body, _ := io.ReadAll(resp.Body) + c.logger.Errorf("server rejected response with status %d: %s", resp.StatusCode, body) + } +} + func (c *StreamableHTTP) contextAwareOfClientClose(ctx context.Context) (context.Context, context.CancelFunc) { newCtx, cancel := context.WithCancel(ctx) go func() { diff --git a/client/transport/streamable_http_sampling_test.go b/client/transport/streamable_http_sampling_test.go new file mode 100644 index 000000000..edba61eac --- /dev/null +++ b/client/transport/streamable_http_sampling_test.go @@ -0,0 +1,496 @@ +package transport + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" +) + +// TestStreamableHTTP_SamplingFlow tests the complete sampling flow with HTTP transport +func TestStreamableHTTP_SamplingFlow(t *testing.T) { + // Create simple test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Just respond OK to any requests + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + // Create HTTP client transport + client, err := NewStreamableHTTP(server.URL) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer client.Close() + + // Set up sampling request handler + var handledRequest *JSONRPCRequest + handlerCalled := make(chan struct{}) + client.SetRequestHandler(func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) { + handledRequest = &request + close(handlerCalled) + + // Simulate sampling handler response + result := map[string]any{ + "role": "assistant", + "content": map[string]any{ + "type": "text", + "text": "Hello! How can I help you today?", + }, + "model": "test-model", + "stopReason": "stop_sequence", + } + + resultBytes, _ := json.Marshal(result) + + return &JSONRPCResponse{ + JSONRPC: "2.0", + ID: request.ID, + Result: resultBytes, + }, nil + }) + + // Start the client + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + err = client.Start(ctx) + if err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Test direct request handling (simulating a sampling request) + samplingRequest := JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(1), + Method: string(mcp.MethodSamplingCreateMessage), + Params: map[string]any{ + "messages": []map[string]any{ + { + "role": "user", + "content": map[string]any{ + "type": "text", + "text": "Hello, world!", + }, + }, + }, + }, + } + + // Directly test request handling + client.handleIncomingRequest(ctx, samplingRequest) + + // Wait for handler to be called + select { + case <-handlerCalled: + // Handler was called + case <-time.After(1 * time.Second): + t.Fatal("Handler was not called within timeout") + } + + // Verify the request was handled + if handledRequest == nil { + t.Fatal("Sampling request was not handled") + } + + if handledRequest.Method != string(mcp.MethodSamplingCreateMessage) { + t.Errorf("Expected method %s, got %s", mcp.MethodSamplingCreateMessage, handledRequest.Method) + } +} + +// TestStreamableHTTP_SamplingErrorHandling tests error handling in sampling requests +func TestStreamableHTTP_SamplingErrorHandling(t *testing.T) { + var errorHandled sync.WaitGroup + errorHandled.Add(1) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var body map[string]any + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Logf("Failed to decode body: %v", err) + w.WriteHeader(http.StatusOK) + return + } + + // Check if this is an error response + if errorField, ok := body["error"]; ok { + errorMap := errorField.(map[string]any) + if code, ok := errorMap["code"].(float64); ok && code == -32603 { + errorHandled.Done() + w.WriteHeader(http.StatusOK) + return + } + } + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client, err := NewStreamableHTTP(server.URL) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer client.Close() + + // Set up request handler that returns an error + client.SetRequestHandler(func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) { + return nil, fmt.Errorf("sampling failed") + }) + + // Start the client + ctx := context.Background() + err = client.Start(ctx) + if err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Simulate incoming sampling request + samplingRequest := JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(1), + Method: string(mcp.MethodSamplingCreateMessage), + Params: map[string]any{}, + } + + // This should trigger error handling + client.handleIncomingRequest(ctx, samplingRequest) + + // Wait for error to be handled + errorHandled.Wait() +} + +// TestStreamableHTTP_NoSamplingHandler tests behavior when no sampling handler is set +func TestStreamableHTTP_NoSamplingHandler(t *testing.T) { + var errorReceived bool + errorReceivedChan := make(chan struct{}) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var body map[string]any + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Logf("Failed to decode body: %v", err) + w.WriteHeader(http.StatusOK) + return + } + + // Check if this is an error response with method not found + if errorField, ok := body["error"]; ok { + errorMap := errorField.(map[string]any) + if code, ok := errorMap["code"].(float64); ok && code == -32601 { + if message, ok := errorMap["message"].(string); ok && + strings.Contains(message, "no handler configured") { + errorReceived = true + close(errorReceivedChan) + } + } + } + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client, err := NewStreamableHTTP(server.URL) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer client.Close() + + // Don't set any request handler + + ctx := context.Background() + err = client.Start(ctx) + if err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Simulate incoming sampling request + samplingRequest := JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(1), + Method: string(mcp.MethodSamplingCreateMessage), + Params: map[string]any{}, + } + + // This should trigger "method not found" error + client.handleIncomingRequest(ctx, samplingRequest) + + // Wait for error to be received + select { + case <-errorReceivedChan: + // Error was received + case <-time.After(1 * time.Second): + t.Fatal("Method not found error was not received within timeout") + } + + if !errorReceived { + t.Error("Expected method not found error, but didn't receive it") + } +} + +// TestStreamableHTTP_BidirectionalInterface verifies the interface implementation +func TestStreamableHTTP_BidirectionalInterface(t *testing.T) { + client, err := NewStreamableHTTP("http://example.com") + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer client.Close() + + // Verify it implements BidirectionalInterface + _, ok := any(client).(BidirectionalInterface) + if !ok { + t.Error("StreamableHTTP should implement BidirectionalInterface") + } + + // Test SetRequestHandler + handlerSet := false + handlerSetChan := make(chan struct{}) + client.SetRequestHandler(func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) { + handlerSet = true + close(handlerSetChan) + return nil, nil + }) + + // Verify handler was set by triggering it + ctx := context.Background() + client.handleIncomingRequest(ctx, JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(1), + Method: "test", + }) + + // Wait for handler to be called + select { + case <-handlerSetChan: + // Handler was called + case <-time.After(1 * time.Second): + t.Fatal("Handler was not called within timeout") + } + + if !handlerSet { + t.Error("Request handler was not properly set or called") + } +} + +// TestStreamableHTTP_ConcurrentSamplingRequests tests concurrent sampling requests +// where the second request completes faster than the first request +func TestStreamableHTTP_ConcurrentSamplingRequests(t *testing.T) { + var receivedResponses []map[string]any + var responseMutex sync.Mutex + responseComplete := make(chan struct{}, 2) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var body map[string]any + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Logf("Failed to decode body: %v", err) + w.WriteHeader(http.StatusBadRequest) + return + } + + // Check if this is a response from client (not a request) + if _, ok := body["result"]; ok { + responseMutex.Lock() + receivedResponses = append(receivedResponses, body) + responseMutex.Unlock() + responseComplete <- struct{}{} + } + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client, err := NewStreamableHTTP(server.URL) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer client.Close() + + // Track which requests have been received and their completion order + var requestOrder []int + var orderMutex sync.Mutex + + // Set up request handler that simulates different processing times + client.SetRequestHandler(func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) { + // Extract request ID to determine processing time + requestIDValue := request.ID.Value() + + var delay time.Duration + var responseText string + var requestNum int + + // First request (ID 1) takes longer, second request (ID 2) completes faster + if requestIDValue == int64(1) { + delay = 100 * time.Millisecond + responseText = "Response from slow request 1" + requestNum = 1 + } else if requestIDValue == int64(2) { + delay = 10 * time.Millisecond + responseText = "Response from fast request 2" + requestNum = 2 + } else { + t.Errorf("Unexpected request ID: %v", requestIDValue) + return nil, fmt.Errorf("unexpected request ID") + } + + // Simulate processing time + time.Sleep(delay) + + // Record completion order + orderMutex.Lock() + requestOrder = append(requestOrder, requestNum) + orderMutex.Unlock() + + // Return response with correct request ID + result := map[string]any{ + "role": "assistant", + "content": map[string]any{ + "type": "text", + "text": responseText, + }, + "model": "test-model", + "stopReason": "stop_sequence", + } + + resultBytes, _ := json.Marshal(result) + + return &JSONRPCResponse{ + JSONRPC: "2.0", + ID: request.ID, + Result: resultBytes, + }, nil + }) + + // Start the client + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + err = client.Start(ctx) + if err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Create two sampling requests with different IDs + request1 := JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(int64(1)), + Method: string(mcp.MethodSamplingCreateMessage), + Params: map[string]any{ + "messages": []map[string]any{ + { + "role": "user", + "content": map[string]any{ + "type": "text", + "text": "Slow request 1", + }, + }, + }, + }, + } + + request2 := JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(int64(2)), + Method: string(mcp.MethodSamplingCreateMessage), + Params: map[string]any{ + "messages": []map[string]any{ + { + "role": "user", + "content": map[string]any{ + "type": "text", + "text": "Fast request 2", + }, + }, + }, + }, + } + + // Send both requests concurrently + go client.handleIncomingRequest(ctx, request1) + go client.handleIncomingRequest(ctx, request2) + + // Wait for both responses to complete + for i := 0; i < 2; i++ { + select { + case <-responseComplete: + // Response received + case <-time.After(2 * time.Second): + t.Fatal("Timeout waiting for response") + } + } + + // Verify completion order: request 2 should complete first + orderMutex.Lock() + defer orderMutex.Unlock() + + if len(requestOrder) != 2 { + t.Fatalf("Expected 2 completed requests, got %d", len(requestOrder)) + } + + if requestOrder[0] != 2 { + t.Errorf("Expected request 2 to complete first, but request %d completed first", requestOrder[0]) + } + + if requestOrder[1] != 1 { + t.Errorf("Expected request 1 to complete second, but request %d completed second", requestOrder[1]) + } + + // Verify responses are correctly associated + responseMutex.Lock() + defer responseMutex.Unlock() + + if len(receivedResponses) != 2 { + t.Fatalf("Expected 2 responses, got %d", len(receivedResponses)) + } + + // Find responses by ID + var response1, response2 map[string]any + for _, resp := range receivedResponses { + if id, ok := resp["id"]; ok { + switch id { + case int64(1), float64(1): + response1 = resp + case int64(2), float64(2): + response2 = resp + } + } + } + + if response1 == nil { + t.Error("Response for request 1 not found") + } + if response2 == nil { + t.Error("Response for request 2 not found") + } + + // Verify each response contains the correct content + if response1 != nil { + if result, ok := response1["result"].(map[string]any); ok { + if content, ok := result["content"].(map[string]any); ok { + if text, ok := content["text"].(string); ok { + if !strings.Contains(text, "slow request 1") { + t.Errorf("Response 1 should contain 'slow request 1', got: %s", text) + } + } + } + } + } + + if response2 != nil { + if result, ok := response2["result"].(map[string]any); ok { + if content, ok := result["content"].(map[string]any); ok { + if text, ok := content["text"].(string); ok { + if !strings.Contains(text, "fast request 2") { + t.Errorf("Response 2 should contain 'fast request 2', got: %s", text) + } + } + } + } + } +} \ No newline at end of file diff --git a/examples/sampling_client/main.go b/examples/sampling_client/main.go index 67b3840b0..093b59817 100644 --- a/examples/sampling_client/main.go +++ b/examples/sampling_client/main.go @@ -5,6 +5,8 @@ import ( "fmt" "log" "os" + "os/signal" + "syscall" "github.com/mark3labs/mcp-go/client" "github.com/mark3labs/mcp-go/client/transport" @@ -28,7 +30,7 @@ func (h *MockSamplingHandler) CreateMessage(ctx context.Context, request mcp.Cre switch content := userMessage.Content.(type) { case mcp.TextContent: userText = content.Text - case map[string]interface{}: + case map[string]any: // Handle case where content is unmarshaled as a map if text, ok := content["text"].(string); ok { userText = text @@ -89,7 +91,25 @@ func main() { if err := mcpClient.Start(ctx); err != nil { log.Fatalf("Failed to start client: %v", err) } - defer mcpClient.Close() + + // Setup graceful shutdown + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + // Create a context that cancels on signal + ctx, cancel := context.WithCancel(ctx) + go func() { + <-sigChan + log.Println("Received shutdown signal, closing client...") + cancel() + }() + + // Move defer after error checking + defer func() { + if err := mcpClient.Close(); err != nil { + log.Printf("Error closing client: %v", err) + } + }() // Initialize the connection initResult, err := mcpClient.Initialize(ctx, mcp.InitializeRequest{ diff --git a/examples/sampling_http_client/README.md b/examples/sampling_http_client/README.md new file mode 100644 index 000000000..e4cf0ea4e --- /dev/null +++ b/examples/sampling_http_client/README.md @@ -0,0 +1,95 @@ +# HTTP Sampling Client Example + +This example demonstrates how to create an MCP client using HTTP transport that supports sampling requests from the server. + +## Overview + +This client: +- Connects to an MCP server via HTTP/HTTPS transport +- Declares sampling capability during initialization +- Handles incoming sampling requests from the server +- Uses a mock LLM to generate responses (replace with real LLM integration) + +## Usage + +1. Start an MCP server that supports sampling (e.g., using the `sampling_server` example) + +2. Update the server URL in `main.go`: + ```go + httpClient, err := client.NewStreamableHttpClient( + "http://your-server:port", // Replace with your server URL + ) + ``` + +3. Run the client: + ```bash + go run main.go + ``` + +## Key Features + +### HTTP Transport with Sampling +The client creates the HTTP transport directly and then wraps it with a client that supports sampling: + +```go +httpTransport, err := transport.NewStreamableHTTP("http://localhost:8080") +mcpClient := client.NewClient(httpTransport, client.WithSamplingHandler(samplingHandler)) +``` + +### Sampling Handler +The `MockSamplingHandler` implements the `client.SamplingHandler` interface: + +```go +type MockSamplingHandler struct{} + +func (h *MockSamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + // Process the sampling request and return LLM response + // In production, integrate with OpenAI, Anthropic, or other LLM APIs +} +``` + +### Client Configuration +The client is configured with sampling capabilities: + +```go +mcpClient := client.NewClient( + httpTransport, + client.WithSamplingHandler(samplingHandler), +) +// Sampling capability is automatically declared when a handler is provided +``` + +## Real Implementation + +For a production implementation, replace the `MockSamplingHandler` with a real LLM client: + +```go +type RealSamplingHandler struct { + client *openai.Client // or other LLM client +} + +func (h *RealSamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + // Convert MCP request to LLM API format + // Call LLM API + // Convert response back to MCP format + // Return the result +} +``` + +## HTTP-Specific Features + +The HTTP transport supports: +- Standard HTTP headers for authentication and customization +- OAuth 2.0 authentication (using `WithHTTPOAuth`) +- Custom headers (using `WithHTTPHeaders`) +- Server-side events (SSE) for bidirectional communication +- Proper error handling with HTTP status codes +- Session management via HTTP headers + +## Testing + +The implementation includes comprehensive tests in `client/transport/streamable_http_sampling_test.go` that verify: +- Sampling request handling +- Error scenarios +- Bidirectional interface compliance +- HTTP-specific error codes and responses \ No newline at end of file diff --git a/examples/sampling_http_client/main.go b/examples/sampling_http_client/main.go new file mode 100644 index 000000000..98817e6f8 --- /dev/null +++ b/examples/sampling_http_client/main.go @@ -0,0 +1,116 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + "os/signal" + "syscall" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" +) + +// MockSamplingHandler implements client.SamplingHandler for demonstration. +// In a real implementation, this would integrate with an actual LLM API. +type MockSamplingHandler struct{} + +func (h *MockSamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + // Extract the user's message + if len(request.Messages) == 0 { + return nil, fmt.Errorf("no messages provided") + } + + // Get the last user message + lastMessage := request.Messages[len(request.Messages)-1] + userText := "" + if textContent, ok := lastMessage.Content.(mcp.TextContent); ok { + userText = textContent.Text + } + + // Generate a mock response + responseText := fmt.Sprintf("Mock LLM response to: '%s'", userText) + + log.Printf("Mock LLM generating response: %s", responseText) + + result := &mcp.CreateMessageResult{ + SamplingMessage: mcp.SamplingMessage{ + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Type: "text", + Text: responseText, + }, + }, + Model: "mock-model-v1", + StopReason: "endTurn", + } + + return result, nil +} + +func main() { + // Create sampling handler + samplingHandler := &MockSamplingHandler{} + + // Create HTTP transport directly + httpTransport, err := transport.NewStreamableHTTP( + "http://localhost:8080", // Replace with your MCP server URL + // You can add HTTP-specific options here like headers, OAuth, etc. + ) + if err != nil { + log.Fatalf("Failed to create HTTP transport: %v", err) + } + defer httpTransport.Close() + + // Create client with sampling support + mcpClient := client.NewClient( + httpTransport, + client.WithSamplingHandler(samplingHandler), + ) + + // Start the client + ctx := context.Background() + err = mcpClient.Start(ctx) + if err != nil { + log.Fatalf("Failed to start client: %v", err) + } + + // Initialize the MCP session + initRequest := mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + Capabilities: mcp.ClientCapabilities{ + // Sampling capability will be automatically added by the client + }, + ClientInfo: mcp.Implementation{ + Name: "sampling-http-client", + Version: "1.0.0", + }, + }, + } + + _, err = mcpClient.Initialize(ctx, initRequest) + if err != nil { + log.Fatalf("Failed to initialize MCP session: %v", err) + } + + log.Println("HTTP MCP client with sampling support started successfully!") + log.Println("The client is now ready to handle sampling requests from the server.") + log.Println("When the server sends a sampling request, the MockSamplingHandler will process it.") + + // In a real application, you would keep the client running to handle sampling requests + // For this example, we'll just demonstrate that it's working + + // Keep the client running (in a real app, you'd have your main application logic here) + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) + + select { + case <-ctx.Done(): + log.Println("Client context cancelled") + case <-sigChan: + log.Println("Received shutdown signal") + } +} \ No newline at end of file diff --git a/examples/sampling_http_server/README.md b/examples/sampling_http_server/README.md new file mode 100644 index 000000000..64be58c2c --- /dev/null +++ b/examples/sampling_http_server/README.md @@ -0,0 +1,138 @@ +# HTTP Sampling Server Example + +This example demonstrates how to create an MCP server using HTTP transport that can send sampling requests to clients. + +## Overview + +This server: +- Runs on HTTP transport (port 8080 by default) +- Declares sampling capability during initialization +- Can send sampling requests to connected clients via Server-Sent Events (SSE) +- Receives sampling responses from clients via HTTP POST +- Includes tools that demonstrate sampling functionality + +## Usage + +1. Start the server: + ```bash + go run main.go + ``` + +2. The server will be available at: `http://localhost:8080/mcp` + +3. Connect with an HTTP client that supports sampling (like the `sampling_http_client` example) + +## Tools Available + +### `ask_llm` +Demonstrates server-initiated sampling: +- Takes a question and optional system prompt +- Sends sampling request to client +- Returns the LLM's response + +### `echo` +Simple tool for testing basic functionality: +- Echoes back the input message +- Doesn't require sampling + +## How Sampling Works + +### Server → Client Flow +1. **Tool Invocation**: Client calls `ask_llm` tool +2. **Sampling Request**: Server creates sampling request with user's question +3. **SSE Transmission**: Server sends JSON-RPC request to client via SSE stream +4. **Client Processing**: Client's sampling handler processes the request +5. **HTTP Response**: Client sends JSON-RPC response back via HTTP POST +6. **Tool Response**: Server returns the LLM response to the original tool caller + +### Communication Architecture +``` +Client (HTTP + SSE) ←→ Server (HTTP) + │ │ + ├─ POST: Tool Call ──→ │ + │ │ + │ ←── SSE: Sampling ───┤ + │ Request │ + │ │ + ├─ POST: Sampling ───→ │ + │ Response │ + │ │ + │ ←── HTTP: Tool ──────┤ + Response +``` + +## Key Features + +### Bidirectional Communication +- **SSE Stream**: Server → Client requests (sampling, notifications) +- **HTTP POST**: Client → Server responses and requests + +### Session Management +- Session ID tracking for request/response correlation +- Proper session lifecycle management +- Session validation for security + +### Error Handling +- JSON-RPC error codes for different failure scenarios +- Timeout handling for sampling requests +- Queue overflow protection + +### HTTP-Specific Features +- Standard MCP headers (`Mcp-Session-Id`, `Mcp-Protocol-Version`) +- Content-Type validation +- Proper HTTP status codes +- SSE event formatting + +## Testing + +You can test the server using the `sampling_http_client` example: + +1. Start this server: + ```bash + go run examples/sampling_http_server/main.go + ``` + +2. In another terminal, start the client: + ```bash + go run examples/sampling_http_client/main.go + ``` + +3. The client will connect and be ready to handle sampling requests from the server. + +## Production Considerations + +### Security +- Implement proper authentication/authorization +- Use HTTPS in production +- Validate all incoming data +- Implement rate limiting + +### Scalability +- Consider connection pooling for multiple clients +- Implement proper session cleanup +- Monitor memory usage for long-running sessions +- Add metrics and monitoring + +### Reliability +- Implement request retries +- Add circuit breakers for failing clients +- Implement graceful degradation when sampling is unavailable +- Add comprehensive logging + +## Integration + +This server can be integrated into existing HTTP infrastructure: + +```go +// Custom HTTP server integration +mux := http.NewServeMux() +mux.Handle("/mcp", httpServer) +mux.Handle("/health", healthHandler) + +server := &http.Server{ + Addr: ":8080", + Handler: mux, +} +``` + +The sampling functionality works seamlessly with other MCP features like tools, resources, and prompts. \ No newline at end of file diff --git a/examples/sampling_http_server/main.go b/examples/sampling_http_server/main.go new file mode 100644 index 000000000..95a2bf29b --- /dev/null +++ b/examples/sampling_http_server/main.go @@ -0,0 +1,150 @@ +package main + +import ( + "context" + "fmt" + "log" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +func main() { + // Create MCP server with sampling capability + mcpServer := server.NewMCPServer("sampling-http-server", "1.0.0") + + // Enable sampling capability + mcpServer.EnableSampling() + + // Add a tool that uses sampling to get LLM responses + mcpServer.AddTool(mcp.Tool{ + Name: "ask_llm", + Description: "Ask the LLM a question using sampling over HTTP", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "question": map[string]any{ + "type": "string", + "description": "The question to ask the LLM", + }, + "system_prompt": map[string]any{ + "type": "string", + "description": "Optional system prompt to provide context", + }, + }, + Required: []string{"question"}, + }, + }, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Extract parameters + question, err := request.RequireString("question") + if err != nil { + return nil, err + } + + systemPrompt := request.GetString("system_prompt", "You are a helpful assistant.") + + // Create sampling request + samplingRequest := mcp.CreateMessageRequest{ + CreateMessageParams: mcp.CreateMessageParams{ + Messages: []mcp.SamplingMessage{ + { + Role: mcp.RoleUser, + Content: mcp.TextContent{ + Type: "text", + Text: question, + }, + }, + }, + SystemPrompt: systemPrompt, + MaxTokens: 1000, + Temperature: 0.7, + }, + } + + // Request sampling from the client with timeout + samplingCtx, cancel := context.WithTimeout(ctx, 2*time.Minute) + defer cancel() + + serverFromCtx := server.ServerFromContext(ctx) + result, err := serverFromCtx.RequestSampling(samplingCtx, samplingRequest) + if err != nil { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("Error requesting sampling: %v", err), + }, + }, + IsError: true, + }, nil + } + + // Extract response text safely + var responseText string + if textContent, ok := result.Content.(mcp.TextContent); ok { + responseText = textContent.Text + } else { + responseText = fmt.Sprintf("%v", result.Content) + } + + // Return the LLM response + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("LLM Response (model: %s): %s", result.Model, responseText), + }, + }, + }, nil + }) + + // Add a simple echo tool for testing + mcpServer.AddTool(mcp.Tool{ + Name: "echo", + Description: "Echo back the input message", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "message": map[string]any{ + "type": "string", + "description": "The message to echo back", + }, + }, + Required: []string{"message"}, + }, + }, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + message := request.GetString("message", "") + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("Echo: %s", message), + }, + }, + }, nil + }) + + // Create HTTP server + httpServer := server.NewStreamableHTTPServer(mcpServer) + + log.Println("Starting HTTP MCP server with sampling support on :8080") + log.Println("Endpoint: http://localhost:8080/mcp") + log.Println("") + log.Println("This server supports sampling over HTTP transport.") + log.Println("Clients must:") + log.Println("1. Initialize with sampling capability") + log.Println("2. Establish SSE connection for bidirectional communication") + log.Println("3. Handle incoming sampling requests from the server") + log.Println("4. Send responses back via HTTP POST") + log.Println("") + log.Println("Available tools:") + log.Println("- ask_llm: Ask the LLM a question (requires sampling)") + log.Println("- echo: Simple echo tool (no sampling required)") + + // Start the server + if err := httpServer.Start(":8080"); err != nil { + log.Fatalf("Server failed to start: %v", err) + } +} \ No newline at end of file diff --git a/examples/sampling_server/main.go b/examples/sampling_server/main.go index c3bcf4902..ea887c588 100644 --- a/examples/sampling_server/main.go +++ b/examples/sampling_server/main.go @@ -127,11 +127,11 @@ func main() { } // Helper function to extract text from content -func getTextFromContent(content interface{}) string { +func getTextFromContent(content any) string { switch c := content.(type) { case mcp.TextContent: return c.Text - case map[string]interface{}: + case map[string]any: // Handle JSON unmarshaled content if text, ok := c["text"].(string); ok { return text diff --git a/mcp/types.go b/mcp/types.go index 0ef6811fd..a13e3689a 100644 --- a/mcp/types.go +++ b/mcp/types.go @@ -472,6 +472,8 @@ type ServerCapabilities struct { // list. ListChanged bool `json:"listChanged,omitempty"` } `json:"resources,omitempty"` + // Present if the server supports sending sampling requests to clients. + Sampling *struct{} `json:"sampling,omitempty"` // Present if the server offers any tools to call. Tools *struct { // Whether this server supports notifications for changes to the tool list. diff --git a/server/errors.go b/server/errors.go index ecbe91e5f..3864f36f7 100644 --- a/server/errors.go +++ b/server/errors.go @@ -21,7 +21,7 @@ var ( // Notification-related errors ErrNotificationNotInitialized = errors.New("notification channel not initialized") - ErrNotificationChannelBlocked = errors.New("notification channel full or blocked") + ErrNotificationChannelBlocked = errors.New("notification channel queue is full - client may not be processing notifications fast enough") ) // ErrDynamicPathConfig is returned when attempting to use static path methods with dynamic path configuration diff --git a/server/sampling.go b/server/sampling.go index ae0812fa5..4423ccf5f 100644 --- a/server/sampling.go +++ b/server/sampling.go @@ -12,6 +12,9 @@ import ( func (s *MCPServer) EnableSampling() { s.capabilitiesMu.Lock() defer s.capabilitiesMu.Unlock() + + enabled := true + s.capabilities.sampling = &enabled } // RequestSampling sends a sampling request to the client. diff --git a/server/sampling_test.go b/server/sampling_test.go index c69ac6cb5..fbecdd70d 100644 --- a/server/sampling_test.go +++ b/server/sampling_test.go @@ -113,3 +113,42 @@ func TestMCPServer_RequestSampling_Success(t *testing.T) { t.Errorf("expected model %q, got %q", "test-model", result.Model) } } + +func TestMCPServer_EnableSampling_SetsCapability(t *testing.T) { + server := NewMCPServer("test", "1.0.0") + + // Verify sampling capability is not set initially + ctx := context.Background() + initRequest := mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: "2025-03-26", + ClientInfo: mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + }, + Capabilities: mcp.ClientCapabilities{}, + }, + } + + result, err := server.handleInitialize(ctx, 1, initRequest) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Capabilities.Sampling != nil { + t.Error("sampling capability should not be set before EnableSampling() is called") + } + + // Enable sampling + server.EnableSampling() + + // Verify sampling capability is now set + result, err = server.handleInitialize(ctx, 2, initRequest) + if err != nil { + t.Fatalf("unexpected error after EnableSampling(): %v", err) + } + + if result.Capabilities.Sampling == nil { + t.Error("sampling capability should be set after EnableSampling() is called") + } +} diff --git a/server/server.go b/server/server.go index a98a2132b..7822a125f 100644 --- a/server/server.go +++ b/server/server.go @@ -181,6 +181,7 @@ type serverCapabilities struct { resources *resourceCapabilities prompts *promptCapabilities logging *bool + sampling *bool } // resourceCapabilities defines the supported resource-related features @@ -580,6 +581,10 @@ func (s *MCPServer) handleInitialize( capabilities.Logging = &struct{}{} } + if s.capabilities.sampling != nil && *s.capabilities.sampling { + capabilities.Sampling = &struct{}{} + } + result := mcp.InitializeResult{ ProtocolVersion: s.protocolVersion(request.Params.ProtocolVersion), ServerInfo: mcp.Implementation{ diff --git a/server/session_test.go b/server/session_test.go index 9bd8bc9fa..04334487b 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -1471,8 +1471,8 @@ func TestMCPServer_LoggingNotificationFormat(t *testing.T) { // Send log messages with different formats testCases := []struct { name string - data interface{} - expected interface{} + data any + expected any }{ { name: "string data", @@ -1481,8 +1481,8 @@ func TestMCPServer_LoggingNotificationFormat(t *testing.T) { }, { name: "structured data", - data: map[string]interface{}{"key": "value", "num": 42}, - expected: map[string]interface{}{"key": "value", "num": 42}, + data: map[string]any{"key": "value", "num": 42}, + expected: map[string]any{"key": "value", "num": 42}, }, { name: "error data", @@ -1514,9 +1514,9 @@ func TestMCPServer_LoggingNotificationFormat(t *testing.T) { switch expected := tc.expected.(type) { case string: assert.Equal(t, expected, dataField) - case map[string]interface{}: - assert.IsType(t, map[string]interface{}{}, dataField) - dataMap := dataField.(map[string]interface{}) + case map[string]any: + assert.IsType(t, map[string]any{}, dataField) + dataMap := dataField.(map[string]any) for k, v := range expected { assert.Equal(t, v, dataMap[k]) } diff --git a/server/streamable_http.go b/server/streamable_http.go index f39e24f87..24ec1c95a 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -120,6 +120,7 @@ type StreamableHTTPServer struct { server *MCPServer sessionTools *sessionToolsStore sessionRequestIDs sync.Map // sessionId --> last requestID(*atomic.Int64) + activeSessions sync.Map // sessionId --> *streamableHttpSession (for sampling responses) httpServer *http.Server mu sync.RWMutex @@ -223,14 +224,32 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, fmt.Sprintf("read request body error: %v", err)) return } - var baseMessage struct { - Method mcp.MCPMethod `json:"method"` + // First, try to parse as a response (sampling responses don't have a method field) + var jsonMessage struct { + ID json.RawMessage `json:"id"` + Result json.RawMessage `json:"result,omitempty"` + Error json.RawMessage `json:"error,omitempty"` + Method mcp.MCPMethod `json:"method,omitempty"` } - if err := json.Unmarshal(rawData, &baseMessage); err != nil { + if err := json.Unmarshal(rawData, &jsonMessage); err != nil { s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, "request body is not valid json") return } - isInitializeRequest := baseMessage.Method == mcp.MethodInitialize + + // Check if this is a sampling response (has result/error but no method) + isSamplingResponse := jsonMessage.Method == "" && jsonMessage.ID != nil && + (jsonMessage.Result != nil || jsonMessage.Error != nil) + + isInitializeRequest := jsonMessage.Method == mcp.MethodInitialize + + // Handle sampling responses separately + if isSamplingResponse { + if err := s.handleSamplingResponse(w, r, jsonMessage); err != nil { + s.logger.Errorf("Failed to handle sampling response: %v", err) + http.Error(w, "Failed to handle sampling response", http.StatusInternalServerError) + } + return + } // Prepare the session for the mcp server // The session is ephemeral. Its life is the same as the request. It's only created @@ -371,6 +390,10 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) return } defer s.server.UnregisterSession(r.Context(), sessionID) + + // Register session for sampling response delivery + s.activeSessions.Store(sessionID, session) + defer s.activeSessions.Delete(sessionID) // Set the client context before handling the message w.Header().Set("Content-Type", "text/event-stream") @@ -399,6 +422,21 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) case <-done: return } + case samplingReq := <-session.samplingRequestChan: + // Send sampling request to client via SSE + jsonrpcRequest := mcp.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(samplingReq.requestID), + Request: mcp.Request{ + Method: string(mcp.MethodSamplingCreateMessage), + }, + Params: samplingReq.request.CreateMessageParams, + } + select { + case writeChan <- jsonrpcRequest: + case <-done: + return + } case <-done: return } @@ -487,6 +525,114 @@ func writeSSEEvent(w io.Writer, data any) error { return nil } +// handleSamplingResponse processes incoming sampling responses from clients +func (s *StreamableHTTPServer) handleSamplingResponse(w http.ResponseWriter, r *http.Request, responseMessage struct { + ID json.RawMessage `json:"id"` + Result json.RawMessage `json:"result,omitempty"` + Error json.RawMessage `json:"error,omitempty"` + Method mcp.MCPMethod `json:"method,omitempty"` +}) error { + // Get session ID from header + sessionID := r.Header.Get(HeaderKeySessionID) + if sessionID == "" { + http.Error(w, "Missing session ID for sampling response", http.StatusBadRequest) + return fmt.Errorf("missing session ID") + } + + // Validate session + isTerminated, err := s.sessionIdManager.Validate(sessionID) + if err != nil { + http.Error(w, "Invalid session ID", http.StatusBadRequest) + return err + } + if isTerminated { + http.Error(w, "Session terminated", http.StatusNotFound) + return fmt.Errorf("session terminated") + } + + // Parse the request ID + var requestID int64 + if err := json.Unmarshal(responseMessage.ID, &requestID); err != nil { + http.Error(w, "Invalid request ID in sampling response", http.StatusBadRequest) + return err + } + + // Create the sampling response item + response := samplingResponseItem{ + requestID: requestID, + } + + // Parse result or error + if responseMessage.Error != nil { + // Parse error + var jsonrpcError struct { + Code int `json:"code"` + Message string `json:"message"` + } + if err := json.Unmarshal(responseMessage.Error, &jsonrpcError); err != nil { + response.err = fmt.Errorf("failed to parse error: %v", err) + } else { + response.err = fmt.Errorf("sampling error %d: %s", jsonrpcError.Code, jsonrpcError.Message) + } + } else if responseMessage.Result != nil { + // Parse result + var result mcp.CreateMessageResult + if err := json.Unmarshal(responseMessage.Result, &result); err != nil { + response.err = fmt.Errorf("failed to parse sampling result: %v", err) + } else { + response.result = &result + } + } else { + response.err = fmt.Errorf("sampling response has neither result nor error") + } + + // Find the corresponding session and deliver the response + // The response is delivered to the specific session identified by sessionID + if err := s.deliverSamplingResponse(sessionID, response); err != nil { + s.logger.Errorf("Failed to deliver sampling response: %v", err) + http.Error(w, "Failed to deliver response", http.StatusInternalServerError) + return err + } + + // Acknowledge receipt + w.WriteHeader(http.StatusOK) + return nil +} + +// deliverSamplingResponse delivers a sampling response to the appropriate session +func (s *StreamableHTTPServer) deliverSamplingResponse(sessionID string, response samplingResponseItem) error { + // Look up the active session + sessionInterface, ok := s.activeSessions.Load(sessionID) + if !ok { + return fmt.Errorf("no active session found for session %s", sessionID) + } + + session, ok := sessionInterface.(*streamableHttpSession) + if !ok { + return fmt.Errorf("invalid session type for session %s", sessionID) + } + + // Look up the dedicated response channel for this specific request + responseChannelInterface, exists := session.samplingRequests.Load(response.requestID) + if !exists { + return fmt.Errorf("no pending request found for session %s, request %d", sessionID, response.requestID) + } + + responseChan, ok := responseChannelInterface.(chan samplingResponseItem) + if !ok { + return fmt.Errorf("invalid response channel type for session %s, request %d", sessionID, response.requestID) + } + + // Attempt to deliver the response with timeout to prevent indefinite blocking + select { + case responseChan <- response: + s.logger.Infof("Delivered sampling response for session %s, request %d", sessionID, response.requestID) + return nil + default: + return fmt.Errorf("failed to deliver sampling response for session %s, request %d: channel full or blocked", sessionID, response.requestID) + } +} + // writeJSONRPCError writes a JSON-RPC error response with the given error details. func (s *StreamableHTTPServer) writeJSONRPCError( w http.ResponseWriter, @@ -573,6 +719,19 @@ func (s *sessionToolsStore) delete(sessionID string) { delete(s.tools, sessionID) } +// Sampling support types for HTTP transport +type samplingRequestItem struct { + requestID int64 + request mcp.CreateMessageRequest + response chan samplingResponseItem +} + +type samplingResponseItem struct { + requestID int64 + result *mcp.CreateMessageResult + err error +} + // streamableHttpSession is a session for streamable-http transport // When in POST handlers(request/notification), it's ephemeral, and only exists in the life of the request handler. // When in GET handlers(listening), it's a real session, and will be registered in the MCP server. @@ -582,14 +741,20 @@ type streamableHttpSession struct { tools *sessionToolsStore upgradeToSSE atomic.Bool logLevels *sessionLogLevelsStore + + // Sampling support for bidirectional communication + samplingRequestChan chan samplingRequestItem // server -> client sampling requests + samplingRequests sync.Map // requestID -> pending sampling request context + requestIDCounter atomic.Int64 // for generating unique request IDs } func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore, levels *sessionLogLevelsStore) *streamableHttpSession { s := &streamableHttpSession{ - sessionID: sessionID, - notificationChannel: make(chan mcp.JSONRPCNotification, 100), - tools: toolStore, - logLevels: levels, + sessionID: sessionID, + notificationChannel: make(chan mcp.JSONRPCNotification, 100), + tools: toolStore, + logLevels: levels, + samplingRequestChan: make(chan samplingRequestItem, 10), } return s } @@ -641,6 +806,49 @@ func (s *streamableHttpSession) UpgradeToSSEWhenReceiveNotification() { var _ SessionWithStreamableHTTPConfig = (*streamableHttpSession)(nil) +// RequestSampling implements SessionWithSampling interface for HTTP transport +func (s *streamableHttpSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + // Generate unique request ID + requestID := s.requestIDCounter.Add(1) + + // Create response channel for this specific request + responseChan := make(chan samplingResponseItem, 1) + + // Create the sampling request item + samplingRequest := samplingRequestItem{ + requestID: requestID, + request: request, + response: responseChan, + } + + // Store the pending request + s.samplingRequests.Store(requestID, responseChan) + defer s.samplingRequests.Delete(requestID) + + // Send the sampling request via the channel (non-blocking) + select { + case s.samplingRequestChan <- samplingRequest: + // Request queued successfully + case <-ctx.Done(): + return nil, ctx.Err() + default: + return nil, fmt.Errorf("sampling request queue is full - server overloaded") + } + + // Wait for response or context cancellation + select { + case response := <-responseChan: + if response.err != nil { + return nil, response.err + } + return response.result, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +var _ SessionWithSampling = (*streamableHttpSession)(nil) + // --- session id manager --- type SessionIdManager interface { diff --git a/server/streamable_http_sampling_test.go b/server/streamable_http_sampling_test.go new file mode 100644 index 000000000..4cf57838c --- /dev/null +++ b/server/streamable_http_sampling_test.go @@ -0,0 +1,216 @@ +package server + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" +) + +// TestStreamableHTTPServer_SamplingBasic tests basic sampling session functionality +func TestStreamableHTTPServer_SamplingBasic(t *testing.T) { + // Create MCP server with sampling enabled + mcpServer := NewMCPServer("test-server", "1.0.0") + mcpServer.EnableSampling() + + // Create HTTP server + httpServer := NewStreamableHTTPServer(mcpServer) + testServer := httptest.NewServer(httpServer) + defer testServer.Close() + + // Test session creation and interface implementation + sessionID := "test-session" + session := newStreamableHttpSession(sessionID, httpServer.sessionTools, httpServer.sessionLogLevels) + + // Verify it implements SessionWithSampling + _, ok := any(session).(SessionWithSampling) + if !ok { + t.Error("streamableHttpSession should implement SessionWithSampling") + } + + // Test that sampling request channels are initialized + if session.samplingRequestChan == nil { + t.Error("samplingRequestChan should be initialized") + } +} + +// TestStreamableHTTPServer_SamplingErrorHandling tests error scenarios +func TestStreamableHTTPServer_SamplingErrorHandling(t *testing.T) { + mcpServer := NewMCPServer("test-server", "1.0.0") + mcpServer.EnableSampling() + + httpServer := NewStreamableHTTPServer(mcpServer) + testServer := httptest.NewServer(httpServer) + defer testServer.Close() + + client := &http.Client{} + baseURL := testServer.URL + + tests := []struct { + name string + sessionID string + body map[string]any + expectedStatus int + }{ + { + name: "missing session ID", + sessionID: "", + body: map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "result": map[string]any{ + "role": "assistant", + "content": map[string]any{ + "type": "text", + "text": "Test response", + }, + }, + }, + expectedStatus: http.StatusBadRequest, + }, + { + name: "invalid request ID", + sessionID: "mcp-session-550e8400-e29b-41d4-a716-446655440000", + body: map[string]any{ + "jsonrpc": "2.0", + "id": "invalid-id", + "result": map[string]any{ + "role": "assistant", + "content": map[string]any{ + "type": "text", + "text": "Test response", + }, + }, + }, + expectedStatus: http.StatusBadRequest, + }, + { + name: "malformed result", + sessionID: "mcp-session-550e8400-e29b-41d4-a716-446655440000", + body: map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "result": "invalid-result", + }, + expectedStatus: http.StatusInternalServerError, // Now correctly returns 500 due to no active session + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + payload, _ := json.Marshal(tt.body) + req, err := http.NewRequest("POST", baseURL, bytes.NewReader(payload)) + if err != nil { + t.Errorf("Failed to create request: %v", err) + return + } + req.Header.Set("Content-Type", "application/json") + if tt.sessionID != "" { + req.Header.Set("Mcp-Session-Id", tt.sessionID) + } + + resp, err := client.Do(req) + if err != nil { + t.Errorf("Failed to send request: %v", err) + return + } + defer resp.Body.Close() + + if resp.StatusCode != tt.expectedStatus { + t.Errorf("Expected status %d, got %d", tt.expectedStatus, resp.StatusCode) + } + }) + } +} + +// TestStreamableHTTPServer_SamplingInterface verifies interface implementation +func TestStreamableHTTPServer_SamplingInterface(t *testing.T) { + mcpServer := NewMCPServer("test-server", "1.0.0") + mcpServer.EnableSampling() + httpServer := NewStreamableHTTPServer(mcpServer) + testServer := httptest.NewServer(httpServer) + defer testServer.Close() + + // Create a session + sessionID := "test-session" + session := newStreamableHttpSession(sessionID, httpServer.sessionTools, httpServer.sessionLogLevels) + + // Verify it implements SessionWithSampling + _, ok := any(session).(SessionWithSampling) + if !ok { + t.Error("streamableHttpSession should implement SessionWithSampling") + } + + // Test RequestSampling with timeout + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + request := mcp.CreateMessageRequest{ + CreateMessageParams: mcp.CreateMessageParams{ + Messages: []mcp.SamplingMessage{ + { + Role: mcp.RoleUser, + Content: mcp.TextContent{ + Type: "text", + Text: "Test message", + }, + }, + }, + }, + } + + _, err := session.RequestSampling(ctx, request) + if err == nil { + t.Error("Expected timeout error, but got nil") + } + + if !strings.Contains(err.Error(), "context deadline exceeded") { + t.Errorf("Expected timeout error, got: %v", err) + } +} + +// TestStreamableHTTPServer_SamplingQueueFull tests queue overflow scenarios +func TestStreamableHTTPServer_SamplingQueueFull(t *testing.T) { + sessionID := "test-session" + session := newStreamableHttpSession(sessionID, nil, nil) + + // Fill the sampling request queue + for i := 0; i < cap(session.samplingRequestChan); i++ { + session.samplingRequestChan <- samplingRequestItem{ + requestID: int64(i), + request: mcp.CreateMessageRequest{}, + response: make(chan samplingResponseItem, 1), + } + } + + // Try to add another request (should fail) + ctx := context.Background() + request := mcp.CreateMessageRequest{ + CreateMessageParams: mcp.CreateMessageParams{ + Messages: []mcp.SamplingMessage{ + { + Role: mcp.RoleUser, + Content: mcp.TextContent{ + Type: "text", + Text: "Test message", + }, + }, + }, + }, + } + + _, err := session.RequestSampling(ctx, request) + if err == nil { + t.Error("Expected queue full error, but got nil") + } + + if !strings.Contains(err.Error(), "queue is full") { + t.Errorf("Expected queue full error, got: %v", err) + } +} \ No newline at end of file diff --git a/server/streamable_http_test.go b/server/streamable_http_test.go index 6f4a6edad..105fd18ce 100644 --- a/server/streamable_http_test.go +++ b/server/streamable_http_test.go @@ -207,7 +207,7 @@ func TestStreamableHTTP_POST_SendAndReceive(t *testing.T) { Notification: mcp.Notification{ Method: "testNotification", Params: mcp.NotificationParams{ - AdditionalFields: map[string]interface{}{"param1": "value1"}, + AdditionalFields: map[string]any{"param1": "value1"}, }, }, } @@ -395,7 +395,7 @@ func TestStreamableHTTP_POST_SendAndReceive_stateless(t *testing.T) { Notification: mcp.Notification{ Method: "testNotification", Params: mcp.NotificationParams{ - AdditionalFields: map[string]interface{}{"param1": "value1"}, + AdditionalFields: map[string]any{"param1": "value1"}, }, }, }