diff --git a/README.md b/README.md index a35a3ebe0..0c056883d 100644 --- a/README.md +++ b/README.md @@ -547,6 +547,7 @@ MCP-Go provides a robust session management system that allows you to: - Register and track client sessions - Send notifications to specific clients - Provide per-session tool customization +- Provide per-session prompt customization
Show Session Management Examples diff --git a/server/errors.go b/server/errors.go index ecbe91e5f..5edf318e4 100644 --- a/server/errors.go +++ b/server/errors.go @@ -17,6 +17,7 @@ var ( ErrSessionExists = errors.New("session already exists") ErrSessionNotInitialized = errors.New("session not properly initialized") ErrSessionDoesNotSupportTools = errors.New("session does not support per-session tools") + ErrSessionDoesNotSupportPrompts = errors.New("session does not support per-session prompts") ErrSessionDoesNotSupportLogging = errors.New("session does not support setting logging level") // Notification-related errors diff --git a/server/server.go b/server/server.go index 46e6d9c57..3a78b84b9 100644 --- a/server/server.go +++ b/server/server.go @@ -832,6 +832,34 @@ func (s *MCPServer) handleListPrompts( } s.promptsMu.RUnlock() + // Check if there are session-specific prompts + session := ClientSessionFromContext(ctx) + if session != nil { + if sessionWithPrompts, ok := session.(SessionWithPrompts); ok { + if sessionPrompts := sessionWithPrompts.GetSessionPrompts(); sessionPrompts != nil { + // Override or add session-specific prompts + // We need to create a map first to merge the prompts properly + promptMap := make(map[string]mcp.Prompt) + + // Add global prompts first + for _, prompt := range prompts { + promptMap[prompt.Name] = prompt + } + + // Then override with session-specific tools + for name, serverPrompt := range sessionPrompts { + promptMap[name] = serverPrompt.Prompt + } + + // Convert back to slice + prompts = make([]mcp.Prompt, 0, len(promptMap)) + for _, prompt := range promptMap { + prompts = append(prompts, prompt) + } + } + } + } + // sort prompts by name sort.Slice(prompts, func(i, j int) bool { return prompts[i].Name < prompts[j].Name @@ -863,9 +891,28 @@ func (s *MCPServer) handleGetPrompt( id any, request mcp.GetPromptRequest, ) (*mcp.GetPromptResult, *requestError) { - s.promptsMu.RLock() - handler, ok := s.promptHandlers[request.Params.Name] - s.promptsMu.RUnlock() + // First check session-specific prompts + var handler PromptHandlerFunc + var ok bool + + session := ClientSessionFromContext(ctx) + if session != nil { + if sessionWithPrompts, typeAssertOk := session.(SessionWithPrompts); typeAssertOk { + if sessionPrompts := sessionWithPrompts.GetSessionPrompts(); sessionPrompts != nil { + if serverPrompt, sessionOk := sessionPrompts[request.Params.Name]; sessionOk { + handler = serverPrompt.Handler + ok = true + } + } + } + } + + // If not found in session prompts, check global prompts + if !ok { + s.promptsMu.RLock() + handler, ok = s.promptHandlers[request.Params.Name] + s.promptsMu.RUnlock() + } if !ok { return nil, &requestError{ diff --git a/server/session.go b/server/session.go index a79da22ca..9c58c5873 100644 --- a/server/session.go +++ b/server/session.go @@ -39,6 +39,14 @@ type SessionWithTools interface { SetSessionTools(tools map[string]ServerTool) } +type SessionWithPrompts interface { + ClientSession + // GetPrompts returns the prompts specific to this session, if any + GetSessionPrompts() map[string]ServerPrompt + // SetPrompts sets prompts specific to this session + SetSessionPrompts(prompts map[string]ServerPrompt) +} + // SessionWithClientInfo is an extension of ClientSession that can store client info type SessionWithClientInfo interface { ClientSession @@ -378,3 +386,60 @@ func (s *MCPServer) DeleteSessionTools(sessionID string, names ...string) error return nil } + +func (s *MCPServer) AddSessionPrompts(sessionID string, prompts ...ServerPrompt) error { + sessionValue, ok := s.sessions.Load(sessionID) + if !ok { + return ErrSessionNotFound + } + + session, ok := sessionValue.(SessionWithPrompts) + if !ok { + return ErrSessionDoesNotSupportPrompts + } + + s.implicitlyRegisterPromptCapabilities() + + // Get existing prompts (this should return a thread-safe copy) + sessionPrompts := session.GetSessionPrompts() + + // Create a new map to avoid concurrent modification issues + newSessionPrompts := make(map[string]ServerPrompt, len(sessionPrompts)+len(prompts)) + + // Copy existing prompts + for k, v := range sessionPrompts { + newSessionPrompts[k] = v + } + + // Add new prompts + for _, prompt := range prompts { + newSessionPrompts[prompt.Prompt.Name] = prompt + } + + // Set the prompts (this should be thread-safe) + session.SetSessionPrompts(newSessionPrompts) + + if session.Initialized() && s.capabilities.prompts != nil && s.capabilities.prompts.listChanged { + // Send notification only to this session + if err := s.SendNotificationToSpecificClient(sessionID, "notifications/prompts/list_changed", nil); err != nil { + // Log the error but don't fail the operation + // The prompts were successfully added, but notification failed + if s.hooks != nil && len(s.hooks.OnError) > 0 { + hooks := s.hooks + go func(sID string, hooks *Hooks) { + ctx := context.Background() + hooks.onError(ctx, nil, "notification", map[string]any{ + "method": "notifications/prompts/list_changed", + "sessionID": sID, + }, fmt.Errorf("failed to send notification after adding prompts: %w", err)) + }(sessionID, hooks) + } + } + } + + return nil +} + +func (s *MCPServer) AddSessionPrompt(sessionID string, prompt mcp.Prompt, handler PromptHandlerFunc) error { + return s.AddSessionPrompts(sessionID, ServerPrompt{Prompt: prompt, Handler: handler}) +} diff --git a/server/session_test.go b/server/session_test.go index 3067f4e9c..330b19f7a 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -137,7 +137,7 @@ func (f *sessionTestClientWithClientInfo) SetClientInfo(clientInfo mcp.Implement f.clientInfo.Store(clientInfo) } -// sessionTestClientWithTools implements the SessionWithLogging interface for testing +// sessionTestClientWithLogging implements the SessionWithLogging interface for testing type sessionTestClientWithLogging struct { sessionID string notificationChannel chan mcp.JSONRPCNotification @@ -172,12 +172,70 @@ func (f *sessionTestClientWithLogging) GetLogLevel() mcp.LoggingLevel { return level.(mcp.LoggingLevel) } +// sessionTestClientWithPrompts implements the SessionWithPrompts interface for testing +type sessionTestClientWithPrompts struct { + sessionID string + notificationChannel chan mcp.JSONRPCNotification + initialized bool + sessionPrompts map[string]ServerPrompt + mu sync.RWMutex // Mutex to protect concurrent access to sessionPrompts +} + +func (f *sessionTestClientWithPrompts) SessionID() string { + return f.sessionID +} + +func (f *sessionTestClientWithPrompts) NotificationChannel() chan<- mcp.JSONRPCNotification { + return f.notificationChannel +} + +func (f *sessionTestClientWithPrompts) Initialize() { + f.initialized = true +} + +func (f *sessionTestClientWithPrompts) Initialized() bool { + return f.initialized +} + +func (f *sessionTestClientWithPrompts) GetSessionPrompts() map[string]ServerPrompt { + f.mu.RLock() + defer f.mu.RUnlock() + + // Return a copy of the map to prevent concurrent modification + if f.sessionPrompts == nil { + return nil + } + + promptsCopy := make(map[string]ServerPrompt, len(f.sessionPrompts)) + for k, v := range f.sessionPrompts { + promptsCopy[k] = v + } + return promptsCopy +} + +func (f *sessionTestClientWithPrompts) SetSessionPrompts(prompts map[string]ServerPrompt) { + f.mu.Lock() + defer f.mu.Unlock() + + if prompts == nil { + f.sessionPrompts = nil + return + } + + promptsCopy := make(map[string]ServerPrompt, len(prompts)) + for k, v := range prompts { + promptsCopy[k] = v + } + f.sessionPrompts = promptsCopy +} + // Verify that all implementations satisfy their respective interfaces var ( _ ClientSession = (*sessionTestClient)(nil) _ SessionWithTools = (*sessionTestClientWithTools)(nil) _ SessionWithLogging = (*sessionTestClientWithLogging)(nil) _ SessionWithClientInfo = (*sessionTestClientWithClientInfo)(nil) + _ SessionWithPrompts = (*sessionTestClientWithPrompts)(nil) ) func TestSessionWithTools_Integration(t *testing.T) { @@ -335,6 +393,41 @@ func TestMCPServer_AddSessionTools(t *testing.T) { assert.Contains(t, session.GetSessionTools(), "session-tool") } +func TestMCPServer_AddSessionPrompts(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", WithPromptCapabilities(true)) + ctx := context.Background() + + // Create a session + sessionChan := make(chan mcp.JSONRPCNotification, 10) + session := &sessionTestClientWithPrompts{ + sessionID: "session-1", + notificationChannel: sessionChan, + initialized: true, + } + + // Register the session + err := server.RegisterSession(ctx, session) + require.NoError(t, err) + + // Add session-specific prompts + err = server.AddSessionPrompts(session.SessionID(), + ServerPrompt{Prompt: mcp.NewPrompt("session-prompt")}, + ) + require.NoError(t, err) + + // Check that notification was sent + select { + case notification := <-sessionChan: + assert.Equal(t, "notifications/prompts/list_changed", notification.Method) + case <-time.After(100 * time.Millisecond): + t.Error("Expected notification not received") + } + + // Verify prompt was added to session + assert.Len(t, session.GetSessionPrompts(), 1) + assert.Contains(t, session.GetSessionPrompts(), "session-prompt") +} + func TestMCPServer_AddSessionTool(t *testing.T) { server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(true)) ctx := context.Background() @@ -374,6 +467,50 @@ func TestMCPServer_AddSessionTool(t *testing.T) { assert.Contains(t, session.GetSessionTools(), "session-tool-helper") } +func TestMCPServer_AddSessionPrompt(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", WithPromptCapabilities(true)) + ctx := context.Background() + + // Create a session + sessionChan := make(chan mcp.JSONRPCNotification, 10) + session := &sessionTestClientWithPrompts{ + sessionID: "session-1", + notificationChannel: sessionChan, + initialized: true, + } + + // Register the session + err := server.RegisterSession(ctx, session) + require.NoError(t, err) + + // Add session-specific tool using the new helper method + err = server.AddSessionPrompt( + session.SessionID(), + mcp.NewPrompt("session-prompt-helper"), + func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + return mcp.NewGetPromptResult("helper result", []mcp.PromptMessage{ + { + Role: mcp.RoleUser, + Content: mcp.TextContent{Text: "helper result"}, + }, + }), nil + }, + ) + require.NoError(t, err) + + // Check that notification was sent + select { + case notification := <-sessionChan: + assert.Equal(t, "notifications/prompts/list_changed", notification.Method) + case <-time.After(100 * time.Millisecond): + t.Error("Expected notification not received") + } + + // Verify tool was added to session + assert.Len(t, session.GetSessionPrompts(), 1) + assert.Contains(t, session.GetSessionPrompts(), "session-prompt-helper") +} + func TestMCPServer_AddSessionToolsUninitialized(t *testing.T) { // This test verifies that adding tools to an uninitialized session works correctly. // @@ -465,6 +602,97 @@ func TestMCPServer_AddSessionToolsUninitialized(t *testing.T) { assert.Contains(t, session.GetSessionTools(), "initialized-tool") } +func TestMCPServer_AddSessionPromptsUninitialized(t *testing.T) { + // This test verifies that adding prompts to an uninitialized session works correctly. + // + // This scenario can occur when prompts are added during the session registration hook, + // before the session is fully initialized. In this case, we should: + // 1. Successfully add the prompts to the session + // 2. Not attempt to send a notification (since the session isn't ready) + // 3. Have the prompts available once the session is initialized + // 4. Not trigger any error hooks when adding prompts to uninitialized sessions + + // Set up error hook to track if it's called + errorChan := make(chan error) + hooks := &Hooks{} + hooks.AddOnError( + func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) { + errorChan <- err + }, + ) + + server := NewMCPServer("test-server", "1.0.0", + WithPromptCapabilities(true), + WithHooks(hooks), + ) + ctx := context.Background() + + // Create an uninitialized session + sessionChan := make(chan mcp.JSONRPCNotification, 1) + session := &sessionTestClientWithPrompts{ + sessionID: "uninitialized-session", + notificationChannel: sessionChan, + initialized: false, + } + + // Register the session + err := server.RegisterSession(ctx, session) + require.NoError(t, err) + + // Add session-specific tools to the uninitialized session + err = server.AddSessionPrompts(session.SessionID(), + ServerPrompt{Prompt: mcp.NewPrompt("uninitialized-prompt")}, + ) + require.NoError(t, err) + + // Verify no errors + select { + case err := <-errorChan: + t.Error("Expected no errors, but OnError called with: ", err) + case <-time.After(25 * time.Millisecond): // no errors + } + + // Verify no notification was sent (channel should be empty) + select { + case <-sessionChan: + t.Error("Expected no notification to be sent for uninitialized session") + default: // no notifications + } + + // Verify prompt was added to session + assert.Len(t, session.GetSessionPrompts(), 1) + assert.Contains(t, session.GetSessionPrompts(), "uninitialized-prompt") + + // Initialize the session + session.Initialize() + + // Now verify that subsequent tool additions will send notifications + err = server.AddSessionPrompts(session.SessionID(), + ServerPrompt{Prompt: mcp.NewPrompt("initialized-prompt")}, + ) + require.NoError(t, err) + + // Verify no errors + select { + case err := <-errorChan: + t.Error("Expected no errors, but OnError called with:", err) + case <-time.After(200 * time.Millisecond): // No errors + } + + // Verify notification was sent for the initialized session + select { + case notification := <-sessionChan: + assert.Equal(t, "notifications/prompts/list_changed", notification.Method) + case <-time.After(100 * time.Millisecond): + t.Error("Timeout waiting for expected notifications/prompts/list_changed notification") + } + + // Verify both tools are available + assert.Len(t, session.GetSessionPrompts(), 2) + assert.Contains(t, session.GetSessionPrompts(), "uninitialized-prompt") + assert.Contains(t, session.GetSessionPrompts(), "initialized-prompt") +} + func TestMCPServer_DeleteSessionToolsUninitialized(t *testing.T) { // This test verifies that deleting tools from an uninitialized session works correctly. // @@ -614,6 +842,78 @@ func TestMCPServer_CallSessionTool(t *testing.T) { } } +func TestMCPServer_GetSessionPrompt(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", WithPromptCapabilities(true)) + + // Add global prompt + server.AddPrompt(mcp.NewPrompt("test_prompt"), func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + return mcp.NewGetPromptResult("global result", []mcp.PromptMessage{ + { + Role: mcp.RoleUser, + Content: mcp.TextContent{Text: "global result"}, + }, + }), nil + }) + + // Create a session + sessionChan := make(chan mcp.JSONRPCNotification, 10) + session := &sessionTestClientWithPrompts{ + sessionID: "session-1", + notificationChannel: sessionChan, + initialized: true, + } + + // Register the session + err := server.RegisterSession(context.Background(), session) + require.NoError(t, err) + + // Add session-specific prompt with the same name to override the global prompt + err = server.AddSessionPrompt( + session.SessionID(), + mcp.NewPrompt("test_prompt"), + func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + return mcp.NewGetPromptResult("session result", []mcp.PromptMessage{ + { + Role: mcp.RoleUser, + Content: mcp.TextContent{Text: "session result"}, + }, + }), nil + }, + ) + require.NoError(t, err) + + // Get the prompt using session context + sessionCtx := server.WithContext(context.Background(), session) + toolRequest := map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "prompts/get", + "params": map[string]any{ + "name": "test_prompt", + }, + } + requestBytes, err := json.Marshal(toolRequest) + if err != nil { + t.Fatalf("Failed to marshal prompt request: %v", err) + } + + response := server.HandleMessage(sessionCtx, requestBytes) + resp, ok := response.(mcp.JSONRPCResponse) + assert.True(t, ok) + + getPromptResult, ok := resp.Result.(mcp.GetPromptResult) + assert.True(t, ok) + + // Since we specify a prompt with the same name for current session, the expected text should be "session result" + if textContent, ok := getPromptResult.Messages[0].Content.(mcp.TextContent); ok { + if textContent.Text != "session result" { + t.Errorf("Expected result 'session result', got %q", textContent.Text) + } + } else { + t.Error("Expected TextContent") + } +} + func TestMCPServer_DeleteSessionTools(t *testing.T) { server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(true)) ctx := context.Background() diff --git a/server/sse.go b/server/sse.go index 416995730..aacf0cb5a 100644 --- a/server/sse.go +++ b/server/sse.go @@ -29,6 +29,7 @@ type sseSession struct { initialized atomic.Bool loggingLevel atomic.Value tools sync.Map // stores session-specific tools + prompts sync.Map // stores session-specific prompts clientInfo atomic.Value // stores session-specific client info } @@ -74,6 +75,17 @@ func (s *sseSession) GetLogLevel() mcp.LoggingLevel { return level.(mcp.LoggingLevel) } +func (s *sseSession) GetSessionPrompts() map[string]ServerPrompt { + prompts := make(map[string]ServerPrompt) + s.prompts.Range(func(key, value any) bool { + if prompt, ok := value.(ServerPrompt); ok { + prompts[key.(string)] = prompt + } + return true + }) + return prompts +} + func (s *sseSession) GetSessionTools() map[string]ServerTool { tools := make(map[string]ServerTool) s.tools.Range(func(key, value any) bool { @@ -85,6 +97,16 @@ func (s *sseSession) GetSessionTools() map[string]ServerTool { return tools } +func (s *sseSession) SetSessionPrompts(prompts map[string]ServerPrompt) { + // Clear existing prompts + s.prompts.Clear() + + // Set new prompts + for name, prompt := range prompts { + s.prompts.Store(name, prompt) + } +} + func (s *sseSession) SetSessionTools(tools map[string]ServerTool) { // Clear existing tools s.tools.Clear() diff --git a/server/sse_test.go b/server/sse_test.go index 96912be49..7ccef4164 100644 --- a/server/sse_test.go +++ b/server/sse_test.go @@ -1141,6 +1141,112 @@ func TestSSEServer(t *testing.T) { } }) + t.Run("TestSessionWithPrompts", func(t *testing.T) { + // Create hooks to track sessions + hooks := &Hooks{} + var registeredSession *sseSession + hooks.AddOnRegisterSession(func(ctx context.Context, session ClientSession) { + if s, ok := session.(*sseSession); ok { + registeredSession = s + } + }) + + mcpServer := NewMCPServer("test", "1.0.0", WithHooks(hooks)) + testServer := NewTestServer(mcpServer) + defer testServer.Close() + + // Connect to SSE endpoint + sseResp, err := http.Get(fmt.Sprintf("%s/sse", testServer.URL)) + if err != nil { + t.Fatalf("Failed to connect to SSE endpoint: %v", err) + } + defer sseResp.Body.Close() + + // Read the endpoint event to ensure session is established + _, err = readSSEEvent(sseResp) + if err != nil { + t.Fatalf("Failed to read SSE response: %v", err) + } + + // Verify we got a session + if registeredSession == nil { + t.Fatal("Session was not registered via hook") + } + + // Test setting and getting prompts + prompts := map[string]ServerPrompt{ + "test_prompt": { + Prompt: mcp.Prompt{ + Name: "test_prompt", + Description: "A test prompt", + }, + Handler: func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + return mcp.NewGetPromptResult("test", []mcp.PromptMessage{ + { + Role: mcp.RoleUser, + Content: mcp.TextContent{Text: "test"}, + }, + }), nil + }, + }, + } + + // Test SetSessionPrompts + registeredSession.SetSessionPrompts(prompts) + + // Test GetSessionPrompts + retrievedPrompts := registeredSession.GetSessionPrompts() + if len(retrievedPrompts) != 1 { + t.Errorf("Expected 1 prompt, got %d", len(retrievedPrompts)) + } + if prompt, exists := retrievedPrompts["test_prompt"]; !exists { + t.Error("Expected test_prompt to exist") + } else if prompt.Prompt.Name != "test_prompt" { + t.Errorf("Expected prompt name test_prompt, got %s", prompt.Prompt.Name) + } + + // Test concurrent access + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(2) + go func(i int) { + defer wg.Done() + prompts := map[string]ServerPrompt{ + fmt.Sprintf("prompt_%d", i): { + Prompt: mcp.Prompt{ + Name: fmt.Sprintf("prompt_%d", i), + Description: fmt.Sprintf("Prompt %d", i), + }, + }, + } + registeredSession.SetSessionPrompts(prompts) + }(i) + go func() { + defer wg.Done() + _ = registeredSession.GetSessionPrompts() + }() + } + wg.Wait() + + // Verify we can still get and set tools after concurrent access + finalPrompts := map[string]ServerPrompt{ + "final_prompt": { + Prompt: mcp.Prompt{ + Name: "final_prompt", + Description: "Final Prompt", + }, + }, + } + registeredSession.SetSessionPrompts(finalPrompts) + retrievedPrompts = registeredSession.GetSessionPrompts() + if len(retrievedPrompts) != 1 { + t.Errorf("Expected 1 prompt, got %d", len(retrievedPrompts)) + } + if _, exists := retrievedPrompts["final_prompt"]; !exists { + t.Error("Expected final_prompt to exist") + } + }) + t.Run("SessionWithTools implementation", func(t *testing.T) { // Create hooks to track sessions hooks := &Hooks{}