Skip to content

Commit 8f388af

Browse files
davidhuserJonasKs
authored andcommitted
fix: mypy issue ony incompatible types in assignment
1 parent 52a5bc0 commit 8f388af

File tree

14 files changed

+299
-118
lines changed

14 files changed

+299
-118
lines changed

demo_project/api/dependencies.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
MultiTenantAzureAuthorizationCodeBearer,
1212
SingleTenantAzureAuthorizationCodeBearer,
1313
)
14-
from fastapi_azure_auth.exceptions import InvalidAuthHttp
14+
from fastapi_azure_auth.exceptions import ForbiddenHttp, UnauthorizedHttp
1515
from fastapi_azure_auth.user import User
1616

1717
log = logging.getLogger(__name__)
@@ -30,7 +30,7 @@ async def validate_is_admin_user(user: User = Depends(azure_scheme)) -> None:
3030
Raises a 401 authentication error if not.
3131
"""
3232
if 'AdminUser' not in user.roles:
33-
raise InvalidAuthHttp('User is not an AdminUser')
33+
raise ForbiddenHttp('User is not an AdminUser')
3434

3535

3636
class IssuerFetcher:
@@ -44,7 +44,7 @@ def __init__(self) -> None:
4444
async def __call__(self, tid: str) -> str:
4545
"""
4646
Check if memory cache needs to be updated or not, and then returns an issuer for a given tenant
47-
:raises InvalidAuth when it's not a valid tenant
47+
:raises Unauthorized when it's not a valid tenant
4848
"""
4949
refresh_time = datetime.now() - timedelta(hours=1)
5050
if not self._config_timestamp or self._config_timestamp < refresh_time:
@@ -58,7 +58,7 @@ async def __call__(self, tid: str) -> str:
5858
return self.tid_to_iss[tid]
5959
except Exception as error:
6060
log.exception('`iss` not found for `tid` %s. Error %s', tid, error)
61-
raise InvalidAuthHttp('You must be an Intility customer to access this resource')
61+
raise UnauthorizedHttp('You must be an Intility customer to access this resource')
6262

6363

6464
issuer_fetcher = IssuerFetcher()
@@ -101,7 +101,7 @@ async def multi_auth(
101101
return azure_auth
102102
if api_key == 'JonasIsCool':
103103
return api_key
104-
raise InvalidAuthHttp('You must either provide a valid bearer token or API key')
104+
raise UnauthorizedHttp('You must either provide a valid bearer token or API key')
105105

106106

107107
async def multi_auth_b2c(
@@ -115,4 +115,4 @@ async def multi_auth_b2c(
115115
return azure_auth
116116
if api_key == 'JonasIsCool':
117117
return api_key
118-
raise InvalidAuthHttp('You must either provide a valid bearer token or API key')
118+
raise UnauthorizedHttp('You must either provide a valid bearer token or API key')

demo_project/core/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ class AzureActiveDirectory(BaseSettings): # type: ignore[misc, valid-type]
1313
OPENAPI_CLIENT_ID: str = Field(default='')
1414
TENANT_ID: str = Field(default='')
1515
APP_CLIENT_ID: str = Field(default='')
16-
AUTH_URL: AnyHttpUrl = Field(default='https://dummy.com/')
17-
CONFIG_URL: AnyHttpUrl = Field(default='https://dummy.com/')
18-
TOKEN_URL: AnyHttpUrl = Field(default='https://dummy.com/')
16+
AUTH_URL: AnyHttpUrl = Field(default=AnyHttpUrl('https://dummy.com/'))
17+
CONFIG_URL: AnyHttpUrl = Field(default=AnyHttpUrl('https://dummy.com/'))
18+
TOKEN_URL: AnyHttpUrl = Field(default=AnyHttpUrl('https://dummy.com/'))
1919
GRAPH_SECRET: str = Field(default='')
2020
CLIENT_SECRET: str = Field(default='')
2121

docs/docs/multi-tenant/accept_specific_tenants_only.mdx

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ from fastapi.middleware.cors import CORSMiddleware
1515
from pydantic import AnyHttpUrl
1616
from pydantic_settings import BaseSettings, SettingsConfigDict
1717
from fastapi_azure_auth import MultiTenantAzureAuthorizationCodeBearer
18-
from fastapi_azure_auth.exceptions import InvalidAuth
18+
from fastapi_azure_auth.exceptions import Unauthorized
1919

2020

2121
class Settings(BaseSettings):
@@ -56,7 +56,7 @@ async def check_if_valid_tenant(tid: str) -> str:
5656
try:
5757
return tid_to_iss_mapping[tid]
5858
except KeyError:
59-
raise InvalidAuth('Tenant not allowed')
59+
raise Unauthorized('Tenant not allowed')
6060

6161
azure_scheme = MultiTenantAzureAuthorizationCodeBearer(
6262
app_client_id=settings.APP_CLIENT_ID,
@@ -86,7 +86,7 @@ if __name__ == '__main__':
8686
```
8787

8888
We're first creating an `async function`, which takes a `tid` as an argument, and returns the tenant ID's `iss` if it's a valid tenant.
89-
If it's not a valid tenant, it has to raise an `InvalidAuth()` exception.
89+
If it's not a valid tenant, it has to raise an `Unauthorized()` exception.
9090

9191
## More sophisticated callable
9292
If you want to cache these results in memory, you can do so by creating a more sophisticated callable:
@@ -103,7 +103,7 @@ class IssuerFetcher:
103103
async def __call__(self, tid: str) -> str:
104104
"""
105105
Check if memory cache needs to be updated or not, and then returns an issuer for a given tenant
106-
:raises InvalidAuth when it's not a valid tenant
106+
:raises Unauthorized when it's not a valid tenant
107107
"""
108108
refresh_time = datetime.now() - timedelta(hours=1)
109109
if not self._config_timestamp or self._config_timestamp < refresh_time:
@@ -117,7 +117,7 @@ class IssuerFetcher:
117117
return self.tid_to_iss[tid]
118118
except Exception as error:
119119
log.exception('`iss` not found for `tid` %s. Error %s', tid, error)
120-
raise InvalidAuth('You must be an Intility customer to access this resource')
120+
raise Unauthorized('You must be an Intility customer to access this resource')
121121

122122

123123
issuer_fetcher = IssuerFetcher()

docs/docs/usage-and-faq/guest_users.mdx

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,15 @@ would like to lock down specific endpoints.
3939

4040
```python title="security.py"
4141
from fastapi import Depends
42-
from fastapi_azure_auth.exceptions import InvalidAuth
42+
from fastapi_azure_auth.exceptions import Unauthorized
4343
from fastapi_azure_auth.user import User
4444

4545
async def deny_guest_users(user: User = Depends(azure_scheme)) -> None:
4646
"""
4747
Deny guest users
4848
"""
4949
if user.is_guest:
50-
raise InvalidAuth('Guest user not allowed')
50+
raise Unauthorized('Guest user not allowed')
5151
```
5252

5353

@@ -57,15 +57,15 @@ Alternatively, after [FastAPI 0.95.0](https://github.com/tiangolo/fastapi/releas
5757
```python title="security.py"
5858
from typing import Annotated
5959
from fastapi import Depends
60-
from fastapi_azure_auth.exceptions import InvalidAuth
60+
from fastapi_azure_auth.exceptions import Unauthorized
6161
from fastapi_azure_auth.user import User
6262

6363
async def deny_guest_users(user: User = Depends(azure_scheme)) -> None:
6464
"""
6565
Deny guest users
6666
"""
6767
if user.is_guest:
68-
raise InvalidAuth('Guest user not allowed')
68+
raise Unauthorized('Guest user not allowed')
6969

7070
NonGuestUser = Annotated[User, Depends(deny_guest_users)]
7171
```

docs/docs/usage-and-faq/locking_down_on_roles.mdx

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ You can lock down on roles by creating your own wrapper dependency:
3030

3131
```python title="dependencies.py"
3232
from fastapi import Depends
33-
from fastapi_azure_auth.exceptions import InvalidAuth
33+
from fastapi_azure_auth.exceptions import Unauthorized
3434
from fastapi_azure_auth.user import User
3535

3636
async def validate_is_admin_user(user: User = Depends(azure_scheme)) -> None:
@@ -39,7 +39,7 @@ async def validate_is_admin_user(user: User = Depends(azure_scheme)) -> None:
3939
Raises a 401 authentication error if not.
4040
"""
4141
if 'AdminUser' not in user.roles:
42-
raise InvalidAuth('User is not an AdminUser')
42+
raise Unauthorized('User is not an AdminUser')
4343
```
4444

4545
and then use this dependency over `azure_scheme`.
@@ -51,7 +51,7 @@ Alternatively, after [FastAPI 0.95.0](https://github.com/tiangolo/fastapi/releas
5151
```python title="security.py"
5252
from typing import Annotated
5353
from fastapi import Depends
54-
from fastapi_azure_auth.exceptions import InvalidAuth
54+
from fastapi_azure_auth.exceptions import Unauthorized
5555
from fastapi_azure_auth.user import User
5656

5757
async def validate_is_admin_user(user: User = Depends(azure_scheme)) -> None:
@@ -60,7 +60,7 @@ async def validate_is_admin_user(user: User = Depends(azure_scheme)) -> None:
6060
Raises a 401 authentication error if not.
6161
"""
6262
if 'AdminUser' not in user.roles:
63-
raise InvalidAuth('User is not an AdminUser')
63+
raise Unauthorized('User is not an AdminUser')
6464

6565
AdminUser = Annotated[User, Depends(validate_is_admin_user)]
6666
```

fastapi_azure_auth/auth.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,16 @@
1717
)
1818
from starlette.requests import HTTPConnection
1919

20-
from fastapi_azure_auth.exceptions import InvalidAuth, InvalidAuthHttp, InvalidAuthWebSocket
20+
from fastapi_azure_auth.exceptions import (
21+
Forbidden,
22+
ForbiddenHttp,
23+
ForbiddenWebSocket,
24+
InvalidRequest,
25+
InvalidRequestHttp,
26+
Unauthorized,
27+
UnauthorizedHttp,
28+
UnauthorizedWebSocket,
29+
)
2130
from fastapi_azure_auth.openid_config import OpenIdConfig
2231
from fastapi_azure_auth.user import User
2332
from fastapi_azure_auth.utils import get_unverified_claims, get_unverified_header, is_guest
@@ -148,28 +157,28 @@ async def __call__(self, request: HTTPConnection, security_scopes: SecurityScope
148157
access_token = await self.extract_access_token(request)
149158
try:
150159
if access_token is None:
151-
raise InvalidAuth('No access token provided', request=request)
160+
raise InvalidRequest('No access token provided', request=request)
152161
# Extract header information of the token.
153162
header: dict[str, Any] = get_unverified_header(access_token)
154163
claims: dict[str, Any] = get_unverified_claims(access_token)
155164
except Exception as error:
156165
log.warning('Malformed token received. %s. Error: %s', access_token, error, exc_info=True)
157-
raise InvalidAuth(detail='Invalid token format', request=request) from error
166+
raise Unauthorized(detail='Invalid token format', request=request) from error
158167

159168
user_is_guest: bool = is_guest(claims=claims)
160169
if not self.allow_guest_users and user_is_guest:
161170
log.info('User denied, is a guest user', claims)
162-
raise InvalidAuth(detail='Guest users not allowed', request=request)
171+
raise Forbidden(detail='Guest users not allowed', request=request)
163172

164173
for scope in security_scopes.scopes:
165174
token_scope_string = claims.get('scp', '')
166175
log.debug('Scopes: %s', token_scope_string)
167176
if not isinstance(token_scope_string, str):
168-
raise InvalidAuth('Token contains invalid formatted scopes', request=request)
177+
raise Forbidden('Token contains invalid formatted scopes', request=request)
169178

170179
token_scopes = token_scope_string.split(' ')
171180
if scope not in token_scopes:
172-
raise InvalidAuth('Required scope missing', request=request)
181+
raise Forbidden('Required scope missing', request=request)
173182
# Load new config if old
174183
await self.openid_config.load_config()
175184

@@ -211,27 +220,34 @@ async def __call__(self, request: HTTPConnection, security_scopes: SecurityScope
211220
MissingRequiredClaimError,
212221
) as error:
213222
log.info('Token contains invalid claims. %s', error)
214-
raise InvalidAuth(detail='Token contains invalid claims', request=request) from error
223+
raise Unauthorized(detail='Token contains invalid claims', request=request) from error
215224
except ExpiredSignatureError as error:
216225
log.info('Token signature has expired. %s', error)
217-
raise InvalidAuth(detail='Token signature has expired', request=request) from error
226+
raise Unauthorized(detail='Token signature has expired', request=request) from error
218227
except InvalidTokenError as error:
219228
log.warning('Invalid token. Error: %s', error, exc_info=True)
220-
raise InvalidAuth(detail='Unable to validate token', request=request) from error
229+
raise Unauthorized(detail='Unable to validate token', request=request) from error
221230
except Exception as error:
222231
# Extra failsafe in case of a bug in a future version of the jwt library
223232
log.exception('Unable to process jwt token. Uncaught error: %s', error)
224-
raise InvalidAuth(detail='Unable to process token', request=request) from error
233+
raise Unauthorized(detail='Unable to process token', request=request) from error
225234
log.warning('Unable to verify token. No signing keys found')
226-
raise InvalidAuth(detail='Unable to verify token, no signing keys found', request=request)
227-
except (InvalidAuthHttp, InvalidAuthWebSocket, HTTPException):
235+
raise Unauthorized(detail='Unable to verify token, no signing keys found', request=request)
236+
except (
237+
InvalidRequestHttp,
238+
UnauthorizedHttp,
239+
UnauthorizedWebSocket,
240+
ForbiddenHttp,
241+
ForbiddenWebSocket,
242+
HTTPException,
243+
):
228244
if not self.auto_error:
229245
return None
230246
raise
231247
except Exception as error:
232248
if not self.auto_error:
233249
return None
234-
raise InvalidAuth(detail='Unable to validate token', request=request) from error
250+
raise InvalidRequest(detail='Unable to validate token', request=request) from error
235251

236252
async def extract_access_token(self, request: HTTPConnection) -> Optional[str]:
237253
"""

fastapi_azure_auth/exceptions.py

Lines changed: 64 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,33 +4,80 @@
44
from starlette.requests import HTTPConnection
55

66

7-
class InvalidAuthHttp(HTTPException):
8-
"""
9-
Exception raised when the user is not authorized over HTTP
10-
"""
7+
class InvalidRequestHttp(HTTPException):
8+
"""HTTP exception for malformed/invalid requests"""
119

1210
def __init__(self, detail: str) -> None:
1311
super().__init__(
14-
status_code=status.HTTP_401_UNAUTHORIZED, detail=detail, headers={'WWW-Authenticate': 'Bearer'}
12+
status_code=status.HTTP_400_BAD_REQUEST, detail={"error": "invalid_request", "message": detail}
1513
)
1614

1715

18-
class InvalidAuthWebSocket(WebSocketException):
19-
"""
20-
Exception raised when the user is not authorized over WebSockets
21-
"""
16+
class InvalidRequestWebSocket(WebSocketException):
17+
"""WebSocket exception for malformed/invalid requests"""
2218

2319
def __init__(self, detail: str) -> None:
2420
super().__init__(
25-
code=status.WS_1008_POLICY_VIOLATION,
26-
reason=detail,
21+
code=status.WS_1008_POLICY_VIOLATION, reason=str({"error": "invalid_request", "message": detail})
2722
)
2823

2924

30-
def InvalidAuth(detail: str, request: HTTPConnection) -> InvalidAuthHttp | InvalidAuthWebSocket:
31-
"""
32-
Returns the correct exception based on the connection type
33-
"""
25+
class UnauthorizedHttp(HTTPException):
26+
"""HTTP exception for authentication failures"""
27+
28+
def __init__(self, detail: str) -> None:
29+
super().__init__(
30+
status_code=status.HTTP_401_UNAUTHORIZED,
31+
detail={"error": "invalid_token", "message": detail},
32+
headers={"WWW-Authenticate": "Bearer"},
33+
)
34+
35+
36+
class UnauthorizedWebSocket(WebSocketException):
37+
"""WebSocket exception for authentication failures"""
38+
39+
def __init__(self, detail: str) -> None:
40+
super().__init__(
41+
code=status.WS_1008_POLICY_VIOLATION, reason=str({"error": "invalid_token", "message": detail})
42+
)
43+
44+
45+
class ForbiddenHttp(HTTPException):
46+
"""HTTP exception for insufficient permissions"""
47+
48+
def __init__(self, detail: str) -> None:
49+
super().__init__(
50+
status_code=status.HTTP_403_FORBIDDEN,
51+
detail={"error": "insufficient_scope", "message": detail},
52+
headers={"WWW-Authenticate": "Bearer"},
53+
)
54+
55+
56+
class ForbiddenWebSocket(WebSocketException):
57+
"""WebSocket exception for insufficient permissions"""
58+
59+
def __init__(self, detail: str) -> None:
60+
super().__init__(
61+
code=status.WS_1008_POLICY_VIOLATION, reason=str({"error": "insufficient_scope", "message": detail})
62+
)
63+
64+
65+
def InvalidRequest(detail: str, request: HTTPConnection) -> InvalidRequestHttp | InvalidRequestWebSocket:
66+
"""Factory function for invalid request exceptions (HTTP only, as request validation happens pre-connection)"""
3467
if request.scope['type'] == 'http':
35-
return InvalidAuthHttp(detail)
36-
return InvalidAuthWebSocket(detail)
68+
return InvalidRequestHttp(detail)
69+
return InvalidRequestWebSocket(detail)
70+
71+
72+
def Unauthorized(detail: str, request: HTTPConnection) -> UnauthorizedHttp | UnauthorizedWebSocket:
73+
"""Factory function for unauthorized exceptions"""
74+
if request.scope["type"] == "http":
75+
return UnauthorizedHttp(detail)
76+
return UnauthorizedWebSocket(detail)
77+
78+
79+
def Forbidden(detail: str, request: HTTPConnection) -> ForbiddenHttp | ForbiddenWebSocket:
80+
"""Factory function for forbidden exceptions"""
81+
if request.scope["type"] == "http":
82+
return ForbiddenHttp(detail)
83+
return ForbiddenWebSocket(detail)

0 commit comments

Comments
 (0)