Skip to content

Commit e215c98

Browse files
committed
feat: implement MCP elicitation support (#413)
* Add ElicitationRequest, ElicitationResult, and related types to mcp/types.go * Implement server-side RequestElicitation method with session support * Add client-side ElicitationHandler interface and request handling * Implement elicitation in stdio and in-process transports * Add comprehensive tests following sampling patterns * Create elicitation example demonstrating usage patterns * Use 'Elicitation' prefix for type names to maintain clarity
1 parent baa7153 commit e215c98

18 files changed

+1360
-102
lines changed

client/client.go

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ type Client struct {
2323
clientCapabilities mcp.ClientCapabilities
2424
serverCapabilities mcp.ServerCapabilities
2525
samplingHandler SamplingHandler
26+
elicitationHandler ElicitationHandler
2627
}
2728

2829
type ClientOption func(*Client)
@@ -42,6 +43,14 @@ func WithSamplingHandler(handler SamplingHandler) ClientOption {
4243
}
4344
}
4445

46+
// WithElicitationHandler sets the elicitation handler for the client.
47+
// When set, the client will declare elicitation capability during initialization.
48+
func WithElicitationHandler(handler ElicitationHandler) ClientOption {
49+
return func(c *Client) {
50+
c.elicitationHandler = handler
51+
}
52+
}
53+
4554
// WithSession assumes a MCP Session has already been initialized
4655
func WithSession() ClientOption {
4756
return func(c *Client) {
@@ -154,6 +163,10 @@ func (c *Client) Initialize(
154163
if c.samplingHandler != nil {
155164
capabilities.Sampling = &struct{}{}
156165
}
166+
// Add elicitation capability if handler is configured
167+
if c.elicitationHandler != nil {
168+
capabilities.Elicitation = &struct{}{}
169+
}
157170

158171
// Ensure we send a params object with all required fields
159172
params := struct {
@@ -427,11 +440,13 @@ func (c *Client) Complete(
427440
}
428441

429442
// handleIncomingRequest processes incoming requests from the server.
430-
// This is the main entry point for server-to-client requests like sampling.
443+
// This is the main entry point for server-to-client requests like sampling and elicitation.
431444
func (c *Client) handleIncomingRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
432445
switch request.Method {
433446
case string(mcp.MethodSamplingCreateMessage):
434447
return c.handleSamplingRequestTransport(ctx, request)
448+
case string(mcp.MethodElicitationCreate):
449+
return c.handleElicitationRequestTransport(ctx, request)
435450
default:
436451
return nil, fmt.Errorf("unsupported request method: %s", request.Method)
437452
}
@@ -484,6 +499,55 @@ func (c *Client) handleSamplingRequestTransport(ctx context.Context, request tra
484499

485500
return response, nil
486501
}
502+
503+
// handleElicitationRequestTransport handles elicitation requests at the transport level.
504+
func (c *Client) handleElicitationRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
505+
if c.elicitationHandler == nil {
506+
return nil, fmt.Errorf("no elicitation handler configured")
507+
}
508+
509+
// Parse the request parameters
510+
var params mcp.ElicitationParams
511+
if request.Params != nil {
512+
paramsBytes, err := json.Marshal(request.Params)
513+
if err != nil {
514+
return nil, fmt.Errorf("failed to marshal params: %w", err)
515+
}
516+
if err := json.Unmarshal(paramsBytes, &params); err != nil {
517+
return nil, fmt.Errorf("failed to unmarshal params: %w", err)
518+
}
519+
}
520+
521+
// Create the MCP request
522+
mcpRequest := mcp.ElicitationRequest{
523+
Request: mcp.Request{
524+
Method: string(mcp.MethodElicitationCreate),
525+
},
526+
Params: params,
527+
}
528+
529+
// Call the elicitation handler
530+
result, err := c.elicitationHandler.Elicit(ctx, mcpRequest)
531+
if err != nil {
532+
return nil, err
533+
}
534+
535+
// Marshal the result
536+
resultBytes, err := json.Marshal(result)
537+
if err != nil {
538+
return nil, fmt.Errorf("failed to marshal result: %w", err)
539+
}
540+
541+
// Create the transport response
542+
response := &transport.JSONRPCResponse{
543+
JSONRPC: mcp.JSONRPC_VERSION,
544+
ID: request.ID,
545+
Result: json.RawMessage(resultBytes),
546+
}
547+
548+
return response, nil
549+
}
550+
487551
func listByPage[T any](
488552
ctx context.Context,
489553
client *Client,

client/elicitation.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package client
2+
3+
import (
4+
"context"
5+
6+
"github.com/mark3labs/mcp-go/mcp"
7+
)
8+
9+
// ElicitationHandler defines the interface for handling elicitation requests from servers.
10+
// Clients can implement this interface to request additional information from users.
11+
type ElicitationHandler interface {
12+
// Elicit handles an elicitation request from the server and returns the user's response.
13+
// The implementation should:
14+
// 1. Present the request message to the user
15+
// 2. Validate input against the requested schema
16+
// 3. Allow the user to accept, decline, or cancel
17+
// 4. Return the appropriate response
18+
Elicit(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error)
19+
}

client/elicitation_test.go

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
package client
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"testing"
8+
9+
"github.com/mark3labs/mcp-go/client/transport"
10+
"github.com/mark3labs/mcp-go/mcp"
11+
)
12+
13+
// mockElicitationHandler implements ElicitationHandler for testing
14+
type mockElicitationHandler struct {
15+
result *mcp.ElicitationResult
16+
err error
17+
}
18+
19+
func (m *mockElicitationHandler) Elicit(ctx context.Context, request mcp.ElicitationRequest) (*mcp.ElicitationResult, error) {
20+
if m.err != nil {
21+
return nil, m.err
22+
}
23+
return m.result, nil
24+
}
25+
26+
func TestClient_HandleElicitationRequest(t *testing.T) {
27+
tests := []struct {
28+
name string
29+
handler ElicitationHandler
30+
expectedError string
31+
}{
32+
{
33+
name: "no handler configured",
34+
handler: nil,
35+
expectedError: "no elicitation handler configured",
36+
},
37+
{
38+
name: "successful elicitation - accept",
39+
handler: &mockElicitationHandler{
40+
result: &mcp.ElicitationResult{
41+
Response: mcp.ElicitationResponse{
42+
Type: mcp.ElicitationResponseTypeAccept,
43+
Value: map[string]interface{}{
44+
"name": "test-project",
45+
"framework": "react",
46+
},
47+
},
48+
},
49+
},
50+
},
51+
{
52+
name: "successful elicitation - decline",
53+
handler: &mockElicitationHandler{
54+
result: &mcp.ElicitationResult{
55+
Response: mcp.ElicitationResponse{
56+
Type: mcp.ElicitationResponseTypeDecline,
57+
},
58+
},
59+
},
60+
},
61+
{
62+
name: "successful elicitation - cancel",
63+
handler: &mockElicitationHandler{
64+
result: &mcp.ElicitationResult{
65+
Response: mcp.ElicitationResponse{
66+
Type: mcp.ElicitationResponseTypeCancel,
67+
},
68+
},
69+
},
70+
},
71+
{
72+
name: "handler returns error",
73+
handler: &mockElicitationHandler{
74+
err: fmt.Errorf("user interaction failed"),
75+
},
76+
expectedError: "user interaction failed",
77+
},
78+
}
79+
80+
for _, tt := range tests {
81+
t.Run(tt.name, func(t *testing.T) {
82+
client := &Client{elicitationHandler: tt.handler}
83+
84+
request := transport.JSONRPCRequest{
85+
ID: mcp.NewRequestId(1),
86+
Method: string(mcp.MethodElicitationCreate),
87+
Params: map[string]interface{}{
88+
"message": "Please provide project details",
89+
"requestedSchema": map[string]interface{}{
90+
"type": "object",
91+
"properties": map[string]interface{}{
92+
"name": map[string]interface{}{"type": "string"},
93+
"framework": map[string]interface{}{"type": "string"},
94+
},
95+
},
96+
},
97+
}
98+
99+
result, err := client.handleElicitationRequestTransport(context.Background(), request)
100+
101+
if tt.expectedError != "" {
102+
if err == nil {
103+
t.Errorf("expected error %q, got nil", tt.expectedError)
104+
} else if err.Error() != tt.expectedError {
105+
t.Errorf("expected error %q, got %q", tt.expectedError, err.Error())
106+
}
107+
} else {
108+
if err != nil {
109+
t.Errorf("unexpected error: %v", err)
110+
}
111+
if result == nil {
112+
t.Error("expected result, got nil")
113+
} else {
114+
// Verify the response is properly formatted
115+
var elicitationResult mcp.ElicitationResult
116+
if err := json.Unmarshal(result.Result, &elicitationResult); err != nil {
117+
t.Errorf("failed to unmarshal result: %v", err)
118+
}
119+
}
120+
}
121+
})
122+
}
123+
}
124+
125+
func TestWithElicitationHandler(t *testing.T) {
126+
handler := &mockElicitationHandler{}
127+
client := &Client{}
128+
129+
option := WithElicitationHandler(handler)
130+
option(client)
131+
132+
if client.elicitationHandler != handler {
133+
t.Error("elicitation handler not set correctly")
134+
}
135+
}
136+
137+
func TestClient_Initialize_WithElicitationHandler(t *testing.T) {
138+
mockTransport := &mockElicitationTransport{
139+
sendRequestFunc: func(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
140+
// Verify that elicitation capability is included
141+
// The client internally converts the typed params to a map for transport
142+
// So we check if we're getting the initialize request
143+
if request.Method != "initialize" {
144+
t.Fatalf("expected initialize method, got %s", request.Method)
145+
}
146+
147+
// Return successful initialization response
148+
result := mcp.InitializeResult{
149+
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
150+
ServerInfo: mcp.Implementation{
151+
Name: "test-server",
152+
Version: "1.0.0",
153+
},
154+
Capabilities: mcp.ServerCapabilities{},
155+
}
156+
157+
resultBytes, _ := json.Marshal(result)
158+
return &transport.JSONRPCResponse{
159+
ID: request.ID,
160+
Result: json.RawMessage(resultBytes),
161+
}, nil
162+
},
163+
sendNotificationFunc: func(ctx context.Context, notification mcp.JSONRPCNotification) error {
164+
return nil
165+
},
166+
}
167+
168+
handler := &mockElicitationHandler{}
169+
client := NewClient(mockTransport, WithElicitationHandler(handler))
170+
171+
err := client.Start(context.Background())
172+
if err != nil {
173+
t.Fatalf("failed to start client: %v", err)
174+
}
175+
176+
_, err = client.Initialize(context.Background(), mcp.InitializeRequest{
177+
Params: mcp.InitializeParams{
178+
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
179+
ClientInfo: mcp.Implementation{
180+
Name: "test-client",
181+
Version: "1.0.0",
182+
},
183+
Capabilities: mcp.ClientCapabilities{},
184+
},
185+
})
186+
187+
if err != nil {
188+
t.Fatalf("failed to initialize: %v", err)
189+
}
190+
}
191+
192+
// mockElicitationTransport implements transport.Interface for testing
193+
type mockElicitationTransport struct {
194+
sendRequestFunc func(context.Context, transport.JSONRPCRequest) (*transport.JSONRPCResponse, error)
195+
sendNotificationFunc func(context.Context, mcp.JSONRPCNotification) error
196+
}
197+
198+
func (m *mockElicitationTransport) Start(ctx context.Context) error {
199+
return nil
200+
}
201+
202+
func (m *mockElicitationTransport) Close() error {
203+
return nil
204+
}
205+
206+
func (m *mockElicitationTransport) SendRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
207+
if m.sendRequestFunc != nil {
208+
return m.sendRequestFunc(ctx, request)
209+
}
210+
return nil, nil
211+
}
212+
213+
func (m *mockElicitationTransport) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error {
214+
if m.sendNotificationFunc != nil {
215+
return m.sendNotificationFunc(ctx, notification)
216+
}
217+
return nil
218+
}
219+
220+
func (m *mockElicitationTransport) SetNotificationHandler(handler func(mcp.JSONRPCNotification)) {
221+
}
222+
223+
func (m *mockElicitationTransport) GetSessionId() string {
224+
return "mock-session"
225+
}

0 commit comments

Comments
 (0)