Skip to content

Commit 6399e3a

Browse files
authored
Merge pull request #43 from sacha-development-stuff/codex/investigate-missing-coverage-in-oauth2.py
Increase coverage for OAuth client credential and refresh flows
2 parents 699f011 + 1286359 commit 6399e3a

File tree

2 files changed

+108
-0
lines changed

2 files changed

+108
-0
lines changed

tests/unit/client/test_oauth2_providers.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,83 @@ async def test_client_credentials_request_token_without_metadata(monkeypatch: py
464464
assert provider._metadata is None
465465

466466

467+
@pytest.mark.anyio
468+
async def test_client_credentials_request_token_omits_scope_when_not_registered(
469+
monkeypatch: pytest.MonkeyPatch,
470+
) -> None:
471+
storage = InMemoryStorage()
472+
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())
473+
provider = ClientCredentialsProvider("https://api.example.com/service", client_metadata, storage)
474+
475+
metadata_json = _metadata_json().copy()
476+
metadata_json.pop("scopes_supported")
477+
metadata_response = _make_response(200, json_data=metadata_json)
478+
registration_response = _make_response(200, json_data=_registration_json())
479+
token_response = _make_response(200, json_data=_token_json())
480+
481+
class CapturingAsyncClient(DummyAsyncClient):
482+
def __init__(self, *args: object, **kwargs: object) -> None:
483+
super().__init__(*args, **kwargs)
484+
self.captured_data: dict[str, str] | None = None
485+
self.captured_headers: dict[str, str] | None = None
486+
487+
async def post(
488+
self,
489+
url: str,
490+
*,
491+
data: dict[str, str],
492+
headers: dict[str, str],
493+
) -> httpx.Response:
494+
self.captured_data = dict(data)
495+
self.captured_headers = dict(headers)
496+
assert self._post_responses, "Unexpected post() call"
497+
return self._post_responses.pop(0)
498+
499+
capturing_client = CapturingAsyncClient(post_responses=[token_response])
500+
clients = [
501+
DummyAsyncClient(send_responses=[metadata_response]),
502+
DummyAsyncClient(send_responses=[registration_response]),
503+
capturing_client,
504+
]
505+
monkeypatch.setattr("mcp.client.auth.oauth2.httpx.AsyncClient", AsyncClientFactory(clients))
506+
507+
await provider._request_token()
508+
509+
assert capturing_client.captured_data is not None
510+
assert capturing_client.captured_headers == {
511+
"Content-Type": "application/x-www-form-urlencoded"
512+
}
513+
assert capturing_client.captured_data["grant_type"] == "client_credentials"
514+
assert capturing_client.captured_data["resource"] == provider.resource
515+
assert "scope" not in capturing_client.captured_data
516+
517+
518+
@pytest.mark.anyio
519+
async def test_client_credentials_request_token_stops_on_server_error(
520+
monkeypatch: pytest.MonkeyPatch,
521+
) -> None:
522+
storage = InMemoryStorage()
523+
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris(), scope="alpha")
524+
provider = ClientCredentialsProvider("https://api.example.com/service", client_metadata, storage)
525+
526+
metadata_responses = [_make_response(503)]
527+
registration_response = _make_response(200, json_data=_registration_json())
528+
token_response = _make_response(200, json_data=_token_json("alpha"))
529+
530+
clients = [
531+
DummyAsyncClient(send_responses=metadata_responses),
532+
DummyAsyncClient(send_responses=[registration_response]),
533+
DummyAsyncClient(post_responses=[token_response]),
534+
]
535+
monkeypatch.setattr("mcp.client.auth.oauth2.httpx.AsyncClient", AsyncClientFactory(clients))
536+
537+
await provider._request_token()
538+
539+
assert storage.tokens is not None
540+
assert storage.tokens.scope == "alpha"
541+
assert provider._metadata is None
542+
543+
467544
@pytest.mark.anyio
468545
async def test_client_credentials_ensure_token_returns_when_valid() -> None:
469546
storage = InMemoryStorage()

tests/unit/server/auth/test_token_handler.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,37 @@ async def test_handle_route_refresh_token_branch() -> None:
256256
assert payload["access_token"] == "refreshed-token"
257257

258258

259+
@pytest.mark.anyio
260+
async def test_handle_route_refresh_token_invalid_scope() -> None:
261+
provider = RefreshTokenProvider()
262+
client_info = OAuthClientInformationFull(
263+
client_id="client",
264+
grant_types=["refresh_token"],
265+
scope="alpha",
266+
)
267+
handler = TokenHandler(
268+
provider=cast(OAuthAuthorizationServerProvider[Any, Any, Any], provider),
269+
client_authenticator=cast(ClientAuthenticator, DummyAuthenticator(client_info)),
270+
)
271+
272+
request_data = {
273+
"grant_type": "refresh_token",
274+
"refresh_token": "refresh-token",
275+
"scope": "beta",
276+
"client_id": "client",
277+
"client_secret": "secret",
278+
}
279+
280+
response = await handler.handle(cast(Request, DummyRequest(request_data)))
281+
282+
assert response.status_code == 400
283+
payload = json.loads(bytes(response.body).decode())
284+
assert payload == {
285+
"error": "invalid_scope",
286+
"error_description": "cannot request scope `beta` not provided by refresh token",
287+
}
288+
289+
259290
@pytest.mark.anyio
260291
async def test_handle_route_token_exchange_branch() -> None:
261292
provider = TokenExchangeProviderStub()

0 commit comments

Comments
 (0)