diff --git a/CHANGELOG.md b/CHANGELOG.md index 409030f..c18f53f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## 0.8 + +### Added + +- **Headers**: Added `HTTPSRedirectMiddleware` that redirects HTTP to HTTPS while bypassing requests from `localhost` (matching `^localhost(:\d+)?$`), allowing Kubernetes liveness/readiness probes to work over plain HTTP. + ## 0.7 ### ⚠️ Breaking Changes & Migration Guide diff --git a/README.md b/README.md index 3c5e361..f530d10 100644 --- a/README.md +++ b/README.md @@ -392,8 +392,7 @@ Add in your application: ```python import c2casgiutils -from c2casgiutils import broadcast -from c2casgiutils import config +from c2casgiutils import broadcast, config, headers from prometheus_client import start_http_server from prometheus_fastapi_instrumentator import Instrumentator from contextlib import asynccontextmanager @@ -419,9 +418,9 @@ app.add_middleware( allowed_hosts=["*"], # Configure with specific hosts in production ) -# Add HTTPSRedirectMiddleware -if config.settings.http: - app.add_middleware(HTTPSRedirectMiddleware) +# Redirect HTTP to HTTPS (except for localhost) +if not config.settings.http: + app.add_middleware(headers.HTTPSRedirectMiddleware) # Add GZipMiddleware app.add_middleware(GZipMiddleware, minimum_size=1000) diff --git a/acceptance_tests/fastapi_app/fastapi_app/main.py b/acceptance_tests/fastapi_app/fastapi_app/main.py index 648f433..8fe80d7 100644 --- a/acceptance_tests/fastapi_app/fastapi_app/main.py +++ b/acceptance_tests/fastapi_app/fastapi_app/main.py @@ -9,7 +9,6 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware -from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware from fastapi.middleware.trustedhost import TrustedHostMiddleware from prometheus_client import start_http_server from prometheus_fastapi_instrumentator import Instrumentator @@ -55,9 +54,9 @@ async def _lifespan(main_app: FastAPI) -> AsyncGenerator[None, None]: allowed_hosts=["*"], # Configure with specific hosts in production ) -# Add HTTPSRedirectMiddleware +# Redirect HTTP to HTTPS (except for localhost) if not config.settings.http: - app.add_middleware(HTTPSRedirectMiddleware) + app.add_middleware(headers.HTTPSRedirectMiddleware) # Add GZipMiddleware app.add_middleware(GZipMiddleware, minimum_size=1000) diff --git a/c2casgiutils/headers.py b/c2casgiutils/headers.py index 7acca89..6e8016c 100644 --- a/c2casgiutils/headers.py +++ b/c2casgiutils/headers.py @@ -6,15 +6,17 @@ from typing import TypedDict from pydantic import BaseModel +from starlette.datastructures import URL from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request -from starlette.responses import Response -from starlette.types import ASGIApp +from starlette.responses import RedirectResponse, Response +from starlette.types import ASGIApp, Receive, Scope, Send _LOGGER = logging.getLogger(__name__) # Content type matcher _HTML_CONTENT_TYPE_MATCH = r"^text/html(?:;|$)" +_LOCALHOST_NETLOC_RE = re.compile(r"^localhost(:\d+)?$") Header = str | list[str] | dict[str, str] | dict[str, list[str]] | None @@ -192,6 +194,35 @@ def _build_header( } +class HTTPSRedirectMiddleware: + r""" + Middleware that redirects HTTP requests to HTTPS and WebSocket requests to WSS. + + Requests from ``localhost`` (matching ``^localhost(:\d+)?$``) are passed through + without a redirect so that Kubernetes liveness/readiness probes sent over plain HTTP + continue to work even when HTTPS enforcement is enabled. + """ + + def __init__(self, app: ASGIApp) -> None: + """Initialize the middleware.""" + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """Handle ASGI requests, redirecting HTTP to HTTPS and WS to WSS unless the host is localhost.""" + if scope["type"] in ("http", "websocket"): + url = URL(scope=scope) + if not _LOCALHOST_NETLOC_RE.match(url.netloc): + # Map schemes: http -> https, ws -> wss + scheme_map = {"http": "https", "ws": "wss"} + new_scheme = scheme_map.get(url.scheme) + if new_scheme: + redirect_url = url.replace(scheme=new_scheme) + response = RedirectResponse(url=str(redirect_url), status_code=307) + await response(scope, receive, send) + return + await self.app(scope, receive, send) + + class ArmorHeaderMiddleware(BaseHTTPMiddleware): """Middleware to add headers to responses based on request netloc (host:port) and path.""" diff --git a/scaffold/{{cookiecutter.project_slug}}/{{cookiecutter.project_slug}}/main.py b/scaffold/{{cookiecutter.project_slug}}/{{cookiecutter.project_slug}}/main.py index abda12c..58df076 100644 --- a/scaffold/{{cookiecutter.project_slug}}/{{cookiecutter.project_slug}}/main.py +++ b/scaffold/{{cookiecutter.project_slug}}/{{cookiecutter.project_slug}}/main.py @@ -9,7 +9,6 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware -from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware from fastapi.middleware.trustedhost import TrustedHostMiddleware from prometheus_client import start_http_server from prometheus_fastapi_instrumentator import Instrumentator @@ -56,9 +55,9 @@ async def _lifespan(main_app: FastAPI) -> AsyncGenerator[None, None]: allowed_hosts=["*"], # Configure with specific hosts in production ) -# Add HTTPSRedirectMiddleware +# Redirect HTTP to HTTPS (except for localhost) if not config.settings.http: - app.add_middleware(HTTPSRedirectMiddleware) + app.add_middleware(headers.HTTPSRedirectMiddleware) # Add GZipMiddleware app.add_middleware(GZipMiddleware, minimum_size=1000) diff --git a/test/test_header.py b/test/test_header.py index df9a161..27534ca 100644 --- a/test/test_header.py +++ b/test/test_header.py @@ -6,9 +6,11 @@ import pytest from starlette.applications import Starlette from starlette.requests import Request -from starlette.responses import Response +from starlette.responses import PlainTextResponse, Response +from starlette.routing import Route +from starlette.testclient import TestClient -from c2casgiutils.headers import ArmorHeaderMiddleware, _build_header +from c2casgiutils.headers import ArmorHeaderMiddleware, HTTPSRedirectMiddleware, _build_header class State: @@ -813,3 +815,49 @@ async def simple_app(scope, receive, send): valid_base64 = False assert valid_base64, "Nonce should be valid base64" + + +# Tests for HTTPSRedirectMiddleware + + +def _make_app(): + """Create a simple test app wrapped with HTTPSRedirectMiddleware.""" + + def homepage(request): + return PlainTextResponse("OK") + + inner = Starlette(routes=[Route("/", homepage)]) + return HTTPSRedirectMiddleware(inner) + + +def test_https_redirect_middleware_localhost_no_redirect(): + """Requests from localhost should pass through without redirect.""" + app = _make_app() + client = TestClient(app, base_url="http://localhost", follow_redirects=False) + response = client.get("/") + assert response.status_code == 200 + + +def test_https_redirect_middleware_localhost_with_port_no_redirect(): + """Requests from localhost:port should pass through without redirect.""" + app = _make_app() + client = TestClient(app, base_url="http://localhost:8080", follow_redirects=False) + response = client.get("/") + assert response.status_code == 200 + + +def test_https_redirect_middleware_external_host_redirects(): + """Requests from non-localhost hosts should be redirected to HTTPS.""" + app = _make_app() + client = TestClient(app, base_url="http://example.com", follow_redirects=False) + response = client.get("/") + assert response.status_code == 307 + assert response.headers["location"] == "https://example.com/" + + +def test_https_redirect_middleware_already_https_no_redirect(): + """Requests already using HTTPS should not be redirected.""" + app = _make_app() + client = TestClient(app, base_url="https://example.com", follow_redirects=False) + response = client.get("/") + assert response.status_code == 200