diff --git a/server/streamable_http.go b/server/streamable_http.go index 5a596467..a87c3889 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -976,6 +976,8 @@ type streamableHttpSession struct { resourceTemplates *sessionResourceTemplatesStore upgradeToSSE atomic.Bool logLevels *sessionLogLevelsStore + clientInfo atomic.Value // stores session-specific client info + clientCapabilities atomic.Value // stores session-specific client capabilities // Sampling support for bidirectional communication samplingRequestChan chan samplingRequestItem // server -> client sampling requests @@ -1053,11 +1055,38 @@ func (s *streamableHttpSession) SetSessionResourceTemplates(templates map[string s.resourceTemplates.set(s.sessionID, templates) } +func (s *streamableHttpSession) GetClientInfo() mcp.Implementation { + if value := s.clientInfo.Load(); value != nil { + if clientInfo, ok := value.(mcp.Implementation); ok { + return clientInfo + } + } + return mcp.Implementation{} +} + +func (s *streamableHttpSession) SetClientInfo(clientInfo mcp.Implementation) { + s.clientInfo.Store(clientInfo) +} + +func (s *streamableHttpSession) GetClientCapabilities() mcp.ClientCapabilities { + if value := s.clientCapabilities.Load(); value != nil { + if clientCapabilities, ok := value.(mcp.ClientCapabilities); ok { + return clientCapabilities + } + } + return mcp.ClientCapabilities{} +} + +func (s *streamableHttpSession) SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) { + s.clientCapabilities.Store(clientCapabilities) +} + var ( _ SessionWithTools = (*streamableHttpSession)(nil) _ SessionWithResources = (*streamableHttpSession)(nil) _ SessionWithResourceTemplates = (*streamableHttpSession)(nil) _ SessionWithLogging = (*streamableHttpSession)(nil) + _ SessionWithClientInfo = (*streamableHttpSession)(nil) ) func (s *streamableHttpSession) UpgradeToSSEWhenReceiveNotification() { diff --git a/server/streamable_http_client_info_test.go b/server/streamable_http_client_info_test.go new file mode 100644 index 00000000..861065ea --- /dev/null +++ b/server/streamable_http_client_info_test.go @@ -0,0 +1,60 @@ +package server + +import ( + "testing" + + "github.com/mark3labs/mcp-go/mcp" +) + +func TestStreamableHttpSessionImplementsSessionWithClientInfo(t *testing.T) { + // Create the session stores + toolStore := newSessionToolsStore() + resourceStore := newSessionResourcesStore() + templatesStore := newSessionResourceTemplatesStore() + logStore := newSessionLogLevelsStore() + + // Create a streamable HTTP session + session := newStreamableHttpSession("test-session", toolStore, resourceStore, templatesStore, logStore) + + // Verify it implements SessionWithClientInfo + var clientSession ClientSession = session + clientInfoSession, ok := clientSession.(SessionWithClientInfo) + if !ok { + t.Fatal("streamableHttpSession should implement SessionWithClientInfo") + } + + // Test GetClientInfo with no data set (should return empty) + clientInfo := clientInfoSession.GetClientInfo() + if clientInfo.Name != "" || clientInfo.Version != "" { + t.Errorf("expected empty client info, got %+v", clientInfo) + } + + // Test SetClientInfo and GetClientInfo + expectedClientInfo := mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + clientInfoSession.SetClientInfo(expectedClientInfo) + + actualClientInfo := clientInfoSession.GetClientInfo() + if actualClientInfo.Name != expectedClientInfo.Name || actualClientInfo.Version != expectedClientInfo.Version { + t.Errorf("expected client info %+v, got %+v", expectedClientInfo, actualClientInfo) + } + + // Test GetClientCapabilities with no data set (should return empty) + capabilities := clientInfoSession.GetClientCapabilities() + if capabilities.Sampling != nil || capabilities.Roots != nil { + t.Errorf("expected empty client capabilities, got %+v", capabilities) + } + + // Test SetClientCapabilities and GetClientCapabilities + expectedCapabilities := mcp.ClientCapabilities{ + Sampling: &struct{}{}, + } + clientInfoSession.SetClientCapabilities(expectedCapabilities) + + actualCapabilities := clientInfoSession.GetClientCapabilities() + if actualCapabilities.Sampling == nil { + t.Errorf("expected sampling capability to be set") + } +} \ No newline at end of file