Skip to content

Commit 4a0269e

Browse files
dap0amleandrodamascenadreamorosi
authored
feat(event-handler): add support for Pydantic Field discriminator in validation (#7227)
* feat(event-handler): add support for Pydantic Field discriminator in validation (#5953) Enable use of Field(discriminator='...') with tagged unions in event handler validation. This allows developers to use Pydantic's native discriminator syntax instead of requiring Powertools-specific Param annotations. - Handle Field(discriminator) + Body() combination in get_field_info_annotated_type - Preserve discriminator metadata when creating TypeAdapter in ModelField - Add comprehensive tests for discriminator validation and Field features * style(tests): remove inline comments to match project test style * style: run make format to fix CI formatting issues * fix(event-handler): preserve FieldInfo subclass types in copy_field_info Fix regression where copy_field_info was losing custom FieldInfo subclass types (Body, Query, etc.) by using shallow copy instead of from_annotation. This resolves the failing test_validate_embed_body_param while maintaining the discriminator functionality. * refactor(event-handler): reduce cognitive complexity and address SonarCloud issues - Refactor get_field_info_annotated_type function by extracting helper functions to reduce cognitive complexity from 29 to below 15 - Fix copy_field_info to preserve FieldInfo subclass types using shallow copy instead of from_annotation - Rename variable Action to action_type to follow Python naming conventions - Resolve failing test_validate_embed_body_param by maintaining Body parameter type recognition - Add helper functions: _has_discriminator, _handle_discriminator_with_body, _create_field_info, _set_field_default - Maintain full backward compatibility and discriminator functionality * style: fix formatting to pass CI format check Apply ruff formatting to params.py to resolve failing format check in CI * fix: resolve mypy type error in _create_field_info function Add explicit type annotation for field_info variable to fix mypy error about incompatible types between FieldInfo and Body. This ensures type checking passes across all Python versions (3.9-3.13). * fix: use Union syntax for Python 3.9 compatibility * feat(event-handler): add documentation and example for Field discriminator support * style: run make format to fix CI formatting issues * small changes * small changes --------- Co-authored-by: Leandro Damascena <[email protected]> Co-authored-by: Andrea Amorosi <[email protected]>
1 parent e771849 commit 4a0269e

File tree

8 files changed

+265
-27
lines changed

8 files changed

+265
-27
lines changed

aws_lambda_powertools/event_handler/openapi/compat.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33

44
from collections import deque
55
from collections.abc import Mapping, Sequence
6-
7-
# MAINTENANCE: remove when deprecating Pydantic v1. Mypy doesn't handle two different code paths that import different
8-
# versions of a module, so we need to ignore errors here.
6+
from copy import copy
97
from dataclasses import dataclass, is_dataclass
108
from typing import TYPE_CHECKING, Any, Deque, FrozenSet, List, Set, Tuple, Union
119

@@ -80,9 +78,19 @@ def type_(self) -> Any:
8078
return self.field_info.annotation
8179

8280
def __post_init__(self) -> None:
83-
self._type_adapter: TypeAdapter[Any] = TypeAdapter(
84-
Annotated[self.field_info.annotation, self.field_info],
85-
)
81+
# If the field_info.annotation is already an Annotated type with discriminator metadata,
82+
# use it directly instead of wrapping it again
83+
annotation = self.field_info.annotation
84+
if (
85+
get_origin(annotation) is Annotated
86+
and hasattr(self.field_info, "discriminator")
87+
and self.field_info.discriminator is not None
88+
):
89+
self._type_adapter: TypeAdapter[Any] = TypeAdapter(annotation)
90+
else:
91+
self._type_adapter: TypeAdapter[Any] = TypeAdapter(
92+
Annotated[annotation, self.field_info],
93+
)
8694

8795
def get_default(self) -> Any:
8896
if self.field_info.is_required():
@@ -176,7 +184,11 @@ def model_rebuild(model: type[BaseModel]) -> None:
176184

177185

178186
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
179-
return type(field_info).from_annotation(annotation)
187+
# Create a shallow copy of the field_info to preserve its type and all attributes
188+
new_field = copy(field_info)
189+
# Update only the annotation to the new one
190+
new_field.annotation = annotation
191+
return new_field
180192

181193

182194
def get_missing_field_error(loc: tuple[str, ...]) -> dict[str, Any]:

aws_lambda_powertools/event_handler/openapi/params.py

Lines changed: 84 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,35 +1037,101 @@ def get_field_info_response_type(annotation, value) -> tuple[FieldInfo | None, A
10371037
return get_field_info_and_type_annotation(inner_type, value, False, True)
10381038

10391039

1040+
def _has_discriminator(field_info: FieldInfo) -> bool:
1041+
"""Check if a FieldInfo has a discriminator."""
1042+
return hasattr(field_info, "discriminator") and field_info.discriminator is not None
1043+
1044+
1045+
def _handle_discriminator_with_param(
1046+
annotations: list[FieldInfo],
1047+
annotation: Any,
1048+
) -> tuple[FieldInfo | None, Any, bool]:
1049+
"""
1050+
Handle the special case of Field(discriminator) + Body() combination.
1051+
1052+
Returns:
1053+
tuple of (powertools_annotation, type_annotation, has_discriminator_with_body)
1054+
"""
1055+
field_obj = None
1056+
body_obj = None
1057+
1058+
for ann in annotations:
1059+
if isinstance(ann, Body):
1060+
body_obj = ann
1061+
elif _has_discriminator(ann):
1062+
field_obj = ann
1063+
1064+
if field_obj and body_obj:
1065+
# Use Body as the primary annotation, preserve full annotation for validation
1066+
return body_obj, annotation, True
1067+
1068+
raise AssertionError("Only one FieldInfo can be used per parameter")
1069+
1070+
1071+
def _create_field_info(
1072+
powertools_annotation: FieldInfo,
1073+
type_annotation: Any,
1074+
has_discriminator_with_body: bool,
1075+
) -> FieldInfo:
1076+
"""Create or copy FieldInfo based on the annotation type."""
1077+
field_info: FieldInfo
1078+
if has_discriminator_with_body:
1079+
# For discriminator + Body case, create a new Body instance directly
1080+
field_info = Body()
1081+
field_info.annotation = type_annotation
1082+
else:
1083+
# Copy field_info because we mutate field_info.default later
1084+
field_info = copy_field_info(
1085+
field_info=powertools_annotation,
1086+
annotation=type_annotation,
1087+
)
1088+
return field_info
1089+
1090+
1091+
def _set_field_default(field_info: FieldInfo, value: Any, is_path_param: bool) -> None:
1092+
"""Set the default value for a field."""
1093+
if field_info.default not in [Undefined, Required]:
1094+
raise AssertionError("FieldInfo needs to have a default value of Undefined or Required")
1095+
1096+
if value is not inspect.Signature.empty:
1097+
if is_path_param:
1098+
raise AssertionError("Cannot use a FieldInfo as a path parameter and pass a value")
1099+
field_info.default = value
1100+
else:
1101+
field_info.default = Required
1102+
1103+
10401104
def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> tuple[FieldInfo | None, Any]:
10411105
"""
10421106
Get the FieldInfo and type annotation from an Annotated type.
10431107
"""
1044-
field_info: FieldInfo | None = None
10451108
annotated_args = get_args(annotation)
10461109
type_annotation = annotated_args[0]
10471110
powertools_annotations = [arg for arg in annotated_args[1:] if isinstance(arg, FieldInfo)]
10481111

1049-
if len(powertools_annotations) > 1:
1050-
raise AssertionError("Only one FieldInfo can be used per parameter")
1051-
1052-
powertools_annotation = next(iter(powertools_annotations), None)
1112+
# Determine which annotation to use
1113+
powertools_annotation: FieldInfo | None = None
1114+
has_discriminator_with_param = False
10531115

1054-
if isinstance(powertools_annotation, FieldInfo):
1055-
# Copy `field_info` because we mutate `field_info.default` later
1056-
field_info = copy_field_info(
1057-
field_info=powertools_annotation,
1058-
annotation=annotation,
1116+
if len(powertools_annotations) == 2:
1117+
powertools_annotation, type_annotation, has_discriminator_with_param = _handle_discriminator_with_param(
1118+
powertools_annotations,
1119+
annotation,
10591120
)
1060-
if field_info.default not in [Undefined, Required]:
1061-
raise AssertionError("FieldInfo needs to have a default value of Undefined or Required")
1121+
elif len(powertools_annotations) > 1:
1122+
raise AssertionError("Only one FieldInfo can be used per parameter")
1123+
else:
1124+
powertools_annotation = next(iter(powertools_annotations), None)
10621125

1063-
if value is not inspect.Signature.empty:
1064-
if is_path_param:
1065-
raise AssertionError("Cannot use a FieldInfo as a path parameter and pass a value")
1066-
field_info.default = value
1067-
else:
1068-
field_info.default = Required
1126+
# Process the annotation if it exists
1127+
field_info: FieldInfo | None = None
1128+
if isinstance(powertools_annotation, FieldInfo): # pragma: no cover
1129+
field_info = _create_field_info(powertools_annotation, type_annotation, has_discriminator_with_param)
1130+
_set_field_default(field_info, value, is_path_param)
1131+
1132+
# Preserve full annotated type for discriminated unions
1133+
if _has_discriminator(powertools_annotation): # pragma: no cover
1134+
type_annotation = annotation # pragma: no cover
10691135

10701136
return field_info, type_annotation
10711137

docs/core/event_handler/api_gateway.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,17 @@ We use the `Annotated` and OpenAPI `Body` type to instruct Event Handler that ou
428428
--8<-- "examples/event_handler_rest/src/validating_payload_subset_output.json"
429429
```
430430

431+
##### Discriminated unions
432+
433+
You can use Pydantic's `Field(discriminator="...")` with union types to create discriminated unions (also known as tagged unions). This allows the Event Handler to automatically determine which model to use based on a discriminator field in the request body.
434+
435+
```python hl_lines="3 4 8 31 36" title="discriminated_unions.py"
436+
--8<-- "examples/event_handler_rest/src/discriminated_unions.py"
437+
```
438+
439+
1. `Field(discriminator="action")` tells Pydantic to use the `action` field to determine which model to instantiate
440+
2. `Body()` annotation tells the Event Handler to parse the request body using the discriminated union
441+
431442
#### Validating responses
432443

433444
You can use `response_validation_error_http_code` to set a custom HTTP code for failed response validation. When this field is set, we will raise a `ResponseValidationError` instead of a `RequestValidationError`.
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from typing import Literal, Union
2+
3+
from pydantic import BaseModel, Field
4+
from typing_extensions import Annotated
5+
6+
from aws_lambda_powertools import Logger, Tracer
7+
from aws_lambda_powertools.event_handler import APIGatewayRestResolver
8+
from aws_lambda_powertools.event_handler.openapi.params import Body
9+
from aws_lambda_powertools.logging import correlation_paths
10+
from aws_lambda_powertools.utilities.typing import LambdaContext
11+
12+
tracer = Tracer()
13+
logger = Logger()
14+
app = APIGatewayRestResolver(enable_validation=True)
15+
16+
17+
class FooAction(BaseModel):
18+
"""Action type for foo operations."""
19+
20+
action: Literal["foo"] = "foo"
21+
foo_data: str
22+
23+
24+
class BarAction(BaseModel):
25+
"""Action type for bar operations."""
26+
27+
action: Literal["bar"] = "bar"
28+
bar_data: int
29+
30+
31+
ActionType = Annotated[Union[FooAction, BarAction], Field(discriminator="action")] # (1)!
32+
33+
34+
@app.post("/actions")
35+
@tracer.capture_method
36+
def handle_action(action: Annotated[ActionType, Body(description="Action to perform")]): # (2)!
37+
"""Handle different action types using discriminated unions."""
38+
if isinstance(action, FooAction):
39+
return {"message": f"Handling foo action with data: {action.foo_data}"}
40+
elif isinstance(action, BarAction):
41+
return {"message": f"Handling bar action with data: {action.bar_data}"}
42+
43+
44+
@logger.inject_lambda_context(correlation_id_path=correlation_paths.API_GATEWAY_HTTP)
45+
@tracer.capture_lambda_handler
46+
def lambda_handler(event: dict, context: LambdaContext) -> dict:
47+
return app.resolve(event, context)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from __future__ import annotations
2+
3+
from typing import Annotated, Literal
4+
5+
from pydantic import BaseModel, Field
6+
7+
from aws_lambda_powertools.event_handler import APIGatewayRestResolver
8+
from aws_lambda_powertools.event_handler.openapi.params import Body
9+
10+
app = APIGatewayRestResolver(enable_validation=True)
11+
app.enable_swagger()
12+
13+
14+
class FooAction(BaseModel):
15+
action: Literal["foo"]
16+
foo_data: str
17+
18+
19+
class BarAction(BaseModel):
20+
action: Literal["bar"]
21+
bar_data: int
22+
23+
24+
Action = Annotated[FooAction | BarAction, Field(discriminator="action")]
25+
26+
27+
@app.post("/data_validation_with_fields")
28+
def create_action(action: Annotated[Action, Body(discriminator="action")]):
29+
return {"message": "Powertools e2e API"}
30+
31+
32+
def lambda_handler(event, context):
33+
print(event)
34+
return app.resolve(event, context)

tests/e2e/event_handler/infrastructure.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def create_resources(self):
2424
functions["OpenapiHandler"],
2525
functions["OpenapiHandlerWithPep563"],
2626
functions["DataValidationAndMiddleware"],
27+
functions["DataValidationWithFields"],
2728
],
2829
)
2930
self._create_api_gateway_http(function=functions["ApiGatewayHttpHandler"])
@@ -105,6 +106,9 @@ def _create_api_gateway_rest(self, function: list[Function]):
105106
openapi_schema = apigw.root.add_resource("data_validation_middleware")
106107
openapi_schema.add_method("GET", apigwv1.LambdaIntegration(function[3], proxy=True))
107108

109+
openapi_schema = apigw.root.add_resource("data_validation_with_fields")
110+
openapi_schema.add_method("POST", apigwv1.LambdaIntegration(function[4], proxy=True))
111+
108112
CfnOutput(self.stack, "APIGatewayRestUrl", value=apigw.url)
109113

110114
def _create_lambda_function_url(self, function: Function):

tests/e2e/event_handler/test_openapi.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,20 @@ def test_get_openapi_validation_and_middleware(apigw_rest_endpoint):
5959
)
6060

6161
assert response.status_code == 202
62+
63+
64+
def test_openapi_with_fields_discriminator(apigw_rest_endpoint):
65+
# GIVEN
66+
url = f"{apigw_rest_endpoint}data_validation_with_fields"
67+
68+
# WHEN
69+
response = data_fetcher.get_http_response(
70+
Request(
71+
method="POST",
72+
url=url,
73+
json={"action": "foo", "foo_data": "foo data working"},
74+
),
75+
)
76+
77+
assert "Powertools e2e API" in response.text
78+
assert response.status_code == 200

tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
from dataclasses import dataclass
44
from enum import Enum
55
from pathlib import PurePath
6-
from typing import List, Optional, Tuple
6+
from typing import List, Literal, Optional, Tuple, Union
77

88
import pytest
9-
from pydantic import BaseModel
9+
from pydantic import BaseModel, Field
1010
from typing_extensions import Annotated
1111

1212
from aws_lambda_powertools.event_handler import (
@@ -1983,3 +1983,50 @@ def get_user(user_id: int) -> UserModel:
19831983
assert response_body["name"] == "User123"
19841984
assert response_body["age"] == 143
19851985
assert response_body["email"] == "[email protected]"
1986+
1987+
1988+
def test_field_discriminator_validation(gw_event):
1989+
"""Test that Pydantic Field discriminator works with event_handler validation"""
1990+
app = APIGatewayRestResolver(enable_validation=True)
1991+
1992+
class FooAction(BaseModel):
1993+
action: Literal["foo"]
1994+
foo_data: str
1995+
1996+
class BarAction(BaseModel):
1997+
action: Literal["bar"]
1998+
bar_data: int
1999+
2000+
action_type = Annotated[Union[FooAction, BarAction], Field(discriminator="action")]
2001+
2002+
@app.post("/actions")
2003+
def create_action(action: Annotated[action_type, Body()]):
2004+
return {"received_action": action.action, "data": action.model_dump()}
2005+
2006+
gw_event["path"] = "/actions"
2007+
gw_event["httpMethod"] = "POST"
2008+
gw_event["headers"]["content-type"] = "application/json"
2009+
gw_event["body"] = '{"action": "foo", "foo_data": "test"}'
2010+
2011+
result = app(gw_event, {})
2012+
assert result["statusCode"] == 200
2013+
2014+
response_body = json.loads(result["body"])
2015+
assert response_body["received_action"] == "foo"
2016+
assert response_body["data"]["action"] == "foo"
2017+
assert response_body["data"]["foo_data"] == "test"
2018+
2019+
gw_event["body"] = '{"action": "bar", "bar_data": 123}'
2020+
2021+
result = app(gw_event, {})
2022+
assert result["statusCode"] == 200
2023+
2024+
response_body = json.loads(result["body"])
2025+
assert response_body["received_action"] == "bar"
2026+
assert response_body["data"]["action"] == "bar"
2027+
assert response_body["data"]["bar_data"] == 123
2028+
2029+
gw_event["body"] = '{"action": "invalid", "some_data": "test"}'
2030+
2031+
result = app(gw_event, {})
2032+
assert result["statusCode"] == 422

0 commit comments

Comments
 (0)