Skip to content

Commit fda6b38

Browse files
andigclaude
andauthored
feat: implement sampling support for Streamable HTTP transport (#515)
* feat: implement sampling support for Streamable HTTP transport Implements sampling capability for HTTP transport, resolving issue #419. Enables servers to send sampling requests to HTTP clients via SSE and receive LLM-generated responses. ## Key Changes ### Core Implementation - Add `BidirectionalInterface` support to `StreamableHTTP` - Implement `SetRequestHandler` for server-to-client requests - Enhance SSE parsing to handle requests alongside responses/notifications - Add `handleIncomingRequest` and `sendResponseToServer` methods ### HTTP-Specific Features - Leverage existing MCP headers (`Mcp-Session-Id`, `Mcp-Protocol-Version`) - Bidirectional communication via HTTP POST for responses - Proper JSON-RPC request/response handling over HTTP ### Error Handling - Add specific JSON-RPC error codes for different failure scenarios: - `-32601` (Method not found) when no handler configured - `-32603` (Internal error) for sampling failures - `-32800` (Request cancelled/timeout) for context errors - Enhanced error messages with sampling-specific context ### Testing & Examples - Comprehensive test suite in `streamable_http_sampling_test.go` - Complete working example in `examples/sampling_http_client/` - Tests cover success flows, error scenarios, and interface compliance ## Technical Details The implementation maintains full backward compatibility while adding bidirectional communication support. Server requests are processed asynchronously to avoid blocking the SSE stream reader. HTTP transport now supports the complete sampling flow that was previously only available in stdio and inprocess transports. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> * feat: implement server-side sampling support for HTTP transport This completes the server-side implementation of sampling support for HTTP transport, addressing the remaining requirements from issue #419. Changes: - Enhanced streamableHttpSession to implement SessionWithSampling interface - Added bidirectional SSE communication for server-to-client requests - Implemented session registry for proper response correlation - Added comprehensive error handling with JSON-RPC error codes - Created extensive test suite covering all scenarios - Added working example server with sampling tools Key Features: - Server can send sampling requests to HTTP clients via SSE - Clients respond via HTTP POST with proper session correlation - Queue overflow protection and timeout handling - Compatible with existing HTTP transport architecture 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> * fix: replace time.Sleep with synchronization primitives in tests Replace flaky time.Sleep calls with proper synchronization using channels and sync.WaitGroup to make tests deterministic and avoid race conditions. Also improves error handling robustness in test servers with proper JSON decoding error checks. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> * fix: improve request detection logic and add nil pointer checks - Make request vs response detection more robust by checking for presence of "method" field instead of relying on nil Result/Error fields - Add nil pointer check in sendResponseToServer function to prevent panics These changes improve reliability against malformed messages and edge cases. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> * fix: correct misleading comment about response delivery The comment incorrectly stated that responses are broadcast to all sessions, but the implementation actually delivers responses to the specific session identified by sessionID using the activeSessions registry. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> * fix: implement EnableSampling() to properly declare sampling capability Previously, EnableSampling() was a no-op that didn't actually enable the sampling capability in the server's declared capabilities. Changes: - Add Sampling field to mcp.ServerCapabilities struct - Add sampling field to internal serverCapabilities struct - Update EnableSampling() to set the sampling capability flag - Update handleInitialize() to include sampling in capability response - Add test to verify sampling capability is properly declared Now when EnableSampling() is called, the server will properly declare sampling capability during initialization, allowing clients to know that the server supports sending sampling requests. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> * fix: prevent panic from unsafe type assertion in example server Replace unsafe type assertion result.Content.(mcp.TextContent).Text with safe type checking to handle cases where Content might not be a TextContent struct. Now gracefully handles different content types without panicking. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> * fix: add missing EnableSampling() call in interface test The SamplingInterface test was missing the EnableSampling() call, which is necessary to activate sampling features for proper testing. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> * fix: expand error test coverage and avoid t.Fatalf - Replace single error test with comprehensive table-driven tests - Add test cases for invalid request IDs and malformed results - Replace t.Fatalf with t.Errorf to follow project conventions - Use proper session ID format for valid test scenarios 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> * fix: eliminate recursive response handling and improve routing - Remove recursive call in RequestSampling that could cause stack overflow - Remove problematic response re-queuing to global channel - Update deliverSamplingResponse to route responses directly to dedicated request channels via samplingRequests map lookup - This prevents ordering issues and ensures responses reach the correct waiting request 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> * fix: improve sampling response delivery robustness - Modified deliverSamplingResponse to return error instead of just logging - Added proper error handling for disconnected sessions - Improved error messages for debugging - Updated test expectations to match new error behavior 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> * fix: add graceful shutdown handling to sampling client - Add signal handling for SIGINT and SIGTERM - Move defer statement after error checking - Improve shutdown error handling 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> * fix: improve context handling in streamable HTTP transport - Add timeout context for SSE response processing (30s default) - Add timeout for individual connection attempts in listenForever (10s) - Use context-aware sleep in retry logic - Ensure async goroutines properly respect context cancellation 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> * fix: improve error message for notification channel queue full condition - Make error message more descriptive and actionable - Provide clearer debugging information about why the channel is blocked 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> * refactor: rename struct variable for clarity in message parsing - Rename 'baseMessage' to 'jsonMessage' for more neutral naming - Improves code readability and follows consistent naming conventions 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> * test: add concurrent sampling requests test with response association Add test verifying that concurrent sampling requests are handled correctly when the second request completes faster than the first. The test ensures: - Responses are correctly associated with their request IDs - Server processes requests concurrently without blocking - Completion order follows actual processing time, not submission order 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> * fix: improve context handling in async goroutine Create new context with 30-second timeout for request handling to prevent long-running handlers from blocking indefinitely. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> * refactor: replace interface{} with any throughout codebase Replace all occurrences of interface{} with the modern Go any type alias for improved readability and consistency with current Go best practices. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> * fix: improve context handling in async goroutine for StreamableHTTP Create timeout context from parent context instead of context.Background() to ensure request handlers respect parent context cancellation. Addresses review comment about context handling in async goroutine. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> * refactor: remove unused samplingResponseChan field from session struct The samplingResponseChan field was declared but never used in the streamableHttpSession struct. Remove it and update tests accordingly. Addresses review comment about unused fields in session struct. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> * feat: add graceful shutdown handling to sampling HTTP client example Add signal handling for SIGINT and SIGTERM to allow graceful shutdown of the sampling HTTP client example. This prevents indefinite blocking and provides better production-ready behavior. Addresses review comment about adding graceful shutdown handling. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> * refactor: remove unused mu field from streamableHttpSession Removes unused sync.RWMutex field that was flagged by golangci-lint. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> --------- Co-authored-by: Claude <[email protected]>
1 parent 6e5d6fd commit fda6b38

File tree

17 files changed

+1665
-26
lines changed

17 files changed

+1665
-26
lines changed

client/transport/streamable_http.go

Lines changed: 155 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ func WithSession(sessionID string) StreamableHTTPCOption {
9292
// The current implementation does not support the following features:
9393
// - resuming stream
9494
// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#resumability-and-redelivery)
95-
// - server -> client request
9695
type StreamableHTTP struct {
9796
serverURL *url.URL
9897
httpClient *http.Client
@@ -110,6 +109,10 @@ type StreamableHTTP struct {
110109
notificationHandler func(mcp.JSONRPCNotification)
111110
notifyMu sync.RWMutex
112111

112+
// Request handler for incoming server-to-client requests (like sampling)
113+
requestHandler RequestHandler
114+
requestMu sync.RWMutex
115+
113116
closed chan struct{}
114117

115118
// OAuth support
@@ -397,15 +400,23 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl
397400
// Create a channel for this specific request
398401
responseChan := make(chan *JSONRPCResponse, 1)
399402

403+
// Add timeout context for request processing if not already set
404+
if deadline, ok := ctx.Deadline(); !ok || time.Until(deadline) > 30*time.Second {
405+
var cancel context.CancelFunc
406+
ctx, cancel = context.WithTimeout(ctx, 30*time.Second)
407+
defer cancel()
408+
}
409+
400410
ctx, cancel := context.WithCancel(ctx)
401411
defer cancel()
402412

403413
// Start a goroutine to process the SSE stream
404414
go func() {
405-
// only close responseChan after readingSSE()
415+
// Ensure this goroutine respects the context
406416
defer close(responseChan)
407417

408418
c.readSSE(ctx, reader, func(event, data string) {
419+
// Try to unmarshal as a response first
409420
var message JSONRPCResponse
410421
if err := json.Unmarshal([]byte(data), &message); err != nil {
411422
c.logger.Errorf("failed to unmarshal message: %v", err)
@@ -427,6 +438,19 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl
427438
return
428439
}
429440

441+
// Check if this is actually a request from the server by looking for method field
442+
var rawMessage map[string]json.RawMessage
443+
if err := json.Unmarshal([]byte(data), &rawMessage); err == nil {
444+
if _, hasMethod := rawMessage["method"]; hasMethod && !message.ID.IsNil() {
445+
var request JSONRPCRequest
446+
if err := json.Unmarshal([]byte(data), &request); err == nil {
447+
// This is a request from the server
448+
c.handleIncomingRequest(ctx, request)
449+
return
450+
}
451+
}
452+
}
453+
430454
if !ignoreResponse {
431455
responseChan <- &message
432456
}
@@ -547,6 +571,13 @@ func (c *StreamableHTTP) SetNotificationHandler(handler func(mcp.JSONRPCNotifica
547571
c.notificationHandler = handler
548572
}
549573

574+
// SetRequestHandler sets the handler for incoming requests from the server.
575+
func (c *StreamableHTTP) SetRequestHandler(handler RequestHandler) {
576+
c.requestMu.Lock()
577+
defer c.requestMu.Unlock()
578+
c.requestHandler = handler
579+
}
580+
550581
func (c *StreamableHTTP) GetSessionId() string {
551582
return c.sessionID.Load().(string)
552583
}
@@ -564,7 +595,11 @@ func (c *StreamableHTTP) IsOAuthEnabled() bool {
564595
func (c *StreamableHTTP) listenForever(ctx context.Context) {
565596
c.logger.Infof("listening to server forever")
566597
for {
567-
err := c.createGETConnectionToServer(ctx)
598+
// Add timeout for individual connection attempts
599+
connectCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
600+
err := c.createGETConnectionToServer(connectCtx)
601+
cancel()
602+
568603
if errors.Is(err, ErrGetMethodNotAllowed) {
569604
// server does not support listening
570605
c.logger.Errorf("server does not support listening")
@@ -580,7 +615,13 @@ func (c *StreamableHTTP) listenForever(ctx context.Context) {
580615
if err != nil {
581616
c.logger.Errorf("failed to listen to server. retry in 1 second: %v", err)
582617
}
583-
time.Sleep(retryInterval)
618+
619+
// Use context-aware sleep
620+
select {
621+
case <-time.After(retryInterval):
622+
case <-ctx.Done():
623+
return
624+
}
584625
}
585626
}
586627

@@ -627,6 +668,116 @@ func (c *StreamableHTTP) createGETConnectionToServer(ctx context.Context) error
627668
return nil
628669
}
629670

671+
// handleIncomingRequest processes requests from the server (like sampling requests)
672+
func (c *StreamableHTTP) handleIncomingRequest(ctx context.Context, request JSONRPCRequest) {
673+
c.requestMu.RLock()
674+
handler := c.requestHandler
675+
c.requestMu.RUnlock()
676+
677+
if handler == nil {
678+
c.logger.Errorf("received request from server but no handler set: %s", request.Method)
679+
// Send method not found error
680+
errorResponse := &JSONRPCResponse{
681+
JSONRPC: "2.0",
682+
ID: request.ID,
683+
Error: &struct {
684+
Code int `json:"code"`
685+
Message string `json:"message"`
686+
Data json.RawMessage `json:"data"`
687+
}{
688+
Code: -32601, // Method not found
689+
Message: fmt.Sprintf("no handler configured for method: %s", request.Method),
690+
},
691+
}
692+
c.sendResponseToServer(ctx, errorResponse)
693+
return
694+
}
695+
696+
// Handle the request in a goroutine to avoid blocking the SSE reader
697+
go func() {
698+
// Create a new context with timeout for request handling, respecting parent context
699+
requestCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
700+
defer cancel()
701+
702+
response, err := handler(requestCtx, request)
703+
if err != nil {
704+
c.logger.Errorf("error handling request %s: %v", request.Method, err)
705+
706+
// Determine appropriate JSON-RPC error code based on error type
707+
var errorCode int
708+
var errorMessage string
709+
710+
// Check for specific sampling-related errors
711+
if errors.Is(err, context.Canceled) {
712+
errorCode = -32800 // Request cancelled
713+
errorMessage = "request was cancelled"
714+
} else if errors.Is(err, context.DeadlineExceeded) {
715+
errorCode = -32800 // Request timeout
716+
errorMessage = "request timed out"
717+
} else {
718+
// Generic error cases
719+
switch request.Method {
720+
case string(mcp.MethodSamplingCreateMessage):
721+
errorCode = -32603 // Internal error
722+
errorMessage = fmt.Sprintf("sampling request failed: %v", err)
723+
default:
724+
errorCode = -32603 // Internal error
725+
errorMessage = err.Error()
726+
}
727+
}
728+
729+
// Send error response
730+
errorResponse := &JSONRPCResponse{
731+
JSONRPC: "2.0",
732+
ID: request.ID,
733+
Error: &struct {
734+
Code int `json:"code"`
735+
Message string `json:"message"`
736+
Data json.RawMessage `json:"data"`
737+
}{
738+
Code: errorCode,
739+
Message: errorMessage,
740+
},
741+
}
742+
c.sendResponseToServer(requestCtx, errorResponse)
743+
return
744+
}
745+
746+
if response != nil {
747+
c.sendResponseToServer(requestCtx, response)
748+
}
749+
}()
750+
}
751+
752+
// sendResponseToServer sends a response back to the server via HTTP POST
753+
func (c *StreamableHTTP) sendResponseToServer(ctx context.Context, response *JSONRPCResponse) {
754+
if response == nil {
755+
c.logger.Errorf("cannot send nil response to server")
756+
return
757+
}
758+
759+
responseBody, err := json.Marshal(response)
760+
if err != nil {
761+
c.logger.Errorf("failed to marshal response: %v", err)
762+
return
763+
}
764+
765+
ctx, cancel := c.contextAwareOfClientClose(ctx)
766+
defer cancel()
767+
768+
resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(responseBody), "application/json")
769+
if err != nil {
770+
c.logger.Errorf("failed to send response to server: %v", err)
771+
return
772+
}
773+
defer resp.Body.Close()
774+
775+
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
776+
body, _ := io.ReadAll(resp.Body)
777+
c.logger.Errorf("server rejected response with status %d: %s", resp.StatusCode, body)
778+
}
779+
}
780+
630781
func (c *StreamableHTTP) contextAwareOfClientClose(ctx context.Context) (context.Context, context.CancelFunc) {
631782
newCtx, cancel := context.WithCancel(ctx)
632783
go func() {

0 commit comments

Comments
 (0)