Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/fastmcp/client/transports.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,8 @@ async def connect_session(

# Create a cancel scope for the server task
async with anyio.create_task_group() as tg:
if isinstance(self.server, FastMCP):
await self.server.__aenter__()
tg.start_soon(
lambda: self.server._mcp_server.run(
server_read,
Expand All @@ -827,9 +829,16 @@ async def connect_session(
**session_kwargs,
) as client_session:
yield client_session
except Exception as e:
if isinstance(self.server, FastMCP):
_ = await self.server.__aexit__(type(e), e, None)
raise e
finally:
tg.cancel_scope.cancel()

if isinstance(self.server, FastMCP):
_ = await self.server.__aexit__(None, None, None)

def __repr__(self) -> str:
return f"<FastMCPTransport(server='{self.server.name}')>"

Expand Down
157 changes: 126 additions & 31 deletions src/fastmcp/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from types import TracebackType
from typing import TYPE_CHECKING, Any, Generic, Literal, cast, overload

import anyio
Expand Down Expand Up @@ -89,6 +90,10 @@
# Compiled URI parsing regex to split a URI into protocol and path components
URI_PATTERN = re.compile(r"^([^:]+://)(.*?)$")

LifespanCallable = Callable[
["FastMCP[LifespanResultT]"], AbstractAsyncContextManager[LifespanResultT]
]


@asynccontextmanager
async def default_lifespan(server: FastMCP[LifespanResultT]) -> AsyncIterator[Any]:
Expand Down Expand Up @@ -131,13 +136,9 @@ def __init__(
version: str | None = None,
auth: AuthProvider | None | NotSetT = NotSet,
middleware: list[Middleware] | None = None,
lifespan: (
Callable[
[FastMCP[LifespanResultT]],
AbstractAsyncContextManager[LifespanResultT],
]
| None
) = None,
lifespan: LifespanCallable | None = None,
session_lifespan: LifespanCallable | None = None,
server_lifespan: LifespanCallable | None = None,
dependencies: list[str] | None = None,
resource_prefix_format: Literal["protocol", "path"] | None = None,
mask_error_details: bool | None = None,
Expand Down Expand Up @@ -188,18 +189,25 @@ def __init__(
)
self._tool_serializer = tool_serializer

if lifespan is None:
self._has_lifespan = False
lifespan = default_lifespan
else:
self._has_lifespan = True
session_lifespan: LifespanCallable = self._handle_lifespan_settings(
lifespan=lifespan,
session_lifespan=session_lifespan,
server_lifespan=server_lifespan,
)

self._session_lifespan: LifespanCallable | None = session_lifespan

self._server_lifespan: LifespanCallable | None = server_lifespan
self._server_lifespan_result: LifespanResultT | None = None
self._server_lifespan_stack: AsyncExitStack | None = None

# Generate random ID if no name provided
self._mcp_server = LowLevelServer[LifespanResultT](
fastmcp=self,
name=name or self.generate_name(),
version=version,
instructions=instructions,
lifespan=_lifespan_wrapper(self, lifespan),
lifespan=_lifespan_wrapper(app=self, lifespan=session_lifespan),
)

# if auth is `NotSet`, try to create a provider from the environment
Expand Down Expand Up @@ -268,6 +276,42 @@ def __init__(
def __repr__(self) -> str:
return f"{type(self).__name__}({self.name!r})"

def _handle_lifespan_settings(
self,
lifespan: LifespanCallable | None = None,
session_lifespan: LifespanCallable | None = None,
server_lifespan: LifespanCallable | None = None,
) -> LifespanCallable:
"""Handle lifespan settings including deprecation and proxying server_lifespan into a session_lifespan.

Args:
lifespan: The lifespan callable to use for the server.
session_lifespan: The lifespan callable to use for the session.
server_lifespan: The lifespan callable to use for the server.
"""
if lifespan is not None:
if fastmcp.settings.deprecation_warnings:
import warnings

warnings.warn(
"The lifespan parameter is deprecated (as of 2.13.0). For the same behavior, "
+ "use the session_lifespan parameter instead.",
DeprecationWarning,
stacklevel=2,
)

session_lifespan = lifespan or session_lifespan

if session_lifespan and server_lifespan:
raise ValueError(
"Cannot specify both session_lifespan (or lifespan) and server_lifespan."
)

if server_lifespan is not None:
return _server_lifespan_proxy_factory(server=self)

return session_lifespan or default_lifespan

def _handle_deprecated_settings(
self,
log_level: str | None,
Expand Down Expand Up @@ -334,6 +378,41 @@ def instructions(self, value: str | None) -> None:
def version(self) -> str | None:
return self._mcp_server.version

async def __aenter__(self) -> None:
"""Enter the server-wide lifespan context."""
if self._server_lifespan_stack is not None:
# Already entered
return

self._server_lifespan_stack = AsyncExitStack()
if self._server_lifespan is not None:
self._server_lifespan_result = (
await self._server_lifespan_stack.enter_async_context(
self._server_lifespan(self)
)
)

for server in self._mounted_servers:
await self._server_lifespan_stack.enter_async_context(server.server)

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Exit the server-wide lifespan context."""
if self._server_lifespan is None:
return

if self._server_lifespan_stack is not None:
await self._server_lifespan_stack.aclose()
self._server_lifespan_stack = None
self._server_lifespan_result = None

for server in self._mounted_servers:
await server.server.__aexit__(exc_type, exc_val, exc_tb)

async def run_async(
self,
transport: Transport | None = None,
Expand All @@ -345,24 +424,25 @@ async def run_async(
Args:
transport: Transport protocol to use ("stdio", "sse", or "streamable-http")
"""
if transport is None:
transport = "stdio"
if transport not in {"stdio", "http", "sse", "streamable-http"}:
raise ValueError(f"Unknown transport: {transport}")
async with self:
if transport is None:
transport = "stdio"
if transport not in {"stdio", "http", "sse", "streamable-http"}:
raise ValueError(f"Unknown transport: {transport}")

if transport == "stdio":
await self.run_stdio_async(
show_banner=show_banner,
**transport_kwargs,
)
elif transport in {"http", "sse", "streamable-http"}:
await self.run_http_async(
transport=transport,
show_banner=show_banner,
**transport_kwargs,
)
else:
raise ValueError(f"Unknown transport: {transport}")
if transport == "stdio":
await self.run_stdio_async(
show_banner=show_banner,
**transport_kwargs,
)
elif transport in {"http", "sse", "streamable-http"}:
await self.run_http_async(
transport=transport,
show_banner=show_banner,
**transport_kwargs,
)
else:
raise ValueError(f"Unknown transport: {transport}")

def run(
self,
Expand Down Expand Up @@ -1889,7 +1969,7 @@ def mount(
# if as_proxy is not specified and the server has a custom lifespan,
# we should treat it as a proxy
if as_proxy is None:
as_proxy = server._has_lifespan
as_proxy = server._session_lifespan != default_lifespan

if as_proxy and not isinstance(server, FastMCPProxy):
server = FastMCP.as_proxy(server)
Expand Down Expand Up @@ -2460,3 +2540,18 @@ def has_resource_prefix(
return bool(re.match(prefix_pattern, path))
else:
raise ValueError(f"Invalid prefix format: {prefix_format}")


def _server_lifespan_proxy_factory(
server: FastMCP[LifespanResultT],
) -> LifespanCallable:
@asynccontextmanager
async def lifespan_proxy(
app: FastMCP[LifespanResultT],
) -> AsyncIterator[LifespanResultT | None]:
# Return the already-initialized server context
if server._server_lifespan_stack is None:
raise RuntimeError("Server lifespan result is not initialized")
yield server._server_lifespan_result

return lifespan_proxy
11 changes: 7 additions & 4 deletions tests/server/test_mount.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,16 +890,19 @@ async def test_as_proxy_true(self):
assert mcp._tool_manager._mounted_servers[0].server is not sub
assert isinstance(mcp._tool_manager._mounted_servers[0].server, FastMCPProxy)

async def test_as_proxy_defaults_true_if_lifespan(self):
async def test_as_proxy_defaults_true_if_server_lifespan(self):
"""Test that as_proxy defaults to True when server_lifespan is provided."""

@asynccontextmanager
async def lifespan(mcp: FastMCP):
async def server_lifespan(mcp: FastMCP):
yield

mcp = FastMCP("Main")
sub = FastMCP("Sub", lifespan=lifespan)
sub = FastMCP("Sub", server_lifespan=server_lifespan)

mcp.mount(sub, "sub")

# Should auto-proxy because server_lifespan is set
assert mcp._tool_manager._mounted_servers[0].server is not sub
assert isinstance(mcp._tool_manager._mounted_servers[0].server, FastMCPProxy)

Expand Down Expand Up @@ -953,7 +956,7 @@ async def lifespan(mcp: FastMCP):
yield

mcp = FastMCP("Main")
sub = FastMCP("Sub", lifespan=lifespan)
sub = FastMCP("Sub", server_lifespan=lifespan)

@sub.tool
def hello():
Expand Down
Loading