Skip to content

feat: implement sampling support for Streamable HTTP transport #515

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Aug 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
881e095
feat: implement sampling support for Streamable HTTP transport
andig Jul 27, 2025
d239784
feat: implement server-side sampling support for HTTP transport
andig Jul 27, 2025
dd877e0
fix: replace time.Sleep with synchronization primitives in tests
andig Jul 27, 2025
a4ec0b3
fix: improve request detection logic and add nil pointer checks
andig Jul 27, 2025
1cae3a9
fix: correct misleading comment about response delivery
andig Jul 27, 2025
204b273
fix: implement EnableSampling() to properly declare sampling capability
andig Jul 27, 2025
5d4fb64
fix: prevent panic from unsafe type assertion in example server
andig Jul 27, 2025
4e41f25
fix: add missing EnableSampling() call in interface test
andig Jul 27, 2025
178e234
fix: expand error test coverage and avoid t.Fatalf
andig Jul 27, 2025
27322ca
fix: eliminate recursive response handling and improve routing
andig Jul 27, 2025
d025975
fix: improve sampling response delivery robustness
andig Jul 27, 2025
a9b20be
fix: add graceful shutdown handling to sampling client
andig Jul 28, 2025
b7afbb9
fix: improve context handling in streamable HTTP transport
andig Jul 28, 2025
a664289
fix: improve error message for notification channel queue full condition
andig Jul 28, 2025
bac5dad
refactor: rename struct variable for clarity in message parsing
andig Jul 28, 2025
e69716d
test: add concurrent sampling requests test with response association
andig Jul 28, 2025
e28a859
fix: improve context handling in async goroutine
andig Jul 28, 2025
4fa5295
refactor: replace interface{} with any throughout codebase
andig Jul 28, 2025
3852e2d
fix: improve context handling in async goroutine for StreamableHTTP
andig Jul 28, 2025
83883ed
refactor: remove unused samplingResponseChan field from session struct
andig Jul 28, 2025
9ea4a10
feat: add graceful shutdown handling to sampling HTTP client example
andig Jul 28, 2025
11f8d0e
refactor: remove unused mu field from streamableHttpSession
andig Jul 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 155 additions & 4 deletions client/transport/streamable_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ func WithSession(sessionID string) StreamableHTTPCOption {
// The current implementation does not support the following features:
// - resuming stream
// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#resumability-and-redelivery)
// - server -> client request
type StreamableHTTP struct {
serverURL *url.URL
httpClient *http.Client
Expand All @@ -110,6 +109,10 @@ type StreamableHTTP struct {
notificationHandler func(mcp.JSONRPCNotification)
notifyMu sync.RWMutex

// Request handler for incoming server-to-client requests (like sampling)
requestHandler RequestHandler
requestMu sync.RWMutex

closed chan struct{}

// OAuth support
Expand Down Expand Up @@ -397,15 +400,23 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl
// Create a channel for this specific request
responseChan := make(chan *JSONRPCResponse, 1)

// Add timeout context for request processing if not already set
if deadline, ok := ctx.Deadline(); !ok || time.Until(deadline) > 30*time.Second {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, 30*time.Second)
defer cancel()
}

ctx, cancel := context.WithCancel(ctx)
defer cancel()

// Start a goroutine to process the SSE stream
go func() {
// only close responseChan after readingSSE()
// Ensure this goroutine respects the context
defer close(responseChan)

c.readSSE(ctx, reader, func(event, data string) {
// Try to unmarshal as a response first
var message JSONRPCResponse
if err := json.Unmarshal([]byte(data), &message); err != nil {
c.logger.Errorf("failed to unmarshal message: %v", err)
Expand All @@ -427,6 +438,19 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl
return
}

// Check if this is actually a request from the server by looking for method field
var rawMessage map[string]json.RawMessage
if err := json.Unmarshal([]byte(data), &rawMessage); err == nil {
if _, hasMethod := rawMessage["method"]; hasMethod && !message.ID.IsNil() {
var request JSONRPCRequest
if err := json.Unmarshal([]byte(data), &request); err == nil {
// This is a request from the server
c.handleIncomingRequest(ctx, request)
return
}
}
}

if !ignoreResponse {
responseChan <- &message
}
Expand Down Expand Up @@ -547,6 +571,13 @@ func (c *StreamableHTTP) SetNotificationHandler(handler func(mcp.JSONRPCNotifica
c.notificationHandler = handler
}

// SetRequestHandler sets the handler for incoming requests from the server.
func (c *StreamableHTTP) SetRequestHandler(handler RequestHandler) {
c.requestMu.Lock()
defer c.requestMu.Unlock()
c.requestHandler = handler
}

func (c *StreamableHTTP) GetSessionId() string {
return c.sessionID.Load().(string)
}
Expand All @@ -564,7 +595,11 @@ func (c *StreamableHTTP) IsOAuthEnabled() bool {
func (c *StreamableHTTP) listenForever(ctx context.Context) {
c.logger.Infof("listening to server forever")
for {
err := c.createGETConnectionToServer(ctx)
// Add timeout for individual connection attempts
connectCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
err := c.createGETConnectionToServer(connectCtx)
cancel()

if errors.Is(err, ErrGetMethodNotAllowed) {
// server does not support listening
c.logger.Errorf("server does not support listening")
Expand All @@ -580,7 +615,13 @@ func (c *StreamableHTTP) listenForever(ctx context.Context) {
if err != nil {
c.logger.Errorf("failed to listen to server. retry in 1 second: %v", err)
}
time.Sleep(retryInterval)

// Use context-aware sleep
select {
case <-time.After(retryInterval):
case <-ctx.Done():
return
}
}
}

Expand Down Expand Up @@ -627,6 +668,116 @@ func (c *StreamableHTTP) createGETConnectionToServer(ctx context.Context) error
return nil
}

// handleIncomingRequest processes requests from the server (like sampling requests)
func (c *StreamableHTTP) handleIncomingRequest(ctx context.Context, request JSONRPCRequest) {
c.requestMu.RLock()
handler := c.requestHandler
c.requestMu.RUnlock()

if handler == nil {
c.logger.Errorf("received request from server but no handler set: %s", request.Method)
// Send method not found error
errorResponse := &JSONRPCResponse{
JSONRPC: "2.0",
ID: request.ID,
Error: &struct {
Code int `json:"code"`
Message string `json:"message"`
Data json.RawMessage `json:"data"`
}{
Code: -32601, // Method not found
Message: fmt.Sprintf("no handler configured for method: %s", request.Method),
},
}
c.sendResponseToServer(ctx, errorResponse)
return
}

// Handle the request in a goroutine to avoid blocking the SSE reader
go func() {
// Create a new context with timeout for request handling, respecting parent context
requestCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()

response, err := handler(requestCtx, request)
if err != nil {
c.logger.Errorf("error handling request %s: %v", request.Method, err)

// Determine appropriate JSON-RPC error code based on error type
var errorCode int
var errorMessage string

// Check for specific sampling-related errors
if errors.Is(err, context.Canceled) {
errorCode = -32800 // Request cancelled
errorMessage = "request was cancelled"
} else if errors.Is(err, context.DeadlineExceeded) {
errorCode = -32800 // Request timeout
errorMessage = "request timed out"
} else {
// Generic error cases
switch request.Method {
case string(mcp.MethodSamplingCreateMessage):
errorCode = -32603 // Internal error
errorMessage = fmt.Sprintf("sampling request failed: %v", err)
default:
errorCode = -32603 // Internal error
errorMessage = err.Error()
}
}

// Send error response
errorResponse := &JSONRPCResponse{
JSONRPC: "2.0",
ID: request.ID,
Error: &struct {
Code int `json:"code"`
Message string `json:"message"`
Data json.RawMessage `json:"data"`
}{
Code: errorCode,
Message: errorMessage,
},
}
c.sendResponseToServer(requestCtx, errorResponse)
return
}

if response != nil {
c.sendResponseToServer(requestCtx, response)
}
}()
}

// sendResponseToServer sends a response back to the server via HTTP POST
func (c *StreamableHTTP) sendResponseToServer(ctx context.Context, response *JSONRPCResponse) {
if response == nil {
c.logger.Errorf("cannot send nil response to server")
return
}

responseBody, err := json.Marshal(response)
if err != nil {
c.logger.Errorf("failed to marshal response: %v", err)
return
}

ctx, cancel := c.contextAwareOfClientClose(ctx)
defer cancel()

resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(responseBody), "application/json")
if err != nil {
c.logger.Errorf("failed to send response to server: %v", err)
return
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
body, _ := io.ReadAll(resp.Body)
c.logger.Errorf("server rejected response with status %d: %s", resp.StatusCode, body)
}
}

func (c *StreamableHTTP) contextAwareOfClientClose(ctx context.Context) (context.Context, context.CancelFunc) {
newCtx, cancel := context.WithCancel(ctx)
go func() {
Expand Down
Loading