Skip to content

Commit c56705b

Browse files
GosthellJBUinfo
authored andcommitted
feat(streamable_http): elicitation request
1 parent 713d4d3 commit c56705b

File tree

4 files changed

+126
-47
lines changed

4 files changed

+126
-47
lines changed

client/client.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,8 @@ func (c *Client) handleIncomingRequest(ctx context.Context, request transport.JS
471471
return c.handleSamplingRequestTransport(ctx, request)
472472
case string(mcp.MethodElicitationCreate):
473473
return c.handleElicitationRequestTransport(ctx, request)
474+
case string(mcp.MethodPing):
475+
return c.handlePingRequestTransport(ctx, request)
474476
default:
475477
return nil, fmt.Errorf("unsupported request method: %s", request.Method)
476478
}
@@ -572,6 +574,15 @@ func (c *Client) handleElicitationRequestTransport(ctx context.Context, request
572574
return response, nil
573575
}
574576

577+
func (c *Client) handlePingRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
578+
b, _ := json.Marshal(&mcp.EmptyResult{})
579+
return &transport.JSONRPCResponse{
580+
JSONRPC: mcp.JSONRPC_VERSION,
581+
ID: request.ID,
582+
Result: b,
583+
}, nil
584+
}
585+
575586
func listByPage[T any](
576587
ctx context.Context,
577588
client *Client,

client/transport/streamable_http.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ func (c *StreamableHTTP) listenForever(ctx context.Context) {
605605
connectCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
606606
err := c.createGETConnectionToServer(connectCtx)
607607
cancel()
608-
608+
609609
if errors.Is(err, ErrGetMethodNotAllowed) {
610610
// server does not support listening
611611
c.logger.Errorf("server does not support listening")
@@ -621,7 +621,7 @@ func (c *StreamableHTTP) listenForever(ctx context.Context) {
621621
if err != nil {
622622
c.logger.Errorf("failed to listen to server. retry in 1 second: %v", err)
623623
}
624-
624+
625625
// Use context-aware sleep
626626
select {
627627
case <-time.After(retryInterval):
@@ -704,15 +704,15 @@ func (c *StreamableHTTP) handleIncomingRequest(ctx context.Context, request JSON
704704
// Create a new context with timeout for request handling, respecting parent context
705705
requestCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
706706
defer cancel()
707-
707+
708708
response, err := handler(requestCtx, request)
709709
if err != nil {
710710
c.logger.Errorf("error handling request %s: %v", request.Method, err)
711-
711+
712712
// Determine appropriate JSON-RPC error code based on error type
713713
var errorCode int
714714
var errorMessage string
715-
715+
716716
// Check for specific sampling-related errors
717717
if errors.Is(err, context.Canceled) {
718718
errorCode = -32800 // Request cancelled
@@ -731,7 +731,7 @@ func (c *StreamableHTTP) handleIncomingRequest(ctx context.Context, request JSON
731731
errorMessage = err.Error()
732732
}
733733
}
734-
734+
735735
// Send error response
736736
errorResponse := &JSONRPCResponse{
737737
JSONRPC: "2.0",

server/streamable_http.go

Lines changed: 107 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -237,14 +237,13 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request
237237
}
238238

239239
// Check if this is a sampling response (has result/error but no method)
240-
isSamplingResponse := jsonMessage.Method == "" && jsonMessage.ID != nil &&
240+
isResponse := jsonMessage.Method == "" && jsonMessage.ID != nil &&
241241
(jsonMessage.Result != nil || jsonMessage.Error != nil)
242-
243242
isInitializeRequest := jsonMessage.Method == mcp.MethodInitialize
244243

245244
// Handle sampling responses separately
246-
if isSamplingResponse {
247-
if err := s.handleSamplingResponse(w, r, jsonMessage); err != nil {
245+
if isResponse {
246+
if err := s.handleResponse(w, r, jsonMessage); err != nil {
248247
s.logger.Errorf("Failed to handle sampling response: %v", err)
249248
http.Error(w, "Failed to handle sampling response", http.StatusInternalServerError)
250249
}
@@ -390,7 +389,7 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
390389
return
391390
}
392391
defer s.server.UnregisterSession(r.Context(), sessionID)
393-
392+
394393
// Register session for sampling response delivery
395394
s.activeSessions.Store(sessionID, session)
396395
defer s.activeSessions.Delete(sessionID)
@@ -437,6 +436,21 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
437436
case <-done:
438437
return
439438
}
439+
case elicitationReq := <-session.elicitationRequestChan:
440+
// Send elicitation request to client via SSE
441+
jsonrpcRequest := mcp.JSONRPCRequest{
442+
JSONRPC: "2.0",
443+
ID: mcp.NewRequestId(elicitationReq.requestID),
444+
Request: mcp.Request{
445+
Method: string(mcp.MethodElicitationCreate),
446+
},
447+
Params: elicitationReq.request.Params,
448+
}
449+
select {
450+
case writeChan <- jsonrpcRequest:
451+
case <-done:
452+
return
453+
}
440454
case <-done:
441455
return
442456
}
@@ -525,8 +539,8 @@ func writeSSEEvent(w io.Writer, data any) error {
525539
return nil
526540
}
527541

528-
// handleSamplingResponse processes incoming sampling responses from clients
529-
func (s *StreamableHTTPServer) handleSamplingResponse(w http.ResponseWriter, r *http.Request, responseMessage struct {
542+
// handleResponse processes incoming responses from clients
543+
func (s *StreamableHTTPServer) handleResponse(w http.ResponseWriter, r *http.Request, responseMessage struct {
530544
ID json.RawMessage `json:"id"`
531545
Result json.RawMessage `json:"result,omitempty"`
532546
Error json.RawMessage `json:"error,omitempty"`
@@ -558,7 +572,7 @@ func (s *StreamableHTTPServer) handleSamplingResponse(w http.ResponseWriter, r *
558572
}
559573

560574
// Create the sampling response item
561-
response := samplingResponseItem{
575+
response := responseItem{
562576
requestID: requestID,
563577
}
564578

@@ -575,20 +589,14 @@ func (s *StreamableHTTPServer) handleSamplingResponse(w http.ResponseWriter, r *
575589
response.err = fmt.Errorf("sampling error %d: %s", jsonrpcError.Code, jsonrpcError.Message)
576590
}
577591
} else if responseMessage.Result != nil {
578-
// Parse result
579-
var result mcp.CreateMessageResult
580-
if err := json.Unmarshal(responseMessage.Result, &result); err != nil {
581-
response.err = fmt.Errorf("failed to parse sampling result: %v", err)
582-
} else {
583-
response.result = &result
584-
}
592+
response.result = responseMessage.Result
585593
} else {
586594
response.err = fmt.Errorf("sampling response has neither result nor error")
587595
}
588596

589597
// Find the corresponding session and deliver the response
590598
// The response is delivered to the specific session identified by sessionID
591-
if err := s.deliverSamplingResponse(sessionID, response); err != nil {
599+
if err := s.deliverResponse(sessionID, response); err != nil {
592600
s.logger.Errorf("Failed to deliver sampling response: %v", err)
593601
http.Error(w, "Failed to deliver response", http.StatusInternalServerError)
594602
return err
@@ -600,7 +608,7 @@ func (s *StreamableHTTPServer) handleSamplingResponse(w http.ResponseWriter, r *
600608
}
601609

602610
// deliverSamplingResponse delivers a sampling response to the appropriate session
603-
func (s *StreamableHTTPServer) deliverSamplingResponse(sessionID string, response samplingResponseItem) error {
611+
func (s *StreamableHTTPServer) deliverResponse(sessionID string, response responseItem) error {
604612
// Look up the active session
605613
sessionInterface, ok := s.activeSessions.Load(sessionID)
606614
if !ok {
@@ -613,12 +621,12 @@ func (s *StreamableHTTPServer) deliverSamplingResponse(sessionID string, respons
613621
}
614622

615623
// Look up the dedicated response channel for this specific request
616-
responseChannelInterface, exists := session.samplingRequests.Load(response.requestID)
624+
responseChannelInterface, exists := session.requests.Load(response.requestID)
617625
if !exists {
618626
return fmt.Errorf("no pending request found for session %s, request %d", sessionID, response.requestID)
619627
}
620628

621-
responseChan, ok := responseChannelInterface.(chan samplingResponseItem)
629+
responseChan, ok := responseChannelInterface.(chan responseItem)
622630
if !ok {
623631
return fmt.Errorf("invalid response channel type for session %s, request %d", sessionID, response.requestID)
624632
}
@@ -723,15 +731,22 @@ func (s *sessionToolsStore) delete(sessionID string) {
723731
type samplingRequestItem struct {
724732
requestID int64
725733
request mcp.CreateMessageRequest
726-
response chan samplingResponseItem
734+
response chan responseItem
727735
}
728736

729-
type samplingResponseItem struct {
737+
type responseItem struct {
730738
requestID int64
731-
result *mcp.CreateMessageResult
739+
result json.RawMessage
732740
err error
733741
}
734742

743+
// Elicitation support types for HTTP transport
744+
type elicitationRequestItem struct {
745+
requestID int64
746+
request mcp.ElicitationRequest
747+
response chan responseItem
748+
}
749+
735750
// streamableHttpSession is a session for streamable-http transport
736751
// When in POST handlers(request/notification), it's ephemeral, and only exists in the life of the request handler.
737752
// When in GET handlers(listening), it's a real session, and will be registered in the MCP server.
@@ -743,18 +758,21 @@ type streamableHttpSession struct {
743758
logLevels *sessionLogLevelsStore
744759

745760
// Sampling support for bidirectional communication
746-
samplingRequestChan chan samplingRequestItem // server -> client sampling requests
747-
samplingRequests sync.Map // requestID -> pending sampling request context
748-
requestIDCounter atomic.Int64 // for generating unique request IDs
761+
samplingRequestChan chan samplingRequestItem // server -> client sampling requests
762+
elicitationRequestChan chan elicitationRequestItem // server -> client elicitation requests
763+
764+
requests sync.Map // requestID -> pending request context
765+
requestIDCounter atomic.Int64 // for generating unique request IDs
749766
}
750767

751768
func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore, levels *sessionLogLevelsStore) *streamableHttpSession {
752769
s := &streamableHttpSession{
753-
sessionID: sessionID,
754-
notificationChannel: make(chan mcp.JSONRPCNotification, 100),
755-
tools: toolStore,
756-
logLevels: levels,
757-
samplingRequestChan: make(chan samplingRequestItem, 10),
770+
sessionID: sessionID,
771+
notificationChannel: make(chan mcp.JSONRPCNotification, 100),
772+
tools: toolStore,
773+
logLevels: levels,
774+
samplingRequestChan: make(chan samplingRequestItem, 10),
775+
elicitationRequestChan: make(chan elicitationRequestItem, 10),
758776
}
759777
return s
760778
}
@@ -810,21 +828,21 @@ var _ SessionWithStreamableHTTPConfig = (*streamableHttpSession)(nil)
810828
func (s *streamableHttpSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) {
811829
// Generate unique request ID
812830
requestID := s.requestIDCounter.Add(1)
813-
831+
814832
// Create response channel for this specific request
815-
responseChan := make(chan samplingResponseItem, 1)
816-
833+
responseChan := make(chan responseItem, 1)
834+
817835
// Create the sampling request item
818836
samplingRequest := samplingRequestItem{
819837
requestID: requestID,
820838
request: request,
821839
response: responseChan,
822840
}
823-
841+
824842
// Store the pending request
825-
s.samplingRequests.Store(requestID, responseChan)
826-
defer s.samplingRequests.Delete(requestID)
827-
843+
s.requests.Store(requestID, responseChan)
844+
defer s.requests.Delete(requestID)
845+
828846
// Send the sampling request via the channel (non-blocking)
829847
select {
830848
case s.samplingRequestChan <- samplingRequest:
@@ -834,20 +852,70 @@ func (s *streamableHttpSession) RequestSampling(ctx context.Context, request mcp
834852
default:
835853
return nil, fmt.Errorf("sampling request queue is full - server overloaded")
836854
}
837-
855+
856+
// Wait for response or context cancellation
857+
select {
858+
case response := <-responseChan:
859+
if response.err != nil {
860+
return nil, response.err
861+
}
862+
var result mcp.CreateMessageResult
863+
if err := json.Unmarshal(response.result, &result); err != nil {
864+
return nil, fmt.Errorf("failed to unmarshal sampling response: %v", err)
865+
}
866+
return &result, nil
867+
case <-ctx.Done():
868+
return nil, ctx.Err()
869+
}
870+
}
871+
872+
// RequestElicitation implements SessionWithElicitation interface for HTTP transport
873+
func (s *streamableHttpSession) RequestElicitation(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) {
874+
// Generate unique request ID
875+
requestID := s.requestIDCounter.Add(1)
876+
877+
// Create response channel for this specific request
878+
responseChan := make(chan responseItem, 1)
879+
880+
// Create the sampling request item
881+
elicitationRequest := elicitationRequestItem{
882+
requestID: requestID,
883+
request: request,
884+
response: responseChan,
885+
}
886+
887+
// Store the pending request
888+
s.requests.Store(requestID, responseChan)
889+
defer s.requests.Delete(requestID)
890+
891+
// Send the sampling request via the channel (non-blocking)
892+
select {
893+
case s.elicitationRequestChan <- elicitationRequest:
894+
// Request queued successfully
895+
case <-ctx.Done():
896+
return nil, ctx.Err()
897+
default:
898+
return nil, fmt.Errorf("elicitation request queue is full - server overloaded")
899+
}
900+
838901
// Wait for response or context cancellation
839902
select {
840903
case response := <-responseChan:
841904
if response.err != nil {
842905
return nil, response.err
843906
}
844-
return response.result, nil
907+
var result mcp.ElicitationResult
908+
if err := json.Unmarshal(response.result, &result); err != nil {
909+
return nil, fmt.Errorf("failed to unmarshal elicitation response: %v", err)
910+
}
911+
return &result, nil
845912
case <-ctx.Done():
846913
return nil, ctx.Err()
847914
}
848915
}
849916

850917
var _ SessionWithSampling = (*streamableHttpSession)(nil)
918+
var _ SessionWithElicitation = (*streamableHttpSession)(nil)
851919

852920
// --- session id manager ---
853921

server/streamable_http_sampling_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ func TestStreamableHTTPServer_SamplingQueueFull(t *testing.T) {
185185
session.samplingRequestChan <- samplingRequestItem{
186186
requestID: int64(i),
187187
request: mcp.CreateMessageRequest{},
188-
response: make(chan samplingResponseItem, 1),
188+
response: make(chan responseItem, 1),
189189
}
190190
}
191191

@@ -213,4 +213,4 @@ func TestStreamableHTTPServer_SamplingQueueFull(t *testing.T) {
213213
if !strings.Contains(err.Error(), "queue is full") {
214214
t.Errorf("Expected queue full error, got: %v", err)
215215
}
216-
}
216+
}

0 commit comments

Comments
 (0)