-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Description
Summary
get_access_token() returns the bearer token from the session-creating request for the entire lifetime of a stateful streamable-HTTP session, regardless of what Authorization header later requests send.
Repro
Open for code
import multiprocessing
import socket
import time
import httpx
import pytest
import uvicorn
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.routing import Mount
from mcp.client.session import ClientSession
from mcp.client.streamable_http import streamable_http_client
from mcp.server import Server, ServerRequestContext
from mcp.server.auth.middleware.auth_context import AuthContextMiddleware, get_access_token
from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend
from mcp.server.auth.provider import AccessToken
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from mcp.server.transport_security import TransportSecuritySettings
from mcp.types import (
CallToolRequestParams,
CallToolResult,
ListToolsResult,
PaginatedRequestParams,
TextContent,
Tool,
)
class _EchoTokenVerifier:
"""Accepts any bearer and echoes it back so we can tell tokens apart."""
async def verify_token(self, token: str) -> AccessToken | None:
return AccessToken(token=token, client_id=token, scopes=[], expires_at=int(time.time()) + 3600)
async def _handle_whoami(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult:
# The user-facing contract of get_access_token(): call it from a handler,
# get the token for the current request.
access = get_access_token()
text = access.token if access else "<none>"
return CallToolResult(content=[TextContent(type="text", text=text)])
async def _handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult:
return ListToolsResult(tools=[Tool(name="whoami", input_schema={"type": "object", "properties": {}})])
def _run_auth_server(port: int) -> None:
server = Server(name="auth_test_server", on_call_tool=_handle_whoami, on_list_tools=_handle_list_tools)
security = TransportSecuritySettings(allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"])
session_manager = StreamableHTTPSessionManager(app=server, security_settings=security, stateless=False)
# Same middleware chain lowlevel Server.streamable_http_app builds when auth is on
asgi_app = Starlette(
routes=[Mount("/mcp", app=session_manager.handle_request)],
middleware=[
Middleware(AuthenticationMiddleware, backend=BearerAuthBackend(_EchoTokenVerifier())),
Middleware(AuthContextMiddleware),
],
lifespan=lambda app: session_manager.run(),
)
uvicorn.run(asgi_app, host="127.0.0.1", port=port, log_level="error")
class _MutableBearerAuth(httpx.Auth):
"""Reads the bearer from a mutable attribute at send-time so we can swap mid-session."""
def __init__(self, token: str) -> None:
self.token = token
def auth_flow(self, request: httpx.Request):
request.headers["Authorization"] = f"Bearer {self.token}"
yield request
@pytest.mark.anyio
async def test_get_access_token_reflects_current_request_in_stateful_session() -> None:
with socket.socket() as s:
s.bind(("127.0.0.1", 0))
port = s.getsockname()[1]
proc = multiprocessing.Process(target=_run_auth_server, args=(port,), daemon=True)
proc.start()
try:
# wait for server
for _ in range(200):
try:
with socket.socket() as s:
s.connect(("127.0.0.1", port))
break
except OSError:
time.sleep(0.01)
url = f"http://127.0.0.1:{port}/mcp"
auth = _MutableBearerAuth("token-A")
async with httpx.AsyncClient(auth=auth, timeout=httpx.Timeout(30, read=30), follow_redirects=True) as http_client:
async with streamable_http_client(url, http_client=http_client) as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
# Request 1: session created, _receive_loop spawns with token-A in its context
r1 = await session.call_tool("whoami", {})
assert isinstance(r1.content[0], TextContent)
assert r1.content[0].text == "token-A"
# Request 2: same session, different bearer — reuses the existing _receive_loop
auth.token = "token-B"
r2 = await session.call_tool("whoami", {})
assert isinstance(r2.content[0], TextContent)
# EXPECTED: "token-B" — handler should see the token sent with THIS request
# ACTUAL: "token-A" — handler sees the session-creating request's token
assert r2.content[0].text == "token-B"
finally:
proc.kill()
proc.join(timeout=2)Result on main:
AssertionError: assert 'token-A' == 'token-B'
- token-B
? ^
+ token-A
? ^
Root cause
AuthContextMiddleware sets auth_context_var inside the ASGI request's task. But the tool handler doesn't run in that task — it runs in a task spawned by Server.run()'s tg.start_soon(_handle_message, ...), which is itself inside run_server, which was spawned at session creation:
| await self._task_group.start(run_server) |
tg.start() copies the caller's contextvars.Context at call time. So run_server (and every task it spawns) carries a snapshot from request 1. Requests 2..N write to the transport's read stream from the new ASGI task, but the reader is _receive_loop — still running with the request-1 snapshot. The ContextVar set in request N's ASGI task never reaches the handler.
ASGI req 1 (auth_context_var=A)
└─ tg.start(run_server) ← context copied: A
└─ ServerSession.__aenter__
└─ tg.start_soon(_receive_loop) ← inherits A
└─ async for msg in session.incoming_messages:
└─ tg.start_soon(_handle_message) ← inherits A
└─ tool handler: get_access_token() → A ✓
ASGI req 2 (auth_context_var=B)
└─ transport.handle_request(...) ← writes to read_stream
... _receive_loop (still ctx A) reads it ...
... tg.start_soon(_handle_message) ← inherits A, not B
└─ tool handler: get_access_token() → A ✗
The existing unit test passes because MockApp runs inline in the same task as the middleware — no stream crossing:
| self.access_token_during_call = get_access_token() |
Impact
- Token refresh mid-session: refreshed token is invisible to handlers.
- Compounds Bind authenticated identity to sessions in StreamableHTTPSessionManager #2100: if a session ID is hijacked, the attacker's requests execute with the victim's token visible to
get_access_token(). - Stateless HTTP is unaffected — new session per request means a fresh context snapshot each time.
The correct path already exists
The Starlette Request is threaded explicitly through ServerMessageMetadata.request_context → ServerRequestContext.request. Inside a handler:
request: Request = ctx.request_context.request
user = request.user # set by AuthenticationMiddleware
token = user.access_token # per-request, correctProposed fix
Given get_access_token() has no callers in src/ or examples/ and isn't documented: remove auth_context_var, get_access_token(), and AuthContextMiddleware. Expose auth on Context as part of #2098 using the explicit request threading above.
Alternative (if we want to keep the API): thread the AuthenticatedUser alongside request on ServerMessageMetadata and set the contextvar at the tg.start_soon site in Server.run(). But that's re-inventing what request.user already provides.
Related
- Expose session, auth, and transport information on handler Context #2098 — expose session/auth/transport on
Context - Bind authenticated identity to sessions in StreamableHTTPSessionManager #2100 — bind authenticated identity to sessions
- Extract OAuth flow logic into reusable components for proxy use cases #1743 — auth rework
- Propagate ContextVars to Transport Layer in MCP Clients #1969 — general user-ContextVar propagation through streams (different scope: arbitrary user vars, not just auth)