Skip to content

feat: Implement per-session prompt functionality #459

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,7 @@ MCP-Go provides a robust session management system that allows you to:
- Register and track client sessions
- Send notifications to specific clients
- Provide per-session tool customization
- Provide per-session prompt customization

<details>
<summary>Show Session Management Examples</summary>
Expand Down
1 change: 1 addition & 0 deletions server/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ var (
ErrSessionExists = errors.New("session already exists")
ErrSessionNotInitialized = errors.New("session not properly initialized")
ErrSessionDoesNotSupportTools = errors.New("session does not support per-session tools")
ErrSessionDoesNotSupportPrompts = errors.New("session does not support per-session prompts")
ErrSessionDoesNotSupportLogging = errors.New("session does not support setting logging level")

// Notification-related errors
Expand Down
53 changes: 50 additions & 3 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,34 @@ func (s *MCPServer) handleListPrompts(
}
s.promptsMu.RUnlock()

// Check if there are session-specific prompts
session := ClientSessionFromContext(ctx)
if session != nil {
if sessionWithPrompts, ok := session.(SessionWithPrompts); ok {
if sessionPrompts := sessionWithPrompts.GetSessionPrompts(); sessionPrompts != nil {
// Override or add session-specific prompts
// We need to create a map first to merge the prompts properly
promptMap := make(map[string]mcp.Prompt)

// Add global prompts first
for _, prompt := range prompts {
promptMap[prompt.Name] = prompt
}

// Then override with session-specific tools
for name, serverPrompt := range sessionPrompts {
promptMap[name] = serverPrompt.Prompt
}

// Convert back to slice
prompts = make([]mcp.Prompt, 0, len(promptMap))
for _, prompt := range promptMap {
prompts = append(prompts, prompt)
}
}
}
}

// sort prompts by name
sort.Slice(prompts, func(i, j int) bool {
return prompts[i].Name < prompts[j].Name
Expand Down Expand Up @@ -863,9 +891,28 @@ func (s *MCPServer) handleGetPrompt(
id any,
request mcp.GetPromptRequest,
) (*mcp.GetPromptResult, *requestError) {
s.promptsMu.RLock()
handler, ok := s.promptHandlers[request.Params.Name]
s.promptsMu.RUnlock()
// First check session-specific prompts
var handler PromptHandlerFunc
var ok bool

session := ClientSessionFromContext(ctx)
if session != nil {
if sessionWithPrompts, typeAssertOk := session.(SessionWithPrompts); typeAssertOk {
if sessionPrompts := sessionWithPrompts.GetSessionPrompts(); sessionPrompts != nil {
if serverPrompt, sessionOk := sessionPrompts[request.Params.Name]; sessionOk {
handler = serverPrompt.Handler
ok = true
}
}
}
}

// If not found in session prompts, check global prompts
if !ok {
s.promptsMu.RLock()
handler, ok = s.promptHandlers[request.Params.Name]
s.promptsMu.RUnlock()
}

if !ok {
return nil, &requestError{
Expand Down
65 changes: 65 additions & 0 deletions server/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ type SessionWithTools interface {
SetSessionTools(tools map[string]ServerTool)
}

type SessionWithPrompts interface {
ClientSession
// GetPrompts returns the prompts specific to this session, if any
GetSessionPrompts() map[string]ServerPrompt
// SetPrompts sets prompts specific to this session
SetSessionPrompts(prompts map[string]ServerPrompt)
}
Comment on lines +42 to +48
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Document thread-safety requirements for consistency.

The SessionWithTools interface above documents thread-safety requirements in its method comments, but SessionWithPrompts does not. For consistency and to ensure proper implementation, please add similar documentation.

 type SessionWithPrompts interface {
 	ClientSession
-	// GetPrompts returns the prompts specific to this session, if any
+	// GetSessionPrompts returns the prompts specific to this session, if any
+	// This method must be thread-safe for concurrent access
 	GetSessionPrompts() map[string]ServerPrompt
-	// SetPrompts sets prompts specific to this session
+	// SetSessionPrompts sets prompts specific to this session
+	// This method must be thread-safe for concurrent access
 	SetSessionPrompts(prompts map[string]ServerPrompt)
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
type SessionWithPrompts interface {
ClientSession
// GetPrompts returns the prompts specific to this session, if any
GetSessionPrompts() map[string]ServerPrompt
// SetPrompts sets prompts specific to this session
SetSessionPrompts(prompts map[string]ServerPrompt)
}
type SessionWithPrompts interface {
ClientSession
// GetSessionPrompts returns the prompts specific to this session, if any
// This method must be thread-safe for concurrent access
GetSessionPrompts() map[string]ServerPrompt
// SetSessionPrompts sets prompts specific to this session
// This method must be thread-safe for concurrent access
SetSessionPrompts(prompts map[string]ServerPrompt)
}
🤖 Prompt for AI Agents
In server/session.go around lines 42 to 48, the SessionWithPrompts interface
lacks thread-safety documentation unlike the SessionWithTools interface. Add
comments to the GetSessionPrompts and SetSessionPrompts methods specifying their
thread-safety requirements, indicating whether they must be safe for concurrent
use or require external synchronization, to maintain consistency and guide
proper implementation.


// SessionWithClientInfo is an extension of ClientSession that can store client info
type SessionWithClientInfo interface {
ClientSession
Expand Down Expand Up @@ -378,3 +386,60 @@ func (s *MCPServer) DeleteSessionTools(sessionID string, names ...string) error

return nil
}

func (s *MCPServer) AddSessionPrompts(sessionID string, prompts ...ServerPrompt) error {
sessionValue, ok := s.sessions.Load(sessionID)
if !ok {
return ErrSessionNotFound
}

session, ok := sessionValue.(SessionWithPrompts)
if !ok {
return ErrSessionDoesNotSupportPrompts
}

s.implicitlyRegisterPromptCapabilities()

// Get existing prompts (this should return a thread-safe copy)
sessionPrompts := session.GetSessionPrompts()

// Create a new map to avoid concurrent modification issues
newSessionPrompts := make(map[string]ServerPrompt, len(sessionPrompts)+len(prompts))

// Copy existing prompts
for k, v := range sessionPrompts {
newSessionPrompts[k] = v
}

// Add new prompts
for _, prompt := range prompts {
newSessionPrompts[prompt.Prompt.Name] = prompt
}

// Set the prompts (this should be thread-safe)
session.SetSessionPrompts(newSessionPrompts)

if session.Initialized() && s.capabilities.prompts != nil && s.capabilities.prompts.listChanged {
// Send notification only to this session
if err := s.SendNotificationToSpecificClient(sessionID, "notifications/prompts/list_changed", nil); err != nil {
// Log the error but don't fail the operation
// The prompts were successfully added, but notification failed
if s.hooks != nil && len(s.hooks.OnError) > 0 {
hooks := s.hooks
go func(sID string, hooks *Hooks) {
ctx := context.Background()
hooks.onError(ctx, nil, "notification", map[string]any{
"method": "notifications/prompts/list_changed",
"sessionID": sID,
}, fmt.Errorf("failed to send notification after adding prompts: %w", err))
}(sessionID, hooks)
}
}
}

return nil
}

func (s *MCPServer) AddSessionPrompt(sessionID string, prompt mcp.Prompt, handler PromptHandlerFunc) error {
return s.AddSessionPrompts(sessionID, ServerPrompt{Prompt: prompt, Handler: handler})
}
Loading