Skip to content

Commit 86ff69f

Browse files
committed
GH-9: Handle claims mapping for certain provider
1 parent ead1086 commit 86ff69f

File tree

1 file changed

+34
-6
lines changed

1 file changed

+34
-6
lines changed

src/fastapi_oauth2/middleware.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from datetime import datetime
22
from datetime import timedelta
3+
from typing import Any
34
from typing import Awaitable
45
from typing import Callable
56
from typing import Dict
67
from typing import List
78
from typing import Optional
9+
from typing import Sequence
810
from typing import Tuple
911
from typing import Union
1012

@@ -21,6 +23,7 @@
2123
from starlette.types import Scope
2224
from starlette.types import Send
2325

26+
from .claims import Claims
2427
from .client import OAuth2Client
2528
from .config import OAuth2Config
2629
from .core import OAuth2Core
@@ -36,6 +39,17 @@ class Auth(AuthCredentials):
3639
scopes: List[str]
3740
clients: Dict[str, OAuth2Core] = {}
3841

42+
provider: str
43+
default_provider: str = "local"
44+
45+
def __init__(
46+
self,
47+
scopes: Optional[Sequence[str]] = None,
48+
provider: str = default_provider,
49+
) -> None:
50+
super().__init__(scopes)
51+
self.provider = provider
52+
3953
@classmethod
4054
def set_http(cls, http: bool) -> None:
4155
cls.http = http
@@ -79,19 +93,29 @@ def is_authenticated(self) -> bool:
7993

8094
@property
8195
def display_name(self) -> str:
82-
return self.get("display_name", "") # name
96+
return self.__getprop__("display_name")
8397

8498
@property
8599
def identity(self) -> str:
86-
return self.get("identity", "") # username
100+
return self.__getprop__("identity")
87101

88102
@property
89103
def picture(self) -> str:
90-
return self.get("picture", "") # image
104+
return self.__getprop__("picture")
91105

92106
@property
93107
def email(self) -> str:
94-
return self.get("email", "") # email
108+
return self.__getprop__("email")
109+
110+
def use_claims(self, claims: Claims) -> "User":
111+
for attr, item in claims.items():
112+
self[attr] = self.__getprop__(item)
113+
return self
114+
115+
def __getprop__(self, item, default="") -> Any:
116+
if callable(item):
117+
return item(self)
118+
return self.get(item, default)
95119

96120

97121
class OAuth2Backend(AuthenticationBackend):
@@ -120,8 +144,12 @@ async def authenticate(self, request: Request) -> Optional[Tuple[Auth, User]]:
120144
if not scheme or not param:
121145
return Auth(), User()
122146

123-
user = Auth.jwt_decode(param)
124-
auth, user = Auth(user.pop("scope", [])), User(user)
147+
user = User(Auth.jwt_decode(param))
148+
user.update(provider=user.get("provider", Auth.default_provider))
149+
auth = Auth(user.pop("scope", []), user.get("provider"))
150+
client = Auth.clients.get(auth.provider)
151+
claims = client.claims if client else Claims()
152+
user = user.use_claims(claims)
125153

126154
# Call the callback function on authentication
127155
if callable(self.callback):

0 commit comments

Comments
 (0)