Skip to content

Commit c4e255a

Browse files
committed
feat: add SessionWithParams interface for URL query parameter access
- Add SessionWithParams interface to session.go for accessing URL query parameters - Implement Params() method in SSE and StreamableHTTP sessions to parse query parameters - Add comprehensive tests for parameter parsing with concurrent access protection - Support special characters and empty parameter scenarios This enables server.WithToolFilter and other middleware to access session-specific parameters like tenant_id, user_id, or environment from URL query strings, allowing for fine-grained filtering and context-aware tool execution.
1 parent 7c38b56 commit c4e255a

File tree

6 files changed

+553
-3
lines changed

6 files changed

+553
-3
lines changed

server/session.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,13 @@ type SessionWithClientInfo interface {
4848
SetClientInfo(clientInfo mcp.Implementation)
4949
}
5050

51+
// SessionWithParams is an extension of ClientSession that can store session parameters
52+
type SessionWithParams interface {
53+
ClientSession
54+
// Params returns the parameters associated with the session.
55+
Params() map[string]string
56+
}
57+
5158
// SessionWithStreamableHTTPConfig extends ClientSession to support streamable HTTP transport configurations
5259
type SessionWithStreamableHTTPConfig interface {
5360
ClientSession

server/session_test.go

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"encoding/json"
66
"errors"
7+
"fmt"
78
"sync"
89
"sync/atomic"
910
"testing"
@@ -172,12 +173,77 @@ func (f *sessionTestClientWithLogging) GetLogLevel() mcp.LoggingLevel {
172173
return level.(mcp.LoggingLevel)
173174
}
174175

176+
// sessionTestClientWithParams implements the SessionWithParams interface for testing
177+
type sessionTestClientWithParams struct {
178+
sessionID string
179+
notificationChannel chan mcp.JSONRPCNotification
180+
initialized bool
181+
params map[string]string
182+
mu sync.RWMutex // Mutex to protect concurrent access to params
183+
}
184+
185+
func (f *sessionTestClientWithParams) SessionID() string {
186+
return f.sessionID
187+
}
188+
189+
func (f *sessionTestClientWithParams) NotificationChannel() chan<- mcp.JSONRPCNotification {
190+
return f.notificationChannel
191+
}
192+
193+
func (f *sessionTestClientWithParams) Initialize() {
194+
f.initialized = true
195+
}
196+
197+
func (f *sessionTestClientWithParams) Initialized() bool {
198+
return f.initialized
199+
}
200+
201+
func (f *sessionTestClientWithParams) Params() map[string]string {
202+
f.mu.RLock()
203+
defer f.mu.RUnlock()
204+
205+
// Return a copy of the map to prevent concurrent modification
206+
if f.params == nil {
207+
return nil
208+
}
209+
210+
paramsCopy := make(map[string]string, len(f.params))
211+
for k, v := range f.params {
212+
paramsCopy[k] = v
213+
}
214+
return paramsCopy
215+
}
216+
217+
// GetParams returns the current params (for testing purposes)
218+
func (f *sessionTestClientWithParams) GetParams() map[string]string {
219+
return f.Params()
220+
}
221+
222+
// SetParams sets the params (for testing purposes)
223+
func (f *sessionTestClientWithParams) SetParams(params map[string]string) {
224+
f.mu.Lock()
225+
defer f.mu.Unlock()
226+
227+
// Create a copy of the map to prevent concurrent modification
228+
if params == nil {
229+
f.params = nil
230+
return
231+
}
232+
233+
paramsCopy := make(map[string]string, len(params))
234+
for k, v := range params {
235+
paramsCopy[k] = v
236+
}
237+
f.params = paramsCopy
238+
}
239+
175240
// Verify that all implementations satisfy their respective interfaces
176241
var (
177242
_ ClientSession = (*sessionTestClient)(nil)
178243
_ SessionWithTools = (*sessionTestClientWithTools)(nil)
179244
_ SessionWithLogging = (*sessionTestClientWithLogging)(nil)
180245
_ SessionWithClientInfo = (*sessionTestClientWithClientInfo)(nil)
246+
_ SessionWithParams = (*sessionTestClientWithParams)(nil)
181247
)
182248

183249
func TestSessionWithTools_Integration(t *testing.T) {
@@ -1507,3 +1573,112 @@ func TestMCPServer_LoggingNotificationFormat(t *testing.T) {
15071573
})
15081574
}
15091575
}
1576+
1577+
func TestSessionWithParams_Integration(t *testing.T) {
1578+
server := NewMCPServer("test-server", "1.0.0")
1579+
1580+
// Create a session with params
1581+
testParams := map[string]string{
1582+
"tenant_id": "test-tenant-123",
1583+
"user_id": "user-456",
1584+
"environment": "development",
1585+
}
1586+
1587+
session := &sessionTestClientWithParams{
1588+
sessionID: "session-1",
1589+
notificationChannel: make(chan mcp.JSONRPCNotification, 10),
1590+
initialized: true,
1591+
params: testParams,
1592+
}
1593+
1594+
// Register the session
1595+
err := server.RegisterSession(context.Background(), session)
1596+
require.NoError(t, err)
1597+
1598+
// Test that we can access the session from context
1599+
sessionCtx := server.WithContext(context.Background(), session)
1600+
1601+
// Check if the session was stored in the context correctly
1602+
s := ClientSessionFromContext(sessionCtx)
1603+
require.NotNil(t, s, "Session should be available from context")
1604+
assert.Equal(t, session.SessionID(), s.SessionID(), "Session ID should match")
1605+
1606+
// Check if the session can be cast to SessionWithParams
1607+
swp, ok := s.(SessionWithParams)
1608+
require.True(t, ok, "Session should implement SessionWithParams")
1609+
1610+
// Test accessing params
1611+
params := swp.Params()
1612+
require.NotNil(t, params, "Session params should be available")
1613+
require.Len(t, params, 3, "Should have 3 params")
1614+
assert.Equal(t, "test-tenant-123", params["tenant_id"], "tenant_id should match")
1615+
assert.Equal(t, "user-456", params["user_id"], "user_id should match")
1616+
assert.Equal(t, "development", params["environment"], "environment should match")
1617+
1618+
// Test that params are returned as a copy (not the original map)
1619+
params["tenant_id"] = "modified-tenant"
1620+
originalParams := swp.Params()
1621+
assert.Equal(t, "test-tenant-123", originalParams["tenant_id"], "Original params should not be modified")
1622+
1623+
t.Run("test concurrent access", func(t *testing.T) {
1624+
// Test concurrent access to params
1625+
var wg sync.WaitGroup
1626+
errors := make(chan error, 10)
1627+
1628+
// Start multiple goroutines reading params
1629+
for i := 0; i < 5; i++ {
1630+
wg.Add(1)
1631+
go func() {
1632+
defer wg.Done()
1633+
for j := 0; j < 10; j++ {
1634+
params := swp.Params()
1635+
if params == nil {
1636+
errors <- fmt.Errorf("params should not be nil")
1637+
return
1638+
}
1639+
if len(params) != 3 {
1640+
errors <- fmt.Errorf("expected 3 params, got %d", len(params))
1641+
return
1642+
}
1643+
}
1644+
}()
1645+
}
1646+
1647+
// Start goroutines modifying params
1648+
for i := 0; i < 3; i++ {
1649+
wg.Add(1)
1650+
go func(idx int) {
1651+
defer wg.Done()
1652+
newParams := map[string]string{
1653+
"tenant_id": fmt.Sprintf("tenant-%d", idx),
1654+
"user_id": fmt.Sprintf("user-%d", idx),
1655+
"environment": "test",
1656+
}
1657+
session.SetParams(newParams)
1658+
}(i)
1659+
}
1660+
1661+
wg.Wait()
1662+
close(errors)
1663+
1664+
// Check for any errors during concurrent access
1665+
for err := range errors {
1666+
t.Error(err)
1667+
}
1668+
})
1669+
1670+
t.Run("test nil params", func(t *testing.T) {
1671+
// Test with nil params
1672+
session.SetParams(nil)
1673+
params := swp.Params()
1674+
assert.Nil(t, params, "Params should be nil when set to nil")
1675+
})
1676+
1677+
t.Run("test empty params", func(t *testing.T) {
1678+
// Test with empty params
1679+
session.SetParams(map[string]string{})
1680+
params := swp.Params()
1681+
require.NotNil(t, params, "Params should not be nil for empty map")
1682+
assert.Len(t, params, 0, "Params should be empty")
1683+
})
1684+
}

server/sse.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ type sseSession struct {
3030
loggingLevel atomic.Value
3131
tools sync.Map // stores session-specific tools
3232
clientInfo atomic.Value // stores session-specific client info
33+
params map[string]string
3334
}
3435

3536
// SSEContextFunc is a function that takes an existing context and the current
@@ -48,6 +49,10 @@ func (s *sseSession) SessionID() string {
4849
return s.sessionID
4950
}
5051

52+
func (s *sseSession) Params() map[string]string {
53+
return s.params
54+
}
55+
5156
func (s *sseSession) NotificationChannel() chan<- mcp.JSONRPCNotification {
5257
return s.notificationChannel
5358
}
@@ -347,12 +352,18 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
347352
return
348353
}
349354

355+
params := make(map[string]string)
356+
for k, v := range r.URL.Query() {
357+
params[k] = v[0]
358+
}
359+
350360
sessionID := uuid.New().String()
351361
session := &sseSession{
352362
done: make(chan struct{}),
353363
eventQueue: make(chan string, 100), // Buffer for events
354364
sessionID: sessionID,
355365
notificationChannel: make(chan mcp.JSONRPCNotification, 100),
366+
params: params,
356367
}
357368

358369
s.sessions.Store(sessionID, session)

0 commit comments

Comments
 (0)