Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions tests/client/test_http_unicode.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from mcp.client.session import ClientSession
from mcp.client.streamable_http import streamablehttp_client
from tests.test_helpers import get_worker_specific_port

# Test constants with various Unicode characters
UNICODE_TEST_STRINGS = {
Expand Down Expand Up @@ -145,11 +146,9 @@ async def lifespan(app: Starlette) -> AsyncGenerator[None, None]:


@pytest.fixture
def unicode_server_port() -> int:
def unicode_server_port(worker_id: str) -> int:
"""Find an available port for the Unicode test server."""
with socket.socket() as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
return get_worker_specific_port(worker_id)


@pytest.fixture
Expand Down
7 changes: 3 additions & 4 deletions tests/client/test_notification_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from mcp.client.streamable_http import streamablehttp_client
from mcp.shared.session import RequestResponder
from mcp.types import ClientNotification, RootsListChangedNotification
from tests.test_helpers import get_worker_specific_port


def create_non_sdk_server_app() -> Starlette:
Expand Down Expand Up @@ -81,11 +82,9 @@ def run_non_sdk_server(port: int) -> None:


@pytest.fixture
def non_sdk_server_port() -> int:
def non_sdk_server_port(worker_id: str) -> int:
"""Get an available port for the test server."""
with socket.socket() as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
return get_worker_specific_port(worker_id)


@pytest.fixture
Expand Down
20 changes: 15 additions & 5 deletions tests/server/fastmcp/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
TextResourceContents,
ToolListChangedNotification,
)
from tests.test_helpers import get_worker_specific_port


class NotificationCollector:
Expand Down Expand Up @@ -88,11 +89,20 @@ async def handle_generic_notification(

# Common fixtures
@pytest.fixture
def server_port() -> int:
"""Get a free port for testing."""
with socket.socket() as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
def server_port(worker_id: str) -> int:
"""Get a free port for testing with worker-specific ranges.

Uses worker-specific port ranges to prevent port conflicts when running
tests in parallel with pytest-xdist. Each worker gets a dedicated range
of ports, eliminating race conditions.

Args:
worker_id: pytest-xdist worker ID (injected by pytest)

Returns:
An available port in this worker's range
"""
return get_worker_specific_port(worker_id)


@pytest.fixture
Expand Down
9 changes: 3 additions & 6 deletions tests/server/test_sse_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import logging
import multiprocessing
import socket

import httpx
import pytest
Expand All @@ -16,17 +15,15 @@
from mcp.server.sse import SseServerTransport
from mcp.server.transport_security import TransportSecuritySettings
from mcp.types import Tool
from tests.test_helpers import wait_for_server
from tests.test_helpers import get_worker_specific_port, wait_for_server

logger = logging.getLogger(__name__)
SERVER_NAME = "test_sse_security_server"


@pytest.fixture
def server_port() -> int:
with socket.socket() as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
def server_port(worker_id: str) -> int:
return get_worker_specific_port(worker_id)


@pytest.fixture
Expand Down
9 changes: 3 additions & 6 deletions tests/server/test_streamable_http_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import logging
import multiprocessing
import socket
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager

Expand All @@ -17,17 +16,15 @@
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from mcp.server.transport_security import TransportSecuritySettings
from mcp.types import Tool
from tests.test_helpers import wait_for_server
from tests.test_helpers import get_worker_specific_port, wait_for_server

logger = logging.getLogger(__name__)
SERVER_NAME = "test_streamable_http_security_server"


@pytest.fixture
def server_port() -> int:
with socket.socket() as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
def server_port(worker_id: str) -> int:
return get_worker_specific_port(worker_id)


@pytest.fixture
Expand Down
8 changes: 3 additions & 5 deletions tests/shared/test_sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,14 @@
TextResourceContents,
Tool,
)
from tests.test_helpers import wait_for_server
from tests.test_helpers import get_worker_specific_port, wait_for_server

SERVER_NAME = "test_server_for_SSE"


@pytest.fixture
def server_port() -> int:
with socket.socket() as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
def server_port(worker_id: str) -> int:
return get_worker_specific_port(worker_id)


@pytest.fixture
Expand Down
21 changes: 7 additions & 14 deletions tests/shared/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import json
import multiprocessing
import socket
from collections.abc import Generator
from typing import Any

Expand Down Expand Up @@ -42,7 +41,7 @@
from mcp.shared.message import ClientMessageMetadata
from mcp.shared.session import RequestResponder
from mcp.types import InitializeResult, TextContent, TextResourceContents, Tool
from tests.test_helpers import wait_for_server
from tests.test_helpers import get_worker_specific_port, wait_for_server

# Test constants
SERVER_NAME = "test_streamable_http_server"
Expand Down Expand Up @@ -322,19 +321,15 @@ def run_server(port: int, is_json_response_enabled: bool = False, event_store: E

# Test fixtures - using same approach as SSE tests
@pytest.fixture
def basic_server_port() -> int:
def basic_server_port(worker_id: str) -> int:
"""Find an available port for the basic server."""
with socket.socket() as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
return get_worker_specific_port(worker_id)


@pytest.fixture
def json_server_port() -> int:
def json_server_port(worker_id: str) -> int:
"""Find an available port for the JSON response server."""
with socket.socket() as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
return get_worker_specific_port(worker_id)


@pytest.fixture
Expand All @@ -360,11 +355,9 @@ def event_store() -> SimpleEventStore:


@pytest.fixture
def event_server_port() -> int:
def event_server_port(worker_id: str) -> int:
"""Find an available port for the event store server."""
with socket.socket() as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
return get_worker_specific_port(worker_id)


@pytest.fixture
Expand Down
9 changes: 3 additions & 6 deletions tests/shared/test_ws.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import multiprocessing
import socket
import time
from collections.abc import AsyncGenerator, Generator
from typing import Any
Expand All @@ -26,16 +25,14 @@
TextResourceContents,
Tool,
)
from tests.test_helpers import wait_for_server
from tests.test_helpers import get_worker_specific_port, wait_for_server

SERVER_NAME = "test_server_for_WS"


@pytest.fixture
def server_port() -> int:
with socket.socket() as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
def server_port(worker_id: str) -> int:
return get_worker_specific_port(worker_id)


@pytest.fixture
Expand Down
116 changes: 116 additions & 0 deletions tests/test_helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Common test utilities for MCP server tests."""

import os
import socket
import time

Expand Down Expand Up @@ -29,3 +30,118 @@ def wait_for_server(port: int, timeout: float = 5.0) -> None:
# Server not ready yet, retry quickly
time.sleep(0.01)
raise TimeoutError(f"Server on port {port} did not start within {timeout} seconds")


def parse_worker_index(worker_id: str) -> int:
"""Parse worker index from pytest-xdist worker ID.

Extracts the numeric worker index from worker_id strings. Handles standard
formats ('master', 'gwN') with fallback for unexpected formats.

Args:
worker_id: pytest-xdist worker ID string (e.g., 'master', 'gw0', 'gw1')

Returns:
Worker index: 0 for 'master', N for 'gwN', hash-based fallback otherwise

Examples:
>>> parse_worker_index('master')
0
>>> parse_worker_index('gw0')
0
>>> parse_worker_index('gw5')
5
>>> parse_worker_index('unexpected_format') # Returns consistent hash-based value
42 # (example - actual value depends on hash)
"""
if worker_id == "master":
return 0

try:
# Try to extract number from 'gwN' format
return int(worker_id.replace("gw", ""))
except (ValueError, AttributeError):
# Fallback: if parsing fails, use hash of worker_id to avoid collisions
# Modulo 100 to keep worker indices reasonable
return abs(hash(worker_id)) % 100


def calculate_port_range(
worker_index: int, worker_count: int, base_port: int = 40000, total_ports: int = 20000
) -> tuple[int, int]:
"""Calculate non-overlapping port range for a worker.

Divides the total port range equally among workers, ensuring each worker
gets an exclusive range. Guarantees minimum of 100 ports per worker.

Args:
worker_index: Zero-based worker index
worker_count: Total number of workers in the test session
base_port: Starting port of the total range (default: 40000)
total_ports: Total number of ports available (default: 20000)

Returns:
Tuple of (start_port, end_port) where end_port is exclusive

Examples:
>>> calculate_port_range(0, 4) # 4 workers, first worker
(40000, 45000)
>>> calculate_port_range(1, 4) # 4 workers, second worker
(45000, 50000)
>>> calculate_port_range(0, 1) # Single worker gets all ports
(40000, 60000)
"""
# Calculate ports per worker (minimum 100 ports per worker)
ports_per_worker = max(100, total_ports // worker_count)

# Calculate this worker's port range
worker_base_port = base_port + (worker_index * ports_per_worker)
worker_max_port = min(worker_base_port + ports_per_worker, base_port + total_ports)

return worker_base_port, worker_max_port


def get_worker_specific_port(worker_id: str) -> int:
"""Get a free port specific to this pytest-xdist worker.

Allocates non-overlapping port ranges to each worker to prevent port conflicts
when running tests in parallel. This eliminates race conditions where multiple
workers try to bind to the same port.

Args:
worker_id: pytest-xdist worker ID string (e.g., 'master', 'gw0', 'gw1')

Returns:
An available port in this worker's range

Raises:
RuntimeError: If no available ports found in the worker's range
"""
# Parse worker index from worker_id
worker_index = parse_worker_index(worker_id)

# Get total number of workers from environment variable
worker_count = 1
worker_count_str = os.environ.get("PYTEST_XDIST_WORKER_COUNT")
if worker_count_str:
try:
worker_count = int(worker_count_str)
except ValueError:
# Fallback to single worker if parsing fails
worker_count = 1

# Calculate this worker's port range
worker_base_port, worker_max_port = calculate_port_range(worker_index, worker_count)

# Try to find an available port in this worker's range
for port in range(worker_base_port, worker_max_port):
try:
with socket.socket() as s:
s.bind(("127.0.0.1", port))
# Port is available, return it immediately
return port
except OSError:
# Port in use, try next one
continue

raise RuntimeError(f"No available ports in range {worker_base_port}-{worker_max_port - 1} for worker {worker_id}")
Loading