13
13
import time
14
14
from collections .abc import AsyncGenerator , Awaitable , Callable
15
15
from dataclasses import dataclass , field
16
- from typing import Protocol
16
+ from typing import Any , Protocol
17
17
from urllib .parse import urlencode , urljoin , urlparse
18
+ from uuid import uuid4
18
19
19
20
import anyio
20
21
import httpx
22
+ import jwt
21
23
from pydantic import BaseModel , Field , ValidationError
22
24
23
25
from mcp .client .streamable_http import MCP_PROTOCOL_VERSION
@@ -61,6 +63,23 @@ def generate(cls) -> "PKCEParameters":
61
63
return cls (code_verifier = code_verifier , code_challenge = code_challenge )
62
64
63
65
66
+ class JWTParameters (BaseModel ):
67
+ """JWT parameters."""
68
+
69
+ assertion : str | None = Field (
70
+ default = None ,
71
+ description = "JWT assertion for JWT authentication. "
72
+ "Will be used instead of generating a new assertion if provided." ,
73
+ )
74
+
75
+ issuer : str | None = Field (default = None , description = "Issuer for JWT assertions." )
76
+ subject : str | None = Field (default = None , description = "Subject identifier for JWT assertions." )
77
+ claims : dict [str , Any ] | None = Field (default = None , description = "Additional claims for JWT assertions." )
78
+ jwt_signing_algorithm : str | None = Field (default = "RS256" , description = "Algorithm for signing JWT assertions." )
79
+ jwt_signing_key : str | None = Field (default = None , description = "Private key for JWT signing." )
80
+ jwt_lifetime_seconds : int = Field (default = 300 , description = "Lifetime of generated JWT in seconds." )
81
+
82
+
64
83
class TokenStorage (Protocol ):
65
84
"""Protocol for token storage implementations."""
66
85
@@ -91,6 +110,7 @@ class OAuthContext:
91
110
redirect_handler : Callable [[str ], Awaitable [None ]] | None
92
111
callback_handler : Callable [[], Awaitable [tuple [str , str | None ]]] | None
93
112
timeout : float = 300.0
113
+ jwt_parameters : JWTParameters | None = None
94
114
95
115
# Discovered metadata
96
116
protected_resource_metadata : ProtectedResourceMetadata | None = None
@@ -192,6 +212,7 @@ def __init__(
192
212
redirect_handler : Callable [[str ], Awaitable [None ]] | None = None ,
193
213
callback_handler : Callable [[], Awaitable [tuple [str , str | None ]]] | None = None ,
194
214
timeout : float = 300.0 ,
215
+ jwt_parameters : JWTParameters | None = None ,
195
216
):
196
217
"""Initialize OAuth2 authentication."""
197
218
self .context = OAuthContext (
@@ -201,6 +222,7 @@ def __init__(
201
222
redirect_handler = redirect_handler ,
202
223
callback_handler = callback_handler ,
203
224
timeout = timeout ,
225
+ jwt_parameters = jwt_parameters ,
204
226
)
205
227
self ._initialized = False
206
228
@@ -314,6 +336,9 @@ async def _perform_authorization(self) -> httpx.Request:
314
336
if "client_credentials" in self .context .client_metadata .grant_types :
315
337
token_request = await self ._exchange_token_client_credentials ()
316
338
return token_request
339
+ elif "urn:ietf:params:oauth:grant-type:jwt-bearer" in self .context .client_metadata .grant_types :
340
+ token_request = await self ._exchange_token_jwt_bearer ()
341
+ return token_request
317
342
else :
318
343
auth_code , code_verifier = await self ._perform_authorization_code_grant ()
319
344
token_request = await self ._exchange_token_authorization_code (auth_code , code_verifier )
@@ -372,19 +397,22 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]:
372
397
# Return auth code and code verifier for token exchange
373
398
return auth_code , pkce_params .code_verifier
374
399
400
+ def _get_token_endpoint (self ) -> str :
401
+ if self .context .oauth_metadata and self .context .oauth_metadata .token_endpoint :
402
+ token_url = str (self .context .oauth_metadata .token_endpoint )
403
+ else :
404
+ auth_base_url = self .context .get_authorization_base_url (self .context .server_url )
405
+ token_url = urljoin (auth_base_url , "/token" )
406
+ return token_url
407
+
375
408
async def _exchange_token_authorization_code (self , auth_code : str , code_verifier : str ) -> httpx .Request :
376
409
"""Build token exchange request for authorization_code flow."""
377
410
if self .context .client_metadata .redirect_uris is None :
378
411
raise OAuthFlowError ("No redirect URIs provided for authorization code grant" )
379
412
if not self .context .client_info :
380
413
raise OAuthFlowError ("Missing client info" )
381
414
382
- if self .context .oauth_metadata and self .context .oauth_metadata .token_endpoint :
383
- token_url = str (self .context .oauth_metadata .token_endpoint )
384
- else :
385
- auth_base_url = self .context .get_authorization_base_url (self .context .server_url )
386
- token_url = urljoin (auth_base_url , "/token" )
387
-
415
+ token_url = self ._get_token_endpoint ()
388
416
token_data = {
389
417
"grant_type" : "authorization_code" ,
390
418
"code" : auth_code ,
@@ -409,19 +437,17 @@ async def _exchange_token_client_credentials(self) -> httpx.Request:
409
437
if not self .context .client_info :
410
438
raise OAuthFlowError ("Missing client info" )
411
439
412
- if self .context .oauth_metadata and self .context .oauth_metadata .token_endpoint :
413
- token_url = str (self .context .oauth_metadata .token_endpoint )
414
- else :
415
- auth_base_url = self .context .get_authorization_base_url (self .context .server_url )
416
- token_url = urljoin (auth_base_url , "/token" )
417
-
440
+ token_url = self ._get_token_endpoint ()
418
441
token_data = {
419
442
"grant_type" : "client_credentials" ,
420
- "resource" : self .context .get_resource_url (), # RFC 8707
421
443
}
422
444
423
445
headers = {"Content-Type" : "application/x-www-form-urlencoded" }
424
446
447
+ # Only include resource param if conditions are met
448
+ if self .context .should_include_resource_param (self .context .protocol_version ):
449
+ token_data ["resource" ] = self .context .get_resource_url () # RFC 8707
450
+
425
451
if self .context .client_metadata .scope :
426
452
token_data ["scope" ] = self .context .client_metadata .scope
427
453
@@ -442,6 +468,57 @@ async def _exchange_token_client_credentials(self) -> httpx.Request:
442
468
443
469
return httpx .Request ("POST" , token_url , data = token_data , headers = headers )
444
470
471
+ async def _exchange_token_jwt_bearer (self ) -> httpx .Request :
472
+ """Build token exchange request for JWT bearer grant."""
473
+ if not self .context .client_info :
474
+ raise OAuthFlowError ("Missing client info" )
475
+ if not self .context .jwt_parameters :
476
+ raise OAuthFlowError ("Missing JWT parameters" )
477
+
478
+ token_url = self ._get_token_endpoint ()
479
+
480
+ if self .context .jwt_parameters .assertion is not None :
481
+ assertion = self .context .jwt_parameters .assertion
482
+ else :
483
+ if not self .context .jwt_parameters .jwt_signing_key :
484
+ raise OAuthFlowError ("Missing signing key for JWT bearer grant" )
485
+ if not self .context .jwt_parameters .issuer :
486
+ raise OAuthFlowError ("Missing issuer for JWT bearer grant" )
487
+ if not self .context .jwt_parameters .subject :
488
+ raise OAuthFlowError ("Missing subject for JWT bearer grant" )
489
+
490
+ now = int (time .time ())
491
+ claims = {
492
+ "iss" : self .context .jwt_parameters .issuer ,
493
+ "sub" : self .context .jwt_parameters .subject ,
494
+ "aud" : token_url ,
495
+ "exp" : now + self .context .jwt_parameters .jwt_lifetime_seconds ,
496
+ "iat" : now ,
497
+ "jti" : str (uuid4 ()),
498
+ }
499
+ claims .update (self .context .jwt_parameters .claims or {})
500
+
501
+ assertion = jwt .encode (
502
+ claims ,
503
+ self .context .jwt_parameters .jwt_signing_key ,
504
+ algorithm = self .context .jwt_parameters .jwt_signing_algorithm or "RS256" ,
505
+ )
506
+
507
+ token_data = {
508
+ "grant_type" : "urn:ietf:params:oauth:grant-type:jwt-bearer" ,
509
+ "assertion" : assertion ,
510
+ }
511
+
512
+ if self .context .should_include_resource_param (self .context .protocol_version ):
513
+ token_data ["resource" ] = self .context .get_resource_url ()
514
+
515
+ if self .context .client_metadata .scope :
516
+ token_data ["scope" ] = self .context .client_metadata .scope
517
+
518
+ return httpx .Request (
519
+ "POST" , token_url , data = token_data , headers = {"Content-Type" : "application/x-www-form-urlencoded" }
520
+ )
521
+
445
522
async def _handle_token_response (self , response : httpx .Response ) -> None :
446
523
"""Handle token exchange response."""
447
524
if response .status_code != 200 :
0 commit comments