Skip to content

Commit 6677894

Browse files
committed
Merge branch 'main' of https://github.com/modelcontextprotocol/python-sdk into feat/client-credentials
2 parents 4a4c007 + 959d4e3 commit 6677894

File tree

4 files changed

+192
-51
lines changed

4 files changed

+192
-51
lines changed

src/mcp/server/fastmcp/server.py

Lines changed: 71 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,13 @@
44

55
import inspect
66
import re
7-
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence
8-
from contextlib import (
9-
AbstractAsyncContextManager,
10-
asynccontextmanager,
11-
)
7+
from collections.abc import AsyncIterator, Awaitable, Callable, Collection, Iterable, Sequence
8+
from contextlib import AbstractAsyncContextManager, asynccontextmanager
129
from typing import Any, Generic, Literal
1310

1411
import anyio
1512
import pydantic_core
16-
from pydantic import BaseModel, Field
13+
from pydantic import BaseModel
1714
from pydantic.networks import AnyUrl
1815
from pydantic_settings import BaseSettings, SettingsConfigDict
1916
from starlette.applications import Starlette
@@ -25,10 +22,7 @@
2522
from starlette.types import Receive, Scope, Send
2623

2724
from mcp.server.auth.middleware.auth_context import AuthContextMiddleware
28-
from mcp.server.auth.middleware.bearer_auth import (
29-
BearerAuthBackend,
30-
RequireAuthMiddleware,
31-
)
25+
from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend, RequireAuthMiddleware
3226
from mcp.server.auth.provider import OAuthAuthorizationServerProvider, ProviderTokenVerifier, TokenVerifier
3327
from mcp.server.auth.settings import AuthSettings
3428
from mcp.server.elicitation import ElicitationResult, ElicitSchemaModelT, elicit_with_validation
@@ -48,12 +42,7 @@
4842
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
4943
from mcp.server.transport_security import TransportSecuritySettings
5044
from mcp.shared.context import LifespanContextT, RequestContext, RequestT
51-
from mcp.types import (
52-
AnyFunction,
53-
ContentBlock,
54-
GetPromptResult,
55-
ToolAnnotations,
56-
)
45+
from mcp.types import AnyFunction, ContentBlock, GetPromptResult, ToolAnnotations
5746
from mcp.types import Prompt as MCPPrompt
5847
from mcp.types import PromptArgument as MCPPromptArgument
5948
from mcp.types import Resource as MCPResource
@@ -79,58 +68,57 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
7968
)
8069

8170
# Server settings
82-
debug: bool = False
83-
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
71+
debug: bool
72+
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
8473

8574
# HTTP settings
86-
host: str = "127.0.0.1"
87-
port: int = 8000
88-
mount_path: str = "/" # Mount path (e.g. "/github", defaults to root path)
89-
sse_path: str = "/sse"
90-
message_path: str = "/messages/"
91-
streamable_http_path: str = "/mcp"
75+
host: str
76+
port: int
77+
mount_path: str
78+
sse_path: str
79+
message_path: str
80+
streamable_http_path: str
9281

9382
# StreamableHTTP settings
94-
json_response: bool = False
95-
stateless_http: bool = False # If True, uses true stateless mode (new transport per request)
83+
json_response: bool
84+
stateless_http: bool
85+
"""Define if the server should create a new transport per request."""
9686

9787
# resource settings
98-
warn_on_duplicate_resources: bool = True
88+
warn_on_duplicate_resources: bool
9989

10090
# tool settings
101-
warn_on_duplicate_tools: bool = True
91+
warn_on_duplicate_tools: bool
10292

10393
# prompt settings
104-
warn_on_duplicate_prompts: bool = True
94+
warn_on_duplicate_prompts: bool
10595

106-
dependencies: list[str] = Field(
107-
default_factory=list,
108-
description="List of dependencies to install in the server environment",
109-
)
96+
# TODO(Marcelo): Investigate if this is used. If it is, it's probably a good idea to remove it.
97+
dependencies: list[str]
98+
"""A list of dependencies to install in the server environment."""
11099

111-
lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]] | None = Field(
112-
None, description="Lifespan context manager"
113-
)
100+
lifespan: Callable[[FastMCP[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]] | None
101+
"""A async context manager that will be called when the server is started."""
114102

115-
auth: AuthSettings | None = None
103+
auth: AuthSettings | None
116104

117105
# Transport security settings (DNS rebinding protection)
118-
transport_security: TransportSecuritySettings | None = None
106+
transport_security: TransportSecuritySettings | None
119107

120108

121109
def lifespan_wrapper(
122-
app: FastMCP,
123-
lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]],
124-
) -> Callable[[MCPServer[LifespanResultT, Request]], AbstractAsyncContextManager[object]]:
110+
app: FastMCP[LifespanResultT],
111+
lifespan: Callable[[FastMCP[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]],
112+
) -> Callable[[MCPServer[LifespanResultT, Request]], AbstractAsyncContextManager[LifespanResultT]]:
125113
@asynccontextmanager
126-
async def wrap(s: MCPServer[LifespanResultT, Request]) -> AsyncIterator[object]:
114+
async def wrap(_: MCPServer[LifespanResultT, Request]) -> AsyncIterator[LifespanResultT]:
127115
async with lifespan(app) as context:
128116
yield context
129117

130118
return wrap
131119

132120

133-
class FastMCP:
121+
class FastMCP(Generic[LifespanResultT]):
134122
def __init__(
135123
self,
136124
name: str | None = None,
@@ -140,14 +128,50 @@ def __init__(
140128
event_store: EventStore | None = None,
141129
*,
142130
tools: list[Tool] | None = None,
143-
**settings: Any,
131+
debug: bool = False,
132+
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO",
133+
host: str = "127.0.0.1",
134+
port: int = 8000,
135+
mount_path: str = "/",
136+
sse_path: str = "/sse",
137+
message_path: str = "/messages/",
138+
streamable_http_path: str = "/mcp",
139+
json_response: bool = False,
140+
stateless_http: bool = False,
141+
warn_on_duplicate_resources: bool = True,
142+
warn_on_duplicate_tools: bool = True,
143+
warn_on_duplicate_prompts: bool = True,
144+
dependencies: Collection[str] = (),
145+
lifespan: Callable[[FastMCP[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]] | None = None,
146+
auth: AuthSettings | None = None,
147+
transport_security: TransportSecuritySettings | None = None,
144148
):
145-
self.settings = Settings(**settings)
149+
self.settings = Settings(
150+
debug=debug,
151+
log_level=log_level,
152+
host=host,
153+
port=port,
154+
mount_path=mount_path,
155+
sse_path=sse_path,
156+
message_path=message_path,
157+
streamable_http_path=streamable_http_path,
158+
json_response=json_response,
159+
stateless_http=stateless_http,
160+
warn_on_duplicate_resources=warn_on_duplicate_resources,
161+
warn_on_duplicate_tools=warn_on_duplicate_tools,
162+
warn_on_duplicate_prompts=warn_on_duplicate_prompts,
163+
dependencies=list(dependencies),
164+
lifespan=lifespan,
165+
auth=auth,
166+
transport_security=transport_security,
167+
)
146168

147169
self._mcp_server = MCPServer(
148170
name=name or "FastMCP",
149171
instructions=instructions,
150-
lifespan=(lifespan_wrapper(self, self.settings.lifespan) if self.settings.lifespan else default_lifespan),
172+
# TODO(Marcelo): It seems there's a type mismatch between the lifespan type from an FastMCP and Server.
173+
# We need to create a Lifespan type that is a generic on the server type, like Starlette does.
174+
lifespan=(lifespan_wrapper(self, self.settings.lifespan) if self.settings.lifespan else default_lifespan), # type: ignore
151175
)
152176
self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools)
153177
self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources)
@@ -257,7 +281,7 @@ async def list_tools(self) -> list[MCPTool]:
257281
for info in tools
258282
]
259283

260-
def get_context(self) -> Context[ServerSession, object, Request]:
284+
def get_context(self) -> Context[ServerSession, LifespanResultT, Request]:
261285
"""
262286
Returns a Context object. Note that the context will only be valid
263287
during a request; outside a request, most methods will error.

src/mcp/server/lowlevel/server.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ async def main():
9393

9494
logger = logging.getLogger(__name__)
9595

96-
LifespanResultT = TypeVar("LifespanResultT")
96+
LifespanResultT = TypeVar("LifespanResultT", default=Any)
9797
RequestT = TypeVar("RequestT", default=Any)
9898

9999
# type aliases for tool call results
@@ -118,7 +118,7 @@ def __init__(
118118

119119

120120
@asynccontextmanager
121-
async def lifespan(server: Server[LifespanResultT, RequestT]) -> AsyncIterator[object]:
121+
async def lifespan(_: Server[LifespanResultT, RequestT]) -> AsyncIterator[dict[str, Any]]:
122122
"""Default lifespan context manager that does nothing.
123123
124124
Args:
@@ -647,6 +647,12 @@ async def _handle_request(
647647
response = await handler(req)
648648
except McpError as err:
649649
response = err.error
650+
except anyio.get_cancelled_exc_class():
651+
logger.info(
652+
"Request %s cancelled - duplicate response suppressed",
653+
message.request_id,
654+
)
655+
return
650656
except Exception as err:
651657
if raise_exceptions:
652658
raise err

tests/server/test_cancel_handling.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
"""Test that cancelled requests don't cause double responses."""
2+
3+
import anyio
4+
import pytest
5+
6+
import mcp.types as types
7+
from mcp.server.lowlevel.server import Server
8+
from mcp.shared.exceptions import McpError
9+
from mcp.shared.memory import create_connected_server_and_client_session
10+
from mcp.types import (
11+
CallToolRequest,
12+
CallToolRequestParams,
13+
CallToolResult,
14+
CancelledNotification,
15+
CancelledNotificationParams,
16+
ClientNotification,
17+
ClientRequest,
18+
Tool,
19+
)
20+
21+
22+
@pytest.mark.anyio
23+
async def test_server_remains_functional_after_cancel():
24+
"""Verify server can handle new requests after a cancellation."""
25+
26+
server = Server("test-server")
27+
28+
# Track tool calls
29+
call_count = 0
30+
ev_first_call = anyio.Event()
31+
first_request_id = None
32+
33+
@server.list_tools()
34+
async def handle_list_tools() -> list[Tool]:
35+
return [
36+
Tool(
37+
name="test_tool",
38+
description="Tool for testing",
39+
inputSchema={},
40+
)
41+
]
42+
43+
@server.call_tool()
44+
async def handle_call_tool(name: str, arguments: dict | None) -> list:
45+
nonlocal call_count, first_request_id
46+
if name == "test_tool":
47+
call_count += 1
48+
if call_count == 1:
49+
first_request_id = server.request_context.request_id
50+
ev_first_call.set()
51+
await anyio.sleep(5) # First call is slow
52+
return [types.TextContent(type="text", text=f"Call number: {call_count}")]
53+
raise ValueError(f"Unknown tool: {name}")
54+
55+
async with create_connected_server_and_client_session(server) as client:
56+
# First request (will be cancelled)
57+
async def first_request():
58+
try:
59+
await client.send_request(
60+
ClientRequest(
61+
CallToolRequest(
62+
method="tools/call",
63+
params=CallToolRequestParams(name="test_tool", arguments={}),
64+
)
65+
),
66+
CallToolResult,
67+
)
68+
pytest.fail("First request should have been cancelled")
69+
except McpError:
70+
pass # Expected
71+
72+
# Start first request
73+
async with anyio.create_task_group() as tg:
74+
tg.start_soon(first_request)
75+
76+
# Wait for it to start
77+
await ev_first_call.wait()
78+
79+
# Cancel it
80+
assert first_request_id is not None
81+
await client.send_notification(
82+
ClientNotification(
83+
CancelledNotification(
84+
method="notifications/cancelled",
85+
params=CancelledNotificationParams(
86+
requestId=first_request_id,
87+
reason="Testing server recovery",
88+
),
89+
)
90+
)
91+
)
92+
93+
# Second request (should work normally)
94+
result = await client.send_request(
95+
ClientRequest(
96+
CallToolRequest(
97+
method="tools/call",
98+
params=CallToolRequestParams(name="test_tool", arguments={}),
99+
)
100+
),
101+
CallToolResult,
102+
)
103+
104+
# Verify second request completed successfully
105+
assert len(result.content) == 1
106+
# Type narrowing for pyright
107+
content = result.content[0]
108+
assert content.type == "text"
109+
assert isinstance(content, types.TextContent)
110+
assert content.text == "Call number: 2"
111+
assert call_count == 2

tests/shared/test_progress_notifications.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ async def run_server():
4242
serv_sesh = server_session
4343
async for message in server_session.incoming_messages:
4444
try:
45-
await server._handle_message(message, server_session, ())
45+
await server._handle_message(message, server_session, {})
4646
except Exception as e:
4747
raise e
4848

@@ -252,7 +252,7 @@ async def run_server():
252252
) as server_session:
253253
async for message in server_session.incoming_messages:
254254
try:
255-
await server._handle_message(message, server_session, ())
255+
await server._handle_message(message, server_session, {})
256256
except Exception as e:
257257
raise e
258258

0 commit comments

Comments
 (0)