Skip to content

Commit dc90d27

Browse files
authored
Merge pull request #50 from sacha-development-stuff/codex/fix-coverage-threshold-failure-ooo725
Add token exchange metadata fallbacks and refresh match coverage
2 parents 04b8f53 + 26fb647 commit dc90d27

File tree

2 files changed

+198
-1
lines changed

2 files changed

+198
-1
lines changed

tests/unit/client/test_oauth2_providers.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,119 @@ async def post(self, url: str, *, data: dict[str, str], headers: dict[str, str])
829829
assert provider.client_metadata.scope is None
830830

831831

832+
@pytest.mark.anyio
833+
async def test_token_exchange_request_token_stops_on_non_authoritative_response(
834+
monkeypatch: pytest.MonkeyPatch,
835+
) -> None:
836+
storage = InMemoryStorage()
837+
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris(), scope="alpha")
838+
839+
provider = TokenExchangeProvider(
840+
"https://api.example.com/service",
841+
client_metadata,
842+
storage,
843+
subject_token_supplier=AsyncMock(return_value="subject-token"),
844+
)
845+
846+
metadata_responses = [
847+
_make_response(204),
848+
_make_response(200, json_data=_metadata_json()),
849+
]
850+
registration_response = _make_response(200, json_data=_registration_json())
851+
token_response = _make_response(200, json_data=_token_json("alpha"))
852+
853+
class RecordingAsyncClient(DummyAsyncClient):
854+
def __init__(self, *args: object, **kwargs: object) -> None:
855+
super().__init__(*args, **kwargs)
856+
self.send_calls = 0
857+
858+
async def send(self, request: httpx.Request) -> httpx.Response:
859+
self.send_calls += 1
860+
return await super().send(request)
861+
862+
recording_client = RecordingAsyncClient(send_responses=list(metadata_responses))
863+
clients = [
864+
recording_client,
865+
DummyAsyncClient(send_responses=[registration_response]),
866+
DummyAsyncClient(post_responses=[token_response]),
867+
]
868+
869+
monkeypatch.setattr("mcp.client.auth.oauth2.httpx.AsyncClient", AsyncClientFactory(clients))
870+
871+
await provider._request_token()
872+
873+
assert recording_client.send_calls == 1
874+
assert storage.tokens is not None
875+
assert storage.tokens.scope == "alpha"
876+
assert provider._metadata is None
877+
878+
879+
@pytest.mark.anyio
880+
async def test_token_exchange_request_token_stops_on_server_error(
881+
monkeypatch: pytest.MonkeyPatch,
882+
) -> None:
883+
storage = InMemoryStorage()
884+
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris(), scope="alpha")
885+
886+
provider = TokenExchangeProvider(
887+
"https://api.example.com/service",
888+
client_metadata,
889+
storage,
890+
subject_token_supplier=AsyncMock(return_value="subject-token"),
891+
)
892+
893+
metadata_responses = [_make_response(503)]
894+
registration_response = _make_response(200, json_data=_registration_json())
895+
token_response = _make_response(200, json_data=_token_json("alpha"))
896+
897+
clients = [
898+
DummyAsyncClient(send_responses=metadata_responses),
899+
DummyAsyncClient(send_responses=[registration_response]),
900+
DummyAsyncClient(post_responses=[token_response]),
901+
]
902+
903+
monkeypatch.setattr("mcp.client.auth.oauth2.httpx.AsyncClient", AsyncClientFactory(clients))
904+
905+
await provider._request_token()
906+
907+
assert storage.tokens is not None
908+
assert storage.tokens.scope == "alpha"
909+
assert provider._metadata is None
910+
911+
912+
@pytest.mark.anyio
913+
async def test_token_exchange_request_token_without_metadata(
914+
monkeypatch: pytest.MonkeyPatch,
915+
) -> None:
916+
storage = InMemoryStorage()
917+
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris(), scope="alpha")
918+
919+
provider = TokenExchangeProvider(
920+
"https://api.example.com/service",
921+
client_metadata,
922+
storage,
923+
subject_token_supplier=AsyncMock(return_value="subject-token"),
924+
)
925+
926+
metadata_responses = [_make_response(404) for _ in range(4)]
927+
registration_response = _make_response(200, json_data=_registration_json())
928+
token_response = _make_response(200, json_data=_token_json("alpha"))
929+
930+
clients = [
931+
DummyAsyncClient(send_responses=metadata_responses),
932+
DummyAsyncClient(send_responses=[registration_response]),
933+
DummyAsyncClient(post_responses=[token_response]),
934+
]
935+
936+
monkeypatch.setattr("mcp.client.auth.oauth2.httpx.AsyncClient", AsyncClientFactory(clients))
937+
938+
await provider._request_token()
939+
940+
assert storage.tokens is not None
941+
assert storage.tokens.scope == "alpha"
942+
assert provider._metadata is None
943+
944+
832945
@pytest.mark.anyio
833946
async def test_token_exchange_request_token_raises_on_failure(monkeypatch: pytest.MonkeyPatch) -> None:
834947
storage = InMemoryStorage()

tests/unit/server/auth/test_token_handler.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import json
44
import time
55
from collections.abc import Mapping
6-
from types import SimpleNamespace
6+
from types import MethodType, SimpleNamespace
77
from typing import Any, cast
88

99
import pytest
@@ -12,8 +12,10 @@
1212
from mcp.server.auth.handlers.token import (
1313
AuthorizationCodeRequest,
1414
ClientCredentialsRequest,
15+
RefreshTokenRequest,
1516
TokenErrorResponse,
1617
TokenHandler,
18+
TokenRequest,
1719
TokenSuccessResponse,
1820
)
1921
from mcp.server.auth.middleware.client_auth import ClientAuthenticator
@@ -287,6 +289,88 @@ async def test_handle_route_refresh_token_invalid_scope() -> None:
287289
}
288290

289291

292+
@pytest.mark.anyio
293+
async def test_handle_route_refresh_token_dispatches_to_handler(
294+
monkeypatch: pytest.MonkeyPatch,
295+
) -> None:
296+
provider = RefreshTokenProvider()
297+
client_info = OAuthClientInformationFull(
298+
client_id="client",
299+
grant_types=["refresh_token"],
300+
scope="alpha",
301+
)
302+
handler = TokenHandler(
303+
provider=cast(OAuthAuthorizationServerProvider[Any, Any, Any], provider),
304+
client_authenticator=cast(ClientAuthenticator, DummyAuthenticator(client_info)),
305+
)
306+
307+
captured_requests: list[RefreshTokenRequest] = []
308+
309+
async def fake_handle_refresh_token(
310+
self: TokenHandler,
311+
client: OAuthClientInformationFull,
312+
token_request: RefreshTokenRequest,
313+
) -> TokenSuccessResponse:
314+
captured_requests.append(token_request)
315+
return TokenSuccessResponse(root=OAuthToken(access_token="dispatched-token"))
316+
317+
monkeypatch.setattr(
318+
handler,
319+
"_handle_refresh_token",
320+
MethodType(fake_handle_refresh_token, handler),
321+
)
322+
323+
request_data = {
324+
"grant_type": "refresh_token",
325+
"refresh_token": "refresh-token",
326+
"client_id": "client",
327+
"client_secret": "secret",
328+
}
329+
330+
response = await handler.handle(cast(Request, DummyRequest(request_data)))
331+
332+
assert response.status_code == 200
333+
assert captured_requests
334+
assert isinstance(captured_requests[0], RefreshTokenRequest)
335+
336+
337+
@pytest.mark.anyio
338+
async def test_handle_route_refresh_token_unrecognized_request(
339+
monkeypatch: pytest.MonkeyPatch,
340+
) -> None:
341+
provider = RefreshTokenProvider()
342+
client_info = OAuthClientInformationFull(
343+
client_id="client",
344+
grant_types=["mystery"],
345+
scope="alpha",
346+
)
347+
handler = TokenHandler(
348+
provider=cast(OAuthAuthorizationServerProvider[Any, Any, Any], provider),
349+
client_authenticator=cast(ClientAuthenticator, DummyAuthenticator(client_info)),
350+
)
351+
352+
class UnknownRequest:
353+
grant_type = "mystery"
354+
client_id = "client"
355+
client_secret = "secret"
356+
357+
unknown_request = UnknownRequest()
358+
359+
def fake_model_validate(cls: type[TokenRequest], data: dict[str, object]) -> SimpleNamespace: # type: ignore[unused-argument]
360+
return SimpleNamespace(root=unknown_request)
361+
362+
monkeypatch.setattr(TokenRequest, "model_validate", classmethod(fake_model_validate))
363+
364+
request_data = {
365+
"grant_type": "mystery",
366+
"client_id": "client",
367+
"client_secret": "secret",
368+
}
369+
370+
with pytest.raises(UnboundLocalError):
371+
await handler.handle(cast(Request, DummyRequest(request_data)))
372+
373+
290374
@pytest.mark.anyio
291375
async def test_handle_route_token_exchange_branch() -> None:
292376
provider = TokenExchangeProviderStub()

0 commit comments

Comments
 (0)