Skip to content

Commit fc8331c

Browse files
committed
Implement RFC 7523 authorization grant flow
1 parent 6677894 commit fc8331c

File tree

5 files changed

+252
-19
lines changed

5 files changed

+252
-19
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ dependencies = [
3333
"uvicorn>=0.23.1; sys_platform != 'emscripten'",
3434
"jsonschema>=4.20.0",
3535
"pywin32>=310; sys_platform == 'win32'",
36+
"pyjwt[crypto]>=2.10.1",
3637
]
3738

3839
[project.optional-dependencies]

src/mcp/client/auth.py

Lines changed: 91 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@
1313
import time
1414
from collections.abc import AsyncGenerator, Awaitable, Callable
1515
from dataclasses import dataclass, field
16-
from typing import Protocol
16+
from typing import Any, Protocol
1717
from urllib.parse import urlencode, urljoin, urlparse
18+
from uuid import uuid4
1819

1920
import anyio
2021
import httpx
22+
import jwt
2123
from pydantic import BaseModel, Field, ValidationError
2224

2325
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
@@ -61,6 +63,23 @@ def generate(cls) -> "PKCEParameters":
6163
return cls(code_verifier=code_verifier, code_challenge=code_challenge)
6264

6365

66+
class JWTParameters(BaseModel):
67+
"""JWT parameters."""
68+
69+
assertion: str | None = Field(
70+
default=None,
71+
description="JWT assertion for JWT authentication. "
72+
"Will be used instead of generating a new assertion if provided.",
73+
)
74+
75+
issuer: str | None = Field(default=None, description="Issuer for JWT assertions.")
76+
subject: str | None = Field(default=None, description="Subject identifier for JWT assertions.")
77+
claims: dict[str, Any] | None = Field(default=None, description="Additional claims for JWT assertions.")
78+
jwt_signing_algorithm: str | None = Field(default="RS256", description="Algorithm for signing JWT assertions.")
79+
jwt_signing_key: str | None = Field(default=None, description="Private key for JWT signing.")
80+
jwt_lifetime_seconds: int = Field(default=300, description="Lifetime of generated JWT in seconds.")
81+
82+
6483
class TokenStorage(Protocol):
6584
"""Protocol for token storage implementations."""
6685

@@ -91,6 +110,7 @@ class OAuthContext:
91110
redirect_handler: Callable[[str], Awaitable[None]] | None
92111
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None
93112
timeout: float = 300.0
113+
jwt_parameters: JWTParameters | None = None
94114

95115
# Discovered metadata
96116
protected_resource_metadata: ProtectedResourceMetadata | None = None
@@ -192,6 +212,7 @@ def __init__(
192212
redirect_handler: Callable[[str], Awaitable[None]] | None = None,
193213
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None,
194214
timeout: float = 300.0,
215+
jwt_parameters: JWTParameters | None = None,
195216
):
196217
"""Initialize OAuth2 authentication."""
197218
self.context = OAuthContext(
@@ -201,6 +222,7 @@ def __init__(
201222
redirect_handler=redirect_handler,
202223
callback_handler=callback_handler,
203224
timeout=timeout,
225+
jwt_parameters=jwt_parameters,
204226
)
205227
self._initialized = False
206228

@@ -314,6 +336,9 @@ async def _perform_authorization(self) -> httpx.Request:
314336
if "client_credentials" in self.context.client_metadata.grant_types:
315337
token_request = await self._exchange_token_client_credentials()
316338
return token_request
339+
elif "urn:ietf:params:oauth:grant-type:jwt-bearer" in self.context.client_metadata.grant_types:
340+
token_request = await self._exchange_token_jwt_bearer()
341+
return token_request
317342
else:
318343
auth_code, code_verifier = await self._perform_authorization_code_grant()
319344
token_request = await self._exchange_token_authorization_code(auth_code, code_verifier)
@@ -372,19 +397,22 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]:
372397
# Return auth code and code verifier for token exchange
373398
return auth_code, pkce_params.code_verifier
374399

400+
def _get_token_endpoint(self) -> str:
401+
if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint:
402+
token_url = str(self.context.oauth_metadata.token_endpoint)
403+
else:
404+
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
405+
token_url = urljoin(auth_base_url, "/token")
406+
return token_url
407+
375408
async def _exchange_token_authorization_code(self, auth_code: str, code_verifier: str) -> httpx.Request:
376409
"""Build token exchange request for authorization_code flow."""
377410
if self.context.client_metadata.redirect_uris is None:
378411
raise OAuthFlowError("No redirect URIs provided for authorization code grant")
379412
if not self.context.client_info:
380413
raise OAuthFlowError("Missing client info")
381414

382-
if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint:
383-
token_url = str(self.context.oauth_metadata.token_endpoint)
384-
else:
385-
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
386-
token_url = urljoin(auth_base_url, "/token")
387-
415+
token_url = self._get_token_endpoint()
388416
token_data = {
389417
"grant_type": "authorization_code",
390418
"code": auth_code,
@@ -409,19 +437,17 @@ async def _exchange_token_client_credentials(self) -> httpx.Request:
409437
if not self.context.client_info:
410438
raise OAuthFlowError("Missing client info")
411439

412-
if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint:
413-
token_url = str(self.context.oauth_metadata.token_endpoint)
414-
else:
415-
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
416-
token_url = urljoin(auth_base_url, "/token")
417-
440+
token_url = self._get_token_endpoint()
418441
token_data = {
419442
"grant_type": "client_credentials",
420-
"resource": self.context.get_resource_url(), # RFC 8707
421443
}
422444

423445
headers = {"Content-Type": "application/x-www-form-urlencoded"}
424446

447+
# Only include resource param if conditions are met
448+
if self.context.should_include_resource_param(self.context.protocol_version):
449+
token_data["resource"] = self.context.get_resource_url() # RFC 8707
450+
425451
if self.context.client_metadata.scope:
426452
token_data["scope"] = self.context.client_metadata.scope
427453

@@ -442,6 +468,57 @@ async def _exchange_token_client_credentials(self) -> httpx.Request:
442468

443469
return httpx.Request("POST", token_url, data=token_data, headers=headers)
444470

471+
async def _exchange_token_jwt_bearer(self) -> httpx.Request:
472+
"""Build token exchange request for JWT bearer grant."""
473+
if not self.context.client_info:
474+
raise OAuthFlowError("Missing client info")
475+
if not self.context.jwt_parameters:
476+
raise OAuthFlowError("Missing JWT parameters")
477+
478+
token_url = self._get_token_endpoint()
479+
480+
if self.context.jwt_parameters.assertion is not None:
481+
assertion = self.context.jwt_parameters.assertion
482+
else:
483+
if not self.context.jwt_parameters.jwt_signing_key:
484+
raise OAuthFlowError("Missing signing key for JWT bearer grant")
485+
if not self.context.jwt_parameters.issuer:
486+
raise OAuthFlowError("Missing issuer for JWT bearer grant")
487+
if not self.context.jwt_parameters.subject:
488+
raise OAuthFlowError("Missing subject for JWT bearer grant")
489+
490+
now = int(time.time())
491+
claims = {
492+
"iss": self.context.jwt_parameters.issuer,
493+
"sub": self.context.jwt_parameters.subject,
494+
"aud": token_url,
495+
"exp": now + self.context.jwt_parameters.jwt_lifetime_seconds,
496+
"iat": now,
497+
"jti": str(uuid4()),
498+
}
499+
claims.update(self.context.jwt_parameters.claims or {})
500+
501+
assertion = jwt.encode(
502+
claims,
503+
self.context.jwt_parameters.jwt_signing_key,
504+
algorithm=self.context.jwt_parameters.jwt_signing_algorithm or "RS256",
505+
)
506+
507+
token_data = {
508+
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
509+
"assertion": assertion,
510+
}
511+
512+
if self.context.should_include_resource_param(self.context.protocol_version):
513+
token_data["resource"] = self.context.get_resource_url()
514+
515+
if self.context.client_metadata.scope:
516+
token_data["scope"] = self.context.client_metadata.scope
517+
518+
return httpx.Request(
519+
"POST", token_url, data=token_data, headers={"Content-Type": "application/x-www-form-urlencoded"}
520+
)
521+
445522
async def _handle_token_response(self, response: httpx.Response) -> None:
446523
"""Handle token exchange response."""
447524
if response.status_code != 200:

src/mcp/shared/auth.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,15 @@ class OAuthClientMetadata(BaseModel):
4343

4444
redirect_uris: list[AnyUrl] | None = Field(..., min_length=1)
4545
# supported auth methods for the token endpoint
46-
token_endpoint_auth_method: Literal["none", "client_secret_basic", "client_secret_post"] = "client_secret_post"
46+
token_endpoint_auth_method: Literal["none", "client_secret_basic", "client_secret_post", "private_key_jwt"] = (
47+
"client_secret_post"
48+
)
4749
# supported grant_types of this implementation
48-
grant_types: list[Literal["authorization_code", "client_credentials", "refresh_token"]] = [
50+
grant_types: list[
51+
Literal[
52+
"authorization_code", "client_credentials", "refresh_token", "urn:ietf:params:oauth:grant-type:jwt-bearer"
53+
]
54+
] = [
4955
"authorization_code",
5056
"refresh_token",
5157
]

tests/client/test_auth.py

Lines changed: 89 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99
from unittest import mock
1010

1111
import httpx
12+
import jwt
1213
import pytest
1314
from inline_snapshot import Is, snapshot
1415
from pydantic import AnyHttpUrl, AnyUrl
1516

16-
from mcp.client.auth import OAuthClientProvider, PKCEParameters
17+
from mcp.client.auth import JWTParameters, OAuthClientProvider, PKCEParameters
1718
from mcp.shared.auth import (
1819
OAuthClientInformationFull,
1920
OAuthClientMetadata,
@@ -434,7 +435,7 @@ async def test_token_exchange_request_authorization_code(self, oauth_provider):
434435
assert "client_secret=test_secret" in content
435436

436437
@pytest.mark.anyio
437-
async def test_token_exchange_request_client_credentials_basic(self, oauth_provider):
438+
async def test_token_exchange_request_client_credentials_basic(self, oauth_provider: OAuthClientProvider):
438439
"""Test token exchange request building."""
439440
# Set up required context
440441
oauth_provider.context.client_info = oauth_provider.context.client_metadata = OAuthClientInformationFull(
@@ -445,6 +446,7 @@ async def test_token_exchange_request_client_credentials_basic(self, oauth_provi
445446
redirect_uris=None,
446447
scope="read write",
447448
)
449+
oauth_provider.context.protocol_version = "2025-06-18"
448450

449451
request = await oauth_provider._exchange_token_client_credentials()
450452

@@ -466,7 +468,7 @@ async def test_token_exchange_request_client_credentials_basic(self, oauth_provi
466468
assert base64.b64decode(request.headers["Authorization"].split(" ")[1]).decode() == "test_client:test_secret"
467469

468470
@pytest.mark.anyio
469-
async def test_token_exchange_request_client_credentials_post(self, oauth_provider):
471+
async def test_token_exchange_request_client_credentials_post(self, oauth_provider: OAuthClientProvider):
470472
"""Test token exchange request building."""
471473
# Set up required context
472474
oauth_provider.context.client_info = oauth_provider.context.client_metadata = OAuthClientInformationFull(
@@ -477,6 +479,7 @@ async def test_token_exchange_request_client_credentials_post(self, oauth_provid
477479
redirect_uris=None,
478480
scope="read write",
479481
)
482+
oauth_provider.context.protocol_version = "2025-06-18"
480483

481484
request = await oauth_provider._exchange_token_client_credentials()
482485

@@ -492,6 +495,89 @@ async def test_token_exchange_request_client_credentials_post(self, oauth_provid
492495
assert "client_id=test_client" in content
493496
assert "client_secret=test_secret" in content
494497

498+
@pytest.mark.anyio
499+
async def test_token_exchange_request_jwt_predefined(self, oauth_provider: OAuthClientProvider):
500+
"""Test token exchange request building with a predefined JWT assertion."""
501+
# Set up required context
502+
oauth_provider.context.client_info = oauth_provider.context.client_metadata = OAuthClientInformationFull(
503+
grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"],
504+
token_endpoint_auth_method="private_key_jwt",
505+
redirect_uris=None,
506+
scope="read write",
507+
)
508+
oauth_provider.context.protocol_version = "2025-06-18"
509+
oauth_provider.context.jwt_parameters = JWTParameters(
510+
# https://www.jwt.io
511+
assertion="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.KMUFsIDTnFmyG3nMiGM6H9FNFUROf3wh7SmqJp-QV30"
512+
)
513+
514+
request = await oauth_provider._exchange_token_jwt_bearer()
515+
516+
assert request.method == "POST"
517+
assert str(request.url) == "https://api.example.com/token"
518+
assert request.headers["Content-Type"] == "application/x-www-form-urlencoded"
519+
520+
# Check form data
521+
content = urllib.parse.unquote_plus(request.content.decode())
522+
assert "grant_type=urn:ietf:params:oauth:grant-type:jwt-bearer" in content
523+
assert "scope=read write" in content
524+
assert "resource=https://api.example.com/v1/mcp" in content
525+
assert (
526+
"assertion=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.KMUFsIDTnFmyG3nMiGM6H9FNFUROf3wh7SmqJp-QV30"
527+
in content
528+
)
529+
530+
@pytest.mark.anyio
531+
async def test_token_exchange_request_jwt(self, oauth_provider: OAuthClientProvider):
532+
"""Test token exchange request building wiith a generated JWT assertion."""
533+
# Set up required context
534+
oauth_provider.context.client_info = oauth_provider.context.client_metadata = OAuthClientInformationFull(
535+
grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"],
536+
token_endpoint_auth_method="private_key_jwt",
537+
redirect_uris=None,
538+
scope="read write",
539+
)
540+
oauth_provider.context.protocol_version = "2025-06-18"
541+
oauth_provider.context.jwt_parameters = JWTParameters(
542+
issuer="foo",
543+
subject="1234567890",
544+
claims={
545+
"name": "John Doe",
546+
"admin": True,
547+
"iat": 1516239022,
548+
},
549+
jwt_signing_algorithm="HS256",
550+
jwt_signing_key="a-string-secret-at-least-256-bits-long",
551+
jwt_lifetime_seconds=300,
552+
)
553+
554+
request = await oauth_provider._exchange_token_jwt_bearer()
555+
556+
assert request.method == "POST"
557+
assert str(request.url) == "https://api.example.com/token"
558+
assert request.headers["Content-Type"] == "application/x-www-form-urlencoded"
559+
560+
# Check form data
561+
content = urllib.parse.unquote_plus(request.content.decode()).split("&")
562+
assert "grant_type=urn:ietf:params:oauth:grant-type:jwt-bearer" in content
563+
assert "scope=read write" in content
564+
assert "resource=https://api.example.com/v1/mcp" in content
565+
566+
# Check assertion
567+
assertion = next(param for param in content if param.startswith("assertion="))[len("assertion=") :]
568+
claims = jwt.decode(
569+
assertion,
570+
key="a-string-secret-at-least-256-bits-long",
571+
algorithms=["HS256"],
572+
audience="https://api.example.com/token",
573+
subject="1234567890",
574+
issuer="foo",
575+
verify=True,
576+
)
577+
assert claims["name"] == "John Doe"
578+
assert claims["admin"]
579+
assert claims["iat"] == 1516239022
580+
495581
@pytest.mark.anyio
496582
async def test_refresh_token_request(self, oauth_provider, valid_tokens):
497583
"""Test refresh token request building."""

0 commit comments

Comments
 (0)