Skip to content

Commit cc3f765

Browse files
authored
Embed client capabilities into the Session (#491)
* Embed client capabilities into the Session Client sends its capabilities during the initialization step. This commit embeds the client capabilities into the client session in this step to enable subsequent executions are able to check the client capabilities to determine what actions they can perform. For instance, MCP Server checks the support of elicitation. If elicititation is supported by the client MCP Server can send elicitation request. * fix * Add document for client capability based filtering * Add return statement into the new example
1 parent 5800c20 commit cc3f765

File tree

7 files changed

+141
-18
lines changed

7 files changed

+141
-18
lines changed

server/inprocess_session.go

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@ type SamplingHandler interface {
1616
}
1717

1818
type InProcessSession struct {
19-
sessionID string
20-
notifications chan mcp.JSONRPCNotification
21-
initialized atomic.Bool
22-
loggingLevel atomic.Value
23-
clientInfo atomic.Value
24-
samplingHandler SamplingHandler
25-
mu sync.RWMutex
19+
sessionID string
20+
notifications chan mcp.JSONRPCNotification
21+
initialized atomic.Bool
22+
loggingLevel atomic.Value
23+
clientInfo atomic.Value
24+
clientCapabilities atomic.Value
25+
samplingHandler SamplingHandler
26+
mu sync.RWMutex
2627
}
2728

2829
func NewInProcessSession(sessionID string, samplingHandler SamplingHandler) *InProcessSession {
@@ -63,6 +64,19 @@ func (s *InProcessSession) SetClientInfo(clientInfo mcp.Implementation) {
6364
s.clientInfo.Store(clientInfo)
6465
}
6566

67+
func (s *InProcessSession) GetClientCapabilities() mcp.ClientCapabilities {
68+
if value := s.clientCapabilities.Load(); value != nil {
69+
if clientCapabilities, ok := value.(mcp.ClientCapabilities); ok {
70+
return clientCapabilities
71+
}
72+
}
73+
return mcp.ClientCapabilities{}
74+
}
75+
76+
func (s *InProcessSession) SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) {
77+
s.clientCapabilities.Store(clientCapabilities)
78+
}
79+
6680
func (s *InProcessSession) SetLogLevel(level mcp.LoggingLevel) {
6781
s.loggingLevel.Store(level)
6882
}

server/server.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -583,8 +583,10 @@ func (s *MCPServer) handleInitialize(
583583
// Store client info if the session supports it
584584
if sessionWithClientInfo, ok := session.(SessionWithClientInfo); ok {
585585
sessionWithClientInfo.SetClientInfo(request.Params.ClientInfo)
586+
sessionWithClientInfo.SetClientCapabilities(request.Params.Capabilities)
586587
}
587588
}
589+
588590
return &result, nil
589591
}
590592

server/session.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ type SessionWithClientInfo interface {
4646
GetClientInfo() mcp.Implementation
4747
// SetClientInfo sets the client information for this session
4848
SetClientInfo(clientInfo mcp.Implementation)
49+
// GetClientCapabilities returns the client capabilities for this session
50+
GetClientCapabilities() mcp.ClientCapabilities
51+
// SetClientCapabilities sets the client capabilities for this session
52+
SetClientCapabilities(clientCapabilities mcp.ClientCapabilities)
4953
}
5054

5155
// SessionWithStreamableHTTPConfig extends ClientSession to support streamable HTTP transport configurations

server/session_test.go

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ type sessionTestClientWithClientInfo struct {
106106
notificationChannel chan mcp.JSONRPCNotification
107107
initialized bool
108108
clientInfo atomic.Value
109+
clientCapabilities atomic.Value
109110
}
110111

111112
func (f *sessionTestClientWithClientInfo) SessionID() string {
@@ -137,6 +138,19 @@ func (f *sessionTestClientWithClientInfo) SetClientInfo(clientInfo mcp.Implement
137138
f.clientInfo.Store(clientInfo)
138139
}
139140

141+
func (f *sessionTestClientWithClientInfo) GetClientCapabilities() mcp.ClientCapabilities {
142+
if value := f.clientCapabilities.Load(); value != nil {
143+
if clientCapabilities, ok := value.(mcp.ClientCapabilities); ok {
144+
return clientCapabilities
145+
}
146+
}
147+
return mcp.ClientCapabilities{}
148+
}
149+
150+
func (f *sessionTestClientWithClientInfo) SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) {
151+
f.clientCapabilities.Store(clientCapabilities)
152+
}
153+
140154
// sessionTestClientWithTools implements the SessionWithLogging interface for testing
141155
type sessionTestClientWithLogging struct {
142156
sessionID string
@@ -888,7 +902,7 @@ func TestMCPServer_SessionToolCapabilitiesBehavior(t *testing.T) {
888902
validateServer func(t *testing.T, s *MCPServer, session *sessionTestClientWithTools)
889903
}{
890904
{
891-
name: "no tool capabilities provided",
905+
name: "no tool capabilities provided",
892906
serverOptions: []ServerOption{
893907
// No WithToolCapabilities
894908
},
@@ -1099,10 +1113,14 @@ func TestSessionWithClientInfo_Integration(t *testing.T) {
10991113
Version: "1.0.0",
11001114
}
11011115

1116+
clientCapability := mcp.ClientCapabilities{
1117+
Sampling: &struct{}{},
1118+
}
1119+
11021120
initRequest := mcp.InitializeRequest{}
11031121
initRequest.Params.ClientInfo = clientInfo
11041122
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
1105-
initRequest.Params.Capabilities = mcp.ClientCapabilities{}
1123+
initRequest.Params.Capabilities = clientCapability
11061124

11071125
sessionCtx := server.WithContext(context.Background(), session)
11081126

@@ -1125,6 +1143,10 @@ func TestSessionWithClientInfo_Integration(t *testing.T) {
11251143

11261144
assert.Equal(t, clientInfo.Name, storedClientInfo.Name, "Client name should match")
11271145
assert.Equal(t, clientInfo.Version, storedClientInfo.Version, "Client version should match")
1146+
1147+
storedClientCapabilities := sessionWithClientInfo.GetClientCapabilities()
1148+
1149+
assert.Equal(t, clientCapability, storedClientCapabilities, "Client capability should match")
11281150
}
11291151

11301152
// New test function to cover log notification functionality

server/sse.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ type sseSession struct {
3030
loggingLevel atomic.Value
3131
tools sync.Map // stores session-specific tools
3232
clientInfo atomic.Value // stores session-specific client info
33+
clientCapabilities atomic.Value // stores session-specific client capabilities
3334
}
3435

3536
// SSEContextFunc is a function that takes an existing context and the current
@@ -108,6 +109,19 @@ func (s *sseSession) SetClientInfo(clientInfo mcp.Implementation) {
108109
s.clientInfo.Store(clientInfo)
109110
}
110111

112+
func (s *sseSession) SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) {
113+
s.clientCapabilities.Store(clientCapabilities)
114+
}
115+
116+
func (s *sseSession) GetClientCapabilities() mcp.ClientCapabilities {
117+
if value := s.clientCapabilities.Load(); value != nil {
118+
if clientCapabilities, ok := value.(mcp.ClientCapabilities); ok {
119+
return clientCapabilities
120+
}
121+
}
122+
return mcp.ClientCapabilities{}
123+
}
124+
111125
var (
112126
_ ClientSession = (*sseSession)(nil)
113127
_ SessionWithTools = (*sseSession)(nil)

server/stdio.go

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,16 @@ func WithStdioContextFunc(fn StdioContextFunc) StdioOption {
5252

5353
// stdioSession is a static client session, since stdio has only one client.
5454
type stdioSession struct {
55-
notifications chan mcp.JSONRPCNotification
56-
initialized atomic.Bool
57-
loggingLevel atomic.Value
58-
clientInfo atomic.Value // stores session-specific client info
59-
writer io.Writer // for sending requests to client
60-
requestID atomic.Int64 // for generating unique request IDs
61-
mu sync.RWMutex // protects writer
62-
pendingRequests map[int64]chan *samplingResponse // for tracking pending sampling requests
63-
pendingMu sync.RWMutex // protects pendingRequests
55+
notifications chan mcp.JSONRPCNotification
56+
initialized atomic.Bool
57+
loggingLevel atomic.Value
58+
clientInfo atomic.Value // stores session-specific client info
59+
clientCapabilities atomic.Value // stores session-specific client capabilities
60+
writer io.Writer // for sending requests to client
61+
requestID atomic.Int64 // for generating unique request IDs
62+
mu sync.RWMutex // protects writer
63+
pendingRequests map[int64]chan *samplingResponse // for tracking pending sampling requests
64+
pendingMu sync.RWMutex // protects pendingRequests
6465
}
6566

6667
// samplingResponse represents a response to a sampling request
@@ -100,6 +101,19 @@ func (s *stdioSession) SetClientInfo(clientInfo mcp.Implementation) {
100101
s.clientInfo.Store(clientInfo)
101102
}
102103

104+
func (s *stdioSession) GetClientCapabilities() mcp.ClientCapabilities {
105+
if value := s.clientCapabilities.Load(); value != nil {
106+
if clientCapabilities, ok := value.(mcp.ClientCapabilities); ok {
107+
return clientCapabilities
108+
}
109+
}
110+
return mcp.ClientCapabilities{}
111+
}
112+
113+
func (s *stdioSession) SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) {
114+
s.clientCapabilities.Store(clientCapabilities)
115+
}
116+
103117
func (s *stdioSession) SetLogLevel(level mcp.LoggingLevel) {
104118
s.loggingLevel.Store(level)
105119
}

www/docs/pages/servers/advanced.mdx

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,59 @@ func startWithGracefulShutdown(s *server.MCPServer) {
821821
}
822822
```
823823

824+
## Client Capability Based Filtering
825+
826+
```go
827+
package main
828+
829+
import (
830+
"context"
831+
"fmt"
832+
833+
"github.com/mark3labs/mcp-go/mcp"
834+
"github.com/mark3labs/mcp-go/server"
835+
)
836+
837+
func main() {
838+
s := server.NewMCPServer("Typed Server", "1.0.0",
839+
server.WithToolCapabilities(true),
840+
)
841+
842+
s.AddTool(
843+
mcp.NewTool("calculate",
844+
mcp.WithDescription("Perform basic mathematical calculations"),
845+
mcp.WithString("operation",
846+
mcp.Required(),
847+
mcp.Enum("add", "subtract", "multiply", "divide"),
848+
mcp.Description("The operation to perform"),
849+
),
850+
mcp.WithNumber("x", mcp.Required(), mcp.Description("First number")),
851+
mcp.WithNumber("y", mcp.Required(), mcp.Description("Second number")),
852+
),
853+
handleCalculate,
854+
)
855+
856+
server.ServeStdio(s)
857+
}
858+
859+
func handleCalculate(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
860+
session := server.ClientSessionFromContext(ctx)
861+
if session == nil {
862+
return nil, fmt.Errorf("no active session")
863+
}
864+
865+
if clientSession, ok := session.(server.SessionWithClientInfo); ok {
866+
clientCapabilities := clientSession.GetClientCapabilities()
867+
if clientCapabilities.Sampling == nil {
868+
fmt.Println("sampling is not enabled in client")
869+
}
870+
}
871+
872+
// TODO: implement calculation logic
873+
return mcp.NewToolResultError("not implemented"), nil
874+
}
875+
```
876+
824877
## Sampling (Advanced)
825878

826879
Sampling is an advanced feature that allows servers to request LLM completions from clients. This enables bidirectional communication where servers can leverage client-side LLM capabilities for content generation, reasoning, and question answering.

0 commit comments

Comments
 (0)