Skip to content

Commit a27a793

Browse files
authored
fix: implement SessionWithClientInfo for streamableHttpSession (#640)
- Add clientInfo and clientCapabilities fields to streamableHttpSession struct - Implement GetClientInfo, SetClientInfo, GetClientCapabilities, SetClientCapabilities methods - Add SessionWithClientInfo interface declaration - Enable client capability checking for StreamableHTTP transport in stateful mode Fixes #639
1 parent 2a23f4a commit a27a793

File tree

2 files changed

+89
-0
lines changed

2 files changed

+89
-0
lines changed

server/streamable_http.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -986,6 +986,8 @@ type streamableHttpSession struct {
986986
resourceTemplates *sessionResourceTemplatesStore
987987
upgradeToSSE atomic.Bool
988988
logLevels *sessionLogLevelsStore
989+
clientInfo atomic.Value // stores session-specific client info
990+
clientCapabilities atomic.Value // stores session-specific client capabilities
989991

990992
// Sampling support for bidirectional communication
991993
samplingRequestChan chan samplingRequestItem // server -> client sampling requests
@@ -1063,11 +1065,38 @@ func (s *streamableHttpSession) SetSessionResourceTemplates(templates map[string
10631065
s.resourceTemplates.set(s.sessionID, templates)
10641066
}
10651067

1068+
func (s *streamableHttpSession) GetClientInfo() mcp.Implementation {
1069+
if value := s.clientInfo.Load(); value != nil {
1070+
if clientInfo, ok := value.(mcp.Implementation); ok {
1071+
return clientInfo
1072+
}
1073+
}
1074+
return mcp.Implementation{}
1075+
}
1076+
1077+
func (s *streamableHttpSession) SetClientInfo(clientInfo mcp.Implementation) {
1078+
s.clientInfo.Store(clientInfo)
1079+
}
1080+
1081+
func (s *streamableHttpSession) GetClientCapabilities() mcp.ClientCapabilities {
1082+
if value := s.clientCapabilities.Load(); value != nil {
1083+
if clientCapabilities, ok := value.(mcp.ClientCapabilities); ok {
1084+
return clientCapabilities
1085+
}
1086+
}
1087+
return mcp.ClientCapabilities{}
1088+
}
1089+
1090+
func (s *streamableHttpSession) SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) {
1091+
s.clientCapabilities.Store(clientCapabilities)
1092+
}
1093+
10661094
var (
10671095
_ SessionWithTools = (*streamableHttpSession)(nil)
10681096
_ SessionWithResources = (*streamableHttpSession)(nil)
10691097
_ SessionWithResourceTemplates = (*streamableHttpSession)(nil)
10701098
_ SessionWithLogging = (*streamableHttpSession)(nil)
1099+
_ SessionWithClientInfo = (*streamableHttpSession)(nil)
10711100
)
10721101

10731102
func (s *streamableHttpSession) UpgradeToSSEWhenReceiveNotification() {
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package server
2+
3+
import (
4+
"testing"
5+
6+
"github.com/mark3labs/mcp-go/mcp"
7+
)
8+
9+
func TestStreamableHttpSessionImplementsSessionWithClientInfo(t *testing.T) {
10+
// Create the session stores
11+
toolStore := newSessionToolsStore()
12+
resourceStore := newSessionResourcesStore()
13+
templatesStore := newSessionResourceTemplatesStore()
14+
logStore := newSessionLogLevelsStore()
15+
16+
// Create a streamable HTTP session
17+
session := newStreamableHttpSession("test-session", toolStore, resourceStore, templatesStore, logStore)
18+
19+
// Verify it implements SessionWithClientInfo
20+
var clientSession ClientSession = session
21+
clientInfoSession, ok := clientSession.(SessionWithClientInfo)
22+
if !ok {
23+
t.Fatal("streamableHttpSession should implement SessionWithClientInfo")
24+
}
25+
26+
// Test GetClientInfo with no data set (should return empty)
27+
clientInfo := clientInfoSession.GetClientInfo()
28+
if clientInfo.Name != "" || clientInfo.Version != "" {
29+
t.Errorf("expected empty client info, got %+v", clientInfo)
30+
}
31+
32+
// Test SetClientInfo and GetClientInfo
33+
expectedClientInfo := mcp.Implementation{
34+
Name: "test-client",
35+
Version: "1.0.0",
36+
}
37+
clientInfoSession.SetClientInfo(expectedClientInfo)
38+
39+
actualClientInfo := clientInfoSession.GetClientInfo()
40+
if actualClientInfo.Name != expectedClientInfo.Name || actualClientInfo.Version != expectedClientInfo.Version {
41+
t.Errorf("expected client info %+v, got %+v", expectedClientInfo, actualClientInfo)
42+
}
43+
44+
// Test GetClientCapabilities with no data set (should return empty)
45+
capabilities := clientInfoSession.GetClientCapabilities()
46+
if capabilities.Sampling != nil || capabilities.Roots != nil {
47+
t.Errorf("expected empty client capabilities, got %+v", capabilities)
48+
}
49+
50+
// Test SetClientCapabilities and GetClientCapabilities
51+
expectedCapabilities := mcp.ClientCapabilities{
52+
Sampling: &struct{}{},
53+
}
54+
clientInfoSession.SetClientCapabilities(expectedCapabilities)
55+
56+
actualCapabilities := clientInfoSession.GetClientCapabilities()
57+
if actualCapabilities.Sampling == nil {
58+
t.Errorf("expected sampling capability to be set")
59+
}
60+
}

0 commit comments

Comments
 (0)