Skip to content

feat: add SessionWithParams interface for URL query parameter access #504

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions server/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ type SessionWithClientInfo interface {
SetClientInfo(clientInfo mcp.Implementation)
}

// SessionWithParams is an extension of ClientSession that can store session parameters
type SessionWithParams interface {
ClientSession
// Params returns the parameters associated with the session.
Params() map[string]string
}

// SessionWithStreamableHTTPConfig extends ClientSession to support streamable HTTP transport configurations
type SessionWithStreamableHTTPConfig interface {
ClientSession
Expand Down
170 changes: 170 additions & 0 deletions server/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"sync"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -172,12 +173,72 @@ func (f *sessionTestClientWithLogging) GetLogLevel() mcp.LoggingLevel {
return level.(mcp.LoggingLevel)
}

// sessionTestClientWithParams implements the SessionWithParams interface for testing
type sessionTestClientWithParams struct {
sessionID string
notificationChannel chan mcp.JSONRPCNotification
initialized bool
params map[string]string
mu sync.RWMutex // Mutex to protect concurrent access to params
}

func (f *sessionTestClientWithParams) SessionID() string {
return f.sessionID
}

func (f *sessionTestClientWithParams) NotificationChannel() chan<- mcp.JSONRPCNotification {
return f.notificationChannel
}

func (f *sessionTestClientWithParams) Initialize() {
f.initialized = true
}

func (f *sessionTestClientWithParams) Initialized() bool {
return f.initialized
}

func (f *sessionTestClientWithParams) Params() map[string]string {
f.mu.RLock()
defer f.mu.RUnlock()

// Return a copy of the map to prevent concurrent modification
if f.params == nil {
return nil
}

paramsCopy := make(map[string]string, len(f.params))
for k, v := range f.params {
paramsCopy[k] = v
}
return paramsCopy
}

// SetParams sets the params (for testing purposes)
func (f *sessionTestClientWithParams) SetParams(params map[string]string) {
f.mu.Lock()
defer f.mu.Unlock()

// Create a copy of the map to prevent concurrent modification
if params == nil {
f.params = nil
return
}

paramsCopy := make(map[string]string, len(params))
for k, v := range params {
paramsCopy[k] = v
}
f.params = paramsCopy
}

// Verify that all implementations satisfy their respective interfaces
var (
_ ClientSession = (*sessionTestClient)(nil)
_ SessionWithTools = (*sessionTestClientWithTools)(nil)
_ SessionWithLogging = (*sessionTestClientWithLogging)(nil)
_ SessionWithClientInfo = (*sessionTestClientWithClientInfo)(nil)
_ SessionWithParams = (*sessionTestClientWithParams)(nil)
)

func TestSessionWithTools_Integration(t *testing.T) {
Expand Down Expand Up @@ -1507,3 +1568,112 @@ func TestMCPServer_LoggingNotificationFormat(t *testing.T) {
})
}
}

func TestSessionWithParams_Integration(t *testing.T) {
server := NewMCPServer("test-server", "1.0.0")

// Create a session with params
testParams := map[string]string{
"tenant_id": "test-tenant-123",
"user_id": "user-456",
"environment": "development",
}

session := &sessionTestClientWithParams{
sessionID: "session-1",
notificationChannel: make(chan mcp.JSONRPCNotification, 10),
initialized: true,
params: testParams,
}

// Register the session
err := server.RegisterSession(context.Background(), session)
require.NoError(t, err)

// Test that we can access the session from context
sessionCtx := server.WithContext(context.Background(), session)

// Check if the session was stored in the context correctly
s := ClientSessionFromContext(sessionCtx)
require.NotNil(t, s, "Session should be available from context")
assert.Equal(t, session.SessionID(), s.SessionID(), "Session ID should match")

// Check if the session can be cast to SessionWithParams
swp, ok := s.(SessionWithParams)
require.True(t, ok, "Session should implement SessionWithParams")

// Test accessing params
params := swp.Params()
require.NotNil(t, params, "Session params should be available")
require.Len(t, params, 3, "Should have 3 params")
assert.Equal(t, "test-tenant-123", params["tenant_id"], "tenant_id should match")
assert.Equal(t, "user-456", params["user_id"], "user_id should match")
assert.Equal(t, "development", params["environment"], "environment should match")

// Test that params are returned as a copy (not the original map)
params["tenant_id"] = "modified-tenant"
originalParams := swp.Params()
assert.Equal(t, "test-tenant-123", originalParams["tenant_id"], "Original params should not be modified")

t.Run("test concurrent access", func(t *testing.T) {
// Test concurrent access to params
var wg sync.WaitGroup
errors := make(chan error, 10)

// Start multiple goroutines reading params
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < 10; j++ {
params := swp.Params()
if params == nil {
errors <- fmt.Errorf("params should not be nil")
return
}
if len(params) != 3 {
errors <- fmt.Errorf("expected 3 params, got %d", len(params))
return
}
}
}()
}

// Start goroutines modifying params
for i := 0; i < 3; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
newParams := map[string]string{
"tenant_id": fmt.Sprintf("tenant-%d", idx),
"user_id": fmt.Sprintf("user-%d", idx),
"environment": "test",
}
session.SetParams(newParams)
}(i)
}

wg.Wait()
close(errors)

// Check for any errors during concurrent access
for err := range errors {
t.Error(err)
}
})

t.Run("test nil params", func(t *testing.T) {
// Test with nil params
session.SetParams(nil)
params := swp.Params()
assert.Nil(t, params, "Params should be nil when set to nil")
})

t.Run("test empty params", func(t *testing.T) {
// Test with empty params
session.SetParams(map[string]string{})
params := swp.Params()
require.NotNil(t, params, "Params should not be nil for empty map")
assert.Len(t, params, 0, "Params should be empty")
})
}
11 changes: 11 additions & 0 deletions server/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type sseSession struct {
loggingLevel atomic.Value
tools sync.Map // stores session-specific tools
clientInfo atomic.Value // stores session-specific client info
params map[string]string
}

// SSEContextFunc is a function that takes an existing context and the current
Expand All @@ -48,6 +49,10 @@ func (s *sseSession) SessionID() string {
return s.sessionID
}

func (s *sseSession) Params() map[string]string {
return s.params
}

func (s *sseSession) NotificationChannel() chan<- mcp.JSONRPCNotification {
return s.notificationChannel
}
Expand Down Expand Up @@ -347,12 +352,18 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
return
}

params := make(map[string]string)
for k, v := range r.URL.Query() {
params[k] = v[0]
}

sessionID := uuid.New().String()
session := &sseSession{
done: make(chan struct{}),
eventQueue: make(chan string, 100), // Buffer for events
sessionID: sessionID,
notificationChannel: make(chan mcp.JSONRPCNotification, 100),
params: params,
}

s.sessions.Store(sessionID, session)
Expand Down
Loading
Loading