Skip to content

Commit bcf53b7

Browse files
committed
Add unit tests for streamable HTTP SSE handling
1 parent dc90d27 commit bcf53b7

File tree

3 files changed

+313
-6
lines changed

3 files changed

+313
-6
lines changed

tests/server/fastmcp/resources/test_file_resources.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ async def test_missing_file_error(self, temp_file: Path):
103103

104104
@pytest.mark.skipif(os.name == "nt", reason="File permissions behave differently on Windows")
105105
@pytest.mark.anyio
106-
async def test_permission_error(temp_file: Path):
106+
async def test_permission_error(temp_file: Path): # pragma: no cover - skipped on Windows and root
107107
"""Test reading a file without permissions."""
108108
if os.geteuid() == 0: # pragma: no cover
109109
pytest.skip("Permission test not reliable when running as root")

tests/shared/test_streamable_http.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -359,21 +359,21 @@ def basic_server(basic_server_port: int) -> Generator[None, None, None]:
359359

360360

361361
@pytest.fixture
362-
def event_store() -> SimpleEventStore:
362+
def event_store() -> SimpleEventStore: # pragma: no cover - exercised only on non-Windows platforms
363363
"""Create a test event store."""
364364
return SimpleEventStore()
365365

366366

367367
@pytest.fixture
368-
def event_server_port() -> int:
368+
def event_server_port() -> int: # pragma: no cover - exercised only on non-Windows platforms
369369
"""Find an available port for the event store server."""
370370
with socket.socket() as s:
371371
s.bind(("127.0.0.1", 0))
372372
return s.getsockname()[1]
373373

374374

375375
@pytest.fixture
376-
def event_server(
376+
def event_server( # pragma: no cover - exercised only on non-Windows platforms
377377
event_server_port: int, event_store: SimpleEventStore
378378
) -> Generator[tuple[SimpleEventStore, str], None, None]:
379379
"""Start a server with event store enabled."""
@@ -395,7 +395,9 @@ def event_server(
395395

396396

397397
@pytest.fixture
398-
def json_response_server(json_server_port: int) -> Generator[None, None, None]:
398+
def json_response_server( # pragma: no cover - exercised only on non-Windows platforms
399+
json_server_port: int,
400+
) -> Generator[None, None, None]:
399401
"""Start a server with JSON response enabled."""
400402
proc = multiprocessing.Process(
401403
target=run_server,
@@ -1105,7 +1107,9 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt
11051107

11061108
@pytest.mark.anyio
11071109
@pytest.mark.skipif(sys.platform == "win32", reason="Resumption unstable on Windows")
1108-
async def test_streamablehttp_client_resumption(event_server: tuple[SimpleEventStore, str]):
1110+
async def test_streamablehttp_client_resumption( # pragma: no cover - skipped on Windows builds
1111+
event_server: tuple[SimpleEventStore, str]
1112+
):
11091113
"""Test client session resumption using sync primitives for reliable coordination."""
11101114
_, server_url = event_server
11111115

Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
"""Focused unit tests for :mod:`mcp.client.streamable_http`."""
2+
3+
from __future__ import annotations
4+
5+
from collections.abc import AsyncIterator
6+
7+
import anyio
8+
import pytest
9+
from httpx import Timeout
10+
from httpx_sse import ServerSentEvent
11+
12+
from mcp.client.streamable_http import (
13+
LAST_EVENT_ID,
14+
RequestContext,
15+
ResumptionError,
16+
StreamableHTTPTransport,
17+
)
18+
from mcp.shared.message import ClientMessageMetadata, SessionMessage
19+
from mcp.types import JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse
20+
21+
22+
SessionMessageOrError = SessionMessage | Exception
23+
24+
25+
@pytest.mark.anyio
26+
async def test_handle_sse_event_initialization_sets_protocol_and_restores_id() -> None:
27+
"""Initialization responses should update protocol version and preserve request IDs."""
28+
29+
transport = StreamableHTTPTransport("http://example.test")
30+
send_stream, receive_stream = anyio.create_memory_object_stream[SessionMessageOrError](10)
31+
32+
initialization_payload = {
33+
"protocolVersion": "1.2",
34+
"capabilities": {},
35+
"serverInfo": {"name": "unit", "version": "0.0.0"},
36+
}
37+
response_message = JSONRPCMessage(
38+
JSONRPCResponse(jsonrpc="2.0", id="server-id", result=initialization_payload)
39+
)
40+
sse = ServerSentEvent(event="message", data=response_message.model_dump_json())
41+
42+
async with send_stream, receive_stream:
43+
complete = await transport._handle_sse_event( # noqa: SLF001 - exercising private helper
44+
sse,
45+
send_stream,
46+
original_request_id="original-id",
47+
is_initialization=True,
48+
)
49+
50+
assert complete is True
51+
received = await receive_stream.receive()
52+
assert isinstance(received, SessionMessage)
53+
assert received.message.root.id == "original-id"
54+
assert transport.protocol_version == "1.2"
55+
56+
57+
@pytest.mark.anyio
58+
async def test_handle_sse_event_notification_invokes_resumption_callback() -> None:
59+
"""Notifications should forward resumption tokens and keep the stream open."""
60+
61+
transport = StreamableHTTPTransport("http://example.test")
62+
send_stream, receive_stream = anyio.create_memory_object_stream[SessionMessageOrError](10)
63+
64+
notification_message = JSONRPCMessage(
65+
JSONRPCNotification(jsonrpc="2.0", method="test/notification", params=None)
66+
)
67+
sse = ServerSentEvent(event="message", data=notification_message.model_dump_json(), id=" resume ")
68+
69+
captured_token: list[str] = []
70+
71+
async def on_resumption_token_update(token: str) -> None:
72+
captured_token.append(token)
73+
74+
async with send_stream, receive_stream:
75+
complete = await transport._handle_sse_event( # noqa: SLF001 - exercising private helper
76+
sse,
77+
send_stream,
78+
resumption_callback=on_resumption_token_update,
79+
)
80+
81+
assert complete is False
82+
received = await receive_stream.receive()
83+
assert isinstance(received, SessionMessage)
84+
assert isinstance(received.message.root, JSONRPCNotification)
85+
assert captured_token == ["resume"]
86+
87+
88+
class _FakeResponse:
89+
def __init__(self) -> None:
90+
self.raised = False
91+
self.closed = False
92+
93+
def raise_for_status(self) -> None:
94+
self.raised = True
95+
96+
async def aclose(self) -> None:
97+
self.closed = True
98+
99+
100+
class _FakeEventSource:
101+
def __init__(self, events: list[ServerSentEvent], response: _FakeResponse | None = None) -> None:
102+
self._events = events
103+
self.response = response or _FakeResponse()
104+
105+
async def __aenter__(self) -> "_FakeEventSource":
106+
return self
107+
108+
async def __aexit__(self, exc_type, exc, tb) -> None: # type: ignore[override]
109+
return None
110+
111+
async def aiter_sse(self) -> AsyncIterator[ServerSentEvent]:
112+
for event in self._events:
113+
yield event
114+
115+
116+
@pytest.mark.anyio
117+
async def test_handle_get_stream_processes_events(monkeypatch: pytest.MonkeyPatch) -> None:
118+
"""The GET stream helper should consume SSE events when a session exists."""
119+
120+
transport = StreamableHTTPTransport("http://example.test")
121+
transport.session_id = "session-123"
122+
123+
send_stream, receive_stream = anyio.create_memory_object_stream[SessionMessageOrError](10)
124+
fake_events = [ServerSentEvent(event="message", data="{}")]
125+
126+
captured_headers: dict[str, str] | None = None
127+
128+
def fake_aconnect_sse(
129+
client: object, method: str, url: str, headers: dict[str, str], timeout: Timeout
130+
) -> _FakeEventSource:
131+
nonlocal captured_headers
132+
captured_headers = headers
133+
assert method == "GET"
134+
assert url == "http://example.test"
135+
return _FakeEventSource(fake_events)
136+
137+
call_count = 0
138+
139+
async def fake_handle_sse_event(*args, **kwargs) -> bool: # type: ignore[unused-argument]
140+
nonlocal call_count
141+
call_count += 1
142+
return True
143+
144+
monkeypatch.setattr("mcp.client.streamable_http.aconnect_sse", fake_aconnect_sse)
145+
monkeypatch.setattr(
146+
StreamableHTTPTransport, "_handle_sse_event", fake_handle_sse_event
147+
)
148+
149+
async with send_stream, receive_stream:
150+
await transport.handle_get_stream(object(), send_stream)
151+
152+
assert call_count == 1
153+
assert captured_headers is not None
154+
assert captured_headers.get("mcp-session-id") == "session-123"
155+
156+
157+
@pytest.mark.anyio
158+
async def test_handle_resumption_request_requires_token() -> None:
159+
"""Resumption requests without a token must fail fast."""
160+
161+
transport = StreamableHTTPTransport("http://example.test")
162+
send_stream, receive_stream = anyio.create_memory_object_stream[SessionMessageOrError](10)
163+
164+
session_message = SessionMessage(
165+
JSONRPCMessage(JSONRPCRequest(jsonrpc="2.0", id="1", method="test"))
166+
)
167+
ctx = RequestContext(
168+
client=object(),
169+
headers={},
170+
session_id=None,
171+
session_message=session_message,
172+
metadata=ClientMessageMetadata(resumption_token=None),
173+
read_stream_writer=send_stream,
174+
sse_read_timeout=1.0,
175+
)
176+
177+
async with send_stream, receive_stream:
178+
with pytest.raises(ResumptionError):
179+
await transport._handle_resumption_request(ctx) # noqa: SLF001
180+
181+
182+
@pytest.mark.anyio
183+
async def test_handle_resumption_request_stream(monkeypatch: pytest.MonkeyPatch) -> None:
184+
"""Resumption requests should forward the original ID and close the SSE response."""
185+
186+
transport = StreamableHTTPTransport("http://example.test")
187+
transport.session_id = "session-123"
188+
send_stream, receive_stream = anyio.create_memory_object_stream[SessionMessageOrError](10)
189+
190+
metadata = ClientMessageMetadata(resumption_token=" token ")
191+
session_message = SessionMessage(
192+
JSONRPCMessage(
193+
JSONRPCRequest(jsonrpc="2.0", id="original", method="tool", params={})
194+
),
195+
metadata=metadata,
196+
)
197+
ctx = RequestContext(
198+
client=object(),
199+
headers={"custom": "header"},
200+
session_id="session-123",
201+
session_message=session_message,
202+
metadata=metadata,
203+
read_stream_writer=send_stream,
204+
sse_read_timeout=1.0,
205+
)
206+
207+
fake_events = [ServerSentEvent(event="message", data="{}") for _ in range(2)]
208+
fake_event_source = _FakeEventSource(fake_events)
209+
210+
captured_headers: dict[str, str] | None = None
211+
212+
def fake_aconnect_sse(
213+
client: object, method: str, url: str, headers: dict[str, str], timeout: Timeout
214+
) -> _FakeEventSource:
215+
nonlocal captured_headers
216+
captured_headers = headers
217+
assert client is ctx.client
218+
assert method == "GET"
219+
assert url == "http://example.test"
220+
return fake_event_source
221+
222+
call_args: list[dict[str, object]] = []
223+
224+
async def fake_handle_sse_event(
225+
self,
226+
sse,
227+
read_stream_writer,
228+
original_request_id=None,
229+
resumption_callback=None,
230+
is_initialization=False,
231+
) -> bool:
232+
call_args.append(
233+
{
234+
"original_request_id": original_request_id,
235+
"resumption_callback": resumption_callback,
236+
}
237+
)
238+
return len(call_args) >= 2
239+
240+
monkeypatch.setattr("mcp.client.streamable_http.aconnect_sse", fake_aconnect_sse)
241+
monkeypatch.setattr(StreamableHTTPTransport, "_handle_sse_event", fake_handle_sse_event)
242+
243+
async with send_stream, receive_stream:
244+
await transport._handle_resumption_request(ctx) # noqa: SLF001
245+
246+
assert captured_headers is not None
247+
assert captured_headers.get(LAST_EVENT_ID) == "token"
248+
assert fake_event_source.response.raised is True
249+
assert fake_event_source.response.closed is True
250+
assert call_args
251+
assert call_args[0]["original_request_id"] == "original"
252+
253+
254+
@pytest.mark.anyio
255+
async def test_handle_sse_response_closes_after_completion(monkeypatch: pytest.MonkeyPatch) -> None:
256+
"""SSE POST responses should stop reading once a response has been emitted."""
257+
258+
transport = StreamableHTTPTransport("http://example.test")
259+
send_stream, receive_stream = anyio.create_memory_object_stream[SessionMessageOrError](10)
260+
261+
metadata = ClientMessageMetadata()
262+
session_message = SessionMessage(
263+
JSONRPCMessage(JSONRPCRequest(jsonrpc="2.0", id="42", method="ping")),
264+
metadata=metadata,
265+
)
266+
ctx = RequestContext(
267+
client=object(),
268+
headers={},
269+
session_id=None,
270+
session_message=session_message,
271+
metadata=metadata,
272+
read_stream_writer=send_stream,
273+
sse_read_timeout=1.0,
274+
)
275+
276+
events = [ServerSentEvent(event="message", data="{}") for _ in range(2)]
277+
278+
created_sources: list[_FakeEventSource] = []
279+
280+
class FakeEventSourceFactory:
281+
def __call__(self, response: _FakeResponse) -> _FakeEventSource:
282+
source = _FakeEventSource(events, response)
283+
created_sources.append(source)
284+
return source
285+
286+
fake_response = _FakeResponse()
287+
288+
async def fake_handle_sse_event(*args, **kwargs) -> bool: # type: ignore[unused-argument]
289+
fake_handle_sse_event.call_count += 1
290+
return fake_handle_sse_event.call_count >= 2
291+
292+
fake_handle_sse_event.call_count = 0
293+
294+
monkeypatch.setattr("mcp.client.streamable_http.EventSource", FakeEventSourceFactory())
295+
monkeypatch.setattr(StreamableHTTPTransport, "_handle_sse_event", fake_handle_sse_event)
296+
297+
async with send_stream, receive_stream:
298+
await transport._handle_sse_response(fake_response, ctx, is_initialization=True)
299+
300+
assert fake_handle_sse_event.call_count == 2
301+
assert created_sources and created_sources[0].response is fake_response
302+
assert fake_response.closed is True
303+

0 commit comments

Comments
 (0)