Skip to content
45 changes: 29 additions & 16 deletions src/fastmcp/client/transports.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,29 +811,42 @@ async def connect_session(

# Create a cancel scope for the server task
async with anyio.create_task_group() as tg:
tg.start_soon(
lambda: self.server._mcp_server.run(
server_read,
server_write,
self.server._mcp_server.create_initialization_options(),
raise_exceptions=self.raise_exceptions,
async with _enter_server_lifespan(server=self.server):
tg.start_soon(
lambda: self.server._mcp_server.run(
server_read,
server_write,
self.server._mcp_server.create_initialization_options(),
raise_exceptions=self.raise_exceptions,
)
)
)

try:
async with ClientSession(
read_stream=client_read,
write_stream=client_write,
**session_kwargs,
) as client_session:
yield client_session
finally:
tg.cancel_scope.cancel()
try:
async with ClientSession(
read_stream=client_read,
write_stream=client_write,
**session_kwargs,
) as client_session:
yield client_session
finally:
tg.cancel_scope.cancel()

def __repr__(self) -> str:
return f"<FastMCPTransport(server='{self.server.name}')>"


@contextlib.asynccontextmanager
async def _enter_server_lifespan(
server: FastMCP | FastMCP1Server,
) -> AsyncIterator[None]:
"""Enters the server's lifespan context for FastMCP servers and does nothing for FastMCP 1 servers."""
if isinstance(server, FastMCP):
async with server._lifespan_manager():
yield
else:
yield


class MCPConfigTransport(ClientTransport):
"""Transport for connecting to one or more MCP servers defined in an MCPConfig.

Expand Down
11 changes: 9 additions & 2 deletions src/fastmcp/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,17 @@ async def sse_endpoint(request: Request) -> Response:
if middleware:
server_middleware.extend(middleware)

@asynccontextmanager
async def lifespan(app: Starlette) -> AsyncGenerator[None, None]:
async with server._lifespan_manager():
yield

# Create and return the app
app = create_base_app(
routes=server_routes,
middleware=server_middleware,
debug=debug,
lifespan=lifespan,
)
# Store the FastMCP server instance on the Starlette app state
app.state.fastmcp_server = server
Expand Down Expand Up @@ -320,8 +326,9 @@ def create_streamable_http_app(
# Create a lifespan manager to start and stop the session manager
@asynccontextmanager
async def lifespan(app: Starlette) -> AsyncGenerator[None, None]:
async with session_manager.run():
yield
async with server._lifespan_manager():
async with session_manager.run():
yield

# Create and return the app with lifespan
app = create_base_app(
Expand Down
116 changes: 74 additions & 42 deletions src/fastmcp/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,35 +89,40 @@
# Compiled URI parsing regex to split a URI into protocol and path components
URI_PATTERN = re.compile(r"^([^:]+://)(.*?)$")

LifespanCallable = Callable[
["FastMCP[LifespanResultT]"], AbstractAsyncContextManager[LifespanResultT]
]


@asynccontextmanager
async def default_lifespan(server: FastMCP[LifespanResultT]) -> AsyncIterator[Any]:
"""Default lifespan context manager that does nothing.

Args:
server: The server instance this lifespan is managing

Returns:
An empty context object
"""
yield {}


def _lifespan_wrapper(
app: FastMCP[LifespanResultT],
lifespan: Callable[
[FastMCP[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]
],
def _lifespan_proxy(
fastmcp_server: FastMCP[LifespanResultT],
) -> Callable[
[LowLevelServer[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]
]:
@asynccontextmanager
async def wrap(
s: LowLevelServer[LifespanResultT],
low_level_server: LowLevelServer[LifespanResultT],
) -> AsyncIterator[LifespanResultT]:
async with AsyncExitStack() as stack:
context = await stack.enter_async_context(lifespan(app))
yield context
if fastmcp_server._lifespan == default_lifespan:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this equality work for context managers? does it need to be is?

yield {}
return

if not fastmcp_server._lifespan_result_set:
raise RuntimeError(
"FastMCP server has a lifespan but no lifespan result is set, which means the server's context manager was not entered."
)

yield fastmcp_server._lifespan_result

return wrap

Expand All @@ -131,13 +136,7 @@ def __init__(
version: str | None = None,
auth: AuthProvider | None | NotSetT = NotSet,
middleware: list[Middleware] | None = None,
lifespan: (
Callable[
[FastMCP[LifespanResultT]],
AbstractAsyncContextManager[LifespanResultT],
]
| None
) = None,
lifespan: LifespanCallable | None = None,
dependencies: list[str] | None = None,
resource_prefix_format: Literal["protocol", "path"] | None = None,
mask_error_details: bool | None = None,
Expand Down Expand Up @@ -188,18 +187,17 @@ def __init__(
)
self._tool_serializer = tool_serializer

if lifespan is None:
self._has_lifespan = False
lifespan = default_lifespan
else:
self._has_lifespan = True
self._lifespan: LifespanCallable[LifespanResultT] = lifespan or default_lifespan
self._lifespan_result: LifespanResultT | None = None
self._lifespan_result_set = False

# Generate random ID if no name provided
self._mcp_server = LowLevelServer[LifespanResultT](
fastmcp=self,
name=name or self.generate_name(),
version=version or fastmcp.__version__,
instructions=instructions,
lifespan=_lifespan_wrapper(self, lifespan),
lifespan=_lifespan_proxy(fastmcp_server=self),
)

# if auth is `NotSet`, try to create a provider from the environment
Expand Down Expand Up @@ -334,6 +332,27 @@ def instructions(self, value: str | None) -> None:
def version(self) -> str | None:
return self._mcp_server.version

@asynccontextmanager
async def _lifespan_manager(self) -> AsyncIterator[None]:
if self._lifespan_result_set:
yield
return

async with self._lifespan(self) as lifespan_result:
self._lifespan_result = lifespan_result
self._lifespan_result_set = True

async with AsyncExitStack[bool | None]() as stack:
for server in self._mounted_servers:
await stack.enter_async_context(
cm=server.server._lifespan_manager()
)

yield

self._lifespan_result_set = False
self._lifespan_result = None

async def run_async(
self,
transport: Transport | None = None,
Expand Down Expand Up @@ -1856,15 +1875,18 @@ async def run_stdio_async(
)

with temporary_log_level(log_level):
async with stdio_server() as (read_stream, write_stream):
logger.info(f"Starting MCP server {self.name!r} with transport 'stdio'")
await self._mcp_server.run(
read_stream,
write_stream,
self._mcp_server.create_initialization_options(
NotificationOptions(tools_changed=True)
),
)
async with self._lifespan_manager():
async with stdio_server() as (read_stream, write_stream):
logger.info(
f"Starting MCP server {self.name!r} with transport 'stdio'"
)
await self._mcp_server.run(
read_stream,
write_stream,
self._mcp_server.create_initialization_options(
NotificationOptions(tools_changed=True)
),
)

async def run_http_async(
self,
Expand Down Expand Up @@ -1935,14 +1957,15 @@ async def run_http_async(
config_kwargs["log_level"] = default_log_level_to_use

with temporary_log_level(log_level):
config = uvicorn.Config(app, host=host, port=port, **config_kwargs)
server = uvicorn.Server(config)
path = app.state.path.lstrip("/") # type: ignore
logger.info(
f"Starting MCP server {self.name!r} with transport {transport!r} on http://{host}:{port}/{path}"
)
async with self._lifespan_manager():
config = uvicorn.Config(app, host=host, port=port, **config_kwargs)
server = uvicorn.Server(config)
path = app.state.path.lstrip("/") # type: ignore
logger.info(
f"Starting MCP server {self.name!r} with transport {transport!r} on http://{host}:{port}/{path}"
)

await server.serve()
await server.serve()

async def run_sse_async(
self,
Expand Down Expand Up @@ -2204,7 +2227,7 @@ def mount(
# if as_proxy is not specified and the server has a custom lifespan,
# we should treat it as a proxy
if as_proxy is None:
as_proxy = server._has_lifespan
as_proxy = server._lifespan != default_lifespan

if as_proxy and not isinstance(server, FastMCPProxy):
server = FastMCP.as_proxy(server)
Expand Down Expand Up @@ -2338,6 +2361,15 @@ async def import_server(
prompt = prompt.model_copy(key=f"{prefix}_{key}")
self._prompt_manager.add_prompt(prompt)

if server._lifespan != default_lifespan:
from warnings import warn

warn(
message="When importing from a server with a lifespan, the lifespan from the imported server will not be used. ",
category=RuntimeWarning,
stacklevel=2,
)

if prefix:
logger.debug(
f"[{self.name}] Imported server {server.name} with prefix '{prefix}'"
Expand Down
2 changes: 1 addition & 1 deletion src/fastmcp/utilities/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def run_server_in_process(
elif attempt < 15:
time.sleep(0.1)
else:
time.sleep(0.2)
time.sleep(2)
attempt += 1
else:
raise RuntimeError(f"Server failed to start after {max_attempts} attempts")
Expand Down
3 changes: 2 additions & 1 deletion tests/client/test_sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ async def test_http_headers(sse_server: str):


def run_nested_server(host: str, port: int) -> None:
app = fastmcp_server().sse_app(path="/mcp/sse/", message_path="/mcp/messages")
fastmcp = fastmcp_server()
app = fastmcp.sse_app(path="/mcp/sse/", message_path="/mcp/messages")
mount = Starlette(routes=[Mount("/nest-inner", app=app)])
mount2 = Starlette(routes=[Mount("/nest-outer", app=mount)])
server = uvicorn.Server(
Expand Down
7 changes: 5 additions & 2 deletions tests/server/test_mount.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,15 +888,18 @@ async def test_as_proxy_true(self):
assert isinstance(mcp._mounted_servers[0].server, FastMCPProxy)

async def test_as_proxy_defaults_true_if_lifespan(self):
"""Test that as_proxy defaults to True when server_lifespan is provided."""

@asynccontextmanager
async def lifespan(mcp: FastMCP):
async def server_lifespan(mcp: FastMCP):
yield

mcp = FastMCP("Main")
sub = FastMCP("Sub", lifespan=lifespan)
sub = FastMCP("Sub", lifespan=server_lifespan)

mcp.mount(sub, "sub")

# Should auto-proxy because lifespan is set
assert mcp._mounted_servers[0].server is not sub
assert isinstance(mcp._mounted_servers[0].server, FastMCPProxy)

Expand Down
68 changes: 68 additions & 0 deletions tests/server/test_server_lifespan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Tests for server_lifespan and session_lifespan behavior."""

from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Any

from fastmcp import Client, FastMCP
from fastmcp.server.context import Context


class TestServerLifespan:
"""Test server_lifespan functionality."""

async def test_server_lifespan_basic(self):
"""Test that server_lifespan is entered once and persists across sessions."""
lifespan_events: list[str] = []

@asynccontextmanager
async def server_lifespan(mcp: FastMCP) -> AsyncIterator[dict[str, Any]]:
_ = lifespan_events.append("enter")
yield {"initialized": True}
_ = lifespan_events.append("exit")

mcp = FastMCP("TestServer", lifespan=server_lifespan)

@mcp.tool
def get_value() -> str:
return "test"

# Server lifespan should be entered when run_async starts
assert lifespan_events == []

# Connect first client session
async with Client(mcp) as client1:
result1 = await client1.call_tool("get_value", {})
assert result1.data == "test"
# Server lifespan should have been entered once
assert lifespan_events == ["enter"]

# Connect second client session while first is still active
async with Client(mcp) as client2:
result2 = await client2.call_tool("get_value", {})
assert result2.data == "test"
# Server lifespan should still only have been entered once
assert lifespan_events == ["enter"]

# Because we're using a fastmcptransport, the server lifespan should be exited
# when the client session closes
assert lifespan_events == ["enter", "exit"]

async def test_server_lifespan_context_available(self):
"""Test that server_lifespan context is available to tools."""

@asynccontextmanager
async def server_lifespan(mcp: FastMCP) -> AsyncIterator[dict]:
yield {"db_connection": "mock_db"}

mcp = FastMCP("TestServer", lifespan=server_lifespan)

@mcp.tool
def get_db_info(ctx: Context) -> str:
# Access the server lifespan context
lifespan_context = ctx.request_context.lifespan_context
return lifespan_context.get("db_connection", "no_db")

async with Client(mcp) as client:
result = await client.call_tool("get_db_info", {})
assert result.data == "mock_db"