diff --git a/docs/development/tests.mdx b/docs/development/tests.mdx index 7263d7dea..d317bb6dc 100644 --- a/docs/development/tests.mdx +++ b/docs/development/tests.mdx @@ -297,7 +297,53 @@ async def test_database_tool(): ### Testing Network Transports -While in-memory testing covers most unit testing needs, you'll occasionally need to test actual network transports. Use the `run_server_in_process` utility to spawn a server in a separate process for testing: +While in-memory testing covers most unit testing needs, you'll occasionally need to test actual network transports like HTTP or SSE. FastMCP provides two approaches: in-process async servers using AnyIO task groups (preferred), and separate subprocess servers (for special cases). + +#### In-Process Network Testing (Preferred) + +For most network transport tests, use `run_server_async` with AnyIO task groups. This runs the server as a task in the same process, providing fast, deterministic tests with full debugger support: + +```python +import pytest +from anyio.abc import TaskGroup +from fastmcp import FastMCP, Client +from fastmcp.client.transports import StreamableHttpTransport +from fastmcp.utilities.tests import run_server_async + +def create_test_server() -> FastMCP: + """Create a test server instance.""" + server = FastMCP("TestServer") + + @server.tool + def greet(name: str) -> str: + return f"Hello, {name}!" + + return server + +@pytest.fixture +async def http_server(task_group: TaskGroup) -> str: + """Start server in-process using task group.""" + server = create_test_server() + url = await run_server_async(task_group, server, transport="http") + return url + +async def test_http_transport(http_server: str): + """Test actual HTTP transport behavior.""" + async with Client( + transport=StreamableHttpTransport(http_server) + ) as client: + result = await client.ping() + assert result is True + + greeting = await client.call_tool("greet", {"name": "World"}) + assert greeting.data == "Hello, World!" +``` + +The `task_group` fixture is provided globally by `conftest.py` and automatically handles server lifecycle and cleanup. This approach is faster than subprocess-based testing and provides better error messages. + +#### Subprocess Testing (Special Cases) + +For tests that require complete process isolation (like STDIO transport or testing subprocess behavior), use `run_server_in_process`: ```python import pytest @@ -328,12 +374,9 @@ async def test_http_transport(http_server: str): ) as client: result = await client.ping() assert result is True - - greeting = await client.call_tool("greet", {"name": "World"}) - assert greeting.data == "Hello, World!" ``` -The `run_server_in_process` utility handles server lifecycle, port allocation, and cleanup automatically. This pattern is essential for testing transport-specific behavior like timeouts, headers, and authentication. Note that FastMCP often uses the `client_process` marker to isolate tests that spawn processes, as they can create contention in CI. +The `run_server_in_process` utility handles server lifecycle, port allocation, and cleanup automatically. Use this only when subprocess isolation is truly necessary, as it's slower and harder to debug than in-process testing. FastMCP uses the `client_process` marker to isolate these tests in CI. ### Documentation Testing diff --git a/pyproject.toml b/pyproject.toml index ac755433f..d98a1f83c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,6 @@ dev = [ "pyinstrument>=5.0.2", "pyperclip>=1.9.0", "pytest>=8.3.3", - "pytest-asyncio>=0.23.5", "pytest-cov>=6.1.1", "pytest-env>=1.1.5", "pytest-flakefinder", @@ -98,9 +97,7 @@ fallback-version = "0.0.0" [tool.pytest.ini_options] -asyncio_mode = "auto" -asyncio_default_fixture_loop_scope = "session" -asyncio_default_test_loop_scope = "session" +anyio_mode = "auto" # filterwarnings = ["error::DeprecationWarning"] timeout = 5 env = [ diff --git a/src/fastmcp/server/context.py b/src/fastmcp/server/context.py index 8b46b7a0d..9c3feaf07 100644 --- a/src/fastmcp/server/context.py +++ b/src/fastmcp/server/context.py @@ -6,7 +6,6 @@ import logging import warnings import weakref -from asyncio.locks import Lock from collections.abc import Generator, Mapping, Sequence from contextlib import contextmanager from contextvars import ContextVar, Token @@ -15,6 +14,7 @@ from logging import Logger from typing import Any, Literal, cast, get_origin, overload +import anyio from mcp import LoggingLevel, ServerSession from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.server.lowlevel.server import request_ctx @@ -61,7 +61,7 @@ T = TypeVar("T", default=Any) _current_context: ContextVar[Context | None] = ContextVar("context", default=None) # type: ignore[assignment] -_flush_lock: Lock = asyncio.Lock() +_flush_lock = anyio.Lock() @dataclass diff --git a/src/fastmcp/server/middleware/rate_limiting.py b/src/fastmcp/server/middleware/rate_limiting.py index 42a0533f7..703fd393a 100644 --- a/src/fastmcp/server/middleware/rate_limiting.py +++ b/src/fastmcp/server/middleware/rate_limiting.py @@ -1,11 +1,11 @@ """Rate limiting middleware for protecting FastMCP servers from abuse.""" -import asyncio import time from collections import defaultdict, deque from collections.abc import Callable from typing import Any +import anyio from mcp import McpError from mcp.types import ErrorData @@ -33,7 +33,7 @@ def __init__(self, capacity: int, refill_rate: float): self.refill_rate = refill_rate self.tokens = capacity self.last_refill = time.time() - self._lock = asyncio.Lock() + self._lock = anyio.Lock() async def consume(self, tokens: int = 1) -> bool: """Try to consume tokens from the bucket. @@ -71,7 +71,7 @@ def __init__(self, max_requests: int, window_seconds: int): self.max_requests = max_requests self.window_seconds = window_seconds self.requests = deque() - self._lock = asyncio.Lock() + self._lock = anyio.Lock() async def is_allowed(self) -> bool: """Check if a request is allowed.""" diff --git a/src/fastmcp/utilities/tests.py b/src/fastmcp/utilities/tests.py index ce175e221..9b1dbfb92 100644 --- a/src/fastmcp/utilities/tests.py +++ b/src/fastmcp/utilities/tests.py @@ -140,6 +140,85 @@ def run_server_in_process( raise RuntimeError("Server process failed to terminate even after kill") +async def run_server_async( + task_group, + server: FastMCP, + port: int | None = None, + transport: Literal["http", "streamable-http", "sse"] = "http", + path: str = "/mcp", + host: str = "127.0.0.1", +) -> str: + """ + Start a FastMCP server in an AnyIO task group for in-process async testing. + + This is the recommended way to test FastMCP servers. It runs the server + as an async task in the same process, eliminating subprocess coordination, + sleeps, and cleanup issues. + + Args: + task_group: AnyIO task group to run the server in + server: FastMCP server instance + port: Port to bind to (default: find available port) + transport: Transport type ("http", "streamable-http", or "sse") + path: URL path for the server (default: "/mcp") + host: Host to bind to (default: "127.0.0.1") + + Returns: + Server URL string + + Example: + ```python + import anyio + import pytest + from anyio.abc import TaskGroup + from fastmcp import FastMCP, Client + from fastmcp.client.transports import StreamableHttpTransport + from fastmcp.utilities.tests import run_server_async + + @pytest.fixture + async def server(task_group: TaskGroup): + mcp = FastMCP("test") + + @mcp.tool() + def greet(name: str) -> str: + return f"Hello, {name}!" + + url = await run_server_async(task_group, mcp) + return url + + async def test_greet(server: str): + async with Client(StreamableHttpTransport(server)) as client: + result = await client.call_tool("greet", {"name": "World"}) + assert result.content[0].text == "Hello, World!" + ``` + """ + from functools import partial + + import anyio + + if port is None: + port = find_available_port() + + # Wait a tiny bit for the port to be released if it was just used + await anyio.sleep(0.01) + + task_group.start_soon( + partial( + server.run_http_async, + host=host, + port=port, + transport=transport, + path=path, + show_banner=False, + ) + ) + + # Give the server a moment to start + await anyio.sleep(0.1) + + return f"http://{host}:{port}{path}" + + @contextmanager def caplog_for_fastmcp(caplog): """Context manager to capture logs from FastMCP loggers even when propagation is disabled.""" diff --git a/tests/cli/test_mcp_server_config_integration.py b/tests/cli/test_mcp_server_config_integration.py index 219532b53..483c9a450 100644 --- a/tests/cli/test_mcp_server_config_integration.py +++ b/tests/cli/test_mcp_server_config_integration.py @@ -84,7 +84,6 @@ def test_detect_test_fastmcp_json(self, tmp_path): class TestConfigWithClient: """Test fastmcp.json configuration with client connections.""" - @pytest.mark.asyncio async def test_config_server_with_client(self, server_with_config): """Test that a server loaded from config works with a client.""" # Load the config diff --git a/tests/cli/test_server_args.py b/tests/cli/test_server_args.py index 8f1782dbf..a0962ef0c 100644 --- a/tests/cli/test_server_args.py +++ b/tests/cli/test_server_args.py @@ -11,7 +11,6 @@ class TestServerArguments: """Test passing arguments to servers.""" - @pytest.mark.asyncio async def test_server_with_argparse(self, tmp_path): """Test a server that uses argparse with command line arguments.""" server_file = tmp_path / "argparse_server.py" @@ -53,7 +52,6 @@ def get_config() -> dict: tools = await server.get_tools() assert "get_config" in tools - @pytest.mark.asyncio async def test_server_with_no_args(self, tmp_path): """Test a server that uses argparse with no arguments (defaults).""" server_file = tmp_path / "default_server.py" @@ -79,7 +77,6 @@ async def test_server_with_no_args(self, tmp_path): assert server.name == "DefaultName" - @pytest.mark.asyncio async def test_server_with_sys_argv_access(self, tmp_path): """Test a server that directly accesses sys.argv.""" server_file = tmp_path / "sysargv_server.py" @@ -112,7 +109,6 @@ async def test_server_with_sys_argv_access(self, tmp_path): assert server.name == "DirectServer" - @pytest.mark.asyncio async def test_config_server_example(self): """Test the actual config_server.py example.""" # Find the examples directory diff --git a/tests/client/auth/test_oauth_client.py b/tests/client/auth/test_oauth_client.py index a689da36f..28de46872 100644 --- a/tests/client/auth/test_oauth_client.py +++ b/tests/client/auth/test_oauth_client.py @@ -1,15 +1,16 @@ -from collections.abc import Generator from urllib.parse import urlparse import httpx import pytest +from anyio.abc import TaskGroup from fastmcp.client import Client from fastmcp.client.transports import StreamableHttpTransport from fastmcp.server.auth.auth import ClientRegistrationOptions from fastmcp.server.auth.providers.in_memory import InMemoryOAuthProvider from fastmcp.server.server import FastMCP -from fastmcp.utilities.tests import HeadlessOAuth, run_server_in_process +from fastmcp.utilities.http import find_available_port +from fastmcp.utilities.tests import HeadlessOAuth, run_server_async def fastmcp_server(issuer_url: str): @@ -35,31 +36,27 @@ def get_test_resource() -> str: return server -def run_server(host: str, port: int, **kwargs) -> None: - fastmcp_server(f"http://{host}:{port}").run(host=host, port=port, **kwargs) - - @pytest.fixture -def streamable_http_server() -> Generator[str, None, None]: - with run_server_in_process(run_server, transport="http") as url: - yield f"{url}/mcp" +async def streamable_http_server(task_group: TaskGroup): + """Start OAuth-enabled server.""" + port = find_available_port() + server = fastmcp_server(f"http://127.0.0.1:{port}") + url = await run_server_async(task_group, server, port=port, transport="http") + return url -@pytest.fixture() +@pytest.fixture def client_unauthorized(streamable_http_server: str) -> Client: return Client(transport=StreamableHttpTransport(streamable_http_server)) -@pytest.fixture() -def client_with_headless_oauth( - streamable_http_server: str, -) -> Generator[Client, None, None]: +@pytest.fixture +def client_with_headless_oauth(streamable_http_server: str) -> Client: """Client with headless OAuth that bypasses browser interaction.""" - client = Client( + return Client( transport=StreamableHttpTransport(streamable_http_server), auth=HeadlessOAuth(mcp_url=streamable_http_server), ) - yield client async def test_unauthorized(client_unauthorized: Client): diff --git a/tests/client/test_openapi_experimental.py b/tests/client/test_openapi_experimental.py index c4b0636a9..e59c2f3ba 100644 --- a/tests/client/test_openapi_experimental.py +++ b/tests/client/test_openapi_experimental.py @@ -1,19 +1,17 @@ import json -from collections.abc import Generator import pytest +from anyio.abc import TaskGroup from fastapi import FastAPI, Request -import fastmcp from fastmcp import Client, FastMCP from fastmcp.client.transports import SSETransport, StreamableHttpTransport from fastmcp.experimental.server.openapi import MCPType, RouteMap -from fastmcp.utilities.tests import run_server_in_process +from fastmcp.utilities.tests import run_server_async, temporary_settings -def fastmcp_server_for_headers() -> FastMCP: - fastmcp.settings.experimental.enable_new_openapi_parser = True - +def create_fastmcp_server_for_headers() -> FastMCP: + """Create a FastMCP server from FastAPI app with experimental parser.""" app = FastAPI() @app.get("/headers") @@ -46,35 +44,30 @@ def post_headers(request: Request): return mcp -def run_server(host: str, port: int, **kwargs) -> None: - fastmcp_server_for_headers().run(host=host, port=port, **kwargs) - - -def run_proxy_server(host: str, port: int, shttp_url: str, **kwargs) -> None: - app = FastMCP.as_proxy(StreamableHttpTransport(shttp_url)) - app.run(host=host, port=port, **kwargs) - - @pytest.fixture -def shttp_server() -> Generator[str, None, None]: - with run_server_in_process(run_server, transport="http") as url: - yield f"{url}/mcp" +async def shttp_server(task_group: TaskGroup): + """Start a test server with StreamableHttp transport.""" + with temporary_settings(experimental__enable_new_openapi_parser=True): + server = create_fastmcp_server_for_headers() + url = await run_server_async(task_group, server, transport="http") + return url @pytest.fixture -def sse_server() -> Generator[str, None, None]: - with run_server_in_process(run_server, transport="sse") as url: - yield f"{url}/sse" +async def sse_server(task_group: TaskGroup): + """Start a test server with SSE transport.""" + with temporary_settings(experimental__enable_new_openapi_parser=True): + server = create_fastmcp_server_for_headers() + url = await run_server_async(task_group, server, transport="sse") + return url @pytest.fixture -def proxy_server(shttp_server: str) -> Generator[str, None, None]: - with run_server_in_process( - run_proxy_server, - shttp_url=shttp_server, - transport="http", - ) as url: - yield f"{url}/mcp" +async def proxy_server(task_group: TaskGroup, shttp_server: str): + """Start a proxy server.""" + proxy = FastMCP.as_proxy(StreamableHttpTransport(shttp_server)) + url = await run_server_async(task_group, proxy, transport="http") + return url async def test_fastapi_client_headers_streamable_http_resource(shttp_server: str): diff --git a/tests/client/test_openapi_legacy.py b/tests/client/test_openapi_legacy.py index 07915c851..0029db8cf 100644 --- a/tests/client/test_openapi_legacy.py +++ b/tests/client/test_openapi_legacy.py @@ -1,13 +1,13 @@ import json -from collections.abc import Generator import pytest +from anyio.abc import TaskGroup from fastapi import FastAPI, Request from fastmcp import Client, FastMCP from fastmcp.client.transports import SSETransport, StreamableHttpTransport from fastmcp.server.openapi import MCPType, RouteMap -from fastmcp.utilities.tests import run_server_in_process +from fastmcp.utilities.tests import run_server_async def fastmcp_server_for_headers() -> FastMCP: @@ -43,35 +43,28 @@ def post_headers(request: Request): return mcp -def run_server(host: str, port: int, **kwargs) -> None: - fastmcp_server_for_headers().run(host=host, port=port, **kwargs) - - -def run_proxy_server(host: str, port: int, shttp_url: str, **kwargs) -> None: - app = FastMCP.as_proxy(StreamableHttpTransport(shttp_url)) - app.run(host=host, port=port, **kwargs) - - @pytest.fixture -def shttp_server() -> Generator[str, None, None]: - with run_server_in_process(run_server, transport="http") as url: - yield f"{url}/mcp" +async def shttp_server(task_group: TaskGroup): + """Start a test server with StreamableHttp transport.""" + server = fastmcp_server_for_headers() + url = await run_server_async(task_group, server, transport="http") + return url @pytest.fixture -def sse_server() -> Generator[str, None, None]: - with run_server_in_process(run_server, transport="sse") as url: - yield f"{url}/sse" +async def sse_server(task_group: TaskGroup): + """Start a test server with SSE transport.""" + server = fastmcp_server_for_headers() + url = await run_server_async(task_group, server, transport="sse") + return url @pytest.fixture -def proxy_server(shttp_server: str) -> Generator[str, None, None]: - with run_server_in_process( - run_proxy_server, - shttp_url=shttp_server, - transport="http", - ) as url: - yield f"{url}/mcp" +async def proxy_server(task_group: TaskGroup, shttp_server: str): + """Start a proxy server.""" + proxy = FastMCP.as_proxy(StreamableHttpTransport(shttp_server)) + url = await run_server_async(task_group, proxy, transport="http") + return url async def test_fastapi_client_headers_streamable_http_resource(shttp_server: str): diff --git a/tests/client/test_sse.py b/tests/client/test_sse.py index 818f233fd..c2628f959 100644 --- a/tests/client/test_sse.py +++ b/tests/client/test_sse.py @@ -1,32 +1,28 @@ import asyncio import json import sys -from collections.abc import Generator +import anyio import pytest -import uvicorn +from anyio.abc import TaskGroup from mcp import McpError -from starlette.applications import Starlette -from starlette.routing import Mount from fastmcp.client import Client from fastmcp.client.transports import SSETransport from fastmcp.server.dependencies import get_http_request from fastmcp.server.server import FastMCP -from fastmcp.utilities.tests import run_server_in_process +from fastmcp.utilities.tests import run_server_async -def fastmcp_server(): - """Fixture that creates a FastMCP server with tools, resources, and prompts.""" +def create_test_server() -> FastMCP: + """Create a FastMCP server with tools, resources, and prompts.""" server = FastMCP("TestServer") - # Add a tool @server.tool def greet(name: str) -> str: """Greet someone by name.""" return f"Hello, {name}!" - # Add a second tool @server.tool def add(a: int, b: int) -> int: """Add two numbers together.""" @@ -38,12 +34,10 @@ async def sleep(seconds: float) -> str: await asyncio.sleep(seconds) return f"Slept for {seconds} seconds" - # Add a resource @server.resource(uri="data://users") async def get_users(): return ["Alice", "Bob", "Charlie"] - # Add a resource template @server.resource(uri="data://user/{user_id}") async def get_user(user_id: str): return {"id": user_id, "name": f"User {user_id}", "active": True} @@ -51,10 +45,8 @@ async def get_user(user_id: str): @server.resource(uri="request://headers") async def get_headers() -> dict[str, str]: request = get_http_request() - return dict(request.headers) - # Add a prompt @server.prompt def welcome(name: str) -> str: """Example greeting prompt.""" @@ -63,14 +55,12 @@ def welcome(name: str) -> str: return server -def run_server(host: str, port: int, **kwargs) -> None: - fastmcp_server().run(host=host, port=port, **kwargs) - - -@pytest.fixture(autouse=True) -def sse_server() -> Generator[str, None, None]: - with run_server_in_process(run_server, transport="sse") as url: - yield f"{url}/sse" +@pytest.fixture +async def sse_server(task_group: TaskGroup): + """Start a test server with SSE transport and return its URL.""" + server = create_test_server() + url = await run_server_async(task_group, server, transport="sse") + return url async def test_ping(sse_server: str): @@ -91,36 +81,59 @@ async def test_http_headers(sse_server: str): assert json_result["x-demo-header"] == "ABC" -def run_nested_server(host: str, port: int) -> None: - fastmcp = fastmcp_server() - app = fastmcp.sse_app(path="/mcp/sse/", message_path="/mcp/messages") - mount = Starlette(routes=[Mount("/nest-inner", app=app)]) - mount2 = Starlette(routes=[Mount("/nest-outer", app=mount)]) - server = uvicorn.Server( - config=uvicorn.Config( - app=mount2, host=host, port=port, log_level="error", ws="websockets-sansio" - ) +@pytest.fixture +async def sse_server_custom_path(task_group: TaskGroup): + """Start a test server with SSE on a custom path.""" + server = create_test_server() + url = await run_server_async(task_group, server, transport="sse", path="/help") + return url + + +@pytest.fixture +async def nested_sse_server(task_group: TaskGroup): + """Test nested server mounts with SSE.""" + import uvicorn + from starlette.applications import Starlette + from starlette.routing import Mount + + from fastmcp.utilities.http import find_available_port + + server = create_test_server() + sse_app = server.sse_app(path="/mcp/sse/", message_path="/mcp/messages") + + # Nest the app under multiple mounts to test URL resolution + inner = Starlette(routes=[Mount("/nest-inner", app=sse_app)]) + outer = Starlette(routes=[Mount("/nest-outer", app=inner)]) + + # Run uvicorn with the nested ASGI app + port = find_available_port() + + config = uvicorn.Config( + app=outer, + host="127.0.0.1", + port=port, + log_level="critical", + ws="websockets-sansio", ) - server.run() + task_group.start_soon(uvicorn.Server(config).serve) + await anyio.sleep(0.1) -async def test_run_server_on_path(): - with run_server_in_process(run_server, transport="sse", path="/help") as url: - async with Client(transport=SSETransport(f"{url}/help")) as client: - result = await client.ping() - assert result is True + return f"http://127.0.0.1:{port}/nest-outer/nest-inner/mcp/sse/" -async def test_nested_sse_server_resolves_correctly(): - # tests patch for - # https://github.com/modelcontextprotocol/python-sdk/pull/659 +async def test_run_server_on_path(sse_server_custom_path: str): + """Test running server on a custom path.""" + async with Client(transport=SSETransport(sse_server_custom_path)) as client: + result = await client.ping() + assert result is True - with run_server_in_process(run_nested_server) as url: - async with Client( - transport=SSETransport(f"{url}/nest-outer/nest-inner/mcp/sse/") - ) as client: - result = await client.ping() - assert result is True + +async def test_nested_sse_server_resolves_correctly(nested_sse_server: str): + """Test patch for https://github.com/modelcontextprotocol/python-sdk/pull/659""" + async with Client(transport=SSETransport(nested_sse_server)) as client: + result = await client.ping() + assert result is True @pytest.mark.skipif( diff --git a/tests/client/test_streamable_http.py b/tests/client/test_streamable_http.py index 9c7896330..8df36b0b5 100644 --- a/tests/client/test_streamable_http.py +++ b/tests/client/test_streamable_http.py @@ -1,28 +1,23 @@ import asyncio import json import sys -from collections.abc import AsyncGenerator from unittest.mock import AsyncMock, call import pytest -import uvicorn from mcp import McpError -from starlette.applications import Starlette -from starlette.routing import Mount from fastmcp import Context from fastmcp.client import Client from fastmcp.client.transports import StreamableHttpTransport from fastmcp.server.dependencies import get_http_request from fastmcp.server.server import FastMCP -from fastmcp.utilities.tests import run_server_in_process +from fastmcp.utilities.tests import run_server_async -def fastmcp_server(): - """Fixture that creates a FastMCP server with tools, resources, and prompts.""" +def create_test_server() -> FastMCP: + """Create a FastMCP server with tools, resources, and prompts.""" server = FastMCP("TestServer") - # Add a tool @server.tool def greet(name: str) -> str: """Greet someone by name.""" @@ -38,7 +33,6 @@ async def elicit(ctx: Context) -> str: else: return "No name provided" - # Add a second tool @server.tool def add(a: int, b: int) -> int: """Add two numbers together.""" @@ -57,12 +51,10 @@ async def greet_with_progress(name: str, ctx: Context) -> str: await ctx.report_progress(0.75, 1.0, "Almost there!") return f"Hello, {name}!" - # Add a resource @server.resource(uri="data://users") async def get_users(): return ["Alice", "Bob", "Charlie"] - # Add a resource template @server.resource(uri="data://user/{user_id}") async def get_user(user_id: str): return {"id": user_id, "name": f"User {user_id}", "active": True} @@ -70,10 +62,8 @@ async def get_user(user_id: str): @server.resource(uri="request://headers") async def get_headers() -> dict[str, str]: request = get_http_request() - return dict(request.headers) - # Add a prompt @server.prompt def welcome(name: str) -> str: """Example greeting prompt.""" @@ -82,51 +72,72 @@ def welcome(name: str) -> str: return server -def run_server(host: str, port: int, stateless_http: bool = False, **kwargs) -> None: - server = fastmcp_server() - server.settings.stateless_http = stateless_http - server.run(host=host, port=port, **kwargs) +@pytest.fixture +async def streamable_http_server(request, task_group): + """Start a test server and return its URL.""" + import fastmcp + stateless_http = getattr(request, "param", False) + if stateless_http: + fastmcp.settings.stateless_http = True -def run_nested_server(host: str, port: int) -> None: - mcp_app = fastmcp_server().http_app(path="/final/mcp") + server = create_test_server() + url = await run_server_async(task_group, server) + yield url - mount = Starlette(routes=[Mount("/nest-inner", app=mcp_app)]) - mount2 = Starlette( - routes=[Mount("/nest-outer", app=mount)], - lifespan=mcp_app.lifespan, - ) - server = uvicorn.Server( - config=uvicorn.Config( - app=mount2, - host=host, - port=port, - log_level="error", - lifespan="on", - ws="websockets-sansio", - ) + if stateless_http: + fastmcp.settings.stateless_http = False + + +@pytest.fixture +async def streamable_http_server_with_streamable_http_alias(task_group): + """Test that the "streamable-http" transport alias works.""" + server = create_test_server() + url = await run_server_async(task_group, server, transport="streamable-http") + yield url + + +@pytest.fixture +async def nested_server(): + """Test nested server mounts with Starlette.""" + import uvicorn + from starlette.applications import Starlette + from starlette.routing import Mount + + from fastmcp.utilities.http import find_available_port + + server = create_test_server() + mcp_app = server.http_app(path="/final/mcp") + + # Nest the app under multiple mounts to test URL resolution + inner = Starlette(routes=[Mount("/nest-inner", app=mcp_app)]) + outer = Starlette( + routes=[Mount("/nest-outer", app=inner)], lifespan=mcp_app.lifespan ) - server.run() + # Run uvicorn with the nested ASGI app + port = find_available_port() -@pytest.fixture() -async def streamable_http_server( - request, -) -> AsyncGenerator[str, None]: - stateless_http = getattr(request, "param", False) - with run_server_in_process( - run_server, stateless_http=stateless_http, transport="http" - ) as url: - yield f"{url}/mcp" + config = uvicorn.Config( + app=outer, + host="127.0.0.1", + port=port, + log_level="critical", + ws="websockets-sansio", + ) + # Use the simple asyncio pattern + server_task = asyncio.create_task(uvicorn.Server(config).serve()) + await asyncio.sleep(0.1) -@pytest.fixture() -async def streamable_http_server_with_streamable_http_alias() -> AsyncGenerator[ - str, None -]: - """Test that the "streamable-http" transport alias works.""" - with run_server_in_process(run_server, transport="streamable-http") as url: - yield f"{url}/mcp" + yield f"http://127.0.0.1:{port}/nest-outer/nest-inner/final/mcp" + + # Cleanup + server_task.cancel() + try: + await server_task + except asyncio.CancelledError: + pass async def test_ping(streamable_http_server: str): @@ -203,16 +214,11 @@ async def elicitation_handler(message, response_type, params, ctx): assert result.data == "You said your name was: Alice!" -async def test_nested_streamable_http_server_resolves_correctly(): - # tests patch for - # https://github.com/modelcontextprotocol/python-sdk/pull/659 - - with run_server_in_process(run_nested_server) as url: - async with Client( - transport=StreamableHttpTransport(f"{url}/nest-outer/nest-inner/final/mcp") - ) as client: - result = await client.ping() - assert result is True +async def test_nested_streamable_http_server_resolves_correctly(nested_server: str): + """Test patch for https://github.com/modelcontextprotocol/python-sdk/pull/659""" + async with Client(transport=StreamableHttpTransport(nested_server)) as client: + result = await client.ping() + assert result is True @pytest.mark.skipif( diff --git a/tests/conftest.py b/tests/conftest.py index 0d7e7c090..f58dac3f9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ from collections.abc import Callable from typing import Any +import anyio import pytest @@ -57,3 +58,11 @@ def get_port(): return port return get_port + + +@pytest.fixture +async def task_group(): + """Provide an anyio task group for running async servers in tests.""" + async with anyio.create_task_group() as tg: + yield tg + tg.cancel_scope.cancel() diff --git a/tests/experimental/openapi_parser/server/openapi/test_comprehensive.py b/tests/experimental/openapi_parser/server/openapi/test_comprehensive.py index c613cd0a0..2473914e6 100644 --- a/tests/experimental/openapi_parser/server/openapi/test_comprehensive.py +++ b/tests/experimental/openapi_parser/server/openapi/test_comprehensive.py @@ -352,7 +352,6 @@ def openapi_31_spec(self): }, } - @pytest.mark.asyncio async def test_comprehensive_server_initialization( self, comprehensive_openapi_spec ): @@ -387,7 +386,6 @@ async def test_comprehensive_server_initialization( assert tool_names == expected_operations - @pytest.mark.asyncio async def test_openapi_31_compatibility(self, openapi_31_spec): """Test that OpenAPI 3.1 specs work correctly.""" async with httpx.AsyncClient(base_url="https://api.example.com") as client: @@ -404,7 +402,6 @@ async def test_openapi_31_compatibility(self, openapi_31_spec): tool = tools[0] assert tool.name == "get_item_31" - @pytest.mark.asyncio async def test_parameter_collision_handling(self, comprehensive_openapi_spec): """Test that parameter collisions are handled correctly.""" async with httpx.AsyncClient(base_url="https://api.example.com") as client: @@ -436,7 +433,6 @@ async def test_parameter_collision_handling(self, comprehensive_openapi_spec): # Should have other parameters assert "data" in param_names - @pytest.mark.asyncio async def test_deep_object_parameters(self, comprehensive_openapi_spec): """Test deepObject parameter handling.""" async with httpx.AsyncClient(base_url="https://api.example.com") as client: @@ -462,7 +458,6 @@ async def test_deep_object_parameters(self, comprehensive_openapi_spec): filter_params = [name for name in properties.keys() if "filter" in name] assert len(filter_params) > 0 - @pytest.mark.asyncio async def test_request_building_and_execution(self, comprehensive_openapi_spec): """Test that requests are built and executed correctly.""" # Create a mock client that tracks requests @@ -503,7 +498,6 @@ async def test_request_building_and_execution(self, comprehensive_openapi_spec): assert "123" in str(request.url) assert "users/123" in str(request.url) - @pytest.mark.asyncio async def test_complex_request_with_body_and_parameters( self, comprehensive_openapi_spec ): @@ -554,7 +548,6 @@ async def test_complex_request_with_body_and_parameters( assert body_data["email"] == "new@example.com" assert body_data["age"] == 25 - @pytest.mark.asyncio async def test_query_parameters(self, comprehensive_openapi_spec): """Test query parameter handling.""" mock_client = Mock(spec=httpx.AsyncClient) @@ -593,7 +586,6 @@ async def test_query_parameters(self, comprehensive_openapi_spec): assert "offset=10" in url_str assert "sort=name" in url_str - @pytest.mark.asyncio async def test_error_handling(self, comprehensive_openapi_spec): """Test error handling for HTTP errors.""" mock_client = Mock(spec=httpx.AsyncClient) @@ -630,7 +622,6 @@ def raise_for_status(): error_message = str(exc_info.value) assert "404" in error_message - @pytest.mark.asyncio async def test_schema_refs_resolution(self, comprehensive_openapi_spec): """Test that schema references are resolved correctly.""" async with httpx.AsyncClient(base_url="https://api.example.com") as client: @@ -652,7 +643,6 @@ async def test_schema_refs_resolution(self, comprehensive_openapi_spec): assert "email" in properties # May also have id and age depending on implementation - @pytest.mark.asyncio async def test_optional_vs_required_parameters(self, comprehensive_openapi_spec): """Test handling of optional vs required parameters.""" async with httpx.AsyncClient(base_url="https://api.example.com") as client: @@ -680,7 +670,6 @@ async def test_optional_vs_required_parameters(self, comprehensive_openapi_spec) # Should have some required parameters assert len(search_schema["properties"]) > 0 - @pytest.mark.asyncio async def test_server_performance_no_latency(self, comprehensive_openapi_spec): """Test that server initialization is fast (no code generation latency).""" import time diff --git a/tests/experimental/openapi_parser/server/openapi/test_deepobject_style.py b/tests/experimental/openapi_parser/server/openapi/test_deepobject_style.py index a4d4be824..54327b1ba 100644 --- a/tests/experimental/openapi_parser/server/openapi/test_deepobject_style.py +++ b/tests/experimental/openapi_parser/server/openapi/test_deepobject_style.py @@ -174,7 +174,6 @@ def deepobject_spec(self): }, } - @pytest.mark.asyncio async def test_deepobject_style_parsing_from_spec(self, deepobject_spec): """Test that deepObject style parameters are correctly parsed from OpenAPI spec.""" async with httpx.AsyncClient(base_url="https://api.example.com") as client: @@ -208,7 +207,6 @@ async def test_deepobject_style_parsing_from_spec(self, deepobject_spec): # Should have some structure, exact format may vary assert target_param is not None - @pytest.mark.asyncio async def test_deepobject_explode_true_handling(self, deepobject_spec): """Test deepObject with explode=true parameter handling.""" async with httpx.AsyncClient(base_url="https://api.example.com") as client: @@ -234,7 +232,6 @@ async def test_deepobject_explode_true_handling(self, deepobject_spec): assert "type" in target_properties assert target_properties["type"]["enum"] == ["location", "organisation"] - @pytest.mark.asyncio async def test_deepobject_explode_false_handling(self, deepobject_spec): """Test deepObject with explode=false parameter handling.""" async with httpx.AsyncClient(base_url="https://api.example.com") as client: @@ -263,7 +260,6 @@ async def test_deepobject_explode_false_handling(self, deepobject_spec): if "type" in compact_param: assert compact_param["type"] == "object" - @pytest.mark.asyncio async def test_nested_object_structure_in_request_body(self, deepobject_spec): """Test nested object structures in request body are preserved.""" async with httpx.AsyncClient(base_url="https://api.example.com") as client: @@ -308,7 +304,6 @@ async def test_nested_object_structure_in_request_body(self, deepobject_spec): assert "push" in notif_props assert "frequency" in notif_props - @pytest.mark.asyncio async def test_deepobject_tool_functionality(self, deepobject_spec): """Test that tools with deepObject parameters maintain basic functionality.""" async with httpx.AsyncClient(base_url="https://api.example.com") as client: diff --git a/tests/experimental/openapi_parser/server/openapi/test_openapi_features.py b/tests/experimental/openapi_parser/server/openapi/test_openapi_features.py index c656c25d3..2dfac618e 100644 --- a/tests/experimental/openapi_parser/server/openapi/test_openapi_features.py +++ b/tests/experimental/openapi_parser/server/openapi/test_openapi_features.py @@ -124,7 +124,6 @@ def parameter_spec(self): }, } - @pytest.mark.asyncio async def test_query_parameters_in_tools(self, parameter_spec): """Test that query parameters are properly included in tool parameters.""" async with httpx.AsyncClient(base_url="https://api.example.com") as client: @@ -173,7 +172,6 @@ async def test_query_parameters_in_tools(self, parameter_spec): assert "query" in required assert "X-API-Key" in required - @pytest.mark.asyncio async def test_path_parameters_in_tools(self, parameter_spec): """Test that path parameters are properly included in tool parameters.""" async with httpx.AsyncClient(base_url="https://api.example.com") as client: @@ -279,7 +277,6 @@ def request_body_spec(self): }, } - @pytest.mark.asyncio async def test_request_body_properties_in_tool(self, request_body_spec): """Test that request body properties are included in tool parameters.""" async with httpx.AsyncClient(base_url="https://api.example.com") as client: @@ -381,7 +378,6 @@ def response_schema_spec(self): }, } - @pytest.mark.asyncio async def test_tool_has_output_schema(self, response_schema_spec): """Test that tools have output schemas from response definitions.""" async with httpx.AsyncClient(base_url="https://api.example.com") as client: diff --git a/tests/experimental/openapi_parser/server/openapi/test_parameter_collisions.py b/tests/experimental/openapi_parser/server/openapi/test_parameter_collisions.py index 572f3cb26..fb87ffb5a 100644 --- a/tests/experimental/openapi_parser/server/openapi/test_parameter_collisions.py +++ b/tests/experimental/openapi_parser/server/openapi/test_parameter_collisions.py @@ -118,7 +118,6 @@ def collision_spec(self): }, } - @pytest.mark.asyncio async def test_path_body_collision_handling(self, collision_spec): """Test that path and body parameters with same name are handled correctly.""" async with httpx.AsyncClient(base_url="https://api.example.com") as client: @@ -159,7 +158,6 @@ async def test_path_body_collision_handling(self, collision_spec): id_required = any("id" in req for req in required) assert id_required - @pytest.mark.asyncio async def test_query_header_collision_handling(self, collision_spec): """Test that query and header parameters with same name are handled correctly.""" async with httpx.AsyncClient(base_url="https://api.example.com") as client: @@ -190,7 +188,6 @@ async def test_query_header_collision_handling(self, collision_spec): query_required = any("query" in req for req in required) assert query_required - @pytest.mark.asyncio async def test_collision_resolution_maintains_functionality(self, collision_spec): """Test that collision resolution doesn't break basic tool functionality.""" async with httpx.AsyncClient(base_url="https://api.example.com") as client: diff --git a/tests/experimental/openapi_parser/server/openapi/test_server.py b/tests/experimental/openapi_parser/server/openapi/test_server.py index 01720bd47..6767e4e69 100644 --- a/tests/experimental/openapi_parser/server/openapi/test_server.py +++ b/tests/experimental/openapi_parser/server/openapi/test_server.py @@ -112,7 +112,6 @@ def test_server_initialization_with_custom_name(self, simple_openapi_spec): # Should use default name assert server.name == "OpenAPI FastMCP" - @pytest.mark.asyncio async def test_server_creates_tools_from_spec(self, simple_openapi_spec): """Test that server creates tools from OpenAPI spec.""" async with httpx.AsyncClient(base_url="https://api.example.com") as client: @@ -131,7 +130,6 @@ async def test_server_creates_tools_from_spec(self, simple_openapi_spec): assert "get_user" in tool_names assert "create_user" in tool_names - @pytest.mark.asyncio async def test_server_tool_execution_fallback_to_http(self, simple_openapi_spec): """Test tool execution falls back to HTTP when callables aren't available.""" # Use a mock client that will be used for HTTP fallback @@ -203,7 +201,6 @@ def test_server_with_empty_spec(self): assert hasattr(server, "_director") assert hasattr(server, "_spec") - @pytest.mark.asyncio async def test_clean_schema_output_no_unused_defs(self): """Test that unused schema definitions are removed from tool schemas.""" # Create a spec with unused HTTPValidationError-like definitions diff --git a/tests/integration_tests/auth/test_github_provider_integration.py b/tests/integration_tests/auth/test_github_provider_integration.py index 54d159fbd..141846c6f 100644 --- a/tests/integration_tests/auth/test_github_provider_integration.py +++ b/tests/integration_tests/auth/test_github_provider_integration.py @@ -12,16 +12,16 @@ import os import re -from collections.abc import Generator from urllib.parse import parse_qs, urlparse import httpx import pytest +from anyio.abc import TaskGroup from fastmcp import FastMCP from fastmcp.client import Client from fastmcp.server.auth.providers.github import GitHubProvider -from fastmcp.utilities.tests import HeadlessOAuth, run_server_in_process +from fastmcp.utilities.tests import HeadlessOAuth, run_server_async FASTMCP_TEST_AUTH_GITHUB_CLIENT_ID = os.getenv("FASTMCP_TEST_AUTH_GITHUB_CLIENT_ID") FASTMCP_TEST_AUTH_GITHUB_CLIENT_SECRET = os.getenv( @@ -36,7 +36,7 @@ ) -def create_github_server(host: str = "127.0.0.1", port: int = 9100, **kwargs) -> None: +def create_github_server(base_url: str) -> FastMCP: """Create FastMCP server with GitHub OAuth protection.""" assert FASTMCP_TEST_AUTH_GITHUB_CLIENT_ID is not None assert FASTMCP_TEST_AUTH_GITHUB_CLIENT_SECRET is not None @@ -45,7 +45,7 @@ def create_github_server(host: str = "127.0.0.1", port: int = 9100, **kwargs) -> auth = GitHubProvider( client_id=FASTMCP_TEST_AUTH_GITHUB_CLIENT_ID, client_secret=FASTMCP_TEST_AUTH_GITHUB_CLIENT_SECRET, - base_url=f"http://{host}:{port}", + base_url=base_url, ) # Create FastMCP server with GitHub authentication @@ -61,13 +61,10 @@ def get_user_info() -> str: """Returns user info from OAuth context.""" return "📝 GitHub OAuth user authenticated successfully" - # Run the server - server.run(host=host, port=port, **kwargs) + return server -def create_github_server_with_mock_callback( - host: str = "127.0.0.1", port: int = 9100, **kwargs -) -> None: +def create_github_server_with_mock_callback(base_url: str) -> FastMCP: """Create FastMCP server with GitHub OAuth that mocks the callback for testing.""" assert FASTMCP_TEST_AUTH_GITHUB_CLIENT_ID is not None assert FASTMCP_TEST_AUTH_GITHUB_CLIENT_SECRET is not None @@ -76,7 +73,7 @@ def create_github_server_with_mock_callback( auth = GitHubProvider( client_id=FASTMCP_TEST_AUTH_GITHUB_CLIENT_ID, client_secret=FASTMCP_TEST_AUTH_GITHUB_CLIENT_SECRET, - base_url=f"http://{host}:{port}", + base_url=base_url, ) # Mock the authorize method to return a fake code instead of redirecting to GitHub @@ -159,29 +156,25 @@ def get_user_info() -> str: """Returns user info from OAuth context.""" return "📝 GitHub OAuth user authenticated successfully" - # Run the server - server.run(host=host, port=port, **kwargs) - - -@pytest.fixture(scope="module") -def github_server() -> Generator[str, None, None]: - """Start GitHub OAuth server in background process on fixed port 9100.""" - with run_server_in_process( - create_github_server, transport="http", host="127.0.0.1", port=9100 - ) as url: - yield f"{url}/mcp" - - -@pytest.fixture(scope="module") -def github_server_with_mock() -> Generator[str, None, None]: - """Start GitHub OAuth server with mocked callback in background process on port 9101.""" - with run_server_in_process( - create_github_server_with_mock_callback, - transport="http", - host="127.0.0.1", - port=9101, - ) as url: - yield f"{url}/mcp" + return server + + +@pytest.fixture +async def github_server(task_group: TaskGroup) -> str: + """Start GitHub OAuth server with AnyIO task group on fixed port 9100.""" + base_url = "http://127.0.0.1:9100" + server = create_github_server(base_url) + url = await run_server_async(task_group, server, port=9100, transport="http") + return url + + +@pytest.fixture +async def github_server_with_mock(task_group: TaskGroup) -> str: + """Start GitHub OAuth server with mocked callback on port 9101.""" + base_url = "http://127.0.0.1:9101" + server = create_github_server_with_mock_callback(base_url) + url = await run_server_async(task_group, server, port=9101, transport="http") + return url @pytest.fixture diff --git a/tests/server/auth/providers/test_descope.py b/tests/server/auth/providers/test_descope.py index 3df78e052..0af08356d 100644 --- a/tests/server/auth/providers/test_descope.py +++ b/tests/server/auth/providers/test_descope.py @@ -1,16 +1,16 @@ """Tests for Descope OAuth provider.""" import os -from collections.abc import Generator from unittest.mock import patch import httpx import pytest +from anyio.abc import TaskGroup from fastmcp import Client, FastMCP from fastmcp.client.transports import StreamableHttpTransport from fastmcp.server.auth.providers.descope import DescopeProvider -from fastmcp.utilities.tests import HeadlessOAuth, run_server_in_process +from fastmcp.utilities.tests import HeadlessOAuth, run_server_async class TestDescopeProvider: @@ -118,7 +118,9 @@ def test_jwt_verifier_configured_correctly(self): assert provider.token_verifier.audience == "P2abc123" # type: ignore[attr-defined] -def run_mcp_server(host: str, port: int) -> None: +@pytest.fixture +async def mcp_server_url(task_group: TaskGroup): + """Start Descope server.""" mcp = FastMCP( auth=DescopeProvider( project_id="P2test123", @@ -131,25 +133,17 @@ def run_mcp_server(host: str, port: int) -> None: def add(a: int, b: int) -> int: return a + b - mcp.run(host=host, port=port, transport="http") + url = await run_server_async(task_group, mcp, transport="http") + return url @pytest.fixture -def mcp_server_url() -> Generator[str]: - with run_server_in_process(run_mcp_server) as url: - yield f"{url}/mcp" - - -@pytest.fixture() -def client_with_headless_oauth( - mcp_server_url: str, -) -> Generator[Client, None, None]: +def client_with_headless_oauth(mcp_server_url: str) -> Client: """Client with headless OAuth that bypasses browser interaction.""" - client = Client( + return Client( transport=StreamableHttpTransport(mcp_server_url), auth=HeadlessOAuth(mcp_url=mcp_server_url), ) - yield client class TestDescopeProviderIntegration: diff --git a/tests/server/auth/providers/test_github.py b/tests/server/auth/providers/test_github.py index 65ee66de0..45a343bd5 100644 --- a/tests/server/auth/providers/test_github.py +++ b/tests/server/auth/providers/test_github.py @@ -177,7 +177,6 @@ def test_init_defaults(self): ) # Parent TokenVerifier sets empty list as default assert verifier.timeout_seconds == 10 - @pytest.mark.asyncio async def test_verify_token_github_api_failure(self): """Test token verification when GitHub API returns error.""" verifier = GitHubTokenVerifier() @@ -196,7 +195,6 @@ async def test_verify_token_github_api_failure(self): result = await verifier.verify_token("invalid_token") assert result is None - @pytest.mark.asyncio async def test_verify_token_success(self): """Test successful token verification.""" from unittest.mock import AsyncMock diff --git a/tests/server/auth/providers/test_scalekit.py b/tests/server/auth/providers/test_scalekit.py index d7f7a6546..af4b34026 100644 --- a/tests/server/auth/providers/test_scalekit.py +++ b/tests/server/auth/providers/test_scalekit.py @@ -1,16 +1,16 @@ """Tests for Scalekit OAuth provider.""" import os -from collections.abc import Generator from unittest.mock import patch import httpx import pytest +from anyio.abc import TaskGroup from fastmcp import Client, FastMCP from fastmcp.client.transports import StreamableHttpTransport from fastmcp.server.auth.providers.scalekit import ScalekitProvider -from fastmcp.utilities.tests import HeadlessOAuth, run_server_in_process +from fastmcp.utilities.tests import HeadlessOAuth, run_server_async class TestScalekitProvider: @@ -109,7 +109,9 @@ def test_authorization_servers_configuration(self): ) -def run_mcp_server(host: str, port: int) -> None: +@pytest.fixture +async def mcp_server_url(task_group: TaskGroup): + """Start Scalekit server.""" mcp = FastMCP( auth=ScalekitProvider( environment_url="https://test-env.scalekit.com", @@ -123,25 +125,17 @@ def run_mcp_server(host: str, port: int) -> None: def add(a: int, b: int) -> int: return a + b - mcp.run(host=host, port=port, transport="http") + url = await run_server_async(task_group, mcp, transport="http") + return url @pytest.fixture -def mcp_server_url() -> Generator[str]: - with run_server_in_process(run_mcp_server) as url: - yield f"{url}/mcp" - - -@pytest.fixture() -def client_with_headless_oauth( - mcp_server_url: str, -) -> Generator[Client, None, None]: +def client_with_headless_oauth(mcp_server_url: str) -> Client: """Client with headless OAuth that bypasses browser interaction.""" - client = Client( + return Client( transport=StreamableHttpTransport(mcp_server_url), auth=HeadlessOAuth(mcp_url=mcp_server_url), ) - yield client class TestScalekitProviderIntegration: diff --git a/tests/server/auth/providers/test_workos.py b/tests/server/auth/providers/test_workos.py index b4d7f3c6f..e52ff2798 100644 --- a/tests/server/auth/providers/test_workos.py +++ b/tests/server/auth/providers/test_workos.py @@ -1,17 +1,17 @@ """Tests for WorkOS OAuth provider.""" import os -from collections.abc import Generator from unittest.mock import patch from urllib.parse import urlparse import httpx import pytest +from anyio.abc import TaskGroup from fastmcp import Client, FastMCP from fastmcp.client.transports import StreamableHttpTransport from fastmcp.server.auth.providers.workos import AuthKitProvider, WorkOSProvider -from fastmcp.utilities.tests import HeadlessOAuth, run_server_in_process +from fastmcp.utilities.tests import HeadlessOAuth, run_server_async class TestWorkOSProvider: @@ -157,7 +157,9 @@ def test_oauth_endpoints_configured_correctly(self): ) # WorkOS doesn't support revocation -def run_mcp_server(host: str, port: int) -> None: +@pytest.fixture +async def mcp_server_url(task_group: TaskGroup): + """Start AuthKit server.""" mcp = FastMCP( auth=AuthKitProvider( authkit_domain="https://respectful-lullaby-34-staging.authkit.app", @@ -169,25 +171,17 @@ def run_mcp_server(host: str, port: int) -> None: def add(a: int, b: int) -> int: return a + b - mcp.run(host=host, port=port, transport="http") + url = await run_server_async(task_group, mcp, transport="http") + return url @pytest.fixture -def mcp_server_url() -> Generator[str]: - with run_server_in_process(run_mcp_server) as url: - yield f"{url}/mcp" - - -@pytest.fixture() -def client_with_headless_oauth( - mcp_server_url: str, -) -> Generator[Client, None, None]: +def client_with_headless_oauth(mcp_server_url: str) -> Client: """Client with headless OAuth that bypasses browser interaction.""" - client = Client( + return Client( transport=StreamableHttpTransport(mcp_server_url), auth=HeadlessOAuth(mcp_url=mcp_server_url), ) - yield client class TestAuthKitProvider: diff --git a/tests/server/auth/test_jwt_provider.py b/tests/server/auth/test_jwt_provider.py index 7d9d9837b..0ea0352dc 100644 --- a/tests/server/auth/test_jwt_provider.py +++ b/tests/server/auth/test_jwt_provider.py @@ -1,14 +1,14 @@ -from collections.abc import Generator from typing import Any import httpx import pytest +from anyio.abc import TaskGroup from pytest_httpx import HTTPXMock from fastmcp import Client, FastMCP from fastmcp.client.auth.bearer import BearerAuth from fastmcp.server.auth.providers.jwt import JWKData, JWKSData, JWTVerifier, RSAKeyPair -from fastmcp.utilities.tests import run_server_in_process +from fastmcp.utilities.tests import run_server_async class SymmetricKeyHelper: @@ -111,13 +111,10 @@ def symmetric_provider(symmetric_key_helper: SymmetricKeyHelper) -> JWTVerifier: ) -def run_mcp_server( +def create_mcp_server( public_key: str, - host: str, - port: int, auth_kwargs: dict[str, Any] | None = None, - run_kwargs: dict[str, Any] | None = None, -) -> None: +) -> FastMCP: mcp = FastMCP( auth=JWTVerifier( public_key=public_key, @@ -129,21 +126,20 @@ def run_mcp_server( def add(a: int, b: int) -> int: return a + b - mcp.run(host=host, port=port, **run_kwargs or {}) + return mcp @pytest.fixture -def mcp_server_url(rsa_key_pair: RSAKeyPair) -> Generator[str]: - with run_server_in_process( - run_mcp_server, +async def mcp_server_url(task_group: TaskGroup, rsa_key_pair: RSAKeyPair) -> str: + server = create_mcp_server( public_key=rsa_key_pair.public_key, auth_kwargs=dict( issuer="https://test.example.com", audience="https://api.example.com", ), - run_kwargs=dict(transport="http"), - ) as url: - yield f"{url}/mcp" + ) + url = await run_server_async(task_group, server, transport="http") + return url class TestRSAKeyPair: @@ -1022,7 +1018,7 @@ async def test_token_with_bad_signature(self, mcp_server_url: str): assert "tools" not in locals() async def test_token_with_insufficient_scopes( - self, mcp_server_url: str, rsa_key_pair: RSAKeyPair + self, task_group: TaskGroup, rsa_key_pair: RSAKeyPair ): token = rsa_key_pair.create_token( subject="test-user", @@ -1031,25 +1027,24 @@ async def test_token_with_insufficient_scopes( scopes=["read"], ) - with run_server_in_process( - run_mcp_server, + server = create_mcp_server( public_key=rsa_key_pair.public_key, auth_kwargs=dict(required_scopes=["read", "write"]), - run_kwargs=dict(transport="http"), - ) as url: - mcp_server_url = f"{url}/mcp/" - with pytest.raises(httpx.HTTPStatusError) as exc_info: - async with Client(mcp_server_url, auth=BearerAuth(token)) as client: - tools = await client.list_tools() # noqa: F841 - # JWTVerifier returns 401 when verify_token returns None (invalid token) - # This is correct behavior - when TokenVerifier.verify_token returns None, - # it indicates the token is invalid (not just insufficient permissions) - assert isinstance(exc_info.value, httpx.HTTPStatusError) - assert exc_info.value.response.status_code == 401 - assert "tools" not in locals() + ) + mcp_server_url = await run_server_async(task_group, server, transport="http") + + with pytest.raises(httpx.HTTPStatusError) as exc_info: + async with Client(mcp_server_url, auth=BearerAuth(token)) as client: + tools = await client.list_tools() # noqa: F841 + # JWTVerifier returns 401 when verify_token returns None (invalid token) + # This is correct behavior - when TokenVerifier.verify_token returns None, + # it indicates the token is invalid (not just insufficient permissions) + assert isinstance(exc_info.value, httpx.HTTPStatusError) + assert exc_info.value.response.status_code == 401 + assert "tools" not in locals() async def test_token_with_sufficient_scopes( - self, mcp_server_url: str, rsa_key_pair: RSAKeyPair + self, task_group: TaskGroup, rsa_key_pair: RSAKeyPair ): token = rsa_key_pair.create_token( subject="test-user", @@ -1058,15 +1053,14 @@ async def test_token_with_sufficient_scopes( scopes=["read", "write"], ) - with run_server_in_process( - run_mcp_server, + server = create_mcp_server( public_key=rsa_key_pair.public_key, auth_kwargs=dict(required_scopes=["read", "write"]), - run_kwargs=dict(transport="http"), - ) as url: - mcp_server_url = f"{url}/mcp/" - async with Client(mcp_server_url, auth=BearerAuth(token)) as client: - tools = await client.list_tools() + ) + mcp_server_url = await run_server_async(task_group, server, transport="http") + + async with Client(mcp_server_url, auth=BearerAuth(token)) as client: + tools = await client.list_tools() assert tools diff --git a/tests/server/auth/test_oauth_proxy.py b/tests/server/auth/test_oauth_proxy.py index 7e526b52e..59cb0a38a 100644 --- a/tests/server/auth/test_oauth_proxy.py +++ b/tests/server/auth/test_oauth_proxy.py @@ -617,7 +617,6 @@ def test_token_auth_method_initialization(self, jwt_verifier): ) assert proxy_default._token_endpoint_auth_method is None - @pytest.mark.asyncio async def test_token_auth_method_passed_to_client(self, jwt_verifier): """Test that auth method is passed to AsyncOAuth2Client.""" proxy = OAuthProxy( @@ -738,7 +737,6 @@ async def test_token_auth_method_passed_to_client(self, jwt_verifier): class TestOAuthProxyE2E: """End-to-end tests using mock OAuth provider.""" - @pytest.mark.asyncio async def test_full_oauth_flow_with_mock_provider(self, mock_oauth_provider): """Test complete OAuth flow with mock provider.""" # Create proxy pointing to mock provider @@ -793,7 +791,6 @@ def protected_tool() -> str: # Transaction ID itself is used as upstream state parameter assert transaction.txn_id == txn_id - @pytest.mark.asyncio async def test_token_refresh_with_mock_provider(self, mock_oauth_provider): """Test token refresh flow with mock provider.""" proxy = OAuthProxy( @@ -906,7 +903,6 @@ async def mock_refresh(*args, **kwargs): assert len(result.access_token.split(".")) == 3 assert mock_oauth_provider.refresh_called - @pytest.mark.asyncio async def test_pkce_validation_with_mock_provider(self, mock_oauth_provider): """Test PKCE validation with mock provider.""" mock_oauth_provider.require_pkce = True @@ -1174,7 +1170,6 @@ async def test_multiple_extra_params(self, jwt_verifier): assert proxy._extra_authorize_params.get("prompt") == "consent" assert proxy._extra_authorize_params.get("max_age") == "3600" - @pytest.mark.asyncio async def test_token_endpoint_invalid_client_error(self, jwt_verifier): """Test that invalid client_id returns OAuth 2.1 compliant error response. diff --git a/tests/server/auth/test_oauth_proxy_redirect_validation.py b/tests/server/auth/test_oauth_proxy_redirect_validation.py index a4185538f..4ef54db67 100644 --- a/tests/server/auth/test_oauth_proxy_redirect_validation.py +++ b/tests/server/auth/test_oauth_proxy_redirect_validation.py @@ -145,7 +145,6 @@ def test_proxy_empty_list_validation(self): assert proxy._allowed_client_redirect_uris == [] - @pytest.mark.asyncio async def test_proxy_register_client_uses_patterns(self): """Test that registered clients use the configured patterns.""" custom_patterns = ["https://app.example.com/*"] @@ -178,7 +177,6 @@ async def test_proxy_register_client_uses_patterns(self): assert isinstance(registered, ProxyDCRClient) assert registered.allowed_redirect_uri_patterns == custom_patterns - @pytest.mark.asyncio async def test_proxy_unregistered_client_returns_none(self): """Test that unregistered clients return None.""" custom_patterns = ["http://localhost:*", "http://127.0.0.1:*"] diff --git a/tests/server/http/test_http_dependencies.py b/tests/server/http/test_http_dependencies.py index d0b3afb96..f1ec5b77f 100644 --- a/tests/server/http/test_http_dependencies.py +++ b/tests/server/http/test_http_dependencies.py @@ -1,13 +1,13 @@ import json -from collections.abc import Generator import pytest +from anyio.abc import TaskGroup from fastmcp.client import Client from fastmcp.client.transports import SSETransport, StreamableHttpTransport from fastmcp.server.dependencies import get_http_request from fastmcp.server.server import FastMCP -from fastmcp.utilities.tests import run_server_in_process +from fastmcp.utilities.tests import run_server_async def fastmcp_server(): @@ -38,20 +38,20 @@ def get_headers_prompt() -> str: return server -def run_server(host: str, port: int, **kwargs) -> None: - fastmcp_server().run(host=host, port=port, **kwargs) +@pytest.fixture +async def shttp_server(task_group: TaskGroup): + """Start a test server with StreamableHttp transport.""" + server = fastmcp_server() + url = await run_server_async(task_group, server, transport="http") + return url -@pytest.fixture(autouse=True) -def shttp_server() -> Generator[str, None, None]: - with run_server_in_process(run_server, transport="http") as url: - yield f"{url}/mcp" - - -@pytest.fixture(autouse=True) -def sse_server() -> Generator[str, None, None]: - with run_server_in_process(run_server, transport="sse") as url: - yield f"{url}/sse" +@pytest.fixture +async def sse_server(task_group: TaskGroup): + """Start a test server with SSE transport.""" + server = fastmcp_server() + url = await run_server_async(task_group, server, transport="sse") + return url async def test_http_headers_resource_shttp(shttp_server: str): diff --git a/tests/server/test_context.py b/tests/server/test_context.py index b0a7fc4c4..2a2e38e74 100644 --- a/tests/server/test_context.py +++ b/tests/server/test_context.py @@ -137,7 +137,6 @@ def test_session_id_without_http_headers(self, context): class TestContextState: """Test suite for Context state functionality.""" - @pytest.mark.asyncio async def test_context_state(self): """Test that state modifications in child contexts don't affect parent.""" mock_fastmcp = MagicMock() @@ -152,7 +151,6 @@ async def test_context_state(self): context.set_state("test1", "new_value") assert context.get_state("test1") == "new_value" - @pytest.mark.asyncio async def test_context_state_inheritance(self): """Test that child contexts inherit parent state.""" mock_fastmcp = MagicMock() diff --git a/uv.lock b/uv.lock index 447ba2eac..7c74dc05b 100644 --- a/uv.lock +++ b/uv.lock @@ -60,15 +60,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f9/58/cc6a08053f822f98f334d38a27687b69c6655fb05cd74a7a5e70a2aeed95/authlib-1.6.1-py2.py3-none-any.whl", hash = "sha256:e9d2031c34c6309373ab845afc24168fe9e93dc52d252631f52642f21f5ed06e", size = 239299, upload-time = "2025-07-20T07:38:39.259Z" }, ] -[[package]] -name = "backports-asyncio-runner" -version = "1.2.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/8e/ff/70dca7d7cb1cbc0edb2c6cc0c38b65cba36cccc491eca64cabd5fe7f8670/backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162", size = 69893, upload-time = "2025-07-02T02:27:15.685Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a0/59/76ab57e3fe74484f48a53f8e337171b4a2349e506eabe136d7e01d059086/backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5", size = 12313, upload-time = "2025-07-02T02:27:14.263Z" }, -] - [[package]] name = "beartype" version = "0.22.2" @@ -580,7 +571,6 @@ dev = [ { name = "pyinstrument" }, { name = "pyperclip" }, { name = "pytest" }, - { name = "pytest-asyncio" }, { name = "pytest-cov" }, { name = "pytest-env" }, { name = "pytest-flakefinder" }, @@ -625,7 +615,6 @@ dev = [ { name = "pyinstrument", specifier = ">=5.0.2" }, { name = "pyperclip", specifier = ">=1.9.0" }, { name = "pytest", specifier = ">=8.3.3" }, - { name = "pytest-asyncio", specifier = ">=0.23.5" }, { name = "pytest-cov", specifier = ">=6.1.1" }, { name = "pytest-env", specifier = ">=1.1.5" }, { name = "pytest-flakefinder" }, @@ -1589,19 +1578,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474, upload-time = "2025-06-18T05:48:03.955Z" }, ] -[[package]] -name = "pytest-asyncio" -version = "1.1.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "backports-asyncio-runner", marker = "python_full_version < '3.11'" }, - { name = "pytest" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/4e/51/f8794af39eeb870e87a8c8068642fc07bce0c854d6865d7dd0f2a9d338c2/pytest_asyncio-1.1.0.tar.gz", hash = "sha256:796aa822981e01b68c12e4827b8697108f7205020f24b5793b3c41555dab68ea", size = 46652, upload-time = "2025-07-16T04:29:26.393Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c7/9d/bf86eddabf8c6c9cb1ea9a869d6873b46f105a5d292d3a6f7071f5b07935/pytest_asyncio-1.1.0-py3-none-any.whl", hash = "sha256:5fe2d69607b0bd75c656d1211f969cadba035030156745ee09e7d71740e58ecf", size = 15157, upload-time = "2025-07-16T04:29:24.929Z" }, -] - [[package]] name = "pytest-cov" version = "6.2.1"