Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion fastpubsub/api/app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""FastAPI application setup and configuration."""

from fastapi import FastAPI, Request, status
from fastapi.responses import ORJSONResponse
from fastapi.responses import JSONResponse, ORJSONResponse
from prometheus_fastapi_instrumentator import Instrumentator

from fastpubsub import models
Expand Down Expand Up @@ -132,6 +132,25 @@ def invalid_client_token_exception_handler(request: Request, exc: InvalidClientT
"""
return _create_error_response(models.GenericError, status.HTTP_403_FORBIDDEN, exc)

@app.exception_handler(Exception)
def generic_exception_handler(request: Request, exc: Exception):
"""Handle generic Exception instances.

Catches any unhandled exceptions that don't have specific handlers.
Returns a generic 500 Internal Server Error response to avoid leaking
sensitive information about the application internals.

Args:
request: The incoming HTTP request that caused the exception.
exc: The unhandled exception that was raised.

Returns:
JSON error response with 500 status code and generic error message.
"""
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"detail": "internal server error"}
)

# Add routers
app.include_router(topics.router)
app.include_router(subscriptions.router)
Expand Down
47 changes: 44 additions & 3 deletions fastpubsub/services/auth.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Authentication and authorization services for fastpubsub."""

import time
from typing import Annotated

from fastapi import Depends, Request
Expand All @@ -8,8 +9,11 @@
from fastpubsub import services
from fastpubsub.config import settings
from fastpubsub.exceptions import InvalidClientToken
from fastpubsub.logger import get_logger
from fastpubsub.models import DecodedClientToken

logger = get_logger(__name__)

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/oauth/token", auto_error=False)


Expand Down Expand Up @@ -54,9 +58,46 @@ async def get_current_token(token: str | None = Depends(oauth2_scheme)) -> Decod
Raises:
InvalidClientToken: If token is invalid or authentication fails.
"""
if token is None:
token = ""
return await services.decode_jwt_client_token(token, auth_enabled=settings.auth_enabled)
start_time = time.perf_counter()

try:
if token is None:
token = ""

decoded_token = await services.decode_jwt_client_token(token, auth_enabled=settings.auth_enabled)

duration = time.perf_counter() - start_time
logger.debug(
"token validated",
extra={
"client_id": str(decoded_token.client_id),
"scopes": list(decoded_token.scopes),
"duration": f"{duration:.4f}s",
},
)
return decoded_token
except InvalidClientToken as e:
duration = time.perf_counter() - start_time
logger.warning(
"token validation failed",
extra={
"error": str(e),
"has_token": token is not None and token != "",
"duration": f"{duration:.4f}s",
},
)
raise
except Exception as e:
duration = time.perf_counter() - start_time
logger.error(
"token validation error",
extra={
"error": str(e),
"has_token": token is not None and token != "",
"duration": f"{duration:.4f}s",
},
)
raise


def require_scope(resource: str, action: str):
Expand Down
206 changes: 152 additions & 54 deletions fastpubsub/services/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import datetime
import secrets
import time
import uuid

from jose import jwt
Expand All @@ -13,6 +14,7 @@
from fastpubsub.database import Client as DBClient
from fastpubsub.database import SessionLocal
from fastpubsub.exceptions import InvalidClient
from fastpubsub.logger import get_logger
from fastpubsub.models import (
Client,
ClientToken,
Expand All @@ -24,6 +26,7 @@
from fastpubsub.services.helpers import _delete_entity, _get_entity, utc_now

password_hash = PasswordHash.recommended()
logger = get_logger(__name__)


def generate_secret() -> str:
Expand Down Expand Up @@ -54,25 +57,44 @@ async def create_client(data: CreateClient) -> CreateClientResult:
AlreadyExistsError: If a client with the same ID already exists.
ValueError: If client data validation fails.
"""
async with SessionLocal() as session:
now = utc_now()
secret = generate_secret()
secret_hash = password_hash.hash(secret)
db_client = DBClient(
id=uuid.uuid7(),
name=data.name,
scopes=data.scopes,
is_active=data.is_active,
secret_hash=secret_hash,
token_version=1,
created_at=now,
updated_at=now,
)
session.add(db_client)

await session.commit()
start_time = time.perf_counter()
logger.info(
"creating client",
extra={"client_name": data.name, "scopes": data.scopes, "is_active": data.is_active},
)

return CreateClientResult(id=db_client.id, secret=secret)
try:
async with SessionLocal() as session:
now = utc_now()
secret = generate_secret()
secret_hash = password_hash.hash(secret)
db_client = DBClient(
id=uuid.uuid7(),
name=data.name,
scopes=data.scopes,
is_active=data.is_active,
secret_hash=secret_hash,
token_version=1,
created_at=now,
updated_at=now,
)
session.add(db_client)

await session.commit()

duration = time.perf_counter() - start_time
logger.info(
"client created",
extra={"client_id": str(db_client.id), "client_name": data.name, "duration": f"{duration:.4f}s"},
)
return CreateClientResult(id=db_client.id, secret=secret)
except Exception as e:
duration = time.perf_counter() - start_time
logger.error(
"client creation failed",
extra={"client_name": data.name, "error": str(e), "duration": f"{duration:.4f}s"},
)
raise


async def get_client(client_id: uuid.UUID) -> Client:
Expand Down Expand Up @@ -175,31 +197,66 @@ async def issue_jwt_client_token(client_id: uuid.UUID, client_secret: str) -> Cl
Raises:
InvalidClient: If client credentials are invalid or client is disabled.
"""
async with SessionLocal() as session:
db_client = await _get_entity(session, DBClient, client_id, "Client not found", raise_exception=False)
if not db_client:
raise InvalidClient("Client not found") from None
if not db_client.is_active:
raise InvalidClient("Client disabled") from None
if password_hash.verify(client_secret, db_client.secret_hash) is False:
raise InvalidClient("Client secret is invalid") from None

now = utc_now()
expires_in = now + datetime.timedelta(minutes=settings.auth_access_token_expire_minutes)
payload = {
"sub": str(client_id),
"exp": expires_in,
"iat": now,
"scope": db_client.scopes,
"ver": db_client.token_version,
}
access_token = jwt.encode(payload, key=settings.auth_secret_key, algorithm=settings.auth_algorithm)

return ClientToken(
access_token=access_token,
expires_in=int((expires_in - now).total_seconds()),
scope=db_client.scopes,
)
start_time = time.perf_counter()
logger.info("issuing jwt token", extra={"client_id": str(client_id)})

try:
async with SessionLocal() as session:
db_client = await _get_entity(
session, DBClient, client_id, "Client not found", raise_exception=False
)
if not db_client:
logger.warning("token issuance failed: client not found", extra={"client_id": str(client_id)})
raise InvalidClient("Client not found") from None
if not db_client.is_active:
logger.warning(
"token issuance failed: client disabled",
extra={"client_id": str(client_id), "client_name": db_client.name},
)
raise InvalidClient("Client disabled") from None
if password_hash.verify(client_secret, db_client.secret_hash) is False:
logger.warning(
"token issuance failed: invalid secret",
extra={"client_id": str(client_id), "client_name": db_client.name},
)
raise InvalidClient("Client secret is invalid") from None

now = utc_now()
expires_in = now + datetime.timedelta(minutes=settings.auth_access_token_expire_minutes)
payload = {
"sub": str(client_id),
"exp": expires_in,
"iat": now,
"scope": db_client.scopes,
"ver": db_client.token_version,
}
access_token = jwt.encode(
payload, key=settings.auth_secret_key, algorithm=settings.auth_algorithm
)

duration = time.perf_counter() - start_time
logger.info(
"jwt token issued",
extra={
"client_id": str(client_id),
"client_name": db_client.name,
"scopes": db_client.scopes,
"expires_in_minutes": settings.auth_access_token_expire_minutes,
"duration": f"{duration:.4f}s",
},
)
return ClientToken(
access_token=access_token,
expires_in=int((expires_in - now).total_seconds()),
scope=db_client.scopes,
)
except Exception as e:
duration = time.perf_counter() - start_time
logger.error(
"token issuance failed",
extra={"client_id": str(client_id), "error": str(e), "duration": f"{duration:.4f}s"},
)
raise


async def decode_jwt_client_token(access_token: str, auth_enabled: bool = True) -> DecodedClientToken:
Expand All @@ -219,28 +276,69 @@ async def decode_jwt_client_token(access_token: str, auth_enabled: bool = True)
InvalidClient: If token is invalid, expired, or client is disabled/revoked.
"""
if not auth_enabled:
logger.debug("authentication disabled, returning test token")
return DecodedClientToken(client_id=uuid.uuid7(), scopes={"*"})

start_time = time.perf_counter()
logger.debug("decoding jwt token")

try:
payload = jwt.decode(
access_token,
key=settings.auth_secret_key,
algorithms=[settings.auth_algorithm],
)
except JWTError:
except JWTError as e:
logger.warning("jwt token decode failed: invalid token", extra={"error": str(e)})
raise InvalidClient("Invalid jwt token") from None

client_id = payload["sub"]
scopes = payload["scope"]
token_version = payload["ver"]

async with SessionLocal() as session:
db_client = await _get_entity(session, DBClient, client_id, "Client not found", raise_exception=False)
if not db_client:
raise InvalidClient("Client not found") from None
if not db_client.is_active:
raise InvalidClient("Client disabled") from None
if token_version != db_client.token_version:
raise InvalidClient("Token revoked") from None

return DecodedClientToken(client_id=uuid.UUID(client_id), scopes={scope for scope in scopes.split()})
try:
async with SessionLocal() as session:
db_client = await _get_entity(
session, DBClient, client_id, "Client not found", raise_exception=False
)
if not db_client:
logger.warning(
"jwt token validation failed: client not found", extra={"client_id": client_id}
)
raise InvalidClient("Client not found") from None
if not db_client.is_active:
logger.warning(
"jwt token validation failed: client disabled",
extra={"client_id": client_id, "client_name": db_client.name},
)
raise InvalidClient("Client disabled") from None
if token_version != db_client.token_version:
logger.warning(
"jwt token validation failed: token revoked",
extra={
"client_id": client_id,
"client_name": db_client.name,
"token_version": token_version,
"current_version": db_client.token_version,
},
)
raise InvalidClient("Token revoked") from None

duration = time.perf_counter() - start_time
logger.debug(
"jwt token validated",
extra={
"client_id": client_id,
"client_name": db_client.name,
"scopes": scopes,
"duration": f"{duration:.4f}s",
},
)
return DecodedClientToken(client_id=uuid.UUID(client_id), scopes={scope for scope in scopes.split()})
except Exception as e:
duration = time.perf_counter() - start_time
logger.error(
"jwt token validation failed",
extra={"client_id": client_id, "error": str(e), "duration": f"{duration:.4f}s"},
)
raise
Loading