|
17 | 17 | ) |
18 | 18 | from starlette.requests import HTTPConnection |
19 | 19 |
|
20 | | -from fastapi_azure_auth.exceptions import InvalidAuth, InvalidAuthHttp, InvalidAuthWebSocket |
| 20 | +from fastapi_azure_auth.exceptions import ( |
| 21 | + Forbidden, |
| 22 | + ForbiddenHttp, |
| 23 | + ForbiddenWebSocket, |
| 24 | + InvalidRequest, |
| 25 | + InvalidRequestHttp, |
| 26 | + Unauthorized, |
| 27 | + UnauthorizedHttp, |
| 28 | + UnauthorizedWebSocket, |
| 29 | +) |
21 | 30 | from fastapi_azure_auth.openid_config import OpenIdConfig |
22 | 31 | from fastapi_azure_auth.user import User |
23 | 32 | from fastapi_azure_auth.utils import get_unverified_claims, get_unverified_header, is_guest |
@@ -148,28 +157,28 @@ async def __call__(self, request: HTTPConnection, security_scopes: SecurityScope |
148 | 157 | access_token = await self.extract_access_token(request) |
149 | 158 | try: |
150 | 159 | if access_token is None: |
151 | | - raise InvalidAuth('No access token provided', request=request) |
| 160 | + raise InvalidRequest('No access token provided', request=request) |
152 | 161 | # Extract header information of the token. |
153 | 162 | header: dict[str, Any] = get_unverified_header(access_token) |
154 | 163 | claims: dict[str, Any] = get_unverified_claims(access_token) |
155 | 164 | except Exception as error: |
156 | 165 | log.warning('Malformed token received. %s. Error: %s', access_token, error, exc_info=True) |
157 | | - raise InvalidAuth(detail='Invalid token format', request=request) from error |
| 166 | + raise Unauthorized(detail='Invalid token format', request=request) from error |
158 | 167 |
|
159 | 168 | user_is_guest: bool = is_guest(claims=claims) |
160 | 169 | if not self.allow_guest_users and user_is_guest: |
161 | 170 | log.info('User denied, is a guest user', claims) |
162 | | - raise InvalidAuth(detail='Guest users not allowed', request=request) |
| 171 | + raise Forbidden(detail='Guest users not allowed', request=request) |
163 | 172 |
|
164 | 173 | for scope in security_scopes.scopes: |
165 | 174 | token_scope_string = claims.get('scp', '') |
166 | 175 | log.debug('Scopes: %s', token_scope_string) |
167 | 176 | if not isinstance(token_scope_string, str): |
168 | | - raise InvalidAuth('Token contains invalid formatted scopes', request=request) |
| 177 | + raise Forbidden('Token contains invalid formatted scopes', request=request) |
169 | 178 |
|
170 | 179 | token_scopes = token_scope_string.split(' ') |
171 | 180 | if scope not in token_scopes: |
172 | | - raise InvalidAuth('Required scope missing', request=request) |
| 181 | + raise Forbidden('Required scope missing', request=request) |
173 | 182 | # Load new config if old |
174 | 183 | await self.openid_config.load_config() |
175 | 184 |
|
@@ -211,27 +220,34 @@ async def __call__(self, request: HTTPConnection, security_scopes: SecurityScope |
211 | 220 | MissingRequiredClaimError, |
212 | 221 | ) as error: |
213 | 222 | log.info('Token contains invalid claims. %s', error) |
214 | | - raise InvalidAuth(detail='Token contains invalid claims', request=request) from error |
| 223 | + raise Unauthorized(detail='Token contains invalid claims', request=request) from error |
215 | 224 | except ExpiredSignatureError as error: |
216 | 225 | log.info('Token signature has expired. %s', error) |
217 | | - raise InvalidAuth(detail='Token signature has expired', request=request) from error |
| 226 | + raise Unauthorized(detail='Token signature has expired', request=request) from error |
218 | 227 | except InvalidTokenError as error: |
219 | 228 | log.warning('Invalid token. Error: %s', error, exc_info=True) |
220 | | - raise InvalidAuth(detail='Unable to validate token', request=request) from error |
| 229 | + raise Unauthorized(detail='Unable to validate token', request=request) from error |
221 | 230 | except Exception as error: |
222 | 231 | # Extra failsafe in case of a bug in a future version of the jwt library |
223 | 232 | log.exception('Unable to process jwt token. Uncaught error: %s', error) |
224 | | - raise InvalidAuth(detail='Unable to process token', request=request) from error |
| 233 | + raise Unauthorized(detail='Unable to process token', request=request) from error |
225 | 234 | log.warning('Unable to verify token. No signing keys found') |
226 | | - raise InvalidAuth(detail='Unable to verify token, no signing keys found', request=request) |
227 | | - except (InvalidAuthHttp, InvalidAuthWebSocket, HTTPException): |
| 235 | + raise Unauthorized(detail='Unable to verify token, no signing keys found', request=request) |
| 236 | + except ( |
| 237 | + InvalidRequestHttp, |
| 238 | + UnauthorizedHttp, |
| 239 | + UnauthorizedWebSocket, |
| 240 | + ForbiddenHttp, |
| 241 | + ForbiddenWebSocket, |
| 242 | + HTTPException, |
| 243 | + ): |
228 | 244 | if not self.auto_error: |
229 | 245 | return None |
230 | 246 | raise |
231 | 247 | except Exception as error: |
232 | 248 | if not self.auto_error: |
233 | 249 | return None |
234 | | - raise InvalidAuth(detail='Unable to validate token', request=request) from error |
| 250 | + raise InvalidRequest(detail='Unable to validate token', request=request) from error |
235 | 251 |
|
236 | 252 | async def extract_access_token(self, request: HTTPConnection) -> Optional[str]: |
237 | 253 | """ |
|
0 commit comments