diff --git a/client/client.go b/client/client.go index cda7665e..1ca4635d 100644 --- a/client/client.go +++ b/client/client.go @@ -25,6 +25,7 @@ type Client struct { serverCapabilities mcp.ServerCapabilities protocolVersion string samplingHandler SamplingHandler + elicitationHandler ElicitationHandler } type ClientOption func(*Client) @@ -44,6 +45,14 @@ func WithSamplingHandler(handler SamplingHandler) ClientOption { } } +// WithElicitationHandler sets the elicitation handler for the client. +// When set, the client will declare elicitation capability during initialization. +func WithElicitationHandler(handler ElicitationHandler) ClientOption { + return func(c *Client) { + c.elicitationHandler = handler + } +} + // WithSession assumes a MCP Session has already been initialized func WithSession() ClientOption { return func(c *Client) { @@ -167,6 +176,10 @@ func (c *Client) Initialize( if c.samplingHandler != nil { capabilities.Sampling = &struct{}{} } + // Add elicitation capability if handler is configured + if c.elicitationHandler != nil { + capabilities.Elicitation = &struct{}{} + } // Ensure we send a params object with all required fields params := struct { @@ -451,11 +464,15 @@ func (c *Client) Complete( } // handleIncomingRequest processes incoming requests from the server. -// This is the main entry point for server-to-client requests like sampling. +// This is the main entry point for server-to-client requests like sampling and elicitation. func (c *Client) handleIncomingRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) { switch request.Method { case string(mcp.MethodSamplingCreateMessage): return c.handleSamplingRequestTransport(ctx, request) + case string(mcp.MethodElicitationCreate): + return c.handleElicitationRequestTransport(ctx, request) + case string(mcp.MethodPing): + return c.handlePingRequestTransport(ctx, request) default: return nil, fmt.Errorf("unsupported request method: %s", request.Method) } @@ -508,6 +525,64 @@ func (c *Client) handleSamplingRequestTransport(ctx context.Context, request tra return response, nil } + +// handleElicitationRequestTransport handles elicitation requests at the transport level. +func (c *Client) handleElicitationRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) { + if c.elicitationHandler == nil { + return nil, fmt.Errorf("no elicitation handler configured") + } + + // Parse the request parameters + var params mcp.ElicitationParams + if request.Params != nil { + paramsBytes, err := json.Marshal(request.Params) + if err != nil { + return nil, fmt.Errorf("failed to marshal params: %w", err) + } + if err := json.Unmarshal(paramsBytes, ¶ms); err != nil { + return nil, fmt.Errorf("failed to unmarshal params: %w", err) + } + } + + // Create the MCP request + mcpRequest := mcp.ElicitationRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodElicitationCreate), + }, + Params: params, + } + + // Call the elicitation handler + result, err := c.elicitationHandler.Elicit(ctx, mcpRequest) + if err != nil { + return nil, err + } + + // Marshal the result + resultBytes, err := json.Marshal(result) + if err != nil { + return nil, fmt.Errorf("failed to marshal result: %w", err) + } + + // Create the transport response + response := &transport.JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: request.ID, + Result: json.RawMessage(resultBytes), + } + + return response, nil +} + +func (c *Client) handlePingRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) { + b, _ := json.Marshal(&mcp.EmptyResult{}) + return &transport.JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: request.ID, + Result: b, + }, nil +} + func listByPage[T any]( ctx context.Context, client *Client, diff --git a/client/elicitation.go b/client/elicitation.go new file mode 100644 index 00000000..92f519bf --- /dev/null +++ b/client/elicitation.go @@ -0,0 +1,19 @@ +package client + +import ( + "context" + + "github.com/mark3labs/mcp-go/mcp" +) + +// ElicitationHandler defines the interface for handling elicitation requests from servers. +// Clients can implement this interface to request additional information from users. +type ElicitationHandler interface { + // Elicit handles an elicitation request from the server and returns the user's response. + // The implementation should: + // 1. Present the request message to the user + // 2. Validate input against the requested schema + // 3. Allow the user to accept, decline, or cancel + // 4. Return the appropriate response + Elicit(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) +} diff --git a/client/elicitation_test.go b/client/elicitation_test.go new file mode 100644 index 00000000..3482dab4 --- /dev/null +++ b/client/elicitation_test.go @@ -0,0 +1,242 @@ +package client + +import ( + "context" + "encoding/json" + "fmt" + "testing" + + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" +) + +// mockElicitationHandler implements ElicitationHandler for testing +type mockElicitationHandler struct { + result *mcp.ElicitationResult + err error +} + +func (m *mockElicitationHandler) Elicit(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) { + if m.err != nil { + return nil, m.err + } + return m.result, nil +} + +func TestClient_HandleElicitationRequest(t *testing.T) { + tests := []struct { + name string + handler ElicitationHandler + expectedError string + }{ + { + name: "no handler configured", + handler: nil, + expectedError: "no elicitation handler configured", + }, + { + name: "successful elicitation - accept", + handler: &mockElicitationHandler{ + result: &mcp.ElicitationResult{ + Response: mcp.ElicitationResponse{ + Type: mcp.ElicitationResponseTypeAccept, + Value: map[string]interface{}{ + "name": "test-project", + "framework": "react", + }, + }, + }, + }, + }, + { + name: "successful elicitation - decline", + handler: &mockElicitationHandler{ + result: &mcp.ElicitationResult{ + Response: mcp.ElicitationResponse{ + Type: mcp.ElicitationResponseTypeDecline, + }, + }, + }, + }, + { + name: "successful elicitation - cancel", + handler: &mockElicitationHandler{ + result: &mcp.ElicitationResult{ + Response: mcp.ElicitationResponse{ + Type: mcp.ElicitationResponseTypeCancel, + }, + }, + }, + }, + { + name: "handler returns error", + handler: &mockElicitationHandler{ + err: fmt.Errorf("user interaction failed"), + }, + expectedError: "user interaction failed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := &Client{elicitationHandler: tt.handler} + + request := transport.JSONRPCRequest{ + ID: mcp.NewRequestId(1), + Method: string(mcp.MethodElicitationCreate), + Params: map[string]interface{}{ + "message": "Please provide project details", + "requestedSchema": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "name": map[string]interface{}{"type": "string"}, + "framework": map[string]interface{}{"type": "string"}, + }, + }, + }, + } + + result, err := client.handleElicitationRequestTransport(context.Background(), request) + + if tt.expectedError != "" { + if err == nil { + t.Errorf("expected error %q, got nil", tt.expectedError) + } else if err.Error() != tt.expectedError { + t.Errorf("expected error %q, got %q", tt.expectedError, err.Error()) + } + } else { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if result == nil { + t.Error("expected result, got nil") + } else { + // Verify the response is properly formatted + var elicitationResult mcp.ElicitationResult + if err := json.Unmarshal(result.Result, &elicitationResult); err != nil { + t.Errorf("failed to unmarshal result: %v", err) + } + } + } + }) + } +} + +func TestWithElicitationHandler(t *testing.T) { + handler := &mockElicitationHandler{} + client := &Client{} + + option := WithElicitationHandler(handler) + option(client) + + if client.elicitationHandler != handler { + t.Error("elicitation handler not set correctly") + } +} + +func TestClient_Initialize_WithElicitationHandler(t *testing.T) { + mockTransport := &mockElicitationTransport{ + sendRequestFunc: func(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) { + // Verify that elicitation capability is included + // The client internally converts the typed params to a map for transport + // So we check if we're getting the initialize request + if request.Method != "initialize" { + t.Fatalf("expected initialize method, got %s", request.Method) + } + + // Verify that elicitation capability is included in the request + paramsBytes, err := json.Marshal(request.Params) + if err != nil { + t.Fatalf("failed to marshal params: %v", err) + } + + var initParams struct { + Capabilities mcp.ClientCapabilities `json:"capabilities"` + } + if err := json.Unmarshal(paramsBytes, &initParams); err != nil { + t.Fatalf("failed to unmarshal params: %v", err) + } + + if initParams.Capabilities.Elicitation == nil { + t.Error("expected elicitation capability to be declared") + } + + // Return successful initialization response + result := mcp.InitializeResult{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ServerInfo: mcp.Implementation{ + Name: "test-server", + Version: "1.0.0", + }, + Capabilities: mcp.ServerCapabilities{}, + } + + resultBytes, _ := json.Marshal(result) + return &transport.JSONRPCResponse{ + ID: request.ID, + Result: json.RawMessage(resultBytes), + }, nil + }, + sendNotificationFunc: func(ctx context.Context, notification mcp.JSONRPCNotification) error { + return nil + }, + } + + handler := &mockElicitationHandler{} + client := NewClient(mockTransport, WithElicitationHandler(handler)) + + err := client.Start(context.Background()) + if err != nil { + t.Fatalf("failed to start client: %v", err) + } + + _, err = client.Initialize(context.Background(), mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ClientInfo: mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + }, + Capabilities: mcp.ClientCapabilities{}, + }, + }) + + if err != nil { + t.Fatalf("failed to initialize: %v", err) + } +} + +// mockElicitationTransport implements transport.Interface for testing +type mockElicitationTransport struct { + sendRequestFunc func(context.Context, transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) + sendNotificationFunc func(context.Context, mcp.JSONRPCNotification) error +} + +func (m *mockElicitationTransport) Start(ctx context.Context) error { + return nil +} + +func (m *mockElicitationTransport) Close() error { + return nil +} + +func (m *mockElicitationTransport) SendRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) { + if m.sendRequestFunc != nil { + return m.sendRequestFunc(ctx, request) + } + return nil, nil +} + +func (m *mockElicitationTransport) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error { + if m.sendNotificationFunc != nil { + return m.sendNotificationFunc(ctx, notification) + } + return nil +} + +func (m *mockElicitationTransport) SetNotificationHandler(handler func(mcp.JSONRPCNotification)) { +} + +func (m *mockElicitationTransport) GetSessionId() string { + return "mock-session" +} diff --git a/client/inprocess_elicitation_test.go b/client/inprocess_elicitation_test.go new file mode 100644 index 00000000..4718341c --- /dev/null +++ b/client/inprocess_elicitation_test.go @@ -0,0 +1,206 @@ +package client + +import ( + "context" + "testing" + + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// MockElicitationHandler implements ElicitationHandler for testing +type MockElicitationHandler struct { + // Track calls for verification + CallCount int + LastRequest mcp.ElicitationRequest +} + +func (h *MockElicitationHandler) Elicit(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) { + h.CallCount++ + h.LastRequest = request + + // Simulate user accepting and providing data + return &mcp.ElicitationResult{ + Response: mcp.ElicitationResponse{ + Type: mcp.ElicitationResponseTypeAccept, + Value: map[string]interface{}{ + "response": "User provided data", + "accepted": true, + }, + }, + }, nil +} + +func TestInProcessElicitation(t *testing.T) { + // Create server with elicitation enabled + mcpServer := server.NewMCPServer("test-server", "1.0.0", server.WithElicitation()) + + // Add a tool that uses elicitation + mcpServer.AddTool(mcp.Tool{ + Name: "test_elicitation", + Description: "Test elicitation functionality", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "action": map[string]any{ + "type": "string", + "description": "Action to perform", + }, + }, + Required: []string{"action"}, + }, + }, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + action, err := request.RequireString("action") + if err != nil { + return nil, err + } + + // Create elicitation request + elicitationRequest := mcp.ElicitationRequest{ + Params: mcp.ElicitationParams{ + Message: "Need additional information for " + action, + RequestedSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "confirm": map[string]interface{}{ + "type": "boolean", + "description": "Confirm the action", + }, + "details": map[string]interface{}{ + "type": "string", + "description": "Additional details", + }, + }, + "required": []string{"confirm"}, + }, + }, + } + + // Request elicitation from client + result, err := mcpServer.RequestElicitation(ctx, elicitationRequest) + if err != nil { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: "Elicitation failed: " + err.Error(), + }, + }, + IsError: true, + }, nil + } + + // Handle the response + var responseText string + switch result.Response.Type { + case mcp.ElicitationResponseTypeAccept: + responseText = "User accepted and provided data" + case mcp.ElicitationResponseTypeDecline: + responseText = "User declined to provide information" + case mcp.ElicitationResponseTypeCancel: + responseText = "User cancelled the request" + } + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: responseText, + }, + }, + }, nil + }) + + // Create handler for elicitation + mockHandler := &MockElicitationHandler{} + + // Create in-process client with elicitation handler + client, err := NewInProcessClientWithElicitationHandler(mcpServer, mockHandler) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer client.Close() + + // Start the client + if err := client.Start(context.Background()); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Initialize the client + _, err = client.Initialize(context.Background(), mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ClientInfo: mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + }, + Capabilities: mcp.ClientCapabilities{ + Elicitation: &struct{}{}, + }, + }, + }) + if err != nil { + t.Fatalf("Failed to initialize: %v", err) + } + + // Call the tool that triggers elicitation + result, err := client.CallTool(context.Background(), mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "test_elicitation", + Arguments: map[string]any{ + "action": "test-action", + }, + }, + }) + + if err != nil { + t.Fatalf("Failed to call tool: %v", err) + } + + // Verify the result + if len(result.Content) == 0 { + t.Fatal("Expected content in result") + } + + textContent, ok := result.Content[0].(mcp.TextContent) + if !ok { + t.Fatal("Expected text content") + } + + if textContent.Text != "User accepted and provided data" { + t.Errorf("Unexpected result: %s", textContent.Text) + } + + // Verify the handler was called + if mockHandler.CallCount != 1 { + t.Errorf("Expected handler to be called once, got %d", mockHandler.CallCount) + } + + if mockHandler.LastRequest.Params.Message != "Need additional information for test-action" { + t.Errorf("Unexpected elicitation message: %s", mockHandler.LastRequest.Params.Message) + } +} + +// NewInProcessClientWithElicitationHandler creates an in-process client with elicitation support +func NewInProcessClientWithElicitationHandler(server *server.MCPServer, handler ElicitationHandler) (*Client, error) { + // Create a wrapper that implements server.ElicitationHandler + serverHandler := &inProcessElicitationHandlerWrapper{handler: handler} + + inProcessTransport := transport.NewInProcessTransportWithOptions(server, + transport.WithElicitationHandler(serverHandler)) + + client := NewClient(inProcessTransport) + client.elicitationHandler = handler + + return client, nil +} + +// inProcessElicitationHandlerWrapper wraps client.ElicitationHandler to implement server.ElicitationHandler +type inProcessElicitationHandlerWrapper struct { + handler ElicitationHandler +} + +func (w *inProcessElicitationHandlerWrapper) Elicit(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) { + return w.handler.Elicit(ctx, request) +} diff --git a/client/transport/inprocess.go b/client/transport/inprocess.go index 59c70940..3757664a 100644 --- a/client/transport/inprocess.go +++ b/client/transport/inprocess.go @@ -11,10 +11,11 @@ import ( ) type InProcessTransport struct { - server *server.MCPServer - samplingHandler server.SamplingHandler - session *server.InProcessSession - sessionID string + server *server.MCPServer + samplingHandler server.SamplingHandler + elicitationHandler server.ElicitationHandler + session *server.InProcessSession + sessionID string onNotification func(mcp.JSONRPCNotification) notifyMu sync.RWMutex @@ -28,6 +29,12 @@ func WithSamplingHandler(handler server.SamplingHandler) InProcessOption { } } +func WithElicitationHandler(handler server.ElicitationHandler) InProcessOption { + return func(t *InProcessTransport) { + t.elicitationHandler = handler + } +} + func NewInProcessTransport(server *server.MCPServer) *InProcessTransport { return &InProcessTransport{ server: server, @@ -48,9 +55,9 @@ func NewInProcessTransportWithOptions(server *server.MCPServer, opts ...InProces } func (c *InProcessTransport) Start(ctx context.Context) error { - // Create and register session if we have a sampling handler - if c.samplingHandler != nil { - c.session = server.NewInProcessSession(c.sessionID, c.samplingHandler) + // Create and register session if we have handlers + if c.samplingHandler != nil || c.elicitationHandler != nil { + c.session = server.NewInProcessSessionWithHandlers(c.sessionID, c.samplingHandler, c.elicitationHandler) if err := c.server.RegisterSession(ctx, c.session); err != nil { return fmt.Errorf("failed to register session: %w", err) } diff --git a/client/transport/streamable_http.go b/client/transport/streamable_http.go index 268aeb34..40a5bc69 100644 --- a/client/transport/streamable_http.go +++ b/client/transport/streamable_http.go @@ -605,7 +605,7 @@ func (c *StreamableHTTP) listenForever(ctx context.Context) { connectCtx, cancel := context.WithTimeout(ctx, 10*time.Second) err := c.createGETConnectionToServer(connectCtx) cancel() - + if errors.Is(err, ErrGetMethodNotAllowed) { // server does not support listening c.logger.Errorf("server does not support listening") @@ -621,7 +621,7 @@ func (c *StreamableHTTP) listenForever(ctx context.Context) { if err != nil { c.logger.Errorf("failed to listen to server. retry in 1 second: %v", err) } - + // Use context-aware sleep select { case <-time.After(retryInterval): @@ -704,15 +704,15 @@ func (c *StreamableHTTP) handleIncomingRequest(ctx context.Context, request JSON // Create a new context with timeout for request handling, respecting parent context requestCtx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - + response, err := handler(requestCtx, request) if err != nil { c.logger.Errorf("error handling request %s: %v", request.Method, err) - + // Determine appropriate JSON-RPC error code based on error type var errorCode int var errorMessage string - + // Check for specific sampling-related errors if errors.Is(err, context.Canceled) { errorCode = -32800 // Request cancelled @@ -731,7 +731,7 @@ func (c *StreamableHTTP) handleIncomingRequest(ctx context.Context, request JSON errorMessage = err.Error() } } - + // Send error response errorResponse := &JSONRPCResponse{ JSONRPC: "2.0", diff --git a/examples/elicitation/main.go b/examples/elicitation/main.go new file mode 100644 index 00000000..f3f29f3d --- /dev/null +++ b/examples/elicitation/main.go @@ -0,0 +1,215 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + "os/signal" + "sync/atomic" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// demoElicitationHandler demonstrates how to use elicitation in a tool +func demoElicitationHandler(s *server.MCPServer) server.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Create an elicitation request to get project details + elicitationRequest := mcp.ElicitationRequest{ + Params: mcp.ElicitationParams{ + Message: "I need some information to set up your project. Please provide the project details.", + RequestedSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "projectName": map[string]interface{}{ + "type": "string", + "description": "Name of the project", + "minLength": 1, + }, + "framework": map[string]interface{}{ + "type": "string", + "description": "Frontend framework to use", + "enum": []string{"react", "vue", "angular", "none"}, + }, + "includeTests": map[string]interface{}{ + "type": "boolean", + "description": "Include test setup", + "default": true, + }, + }, + "required": []string{"projectName"}, + }, + }, + } + + // Request elicitation from the client + result, err := s.RequestElicitation(ctx, elicitationRequest) + if err != nil { + return nil, fmt.Errorf("failed to request elicitation: %w", err) + } + + // Handle the user's response + switch result.Response.Type { + case mcp.ElicitationResponseTypeAccept: + // User provided the information + data, ok := result.Response.Value.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("unexpected response format") + } + + projectName, ok := data["projectName"].(string) + if !ok || projectName == "" { + return nil, fmt.Errorf("invalid or missing 'projectName' in elicitation response") + } + + framework := "none" + if fw, ok := data["framework"].(string); ok { + framework = fw + } + includeTests := true + if tests, ok := data["includeTests"].(bool); ok { + includeTests = tests + } + + // Create project based on user input + message := fmt.Sprintf( + "Created project '%s' with framework: %s, tests: %v", + projectName, framework, includeTests, + ) + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.NewTextContent(message), + }, + }, nil + + case mcp.ElicitationResponseTypeDecline: + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.NewTextContent("Project creation cancelled - user declined to provide information"), + }, + }, nil + + case mcp.ElicitationResponseTypeCancel: + return nil, fmt.Errorf("project creation cancelled by user") + + default: + return nil, fmt.Errorf("unexpected response type: %s", result.Response.Type) + } + } +} + +var requestCount atomic.Int32 + +func main() { + // Create server with elicitation capability + mcpServer := server.NewMCPServer( + "elicitation-demo-server", + "1.0.0", + server.WithElicitation(), // Enable elicitation + ) + + // Add a tool that uses elicitation + mcpServer.AddTool( + mcp.NewTool( + "create_project", + mcp.WithDescription("Creates a new project with user-specified configuration"), + ), + demoElicitationHandler(mcpServer), + ) + + // Add another tool that demonstrates conditional elicitation + mcpServer.AddTool( + mcp.NewTool( + "process_data", + mcp.WithDescription("Processes data with optional user confirmation"), + mcp.WithString("data", mcp.Required(), mcp.Description("Data to process")), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + data := request.GetString("data", "") + // Only request elicitation if data seems sensitive + if len(data) > 100 { + elicitationRequest := mcp.ElicitationRequest{ + Params: mcp.ElicitationParams{ + Message: fmt.Sprintf("The data is %d characters long. Do you want to proceed with processing?", len(data)), + RequestedSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "proceed": map[string]interface{}{ + "type": "boolean", + "description": "Whether to proceed with processing", + }, + "reason": map[string]interface{}{ + "type": "string", + "description": "Optional reason for your decision", + }, + }, + "required": []string{"proceed"}, + }, + }, + } + + result, err := mcpServer.RequestElicitation(ctx, elicitationRequest) + if err != nil { + return nil, fmt.Errorf("failed to get confirmation: %w", err) + } + + if result.Response.Type != mcp.ElicitationResponseTypeAccept { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.NewTextContent("Processing cancelled by user"), + }, + }, nil + } + + responseData, ok := result.Response.Value.(map[string]interface{}) + if !ok { + responseData = make(map[string]interface{}) + } + + if proceed, ok := responseData["proceed"].(bool); !ok || !proceed { + reason := "No reason provided" + if r, ok := responseData["reason"].(string); ok && r != "" { + reason = r + } + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.NewTextContent(fmt.Sprintf("Processing declined: %s", reason)), + }, + }, nil + } + } + + // Process the data + processed := fmt.Sprintf("Processed %d characters of data", len(data)) + count := requestCount.Add(1) + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.NewTextContent(fmt.Sprintf("%s (request #%d)", processed, count)), + }, + }, nil + }, + ) + + // Create and start stdio server + stdioServer := server.NewStdioServer(mcpServer) + + // Handle graceful shutdown + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt) + + go func() { + <-sigChan + cancel() + }() + + fmt.Fprintln(os.Stderr, "Elicitation demo server started") + if err := stdioServer.Listen(ctx, os.Stdin, os.Stdout); err != nil { + log.Fatal(err) + } +} diff --git a/mcp/types.go b/mcp/types.go index f871b7d9..39dc811d 100644 --- a/mcp/types.go +++ b/mcp/types.go @@ -56,6 +56,10 @@ const ( // https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/logging MethodSetLogLevel MCPMethod = "logging/setLevel" + // MethodElicitationCreate requests additional information from the user during interactions. + // https://modelcontextprotocol.io/docs/concepts/elicitation + MethodElicitationCreate MCPMethod = "elicitation/create" + // MethodNotificationResourcesListChanged notifies when the list of available resources changes. // https://modelcontextprotocol.io/specification/2025-03-26/server/resources#list-changed-notification MethodNotificationResourcesListChanged = "notifications/resources/list_changed" @@ -462,6 +466,8 @@ type ClientCapabilities struct { } `json:"roots,omitempty"` // Present if the client supports sampling from an LLM. Sampling *struct{} `json:"sampling,omitempty"` + // Present if the client supports elicitation requests from the server. + Elicitation *struct{} `json:"elicitation,omitempty"` } // ServerCapabilities represents capabilities that a server may support. Known @@ -492,6 +498,8 @@ type ServerCapabilities struct { // Whether this server supports notifications for changes to the tool list. ListChanged bool `json:"listChanged,omitempty"` } `json:"tools,omitempty"` + // Present if the server supports elicitation requests to the client. + Elicitation *struct{} `json:"elicitation,omitempty"` } // Implementation describes the name and version of an MCP implementation. @@ -814,6 +822,54 @@ func (l LoggingLevel) ShouldSendTo(minLevel LoggingLevel) bool { return ia >= ib } +/* Elicitation */ + +// ElicitationRequest is a request from the server to the client to request additional +// information from the user during an interaction. +type ElicitationRequest struct { + Request + Params ElicitationParams `json:"params"` +} + +// ElicitationParams contains the parameters for an elicitation request. +type ElicitationParams struct { + // A human-readable message explaining what information is being requested and why. + Message string `json:"message"` + // A JSON Schema defining the expected structure of the user's response. + RequestedSchema any `json:"requestedSchema"` +} + +// ElicitationResult represents the result of an elicitation request. +type ElicitationResult struct { + Result + // The user's response, which could be: + // - The requested information (if user accepted) + // - A decline indicator (if user declined) + // - A cancel indicator (if user cancelled) + Response ElicitationResponse `json:"response"` +} + +// ElicitationResponse represents the user's response to an elicitation request. +type ElicitationResponse struct { + // Type indicates whether the user accepted, declined, or cancelled. + Type ElicitationResponseType `json:"type"` + // Value contains the user's response data if they accepted. + // Should conform to the requestedSchema from the ElicitationRequest. + Value any `json:"value,omitempty"` +} + +// ElicitationResponseType indicates how the user responded to an elicitation request. +type ElicitationResponseType string + +const ( + // ElicitationResponseTypeAccept indicates the user provided the requested information. + ElicitationResponseTypeAccept ElicitationResponseType = "accept" + // ElicitationResponseTypeDecline indicates the user explicitly declined to provide information. + ElicitationResponseTypeDecline ElicitationResponseType = "decline" + // ElicitationResponseTypeCancel indicates the user cancelled without making a choice. + ElicitationResponseTypeCancel ElicitationResponseType = "cancel" +) + /* Sampling */ const ( diff --git a/server/elicitation.go b/server/elicitation.go new file mode 100644 index 00000000..8deee383 --- /dev/null +++ b/server/elicitation.go @@ -0,0 +1,25 @@ +package server + +import ( + "context" + "fmt" + + "github.com/mark3labs/mcp-go/mcp" +) + +// RequestElicitation sends an elicitation request to the client. +// The client must have declared elicitation capability during initialization. +// The session must implement SessionWithElicitation to support this operation. +func (s *MCPServer) RequestElicitation(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) { + session := ClientSessionFromContext(ctx) + if session == nil { + return nil, fmt.Errorf("no active session") + } + + // Check if the session supports elicitation requests + if elicitationSession, ok := session.(SessionWithElicitation); ok { + return elicitationSession.RequestElicitation(ctx, request) + } + + return nil, fmt.Errorf("session does not support elicitation") +} diff --git a/server/elicitation_test.go b/server/elicitation_test.go new file mode 100644 index 00000000..ed6feb51 --- /dev/null +++ b/server/elicitation_test.go @@ -0,0 +1,263 @@ +package server + +import ( + "context" + "testing" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockBasicSession implements ClientSession for testing (without elicitation support) +type mockBasicSession struct { + sessionID string +} + +func (m *mockBasicSession) SessionID() string { + return m.sessionID +} + +func (m *mockBasicSession) NotificationChannel() chan<- mcp.JSONRPCNotification { + return make(chan mcp.JSONRPCNotification, 1) +} + +func (m *mockBasicSession) Initialize() {} + +func (m *mockBasicSession) Initialized() bool { + return true +} + +// mockElicitationSession implements SessionWithElicitation for testing +type mockElicitationSession struct { + sessionID string + result *mcp.ElicitationResult + err error +} + +func (m *mockElicitationSession) SessionID() string { + return m.sessionID +} + +func (m *mockElicitationSession) NotificationChannel() chan<- mcp.JSONRPCNotification { + return make(chan mcp.JSONRPCNotification, 1) +} + +func (m *mockElicitationSession) Initialize() {} + +func (m *mockElicitationSession) Initialized() bool { + return true +} + +func (m *mockElicitationSession) RequestElicitation(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) { + if m.err != nil { + return nil, m.err + } + return m.result, nil +} + +func TestMCPServer_RequestElicitation_NoSession(t *testing.T) { + server := NewMCPServer("test", "1.0.0") + server.capabilities.elicitation = mcp.ToBoolPtr(true) + + request := mcp.ElicitationRequest{ + Params: mcp.ElicitationParams{ + Message: "Need some information", + RequestedSchema: map[string]interface{}{ + "type": "object", + }, + }, + } + + _, err := server.RequestElicitation(context.Background(), request) + + if err == nil { + t.Error("expected error when no session available") + } + + expectedError := "no active session" + if err.Error() != expectedError { + t.Errorf("expected error %q, got %q", expectedError, err.Error()) + } +} + +func TestMCPServer_RequestElicitation_SessionDoesNotSupportElicitation(t *testing.T) { + server := NewMCPServer("test", "1.0.0", WithElicitation()) + + // Use a regular session that doesn't implement SessionWithElicitation + mockSession := &mockBasicSession{sessionID: "test-session"} + + ctx := context.Background() + ctx = server.WithContext(ctx, mockSession) + + request := mcp.ElicitationRequest{ + Params: mcp.ElicitationParams{ + Message: "Need some information", + RequestedSchema: map[string]interface{}{ + "type": "object", + }, + }, + } + + _, err := server.RequestElicitation(ctx, request) + + if err == nil { + t.Error("expected error when session doesn't support elicitation") + } + + expectedError := "session does not support elicitation" + if err.Error() != expectedError { + t.Errorf("expected error %q, got %q", expectedError, err.Error()) + } +} + +func TestMCPServer_RequestElicitation_Success(t *testing.T) { + server := NewMCPServer("test", "1.0.0", WithElicitation()) + + // Create a mock elicitation session + mockSession := &mockElicitationSession{ + sessionID: "test-session", + result: &mcp.ElicitationResult{ + Response: mcp.ElicitationResponse{ + Type: mcp.ElicitationResponseTypeAccept, + Value: map[string]interface{}{ + "projectName": "my-project", + "framework": "react", + }, + }, + }, + } + + // Create context with session + ctx := context.Background() + ctx = server.WithContext(ctx, mockSession) + + request := mcp.ElicitationRequest{ + Params: mcp.ElicitationParams{ + Message: "Please provide project details", + RequestedSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "projectName": map[string]interface{}{"type": "string"}, + "framework": map[string]interface{}{"type": "string"}, + }, + }, + }, + } + + result, err := server.RequestElicitation(ctx, request) + + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if result == nil { + t.Error("expected result, got nil") + return + } + + if result.Response.Type != mcp.ElicitationResponseTypeAccept { + t.Errorf("expected response type %q, got %q", mcp.ElicitationResponseTypeAccept, result.Response.Type) + } + + value, ok := result.Response.Value.(map[string]interface{}) + if !ok { + t.Error("expected value to be a map") + return + } + + if value["projectName"] != "my-project" { + t.Errorf("expected projectName %q, got %q", "my-project", value["projectName"]) + } +} + +func TestRequestElicitation(t *testing.T) { + tests := []struct { + name string + session ClientSession + request mcp.ElicitationRequest + expectedError string + expectedType mcp.ElicitationResponseType + }{ + { + name: "successful elicitation with accept", + session: &mockElicitationSession{ + sessionID: "test-1", + result: &mcp.ElicitationResult{ + Response: mcp.ElicitationResponse{ + Type: mcp.ElicitationResponseTypeAccept, + Value: map[string]interface{}{ + "name": "test-project", + "framework": "react", + }, + }, + }, + }, + request: mcp.ElicitationRequest{ + Params: mcp.ElicitationParams{ + Message: "Please provide project details", + RequestedSchema: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "name": map[string]interface{}{"type": "string"}, + "framework": map[string]interface{}{"type": "string"}, + }, + }, + }, + }, + expectedType: mcp.ElicitationResponseTypeAccept, + }, + { + name: "elicitation declined by user", + session: &mockElicitationSession{ + sessionID: "test-2", + result: &mcp.ElicitationResult{ + Response: mcp.ElicitationResponse{ + Type: mcp.ElicitationResponseTypeDecline, + }, + }, + }, + request: mcp.ElicitationRequest{ + Params: mcp.ElicitationParams{ + Message: "Need some info", + RequestedSchema: map[string]interface{}{"type": "object"}, + }, + }, + expectedType: mcp.ElicitationResponseTypeDecline, + }, + { + name: "session does not support elicitation", + session: &fakeSession{sessionID: "test-3"}, + request: mcp.ElicitationRequest{ + Params: mcp.ElicitationParams{ + Message: "Need info", + RequestedSchema: map[string]interface{}{"type": "object"}, + }, + }, + expectedError: "session does not support elicitation", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := NewMCPServer("test", "1.0", WithElicitation()) + ctx := server.WithContext(context.Background(), tt.session) + + result, err := server.RequestElicitation(ctx, tt.request) + + if tt.expectedError != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedError) + return + } + + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, tt.expectedType, result.Response.Type) + + if tt.expectedType == mcp.ElicitationResponseTypeAccept { + assert.NotNil(t, result.Response.Value) + } + }) + } +} diff --git a/server/inprocess_session.go b/server/inprocess_session.go index daaf28a5..c6fddc60 100644 --- a/server/inprocess_session.go +++ b/server/inprocess_session.go @@ -15,6 +15,11 @@ type SamplingHandler interface { CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) } +// ElicitationHandler defines the interface for handling elicitation requests from servers. +type ElicitationHandler interface { + Elicit(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) +} + type InProcessSession struct { sessionID string notifications chan mcp.JSONRPCNotification @@ -23,6 +28,7 @@ type InProcessSession struct { clientInfo atomic.Value clientCapabilities atomic.Value samplingHandler SamplingHandler + elicitationHandler ElicitationHandler mu sync.RWMutex } @@ -34,6 +40,15 @@ func NewInProcessSession(sessionID string, samplingHandler SamplingHandler) *InP } } +func NewInProcessSessionWithHandlers(sessionID string, samplingHandler SamplingHandler, elicitationHandler ElicitationHandler) *InProcessSession { + return &InProcessSession{ + sessionID: sessionID, + notifications: make(chan mcp.JSONRPCNotification, 100), + samplingHandler: samplingHandler, + elicitationHandler: elicitationHandler, + } +} + func (s *InProcessSession) SessionID() string { return s.sessionID } @@ -101,6 +116,18 @@ func (s *InProcessSession) RequestSampling(ctx context.Context, request mcp.Crea return handler.CreateMessage(ctx, request) } +func (s *InProcessSession) RequestElicitation(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) { + s.mu.RLock() + handler := s.elicitationHandler + s.mu.RUnlock() + + if handler == nil { + return nil, fmt.Errorf("no elicitation handler available") + } + + return handler.Elicit(ctx, request) +} + // GenerateInProcessSessionID generates a unique session ID for inprocess clients func GenerateInProcessSessionID() string { return fmt.Sprintf("inprocess-%d", time.Now().UnixNano()) @@ -108,8 +135,9 @@ func GenerateInProcessSessionID() string { // Ensure interface compliance var ( - _ ClientSession = (*InProcessSession)(nil) - _ SessionWithLogging = (*InProcessSession)(nil) - _ SessionWithClientInfo = (*InProcessSession)(nil) - _ SessionWithSampling = (*InProcessSession)(nil) + _ ClientSession = (*InProcessSession)(nil) + _ SessionWithLogging = (*InProcessSession)(nil) + _ SessionWithClientInfo = (*InProcessSession)(nil) + _ SessionWithSampling = (*InProcessSession)(nil) + _ SessionWithElicitation = (*InProcessSession)(nil) ) diff --git a/server/server.go b/server/server.go index 366bf661..c29bec57 100644 --- a/server/server.go +++ b/server/server.go @@ -177,10 +177,11 @@ func WithPaginationLimit(limit int) ServerOption { // serverCapabilities defines the supported features of the MCP server type serverCapabilities struct { - tools *toolCapabilities - resources *resourceCapabilities - prompts *promptCapabilities - logging *bool + tools *toolCapabilities + resources *resourceCapabilities + prompts *promptCapabilities + logging *bool + elicitation *bool sampling *bool } @@ -288,6 +289,13 @@ func WithLogging() ServerOption { } } +// WithElicitation enables elicitation capabilities for the server +func WithElicitation() ServerOption { + return func(s *MCPServer) { + s.capabilities.elicitation = mcp.ToBoolPtr(true) + } +} + // WithInstructions sets the server instructions for the client returned in the initialize response func WithInstructions(instructions string) ServerOption { return func(s *MCPServer) { @@ -628,6 +636,10 @@ func (s *MCPServer) handleInitialize( capabilities.Sampling = &struct{}{} } + if s.capabilities.elicitation != nil && *s.capabilities.elicitation { + capabilities.Elicitation = &struct{}{} + } + result := mcp.InitializeResult{ ProtocolVersion: s.protocolVersion(request.Params.ProtocolVersion), ServerInfo: mcp.Implementation{ diff --git a/server/session.go b/server/session.go index 11ee8a2f..3d11df93 100644 --- a/server/session.go +++ b/server/session.go @@ -52,6 +52,13 @@ type SessionWithClientInfo interface { SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) } +// SessionWithElicitation is an extension of ClientSession that can send elicitation requests +type SessionWithElicitation interface { + ClientSession + // RequestElicitation sends an elicitation request to the client and waits for response + RequestElicitation(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) +} + // SessionWithStreamableHTTPConfig extends ClientSession to support streamable HTTP transport configurations type SessionWithStreamableHTTPConfig interface { ClientSession diff --git a/server/stdio.go b/server/stdio.go index 8c270e18..8d533f94 100644 --- a/server/stdio.go +++ b/server/stdio.go @@ -92,16 +92,17 @@ func WithQueueSize(size int) StdioOption { // stdioSession is a static client session, since stdio has only one client. type stdioSession struct { - notifications chan mcp.JSONRPCNotification - initialized atomic.Bool - loggingLevel atomic.Value - clientInfo atomic.Value // stores session-specific client info + notifications chan mcp.JSONRPCNotification + initialized atomic.Bool + loggingLevel atomic.Value + clientInfo atomic.Value // stores session-specific client info clientCapabilities atomic.Value // stores session-specific client capabilities - writer io.Writer // for sending requests to client - requestID atomic.Int64 // for generating unique request IDs - mu sync.RWMutex // protects writer - pendingRequests map[int64]chan *samplingResponse // for tracking pending sampling requests - pendingMu sync.RWMutex // protects pendingRequests + writer io.Writer // for sending requests to client + requestID atomic.Int64 // for generating unique request IDs + mu sync.RWMutex // protects writer + pendingRequests map[int64]chan *samplingResponse // for tracking pending sampling requests + pendingElicitations map[int64]chan *elicitationResponse // for tracking pending elicitation requests + pendingMu sync.RWMutex // protects pendingRequests and pendingElicitations } // samplingResponse represents a response to a sampling request @@ -110,6 +111,12 @@ type samplingResponse struct { err error } +// elicitationResponse represents a response to an elicitation request +type elicitationResponse struct { + result *mcp.ElicitationResult + err error +} + func (s *stdioSession) SessionID() string { return "stdio" } @@ -229,6 +236,69 @@ func (s *stdioSession) RequestSampling(ctx context.Context, request mcp.CreateMe } } +// RequestElicitation sends an elicitation request to the client and waits for the response. +func (s *stdioSession) RequestElicitation(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) { + s.mu.RLock() + writer := s.writer + s.mu.RUnlock() + + if writer == nil { + return nil, fmt.Errorf("no writer available for sending requests") + } + + // Generate a unique request ID + id := s.requestID.Add(1) + + // Create a response channel for this request + responseChan := make(chan *elicitationResponse, 1) + s.pendingMu.Lock() + s.pendingElicitations[id] = responseChan + s.pendingMu.Unlock() + + // Cleanup function to remove the pending request + cleanup := func() { + s.pendingMu.Lock() + delete(s.pendingElicitations, id) + s.pendingMu.Unlock() + } + defer cleanup() + + // Create the JSON-RPC request + jsonRPCRequest := struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Method string `json:"method"` + Params mcp.ElicitationParams `json:"params"` + }{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: id, + Method: string(mcp.MethodElicitationCreate), + Params: request.Params, + } + + // Marshal and send the request + requestBytes, err := json.Marshal(jsonRPCRequest) + if err != nil { + return nil, fmt.Errorf("failed to marshal elicitation request: %w", err) + } + requestBytes = append(requestBytes, '\n') + + if _, err := writer.Write(requestBytes); err != nil { + return nil, fmt.Errorf("failed to write elicitation request: %w", err) + } + + // Wait for the response or context cancellation + select { + case <-ctx.Done(): + return nil, ctx.Err() + case response := <-responseChan: + if response.err != nil { + return nil, response.err + } + return response.result, nil + } +} + // SetWriter sets the writer for sending requests to the client. func (s *stdioSession) SetWriter(writer io.Writer) { s.mu.Lock() @@ -237,15 +307,17 @@ func (s *stdioSession) SetWriter(writer io.Writer) { } var ( - _ ClientSession = (*stdioSession)(nil) - _ SessionWithLogging = (*stdioSession)(nil) - _ SessionWithClientInfo = (*stdioSession)(nil) - _ SessionWithSampling = (*stdioSession)(nil) + _ ClientSession = (*stdioSession)(nil) + _ SessionWithLogging = (*stdioSession)(nil) + _ SessionWithClientInfo = (*stdioSession)(nil) + _ SessionWithSampling = (*stdioSession)(nil) + _ SessionWithElicitation = (*stdioSession)(nil) ) var stdioSessionInstance = stdioSession{ - notifications: make(chan mcp.JSONRPCNotification, 100), - pendingRequests: make(map[int64]chan *samplingResponse), + notifications: make(chan mcp.JSONRPCNotification, 100), + pendingRequests: make(map[int64]chan *samplingResponse), + pendingElicitations: make(map[int64]chan *elicitationResponse), } // NewStdioServer creates a new stdio server wrapper around an MCPServer. @@ -445,6 +517,11 @@ func (s *StdioServer) processMessage( return nil } + // Check if this is a response to an elicitation request + if s.handleElicitationResponse(rawMessage) { + return nil + } + // Check if this is a tool call that might need sampling (and thus should be processed concurrently) var baseMessage struct { Method string `json:"method"` @@ -543,6 +620,67 @@ func (s *stdioSession) handleSamplingResponse(rawMessage json.RawMessage) bool { return true } +// handleElicitationResponse checks if the message is a response to an elicitation request +// and routes it to the appropriate pending request channel. +func (s *StdioServer) handleElicitationResponse(rawMessage json.RawMessage) bool { + return stdioSessionInstance.handleElicitationResponse(rawMessage) +} + +// handleElicitationResponse handles incoming elicitation responses for this session +func (s *stdioSession) handleElicitationResponse(rawMessage json.RawMessage) bool { + // Try to parse as a JSON-RPC response + var response struct { + JSONRPC string `json:"jsonrpc"` + ID json.Number `json:"id"` + Result json.RawMessage `json:"result,omitempty"` + Error *struct { + Code int `json:"code"` + Message string `json:"message"` + } `json:"error,omitempty"` + } + + if err := json.Unmarshal(rawMessage, &response); err != nil { + return false + } + // Parse the ID as int64 + id, err := response.ID.Int64() + if err != nil { + return false + } + + // Check if we have a pending elicitation request with this ID + s.pendingMu.RLock() + responseChan, exists := s.pendingElicitations[id] + s.pendingMu.RUnlock() + + if !exists { + return false + } + + // Parse and send the response + elicitationResp := &elicitationResponse{} + + if response.Error != nil { + elicitationResp.err = fmt.Errorf("elicitation request failed: %s", response.Error.Message) + } else { + var result mcp.ElicitationResult + if err := json.Unmarshal(response.Result, &result); err != nil { + elicitationResp.err = fmt.Errorf("failed to unmarshal elicitation response: %w", err) + } else { + elicitationResp.result = &result + } + } + + // Send the response (non-blocking) + select { + case responseChan <- elicitationResp: + default: + // Channel is full or closed, ignore + } + + return true +} + // writeResponse marshals and writes a JSON-RPC response message followed by a newline. // Returns an error if marshaling or writing fails. func (s *StdioServer) writeResponse( diff --git a/server/streamable_http.go b/server/streamable_http.go index 24ec1c95..f861c319 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -237,14 +237,13 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request } // Check if this is a sampling response (has result/error but no method) - isSamplingResponse := jsonMessage.Method == "" && jsonMessage.ID != nil && + isResponse := jsonMessage.Method == "" && jsonMessage.ID != nil && (jsonMessage.Result != nil || jsonMessage.Error != nil) - isInitializeRequest := jsonMessage.Method == mcp.MethodInitialize // Handle sampling responses separately - if isSamplingResponse { - if err := s.handleSamplingResponse(w, r, jsonMessage); err != nil { + if isResponse { + if err := s.handleResponse(w, r, jsonMessage); err != nil { s.logger.Errorf("Failed to handle sampling response: %v", err) http.Error(w, "Failed to handle sampling response", http.StatusInternalServerError) } @@ -390,7 +389,7 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) return } defer s.server.UnregisterSession(r.Context(), sessionID) - + // Register session for sampling response delivery s.activeSessions.Store(sessionID, session) defer s.activeSessions.Delete(sessionID) @@ -437,6 +436,21 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) case <-done: return } + case elicitationReq := <-session.elicitationRequestChan: + // Send elicitation request to client via SSE + jsonrpcRequest := mcp.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(elicitationReq.requestID), + Request: mcp.Request{ + Method: string(mcp.MethodElicitationCreate), + }, + Params: elicitationReq.request.Params, + } + select { + case writeChan <- jsonrpcRequest: + case <-done: + return + } case <-done: return } @@ -525,8 +539,8 @@ func writeSSEEvent(w io.Writer, data any) error { return nil } -// handleSamplingResponse processes incoming sampling responses from clients -func (s *StreamableHTTPServer) handleSamplingResponse(w http.ResponseWriter, r *http.Request, responseMessage struct { +// handleResponse processes incoming responses from clients +func (s *StreamableHTTPServer) handleResponse(w http.ResponseWriter, r *http.Request, responseMessage struct { ID json.RawMessage `json:"id"` Result json.RawMessage `json:"result,omitempty"` Error json.RawMessage `json:"error,omitempty"` @@ -558,7 +572,7 @@ func (s *StreamableHTTPServer) handleSamplingResponse(w http.ResponseWriter, r * } // Create the sampling response item - response := samplingResponseItem{ + response := responseItem{ requestID: requestID, } @@ -575,20 +589,14 @@ func (s *StreamableHTTPServer) handleSamplingResponse(w http.ResponseWriter, r * response.err = fmt.Errorf("sampling error %d: %s", jsonrpcError.Code, jsonrpcError.Message) } } else if responseMessage.Result != nil { - // Parse result - var result mcp.CreateMessageResult - if err := json.Unmarshal(responseMessage.Result, &result); err != nil { - response.err = fmt.Errorf("failed to parse sampling result: %v", err) - } else { - response.result = &result - } + response.result = responseMessage.Result } else { response.err = fmt.Errorf("sampling response has neither result nor error") } // Find the corresponding session and deliver the response // The response is delivered to the specific session identified by sessionID - if err := s.deliverSamplingResponse(sessionID, response); err != nil { + if err := s.deliverResponse(sessionID, response); err != nil { s.logger.Errorf("Failed to deliver sampling response: %v", err) http.Error(w, "Failed to deliver response", http.StatusInternalServerError) return err @@ -600,7 +608,7 @@ func (s *StreamableHTTPServer) handleSamplingResponse(w http.ResponseWriter, r * } // deliverSamplingResponse delivers a sampling response to the appropriate session -func (s *StreamableHTTPServer) deliverSamplingResponse(sessionID string, response samplingResponseItem) error { +func (s *StreamableHTTPServer) deliverResponse(sessionID string, response responseItem) error { // Look up the active session sessionInterface, ok := s.activeSessions.Load(sessionID) if !ok { @@ -613,12 +621,12 @@ func (s *StreamableHTTPServer) deliverSamplingResponse(sessionID string, respons } // Look up the dedicated response channel for this specific request - responseChannelInterface, exists := session.samplingRequests.Load(response.requestID) + responseChannelInterface, exists := session.requests.Load(response.requestID) if !exists { return fmt.Errorf("no pending request found for session %s, request %d", sessionID, response.requestID) } - responseChan, ok := responseChannelInterface.(chan samplingResponseItem) + responseChan, ok := responseChannelInterface.(chan responseItem) if !ok { return fmt.Errorf("invalid response channel type for session %s, request %d", sessionID, response.requestID) } @@ -723,15 +731,22 @@ func (s *sessionToolsStore) delete(sessionID string) { type samplingRequestItem struct { requestID int64 request mcp.CreateMessageRequest - response chan samplingResponseItem + response chan responseItem } -type samplingResponseItem struct { +type responseItem struct { requestID int64 - result *mcp.CreateMessageResult + result json.RawMessage err error } +// Elicitation support types for HTTP transport +type elicitationRequestItem struct { + requestID int64 + request mcp.ElicitationRequest + response chan responseItem +} + // streamableHttpSession is a session for streamable-http transport // When in POST handlers(request/notification), it's ephemeral, and only exists in the life of the request handler. // When in GET handlers(listening), it's a real session, and will be registered in the MCP server. @@ -743,18 +758,21 @@ type streamableHttpSession struct { logLevels *sessionLogLevelsStore // Sampling support for bidirectional communication - samplingRequestChan chan samplingRequestItem // server -> client sampling requests - samplingRequests sync.Map // requestID -> pending sampling request context - requestIDCounter atomic.Int64 // for generating unique request IDs + samplingRequestChan chan samplingRequestItem // server -> client sampling requests + elicitationRequestChan chan elicitationRequestItem // server -> client elicitation requests + + requests sync.Map // requestID -> pending request context + requestIDCounter atomic.Int64 // for generating unique request IDs } func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore, levels *sessionLogLevelsStore) *streamableHttpSession { s := &streamableHttpSession{ - sessionID: sessionID, - notificationChannel: make(chan mcp.JSONRPCNotification, 100), - tools: toolStore, - logLevels: levels, - samplingRequestChan: make(chan samplingRequestItem, 10), + sessionID: sessionID, + notificationChannel: make(chan mcp.JSONRPCNotification, 100), + tools: toolStore, + logLevels: levels, + samplingRequestChan: make(chan samplingRequestItem, 10), + elicitationRequestChan: make(chan elicitationRequestItem, 10), } return s } @@ -810,21 +828,21 @@ var _ SessionWithStreamableHTTPConfig = (*streamableHttpSession)(nil) func (s *streamableHttpSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { // Generate unique request ID requestID := s.requestIDCounter.Add(1) - + // Create response channel for this specific request - responseChan := make(chan samplingResponseItem, 1) - + responseChan := make(chan responseItem, 1) + // Create the sampling request item samplingRequest := samplingRequestItem{ requestID: requestID, request: request, response: responseChan, } - + // Store the pending request - s.samplingRequests.Store(requestID, responseChan) - defer s.samplingRequests.Delete(requestID) - + s.requests.Store(requestID, responseChan) + defer s.requests.Delete(requestID) + // Send the sampling request via the channel (non-blocking) select { case s.samplingRequestChan <- samplingRequest: @@ -834,20 +852,70 @@ func (s *streamableHttpSession) RequestSampling(ctx context.Context, request mcp default: return nil, fmt.Errorf("sampling request queue is full - server overloaded") } - + + // Wait for response or context cancellation + select { + case response := <-responseChan: + if response.err != nil { + return nil, response.err + } + var result mcp.CreateMessageResult + if err := json.Unmarshal(response.result, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal sampling response: %v", err) + } + return &result, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +// RequestElicitation implements SessionWithElicitation interface for HTTP transport +func (s *streamableHttpSession) RequestElicitation(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) { + // Generate unique request ID + requestID := s.requestIDCounter.Add(1) + + // Create response channel for this specific request + responseChan := make(chan responseItem, 1) + + // Create the sampling request item + elicitationRequest := elicitationRequestItem{ + requestID: requestID, + request: request, + response: responseChan, + } + + // Store the pending request + s.requests.Store(requestID, responseChan) + defer s.requests.Delete(requestID) + + // Send the sampling request via the channel (non-blocking) + select { + case s.elicitationRequestChan <- elicitationRequest: + // Request queued successfully + case <-ctx.Done(): + return nil, ctx.Err() + default: + return nil, fmt.Errorf("elicitation request queue is full - server overloaded") + } + // Wait for response or context cancellation select { case response := <-responseChan: if response.err != nil { return nil, response.err } - return response.result, nil + var result mcp.ElicitationResult + if err := json.Unmarshal(response.result, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal elicitation response: %v", err) + } + return &result, nil case <-ctx.Done(): return nil, ctx.Err() } } var _ SessionWithSampling = (*streamableHttpSession)(nil) +var _ SessionWithElicitation = (*streamableHttpSession)(nil) // --- session id manager --- diff --git a/server/streamable_http_sampling_test.go b/server/streamable_http_sampling_test.go index 4cf57838..c8e5e10c 100644 --- a/server/streamable_http_sampling_test.go +++ b/server/streamable_http_sampling_test.go @@ -185,7 +185,7 @@ func TestStreamableHTTPServer_SamplingQueueFull(t *testing.T) { session.samplingRequestChan <- samplingRequestItem{ requestID: int64(i), request: mcp.CreateMessageRequest{}, - response: make(chan samplingResponseItem, 1), + response: make(chan responseItem, 1), } } @@ -213,4 +213,4 @@ func TestStreamableHTTPServer_SamplingQueueFull(t *testing.T) { if !strings.Contains(err.Error(), "queue is full") { t.Errorf("Expected queue full error, got: %v", err) } -} \ No newline at end of file +}