From 416520004fb5465a9e6a447b3bdd8e82265f5f1b Mon Sep 17 00:00:00 2001 From: David Savage Date: Sat, 12 Jul 2025 05:54:35 +0000 Subject: [PATCH 01/12] add methods to enable call tool requests to be started and joined at a later state and cancelled --- src/mcp/client/session.py | 44 ++- src/mcp/shared/session.py | 307 ++++++++++++---- tests/client/test_resource_cleanup.py | 13 +- tests/client/test_session.py | 495 +++++++++++++++++++++++++- tests/server/test_session.py | 9 +- 5 files changed, 788 insertions(+), 80 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 1853ce7c1..85ea87c11 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -10,7 +10,7 @@ import mcp.types as types from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage -from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder +from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder, RequestStateManager from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0") @@ -118,6 +118,7 @@ def __init__( logging_callback: LoggingFnT | None = None, message_handler: MessageHandlerFnT | None = None, client_info: types.Implementation | None = None, + request_state_manager: RequestStateManager[types.ClientRequest, types.ClientResult] | None = None, ) -> None: super().__init__( read_stream, @@ -125,6 +126,7 @@ def __init__( types.ServerRequest, types.ServerNotification, read_timeout_seconds=read_timeout_seconds, + request_state_manager=request_state_manager, ) self._client_info = client_info or DEFAULT_CLIENT_INFO self._sampling_callback = sampling_callback or _default_sampling_callback @@ -281,6 +283,46 @@ async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: types.EmptyResult, ) + async def request_call_tool( + self, + name: str, + arguments: dict[str, Any] | None = None, + progress_callback: ProgressFnT | None = None, + ) -> types.RequestId: + return await self.start_request( + types.ClientRequest( + types.CallToolRequest( + method="tools/call", + params=types.CallToolRequestParams( + name=name, + arguments=arguments, + ), + ) + ), + progress_callback=progress_callback, + ) + + async def join_call_tool( + self, + request_id: types.RequestId, + progress_callback: ProgressFnT | None = None, + request_read_timeout_seconds: timedelta | None = None, + fail_on_timeout: bool = True, + ) -> types.CallToolResult | None: + return await self.join_request( + request_id, + types.CallToolResult, + request_read_timeout_seconds=request_read_timeout_seconds, + progress_callback=progress_callback, + fail_on_timeout=fail_on_timeout, + ) + + async def cancel_call_tool( + self, + request_id: types.RequestId, + ) -> bool: + return await self.cancel_request(request_id) + async def call_tool( self, name: str, diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index b2f49fc8b..5f3d6be5b 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -17,6 +17,7 @@ CONNECTION_CLOSED, INVALID_PARAMS, CancelledNotification, + CancelledNotificationParams, ClientNotification, ClientRequest, ClientResult, @@ -156,6 +157,152 @@ def cancelled(self) -> bool: return self._cancel_scope.cancel_called +class RequestStateManager( + Generic[ + SendRequestT, + SendResultT, + ], +): + def new_request(self, request: SendRequestT) -> RequestId: ... + + def add_progress_callback(self, request_id: RequestId, progress_callback: ProgressFnT): ... + + async def send_progress( + self, + request_id: RequestId, + progress: float, + total: float | None, + message: str | None, + ): ... + + async def receive_response( + self, request_id: RequestId, timeout: float | None = None, fail_on_timeout: bool = True + ) -> JSONRPCResponse | JSONRPCError | None: ... + + async def handle_response(self, message: JSONRPCResponse | JSONRPCError) -> bool: ... + + async def close_request(self, request_id: RequestId) -> bool: ... + + async def close(self) -> None: ... + + +class ImMemoryRequestStateManager( + RequestStateManager[ + SendRequestT, + SendResultT, + ], +): + _request_id: int + _response_streams: dict[ + RequestId, + tuple[ + SendRequestT, + MemoryObjectSendStream[JSONRPCResponse | JSONRPCError], + MemoryObjectReceiveStream[JSONRPCResponse | JSONRPCError], + ], + ] + _progress_callbacks: dict[RequestId, list[ProgressFnT]] + + def __init__(self): + self._request_id = 0 + self._response_streams = {} + self._progress_callbacks = {} + + def new_request(self, request: SendRequestT) -> RequestId: + request_id = self._request_id + self._request_id = request_id + 1 + + response_stream, response_stream_reader = anyio.create_memory_object_stream[JSONRPCResponse | JSONRPCError](1) + self._response_streams[request_id] = request, response_stream, response_stream_reader + + return request_id + + def add_progress_callback(self, request_id: RequestId, progress_callback: ProgressFnT): + progress_list = self._progress_callbacks.get(request_id) + if progress_list is None: + progress_list = [] + self._progress_callbacks[request_id] = progress_list + + progress_list.append(progress_callback) + + async def send_progress( + self, + request_id: RequestId, + progress: float, + total: float | None, + message: str | None, + ): + if request_id in self._progress_callbacks: + callbacks = self._progress_callbacks[request_id] + for callback in callbacks: + await callback( + progress, + total, + message, + ) + + async def receive_response( + self, request_id: RequestId, timeout: float | None = None, fail_on_timeout: bool = True + ) -> JSONRPCResponse | JSONRPCError | None: + request, _, response_stream_reader = self._response_streams.get(request_id, [None, None, None]) + + if response_stream_reader is None: + raise McpError( + ErrorData( + code=INVALID_PARAMS, + message=(f"Unknown request {request_id}"), + ) + ) + + if fail_on_timeout: + try: + with anyio.fail_after(timeout): + return await response_stream_reader.receive() + except TimeoutError: + raise McpError( + ErrorData( + code=httpx.codes.REQUEST_TIMEOUT, + message=( + f"Timed out while waiting for response to " + f"{request.__class__.__name__}. Waited " + f"{timeout} seconds." + ), + ) + ) + else: + with anyio.move_on_after(timeout): + return await response_stream_reader.receive() + + async def handle_response(self, message: JSONRPCResponse | JSONRPCError) -> bool: + _, stream, _ = self._response_streams.get(message.id, [None, None, None]) + if stream: + await stream.send(message) + return True + else: + return False + + async def close_request(self, request_id: RequestId) -> bool: + _, response_stream, response_stream_reader = self._response_streams.pop(request_id, [None, None, None]) + if response_stream is not None: + await response_stream.aclose() + if response_stream_reader is not None: + await response_stream_reader.aclose() + + self._progress_callbacks.pop(request_id, None) + + return response_stream is not None + + async def close(self): + for id, [_, stream, _] in self._response_streams.items(): + error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed") + try: + await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error)) + await stream.aclose() + except Exception: + # Stream might already be closed + pass + + class BaseSession( Generic[ SendRequestT, @@ -173,10 +320,7 @@ class BaseSession( messages when entered. """ - _response_streams: dict[RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]] - _request_id: int _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] - _progress_callbacks: dict[RequestId, ProgressFnT] def __init__( self, @@ -186,17 +330,16 @@ def __init__( receive_notification_type: type[ReceiveNotificationT], # If none, reading will never time out read_timeout_seconds: timedelta | None = None, + request_state_manager: RequestStateManager[SendRequestT, SendResultT] | None = None, ) -> None: self._read_stream = read_stream self._write_stream = write_stream - self._response_streams = {} - self._request_id = 0 self._receive_request_type = receive_request_type self._receive_notification_type = receive_notification_type self._session_read_timeout_seconds = read_timeout_seconds - self._in_flight = {} - self._progress_callbacks = {} self._exit_stack = AsyncExitStack() + self._in_flight = {} + self._request_state_manager = request_state_manager or ImMemoryRequestStateManager() async def __aenter__(self) -> Self: self._task_group = anyio.create_task_group() @@ -217,27 +360,19 @@ async def __aexit__( self._task_group.cancel_scope.cancel() return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) - async def send_request( + async def start_request( self, request: SendRequestT, - result_type: type[ReceiveResultT], - request_read_timeout_seconds: timedelta | None = None, metadata: MessageMetadata = None, progress_callback: ProgressFnT | None = None, - ) -> ReceiveResultT: + ) -> RequestId: """ - Sends a request and wait for a response. Raises an McpError if the - response contains an error. If a request read timeout is provided, it - will take precedence over the session read timeout. + Starts a request. Do not use this method to emit notifications! Use send_notification() instead. """ - request_id = self._request_id - self._request_id = request_id + 1 - - response_stream, response_stream_reader = anyio.create_memory_object_stream[JSONRPCResponse | JSONRPCError](1) - self._response_streams[request_id] = response_stream + request_id = self._request_state_manager.new_request(request) # Set up progress token if progress callback is provided request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True) @@ -249,49 +384,90 @@ async def send_request( request_data["params"]["_meta"] = {} request_data["params"]["_meta"]["progressToken"] = request_id # Store the callback for this request - self._progress_callbacks[request_id] = progress_callback + self._request_state_manager.add_progress_callback(request_id, progress_callback) - try: - jsonrpc_request = JSONRPCRequest( - jsonrpc="2.0", - id=request_id, - **request_data, - ) + jsonrpc_request = JSONRPCRequest( + jsonrpc="2.0", + id=request_id, + **request_data, + ) + try: await self._write_stream.send(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata)) + return request_id + except Exception as e: + await self._request_state_manager.close_request(request_id) + raise e - # request read timeout takes precedence over session read timeout - timeout = None - if request_read_timeout_seconds is not None: - timeout = request_read_timeout_seconds.total_seconds() - elif self._session_read_timeout_seconds is not None: - timeout = self._session_read_timeout_seconds.total_seconds() + async def join_request( + self, + request_id: RequestId, + result_type: type[ReceiveResultT], + request_read_timeout_seconds: timedelta | None = None, + progress_callback: ProgressFnT | None = None, + fail_on_timeout: bool = True, + ) -> ReceiveResultT | None: + """ + Joins a request previously started via start_request + """ + if progress_callback is not None: + self._request_state_manager.add_progress_callback(request_id, progress_callback) - try: - with anyio.fail_after(timeout): - response_or_error = await response_stream_reader.receive() - except TimeoutError: - raise McpError( - ErrorData( - code=httpx.codes.REQUEST_TIMEOUT, - message=( - f"Timed out while waiting for response to " - f"{request.__class__.__name__}. Waited " - f"{timeout} seconds." - ), - ) - ) + # request read timeout takes precedence over session read timeout + timeout = None + if request_read_timeout_seconds is not None: + timeout = request_read_timeout_seconds.total_seconds() + elif self._session_read_timeout_seconds is not None: + timeout = self._session_read_timeout_seconds.total_seconds() + + response_or_error = await self._request_state_manager.receive_response(request_id, timeout, fail_on_timeout) + if response_or_error is None: + return None + else: + await self._request_state_manager.close_request(request_id) if isinstance(response_or_error, JSONRPCError): raise McpError(response_or_error.error) else: return result_type.model_validate(response_or_error.result) - finally: - self._response_streams.pop(request_id, None) - self._progress_callbacks.pop(request_id, None) - await response_stream.aclose() - await response_stream_reader.aclose() + async def cancel_request(self, request_id: RequestId) -> bool: + """ + Cancels a request previously started via start_request + """ + closed = await self._request_state_manager.close_request(request_id) + + if closed: + notification = CancelledNotification( + method="notifications/cancelled", + params=CancelledNotificationParams(requestId=request_id, reason="cancelled"), + ) + await self.send_notification(notification, request_id) # type: ignore + return True + else: + return False + + async def send_request( + self, + request: SendRequestT, + result_type: type[ReceiveResultT], + request_read_timeout_seconds: timedelta | None = None, + metadata: MessageMetadata = None, + progress_callback: ProgressFnT | None = None, + ) -> ReceiveResultT: + """ + Sends a request and wait for a response. Raises an McpError if the + response contains an error. If a request read timeout is provided, it + will take precedence over the session read timeout. + + Do not use this method to emit notifications! Use send_notification() + instead. + """ + request_id = await self.start_request(request, metadata, progress_callback) + result = await self.join_request(request_id, result_type, request_read_timeout_seconds) + if result is None: + raise RuntimeError("Should not be possible") + return result async def send_notification( self, @@ -390,13 +566,12 @@ async def _receive_loop(self) -> None: progress_token = notification.root.params.progressToken # If there is a progress callback for this token, # call it with the progress information - if progress_token in self._progress_callbacks: - callback = self._progress_callbacks[progress_token] - await callback( - notification.root.params.progress, - notification.root.params.total, - notification.root.params.message, - ) + await self._request_state_manager.send_progress( + progress_token, + notification.root.params.progress, + notification.root.params.total, + notification.root.params.message, + ) await self._received_notification(notification) await self._handle_incoming(notification) except Exception as e: @@ -405,10 +580,8 @@ async def _receive_loop(self) -> None: f"Failed to validate notification: {e}. Message was: {message.message.root}" ) else: # Response or error - stream = self._response_streams.pop(message.message.root.id, None) - if stream: - await stream.send(message.message.root) - else: + handled = await self._request_state_manager.handle_response(message.message.root) + if not handled: await self._handle_incoming( RuntimeError(f"Received response with an unknown request ID: {message}") ) @@ -425,15 +598,7 @@ async def _receive_loop(self) -> None: finally: # after the read stream is closed, we need to send errors # to any pending requests - for id, stream in self._response_streams.items(): - error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed") - try: - await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error)) - await stream.aclose() - except Exception: - # Stream might already be closed - pass - self._response_streams.clear() + await self._request_state_manager.close() async def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None: """ diff --git a/tests/client/test_resource_cleanup.py b/tests/client/test_resource_cleanup.py index 527884219..bb7e22192 100644 --- a/tests/client/test_resource_cleanup.py +++ b/tests/client/test_resource_cleanup.py @@ -3,7 +3,7 @@ import anyio import pytest -from mcp.shared.session import BaseSession +from mcp.shared.session import BaseSession, ImMemoryRequestStateManager from mcp.types import ( ClientRequest, EmptyResult, @@ -28,12 +28,14 @@ async def _send_response(self, request_id, response): write_stream_send, write_stream_receive = anyio.create_memory_object_stream(1) read_stream_send, read_stream_receive = anyio.create_memory_object_stream(1) + request_io_manager = ImMemoryRequestStateManager() # Create the session session = TestSession( read_stream_receive, write_stream_send, object, # Request type doesn't matter for this test - object, # Notification type doesn't matter for this test + object, # Notification type doesn't matter for this test, + request_state_manager=request_io_manager, ) # Create a test request @@ -48,7 +50,7 @@ async def mock_send(*args, **kwargs): raise RuntimeError("Simulated network error") # Record the response streams before the test - initial_stream_count = len(session._response_streams) + initial_stream_count = len(request_io_manager._response_streams) # Run the test with the patched method with patch.object(session._write_stream, "send", mock_send): @@ -56,8 +58,9 @@ async def mock_send(*args, **kwargs): await session.send_request(request, EmptyResult) # Verify that no response streams were leaked - assert len(session._response_streams) == initial_stream_count, ( - f"Expected {initial_stream_count} response streams after request, but found {len(session._response_streams)}" + assert len(request_io_manager._response_streams) == initial_stream_count, ( + f"Expected {initial_stream_count} response streams after request, " + "but found {len(request_io_manager._response_streams)}" ) # Clean up diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 327d1a9e4..b8588a951 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -1,3 +1,4 @@ +from datetime import timedelta from typing import Any import anyio @@ -7,10 +8,13 @@ from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage -from mcp.shared.session import RequestResponder +from mcp.shared.session import ImMemoryRequestStateManager, RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS from mcp.types import ( LATEST_PROTOCOL_VERSION, + CallToolRequest, + CallToolResult, + CancelledNotification, ClientNotification, ClientRequest, Implementation, @@ -23,6 +27,7 @@ JSONRPCResponse, ServerCapabilities, ServerResult, + TextContent, ) @@ -495,3 +500,491 @@ async def mock_server(): assert received_capabilities.roots is not None # Custom list_roots callback provided assert isinstance(received_capabilities.roots, types.RootsCapability) assert received_capabilities.roots.listChanged is True # Should be True for custom callback + + +@pytest.mark.anyio +async def test_client_session_request_call_tool(): + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + async def mock_server(): + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, CallToolRequest) + + request = request.root + assert "hello" == request.params.name + assert request.params.arguments is not None + assert "name" in request.params.arguments + name = request.params.arguments["name"] + + result = ServerResult(CallToolResult(content=[TextContent(type="text", text=f"hello {name}")])) + + async with server_to_client_send: + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + # Create a message handler to catch exceptions + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + + async def progress_callback(progress: float, total: float | None, message: str | None) -> None: + pass + + request_id = await session.request_call_tool("hello", {"name": "world"}, progress_callback) + + with anyio.fail_after(1): + result = await session.join_call_tool(request_id) + + # Assert the result + assert isinstance(result, CallToolResult) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "hello world" + + +@pytest.mark.anyio +async def test_client_session_request_call_tool_join_timeout(): + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + send_result = anyio.Event() + + async def mock_server(): + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, CallToolRequest) + + request = request.root + assert "hello" == request.params.name + assert request.params.arguments is not None + assert "name" in request.params.arguments + name = request.params.arguments["name"] + + await send_result.wait() + + result = ServerResult(CallToolResult(content=[TextContent(type="text", text=f"hello {name}")])) + + async with server_to_client_send: + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + # Create a message handler to catch exceptions + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + + request_id = await session.request_call_tool("hello", {"name": "world"}) + + with anyio.fail_after(1): + result = await session.join_call_tool( + request_id, request_read_timeout_seconds=timedelta(microseconds=1), fail_on_timeout=False + ) + assert result is None + send_result.set() + result = await session.join_call_tool( + request_id, request_read_timeout_seconds=timedelta(seconds=1), fail_on_timeout=False + ) + + # Assert the result + assert isinstance(result, CallToolResult) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "hello world" + + # Assert resources tidied up + assert len(session._request_state_manager._response_streams) == 0 # type: ignore + + +@pytest.mark.anyio +async def test_client_session_request_call_tool_with_progress(): + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + client_2_joined = anyio.Event() + + async def mock_server(): + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, CallToolRequest) + + request = request.root + + assert "hello" == request.params.name + assert request.params.arguments is not None + assert "name" in request.params.arguments + name = request.params.arguments["name"] + assert request.params.meta is not None + assert request.params.meta.progressToken is not None + + async with server_to_client_send: + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCNotification( + jsonrpc="2.0", + method="notifications/progress", + params=types.ProgressNotificationParams( + progressToken=request.params.meta.progressToken, + progress=1, + total=2, + message="event 1", + ).model_dump(), + ) + ) + ) + ) + + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCNotification( + jsonrpc="2.0", + method="notifications/progress", + params=types.ProgressNotificationParams( + progressToken=request.params.meta.progressToken, + progress=2, + total=2, + message="event 2", + ).model_dump(), + ) + ) + ) + ) + + result = ServerResult(CallToolResult(content=[TextContent(type="text", text=f"hello {name}")])) + + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + # Create a message handler to catch exceptions + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + + progress_1 = anyio.Event() + progress_2 = anyio.Event() + + async def progress_callback1(progress: float, total: float | None, message: str | None) -> None: + if progress == 1: + progress_1.set() + elif progress == 2: + progress_2.set() + else: + raise RuntimeError("Unexpected progress value") + + request_id = await session.request_call_tool("hello", {"name": "world"}, progress_callback1) + + with anyio.fail_after(3): + await progress_1.wait() + result = await session.join_call_tool(request_id) + client_2_joined.set() + await progress_2.wait() + + # Assert the result + assert isinstance(result, CallToolResult) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "hello world" + + +@pytest.mark.anyio +async def test_client_session_request_call_tool_with_rejoin(): + client_1_to_server_send, client_1_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_1_send, server_to_client_1_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_2_to_server_send, client_2_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_2_send, server_to_client_2_receive = anyio.create_memory_object_stream[SessionMessage](1) + + async def mock_server(): + session_message = await client_1_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, CallToolRequest) + + request = request.root + + assert "hello" == request.params.name + assert request.params.arguments is not None + assert "name" in request.params.arguments + name = request.params.arguments["name"] + assert request.params.meta is not None + assert request.params.meta.progressToken is not None + + async with server_to_client_1_send, server_to_client_2_send: + await server_to_client_1_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCNotification( + jsonrpc="2.0", + method="notifications/progress", + params=types.ProgressNotificationParams( + progressToken=request.params.meta.progressToken, + progress=1, + total=2, + message="event 1", + ).model_dump(), + ) + ) + ) + ) + + await server_to_client_2_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCNotification( + jsonrpc="2.0", + method="notifications/progress", + params=types.ProgressNotificationParams( + progressToken=request.params.meta.progressToken, + progress=2, + total=2, + message="event 2", + ).model_dump(), + ) + ) + ) + ) + + result = ServerResult(CallToolResult(content=[TextContent(type="text", text=f"hello {name}")])) + + await server_to_client_2_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + # Create a message handler to catch exceptions + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + request_state_manager_1 = ImMemoryRequestStateManager() + request_state_manager_2 = ImMemoryRequestStateManager() + + async with ( + ClientSession( + server_to_client_1_receive, + client_1_to_server_send, + message_handler=message_handler, + request_state_manager=request_state_manager_1, + ) as session1, + ClientSession( + server_to_client_2_receive, + client_2_to_server_send, + message_handler=message_handler, + request_state_manager=request_state_manager_2, + ) as session2, + anyio.create_task_group() as tg, + client_1_to_server_send, + client_1_to_server_receive, + server_to_client_1_send, + server_to_client_1_receive, + client_2_to_server_send, + client_2_to_server_receive, + server_to_client_2_send, + server_to_client_2_receive, + ): + tg.start_soon(mock_server) + + progress_1_1 = anyio.Event() + progress_1_2 = anyio.Event() + progress_2_1 = anyio.Event() + progress_2_2 = anyio.Event() + + async def progress_callback1(progress: float, total: float | None, message: str | None) -> None: + if progress == 1: + progress_1_1.set() + elif progress == 2: + progress_1_2.set() + else: + raise RuntimeError("Unexpected progress value") + + async def progress_callback2(progress: float, total: float | None, message: str | None) -> None: + if progress == 1: + progress_2_1.set() + elif progress == 2: + progress_2_2.set() + else: + raise RuntimeError("Unexpected progress value") + + request_id = await session1.request_call_tool("hello", {"name": "world"}, progress_callback1) + with anyio.fail_after(1): + await progress_1_1.wait() + + # initialise io manager 2 to state of io manager 1 + for request, _, _ in request_state_manager_1._response_streams.values(): + request_state_manager_2.new_request(request) + + # simulate network disconnect and rejoin + await request_state_manager_1.close_request(request_id) + result = await session2.join_call_tool(request_id, progress_callback2) + + await progress_2_2.wait() + + assert not progress_1_2.is_set() + assert not progress_2_1.is_set() + # Assert the result + assert isinstance(result, CallToolResult) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "hello world" + + +@pytest.mark.anyio +async def test_client_session_cancel_call_tool(): + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + cancelled = anyio.Event() + + async def mock_server(): + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, CallToolRequest) + + request = request.root + assert "hello" == request.params.name + assert request.params.arguments is not None + assert "name" in request.params.arguments + + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCNotification) + notification = ClientNotification.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(notification.root, CancelledNotification) + cancelled.set() + + # Create a message handler to catch exceptions + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + + async def progress_callback(progress: float, total: float | None, message: str | None) -> None: + pass + + request_id = await session.request_call_tool("hello", {"name": "world"}, progress_callback) + assert await session.cancel_call_tool(request_id) + with anyio.fail_after(1): + await cancelled.wait() + assert not await session.cancel_call_tool(request_id) diff --git a/tests/server/test_session.py b/tests/server/test_session.py index d00eda875..621db1bf2 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -13,12 +13,13 @@ ClientNotification, Completion, CompletionArgument, + CompletionContext, CompletionsCapability, InitializedNotification, PromptReference, PromptsCapability, - ResourceReference, ResourcesCapability, + ResourceTemplateReference, ServerCapabilities, ) @@ -109,7 +110,11 @@ async def list_resources(): # Add a complete handler @server.completion() - async def complete(ref: PromptReference | ResourceReference, argument: CompletionArgument): + async def complete( + ref: PromptReference | ResourceTemplateReference, + argument: CompletionArgument, + context: CompletionContext | None, + ): return Completion( values=["completion1", "completion2"], ) From 04ff73a13076307a3370cdeaea88987377f27cbc Mon Sep 17 00:00:00 2001 From: David Savage Date: Sat, 12 Jul 2025 08:25:41 +0000 Subject: [PATCH 02/12] refactor args for clearer meaning, use error vs returning none on timeout --- src/mcp/client/session.py | 4 +-- src/mcp/shared/session.py | 70 ++++++++++++++++++------------------ tests/client/test_session.py | 16 ++++++--- 3 files changed, 49 insertions(+), 41 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 85ea87c11..1666d1afa 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -307,14 +307,14 @@ async def join_call_tool( request_id: types.RequestId, progress_callback: ProgressFnT | None = None, request_read_timeout_seconds: timedelta | None = None, - fail_on_timeout: bool = True, + done_on_timeout: bool = True, ) -> types.CallToolResult | None: return await self.join_request( request_id, types.CallToolResult, request_read_timeout_seconds=request_read_timeout_seconds, progress_callback=progress_callback, - fail_on_timeout=fail_on_timeout, + done_on_timeout=done_on_timeout, ) async def cancel_call_tool( diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 5f3d6be5b..11f1a5c84 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -176,8 +176,10 @@ async def send_progress( ): ... async def receive_response( - self, request_id: RequestId, timeout: float | None = None, fail_on_timeout: bool = True - ) -> JSONRPCResponse | JSONRPCError | None: ... + self, + request_id: RequestId, + timeout: float | None = None, + ) -> JSONRPCResponse | JSONRPCError: ... async def handle_response(self, message: JSONRPCResponse | JSONRPCError) -> bool: ... @@ -242,8 +244,10 @@ async def send_progress( ) async def receive_response( - self, request_id: RequestId, timeout: float | None = None, fail_on_timeout: bool = True - ) -> JSONRPCResponse | JSONRPCError | None: + self, + request_id: RequestId, + timeout: float | None = None, + ) -> JSONRPCResponse | JSONRPCError: request, _, response_stream_reader = self._response_streams.get(request_id, [None, None, None]) if response_stream_reader is None: @@ -254,24 +258,20 @@ async def receive_response( ) ) - if fail_on_timeout: - try: - with anyio.fail_after(timeout): - return await response_stream_reader.receive() - except TimeoutError: - raise McpError( - ErrorData( - code=httpx.codes.REQUEST_TIMEOUT, - message=( - f"Timed out while waiting for response to " - f"{request.__class__.__name__}. Waited " - f"{timeout} seconds." - ), - ) - ) - else: - with anyio.move_on_after(timeout): + try: + with anyio.fail_after(timeout): return await response_stream_reader.receive() + except TimeoutError: + raise McpError( + ErrorData( + code=httpx.codes.REQUEST_TIMEOUT, + message=( + f"Timed out while waiting for response to " + f"{request.__class__.__name__}. Waited " + f"{timeout} seconds." + ), + ) + ) async def handle_response(self, message: JSONRPCResponse | JSONRPCError) -> bool: _, stream, _ = self._response_streams.get(message.id, [None, None, None]) @@ -405,8 +405,8 @@ async def join_request( result_type: type[ReceiveResultT], request_read_timeout_seconds: timedelta | None = None, progress_callback: ProgressFnT | None = None, - fail_on_timeout: bool = True, - ) -> ReceiveResultT | None: + done_on_timeout: bool = True, + ) -> ReceiveResultT: """ Joins a request previously started via start_request """ @@ -420,16 +420,18 @@ async def join_request( elif self._session_read_timeout_seconds is not None: timeout = self._session_read_timeout_seconds.total_seconds() - response_or_error = await self._request_state_manager.receive_response(request_id, timeout, fail_on_timeout) + response_or_error = await self._request_state_manager.receive_response(request_id, timeout) - if response_or_error is None: - return None + if isinstance(response_or_error, JSONRPCError): + if response_or_error.error.code == httpx.codes.REQUEST_TIMEOUT.value: + if done_on_timeout: + await self._request_state_manager.close_request(request_id) + else: + await self._request_state_manager.close_request(request_id) + raise McpError(response_or_error.error) else: await self._request_state_manager.close_request(request_id) - if isinstance(response_or_error, JSONRPCError): - raise McpError(response_or_error.error) - else: - return result_type.model_validate(response_or_error.result) + return result_type.model_validate(response_or_error.result) async def cancel_request(self, request_id: RequestId) -> bool: """ @@ -464,10 +466,10 @@ async def send_request( instead. """ request_id = await self.start_request(request, metadata, progress_callback) - result = await self.join_request(request_id, result_type, request_read_timeout_seconds) - if result is None: - raise RuntimeError("Should not be possible") - return result + try: + return await self.join_request(request_id, result_type, request_read_timeout_seconds) + finally: + await self._request_state_manager.close_request(request_id) async def send_notification( self, diff --git a/tests/client/test_session.py b/tests/client/test_session.py index b8588a951..e8776f092 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -2,11 +2,13 @@ from typing import Any import anyio +import httpx import pytest import mcp.types as types from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession from mcp.shared.context import RequestContext +from mcp.shared.exceptions import McpError from mcp.shared.message import SessionMessage from mcp.shared.session import ImMemoryRequestStateManager, RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS @@ -636,13 +638,17 @@ async def message_handler( request_id = await session.request_call_tool("hello", {"name": "world"}) with anyio.fail_after(1): - result = await session.join_call_tool( - request_id, request_read_timeout_seconds=timedelta(microseconds=1), fail_on_timeout=False - ) - assert result is None + try: + result = await session.join_call_tool( + request_id, request_read_timeout_seconds=timedelta(microseconds=1), done_on_timeout=False + ) + except McpError as e: + if not e.error.code == httpx.codes.REQUEST_TIMEOUT: + raise e + send_result.set() result = await session.join_call_tool( - request_id, request_read_timeout_seconds=timedelta(seconds=1), fail_on_timeout=False + request_id, request_read_timeout_seconds=timedelta(seconds=1), done_on_timeout=False ) # Assert the result From 288ebe33946d38170ca22b2e3cd9f4735c704d02 Mon Sep 17 00:00:00 2001 From: David Savage Date: Mon, 14 Jul 2025 08:34:59 +0000 Subject: [PATCH 03/12] add resume logic to request/join call_tool functions --- src/mcp/client/session.py | 29 ++++++- src/mcp/shared/session.py | 83 ++++++++++++------ tests/client/test_session.py | 160 +++++++++++++++++++++++++++++------ 3 files changed, 222 insertions(+), 50 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 1666d1afa..a253b59cf 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -9,7 +9,7 @@ import mcp.types as types from mcp.shared.context import RequestContext -from mcp.shared.message import SessionMessage +from mcp.shared.message import ClientMessageMetadata, SessionMessage from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder, RequestStateManager from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS @@ -289,7 +289,11 @@ async def request_call_tool( arguments: dict[str, Any] | None = None, progress_callback: ProgressFnT | None = None, ) -> types.RequestId: - return await self.start_request( + write, read = anyio.create_memory_object_stream[str]() + + metadata = ClientMessageMetadata(on_resumption_token_update=write.send) + + request_id = await self.start_request( types.ClientRequest( types.CallToolRequest( method="tools/call", @@ -300,8 +304,25 @@ async def request_call_tool( ) ), progress_callback=progress_callback, + metadata=metadata, ) + async def update_token() -> None: + try: + async for token in read: + self._request_state_manager.update_resume_token(request_id, token) + except anyio.ClosedResourceError: + pass + + async def close() -> None: + await write.aclose() + await read.aclose() + + self._exit_stack.push_async_callback(update_token) + self._exit_stack.push_async_callback(close) + + return request_id + async def join_call_tool( self, request_id: types.RequestId, @@ -309,9 +330,13 @@ async def join_call_tool( request_read_timeout_seconds: timedelta | None = None, done_on_timeout: bool = True, ) -> types.CallToolResult | None: + resume_token = self._request_state_manager.get_resume_token(request_id) + metadata = ClientMessageMetadata(resumption_token=resume_token) + return await self.join_request( request_id, types.CallToolResult, + metadata=metadata, request_read_timeout_seconds=request_read_timeout_seconds, progress_callback=progress_callback, done_on_timeout=done_on_timeout, diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 11f1a5c84..21e59626b 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -21,12 +21,14 @@ ClientNotification, ClientRequest, ClientResult, + EmptyResult, ErrorData, JSONRPCError, JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, + PingRequest, ProgressNotification, RequestParams, ServerNotification, @@ -165,6 +167,10 @@ class RequestStateManager( ): def new_request(self, request: SendRequestT) -> RequestId: ... + def update_resume_token(self, request_id: RequestId, token: str) -> None: ... + + def get_resume_token(self, request_id: RequestId) -> str | None: ... + def add_progress_callback(self, request_id: RequestId, progress_callback: ProgressFnT): ... async def send_progress( @@ -204,11 +210,13 @@ class ImMemoryRequestStateManager( ], ] _progress_callbacks: dict[RequestId, list[ProgressFnT]] + _resume_tokens: dict[RequestId, str] def __init__(self): self._request_id = 0 self._response_streams = {} self._progress_callbacks = {} + self._resume_tokens = {} def new_request(self, request: SendRequestT) -> RequestId: request_id = self._request_id @@ -219,6 +227,12 @@ def new_request(self, request: SendRequestT) -> RequestId: return request_id + def update_resume_token(self, request_id: RequestId, token: str) -> None: + self._resume_tokens[request_id] = token + + def get_resume_token(self, request_id: RequestId) -> str | None: + return self._resume_tokens.get(request_id) + def add_progress_callback(self, request_id: RequestId, progress_callback: ProgressFnT): progress_list = self._progress_callbacks.get(request_id) if progress_list is None: @@ -289,6 +303,7 @@ async def close_request(self, request_id: RequestId) -> bool: await response_stream_reader.aclose() self._progress_callbacks.pop(request_id, None) + self._resume_tokens.pop(request_id, None) return response_stream is not None @@ -373,38 +388,17 @@ async def start_request( instead. """ request_id = self._request_state_manager.new_request(request) - - # Set up progress token if progress callback is provided - request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True) - if progress_callback is not None: - # Use request_id as progress token - if "params" not in request_data: - request_data["params"] = {} - if "_meta" not in request_data["params"]: - request_data["params"]["_meta"] = {} - request_data["params"]["_meta"]["progressToken"] = request_id - # Store the callback for this request - self._request_state_manager.add_progress_callback(request_id, progress_callback) - - jsonrpc_request = JSONRPCRequest( - jsonrpc="2.0", - id=request_id, - **request_data, + return await self._send_request( + request_id=request_id, request=request, metadata=metadata, progress_callback=progress_callback ) - try: - await self._write_stream.send(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata)) - return request_id - except Exception as e: - await self._request_state_manager.close_request(request_id) - raise e - async def join_request( self, request_id: RequestId, result_type: type[ReceiveResultT], request_read_timeout_seconds: timedelta | None = None, progress_callback: ProgressFnT | None = None, + metadata: MessageMetadata | None = None, done_on_timeout: bool = True, ) -> ReceiveResultT: """ @@ -420,6 +414,15 @@ async def join_request( elif self._session_read_timeout_seconds is not None: timeout = self._session_read_timeout_seconds.total_seconds() + if metadata: + # need to resend metadata - primary use case is client resume support + await self.send_request( + request=PingRequest(method="ping"), # type: ignore + result_type=EmptyResult, + request_read_timeout_seconds=None if timeout is None else timedelta(seconds=timeout), + metadata=metadata, + ) + response_or_error = await self._request_state_manager.receive_response(request_id, timeout) if isinstance(response_or_error, JSONRPCError): @@ -433,6 +436,38 @@ async def join_request( await self._request_state_manager.close_request(request_id) return result_type.model_validate(response_or_error.result) + async def _send_request( + self, + request_id: RequestId, + request: SendRequestT, + metadata: MessageMetadata = None, + progress_callback: ProgressFnT | None = None, + ): + # Set up progress token if progress callback is provided + request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True) + if progress_callback is not None: + # Use request_id as progress token + if "params" not in request_data: + request_data["params"] = {} + if "_meta" not in request_data["params"]: + request_data["params"]["_meta"] = {} + request_data["params"]["_meta"]["progressToken"] = request_id + # Store the callback for this request + self._request_state_manager.add_progress_callback(request_id, progress_callback) + + jsonrpc_request = JSONRPCRequest( + jsonrpc="2.0", + id=request_id, + **request_data, + ) + + try: + await self._write_stream.send(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata)) + return request_id + except Exception as e: + await self._request_state_manager.close_request(request_id) + raise e + async def cancel_request(self, request_id: RequestId) -> bool: """ Cancels a request previously started via start_request diff --git a/tests/client/test_session.py b/tests/client/test_session.py index e8776f092..b0e462cde 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -19,6 +19,7 @@ CancelledNotification, ClientNotification, ClientRequest, + EmptyResult, Implementation, InitializedNotification, InitializeRequest, @@ -27,6 +28,7 @@ JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, + PingRequest, ServerCapabilities, ServerResult, TextContent, @@ -513,26 +515,49 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request.root, JSONRPCRequest) + call_id = jsonrpc_request.root.id request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) assert isinstance(request.root, CallToolRequest) - request = request.root assert "hello" == request.params.name assert request.params.arguments is not None assert "name" in request.params.arguments name = request.params.arguments["name"] - result = ServerResult(CallToolResult(content=[TextContent(type="text", text=f"hello {name}")])) + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + ping_id = jsonrpc_request.root.id + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, PingRequest) async with server_to_client_send: + result = ServerResult(EmptyResult()) + await server_to_client_send.send( SessionMessage( JSONRPCMessage( JSONRPCResponse( jsonrpc="2.0", - id=jsonrpc_request.root.id, + id=ping_id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + result = ServerResult(CallToolResult(content=[TextContent(type="text", text=f"hello {name}")])) + + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=call_id, result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) @@ -580,39 +605,80 @@ async def test_client_session_request_call_tool_join_timeout(): client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) - send_result = anyio.Event() - async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request.root, JSONRPCRequest) + call_id = jsonrpc_request.root.id request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) assert isinstance(request.root, CallToolRequest) - request = request.root assert "hello" == request.params.name assert request.params.arguments is not None assert "name" in request.params.arguments name = request.params.arguments["name"] - await send_result.wait() - - result = ServerResult(CallToolResult(content=[TextContent(type="text", text=f"hello {name}")])) + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + ping_id_1 = jsonrpc_request.root.id + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, PingRequest) async with server_to_client_send: + result = ServerResult(EmptyResult()) + await server_to_client_send.send( SessionMessage( JSONRPCMessage( JSONRPCResponse( jsonrpc="2.0", - id=jsonrpc_request.root.id, + id=ping_id_1, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + ping_id_2 = jsonrpc_request.root.id + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, PingRequest) + + result = ServerResult(EmptyResult()) + + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=ping_id_2, result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) ) + result = ServerResult(CallToolResult(content=[TextContent(type="text", text=f"hello {name}")])) + + async with server_to_client_send: + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=call_id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) # Create a message handler to catch exceptions async def message_handler( @@ -637,16 +703,15 @@ async def message_handler( request_id = await session.request_call_tool("hello", {"name": "world"}) - with anyio.fail_after(1): + with anyio.fail_after(3): try: result = await session.join_call_tool( - request_id, request_read_timeout_seconds=timedelta(microseconds=1), done_on_timeout=False + request_id, request_read_timeout_seconds=timedelta(seconds=0.5), done_on_timeout=False ) except McpError as e: if not e.error.code == httpx.codes.REQUEST_TIMEOUT: raise e - send_result.set() result = await session.join_call_tool( request_id, request_read_timeout_seconds=timedelta(seconds=1), done_on_timeout=False ) @@ -666,25 +731,23 @@ async def test_client_session_request_call_tool_with_progress(): client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) - client_2_joined = anyio.Event() - async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request.root, JSONRPCRequest) + call_id = jsonrpc_request.root.id request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) assert isinstance(request.root, CallToolRequest) - request = request.root - assert "hello" == request.params.name assert request.params.arguments is not None assert "name" in request.params.arguments name = request.params.arguments["name"] assert request.params.meta is not None assert request.params.meta.progressToken is not None + progrss_token = request.params.meta.progressToken async with server_to_client_send: await server_to_client_send.send( @@ -694,7 +757,7 @@ async def mock_server(): jsonrpc="2.0", method="notifications/progress", params=types.ProgressNotificationParams( - progressToken=request.params.meta.progressToken, + progressToken=progrss_token, progress=1, total=2, message="event 1", @@ -704,6 +767,29 @@ async def mock_server(): ) ) + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + ping_id = jsonrpc_request.root.id + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, PingRequest) + + result = ServerResult(EmptyResult()) + + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=ping_id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + await server_to_client_send.send( SessionMessage( JSONRPCMessage( @@ -711,7 +797,7 @@ async def mock_server(): jsonrpc="2.0", method="notifications/progress", params=types.ProgressNotificationParams( - progressToken=request.params.meta.progressToken, + progressToken=progrss_token, progress=2, total=2, message="event 2", @@ -728,7 +814,7 @@ async def mock_server(): JSONRPCMessage( JSONRPCResponse( jsonrpc="2.0", - id=jsonrpc_request.root.id, + id=call_id, result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) @@ -772,7 +858,6 @@ async def progress_callback1(progress: float, total: float | None, message: str with anyio.fail_after(3): await progress_1.wait() result = await session.join_call_tool(request_id) - client_2_joined.set() await progress_2.wait() # Assert the result @@ -793,6 +878,7 @@ async def mock_server(): session_message = await client_1_to_server_receive.receive() jsonrpc_request = session_message.message assert isinstance(jsonrpc_request.root, JSONRPCRequest) + call_tool_id = jsonrpc_request.root.id request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) @@ -806,6 +892,7 @@ async def mock_server(): name = request.params.arguments["name"] assert request.params.meta is not None assert request.params.meta.progressToken is not None + progress_token = request.params.meta.progressToken async with server_to_client_1_send, server_to_client_2_send: await server_to_client_1_send.send( @@ -815,7 +902,7 @@ async def mock_server(): jsonrpc="2.0", method="notifications/progress", params=types.ProgressNotificationParams( - progressToken=request.params.meta.progressToken, + progressToken=progress_token, progress=1, total=2, message="event 1", @@ -825,6 +912,29 @@ async def mock_server(): ) ) + session_message = await client_2_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + ping_id = jsonrpc_request.root.id + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, PingRequest) + + result = ServerResult(EmptyResult()) + + await server_to_client_2_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=ping_id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + await server_to_client_2_send.send( SessionMessage( JSONRPCMessage( @@ -832,7 +942,7 @@ async def mock_server(): jsonrpc="2.0", method="notifications/progress", params=types.ProgressNotificationParams( - progressToken=request.params.meta.progressToken, + progressToken=progress_token, progress=2, total=2, message="event 2", @@ -849,7 +959,7 @@ async def mock_server(): JSONRPCMessage( JSONRPCResponse( jsonrpc="2.0", - id=jsonrpc_request.root.id, + id=call_tool_id, result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) @@ -919,6 +1029,8 @@ async def progress_callback2(progress: float, total: float | None, message: str # initialise io manager 2 to state of io manager 1 for request, _, _ in request_state_manager_1._response_streams.values(): request_state_manager_2.new_request(request) + for request, token in request_state_manager_1._resume_tokens.items(): + request_state_manager_2._resume_tokens[request] = token # simulate network disconnect and rejoin await request_state_manager_1.close_request(request_id) From 40028daa6880f086b713a3e220aa03c710aad2be Mon Sep 17 00:00:00 2001 From: David Savage Date: Mon, 14 Jul 2025 09:03:53 +0000 Subject: [PATCH 04/12] Remove None as valid return type from join_call_tool, fix typo ImMemory -> InMemory --- src/mcp/client/session.py | 2 +- src/mcp/shared/session.py | 4 ++-- tests/client/test_resource_cleanup.py | 4 ++-- tests/client/test_session.py | 6 +++--- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index a253b59cf..a5b0ff979 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -329,7 +329,7 @@ async def join_call_tool( progress_callback: ProgressFnT | None = None, request_read_timeout_seconds: timedelta | None = None, done_on_timeout: bool = True, - ) -> types.CallToolResult | None: + ) -> types.CallToolResult: resume_token = self._request_state_manager.get_resume_token(request_id) metadata = ClientMessageMetadata(resumption_token=resume_token) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 21e59626b..3014a23a9 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -194,7 +194,7 @@ async def close_request(self, request_id: RequestId) -> bool: ... async def close(self) -> None: ... -class ImMemoryRequestStateManager( +class InMemoryRequestStateManager( RequestStateManager[ SendRequestT, SendResultT, @@ -354,7 +354,7 @@ def __init__( self._session_read_timeout_seconds = read_timeout_seconds self._exit_stack = AsyncExitStack() self._in_flight = {} - self._request_state_manager = request_state_manager or ImMemoryRequestStateManager() + self._request_state_manager = request_state_manager or InMemoryRequestStateManager() async def __aenter__(self) -> Self: self._task_group = anyio.create_task_group() diff --git a/tests/client/test_resource_cleanup.py b/tests/client/test_resource_cleanup.py index bb7e22192..2d4c18343 100644 --- a/tests/client/test_resource_cleanup.py +++ b/tests/client/test_resource_cleanup.py @@ -3,7 +3,7 @@ import anyio import pytest -from mcp.shared.session import BaseSession, ImMemoryRequestStateManager +from mcp.shared.session import BaseSession, InMemoryRequestStateManager from mcp.types import ( ClientRequest, EmptyResult, @@ -28,7 +28,7 @@ async def _send_response(self, request_id, response): write_stream_send, write_stream_receive = anyio.create_memory_object_stream(1) read_stream_send, read_stream_receive = anyio.create_memory_object_stream(1) - request_io_manager = ImMemoryRequestStateManager() + request_io_manager = InMemoryRequestStateManager() # Create the session session = TestSession( read_stream_receive, diff --git a/tests/client/test_session.py b/tests/client/test_session.py index b0e462cde..2238b0603 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -10,7 +10,7 @@ from mcp.shared.context import RequestContext from mcp.shared.exceptions import McpError from mcp.shared.message import SessionMessage -from mcp.shared.session import ImMemoryRequestStateManager, RequestResponder +from mcp.shared.session import InMemoryRequestStateManager, RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS from mcp.types import ( LATEST_PROTOCOL_VERSION, @@ -973,8 +973,8 @@ async def message_handler( if isinstance(message, Exception): raise message - request_state_manager_1 = ImMemoryRequestStateManager() - request_state_manager_2 = ImMemoryRequestStateManager() + request_state_manager_1 = InMemoryRequestStateManager() + request_state_manager_2 = InMemoryRequestStateManager() async with ( ClientSession( From 161da461428d34e60122a2d012bfaafcc13a69d7 Mon Sep 17 00:00:00 2001 From: David Savage Date: Mon, 14 Jul 2025 14:58:35 +0000 Subject: [PATCH 05/12] send resume on init rather than part of join, refactor resume to be global to session rather than per request (read the spec) --- src/mcp/client/session.py | 40 +++++------ src/mcp/shared/session.py | 85 +++++++++--------------- tests/client/test_session.py | 125 +++-------------------------------- 3 files changed, 57 insertions(+), 193 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index a5b0ff979..570ced353 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -176,6 +176,20 @@ async def initialize(self) -> types.InitializeResult: types.ClientNotification(types.InitializedNotification(method="notifications/initialized")) ) + resume_token = await self._request_state_manager.get_resume_token() + if resume_token: + metadata = ClientMessageMetadata(resumption_token=resume_token) + timeout = None + if self._session_read_timeout_seconds is not None: + timeout = self._session_read_timeout_seconds.total_seconds() + + await self.send_request( + request=PingRequest(method="ping"), # type: ignore + result_type=types.EmptyResult, + request_read_timeout_seconds=None if timeout is None else timedelta(seconds=timeout), + metadata=metadata, + ) + return result async def send_ping(self) -> types.EmptyResult: @@ -289,11 +303,9 @@ async def request_call_tool( arguments: dict[str, Any] | None = None, progress_callback: ProgressFnT | None = None, ) -> types.RequestId: - write, read = anyio.create_memory_object_stream[str]() - - metadata = ClientMessageMetadata(on_resumption_token_update=write.send) + metadata = ClientMessageMetadata(on_resumption_token_update=self._request_state_manager.update_resume_token) - request_id = await self.start_request( + return await self.start_request( types.ClientRequest( types.CallToolRequest( method="tools/call", @@ -307,22 +319,6 @@ async def request_call_tool( metadata=metadata, ) - async def update_token() -> None: - try: - async for token in read: - self._request_state_manager.update_resume_token(request_id, token) - except anyio.ClosedResourceError: - pass - - async def close() -> None: - await write.aclose() - await read.aclose() - - self._exit_stack.push_async_callback(update_token) - self._exit_stack.push_async_callback(close) - - return request_id - async def join_call_tool( self, request_id: types.RequestId, @@ -330,13 +326,9 @@ async def join_call_tool( request_read_timeout_seconds: timedelta | None = None, done_on_timeout: bool = True, ) -> types.CallToolResult: - resume_token = self._request_state_manager.get_resume_token(request_id) - metadata = ClientMessageMetadata(resumption_token=resume_token) - return await self.join_request( request_id, types.CallToolResult, - metadata=metadata, request_read_timeout_seconds=request_read_timeout_seconds, progress_callback=progress_callback, done_on_timeout=done_on_timeout, diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 3014a23a9..30106f187 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -21,14 +21,12 @@ ClientNotification, ClientRequest, ClientResult, - EmptyResult, ErrorData, JSONRPCError, JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, - PingRequest, ProgressNotification, RequestParams, ServerNotification, @@ -167,9 +165,9 @@ class RequestStateManager( ): def new_request(self, request: SendRequestT) -> RequestId: ... - def update_resume_token(self, request_id: RequestId, token: str) -> None: ... + async def update_resume_token(self, token: str) -> None: ... - def get_resume_token(self, request_id: RequestId) -> str | None: ... + async def get_resume_token(self) -> str | None: ... def add_progress_callback(self, request_id: RequestId, progress_callback: ProgressFnT): ... @@ -210,13 +208,13 @@ class InMemoryRequestStateManager( ], ] _progress_callbacks: dict[RequestId, list[ProgressFnT]] - _resume_tokens: dict[RequestId, str] + _resume_token: str | None def __init__(self): self._request_id = 0 self._response_streams = {} self._progress_callbacks = {} - self._resume_tokens = {} + self._resume_token = None def new_request(self, request: SendRequestT) -> RequestId: request_id = self._request_id @@ -227,11 +225,11 @@ def new_request(self, request: SendRequestT) -> RequestId: return request_id - def update_resume_token(self, request_id: RequestId, token: str) -> None: - self._resume_tokens[request_id] = token + async def update_resume_token(self, token: str) -> None: + self._resume_token = token - def get_resume_token(self, request_id: RequestId) -> str | None: - return self._resume_tokens.get(request_id) + async def get_resume_token(self) -> str | None: + return self._resume_token def add_progress_callback(self, request_id: RequestId, progress_callback: ProgressFnT): progress_list = self._progress_callbacks.get(request_id) @@ -303,7 +301,6 @@ async def close_request(self, request_id: RequestId) -> bool: await response_stream_reader.aclose() self._progress_callbacks.pop(request_id, None) - self._resume_tokens.pop(request_id, None) return response_stream is not None @@ -388,17 +385,37 @@ async def start_request( instead. """ request_id = self._request_state_manager.new_request(request) - return await self._send_request( - request_id=request_id, request=request, metadata=metadata, progress_callback=progress_callback + # Set up progress token if progress callback is provided + request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True) + if progress_callback is not None: + # Use request_id as progress token + if "params" not in request_data: + request_data["params"] = {} + if "_meta" not in request_data["params"]: + request_data["params"]["_meta"] = {} + request_data["params"]["_meta"]["progressToken"] = request_id + # Store the callback for this request + self._request_state_manager.add_progress_callback(request_id, progress_callback) + + jsonrpc_request = JSONRPCRequest( + jsonrpc="2.0", + id=request_id, + **request_data, ) + try: + await self._write_stream.send(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata)) + return request_id + except Exception as e: + await self._request_state_manager.close_request(request_id) + raise e + async def join_request( self, request_id: RequestId, result_type: type[ReceiveResultT], request_read_timeout_seconds: timedelta | None = None, progress_callback: ProgressFnT | None = None, - metadata: MessageMetadata | None = None, done_on_timeout: bool = True, ) -> ReceiveResultT: """ @@ -414,15 +431,6 @@ async def join_request( elif self._session_read_timeout_seconds is not None: timeout = self._session_read_timeout_seconds.total_seconds() - if metadata: - # need to resend metadata - primary use case is client resume support - await self.send_request( - request=PingRequest(method="ping"), # type: ignore - result_type=EmptyResult, - request_read_timeout_seconds=None if timeout is None else timedelta(seconds=timeout), - metadata=metadata, - ) - response_or_error = await self._request_state_manager.receive_response(request_id, timeout) if isinstance(response_or_error, JSONRPCError): @@ -436,37 +444,6 @@ async def join_request( await self._request_state_manager.close_request(request_id) return result_type.model_validate(response_or_error.result) - async def _send_request( - self, - request_id: RequestId, - request: SendRequestT, - metadata: MessageMetadata = None, - progress_callback: ProgressFnT | None = None, - ): - # Set up progress token if progress callback is provided - request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True) - if progress_callback is not None: - # Use request_id as progress token - if "params" not in request_data: - request_data["params"] = {} - if "_meta" not in request_data["params"]: - request_data["params"]["_meta"] = {} - request_data["params"]["_meta"]["progressToken"] = request_id - # Store the callback for this request - self._request_state_manager.add_progress_callback(request_id, progress_callback) - - jsonrpc_request = JSONRPCRequest( - jsonrpc="2.0", - id=request_id, - **request_data, - ) - - try: - await self._write_stream.send(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata)) - return request_id - except Exception as e: - await self._request_state_manager.close_request(request_id) - raise e async def cancel_request(self, request_id: RequestId) -> bool: """ diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 2238b0603..9c9e30d51 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -526,30 +526,7 @@ async def mock_server(): assert "name" in request.params.arguments name = request.params.arguments["name"] - session_message = await client_to_server_receive.receive() - jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request.root, JSONRPCRequest) - ping_id = jsonrpc_request.root.id - request = ClientRequest.model_validate( - jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) - ) - assert isinstance(request.root, PingRequest) - async with server_to_client_send: - result = ServerResult(EmptyResult()) - - await server_to_client_send.send( - SessionMessage( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=ping_id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) - ) - ) - ) - result = ServerResult(CallToolResult(content=[TextContent(type="text", text=f"hello {name}")])) await server_to_client_send.send( @@ -605,6 +582,8 @@ async def test_client_session_request_call_tool_join_timeout(): client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + send_result = anyio.Event() + async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message @@ -620,51 +599,9 @@ async def mock_server(): assert "name" in request.params.arguments name = request.params.arguments["name"] - session_message = await client_to_server_receive.receive() - jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request.root, JSONRPCRequest) - ping_id_1 = jsonrpc_request.root.id - request = ClientRequest.model_validate( - jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) - ) - assert isinstance(request.root, PingRequest) + await send_result.wait() async with server_to_client_send: - result = ServerResult(EmptyResult()) - - await server_to_client_send.send( - SessionMessage( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=ping_id_1, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) - ) - ) - ) - session_message = await client_to_server_receive.receive() - jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request.root, JSONRPCRequest) - ping_id_2 = jsonrpc_request.root.id - request = ClientRequest.model_validate( - jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) - ) - assert isinstance(request.root, PingRequest) - - result = ServerResult(EmptyResult()) - - await server_to_client_send.send( - SessionMessage( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=ping_id_2, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) - ) - ) - ) result = ServerResult(CallToolResult(content=[TextContent(type="text", text=f"hello {name}")])) async with server_to_client_send: @@ -708,10 +645,11 @@ async def message_handler( result = await session.join_call_tool( request_id, request_read_timeout_seconds=timedelta(seconds=0.5), done_on_timeout=False ) + # raise RuntimeError("Expected fail") except McpError as e: if not e.error.code == httpx.codes.REQUEST_TIMEOUT: raise e - + send_result.set() result = await session.join_call_tool( request_id, request_read_timeout_seconds=timedelta(seconds=1), done_on_timeout=False ) @@ -731,6 +669,8 @@ async def test_client_session_request_call_tool_with_progress(): client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + send_progress_2 = anyio.Event() + async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message @@ -767,28 +707,7 @@ async def mock_server(): ) ) - session_message = await client_to_server_receive.receive() - jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request.root, JSONRPCRequest) - ping_id = jsonrpc_request.root.id - request = ClientRequest.model_validate( - jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) - ) - assert isinstance(request.root, PingRequest) - - result = ServerResult(EmptyResult()) - - await server_to_client_send.send( - SessionMessage( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=ping_id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) - ) - ) - ) + # await send_progress_2.wait() await server_to_client_send.send( SessionMessage( @@ -858,6 +777,7 @@ async def progress_callback1(progress: float, total: float | None, message: str with anyio.fail_after(3): await progress_1.wait() result = await session.join_call_tool(request_id) + send_progress_2.set() await progress_2.wait() # Assert the result @@ -912,29 +832,6 @@ async def mock_server(): ) ) - session_message = await client_2_to_server_receive.receive() - jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request.root, JSONRPCRequest) - ping_id = jsonrpc_request.root.id - request = ClientRequest.model_validate( - jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) - ) - assert isinstance(request.root, PingRequest) - - result = ServerResult(EmptyResult()) - - await server_to_client_2_send.send( - SessionMessage( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=ping_id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) - ) - ) - ) - await server_to_client_2_send.send( SessionMessage( JSONRPCMessage( @@ -1029,9 +926,7 @@ async def progress_callback2(progress: float, total: float | None, message: str # initialise io manager 2 to state of io manager 1 for request, _, _ in request_state_manager_1._response_streams.values(): request_state_manager_2.new_request(request) - for request, token in request_state_manager_1._resume_tokens.items(): - request_state_manager_2._resume_tokens[request] = token - + # simulate network disconnect and rejoin await request_state_manager_1.close_request(request_id) result = await session2.join_call_tool(request_id, progress_callback2) From aa2cbec9340226c2b3a3a53804467a8121c30acd Mon Sep 17 00:00:00 2001 From: David Savage Date: Wed, 16 Jul 2025 07:37:54 +0000 Subject: [PATCH 06/12] fix import error --- src/mcp/client/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 570ced353..c2bb9018a 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -184,7 +184,7 @@ async def initialize(self) -> types.InitializeResult: timeout = self._session_read_timeout_seconds.total_seconds() await self.send_request( - request=PingRequest(method="ping"), # type: ignore + request=types.PingRequest(method="ping"), # type: ignore result_type=types.EmptyResult, request_read_timeout_seconds=None if timeout is None else timedelta(seconds=timeout), metadata=metadata, From 7329cba1f0369eb95d24a0a9322545d697a24ddb Mon Sep 17 00:00:00 2001 From: David Savage Date: Sat, 26 Jul 2025 08:35:26 +0000 Subject: [PATCH 07/12] Refactor code to send resume as part of join call rather than it, this results in the response being consumed prior to the join, also added a capability that identifies whether the server/transport supports resumption that is passed during initialisation --- src/mcp/client/session.py | 80 +++++++++++------ src/mcp/client/streamable_http.py | 23 ++++- src/mcp/shared/session.py | 120 ++++++++++++++++++------- src/mcp/types.py | 10 +++ tests/client/test_session.py | 13 +-- tests/shared/test_streamable_http.py | 128 ++++++++++++++++++++++++++- 6 files changed, 299 insertions(+), 75 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index c2bb9018a..c5b8928ae 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -135,6 +135,7 @@ def __init__( self._logging_callback = logging_callback or _default_logging_callback self._message_handler = message_handler or _default_message_handler self._tool_output_schemas: dict[str, dict[str, Any] | None] = {} + self._resumable = False async def initialize(self) -> types.InitializeResult: sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None @@ -172,24 +173,12 @@ async def initialize(self) -> types.InitializeResult: if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS: raise RuntimeError(f"Unsupported protocol version from the server: {result.protocolVersion}") + self._resumable = result.capabilities.resume and result.capabilities.resume.resumable + await self.send_notification( types.ClientNotification(types.InitializedNotification(method="notifications/initialized")) ) - resume_token = await self._request_state_manager.get_resume_token() - if resume_token: - metadata = ClientMessageMetadata(resumption_token=resume_token) - timeout = None - if self._session_read_timeout_seconds is not None: - timeout = self._session_read_timeout_seconds.total_seconds() - - await self.send_request( - request=types.PingRequest(method="ping"), # type: ignore - result_type=types.EmptyResult, - request_read_timeout_seconds=None if timeout is None else timedelta(seconds=timeout), - metadata=metadata, - ) - return result async def send_ping(self) -> types.EmptyResult: @@ -303,21 +292,58 @@ async def request_call_tool( arguments: dict[str, Any] | None = None, progress_callback: ProgressFnT | None = None, ) -> types.RequestId: - metadata = ClientMessageMetadata(on_resumption_token_update=self._request_state_manager.update_resume_token) - - return await self.start_request( - types.ClientRequest( - types.CallToolRequest( - method="tools/call", - params=types.CallToolRequestParams( - name=name, - arguments=arguments, + if self._resumable: + send_stream, receive_stream = anyio.create_memory_object_stream[str](1) + + async def close() -> None: + await send_stream.aclose() + await receive_stream.aclose() + + self._exit_stack.push_async_callback(close) + + with send_stream, receive_stream: + + async def send_token(token: str): + try: + await send_stream.send(token) + except anyio.BrokenResourceError as e: + raise e + + metadata = ClientMessageMetadata(on_resumption_token_update=send_token) + + request_id = await self.start_request( + types.ClientRequest( + types.CallToolRequest( + method="tools/call", + params=types.CallToolRequestParams( + name=name, + arguments=arguments, + ), + ) ), + progress_callback=progress_callback, + metadata=metadata, ) - ), - progress_callback=progress_callback, - metadata=metadata, - ) + + await anyio.lowlevel.checkpoint() + + token = await receive_stream.receive() + await self._request_state_manager.update_resume_token(request_id, token) + + return request_id + else: + return await self.start_request( + types.ClientRequest( + types.CallToolRequest( + method="tools/call", + params=types.CallToolRequestParams( + name=name, + arguments=arguments, + ), + ) + ), + progress_callback=progress_callback, + ) async def join_call_tool( self, diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 39ac34d8a..d57c4b4aa 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -29,6 +29,7 @@ JSONRPCRequest, JSONRPCResponse, RequestId, + ResumeCapability, ) logger = logging.getLogger(__name__) @@ -136,7 +137,7 @@ def _maybe_extract_session_id_from_response( def _maybe_extract_protocol_version_from_message( self, message: JSONRPCMessage, - ) -> None: + ) -> JSONRPCMessage: """Extract protocol version from initialization response message.""" if isinstance(message.root, JSONRPCResponse) and message.root.result: try: @@ -144,10 +145,18 @@ def _maybe_extract_protocol_version_from_message( init_result = InitializeResult.model_validate(message.root.result) self.protocol_version = str(init_result.protocolVersion) logger.info(f"Negotiated protocol version: {self.protocol_version}") + if init_result.capabilities.resume is None: + # resumeablity is predicated on the server and the transport + # this assumes that if the server hasn't explicitly configured + # that streamable http transports are resumeable + init_result.capabilities.resume = ResumeCapability(resumable=True) + message.root.result = init_result.model_dump() except Exception as exc: logger.warning(f"Failed to parse initialization response as InitializeResult: {exc}") logger.warning(f"Raw result: {message.root.result}") + return message + async def _handle_sse_event( self, sse: ServerSentEvent, @@ -183,7 +192,10 @@ async def _handle_sse_event( except Exception as exc: logger.exception("Error parsing SSE message") - await read_stream_writer.send(exc) + try: + await read_stream_writer.send(exc) + except anyio.BrokenResourceError: + pass return False else: logger.warning(f"Unknown SSE event: {sse.event}") @@ -303,7 +315,7 @@ async def _handle_json_response( # Extract protocol version from initialization response if is_initialization: - self._maybe_extract_protocol_version_from_message(message) + message = self._maybe_extract_protocol_version_from_message(message) session_message = SessionMessage(message) await read_stream_writer.send(session_message) @@ -333,7 +345,10 @@ async def _handle_sse_response( break except Exception as e: logger.exception("Error reading SSE stream:") - await ctx.read_stream_writer.send(e) + try: + await ctx.read_stream_writer.send(e) + except anyio.ClosedResourceError: + pass async def _handle_unexpected_content_type( self, diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 30106f187..0ee21fa8b 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -12,7 +12,7 @@ from typing_extensions import Self from mcp.shared.exceptions import McpError -from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage +from mcp.shared.message import ClientMessageMetadata, MessageMetadata, ServerMessageMetadata, SessionMessage from mcp.types import ( CONNECTION_CLOSED, INVALID_PARAMS, @@ -27,6 +27,7 @@ JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, + PingRequest, ProgressNotification, RequestParams, ServerNotification, @@ -165,9 +166,11 @@ class RequestStateManager( ): def new_request(self, request: SendRequestT) -> RequestId: ... - async def update_resume_token(self, token: str) -> None: ... + def resume(self, request_id: RequestId) -> bool: ... - async def get_resume_token(self) -> str | None: ... + async def update_resume_token(self, request_id: RequestId, token: str) -> None: ... + + async def get_resume_token(self, request_id: RequestId) -> str | None: ... def add_progress_callback(self, request_id: RequestId, progress_callback: ProgressFnT): ... @@ -199,37 +202,53 @@ class InMemoryRequestStateManager( ], ): _request_id: int + _requests: dict[ + RequestId, + SendRequestT, + ] _response_streams: dict[ RequestId, tuple[ - SendRequestT, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError], MemoryObjectReceiveStream[JSONRPCResponse | JSONRPCError], ], ] _progress_callbacks: dict[RequestId, list[ProgressFnT]] - _resume_token: str | None + _resume_tokens: dict[RequestId, str] def __init__(self): self._request_id = 0 + self._requests = {} self._response_streams = {} self._progress_callbacks = {} - self._resume_token = None + self._resume_tokens = {} def new_request(self, request: SendRequestT) -> RequestId: request_id = self._request_id self._request_id = request_id + 1 - response_stream, response_stream_reader = anyio.create_memory_object_stream[JSONRPCResponse | JSONRPCError](1) - self._response_streams[request_id] = request, response_stream, response_stream_reader + send_stream, receive_stream = anyio.create_memory_object_stream[JSONRPCResponse | JSONRPCError](1) + self._response_streams[request_id] = send_stream, receive_stream + self._requests[request_id] = request return request_id - async def update_resume_token(self, token: str) -> None: - self._resume_token = token + def resume(self, request_id: RequestId) -> bool: + if self._requests.get(request_id) is None: + raise RuntimeError(f"Unknown request {request_id}") + + if request_id in self._response_streams: + return False + else: + send_stream, receive_stream = anyio.create_memory_object_stream[JSONRPCResponse | JSONRPCError](1) + self._response_streams[request_id] = send_stream, receive_stream + return True + + async def update_resume_token(self, request_id: RequestId, token: str) -> None: + self._resume_tokens[request_id] = token - async def get_resume_token(self) -> str | None: - return self._resume_token + async def get_resume_token(self, request_id: RequestId) -> str | None: + return self._resume_tokens.get(request_id) def add_progress_callback(self, request_id: RequestId, progress_callback: ProgressFnT): progress_list = self._progress_callbacks.get(request_id) @@ -260,9 +279,8 @@ async def receive_response( request_id: RequestId, timeout: float | None = None, ) -> JSONRPCResponse | JSONRPCError: - request, _, response_stream_reader = self._response_streams.get(request_id, [None, None, None]) - - if response_stream_reader is None: + _, receive_stream = self._response_streams.get(request_id, [None, None]) + if receive_stream is None: raise McpError( ErrorData( code=INVALID_PARAMS, @@ -270,9 +288,19 @@ async def receive_response( ) ) + request = self._requests.get(request_id, None) + assert request is not None + try: with anyio.fail_after(timeout): - return await response_stream_reader.receive() + return await receive_stream.receive() + except anyio.EndOfStream: + raise McpError( + ErrorData( + code=CONNECTION_CLOSED, + message=("Connection closed"), + ) + ) except TimeoutError: raise McpError( ErrorData( @@ -286,33 +314,41 @@ async def receive_response( ) async def handle_response(self, message: JSONRPCResponse | JSONRPCError) -> bool: - _, stream, _ = self._response_streams.get(message.id, [None, None, None]) - if stream: - await stream.send(message) + send_stream, _ = self._response_streams.get(message.id, [None, None]) + if send_stream: + await send_stream.send(message) return True else: return False async def close_request(self, request_id: RequestId) -> bool: - _, response_stream, response_stream_reader = self._response_streams.pop(request_id, [None, None, None]) - if response_stream is not None: - await response_stream.aclose() - if response_stream_reader is not None: - await response_stream_reader.aclose() - + send_stream, receive_stream = self._response_streams.pop(request_id, [None, None]) + if send_stream is not None: + await send_stream.aclose() + if receive_stream is not None: + await receive_stream.aclose() + + self._requests.pop(request_id, None) + self._resume_tokens.pop(request_id, None) self._progress_callbacks.pop(request_id, None) - return response_stream is not None + return send_stream is not None async def close(self): - for id, [_, stream, _] in self._response_streams.items(): - error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed") + for id, [send_stream, receive_stream] in self._response_streams.copy().items(): + await receive_stream.aclose() try: - await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error)) - await stream.aclose() - except Exception: - # Stream might already be closed + error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed") + await send_stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error)) + except anyio.BrokenResourceError: + # Stream already be closed pass + except anyio.ClosedResourceError: + # Stream already be closed + pass + finally: + await send_stream.aclose() + self._response_streams.pop(id) class BaseSession( @@ -421,6 +457,8 @@ async def join_request( """ Joins a request previously started via start_request """ + resume = self._request_state_manager.resume(request_id) + if progress_callback is not None: self._request_state_manager.add_progress_callback(request_id, progress_callback) @@ -431,6 +469,23 @@ async def join_request( elif self._session_read_timeout_seconds is not None: timeout = self._session_read_timeout_seconds.total_seconds() + if resume: + resume_token = await self._request_state_manager.get_resume_token(request_id) + if resume_token is not None: + metadata = ClientMessageMetadata(resumption_token=resume_token) + + request_data = PingRequest(method="ping").model_dump(by_alias=True, mode="json", exclude_none=True) + + jsonrpc_request = JSONRPCRequest( + jsonrpc="2.0", + id=request_id, + **request_data, + ) + + await self._write_stream.send( + SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata) + ) + response_or_error = await self._request_state_manager.receive_response(request_id, timeout) if isinstance(response_or_error, JSONRPCError): @@ -444,7 +499,6 @@ async def join_request( await self._request_state_manager.close_request(request_id) return result_type.model_validate(response_or_error.result) - async def cancel_request(self, request_id: RequestId) -> bool: """ Cancels a request previously started via start_request diff --git a/src/mcp/types.py b/src/mcp/types.py index 91432d69c..f6388ef1e 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -262,6 +262,14 @@ class PromptsCapability(BaseModel): model_config = ConfigDict(extra="allow") +class ResumeCapability(BaseModel): + """Capability for resume operations.""" + + resumable: bool | None = None + """Whether this server supports resume operations.""" + model_config = ConfigDict(extra="allow") + + class ResourcesCapability(BaseModel): """Capability for resources operations.""" @@ -303,6 +311,8 @@ class ServerCapabilities(BaseModel): """Present if the server offers any prompt templates.""" resources: ResourcesCapability | None = None """Present if the server offers any resources to read.""" + resume: ResumeCapability | None = None + """Present if the server offers resume capability.""" tools: ToolsCapability | None = None """Present if the server offers any tools to call.""" completions: CompletionsCapability | None = None diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 9c9e30d51..c857c2ebe 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -19,7 +19,6 @@ CancelledNotification, ClientNotification, ClientRequest, - EmptyResult, Implementation, InitializedNotification, InitializeRequest, @@ -28,7 +27,6 @@ JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, - PingRequest, ServerCapabilities, ServerResult, TextContent, @@ -562,11 +560,7 @@ async def message_handler( ): tg.start_soon(mock_server) - async def progress_callback(progress: float, total: float | None, message: str | None) -> None: - pass - - request_id = await session.request_call_tool("hello", {"name": "world"}, progress_callback) - + request_id = await session.request_call_tool("hello", {"name": "world"}) with anyio.fail_after(1): result = await session.join_call_tool(request_id) @@ -924,9 +918,8 @@ async def progress_callback2(progress: float, total: float | None, message: str await progress_1_1.wait() # initialise io manager 2 to state of io manager 1 - for request, _, _ in request_state_manager_1._response_streams.values(): - request_state_manager_2.new_request(request) - + request_state_manager_2._requests = request_state_manager_1._requests.copy() + # simulate network disconnect and rejoin await request_state_manager_1.close_request(request_id) result = await session2.join_call_tool(request_id, progress_callback2) diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 1ffcc13b0..8f0cf2638 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -43,7 +43,7 @@ from mcp.shared.message import ( ClientMessageMetadata, ) -from mcp.shared.session import RequestResponder +from mcp.shared.session import InMemoryRequestStateManager, RequestResponder from mcp.types import ( InitializeResult, TextContent, @@ -1169,6 +1169,132 @@ async def run_tool(): assert not any(n in captured_notifications_pre for n in captured_notifications) +@pytest.mark.anyio +async def test_streamablehttp_client_resumption_non_blocking(event_server): + """Test client session to resume a long running tool via non blocking api.""" + _, server_url = event_server + + with anyio.fail_after(10): + # Variables to track the state + captured_session_id = None + captured_notifications = [] + tool_started = False + captured_protocol_version = None + captured_request_id = None + request_state_manager_1 = InMemoryRequestStateManager() + request_state_manager_2 = InMemoryRequestStateManager() + + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + if isinstance(message, types.ServerNotification): + captured_notifications.append(message) + # Look for our special notification that indicates the tool is running + if isinstance(message.root, types.LoggingMessageNotification): + if message.root.params.data == "Tool started": + nonlocal tool_started + tool_started = True + + # First, start the client session and begin the long-running tool + async with streamablehttp_client(f"{server_url}/mcp", terminate_on_close=False) as ( + read_stream, + write_stream, + get_session_id, + ): + async with ClientSession( + read_stream, + write_stream, + message_handler=message_handler, + request_state_manager=request_state_manager_1, + ) as session: + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + captured_session_id = get_session_id() + assert captured_session_id is not None + # Capture the negotiated protocol version + captured_protocol_version = result.protocolVersion + + # Start a long-running tool in a task + async with anyio.create_task_group() as tg: + + async def run_tool(): + nonlocal captured_request_id + captured_request_id = await session.request_call_tool( + "long_running_with_checkpoints", arguments={} + ) + + tg.start_soon(run_tool) + + # Wait for the tool to start and at least one notification + # and then kill the task group + while ( + not tool_started or not captured_request_id or len(request_state_manager_1._resume_tokens) == 0 + ): + await anyio.sleep(0.1) + + tg.cancel_scope.cancel() + + # Store pre notifications and clear the captured notifications + # for the post-resumption check + captured_notifications_pre = captured_notifications.copy() + captured_notifications = [] + + # Now resume the session with the same mcp-session-id and protocol version + headers = {} + if captured_session_id: + headers[MCP_SESSION_ID_HEADER] = captured_session_id + if captured_protocol_version: + headers[MCP_PROTOCOL_VERSION_HEADER] = captured_protocol_version + + assert len(request_state_manager_1._requests) == 1, str(request_state_manager_1._requests) + assert len(request_state_manager_1._resume_tokens) == 1 + + request_state_manager_2._requests = request_state_manager_1._requests.copy() + request_state_manager_2._resume_tokens = request_state_manager_1._resume_tokens.copy() + + async with streamablehttp_client(f"{server_url}/mcp", headers=headers) as ( + read_stream, + write_stream, + _, + ): + async with ClientSession( + read_stream, + write_stream, + message_handler=message_handler, + request_state_manager=request_state_manager_2, + ) as session: + # Don't initialize - just use the existing session + + # Resume the tool with the resumption token + assert captured_request_id is not None + + result = await session.join_call_tool(captured_request_id) + + # We should get a complete result + assert len(result.content) == 1 + assert result.content[0].type == "text" + assert "Completed" in result.content[0].text + + # We should have received the remaining notifications + assert len(captured_notifications) > 0 + + # Should not have the first notification + # Check that "Tool started" notification isn't repeated when resuming + assert not any( + isinstance(n.root, types.LoggingMessageNotification) and n.root.params.data == "Tool started" + for n in captured_notifications + ) + # there is no intersection between pre and post notifications + assert not any(n in captured_notifications_pre for n in captured_notifications) + + assert len(request_state_manager_1._progress_callbacks) == 0 + assert len(request_state_manager_1._response_streams) == 0 + assert len(request_state_manager_2._progress_callbacks) == 0 + assert len(request_state_manager_2._resume_tokens) == 0 + assert len(request_state_manager_2._response_streams) == 0 + + @pytest.mark.anyio async def test_streamablehttp_server_sampling(basic_server, basic_server_url): """Test server-initiated sampling request through streamable HTTP transport.""" From e4c25b738bacd5dde0e29f89666adb224f8dd4b3 Mon Sep 17 00:00:00 2001 From: David Savage Date: Sun, 27 Jul 2025 11:41:37 +0000 Subject: [PATCH 08/12] simplify token capture using events rather than streams, add test for timeout on join and subsequent rejoin --- src/mcp/client/session.py | 56 +++++------ src/mcp/client/streamable_http.py | 17 ++-- tests/shared/test_streamable_http.py | 141 +++++++++++++++++++++++++++ 3 files changed, 172 insertions(+), 42 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index c5b8928ae..4721a17b3 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -293,44 +293,36 @@ async def request_call_tool( progress_callback: ProgressFnT | None = None, ) -> types.RequestId: if self._resumable: - send_stream, receive_stream = anyio.create_memory_object_stream[str](1) + captured_token = None + captured = anyio.Event() - async def close() -> None: - await send_stream.aclose() - await receive_stream.aclose() + async def capture_token(token: str): + nonlocal captured_token + captured_token = token + captured.set() - self._exit_stack.push_async_callback(close) + metadata = ClientMessageMetadata(on_resumption_token_update=capture_token) - with send_stream, receive_stream: - - async def send_token(token: str): - try: - await send_stream.send(token) - except anyio.BrokenResourceError as e: - raise e - - metadata = ClientMessageMetadata(on_resumption_token_update=send_token) - - request_id = await self.start_request( - types.ClientRequest( - types.CallToolRequest( - method="tools/call", - params=types.CallToolRequestParams( - name=name, - arguments=arguments, - ), - ) - ), - progress_callback=progress_callback, - metadata=metadata, - ) + request_id = await self.start_request( + types.ClientRequest( + types.CallToolRequest( + method="tools/call", + params=types.CallToolRequestParams( + name=name, + arguments=arguments, + ), + ) + ), + progress_callback=progress_callback, + metadata=metadata, + ) - await anyio.lowlevel.checkpoint() + while captured_token is None: + await captured.wait() - token = await receive_stream.receive() - await self._request_state_manager.update_resume_token(request_id, token) + await self._request_state_manager.update_resume_token(request_id, captured_token) - return request_id + return request_id else: return await self.start_request( types.ClientRequest( diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index d57c4b4aa..2d60983b2 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -173,7 +173,7 @@ async def _handle_sse_event( # Extract protocol version from initialization response if is_initialization: - self._maybe_extract_protocol_version_from_message(message) + message = self._maybe_extract_protocol_version_from_message(message) # If this is a response and we have original_request_id, replace it if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError): @@ -192,10 +192,7 @@ async def _handle_sse_event( except Exception as exc: logger.exception("Error parsing SSE message") - try: - await read_stream_writer.send(exc) - except anyio.BrokenResourceError: - pass + await read_stream_writer.send(exc) return False else: logger.warning(f"Unknown SSE event: {sse.event}") @@ -486,8 +483,8 @@ async def streamablehttp_client( read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) - async with anyio.create_task_group() as tg: - try: + try: + async with anyio.create_task_group() as tg: logger.debug(f"Connecting to StreamableHTTP endpoint: {url}") async with httpx_client_factory( @@ -519,6 +516,6 @@ def start_get_stream() -> None: if transport.session_id and terminate_on_close: await transport.terminate_session(client) tg.cancel_scope.cancel() - finally: - await read_stream_writer.aclose() - await write_stream.aclose() + finally: + await read_stream_writer.aclose() + await write_stream.aclose() diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 8f0cf2638..68c3ce249 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -9,6 +9,7 @@ import socket import time from collections.abc import Generator +from datetime import timedelta from typing import Any import anyio @@ -1295,6 +1296,146 @@ async def run_tool(): assert len(request_state_manager_2._response_streams) == 0 +@pytest.mark.anyio +async def test_streamablehttp_client_resumption_timeout(event_server): + """Test client session to resume a long running tool via non blocking api.""" + _, server_url = event_server + + with anyio.fail_after(10): + # Variables to track the state + captured_session_id = None + captured_notifications = [] + tool_started = False + captured_protocol_version = None + captured_request_id = None + request_state_manager_1 = InMemoryRequestStateManager() + request_state_manager_2 = InMemoryRequestStateManager() + + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + if isinstance(message, types.ServerNotification): + captured_notifications.append(message) + # Look for our special notification that indicates the tool is running + if isinstance(message.root, types.LoggingMessageNotification): + if message.root.params.data == "Tool started": + nonlocal tool_started + tool_started = True + + # First, start the client session and begin the long-running tool + async with streamablehttp_client(f"{server_url}/mcp", terminate_on_close=False) as ( + read_stream, + write_stream, + get_session_id, + ): + async with ClientSession( + read_stream, + write_stream, + message_handler=message_handler, + request_state_manager=request_state_manager_1, + ) as session: + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + captured_session_id = get_session_id() + assert captured_session_id is not None + # Capture the negotiated protocol version + captured_protocol_version = result.protocolVersion + + # Start a long-running tool in a task + async with anyio.create_task_group() as tg: + timed_out = anyio.Event() + + async def run_tool(): + nonlocal captured_request_id + captured_request_id = await session.request_call_tool( + "long_running_with_checkpoints", arguments={} + ) + try: + await session.join_call_tool( + captured_request_id, request_read_timeout_seconds=timedelta(seconds=0.01) + ) + raise RuntimeError("Expected timeout") + except McpError as e: + assert e.error.code == httpx.codes.REQUEST_TIMEOUT.value + + timed_out.set() + + tg.start_soon(run_tool) + + # Wait for the tool to start and at least one notification + # and then kill the task group + while ( + not tool_started or not captured_request_id or len(request_state_manager_1._resume_tokens) == 0 + ): + await anyio.sleep(0.1) + + await timed_out.wait() + + tg.cancel_scope.cancel() + + # Store pre notifications and clear the captured notifications + # for the post-resumption check + captured_notifications_pre = captured_notifications.copy() + captured_notifications = [] + + # Now resume the session with the same mcp-session-id and protocol version + headers = {} + if captured_session_id: + headers[MCP_SESSION_ID_HEADER] = captured_session_id + if captured_protocol_version: + headers[MCP_PROTOCOL_VERSION_HEADER] = captured_protocol_version + + assert len(request_state_manager_1._requests) == 1, str(request_state_manager_1._requests) + assert len(request_state_manager_1._resume_tokens) == 1 + + request_state_manager_2._requests = request_state_manager_1._requests.copy() + request_state_manager_2._resume_tokens = request_state_manager_1._resume_tokens.copy() + + async with streamablehttp_client(f"{server_url}/mcp", headers=headers) as ( + read_stream, + write_stream, + _, + ): + async with ClientSession( + read_stream, + write_stream, + message_handler=message_handler, + request_state_manager=request_state_manager_2, + ) as session: + # Don't initialize - just use the existing session + + # Resume the tool with the resumption token + assert captured_request_id is not None + + result = await session.join_call_tool(captured_request_id) + + # We should get a complete result + assert len(result.content) == 1 + assert result.content[0].type == "text" + assert "Completed" in result.content[0].text + + # We should have received the remaining notifications + assert len(captured_notifications) > 0 + + # Should not have the first notification + # Check that "Tool started" notification isn't repeated when resuming + assert not any( + isinstance(n.root, types.LoggingMessageNotification) and n.root.params.data == "Tool started" + for n in captured_notifications + ) + # there is no intersection between pre and post notifications + assert not any(n in captured_notifications_pre for n in captured_notifications), ( + f"{captured_notifications_pre} -> {captured_notifications}" + ) + + assert len(request_state_manager_1._progress_callbacks) == 0 + assert len(request_state_manager_1._response_streams) == 0 + assert len(request_state_manager_2._progress_callbacks) == 0 + assert len(request_state_manager_2._resume_tokens) == 0 + assert len(request_state_manager_2._response_streams) == 0 + + @pytest.mark.anyio async def test_streamablehttp_server_sampling(basic_server, basic_server_url): """Test server-initiated sampling request through streamable HTTP transport.""" From 79f3c4e9d33872400ca788bdc006cf7b2e091fd5 Mon Sep 17 00:00:00 2001 From: David Savage Date: Mon, 28 Jul 2025 15:45:00 +0000 Subject: [PATCH 09/12] avoid exceptions during join call tool on timeout as this is expected behaviour use None when no result retrieved instead --- src/mcp/client/session.py | 2 +- src/mcp/server/streamable_http.py | 2 +- src/mcp/shared/session.py | 48 +++++++++++++++++----------- tests/client/test_session.py | 12 +++---- tests/shared/test_streamable_http.py | 15 +++++---- 5 files changed, 44 insertions(+), 35 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 4721a17b3..049fa982d 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -343,7 +343,7 @@ async def join_call_tool( progress_callback: ProgressFnT | None = None, request_read_timeout_seconds: timedelta | None = None, done_on_timeout: bool = True, - ) -> types.CallToolResult: + ) -> types.CallToolResult | None: return await self.join_request( request_id, types.CallToolResult, diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 1b37acd43..7b43f2da2 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -763,8 +763,8 @@ async def send_event(event_message: EventMessage) -> None: async with msg_reader: async for event_message in msg_reader: event_data = self._create_event_data(event_message) - await sse_stream_writer.send(event_data) + except Exception as e: logger.exception(f"Error in replay sender: {e}") diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 0ee21fa8b..456f96685 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -186,7 +186,7 @@ async def receive_response( self, request_id: RequestId, timeout: float | None = None, - ) -> JSONRPCResponse | JSONRPCError: ... + ) -> JSONRPCResponse | JSONRPCError | None: ... async def handle_response(self, message: JSONRPCResponse | JSONRPCError) -> bool: ... @@ -278,7 +278,7 @@ async def receive_response( self, request_id: RequestId, timeout: float | None = None, - ) -> JSONRPCResponse | JSONRPCError: + ) -> JSONRPCResponse | JSONRPCError | None: _, receive_stream = self._response_streams.get(request_id, [None, None]) if receive_stream is None: raise McpError( @@ -302,16 +302,7 @@ async def receive_response( ) ) except TimeoutError: - raise McpError( - ErrorData( - code=httpx.codes.REQUEST_TIMEOUT, - message=( - f"Timed out while waiting for response to " - f"{request.__class__.__name__}. Waited " - f"{timeout} seconds." - ), - ) - ) + return None async def handle_response(self, message: JSONRPCResponse | JSONRPCError) -> bool: send_stream, _ = self._response_streams.get(message.id, [None, None]) @@ -453,9 +444,11 @@ async def join_request( request_read_timeout_seconds: timedelta | None = None, progress_callback: ProgressFnT | None = None, done_on_timeout: bool = True, - ) -> ReceiveResultT: + ) -> ReceiveResultT | None: """ - Joins a request previously started via start_request + Joins a request previously started via start_request. + + Returns the result or None if timeout is reached. """ resume = self._request_state_manager.resume(request_id) @@ -488,17 +481,23 @@ async def join_request( response_or_error = await self._request_state_manager.receive_response(request_id, timeout) - if isinstance(response_or_error, JSONRPCError): + if response_or_error is None: + if done_on_timeout: + await self._request_state_manager.close_request(request_id) + return None + elif isinstance(response_or_error, JSONRPCError): if response_or_error.error.code == httpx.codes.REQUEST_TIMEOUT.value: if done_on_timeout: await self._request_state_manager.close_request(request_id) + return None else: await self._request_state_manager.close_request(request_id) - raise McpError(response_or_error.error) - else: + raise McpError(response_or_error.error) + else : await self._request_state_manager.close_request(request_id) return result_type.model_validate(response_or_error.result) + async def cancel_request(self, request_id: RequestId) -> bool: """ Cancels a request previously started via start_request @@ -533,7 +532,20 @@ async def send_request( """ request_id = await self.start_request(request, metadata, progress_callback) try: - return await self.join_request(request_id, result_type, request_read_timeout_seconds) + result = await self.join_request(request_id, result_type, request_read_timeout_seconds) + if result is None: + raise McpError( + ErrorData( + code=httpx.codes.REQUEST_TIMEOUT, + message=( + f"Timed out while waiting for response to " + f"{request.__class__.__name__}. Waited " + f"{request_read_timeout_seconds} seconds." + ), + ) + ) + else: + return result finally: await self._request_state_manager.close_request(request_id) diff --git a/tests/client/test_session.py b/tests/client/test_session.py index c857c2ebe..525f047e7 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -635,14 +635,10 @@ async def message_handler( request_id = await session.request_call_tool("hello", {"name": "world"}) with anyio.fail_after(3): - try: - result = await session.join_call_tool( - request_id, request_read_timeout_seconds=timedelta(seconds=0.5), done_on_timeout=False - ) - # raise RuntimeError("Expected fail") - except McpError as e: - if not e.error.code == httpx.codes.REQUEST_TIMEOUT: - raise e + result = await session.join_call_tool( + request_id, request_read_timeout_seconds=timedelta(seconds=0.5), done_on_timeout=False + ) + assert result is None send_result.set() result = await session.join_call_tool( request_id, request_read_timeout_seconds=timedelta(seconds=1), done_on_timeout=False diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 68c3ce249..cd284b755 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -1351,13 +1351,13 @@ async def run_tool(): captured_request_id = await session.request_call_tool( "long_running_with_checkpoints", arguments={} ) - try: - await session.join_call_tool( - captured_request_id, request_read_timeout_seconds=timedelta(seconds=0.01) - ) - raise RuntimeError("Expected timeout") - except McpError as e: - assert e.error.code == httpx.codes.REQUEST_TIMEOUT.value + + result = await session.join_call_tool( + captured_request_id, request_read_timeout_seconds=timedelta(seconds=0.01), + done_on_timeout=False + ) + + assert result is None timed_out.set() @@ -1409,6 +1409,7 @@ async def run_tool(): assert captured_request_id is not None result = await session.join_call_tool(captured_request_id) + assert result is not None # We should get a complete result assert len(result.content) == 1 From f262bb69dd83c2f27078f819cef29528cb060a51 Mon Sep 17 00:00:00 2001 From: David Savage Date: Mon, 28 Jul 2025 16:03:05 +0000 Subject: [PATCH 10/12] uv ruff fixes --- src/mcp/shared/session.py | 7 +++---- tests/client/test_session.py | 2 -- tests/shared/test_streamable_http.py | 9 +++++---- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 456f96685..60ed1387e 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -447,7 +447,7 @@ async def join_request( ) -> ReceiveResultT | None: """ Joins a request previously started via start_request. - + Returns the result or None if timeout is reached. """ resume = self._request_state_manager.resume(request_id) @@ -483,7 +483,7 @@ async def join_request( if response_or_error is None: if done_on_timeout: - await self._request_state_manager.close_request(request_id) + await self._request_state_manager.close_request(request_id) return None elif isinstance(response_or_error, JSONRPCError): if response_or_error.error.code == httpx.codes.REQUEST_TIMEOUT.value: @@ -493,11 +493,10 @@ async def join_request( else: await self._request_state_manager.close_request(request_id) raise McpError(response_or_error.error) - else : + else: await self._request_state_manager.close_request(request_id) return result_type.model_validate(response_or_error.result) - async def cancel_request(self, request_id: RequestId) -> bool: """ Cancels a request previously started via start_request diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 525f047e7..39bed193c 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -2,13 +2,11 @@ from typing import Any import anyio -import httpx import pytest import mcp.types as types from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession from mcp.shared.context import RequestContext -from mcp.shared.exceptions import McpError from mcp.shared.message import SessionMessage from mcp.shared.session import InMemoryRequestStateManager, RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 110e4938e..d0f004803 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -1396,10 +1396,11 @@ async def run_tool(): captured_request_id = await session.request_call_tool( "long_running_with_checkpoints", arguments={} ) - + result = await session.join_call_tool( - captured_request_id, request_read_timeout_seconds=timedelta(seconds=0.01), - done_on_timeout=False + captured_request_id, + request_read_timeout_seconds=timedelta(seconds=0.01), + done_on_timeout=False, ) assert result is None @@ -1474,7 +1475,7 @@ async def run_tool(): assert not any(n in captured_notifications_pre for n in captured_notifications), ( f"{captured_notifications_pre} -> {captured_notifications}" ) - + assert len(request_state_manager_1._progress_callbacks) == 0 assert len(request_state_manager_1._response_streams) == 0 assert len(request_state_manager_2._progress_callbacks) == 0 From c5eab90c9b1ab68d80105b0e060712e2bb2ffc74 Mon Sep 17 00:00:00 2001 From: David Savage Date: Mon, 28 Jul 2025 16:12:23 +0000 Subject: [PATCH 11/12] add assert for pyright checks --- tests/shared/test_streamable_http.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index d0f004803..936bafe23 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -1316,6 +1316,7 @@ async def run_tool(): assert captured_request_id is not None result = await session.join_call_tool(captured_request_id) + assert result is not None # We should get a complete result assert len(result.content) == 1 From 79eb3c9dbdfdb5976bc97538037524c32eace719 Mon Sep 17 00:00:00 2001 From: David Savage Date: Mon, 28 Jul 2025 16:12:42 +0000 Subject: [PATCH 12/12] update test description --- tests/shared/test_streamable_http.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 936bafe23..59c4388ec 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -1344,7 +1344,7 @@ async def run_tool(): @pytest.mark.anyio async def test_streamablehttp_client_resumption_timeout(event_server): - """Test client session to resume a long running tool via non blocking api.""" + """Test client session to resume a long running tool via non blocking api with timeout.""" _, server_url = event_server with anyio.fail_after(10):