diff --git a/design/design.md b/design/design.md index 8804292e..bfabeac7 100644 --- a/design/design.md +++ b/design/design.md @@ -748,13 +748,26 @@ Server sessions also support the spec methods `ListResources` and `ListResourceT #### Subscriptions -ClientSessions can manage change notifications on particular resources: +##### Client-Side Usage + +Use the Subscribe and Unsubscribe methods on a ClientSession to start or stop receiving updates for a specific resource. ```go func (*ClientSession) Subscribe(context.Context, *SubscribeParams) error func (*ClientSession) Unsubscribe(context.Context, *UnsubscribeParams) error ``` +To process incoming update notifications, you must provide a ResourceUpdatedHandler in your ClientOptions. The SDK calls this function automatically whenever the server sends a notification for a resource you're subscribed to. + +```go +type ClientOptions struct { + ... + ResourceUpdatedHandler func(context.Context, *ClientSession, *ResourceUpdatedNotificationParams) +} +``` + +##### Server-Side Implementation + The server does not implement resource subscriptions. It passes along subscription requests to the user, and supplies a method to notify clients of changes. It tracks which sessions have subscribed to which resources so the user doesn't have to. If a server author wants to support resource subscriptions, they must provide handlers to be called when clients subscribe and unsubscribe. It is an error to provide only one of these handlers. @@ -772,7 +785,7 @@ type ServerOptions struct { User code should call `ResourceUpdated` when a subscribed resource changes. ```go -func (*Server) ResourceUpdated(context.Context, *ResourceUpdatedNotification) error +func (*Server) ResourceUpdated(context.Context, *ResourceUpdatedNotificationParams) error ``` The server routes these notifications to the server sessions that subscribed to the resource. diff --git a/mcp/client.go b/mcp/client.go index b48ad7a1..b386294c 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -60,6 +60,7 @@ type ClientOptions struct { ToolListChangedHandler func(context.Context, *ClientSession, *ToolListChangedParams) PromptListChangedHandler func(context.Context, *ClientSession, *PromptListChangedParams) ResourceListChangedHandler func(context.Context, *ClientSession, *ResourceListChangedParams) + ResourceUpdatedHandler func(context.Context, *ClientSession, *ResourceUpdatedNotificationParams) LoggingMessageHandler func(context.Context, *ClientSession, *LoggingMessageParams) ProgressNotificationHandler func(context.Context, *ClientSession, *ProgressNotificationParams) // If non-zero, defines an interval for regular "ping" requests. @@ -293,6 +294,7 @@ var clientMethodInfos = map[string]methodInfo{ notificationToolListChanged: newMethodInfo(clientMethod((*Client).callToolChangedHandler)), notificationPromptListChanged: newMethodInfo(clientMethod((*Client).callPromptChangedHandler)), notificationResourceListChanged: newMethodInfo(clientMethod((*Client).callResourceChangedHandler)), + notificationResourceUpdated: newMethodInfo(clientMethod((*Client).callResourceUpdatedHandler)), notificationLoggingMessage: newMethodInfo(clientMethod((*Client).callLoggingHandler)), notificationProgress: newMethodInfo(sessionMethod((*ClientSession).callProgressNotificationHandler)), } @@ -386,6 +388,20 @@ func (cs *ClientSession) Complete(ctx context.Context, params *CompleteParams) ( return handleSend[*CompleteResult](ctx, cs, methodComplete, orZero[Params](params)) } +// Subscribe sends a "resources/subscribe" request to the server, asking for +// notifications when the specified resource changes. +func (cs *ClientSession) Subscribe(ctx context.Context, params *SubscribeParams) error { + _, err := handleSend[*emptyResult](ctx, cs, methodSubscribe, orZero[Params](params)) + return err +} + +// Unsubscribe sends a "resources/unsubscribe" request to the server, cancelling +// a previous subscription. +func (cs *ClientSession) Unsubscribe(ctx context.Context, params *UnsubscribeParams) error { + _, err := handleSend[*emptyResult](ctx, cs, methodUnsubscribe, orZero[Params](params)) + return err +} + func (c *Client) callToolChangedHandler(ctx context.Context, s *ClientSession, params *ToolListChangedParams) (Result, error) { return callNotificationHandler(ctx, c.opts.ToolListChangedHandler, s, params) } @@ -398,6 +414,10 @@ func (c *Client) callResourceChangedHandler(ctx context.Context, s *ClientSessio return callNotificationHandler(ctx, c.opts.ResourceListChangedHandler, s, params) } +func (c *Client) callResourceUpdatedHandler(ctx context.Context, s *ClientSession, params *ResourceUpdatedNotificationParams) (Result, error) { + return callNotificationHandler(ctx, c.opts.ResourceUpdatedHandler, s, params) +} + func (c *Client) callLoggingHandler(ctx context.Context, cs *ClientSession, params *LoggingMessageParams) (Result, error) { if h := c.opts.LoggingMessageHandler; h != nil { h(ctx, cs, params) diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 7da2b857..032181a1 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -60,7 +60,7 @@ func TestEndToEnd(t *testing.T) { // Channels to check if notification callbacks happened. notificationChans := map[string]chan int{} - for _, name := range []string{"initialized", "roots", "tools", "prompts", "resources", "progress_server", "progress_client"} { + for _, name := range []string{"initialized", "roots", "tools", "prompts", "resources", "progress_server", "progress_client", "resource_updated", "subscribe", "unsubscribe"} { notificationChans[name] = make(chan int, 1) } waitForNotification := func(t *testing.T, name string) { @@ -78,6 +78,14 @@ func TestEndToEnd(t *testing.T) { ProgressNotificationHandler: func(context.Context, *ServerSession, *ProgressNotificationParams) { notificationChans["progress_server"] <- 0 }, + SubscribeHandler: func(context.Context, *SubscribeParams) error { + notificationChans["subscribe"] <- 0 + return nil + }, + UnsubscribeHandler: func(context.Context, *UnsubscribeParams) error { + notificationChans["unsubscribe"] <- 0 + return nil + }, } s := NewServer(testImpl, sopts) AddTool(s, &Tool{ @@ -128,6 +136,9 @@ func TestEndToEnd(t *testing.T) { ProgressNotificationHandler: func(context.Context, *ClientSession, *ProgressNotificationParams) { notificationChans["progress_client"] <- 0 }, + ResourceUpdatedHandler: func(context.Context, *ClientSession, *ResourceUpdatedNotificationParams) { + notificationChans["resource_updated"] <- 0 + }, } c := NewClient(testImpl, opts) rootAbs, err := filepath.Abs(filepath.FromSlash("testdata/files")) @@ -421,6 +432,37 @@ func TestEndToEnd(t *testing.T) { waitForNotification(t, "progress_server") }) + t.Run("resource_subscriptions", func(t *testing.T) { + err := cs.Subscribe(ctx, &SubscribeParams{ + URI: "test", + }) + if err != nil { + t.Fatal(err) + } + waitForNotification(t, "subscribe") + s.ResourceUpdated(ctx, &ResourceUpdatedNotificationParams{ + URI: "test", + }) + waitForNotification(t, "resource_updated") + err = cs.Unsubscribe(ctx, &UnsubscribeParams{ + URI: "test", + }) + if err != nil { + t.Fatal(err) + } + waitForNotification(t, "unsubscribe") + + // Verify the client does not receive the update after unsubscribing. + s.ResourceUpdated(ctx, &ResourceUpdatedNotificationParams{ + URI: "test", + }) + select { + case <-notificationChans["resource_updated"]: + t.Fatalf("resource updated after unsubscription") + case <-time.After(time.Second): + } + }) + // Disconnect. cs.Close() clientWG.Wait() diff --git a/mcp/protocol.go b/mcp/protocol.go index 4f47c961..00dcd14d 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -859,6 +859,38 @@ type ToolListChangedParams struct { func (x *ToolListChangedParams) GetProgressToken() any { return getProgressToken(x) } func (x *ToolListChangedParams) SetProgressToken(t any) { setProgressToken(x, t) } +// Sent from the client to request resources/updated notifications from the +// server whenever a particular resource changes. +type SubscribeParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The URI of the resource to subscribe to. + URI string `json:"uri"` +} + +// Sent from the client to request cancellation of resources/updated +// notifications from the server. This should follow a previous +// resources/subscribe request. +type UnsubscribeParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The URI of the resource to unsubscribe from. + URI string `json:"uri"` +} + +// A notification from the server to the client, informing it that a resource +// has changed and may need to be read again. This should only be sent if the +// client previously sent a resources/subscribe request. +type ResourceUpdatedNotificationParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The URI of the resource that has been updated. This might be a sub-resource of the one that the client actually subscribed to. + URI string `json:"uri"` +} + // TODO(jba): add CompleteRequest and related types. // TODO(jba): add ElicitRequest and related types. diff --git a/mcp/server.go b/mcp/server.go index e0f691dc..75e66dbc 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -13,6 +13,7 @@ import ( "fmt" "iter" "log" + "maps" "net/url" "path/filepath" "slices" @@ -43,6 +44,7 @@ type Server struct { sessions []*ServerSession sendingMethodHandler_ MethodHandler[*ServerSession] receivingMethodHandler_ MethodHandler[*ServerSession] + resourceSubscriptions map[string]map[*ServerSession]bool // uri -> session -> bool } // ServerOptions is used to configure behavior of the server. @@ -64,6 +66,10 @@ type ServerOptions struct { // If the peer fails to respond to pings originating from the keepalive check, // the session is automatically closed. KeepAlive time.Duration + // Function called when a client session subscribes to a resource. + SubscribeHandler func(context.Context, *SubscribeParams) error + // Function called when a client session unsubscribes from a resource. + UnsubscribeHandler func(context.Context, *UnsubscribeParams) error } // NewServer creates a new MCP server. The resulting server has no features: @@ -89,7 +95,12 @@ func NewServer(impl *Implementation, opts *ServerOptions) *Server { if opts.PageSize == 0 { opts.PageSize = DefaultPageSize } - + if opts.SubscribeHandler != nil && opts.UnsubscribeHandler == nil { + panic("SubscribeHandler requires UnsubscribeHandler") + } + if opts.UnsubscribeHandler != nil && opts.SubscribeHandler == nil { + panic("UnsubscribeHandler requires SubscribeHandler") + } return &Server{ impl: impl, opts: *opts, @@ -99,6 +110,7 @@ func NewServer(impl *Implementation, opts *ServerOptions) *Server { resourceTemplates: newFeatureSet(func(t *serverResourceTemplate) string { return t.resourceTemplate.URITemplate }), sendingMethodHandler_: defaultSendingMethodHandler[*ServerSession], receivingMethodHandler_: defaultReceivingMethodHandler[*ServerSession], + resourceSubscriptions: make(map[string]map[*ServerSession]bool), } } @@ -225,6 +237,9 @@ func (s *Server) capabilities() *serverCapabilities { } if s.resources.len() > 0 || s.resourceTemplates.len() > 0 { caps.Resources = &resourceCapabilities{ListChanged: true} + if s.opts.SubscribeHandler != nil { + caps.Resources.Subscribe = true + } } return caps } @@ -428,6 +443,57 @@ func fileResourceHandler(dir string) ResourceHandler { } } +// ResourceUpdated sends a notification to all clients that have subscribed to the +// resource specified in params. This method is the primary way for a +// server author to signal that a resource has changed. +func (s *Server) ResourceUpdated(ctx context.Context, params *ResourceUpdatedNotificationParams) error { + s.mu.Lock() + subscribedSessions := s.resourceSubscriptions[params.URI] + sessions := slices.Collect(maps.Keys(subscribedSessions)) + s.mu.Unlock() + notifySessions(sessions, notificationResourceUpdated, params) + return nil +} + +func (s *Server) subscribe(ctx context.Context, ss *ServerSession, params *SubscribeParams) (*emptyResult, error) { + if s.opts.SubscribeHandler == nil { + return nil, fmt.Errorf("%w: server does not support resource subscriptions", jsonrpc2.ErrMethodNotFound) + } + if err := s.opts.SubscribeHandler(ctx, params); err != nil { + return nil, err + } + + s.mu.Lock() + defer s.mu.Unlock() + if s.resourceSubscriptions[params.URI] == nil { + s.resourceSubscriptions[params.URI] = make(map[*ServerSession]bool) + } + s.resourceSubscriptions[params.URI][ss] = true + + return &emptyResult{}, nil +} + +func (s *Server) unsubscribe(ctx context.Context, ss *ServerSession, params *UnsubscribeParams) (*emptyResult, error) { + if s.opts.UnsubscribeHandler == nil { + return nil, jsonrpc2.ErrMethodNotFound + } + + if err := s.opts.UnsubscribeHandler(ctx, params); err != nil { + return nil, err + } + + s.mu.Lock() + defer s.mu.Unlock() + if subscribedSessions, ok := s.resourceSubscriptions[params.URI]; ok { + delete(subscribedSessions, ss) + if len(subscribedSessions) == 0 { + delete(s.resourceSubscriptions, params.URI) + } + } + + return &emptyResult{}, nil +} + // Run runs the server over the given transport, which must be persistent. // // Run blocks until the client terminates the connection or the provided @@ -475,6 +541,10 @@ func (s *Server) disconnect(cc *ServerSession) { s.sessions = slices.DeleteFunc(s.sessions, func(cc2 *ServerSession) bool { return cc2 == cc }) + + for _, subscribedSessions := range s.resourceSubscriptions { + delete(subscribedSessions, cc) + } } // Connect connects the MCP server over the given transport and starts handling @@ -616,6 +686,8 @@ var serverMethodInfos = map[string]methodInfo{ methodListResourceTemplates: newMethodInfo(serverMethod((*Server).listResourceTemplates)), methodReadResource: newMethodInfo(serverMethod((*Server).readResource)), methodSetLevel: newMethodInfo(sessionMethod((*ServerSession).setLevel)), + methodSubscribe: newMethodInfo(serverMethod((*Server).subscribe)), + methodUnsubscribe: newMethodInfo(serverMethod((*Server).unsubscribe)), notificationInitialized: newMethodInfo(serverMethod((*Server).callInitializedHandler)), notificationRootsListChanged: newMethodInfo(serverMethod((*Server).callRootsListChangedHandler)), notificationProgress: newMethodInfo(sessionMethod((*ServerSession).callProgressNotificationHandler)), diff --git a/mcp/server_test.go b/mcp/server_test.go index d4243d7c..bb539772 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -5,6 +5,7 @@ package mcp import ( + "context" "log" "slices" "testing" @@ -232,6 +233,7 @@ func TestServerCapabilities(t *testing.T) { testCases := []struct { name string configureServer func(s *Server) + serverOpts ServerOptions wantCapabilities *serverCapabilities }{ { @@ -275,6 +277,25 @@ func TestServerCapabilities(t *testing.T) { Resources: &resourceCapabilities{ListChanged: true}, }, }, + { + name: "With resource subscriptions", + configureServer: func(s *Server) { + s.AddResourceTemplate(&ResourceTemplate{URITemplate: "file:///rt"}, nil) + }, + serverOpts: ServerOptions{ + SubscribeHandler: func(ctx context.Context, sp *SubscribeParams) error { + return nil + }, + UnsubscribeHandler: func(ctx context.Context, up *UnsubscribeParams) error { + return nil + }, + }, + wantCapabilities: &serverCapabilities{ + Completions: &completionCapabilities{}, + Logging: &loggingCapabilities{}, + Resources: &resourceCapabilities{ListChanged: true, Subscribe: true}, + }, + }, { name: "With tools", configureServer: func(s *Server) { @@ -294,11 +315,19 @@ func TestServerCapabilities(t *testing.T) { s.AddResourceTemplate(&ResourceTemplate{URITemplate: "file:///rt"}, nil) s.AddTool(&Tool{Name: "t"}, nil) }, + serverOpts: ServerOptions{ + SubscribeHandler: func(ctx context.Context, sp *SubscribeParams) error { + return nil + }, + UnsubscribeHandler: func(ctx context.Context, up *UnsubscribeParams) error { + return nil + }, + }, wantCapabilities: &serverCapabilities{ Completions: &completionCapabilities{}, Logging: &loggingCapabilities{}, Prompts: &promptCapabilities{ListChanged: true}, - Resources: &resourceCapabilities{ListChanged: true}, + Resources: &resourceCapabilities{ListChanged: true, Subscribe: true}, Tools: &toolCapabilities{ListChanged: true}, }, }, @@ -306,7 +335,7 @@ func TestServerCapabilities(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - server := NewServer(testImpl, nil) + server := NewServer(testImpl, &tc.serverOpts) tc.configureServer(server) gotCapabilities := server.capabilities() if diff := cmp.Diff(tc.wantCapabilities, gotCapabilities); diff != "" {