@@ -237,14 +237,13 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request
237
237
}
238
238
239
239
// Check if this is a sampling response (has result/error but no method)
240
- isSamplingResponse := jsonMessage .Method == "" && jsonMessage .ID != nil &&
240
+ isResponse := jsonMessage .Method == "" && jsonMessage .ID != nil &&
241
241
(jsonMessage .Result != nil || jsonMessage .Error != nil )
242
-
243
242
isInitializeRequest := jsonMessage .Method == mcp .MethodInitialize
244
243
245
244
// Handle sampling responses separately
246
- if isSamplingResponse {
247
- if err := s .handleSamplingResponse (w , r , jsonMessage ); err != nil {
245
+ if isResponse {
246
+ if err := s .handleResponse (w , r , jsonMessage ); err != nil {
248
247
s .logger .Errorf ("Failed to handle sampling response: %v" , err )
249
248
http .Error (w , "Failed to handle sampling response" , http .StatusInternalServerError )
250
249
}
@@ -390,7 +389,7 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
390
389
return
391
390
}
392
391
defer s .server .UnregisterSession (r .Context (), sessionID )
393
-
392
+
394
393
// Register session for sampling response delivery
395
394
s .activeSessions .Store (sessionID , session )
396
395
defer s .activeSessions .Delete (sessionID )
@@ -437,6 +436,21 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
437
436
case <- done :
438
437
return
439
438
}
439
+ case elicitationReq := <- session .elicitationRequestChan :
440
+ // Send elicitation request to client via SSE
441
+ jsonrpcRequest := mcp.JSONRPCRequest {
442
+ JSONRPC : "2.0" ,
443
+ ID : mcp .NewRequestId (elicitationReq .requestID ),
444
+ Request : mcp.Request {
445
+ Method : string (mcp .MethodElicitationCreate ),
446
+ },
447
+ Params : elicitationReq .request .Params ,
448
+ }
449
+ select {
450
+ case writeChan <- jsonrpcRequest :
451
+ case <- done :
452
+ return
453
+ }
440
454
case <- done :
441
455
return
442
456
}
@@ -525,8 +539,8 @@ func writeSSEEvent(w io.Writer, data any) error {
525
539
return nil
526
540
}
527
541
528
- // handleSamplingResponse processes incoming sampling responses from clients
529
- func (s * StreamableHTTPServer ) handleSamplingResponse (w http.ResponseWriter , r * http.Request , responseMessage struct {
542
+ // handleResponse processes incoming responses from clients
543
+ func (s * StreamableHTTPServer ) handleResponse (w http.ResponseWriter , r * http.Request , responseMessage struct {
530
544
ID json.RawMessage `json:"id"`
531
545
Result json.RawMessage `json:"result,omitempty"`
532
546
Error json.RawMessage `json:"error,omitempty"`
@@ -558,7 +572,7 @@ func (s *StreamableHTTPServer) handleSamplingResponse(w http.ResponseWriter, r *
558
572
}
559
573
560
574
// Create the sampling response item
561
- response := samplingResponseItem {
575
+ response := responseItem {
562
576
requestID : requestID ,
563
577
}
564
578
@@ -575,20 +589,14 @@ func (s *StreamableHTTPServer) handleSamplingResponse(w http.ResponseWriter, r *
575
589
response .err = fmt .Errorf ("sampling error %d: %s" , jsonrpcError .Code , jsonrpcError .Message )
576
590
}
577
591
} else if responseMessage .Result != nil {
578
- // Parse result
579
- var result mcp.CreateMessageResult
580
- if err := json .Unmarshal (responseMessage .Result , & result ); err != nil {
581
- response .err = fmt .Errorf ("failed to parse sampling result: %v" , err )
582
- } else {
583
- response .result = & result
584
- }
592
+ response .result = responseMessage .Result
585
593
} else {
586
594
response .err = fmt .Errorf ("sampling response has neither result nor error" )
587
595
}
588
596
589
597
// Find the corresponding session and deliver the response
590
598
// The response is delivered to the specific session identified by sessionID
591
- if err := s .deliverSamplingResponse (sessionID , response ); err != nil {
599
+ if err := s .deliverResponse (sessionID , response ); err != nil {
592
600
s .logger .Errorf ("Failed to deliver sampling response: %v" , err )
593
601
http .Error (w , "Failed to deliver response" , http .StatusInternalServerError )
594
602
return err
@@ -600,7 +608,7 @@ func (s *StreamableHTTPServer) handleSamplingResponse(w http.ResponseWriter, r *
600
608
}
601
609
602
610
// deliverSamplingResponse delivers a sampling response to the appropriate session
603
- func (s * StreamableHTTPServer ) deliverSamplingResponse (sessionID string , response samplingResponseItem ) error {
611
+ func (s * StreamableHTTPServer ) deliverResponse (sessionID string , response responseItem ) error {
604
612
// Look up the active session
605
613
sessionInterface , ok := s .activeSessions .Load (sessionID )
606
614
if ! ok {
@@ -613,12 +621,12 @@ func (s *StreamableHTTPServer) deliverSamplingResponse(sessionID string, respons
613
621
}
614
622
615
623
// Look up the dedicated response channel for this specific request
616
- responseChannelInterface , exists := session .samplingRequests .Load (response .requestID )
624
+ responseChannelInterface , exists := session .requests .Load (response .requestID )
617
625
if ! exists {
618
626
return fmt .Errorf ("no pending request found for session %s, request %d" , sessionID , response .requestID )
619
627
}
620
628
621
- responseChan , ok := responseChannelInterface .(chan samplingResponseItem )
629
+ responseChan , ok := responseChannelInterface .(chan responseItem )
622
630
if ! ok {
623
631
return fmt .Errorf ("invalid response channel type for session %s, request %d" , sessionID , response .requestID )
624
632
}
@@ -723,15 +731,22 @@ func (s *sessionToolsStore) delete(sessionID string) {
723
731
type samplingRequestItem struct {
724
732
requestID int64
725
733
request mcp.CreateMessageRequest
726
- response chan samplingResponseItem
734
+ response chan responseItem
727
735
}
728
736
729
- type samplingResponseItem struct {
737
+ type responseItem struct {
730
738
requestID int64
731
- result * mcp. CreateMessageResult
739
+ result json. RawMessage
732
740
err error
733
741
}
734
742
743
+ // Elicitation support types for HTTP transport
744
+ type elicitationRequestItem struct {
745
+ requestID int64
746
+ request mcp.ElicitationRequest
747
+ response chan responseItem
748
+ }
749
+
735
750
// streamableHttpSession is a session for streamable-http transport
736
751
// When in POST handlers(request/notification), it's ephemeral, and only exists in the life of the request handler.
737
752
// When in GET handlers(listening), it's a real session, and will be registered in the MCP server.
@@ -743,18 +758,21 @@ type streamableHttpSession struct {
743
758
logLevels * sessionLogLevelsStore
744
759
745
760
// Sampling support for bidirectional communication
746
- samplingRequestChan chan samplingRequestItem // server -> client sampling requests
747
- samplingRequests sync.Map // requestID -> pending sampling request context
748
- requestIDCounter atomic.Int64 // for generating unique request IDs
761
+ samplingRequestChan chan samplingRequestItem // server -> client sampling requests
762
+ elicitationRequestChan chan elicitationRequestItem // server -> client elicitation requests
763
+
764
+ requests sync.Map // requestID -> pending request context
765
+ requestIDCounter atomic.Int64 // for generating unique request IDs
749
766
}
750
767
751
768
func newStreamableHttpSession (sessionID string , toolStore * sessionToolsStore , levels * sessionLogLevelsStore ) * streamableHttpSession {
752
769
s := & streamableHttpSession {
753
- sessionID : sessionID ,
754
- notificationChannel : make (chan mcp.JSONRPCNotification , 100 ),
755
- tools : toolStore ,
756
- logLevels : levels ,
757
- samplingRequestChan : make (chan samplingRequestItem , 10 ),
770
+ sessionID : sessionID ,
771
+ notificationChannel : make (chan mcp.JSONRPCNotification , 100 ),
772
+ tools : toolStore ,
773
+ logLevels : levels ,
774
+ samplingRequestChan : make (chan samplingRequestItem , 10 ),
775
+ elicitationRequestChan : make (chan elicitationRequestItem , 10 ),
758
776
}
759
777
return s
760
778
}
@@ -810,21 +828,21 @@ var _ SessionWithStreamableHTTPConfig = (*streamableHttpSession)(nil)
810
828
func (s * streamableHttpSession ) RequestSampling (ctx context.Context , request mcp.CreateMessageRequest ) (* mcp.CreateMessageResult , error ) {
811
829
// Generate unique request ID
812
830
requestID := s .requestIDCounter .Add (1 )
813
-
831
+
814
832
// Create response channel for this specific request
815
- responseChan := make (chan samplingResponseItem , 1 )
816
-
833
+ responseChan := make (chan responseItem , 1 )
834
+
817
835
// Create the sampling request item
818
836
samplingRequest := samplingRequestItem {
819
837
requestID : requestID ,
820
838
request : request ,
821
839
response : responseChan ,
822
840
}
823
-
841
+
824
842
// Store the pending request
825
- s .samplingRequests .Store (requestID , responseChan )
826
- defer s .samplingRequests .Delete (requestID )
827
-
843
+ s .requests .Store (requestID , responseChan )
844
+ defer s .requests .Delete (requestID )
845
+
828
846
// Send the sampling request via the channel (non-blocking)
829
847
select {
830
848
case s .samplingRequestChan <- samplingRequest :
@@ -834,20 +852,70 @@ func (s *streamableHttpSession) RequestSampling(ctx context.Context, request mcp
834
852
default :
835
853
return nil , fmt .Errorf ("sampling request queue is full - server overloaded" )
836
854
}
837
-
855
+
856
+ // Wait for response or context cancellation
857
+ select {
858
+ case response := <- responseChan :
859
+ if response .err != nil {
860
+ return nil , response .err
861
+ }
862
+ var result mcp.CreateMessageResult
863
+ if err := json .Unmarshal (response .result , & result ); err != nil {
864
+ return nil , fmt .Errorf ("failed to unmarshal sampling response: %v" , err )
865
+ }
866
+ return & result , nil
867
+ case <- ctx .Done ():
868
+ return nil , ctx .Err ()
869
+ }
870
+ }
871
+
872
+ // RequestElicitation implements SessionWithElicitation interface for HTTP transport
873
+ func (s * streamableHttpSession ) RequestElicitation (ctx context.Context , request mcp.ElicitationRequest ) (* mcp.ElicitationResult , error ) {
874
+ // Generate unique request ID
875
+ requestID := s .requestIDCounter .Add (1 )
876
+
877
+ // Create response channel for this specific request
878
+ responseChan := make (chan responseItem , 1 )
879
+
880
+ // Create the sampling request item
881
+ elicitationRequest := elicitationRequestItem {
882
+ requestID : requestID ,
883
+ request : request ,
884
+ response : responseChan ,
885
+ }
886
+
887
+ // Store the pending request
888
+ s .requests .Store (requestID , responseChan )
889
+ defer s .requests .Delete (requestID )
890
+
891
+ // Send the sampling request via the channel (non-blocking)
892
+ select {
893
+ case s .elicitationRequestChan <- elicitationRequest :
894
+ // Request queued successfully
895
+ case <- ctx .Done ():
896
+ return nil , ctx .Err ()
897
+ default :
898
+ return nil , fmt .Errorf ("elicitation request queue is full - server overloaded" )
899
+ }
900
+
838
901
// Wait for response or context cancellation
839
902
select {
840
903
case response := <- responseChan :
841
904
if response .err != nil {
842
905
return nil , response .err
843
906
}
844
- return response .result , nil
907
+ var result mcp.ElicitationResult
908
+ if err := json .Unmarshal (response .result , & result ); err != nil {
909
+ return nil , fmt .Errorf ("failed to unmarshal elicitation response: %v" , err )
910
+ }
911
+ return & result , nil
845
912
case <- ctx .Done ():
846
913
return nil , ctx .Err ()
847
914
}
848
915
}
849
916
850
917
var _ SessionWithSampling = (* streamableHttpSession )(nil )
918
+ var _ SessionWithElicitation = (* streamableHttpSession )(nil )
851
919
852
920
// --- session id manager ---
853
921
0 commit comments