Skip to content

Feature: Non-Blocking call_tool and request state externalisation #1209

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 79 additions & 2 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

import mcp.types as types
from mcp.shared.context import RequestContext
from mcp.shared.message import SessionMessage
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
from mcp.shared.message import ClientMessageMetadata, SessionMessage
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder, RequestStateManager
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS

DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0")
Expand Down Expand Up @@ -118,13 +118,15 @@ def __init__(
logging_callback: LoggingFnT | None = None,
message_handler: MessageHandlerFnT | None = None,
client_info: types.Implementation | None = None,
request_state_manager: RequestStateManager[types.ClientRequest, types.ClientResult] | None = None,
) -> None:
super().__init__(
read_stream,
write_stream,
types.ServerRequest,
types.ServerNotification,
read_timeout_seconds=read_timeout_seconds,
request_state_manager=request_state_manager,
)
self._client_info = client_info or DEFAULT_CLIENT_INFO
self._sampling_callback = sampling_callback or _default_sampling_callback
Expand All @@ -133,6 +135,7 @@ def __init__(
self._logging_callback = logging_callback or _default_logging_callback
self._message_handler = message_handler or _default_message_handler
self._tool_output_schemas: dict[str, dict[str, Any] | None] = {}
self._resumable = False

async def initialize(self) -> types.InitializeResult:
sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None
Expand Down Expand Up @@ -170,6 +173,8 @@ async def initialize(self) -> types.InitializeResult:
if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
raise RuntimeError(f"Unsupported protocol version from the server: {result.protocolVersion}")

self._resumable = result.capabilities.resume and result.capabilities.resume.resumable

await self.send_notification(
types.ClientNotification(types.InitializedNotification(method="notifications/initialized"))
)
Expand Down Expand Up @@ -281,6 +286,78 @@ async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
types.EmptyResult,
)

async def request_call_tool(
self,
name: str,
arguments: dict[str, Any] | None = None,
progress_callback: ProgressFnT | None = None,
) -> types.RequestId:
if self._resumable:
captured_token = None
captured = anyio.Event()

async def capture_token(token: str):
nonlocal captured_token
captured_token = token
captured.set()

metadata = ClientMessageMetadata(on_resumption_token_update=capture_token)

request_id = await self.start_request(
types.ClientRequest(
types.CallToolRequest(
method="tools/call",
params=types.CallToolRequestParams(
name=name,
arguments=arguments,
),
)
),
progress_callback=progress_callback,
metadata=metadata,
)

while captured_token is None:
await captured.wait()

await self._request_state_manager.update_resume_token(request_id, captured_token)

return request_id
else:
return await self.start_request(
types.ClientRequest(
types.CallToolRequest(
method="tools/call",
params=types.CallToolRequestParams(
name=name,
arguments=arguments,
),
)
),
progress_callback=progress_callback,
)

async def join_call_tool(
self,
request_id: types.RequestId,
progress_callback: ProgressFnT | None = None,
request_read_timeout_seconds: timedelta | None = None,
done_on_timeout: bool = True,
) -> types.CallToolResult | None:
return await self.join_request(
request_id,
types.CallToolResult,
request_read_timeout_seconds=request_read_timeout_seconds,
progress_callback=progress_callback,
done_on_timeout=done_on_timeout,
)

async def cancel_call_tool(
self,
request_id: types.RequestId,
) -> bool:
return await self.cancel_request(request_id)

async def call_tool(
self,
name: str,
Expand Down
30 changes: 21 additions & 9 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
JSONRPCRequest,
JSONRPCResponse,
RequestId,
ResumeCapability,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -136,18 +137,26 @@ def _maybe_extract_session_id_from_response(
def _maybe_extract_protocol_version_from_message(
self,
message: JSONRPCMessage,
) -> None:
) -> JSONRPCMessage:
"""Extract protocol version from initialization response message."""
if isinstance(message.root, JSONRPCResponse) and message.root.result:
try:
# Parse the result as InitializeResult for type safety
init_result = InitializeResult.model_validate(message.root.result)
self.protocol_version = str(init_result.protocolVersion)
logger.info(f"Negotiated protocol version: {self.protocol_version}")
if init_result.capabilities.resume is None:
# resumeablity is predicated on the server and the transport
# this assumes that if the server hasn't explicitly configured
# that streamable http transports are resumeable
init_result.capabilities.resume = ResumeCapability(resumable=True)
message.root.result = init_result.model_dump()
except Exception as exc:
logger.warning(f"Failed to parse initialization response as InitializeResult: {exc}")
logger.warning(f"Raw result: {message.root.result}")

return message

async def _handle_sse_event(
self,
sse: ServerSentEvent,
Expand All @@ -164,7 +173,7 @@ async def _handle_sse_event(

# Extract protocol version from initialization response
if is_initialization:
self._maybe_extract_protocol_version_from_message(message)
message = self._maybe_extract_protocol_version_from_message(message)

# If this is a response and we have original_request_id, replace it
if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError):
Expand Down Expand Up @@ -303,7 +312,7 @@ async def _handle_json_response(

# Extract protocol version from initialization response
if is_initialization:
self._maybe_extract_protocol_version_from_message(message)
message = self._maybe_extract_protocol_version_from_message(message)

session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
Expand Down Expand Up @@ -333,7 +342,10 @@ async def _handle_sse_response(
break
except Exception as e:
logger.exception("Error reading SSE stream:")
await ctx.read_stream_writer.send(e)
try:
await ctx.read_stream_writer.send(e)
except anyio.ClosedResourceError:
pass

async def _handle_unexpected_content_type(
self,
Expand Down Expand Up @@ -471,8 +483,8 @@ async def streamablehttp_client(
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)

async with anyio.create_task_group() as tg:
try:
try:
async with anyio.create_task_group() as tg:
logger.debug(f"Connecting to StreamableHTTP endpoint: {url}")

async with httpx_client_factory(
Expand Down Expand Up @@ -504,6 +516,6 @@ def start_get_stream() -> None:
if transport.session_id and terminate_on_close:
await transport.terminate_session(client)
tg.cancel_scope.cancel()
finally:
await read_stream_writer.aclose()
await write_stream.aclose()
finally:
await read_stream_writer.aclose()
await write_stream.aclose()
1 change: 0 additions & 1 deletion src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,6 @@ async def send_event(event_message: EventMessage) -> None:
async with msg_reader:
async for event_message in msg_reader:
event_data = self._create_event_data(event_message)

await sse_stream_writer.send(event_data)
except Exception:
logger.exception("Error in replay sender")
Expand Down
Loading
Loading