Skip to content

Commit 4998639

Browse files
committed
feat: add SessionWithParams interface for URL query parameter access
- Add SessionWithParams interface to session.go for accessing URL query parameters - Implement Params() method in SSE and StreamableHTTP sessions to parse query parameters - Add comprehensive tests for parameter parsing with concurrent access protection - Support special characters and empty parameter scenarios This enables server.WithToolFilter and other middleware to access session-specific parameters like tenant_id, user_id, or environment from URL query strings, allowing for fine-grained filtering and context-aware tool execution.
1 parent 7c38b56 commit 4998639

File tree

6 files changed

+548
-3
lines changed

6 files changed

+548
-3
lines changed

server/session.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,13 @@ type SessionWithClientInfo interface {
4848
SetClientInfo(clientInfo mcp.Implementation)
4949
}
5050

51+
// SessionWithParams is an extension of ClientSession that can store session parameters
52+
type SessionWithParams interface {
53+
ClientSession
54+
// Params returns the parameters associated with the session.
55+
Params() map[string]string
56+
}
57+
5158
// SessionWithStreamableHTTPConfig extends ClientSession to support streamable HTTP transport configurations
5259
type SessionWithStreamableHTTPConfig interface {
5360
ClientSession

server/session_test.go

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"encoding/json"
66
"errors"
7+
"fmt"
78
"sync"
89
"sync/atomic"
910
"testing"
@@ -172,12 +173,72 @@ func (f *sessionTestClientWithLogging) GetLogLevel() mcp.LoggingLevel {
172173
return level.(mcp.LoggingLevel)
173174
}
174175

176+
// sessionTestClientWithParams implements the SessionWithParams interface for testing
177+
type sessionTestClientWithParams struct {
178+
sessionID string
179+
notificationChannel chan mcp.JSONRPCNotification
180+
initialized bool
181+
params map[string]string
182+
mu sync.RWMutex // Mutex to protect concurrent access to params
183+
}
184+
185+
func (f *sessionTestClientWithParams) SessionID() string {
186+
return f.sessionID
187+
}
188+
189+
func (f *sessionTestClientWithParams) NotificationChannel() chan<- mcp.JSONRPCNotification {
190+
return f.notificationChannel
191+
}
192+
193+
func (f *sessionTestClientWithParams) Initialize() {
194+
f.initialized = true
195+
}
196+
197+
func (f *sessionTestClientWithParams) Initialized() bool {
198+
return f.initialized
199+
}
200+
201+
func (f *sessionTestClientWithParams) Params() map[string]string {
202+
f.mu.RLock()
203+
defer f.mu.RUnlock()
204+
205+
// Return a copy of the map to prevent concurrent modification
206+
if f.params == nil {
207+
return nil
208+
}
209+
210+
paramsCopy := make(map[string]string, len(f.params))
211+
for k, v := range f.params {
212+
paramsCopy[k] = v
213+
}
214+
return paramsCopy
215+
}
216+
217+
// SetParams sets the params (for testing purposes)
218+
func (f *sessionTestClientWithParams) SetParams(params map[string]string) {
219+
f.mu.Lock()
220+
defer f.mu.Unlock()
221+
222+
// Create a copy of the map to prevent concurrent modification
223+
if params == nil {
224+
f.params = nil
225+
return
226+
}
227+
228+
paramsCopy := make(map[string]string, len(params))
229+
for k, v := range params {
230+
paramsCopy[k] = v
231+
}
232+
f.params = paramsCopy
233+
}
234+
175235
// Verify that all implementations satisfy their respective interfaces
176236
var (
177237
_ ClientSession = (*sessionTestClient)(nil)
178238
_ SessionWithTools = (*sessionTestClientWithTools)(nil)
179239
_ SessionWithLogging = (*sessionTestClientWithLogging)(nil)
180240
_ SessionWithClientInfo = (*sessionTestClientWithClientInfo)(nil)
241+
_ SessionWithParams = (*sessionTestClientWithParams)(nil)
181242
)
182243

183244
func TestSessionWithTools_Integration(t *testing.T) {
@@ -1507,3 +1568,112 @@ func TestMCPServer_LoggingNotificationFormat(t *testing.T) {
15071568
})
15081569
}
15091570
}
1571+
1572+
func TestSessionWithParams_Integration(t *testing.T) {
1573+
server := NewMCPServer("test-server", "1.0.0")
1574+
1575+
// Create a session with params
1576+
testParams := map[string]string{
1577+
"tenant_id": "test-tenant-123",
1578+
"user_id": "user-456",
1579+
"environment": "development",
1580+
}
1581+
1582+
session := &sessionTestClientWithParams{
1583+
sessionID: "session-1",
1584+
notificationChannel: make(chan mcp.JSONRPCNotification, 10),
1585+
initialized: true,
1586+
params: testParams,
1587+
}
1588+
1589+
// Register the session
1590+
err := server.RegisterSession(context.Background(), session)
1591+
require.NoError(t, err)
1592+
1593+
// Test that we can access the session from context
1594+
sessionCtx := server.WithContext(context.Background(), session)
1595+
1596+
// Check if the session was stored in the context correctly
1597+
s := ClientSessionFromContext(sessionCtx)
1598+
require.NotNil(t, s, "Session should be available from context")
1599+
assert.Equal(t, session.SessionID(), s.SessionID(), "Session ID should match")
1600+
1601+
// Check if the session can be cast to SessionWithParams
1602+
swp, ok := s.(SessionWithParams)
1603+
require.True(t, ok, "Session should implement SessionWithParams")
1604+
1605+
// Test accessing params
1606+
params := swp.Params()
1607+
require.NotNil(t, params, "Session params should be available")
1608+
require.Len(t, params, 3, "Should have 3 params")
1609+
assert.Equal(t, "test-tenant-123", params["tenant_id"], "tenant_id should match")
1610+
assert.Equal(t, "user-456", params["user_id"], "user_id should match")
1611+
assert.Equal(t, "development", params["environment"], "environment should match")
1612+
1613+
// Test that params are returned as a copy (not the original map)
1614+
params["tenant_id"] = "modified-tenant"
1615+
originalParams := swp.Params()
1616+
assert.Equal(t, "test-tenant-123", originalParams["tenant_id"], "Original params should not be modified")
1617+
1618+
t.Run("test concurrent access", func(t *testing.T) {
1619+
// Test concurrent access to params
1620+
var wg sync.WaitGroup
1621+
errors := make(chan error, 10)
1622+
1623+
// Start multiple goroutines reading params
1624+
for i := 0; i < 5; i++ {
1625+
wg.Add(1)
1626+
go func() {
1627+
defer wg.Done()
1628+
for j := 0; j < 10; j++ {
1629+
params := swp.Params()
1630+
if params == nil {
1631+
errors <- fmt.Errorf("params should not be nil")
1632+
return
1633+
}
1634+
if len(params) != 3 {
1635+
errors <- fmt.Errorf("expected 3 params, got %d", len(params))
1636+
return
1637+
}
1638+
}
1639+
}()
1640+
}
1641+
1642+
// Start goroutines modifying params
1643+
for i := 0; i < 3; i++ {
1644+
wg.Add(1)
1645+
go func(idx int) {
1646+
defer wg.Done()
1647+
newParams := map[string]string{
1648+
"tenant_id": fmt.Sprintf("tenant-%d", idx),
1649+
"user_id": fmt.Sprintf("user-%d", idx),
1650+
"environment": "test",
1651+
}
1652+
session.SetParams(newParams)
1653+
}(i)
1654+
}
1655+
1656+
wg.Wait()
1657+
close(errors)
1658+
1659+
// Check for any errors during concurrent access
1660+
for err := range errors {
1661+
t.Error(err)
1662+
}
1663+
})
1664+
1665+
t.Run("test nil params", func(t *testing.T) {
1666+
// Test with nil params
1667+
session.SetParams(nil)
1668+
params := swp.Params()
1669+
assert.Nil(t, params, "Params should be nil when set to nil")
1670+
})
1671+
1672+
t.Run("test empty params", func(t *testing.T) {
1673+
// Test with empty params
1674+
session.SetParams(map[string]string{})
1675+
params := swp.Params()
1676+
require.NotNil(t, params, "Params should not be nil for empty map")
1677+
assert.Len(t, params, 0, "Params should be empty")
1678+
})
1679+
}

server/sse.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ type sseSession struct {
3030
loggingLevel atomic.Value
3131
tools sync.Map // stores session-specific tools
3232
clientInfo atomic.Value // stores session-specific client info
33+
params map[string]string
3334
}
3435

3536
// SSEContextFunc is a function that takes an existing context and the current
@@ -48,6 +49,10 @@ func (s *sseSession) SessionID() string {
4849
return s.sessionID
4950
}
5051

52+
func (s *sseSession) Params() map[string]string {
53+
return s.params
54+
}
55+
5156
func (s *sseSession) NotificationChannel() chan<- mcp.JSONRPCNotification {
5257
return s.notificationChannel
5358
}
@@ -347,12 +352,18 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
347352
return
348353
}
349354

355+
params := make(map[string]string)
356+
for k, v := range r.URL.Query() {
357+
params[k] = v[0]
358+
}
359+
350360
sessionID := uuid.New().String()
351361
session := &sseSession{
352362
done: make(chan struct{}),
353363
eventQueue: make(chan string, 100), // Buffer for events
354364
sessionID: sessionID,
355365
notificationChannel: make(chan mcp.JSONRPCNotification, 100),
366+
params: params,
356367
}
357368

358369
s.sessions.Store(sessionID, session)

0 commit comments

Comments
 (0)