Skip to content

Commit 4d65676

Browse files
committed
mcp: update resourceSubscription map type and resourceCapabilities logic
1 parent 00af015 commit 4d65676

File tree

2 files changed

+31
-23
lines changed

2 files changed

+31
-23
lines changed

mcp/server.go

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"fmt"
1414
"iter"
1515
"log"
16+
"maps"
1617
"net/url"
1718
"path/filepath"
1819
"slices"
@@ -43,7 +44,7 @@ type Server struct {
4344
sessions []*ServerSession
4445
sendingMethodHandler_ MethodHandler[*ServerSession]
4546
receivingMethodHandler_ MethodHandler[*ServerSession]
46-
resourceSubscriptions map[string][]*ServerSession // uri -> session
47+
resourceSubscriptions map[string]map[string]bool // uri -> session ID -> bool
4748
}
4849

4950
// ServerOptions is used to configure behavior of the server.
@@ -108,7 +109,7 @@ func NewServer(impl *Implementation, opts *ServerOptions) *Server {
108109
resourceTemplates: newFeatureSet(func(t *serverResourceTemplate) string { return t.resourceTemplate.URITemplate }),
109110
sendingMethodHandler_: defaultSendingMethodHandler[*ServerSession],
110111
receivingMethodHandler_: defaultReceivingMethodHandler[*ServerSession],
111-
resourceSubscriptions: make(map[string][]*ServerSession),
112+
resourceSubscriptions: make(map[string]map[string]bool),
112113
}
113114
}
114115

@@ -235,12 +236,9 @@ func (s *Server) capabilities() *serverCapabilities {
235236
}
236237
if s.resources.len() > 0 || s.resourceTemplates.len() > 0 {
237238
caps.Resources = &resourceCapabilities{ListChanged: true}
238-
}
239-
if s.opts.SubscribeHandler != nil {
240-
if caps.Resources == nil {
241-
caps.Resources = &resourceCapabilities{}
239+
if s.opts.SubscribeHandler != nil {
240+
caps.Resources.Subscribe = true
242241
}
243-
caps.Resources.Subscribe = true
244242
}
245243
return caps
246244
}
@@ -446,11 +444,22 @@ func fileResourceHandler(dir string) ResourceHandler {
446444

447445
func (s *Server) ResourceUpdated(ctx context.Context, params *ResourceUpdatedNotificationParams) error {
448446
s.mu.Lock()
449-
sessions := slices.Clone(s.resourceSubscriptions[params.URI])
447+
subscribedSessionIDs := maps.Clone(s.resourceSubscriptions[params.URI])
450448
s.mu.Unlock()
451-
if len(sessions) == 0 {
449+
if len(subscribedSessionIDs) == 0 {
452450
return nil
453451
}
452+
sessions := make([]*ServerSession, 0, len(subscribedSessionIDs))
453+
for sessionID, active := range subscribedSessionIDs {
454+
if !active {
455+
continue
456+
}
457+
for session := range s.Sessions() {
458+
if session.ID() == sessionID {
459+
sessions = append(sessions, session)
460+
}
461+
}
462+
}
454463
notifySessions(sessions, notificationResourceUpdated, params)
455464
return nil
456465
}
@@ -465,10 +474,10 @@ func (s *Server) subscribe(ctx context.Context, ss *ServerSession, params *Subsc
465474
s.mu.Lock()
466475
defer s.mu.Unlock()
467476
uri := params.URI
468-
subscribers := s.resourceSubscriptions[uri]
469-
if !slices.Contains(subscribers, ss) {
470-
s.resourceSubscriptions[uri] = append(subscribers, ss)
477+
if s.resourceSubscriptions[uri] == nil {
478+
s.resourceSubscriptions[uri] = make(map[string]bool)
471479
}
480+
s.resourceSubscriptions[uri][ss.ID()] = true
472481
return &emptyResult{}, nil
473482
}
474483

@@ -485,10 +494,8 @@ func (s *Server) unsubscribe(ctx context.Context, ss *ServerSession, params *Uns
485494
defer s.mu.Unlock()
486495

487496
uri := params.URI
488-
if sessions, ok := s.resourceSubscriptions[uri]; ok {
489-
s.resourceSubscriptions[uri] = slices.DeleteFunc(sessions, func(s *ServerSession) bool {
490-
return s == ss
491-
})
497+
if subscribedSessionIDs, ok := s.resourceSubscriptions[uri]; ok {
498+
subscribedSessionIDs[ss.ID()] = false
492499
}
493500
return &emptyResult{}, nil
494501
}
@@ -540,10 +547,9 @@ func (s *Server) disconnect(cc *ServerSession) {
540547
s.sessions = slices.DeleteFunc(s.sessions, func(cc2 *ServerSession) bool {
541548
return cc2 == cc
542549
})
543-
for uri, sessions := range s.resourceSubscriptions {
544-
s.resourceSubscriptions[uri] = slices.DeleteFunc(sessions, func(cc2 *ServerSession) bool {
545-
return cc2 == cc
546-
})
550+
551+
for _, subscribedSessionIDs := range s.resourceSubscriptions {
552+
delete(subscribedSessionIDs, cc.ID())
547553
}
548554
}
549555

mcp/server_test.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -278,8 +278,10 @@ func TestServerCapabilities(t *testing.T) {
278278
},
279279
},
280280
{
281-
name: "With resource subscriptions",
282-
configureServer: func(s *Server) {},
281+
name: "With resource subscriptions",
282+
configureServer: func(s *Server) {
283+
s.AddResourceTemplate(&ResourceTemplate{URITemplate: "file:///rt"}, nil)
284+
},
283285
serverOpts: ServerOptions{
284286
SubscribeHandler: func(ctx context.Context, sp *SubscribeParams) error {
285287
return nil
@@ -291,7 +293,7 @@ func TestServerCapabilities(t *testing.T) {
291293
wantCapabilities: &serverCapabilities{
292294
Completions: &completionCapabilities{},
293295
Logging: &loggingCapabilities{},
294-
Resources: &resourceCapabilities{Subscribe: true},
296+
Resources: &resourceCapabilities{ListChanged: true, Subscribe: true},
295297
},
296298
},
297299
{

0 commit comments

Comments
 (0)