diff --git a/examples/01_basic_usage_example.py b/examples/01_basic_usage_example.py index 2f693d4..b0a36e6 100644 --- a/examples/01_basic_usage_example.py +++ b/examples/01_basic_usage_example.py @@ -13,6 +13,8 @@ if __name__ == "__main__": + import os import uvicorn - uvicorn.run(app, host="0.0.0.0", port=8000) + # Set HOST=0.0.0.0 to expose on your network. + uvicorn.run(app, host=os.getenv("HOST", "127.0.0.1"), port=8000) diff --git a/examples/02_full_schema_description_example.py b/examples/02_full_schema_description_example.py index 766d4d3..cbeccab 100644 --- a/examples/02_full_schema_description_example.py +++ b/examples/02_full_schema_description_example.py @@ -21,6 +21,8 @@ mcp.mount_http() if __name__ == "__main__": + import os import uvicorn - uvicorn.run(app, host="0.0.0.0", port=8000) + # Set HOST=0.0.0.0 to expose on your network. + uvicorn.run(app, host=os.getenv("HOST", "127.0.0.1"), port=8000) diff --git a/examples/03_custom_exposed_endpoints_example.py b/examples/03_custom_exposed_endpoints_example.py index 60d13f4..b08829c 100644 --- a/examples/03_custom_exposed_endpoints_example.py +++ b/examples/03_custom_exposed_endpoints_example.py @@ -60,6 +60,7 @@ combined_include_mcp.mount_http(mount_path="/combined-include-mcp") if __name__ == "__main__": + import os import uvicorn print("Server is running with multiple MCP endpoints:") @@ -68,4 +69,5 @@ print(" - /include-tags-mcp: Only operations with the 'items' tag") print(" - /exclude-tags-mcp: All operations except those with the 'search' tag") print(" - /combined-include-mcp: Operations with 'search' tag or delete_item operation") - uvicorn.run(app, host="0.0.0.0", port=8000) + # Set HOST=0.0.0.0 to expose on your network. + uvicorn.run(app, host=os.getenv("HOST", "127.0.0.1"), port=8000) diff --git a/examples/04_separate_server_example.py b/examples/04_separate_server_example.py index 4b200bc..2d3c244 100644 --- a/examples/04_separate_server_example.py +++ b/examples/04_separate_server_example.py @@ -29,6 +29,8 @@ # It still works 🚀 # Your original API is **not exposed**, only via the MCP server. if __name__ == "__main__": + import os import uvicorn - uvicorn.run(mcp_app, host="0.0.0.0", port=8000) + # Set HOST=0.0.0.0 to expose on your network. + uvicorn.run(mcp_app, host=os.getenv("HOST", "127.0.0.1"), port=8000) diff --git a/examples/05_reregister_tools_example.py b/examples/05_reregister_tools_example.py index b7133cd..e1d5b21 100644 --- a/examples/05_reregister_tools_example.py +++ b/examples/05_reregister_tools_example.py @@ -24,6 +24,8 @@ async def new_endpoint(): if __name__ == "__main__": + import os import uvicorn - uvicorn.run(app, host="0.0.0.0", port=8000) + # Set HOST=0.0.0.0 to expose on your network. + uvicorn.run(app, host=os.getenv("HOST", "127.0.0.1"), port=8000) diff --git a/examples/06_custom_mcp_router_example.py b/examples/06_custom_mcp_router_example.py index 47b0f21..920d7f0 100644 --- a/examples/06_custom_mcp_router_example.py +++ b/examples/06_custom_mcp_router_example.py @@ -21,6 +21,8 @@ if __name__ == "__main__": + import os import uvicorn - uvicorn.run(app, host="0.0.0.0", port=8000) + # Set HOST=0.0.0.0 to expose on your network. + uvicorn.run(app, host=os.getenv("HOST", "127.0.0.1"), port=8000) diff --git a/examples/07_configure_http_timeout_example.py b/examples/07_configure_http_timeout_example.py index f225774..8d75865 100644 --- a/examples/07_configure_http_timeout_example.py +++ b/examples/07_configure_http_timeout_example.py @@ -18,6 +18,8 @@ if __name__ == "__main__": + import os import uvicorn - uvicorn.run(app, host="0.0.0.0", port=8000) + # Set HOST=0.0.0.0 to expose on your network. + uvicorn.run(app, host=os.getenv("HOST", "127.0.0.1"), port=8000) diff --git a/examples/08_auth_example_token_passthrough.py b/examples/08_auth_example_token_passthrough.py index dbc5b1d..9e13ea9 100644 --- a/examples/08_auth_example_token_passthrough.py +++ b/examples/08_auth_example_token_passthrough.py @@ -56,6 +56,8 @@ async def private(token=Depends(token_auth_scheme)): if __name__ == "__main__": + import os import uvicorn - uvicorn.run(app, host="0.0.0.0", port=8000) + # Set HOST=0.0.0.0 to expose on your network. + uvicorn.run(app, host=os.getenv("HOST", "127.0.0.1"), port=8000) diff --git a/examples/09_auth_example_auth0.py b/examples/09_auth_example_auth0.py index bf1cde2..76f3eb1 100644 --- a/examples/09_auth_example_auth0.py +++ b/examples/09_auth_example_auth0.py @@ -132,6 +132,8 @@ async def protected(user_id: str = Depends(get_current_user_id)): if __name__ == "__main__": + import os import uvicorn - uvicorn.run(app, host="0.0.0.0", port=8000) + # Set HOST=0.0.0.0 to expose on your network. + uvicorn.run(app, host=os.getenv("HOST", "127.0.0.1"), port=8000) diff --git a/fastapi_mcp/transport/sse.py b/fastapi_mcp/transport/sse.py index 4b64a43..20260c2 100644 --- a/fastapi_mcp/transport/sse.py +++ b/fastapi_mcp/transport/sse.py @@ -1,5 +1,5 @@ -from uuid import UUID import logging +from uuid import UUID from typing import Union from anyio.streams.memory import MemoryObjectSendStream @@ -13,8 +13,36 @@ logger = logging.getLogger(__name__) +DEFAULT_MAX_BODY_BYTES = 1_000_000 + class FastApiSseTransport(SseServerTransport): + def __init__(self, endpoint: str, max_body_bytes: int = DEFAULT_MAX_BODY_BYTES): + super().__init__(endpoint) + + max_body_bytes_int = int(max_body_bytes) + if max_body_bytes_int <= 0: + raise ValueError("max_body_bytes must be positive") + + self._max_body_bytes = max_body_bytes_int + + async def _read_body_with_limit(self, request: Request) -> bytes: + content_length = request.headers.get("content-length") + if content_length is not None: + try: + if int(content_length) > self._max_body_bytes: + raise HTTPException(status_code=413, detail="Payload too large") + except ValueError: + # Ignore invalid Content-Length and fall back to streaming limit. + pass + + out = bytearray() + async for chunk in request.stream(): + out.extend(chunk) + if len(out) > self._max_body_bytes: + raise HTTPException(status_code=413, detail="Payload too large") + return bytes(out) + async def handle_fastapi_post_message(self, request: Request) -> Response: """ A reimplementation of the handle_post_message method of SseServerTransport @@ -53,8 +81,8 @@ async def handle_fastapi_post_message(self, request: Request) -> Response: logger.warning(f"Could not find session for ID: {session_id}") raise HTTPException(status_code=404, detail="Could not find session") - body = await request.body() - logger.debug(f"Received JSON: {body.decode()}") + body = await self._read_body_with_limit(request) + logger.debug(f"Received JSON: {body.decode(errors='replace')}") try: message = JSONRPCMessage.model_validate_json(body) diff --git a/tests/test_sse_mock_transport.py b/tests/test_sse_mock_transport.py index e833e5d..dde4bdd 100644 --- a/tests/test_sse_mock_transport.py +++ b/tests/test_sse_mock_transport.py @@ -11,6 +11,10 @@ from mcp.types import JSONRPCMessage, JSONRPCError +async def _bytes_stream(data: bytes): + yield data + + @pytest.fixture def mock_transport() -> FastApiSseTransport: # Initialize transport with a mock endpoint @@ -88,7 +92,8 @@ async def test_handle_post_message_validation_error( # Create a mock request with valid session_id but invalid body mock_request = MagicMock(spec=Request) mock_request.query_params = {"session_id": valid_session_id.hex} - mock_request.body = AsyncMock(return_value=b'{"invalid": "json"}') + mock_request.headers = {} + mock_request.stream = lambda: _bytes_stream(b'{"invalid": "json"}') # Mock BackgroundTasks with patch("fastapi_mcp.transport.sse.BackgroundTasks") as MockBackgroundTasks: @@ -119,7 +124,8 @@ async def test_handle_post_message_general_exception( # Instead of mocking the body method to raise an exception, # we'll patch the body method to return a normal value and then # patch JSONRPCMessage.model_validate_json to raise the exception - mock_request.body = AsyncMock(return_value=b'{"jsonrpc": "2.0", "method": "test", "id": "1"}') + mock_request.headers = {} + mock_request.stream = lambda: _bytes_stream(b'{"jsonrpc": "2.0", "method": "test", "id": "1"}') # Mock the model_validate_json method to raise an Exception with patch("mcp.types.JSONRPCMessage.model_validate_json", side_effect=Exception("Test exception")): @@ -131,6 +137,22 @@ async def test_handle_post_message_general_exception( assert "Invalid request body" in excinfo.value.detail +@pytest.mark.anyio +async def test_handle_post_message_payload_too_large(valid_session_id: UUID, mock_writer: AsyncMock) -> None: + transport = FastApiSseTransport("/messages", max_body_bytes=10) + transport._read_stream_writers = {valid_session_id: mock_writer} + + mock_request = MagicMock(spec=Request) + mock_request.query_params = {"session_id": valid_session_id.hex} + mock_request.headers = {} + mock_request.stream = lambda: _bytes_stream(b"x" * 11) + + with pytest.raises(HTTPException) as excinfo: + await transport.handle_fastapi_post_message(mock_request) + + assert excinfo.value.status_code == 413 + + @pytest.mark.anyio async def test_send_message_safely_with_validation_error( mock_transport: FastApiSseTransport, mock_writer: AsyncMock