Skip to content

Commit 28dcb79

Browse files
committed
fix: Support user defined http and request validation exception handlers.
1 parent 96ae8d8 commit 28dcb79

File tree

2 files changed

+28
-7
lines changed

2 files changed

+28
-7
lines changed

src/fastapi_problem/handler.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -199,13 +199,12 @@ def add_exception_handler( # noqa: PLR0913
199199
generic_swagger_defaults: bool = True,
200200
strict_rfc9457: bool = False,
201201
) -> ExceptionHandler:
202+
handlers_ = {
203+
HTTPException: http_exception_handler,
204+
RequestValidationError: request_validation_handler,
205+
}
202206
handlers = handlers or {}
203-
handlers.update(
204-
{
205-
HTTPException: http_exception_handler,
206-
RequestValidationError: request_validation_handler,
207-
},
208-
)
207+
handlers_.update(handlers)
209208
pre_hooks = pre_hooks or []
210209
post_hooks = post_hooks or []
211210

@@ -216,7 +215,7 @@ def add_exception_handler( # noqa: PLR0913
216215
eh = ExceptionHandler(
217216
logger=logger,
218217
unhandled_wrappers=unhandled_wrappers,
219-
handlers=handlers,
218+
handlers=handlers_,
220219
pre_hooks=pre_hooks,
221220
post_hooks=post_hooks,
222221
documentation_uri_template=documentation_uri_template,

tests/test_handler.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,3 +679,25 @@ async def status(_a: str) -> dict:
679679
"description": "Validation Error",
680680
},
681681
}
682+
683+
684+
async def test_custom_http_exception_handler_in_app():
685+
def custom_handler(_eh, _request, _exc) -> error.Problem:
686+
return error.Problem("a problem")
687+
688+
app = FastAPI()
689+
690+
handler.add_exception_handler(
691+
app=app,
692+
handlers={HTTPException: custom_handler},
693+
)
694+
695+
transport = httpx.ASGITransport(app=app, raise_app_exceptions=False, client=("1.2.3.4", 123))
696+
client = httpx.AsyncClient(transport=transport, base_url="https://test")
697+
698+
r = await client.get("/endpoint")
699+
assert r.json() == {
700+
"type": "problem",
701+
"title": "a problem",
702+
"status": 500,
703+
}

0 commit comments

Comments
 (0)