Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions server/streamable_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,8 @@ type streamableHttpSession struct {
resourceTemplates *sessionResourceTemplatesStore
upgradeToSSE atomic.Bool
logLevels *sessionLogLevelsStore
clientInfo atomic.Value // stores session-specific client info
clientCapabilities atomic.Value // stores session-specific client capabilities

// Sampling support for bidirectional communication
samplingRequestChan chan samplingRequestItem // server -> client sampling requests
Expand Down Expand Up @@ -1053,11 +1055,38 @@ func (s *streamableHttpSession) SetSessionResourceTemplates(templates map[string
s.resourceTemplates.set(s.sessionID, templates)
}

func (s *streamableHttpSession) GetClientInfo() mcp.Implementation {
if value := s.clientInfo.Load(); value != nil {
if clientInfo, ok := value.(mcp.Implementation); ok {
return clientInfo
}
}
return mcp.Implementation{}
}

func (s *streamableHttpSession) SetClientInfo(clientInfo mcp.Implementation) {
s.clientInfo.Store(clientInfo)
}

func (s *streamableHttpSession) GetClientCapabilities() mcp.ClientCapabilities {
if value := s.clientCapabilities.Load(); value != nil {
if clientCapabilities, ok := value.(mcp.ClientCapabilities); ok {
return clientCapabilities
}
}
return mcp.ClientCapabilities{}
}

func (s *streamableHttpSession) SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) {
s.clientCapabilities.Store(clientCapabilities)
}

var (
_ SessionWithTools = (*streamableHttpSession)(nil)
_ SessionWithResources = (*streamableHttpSession)(nil)
_ SessionWithResourceTemplates = (*streamableHttpSession)(nil)
_ SessionWithLogging = (*streamableHttpSession)(nil)
_ SessionWithClientInfo = (*streamableHttpSession)(nil)
)

func (s *streamableHttpSession) UpgradeToSSEWhenReceiveNotification() {
Expand Down
60 changes: 60 additions & 0 deletions server/streamable_http_client_info_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package server

import (
"testing"

"github.com/mark3labs/mcp-go/mcp"
)

func TestStreamableHttpSessionImplementsSessionWithClientInfo(t *testing.T) {
// Create the session stores
toolStore := newSessionToolsStore()
resourceStore := newSessionResourcesStore()
templatesStore := newSessionResourceTemplatesStore()
logStore := newSessionLogLevelsStore()

// Create a streamable HTTP session
session := newStreamableHttpSession("test-session", toolStore, resourceStore, templatesStore, logStore)

// Verify it implements SessionWithClientInfo
var clientSession ClientSession = session
clientInfoSession, ok := clientSession.(SessionWithClientInfo)
if !ok {
t.Fatal("streamableHttpSession should implement SessionWithClientInfo")
}

// Test GetClientInfo with no data set (should return empty)
clientInfo := clientInfoSession.GetClientInfo()
if clientInfo.Name != "" || clientInfo.Version != "" {
t.Errorf("expected empty client info, got %+v", clientInfo)
}

// Test SetClientInfo and GetClientInfo
expectedClientInfo := mcp.Implementation{
Name: "test-client",
Version: "1.0.0",
}
clientInfoSession.SetClientInfo(expectedClientInfo)

actualClientInfo := clientInfoSession.GetClientInfo()
if actualClientInfo.Name != expectedClientInfo.Name || actualClientInfo.Version != expectedClientInfo.Version {
t.Errorf("expected client info %+v, got %+v", expectedClientInfo, actualClientInfo)
}

// Test GetClientCapabilities with no data set (should return empty)
capabilities := clientInfoSession.GetClientCapabilities()
if capabilities.Sampling != nil || capabilities.Roots != nil {
t.Errorf("expected empty client capabilities, got %+v", capabilities)
}

// Test SetClientCapabilities and GetClientCapabilities
expectedCapabilities := mcp.ClientCapabilities{
Sampling: &struct{}{},
}
clientInfoSession.SetClientCapabilities(expectedCapabilities)

actualCapabilities := clientInfoSession.GetClientCapabilities()
if actualCapabilities.Sampling == nil {
t.Errorf("expected sampling capability to be set")
}
}
Loading