1
1
from datetime import datetime
2
2
from datetime import timedelta
3
+ from typing import Any
3
4
from typing import Awaitable
4
5
from typing import Callable
5
6
from typing import Dict
6
7
from typing import List
7
8
from typing import Optional
9
+ from typing import Sequence
8
10
from typing import Tuple
9
11
from typing import Union
10
12
21
23
from starlette .types import Scope
22
24
from starlette .types import Send
23
25
26
+ from .claims import Claims
24
27
from .client import OAuth2Client
25
28
from .config import OAuth2Config
26
29
from .core import OAuth2Core
@@ -36,6 +39,17 @@ class Auth(AuthCredentials):
36
39
scopes : List [str ]
37
40
clients : Dict [str , OAuth2Core ] = {}
38
41
42
+ provider : str
43
+ default_provider : str = "local"
44
+
45
+ def __init__ (
46
+ self ,
47
+ scopes : Optional [Sequence [str ]] = None ,
48
+ provider : str = default_provider ,
49
+ ) -> None :
50
+ super ().__init__ (scopes )
51
+ self .provider = provider
52
+
39
53
@classmethod
40
54
def set_http (cls , http : bool ) -> None :
41
55
cls .http = http
@@ -79,19 +93,29 @@ def is_authenticated(self) -> bool:
79
93
80
94
@property
81
95
def display_name (self ) -> str :
82
- return self .get ("display_name" , "" ) # name
96
+ return self .__getprop__ ("display_name" )
83
97
84
98
@property
85
99
def identity (self ) -> str :
86
- return self .get ("identity" , "" ) # username
100
+ return self .__getprop__ ("identity" )
87
101
88
102
@property
89
103
def picture (self ) -> str :
90
- return self .get ("picture" , "" ) # image
104
+ return self .__getprop__ ("picture" )
91
105
92
106
@property
93
107
def email (self ) -> str :
94
- return self .get ("email" , "" ) # email
108
+ return self .__getprop__ ("email" )
109
+
110
+ def use_claims (self , claims : Claims ) -> "User" :
111
+ for attr , item in claims .items ():
112
+ self [attr ] = self .__getprop__ (item )
113
+ return self
114
+
115
+ def __getprop__ (self , item , default = "" ) -> Any :
116
+ if callable (item ):
117
+ return item (self )
118
+ return self .get (item , default )
95
119
96
120
97
121
class OAuth2Backend (AuthenticationBackend ):
@@ -120,8 +144,12 @@ async def authenticate(self, request: Request) -> Optional[Tuple[Auth, User]]:
120
144
if not scheme or not param :
121
145
return Auth (), User ()
122
146
123
- user = Auth .jwt_decode (param )
124
- auth , user = Auth (user .pop ("scope" , [])), User (user )
147
+ user = User (Auth .jwt_decode (param ))
148
+ user .update (provider = user .get ("provider" , Auth .default_provider ))
149
+ auth = Auth (user .pop ("scope" , []), user .get ("provider" ))
150
+ client = Auth .clients .get (auth .provider )
151
+ claims = client .claims if client else Claims ()
152
+ user = user .use_claims (claims )
125
153
126
154
# Call the callback function on authentication
127
155
if callable (self .callback ):
0 commit comments