From 10636349b43a13d49d4477a318ebeaea168eb31d Mon Sep 17 00:00:00 2001 From: Greg Spencer Date: Thu, 26 Jun 2025 11:33:46 -0700 Subject: [PATCH 1/2] Adding support for roots changed notification and initialized notification. --- src/mcp/server/lowlevel/server.py | 62 +++++- ...notifications.py => test_notifications.py} | 191 +++++++++++++++++- 2 files changed, 242 insertions(+), 11 deletions(-) rename tests/shared/{test_progress_notifications.py => test_notifications.py} (63%) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index faad95aca..8ae261117 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -68,6 +68,7 @@ async def main(): from __future__ import annotations as _annotations import contextvars +import inspect import json import logging import warnings @@ -104,6 +105,9 @@ async def main(): # This will be properly typed in each Server instance's context request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = contextvars.ContextVar("request_ctx") +# Context variable to hold the current ServerSession, accessible by notification handlers +current_session_ctx: contextvars.ContextVar[ServerSession] = contextvars.ContextVar("current_server_session") + class NotificationOptions: def __init__( @@ -520,6 +524,36 @@ async def handler(req: types.ProgressNotification): return decorator + def initialized_notification(self): + """Decorator to register a handler for InitializedNotification.""" + + def decorator( + func: ( + Callable[[types.InitializedNotification, ServerSession], Awaitable[None]] + | Callable[[types.InitializedNotification], Awaitable[None]] + ), + ): + logger.debug("Registering handler for InitializedNotification") + self.notification_handlers[types.InitializedNotification] = func + return func + + return decorator + + def roots_list_changed_notification(self): + """Decorator to register a handler for RootsListChangedNotification.""" + + def decorator( + func: ( + Callable[[types.RootsListChangedNotification, ServerSession], Awaitable[None]] + | Callable[[types.RootsListChangedNotification], Awaitable[None]] + ), + ): + logger.debug("Registering handler for RootsListChangedNotification") + self.notification_handlers[types.RootsListChangedNotification] = func + return func + + return decorator + def completion(self): """Provides completions for prompts and resource templates""" @@ -591,22 +625,26 @@ async def run( async def _handle_message( self, - message: RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception, + message: (RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception), session: ServerSession, lifespan_context: LifespanResultT, raise_exceptions: bool = False, ): - with warnings.catch_warnings(record=True) as w: - # TODO(Marcelo): We should be checking if message is Exception here. - match message: # type: ignore[reportMatchNotExhaustive] - case RequestResponder(request=types.ClientRequest(root=req)) as responder: - with responder: - await self._handle_request(message, req, session, lifespan_context, raise_exceptions) - case types.ClientNotification(root=notify): - await self._handle_notification(notify) + session_token = current_session_ctx.set(session) + try: + with warnings.catch_warnings(record=True) as w: + # TODO(Marcelo): We should be checking if message is Exception here. + match message: # type: ignore[reportMatchNotExhaustive] + case RequestResponder(request=types.ClientRequest(root=req)) as responder: + with responder: + await self._handle_request(message, req, session, lifespan_context, raise_exceptions) + case types.ClientNotification(root=notify): + await self._handle_notification(notify) for warning in w: logger.info("Warning: %s: %s", warning.category.__name__, warning.message) + finally: + current_session_ctx.reset(session_token) async def _handle_request( self, @@ -666,7 +704,11 @@ async def _handle_notification(self, notify: Any): logger.debug("Dispatching notification of type %s", type(notify).__name__) try: - await handler(notify) + sig = inspect.signature(handler) + if "session" in sig.parameters: + await handler(notify, current_session_ctx.get()) + else: + await handler(notify) except Exception: logger.exception("Uncaught exception in notification handler") diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_notifications.py similarity index 63% rename from tests/shared/test_progress_notifications.py rename to tests/shared/test_notifications.py index 08bcb2662..fe835cd9e 100644 --- a/tests/shared/test_progress_notifications.py +++ b/tests/shared/test_notifications.py @@ -1,3 +1,4 @@ +import asyncio from typing import Any, cast import anyio @@ -10,11 +11,11 @@ from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession from mcp.shared.context import RequestContext +from mcp.shared.message import SessionMessage from mcp.shared.progress import progress from mcp.shared.session import ( BaseSession, RequestResponder, - SessionMessage, ) @@ -333,3 +334,191 @@ async def handle_client_message( assert server_progress_updates[3]["progress"] == 100 assert server_progress_updates[3]["total"] == 100 assert server_progress_updates[3]["message"] == "Processing results..." + + +@pytest.mark.anyio +async def test_initialized_notification(): + """Test that the server receives and handles InitializedNotification.""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) + + server = Server("test") + initialized_received = asyncio.Event() + + @server.initialized_notification() + async def handle_initialized(notification: types.InitializedNotification): + initialized_received.set() + + async def run_server(): + await server.run( + client_to_server_receive, + server_to_client_send, + server.create_initialization_options(), + ) + + async def message_handler( + message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception), + ) -> None: + if isinstance(message, Exception): + raise message + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session, + anyio.create_task_group() as tg, + ): + tg.start_soon(run_server) + await client_session.initialize() + await initialized_received.wait() + tg.cancel_scope.cancel() + + assert initialized_received.is_set() + + +@pytest.mark.anyio +async def test_roots_list_changed_notification(): + """Test that the server receives and handles RootsListChangedNotification.""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) + + server = Server("test") + roots_list_changed_received = asyncio.Event() + + @server.roots_list_changed_notification() + async def handle_roots_list_changed( + notification: types.RootsListChangedNotification, + ): + roots_list_changed_received.set() + + async def run_server(): + await server.run( + client_to_server_receive, + server_to_client_send, + server.create_initialization_options(), + ) + + async def message_handler( + message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception), + ) -> None: + if isinstance(message, Exception): + raise message + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session, + anyio.create_task_group() as tg, + ): + tg.start_soon(run_server) + await client_session.initialize() + await client_session.send_notification( + types.ClientNotification( + root=types.RootsListChangedNotification(method="notifications/roots/list_changed", params=None) + ) + ) + await roots_list_changed_received.wait() + tg.cancel_scope.cancel() + + assert roots_list_changed_received.is_set() + + +@pytest.mark.anyio +async def test_initialized_notification_with_session(): + """Test that the server receives and handles InitializedNotification with a session.""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) + + server = Server("test") + initialized_received = asyncio.Event() + received_session = None + + @server.initialized_notification() + async def handle_initialized(notification: types.InitializedNotification, session: ServerSession): + nonlocal received_session + received_session = session + initialized_received.set() + + async def run_server(): + await server.run( + client_to_server_receive, + server_to_client_send, + server.create_initialization_options(), + ) + + async def message_handler( + message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception), + ) -> None: + if isinstance(message, Exception): + raise message + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session, + anyio.create_task_group() as tg, + ): + tg.start_soon(run_server) + await client_session.initialize() + await initialized_received.wait() + tg.cancel_scope.cancel() + + assert initialized_received.is_set() + assert isinstance(received_session, ServerSession) + + +@pytest.mark.anyio +async def test_roots_list_changed_notification_with_session(): + """Test that the server receives and handles RootsListChangedNotification with a session.""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) + + server = Server("test") + roots_list_changed_received = asyncio.Event() + received_session = None + + @server.roots_list_changed_notification() + async def handle_roots_list_changed(notification: types.RootsListChangedNotification, session: ServerSession): + nonlocal received_session + received_session = session + roots_list_changed_received.set() + + async def run_server(): + await server.run( + client_to_server_receive, + server_to_client_send, + server.create_initialization_options(), + ) + + async def message_handler( + message: (RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception), + ) -> None: + if isinstance(message, Exception): + raise message + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as client_session, + anyio.create_task_group() as tg, + ): + tg.start_soon(run_server) + await client_session.initialize() + await client_session.send_notification( + types.ClientNotification( + root=types.RootsListChangedNotification(method="notifications/roots/list_changed", params=None) + ) + ) + await roots_list_changed_received.wait() + tg.cancel_scope.cancel() + + assert roots_list_changed_received.is_set() + assert isinstance(received_session, ServerSession) From 7b3a96acd8acec1a220f0656dcd32a576a7b6f2a Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sat, 26 Jul 2025 10:37:48 +0200 Subject: [PATCH 2/2] Unpack settings in FastMCP (#1198) --- src/mcp/server/fastmcp/server.py | 118 +++++++++++++++++------------ src/mcp/server/lowlevel/server.py | 4 +- tests/shared/test_notifications.py | 4 +- 3 files changed, 75 insertions(+), 51 deletions(-) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 2fe7c1224..924baaa9b 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -4,16 +4,13 @@ import inspect import re -from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence -from contextlib import ( - AbstractAsyncContextManager, - asynccontextmanager, -) +from collections.abc import AsyncIterator, Awaitable, Callable, Collection, Iterable, Sequence +from contextlib import AbstractAsyncContextManager, asynccontextmanager from typing import Any, Generic, Literal import anyio import pydantic_core -from pydantic import BaseModel, Field +from pydantic import BaseModel from pydantic.networks import AnyUrl from pydantic_settings import BaseSettings, SettingsConfigDict from starlette.applications import Starlette @@ -25,10 +22,7 @@ from starlette.types import Receive, Scope, Send from mcp.server.auth.middleware.auth_context import AuthContextMiddleware -from mcp.server.auth.middleware.bearer_auth import ( - BearerAuthBackend, - RequireAuthMiddleware, -) +from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend, RequireAuthMiddleware from mcp.server.auth.provider import OAuthAuthorizationServerProvider, ProviderTokenVerifier, TokenVerifier from mcp.server.auth.settings import AuthSettings from mcp.server.elicitation import ElicitationResult, ElicitSchemaModelT, elicit_with_validation @@ -48,12 +42,7 @@ from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings from mcp.shared.context import LifespanContextT, RequestContext, RequestT -from mcp.types import ( - AnyFunction, - ContentBlock, - GetPromptResult, - ToolAnnotations, -) +from mcp.types import AnyFunction, ContentBlock, GetPromptResult, ToolAnnotations from mcp.types import Prompt as MCPPrompt from mcp.types import PromptArgument as MCPPromptArgument from mcp.types import Resource as MCPResource @@ -79,58 +68,57 @@ class Settings(BaseSettings, Generic[LifespanResultT]): ) # Server settings - debug: bool = False - log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO" + debug: bool + log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] # HTTP settings - host: str = "127.0.0.1" - port: int = 8000 - mount_path: str = "/" # Mount path (e.g. "/github", defaults to root path) - sse_path: str = "/sse" - message_path: str = "/messages/" - streamable_http_path: str = "/mcp" + host: str + port: int + mount_path: str + sse_path: str + message_path: str + streamable_http_path: str # StreamableHTTP settings - json_response: bool = False - stateless_http: bool = False # If True, uses true stateless mode (new transport per request) + json_response: bool + stateless_http: bool + """Define if the server should create a new transport per request.""" # resource settings - warn_on_duplicate_resources: bool = True + warn_on_duplicate_resources: bool # tool settings - warn_on_duplicate_tools: bool = True + warn_on_duplicate_tools: bool # prompt settings - warn_on_duplicate_prompts: bool = True + warn_on_duplicate_prompts: bool - dependencies: list[str] = Field( - default_factory=list, - description="List of dependencies to install in the server environment", - ) + # TODO(Marcelo): Investigate if this is used. If it is, it's probably a good idea to remove it. + dependencies: list[str] + """A list of dependencies to install in the server environment.""" - lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]] | None = Field( - None, description="Lifespan context manager" - ) + lifespan: Callable[[FastMCP[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]] | None + """A async context manager that will be called when the server is started.""" - auth: AuthSettings | None = None + auth: AuthSettings | None # Transport security settings (DNS rebinding protection) - transport_security: TransportSecuritySettings | None = None + transport_security: TransportSecuritySettings | None def lifespan_wrapper( - app: FastMCP, - lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]], -) -> Callable[[MCPServer[LifespanResultT, Request]], AbstractAsyncContextManager[object]]: + app: FastMCP[LifespanResultT], + lifespan: Callable[[FastMCP[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]], +) -> Callable[[MCPServer[LifespanResultT, Request]], AbstractAsyncContextManager[LifespanResultT]]: @asynccontextmanager - async def wrap(s: MCPServer[LifespanResultT, Request]) -> AsyncIterator[object]: + async def wrap(_: MCPServer[LifespanResultT, Request]) -> AsyncIterator[LifespanResultT]: async with lifespan(app) as context: yield context return wrap -class FastMCP: +class FastMCP(Generic[LifespanResultT]): def __init__( self, name: str | None = None, @@ -140,14 +128,50 @@ def __init__( event_store: EventStore | None = None, *, tools: list[Tool] | None = None, - **settings: Any, + debug: bool = False, + log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO", + host: str = "127.0.0.1", + port: int = 8000, + mount_path: str = "/", + sse_path: str = "/sse", + message_path: str = "/messages/", + streamable_http_path: str = "/mcp", + json_response: bool = False, + stateless_http: bool = False, + warn_on_duplicate_resources: bool = True, + warn_on_duplicate_tools: bool = True, + warn_on_duplicate_prompts: bool = True, + dependencies: Collection[str] = (), + lifespan: Callable[[FastMCP[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]] | None = None, + auth: AuthSettings | None = None, + transport_security: TransportSecuritySettings | None = None, ): - self.settings = Settings(**settings) + self.settings = Settings( + debug=debug, + log_level=log_level, + host=host, + port=port, + mount_path=mount_path, + sse_path=sse_path, + message_path=message_path, + streamable_http_path=streamable_http_path, + json_response=json_response, + stateless_http=stateless_http, + warn_on_duplicate_resources=warn_on_duplicate_resources, + warn_on_duplicate_tools=warn_on_duplicate_tools, + warn_on_duplicate_prompts=warn_on_duplicate_prompts, + dependencies=list(dependencies), + lifespan=lifespan, + auth=auth, + transport_security=transport_security, + ) self._mcp_server = MCPServer( name=name or "FastMCP", instructions=instructions, - lifespan=(lifespan_wrapper(self, self.settings.lifespan) if self.settings.lifespan else default_lifespan), + # TODO(Marcelo): It seems there's a type mismatch between the lifespan type from an FastMCP and Server. + # We need to create a Lifespan type that is a generic on the server type, like Starlette does. + lifespan=(lifespan_wrapper(self, self.settings.lifespan) if self.settings.lifespan else default_lifespan), # type: ignore ) self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools) self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources) @@ -257,7 +281,7 @@ async def list_tools(self) -> list[MCPTool]: for info in tools ] - def get_context(self) -> Context[ServerSession, object, Request]: + def get_context(self) -> Context[ServerSession, LifespanResultT, Request]: """ Returns a Context object. Note that the context will only be valid during a request; outside a request, most methods will error. diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index e13e7106a..b139e1ef4 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -94,7 +94,7 @@ async def main(): logger = logging.getLogger(__name__) -LifespanResultT = TypeVar("LifespanResultT") +LifespanResultT = TypeVar("LifespanResultT", default=Any) RequestT = TypeVar("RequestT", default=Any) # type aliases for tool call results @@ -122,7 +122,7 @@ def __init__( @asynccontextmanager -async def lifespan(server: Server[LifespanResultT, RequestT]) -> AsyncIterator[object]: +async def lifespan(_: Server[LifespanResultT, RequestT]) -> AsyncIterator[dict[str, Any]]: """Default lifespan context manager that does nothing. Args: diff --git a/tests/shared/test_notifications.py b/tests/shared/test_notifications.py index fe835cd9e..28204c8f4 100644 --- a/tests/shared/test_notifications.py +++ b/tests/shared/test_notifications.py @@ -43,7 +43,7 @@ async def run_server(): serv_sesh = server_session async for message in server_session.incoming_messages: try: - await server._handle_message(message, server_session, ()) + await server._handle_message(message, server_session, {}) except Exception as e: raise e @@ -253,7 +253,7 @@ async def run_server(): ) as server_session: async for message in server_session.incoming_messages: try: - await server._handle_message(message, server_session, ()) + await server._handle_message(message, server_session, {}) except Exception as e: raise e