diff --git a/src/fastmcp/client/transports.py b/src/fastmcp/client/transports.py index 3664c21d3..2cadd0631 100644 --- a/src/fastmcp/client/transports.py +++ b/src/fastmcp/client/transports.py @@ -852,29 +852,42 @@ async def connect_session( # Create a cancel scope for the server task async with anyio.create_task_group() as tg: - tg.start_soon( - lambda: self.server._mcp_server.run( - server_read, - server_write, - self.server._mcp_server.create_initialization_options(), - raise_exceptions=self.raise_exceptions, + async with _enter_server_lifespan(server=self.server): + tg.start_soon( + lambda: self.server._mcp_server.run( + server_read, + server_write, + self.server._mcp_server.create_initialization_options(), + raise_exceptions=self.raise_exceptions, + ) ) - ) - try: - async with ClientSession( - read_stream=client_read, - write_stream=client_write, - **session_kwargs, - ) as client_session: - yield client_session - finally: - tg.cancel_scope.cancel() + try: + async with ClientSession( + read_stream=client_read, + write_stream=client_write, + **session_kwargs, + ) as client_session: + yield client_session + finally: + tg.cancel_scope.cancel() def __repr__(self) -> str: return f"" +@contextlib.asynccontextmanager +async def _enter_server_lifespan( + server: FastMCP | FastMCP1Server, +) -> AsyncIterator[None]: + """Enters the server's lifespan context for FastMCP servers and does nothing for FastMCP 1 servers.""" + if isinstance(server, FastMCP): + async with server._lifespan_manager(): + yield + else: + yield + + class MCPConfigTransport(ClientTransport): """Transport for connecting to one or more MCP servers defined in an MCPConfig. diff --git a/src/fastmcp/server/http.py b/src/fastmcp/server/http.py index a5e41daf6..25264ce05 100644 --- a/src/fastmcp/server/http.py +++ b/src/fastmcp/server/http.py @@ -224,11 +224,17 @@ async def sse_endpoint(request: Request) -> Response: if middleware: server_middleware.extend(middleware) + @asynccontextmanager + async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: + async with server._lifespan_manager(): + yield + # Create and return the app app = create_base_app( routes=server_routes, middleware=server_middleware, debug=debug, + lifespan=lifespan, ) # Store the FastMCP server instance on the Starlette app state app.state.fastmcp_server = server @@ -320,8 +326,9 @@ def create_streamable_http_app( # Create a lifespan manager to start and stop the session manager @asynccontextmanager async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: - async with session_manager.run(): - yield + async with server._lifespan_manager(): + async with session_manager.run(): + yield # Create and return the app with lifespan app = create_base_app( diff --git a/src/fastmcp/server/server.py b/src/fastmcp/server/server.py index 25db16cd1..a8823a6d5 100644 --- a/src/fastmcp/server/server.py +++ b/src/fastmcp/server/server.py @@ -89,6 +89,10 @@ # Compiled URI parsing regex to split a URI into protocol and path components URI_PATTERN = re.compile(r"^([^:]+://)(.*?)$") +LifespanCallable = Callable[ + ["FastMCP[LifespanResultT]"], AbstractAsyncContextManager[LifespanResultT] +] + @asynccontextmanager async def default_lifespan(server: FastMCP[LifespanResultT]) -> AsyncIterator[Any]: @@ -98,26 +102,31 @@ async def default_lifespan(server: FastMCP[LifespanResultT]) -> AsyncIterator[An server: The server instance this lifespan is managing Returns: - An empty context object + An empty dictionary as the lifespan result. """ yield {} -def _lifespan_wrapper( - app: FastMCP[LifespanResultT], - lifespan: Callable[ - [FastMCP[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT] - ], +def _lifespan_proxy( + fastmcp_server: FastMCP[LifespanResultT], ) -> Callable[ [LowLevelServer[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT] ]: @asynccontextmanager async def wrap( - s: LowLevelServer[LifespanResultT], + low_level_server: LowLevelServer[LifespanResultT], ) -> AsyncIterator[LifespanResultT]: - async with AsyncExitStack() as stack: - context = await stack.enter_async_context(lifespan(app)) - yield context + if fastmcp_server._lifespan is default_lifespan: + yield {} + return + + if not fastmcp_server._lifespan_result_set: + raise RuntimeError( + "FastMCP server has a lifespan defined but no lifespan result is set, which means the server's context manager was not entered. " + + " Are you running the server in a way that supports lifespans? If so, please file an issue at https://github.com/jlowin/fastmcp/issues." + ) + + yield fastmcp_server._lifespan_result return wrap @@ -131,13 +140,7 @@ def __init__( version: str | None = None, auth: AuthProvider | None | NotSetT = NotSet, middleware: list[Middleware] | None = None, - lifespan: ( - Callable[ - [FastMCP[LifespanResultT]], - AbstractAsyncContextManager[LifespanResultT], - ] - | None - ) = None, + lifespan: LifespanCallable | None = None, dependencies: list[str] | None = None, resource_prefix_format: Literal["protocol", "path"] | None = None, mask_error_details: bool | None = None, @@ -188,18 +191,17 @@ def __init__( ) self._tool_serializer = tool_serializer - if lifespan is None: - self._has_lifespan = False - lifespan = default_lifespan - else: - self._has_lifespan = True + self._lifespan: LifespanCallable[LifespanResultT] = lifespan or default_lifespan + self._lifespan_result: LifespanResultT | None = None + self._lifespan_result_set = False + # Generate random ID if no name provided self._mcp_server = LowLevelServer[LifespanResultT]( fastmcp=self, name=name or self.generate_name(), version=version or fastmcp.__version__, instructions=instructions, - lifespan=_lifespan_wrapper(self, lifespan), + lifespan=_lifespan_proxy(fastmcp_server=self), ) # if auth is `NotSet`, try to create a provider from the environment @@ -334,6 +336,27 @@ def instructions(self, value: str | None) -> None: def version(self) -> str | None: return self._mcp_server.version + @asynccontextmanager + async def _lifespan_manager(self) -> AsyncIterator[None]: + if self._lifespan_result_set: + yield + return + + async with self._lifespan(self) as lifespan_result: + self._lifespan_result = lifespan_result + self._lifespan_result_set = True + + async with AsyncExitStack[bool | None]() as stack: + for server in self._mounted_servers: + await stack.enter_async_context( + cm=server.server._lifespan_manager() + ) + + yield + + self._lifespan_result_set = False + self._lifespan_result = None + async def run_async( self, transport: Transport | None = None, @@ -1880,15 +1903,18 @@ async def run_stdio_async( ) with temporary_log_level(log_level): - async with stdio_server() as (read_stream, write_stream): - logger.info(f"Starting MCP server {self.name!r} with transport 'stdio'") - await self._mcp_server.run( - read_stream, - write_stream, - self._mcp_server.create_initialization_options( - NotificationOptions(tools_changed=True) - ), - ) + async with self._lifespan_manager(): + async with stdio_server() as (read_stream, write_stream): + logger.info( + f"Starting MCP server {self.name!r} with transport 'stdio'" + ) + await self._mcp_server.run( + read_stream, + write_stream, + self._mcp_server.create_initialization_options( + NotificationOptions(tools_changed=True) + ), + ) async def run_http_async( self, @@ -1959,14 +1985,15 @@ async def run_http_async( config_kwargs["log_level"] = default_log_level_to_use with temporary_log_level(log_level): - config = uvicorn.Config(app, host=host, port=port, **config_kwargs) - server = uvicorn.Server(config) - path = app.state.path.lstrip("/") # type: ignore - logger.info( - f"Starting MCP server {self.name!r} with transport {transport!r} on http://{host}:{port}/{path}" - ) + async with self._lifespan_manager(): + config = uvicorn.Config(app, host=host, port=port, **config_kwargs) + server = uvicorn.Server(config) + path = app.state.path.lstrip("/") # type: ignore + logger.info( + f"Starting MCP server {self.name!r} with transport {transport!r} on http://{host}:{port}/{path}" + ) - await server.serve() + await server.serve() async def run_sse_async( self, @@ -2228,7 +2255,7 @@ def mount( # if as_proxy is not specified and the server has a custom lifespan, # we should treat it as a proxy if as_proxy is None: - as_proxy = server._has_lifespan + as_proxy = server._lifespan != default_lifespan if as_proxy and not isinstance(server, FastMCPProxy): server = FastMCP.as_proxy(server) @@ -2362,6 +2389,15 @@ async def import_server( prompt = prompt.model_copy(key=f"{prefix}_{key}") self._prompt_manager.add_prompt(prompt) + if server._lifespan != default_lifespan: + from warnings import warn + + warn( + message="When importing from a server with a lifespan, the lifespan from the imported server will not be used.", + category=RuntimeWarning, + stacklevel=2, + ) + if prefix: logger.debug( f"[{self.name}] Imported server {server.name} with prefix '{prefix}'" diff --git a/tests/client/test_sse.py b/tests/client/test_sse.py index f2fe86605..818f233fd 100644 --- a/tests/client/test_sse.py +++ b/tests/client/test_sse.py @@ -92,7 +92,8 @@ async def test_http_headers(sse_server: str): def run_nested_server(host: str, port: int) -> None: - app = fastmcp_server().sse_app(path="/mcp/sse/", message_path="/mcp/messages") + fastmcp = fastmcp_server() + app = fastmcp.sse_app(path="/mcp/sse/", message_path="/mcp/messages") mount = Starlette(routes=[Mount("/nest-inner", app=app)]) mount2 = Starlette(routes=[Mount("/nest-outer", app=mount)]) server = uvicorn.Server( diff --git a/tests/server/test_mount.py b/tests/server/test_mount.py index e3fd3abd1..d02e86b6c 100644 --- a/tests/server/test_mount.py +++ b/tests/server/test_mount.py @@ -888,15 +888,18 @@ async def test_as_proxy_true(self): assert isinstance(mcp._mounted_servers[0].server, FastMCPProxy) async def test_as_proxy_defaults_true_if_lifespan(self): + """Test that as_proxy defaults to True when server_lifespan is provided.""" + @asynccontextmanager - async def lifespan(mcp: FastMCP): + async def server_lifespan(mcp: FastMCP): yield mcp = FastMCP("Main") - sub = FastMCP("Sub", lifespan=lifespan) + sub = FastMCP("Sub", lifespan=server_lifespan) mcp.mount(sub, "sub") + # Should auto-proxy because lifespan is set assert mcp._mounted_servers[0].server is not sub assert isinstance(mcp._mounted_servers[0].server, FastMCPProxy) diff --git a/tests/server/test_server_lifespan.py b/tests/server/test_server_lifespan.py new file mode 100644 index 000000000..a4373a4a7 --- /dev/null +++ b/tests/server/test_server_lifespan.py @@ -0,0 +1,68 @@ +"""Tests for server_lifespan and session_lifespan behavior.""" + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from typing import Any + +from fastmcp import Client, FastMCP +from fastmcp.server.context import Context + + +class TestServerLifespan: + """Test server_lifespan functionality.""" + + async def test_server_lifespan_basic(self): + """Test that server_lifespan is entered once and persists across sessions.""" + lifespan_events: list[str] = [] + + @asynccontextmanager + async def server_lifespan(mcp: FastMCP) -> AsyncIterator[dict[str, Any]]: + _ = lifespan_events.append("enter") + yield {"initialized": True} + _ = lifespan_events.append("exit") + + mcp = FastMCP("TestServer", lifespan=server_lifespan) + + @mcp.tool + def get_value() -> str: + return "test" + + # Server lifespan should be entered when run_async starts + assert lifespan_events == [] + + # Connect first client session + async with Client(mcp) as client1: + result1 = await client1.call_tool("get_value", {}) + assert result1.data == "test" + # Server lifespan should have been entered once + assert lifespan_events == ["enter"] + + # Connect second client session while first is still active + async with Client(mcp) as client2: + result2 = await client2.call_tool("get_value", {}) + assert result2.data == "test" + # Server lifespan should still only have been entered once + assert lifespan_events == ["enter"] + + # Because we're using a fastmcptransport, the server lifespan should be exited + # when the client session closes + assert lifespan_events == ["enter", "exit"] + + async def test_server_lifespan_context_available(self): + """Test that server_lifespan context is available to tools.""" + + @asynccontextmanager + async def server_lifespan(mcp: FastMCP) -> AsyncIterator[dict]: + yield {"db_connection": "mock_db"} + + mcp = FastMCP("TestServer", lifespan=server_lifespan) + + @mcp.tool + def get_db_info(ctx: Context) -> str: + # Access the server lifespan context + lifespan_context = ctx.request_context.lifespan_context + return lifespan_context.get("db_connection", "no_db") + + async with Client(mcp) as client: + result = await client.call_tool("get_db_info", {}) + assert result.data == "mock_db"