Skip to content

Commit 4894249

Browse files
committed
mcp: update resourceSubscription map type and resourceCapabilities logic
1 parent f4fd5c2 commit 4894249

File tree

2 files changed

+31
-23
lines changed

2 files changed

+31
-23
lines changed

mcp/server.go

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"fmt"
1414
"iter"
1515
"log"
16+
"maps"
1617
"net/url"
1718
"path/filepath"
1819
"slices"
@@ -43,7 +44,7 @@ type Server struct {
4344
sessions []*ServerSession
4445
sendingMethodHandler_ MethodHandler[*ServerSession]
4546
receivingMethodHandler_ MethodHandler[*ServerSession]
46-
resourceSubscriptions map[string][]*ServerSession // uri -> session
47+
resourceSubscriptions map[string]map[string]bool // uri -> session ID -> bool
4748
}
4849

4950
// ServerOptions is used to configure behavior of the server.
@@ -109,7 +110,7 @@ func NewServer(impl *Implementation, opts *ServerOptions) *Server {
109110
resourceTemplates: newFeatureSet(func(t *serverResourceTemplate) string { return t.resourceTemplate.URITemplate }),
110111
sendingMethodHandler_: defaultSendingMethodHandler[*ServerSession],
111112
receivingMethodHandler_: defaultReceivingMethodHandler[*ServerSession],
112-
resourceSubscriptions: make(map[string][]*ServerSession),
113+
resourceSubscriptions: make(map[string]map[string]bool),
113114
}
114115
}
115116

@@ -236,12 +237,9 @@ func (s *Server) capabilities() *serverCapabilities {
236237
}
237238
if s.resources.len() > 0 || s.resourceTemplates.len() > 0 {
238239
caps.Resources = &resourceCapabilities{ListChanged: true}
239-
}
240-
if s.opts.SubscribeHandler != nil {
241-
if caps.Resources == nil {
242-
caps.Resources = &resourceCapabilities{}
240+
if s.opts.SubscribeHandler != nil {
241+
caps.Resources.Subscribe = true
243242
}
244-
caps.Resources.Subscribe = true
245243
}
246244
return caps
247245
}
@@ -447,11 +445,22 @@ func fileResourceHandler(dir string) ResourceHandler {
447445

448446
func (s *Server) ResourceUpdated(ctx context.Context, params *ResourceUpdatedNotificationParams) error {
449447
s.mu.Lock()
450-
sessions := slices.Clone(s.resourceSubscriptions[params.URI])
448+
subscribedSessionIDs := maps.Clone(s.resourceSubscriptions[params.URI])
451449
s.mu.Unlock()
452-
if len(sessions) == 0 {
450+
if len(subscribedSessionIDs) == 0 {
453451
return nil
454452
}
453+
sessions := make([]*ServerSession, 0, len(subscribedSessionIDs))
454+
for sessionID, active := range subscribedSessionIDs {
455+
if !active {
456+
continue
457+
}
458+
for session := range s.Sessions() {
459+
if session.ID() == sessionID {
460+
sessions = append(sessions, session)
461+
}
462+
}
463+
}
455464
notifySessions(sessions, notificationResourceUpdated, params)
456465
return nil
457466
}
@@ -466,10 +475,10 @@ func (s *Server) subscribe(ctx context.Context, ss *ServerSession, params *Subsc
466475
s.mu.Lock()
467476
defer s.mu.Unlock()
468477
uri := params.URI
469-
subscribers := s.resourceSubscriptions[uri]
470-
if !slices.Contains(subscribers, ss) {
471-
s.resourceSubscriptions[uri] = append(subscribers, ss)
478+
if s.resourceSubscriptions[uri] == nil {
479+
s.resourceSubscriptions[uri] = make(map[string]bool)
472480
}
481+
s.resourceSubscriptions[uri][ss.ID()] = true
473482
return &emptyResult{}, nil
474483
}
475484

@@ -486,10 +495,8 @@ func (s *Server) unsubscribe(ctx context.Context, ss *ServerSession, params *Uns
486495
defer s.mu.Unlock()
487496

488497
uri := params.URI
489-
if sessions, ok := s.resourceSubscriptions[uri]; ok {
490-
s.resourceSubscriptions[uri] = slices.DeleteFunc(sessions, func(s *ServerSession) bool {
491-
return s == ss
492-
})
498+
if subscribedSessionIDs, ok := s.resourceSubscriptions[uri]; ok {
499+
subscribedSessionIDs[ss.ID()] = false
493500
}
494501
return &emptyResult{}, nil
495502
}
@@ -541,10 +548,9 @@ func (s *Server) disconnect(cc *ServerSession) {
541548
s.sessions = slices.DeleteFunc(s.sessions, func(cc2 *ServerSession) bool {
542549
return cc2 == cc
543550
})
544-
for uri, sessions := range s.resourceSubscriptions {
545-
s.resourceSubscriptions[uri] = slices.DeleteFunc(sessions, func(cc2 *ServerSession) bool {
546-
return cc2 == cc
547-
})
551+
552+
for _, subscribedSessionIDs := range s.resourceSubscriptions {
553+
delete(subscribedSessionIDs, cc.ID())
548554
}
549555
}
550556

mcp/server_test.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -278,8 +278,10 @@ func TestServerCapabilities(t *testing.T) {
278278
},
279279
},
280280
{
281-
name: "With resource subscriptions",
282-
configureServer: func(s *Server) {},
281+
name: "With resource subscriptions",
282+
configureServer: func(s *Server) {
283+
s.AddResourceTemplate(&ResourceTemplate{URITemplate: "file:///rt"}, nil)
284+
},
283285
serverOpts: ServerOptions{
284286
SubscribeHandler: func(ctx context.Context, sp *SubscribeParams) error {
285287
return nil
@@ -291,7 +293,7 @@ func TestServerCapabilities(t *testing.T) {
291293
wantCapabilities: &serverCapabilities{
292294
Completions: &completionCapabilities{},
293295
Logging: &loggingCapabilities{},
294-
Resources: &resourceCapabilities{Subscribe: true},
296+
Resources: &resourceCapabilities{ListChanged: true, Subscribe: true},
295297
},
296298
},
297299
{

0 commit comments

Comments
 (0)