|
4 | 4 | "context"
|
5 | 5 | "encoding/json"
|
6 | 6 | "errors"
|
| 7 | + "fmt" |
7 | 8 | "sync"
|
8 | 9 | "sync/atomic"
|
9 | 10 | "testing"
|
@@ -172,12 +173,72 @@ func (f *sessionTestClientWithLogging) GetLogLevel() mcp.LoggingLevel {
|
172 | 173 | return level.(mcp.LoggingLevel)
|
173 | 174 | }
|
174 | 175 |
|
| 176 | +// sessionTestClientWithParams implements the SessionWithParams interface for testing |
| 177 | +type sessionTestClientWithParams struct { |
| 178 | + sessionID string |
| 179 | + notificationChannel chan mcp.JSONRPCNotification |
| 180 | + initialized bool |
| 181 | + params map[string]string |
| 182 | + mu sync.RWMutex // Mutex to protect concurrent access to params |
| 183 | +} |
| 184 | + |
| 185 | +func (f *sessionTestClientWithParams) SessionID() string { |
| 186 | + return f.sessionID |
| 187 | +} |
| 188 | + |
| 189 | +func (f *sessionTestClientWithParams) NotificationChannel() chan<- mcp.JSONRPCNotification { |
| 190 | + return f.notificationChannel |
| 191 | +} |
| 192 | + |
| 193 | +func (f *sessionTestClientWithParams) Initialize() { |
| 194 | + f.initialized = true |
| 195 | +} |
| 196 | + |
| 197 | +func (f *sessionTestClientWithParams) Initialized() bool { |
| 198 | + return f.initialized |
| 199 | +} |
| 200 | + |
| 201 | +func (f *sessionTestClientWithParams) Params() map[string]string { |
| 202 | + f.mu.RLock() |
| 203 | + defer f.mu.RUnlock() |
| 204 | + |
| 205 | + // Return a copy of the map to prevent concurrent modification |
| 206 | + if f.params == nil { |
| 207 | + return nil |
| 208 | + } |
| 209 | + |
| 210 | + paramsCopy := make(map[string]string, len(f.params)) |
| 211 | + for k, v := range f.params { |
| 212 | + paramsCopy[k] = v |
| 213 | + } |
| 214 | + return paramsCopy |
| 215 | +} |
| 216 | + |
| 217 | +// SetParams sets the params (for testing purposes) |
| 218 | +func (f *sessionTestClientWithParams) SetParams(params map[string]string) { |
| 219 | + f.mu.Lock() |
| 220 | + defer f.mu.Unlock() |
| 221 | + |
| 222 | + // Create a copy of the map to prevent concurrent modification |
| 223 | + if params == nil { |
| 224 | + f.params = nil |
| 225 | + return |
| 226 | + } |
| 227 | + |
| 228 | + paramsCopy := make(map[string]string, len(params)) |
| 229 | + for k, v := range params { |
| 230 | + paramsCopy[k] = v |
| 231 | + } |
| 232 | + f.params = paramsCopy |
| 233 | +} |
| 234 | + |
175 | 235 | // Verify that all implementations satisfy their respective interfaces
|
176 | 236 | var (
|
177 | 237 | _ ClientSession = (*sessionTestClient)(nil)
|
178 | 238 | _ SessionWithTools = (*sessionTestClientWithTools)(nil)
|
179 | 239 | _ SessionWithLogging = (*sessionTestClientWithLogging)(nil)
|
180 | 240 | _ SessionWithClientInfo = (*sessionTestClientWithClientInfo)(nil)
|
| 241 | + _ SessionWithParams = (*sessionTestClientWithParams)(nil) |
181 | 242 | )
|
182 | 243 |
|
183 | 244 | func TestSessionWithTools_Integration(t *testing.T) {
|
@@ -1507,3 +1568,112 @@ func TestMCPServer_LoggingNotificationFormat(t *testing.T) {
|
1507 | 1568 | })
|
1508 | 1569 | }
|
1509 | 1570 | }
|
| 1571 | + |
| 1572 | +func TestSessionWithParams_Integration(t *testing.T) { |
| 1573 | + server := NewMCPServer("test-server", "1.0.0") |
| 1574 | + |
| 1575 | + // Create a session with params |
| 1576 | + testParams := map[string]string{ |
| 1577 | + "tenant_id": "test-tenant-123", |
| 1578 | + "user_id": "user-456", |
| 1579 | + "environment": "development", |
| 1580 | + } |
| 1581 | + |
| 1582 | + session := &sessionTestClientWithParams{ |
| 1583 | + sessionID: "session-1", |
| 1584 | + notificationChannel: make(chan mcp.JSONRPCNotification, 10), |
| 1585 | + initialized: true, |
| 1586 | + params: testParams, |
| 1587 | + } |
| 1588 | + |
| 1589 | + // Register the session |
| 1590 | + err := server.RegisterSession(context.Background(), session) |
| 1591 | + require.NoError(t, err) |
| 1592 | + |
| 1593 | + // Test that we can access the session from context |
| 1594 | + sessionCtx := server.WithContext(context.Background(), session) |
| 1595 | + |
| 1596 | + // Check if the session was stored in the context correctly |
| 1597 | + s := ClientSessionFromContext(sessionCtx) |
| 1598 | + require.NotNil(t, s, "Session should be available from context") |
| 1599 | + assert.Equal(t, session.SessionID(), s.SessionID(), "Session ID should match") |
| 1600 | + |
| 1601 | + // Check if the session can be cast to SessionWithParams |
| 1602 | + swp, ok := s.(SessionWithParams) |
| 1603 | + require.True(t, ok, "Session should implement SessionWithParams") |
| 1604 | + |
| 1605 | + // Test accessing params |
| 1606 | + params := swp.Params() |
| 1607 | + require.NotNil(t, params, "Session params should be available") |
| 1608 | + require.Len(t, params, 3, "Should have 3 params") |
| 1609 | + assert.Equal(t, "test-tenant-123", params["tenant_id"], "tenant_id should match") |
| 1610 | + assert.Equal(t, "user-456", params["user_id"], "user_id should match") |
| 1611 | + assert.Equal(t, "development", params["environment"], "environment should match") |
| 1612 | + |
| 1613 | + // Test that params are returned as a copy (not the original map) |
| 1614 | + params["tenant_id"] = "modified-tenant" |
| 1615 | + originalParams := swp.Params() |
| 1616 | + assert.Equal(t, "test-tenant-123", originalParams["tenant_id"], "Original params should not be modified") |
| 1617 | + |
| 1618 | + t.Run("test concurrent access", func(t *testing.T) { |
| 1619 | + // Test concurrent access to params |
| 1620 | + var wg sync.WaitGroup |
| 1621 | + errors := make(chan error, 10) |
| 1622 | + |
| 1623 | + // Start multiple goroutines reading params |
| 1624 | + for i := 0; i < 5; i++ { |
| 1625 | + wg.Add(1) |
| 1626 | + go func() { |
| 1627 | + defer wg.Done() |
| 1628 | + for j := 0; j < 10; j++ { |
| 1629 | + params := swp.Params() |
| 1630 | + if params == nil { |
| 1631 | + errors <- fmt.Errorf("params should not be nil") |
| 1632 | + return |
| 1633 | + } |
| 1634 | + if len(params) != 3 { |
| 1635 | + errors <- fmt.Errorf("expected 3 params, got %d", len(params)) |
| 1636 | + return |
| 1637 | + } |
| 1638 | + } |
| 1639 | + }() |
| 1640 | + } |
| 1641 | + |
| 1642 | + // Start goroutines modifying params |
| 1643 | + for i := 0; i < 3; i++ { |
| 1644 | + wg.Add(1) |
| 1645 | + go func(idx int) { |
| 1646 | + defer wg.Done() |
| 1647 | + newParams := map[string]string{ |
| 1648 | + "tenant_id": fmt.Sprintf("tenant-%d", idx), |
| 1649 | + "user_id": fmt.Sprintf("user-%d", idx), |
| 1650 | + "environment": "test", |
| 1651 | + } |
| 1652 | + session.SetParams(newParams) |
| 1653 | + }(i) |
| 1654 | + } |
| 1655 | + |
| 1656 | + wg.Wait() |
| 1657 | + close(errors) |
| 1658 | + |
| 1659 | + // Check for any errors during concurrent access |
| 1660 | + for err := range errors { |
| 1661 | + t.Error(err) |
| 1662 | + } |
| 1663 | + }) |
| 1664 | + |
| 1665 | + t.Run("test nil params", func(t *testing.T) { |
| 1666 | + // Test with nil params |
| 1667 | + session.SetParams(nil) |
| 1668 | + params := swp.Params() |
| 1669 | + assert.Nil(t, params, "Params should be nil when set to nil") |
| 1670 | + }) |
| 1671 | + |
| 1672 | + t.Run("test empty params", func(t *testing.T) { |
| 1673 | + // Test with empty params |
| 1674 | + session.SetParams(map[string]string{}) |
| 1675 | + params := swp.Params() |
| 1676 | + require.NotNil(t, params, "Params should not be nil for empty map") |
| 1677 | + assert.Len(t, params, 0, "Params should be empty") |
| 1678 | + }) |
| 1679 | +} |
0 commit comments