Skip to content

Commit 8f1c683

Browse files
fix(event_handler): parse single list items in form data (#7415)
* fix: parse single list items in form data * fix: reduce cognitive complexity * feat: add tests for requests with empty body * feat: test empty body events --------- Co-authored-by: Leandro Damascena <[email protected]>
1 parent d725d9d commit 8f1c683

File tree

2 files changed

+166
-39
lines changed

2 files changed

+166
-39
lines changed

aws_lambda_powertools/event_handler/middlewares/openapi_validation.py

Lines changed: 58 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import dataclasses
44
import json
55
import logging
6-
from copy import deepcopy
76
from typing import TYPE_CHECKING, Any, Callable, Mapping, MutableMapping, Sequence
87
from urllib.parse import parse_qs
98

@@ -15,6 +14,7 @@
1514
_normalize_errors,
1615
_regenerate_error_with_loc,
1716
get_missing_field_error,
17+
is_sequence_field,
1818
)
1919
from aws_lambda_powertools.event_handler.openapi.dependant import is_scalar_field
2020
from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder
@@ -150,11 +150,10 @@ def _parse_form_data(self, app: EventHandlerInstance) -> dict[str, Any]:
150150
"""Parse URL-encoded form data from the request body."""
151151
try:
152152
body = app.current_event.decoded_body or ""
153-
# parse_qs returns dict[str, list[str]], but we want dict[str, str] for single values
153+
# NOTE: Keep values as lists; we'll normalize per-field later based on the expected type.
154+
# This avoids breaking List[...] fields when only a single value is provided.
154155
parsed = parse_qs(body, keep_blank_values=True)
155-
156-
result: dict[str, Any] = {key: values[0] if len(values) == 1 else values for key, values in parsed.items()}
157-
return result
156+
return parsed
158157

159158
except Exception as e: # pragma: no cover
160159
raise RequestValidationError( # pragma: no cover
@@ -314,12 +313,12 @@ def _prepare_response_content(
314313
def _request_params_to_args(
315314
required_params: Sequence[ModelField],
316315
received_params: Mapping[str, Any],
317-
) -> tuple[dict[str, Any], list[Any]]:
316+
) -> tuple[dict[str, Any], list[dict[str, Any]]]:
318317
"""
319318
Convert the request params to a dictionary of values using validation, and returns a list of errors.
320319
"""
321-
values = {}
322-
errors = []
320+
values: dict[str, Any] = {}
321+
errors: list[dict[str, Any]] = []
323322

324323
for field in required_params:
325324
field_info = field.field_info
@@ -328,16 +327,12 @@ def _request_params_to_args(
328327
if not isinstance(field_info, Param):
329328
raise AssertionError(f"Expected Param field_info, got {field_info}")
330329

331-
value = received_params.get(field.alias)
332-
333330
loc = (field_info.in_.value, field.alias)
331+
value = received_params.get(field.alias)
334332

335333
# If we don't have a value, see if it's required or has a default
336334
if value is None:
337-
if field.required:
338-
errors.append(get_missing_field_error(loc=loc))
339-
else:
340-
values[field.name] = deepcopy(field.default)
335+
_handle_missing_field_value(field, values, errors, loc)
341336
continue
342337

343338
# Finally, validate the value
@@ -363,39 +358,64 @@ def _request_body_to_args(
363358
)
364359

365360
for field in required_params:
366-
# This sets the location to:
367-
# { "user": { object } } if field.alias == user
368-
# { { object } if field_alias is omitted
369-
loc: tuple[str, ...] = ("body", field.alias)
370-
if field_alias_omitted:
371-
loc = ("body",)
361+
loc = _get_body_field_location(field, field_alias_omitted)
362+
value = _extract_field_value_from_body(field, received_body, loc, errors)
372363

373-
value: Any | None = None
374-
375-
# Now that we know what to look for, try to get the value from the received body
376-
if received_body is not None:
377-
try:
378-
value = received_body.get(field.alias)
379-
except AttributeError:
380-
errors.append(get_missing_field_error(loc))
381-
continue
382-
383-
# Determine if the field is required
364+
# If we don't have a value, see if it's required or has a default
384365
if value is None:
385-
if field.required:
386-
errors.append(get_missing_field_error(loc))
387-
else:
388-
values[field.name] = deepcopy(field.default)
366+
_handle_missing_field_value(field, values, errors, loc)
389367
continue
390368

391-
# MAINTENANCE: Handle byte and file fields
392-
393-
# Finally, validate the value
369+
value = _normalize_field_value(field, value)
394370
values[field.name] = _validate_field(field=field, value=value, loc=loc, existing_errors=errors)
395371

396372
return values, errors
397373

398374

375+
def _get_body_field_location(field: ModelField, field_alias_omitted: bool) -> tuple[str, ...]:
376+
"""Get the location tuple for a body field based on whether the field alias is omitted."""
377+
if field_alias_omitted:
378+
return ("body",)
379+
return ("body", field.alias)
380+
381+
382+
def _extract_field_value_from_body(
383+
field: ModelField,
384+
received_body: dict[str, Any] | None,
385+
loc: tuple[str, ...],
386+
errors: list[dict[str, Any]],
387+
) -> Any | None:
388+
"""Extract field value from the received body, handling potential AttributeError."""
389+
if received_body is None:
390+
return None
391+
392+
try:
393+
return received_body.get(field.alias)
394+
except AttributeError:
395+
errors.append(get_missing_field_error(loc))
396+
return None
397+
398+
399+
def _handle_missing_field_value(
400+
field: ModelField,
401+
values: dict[str, Any],
402+
errors: list[dict[str, Any]],
403+
loc: tuple[str, ...],
404+
) -> None:
405+
"""Handle the case when a field value is missing."""
406+
if field.required:
407+
errors.append(get_missing_field_error(loc))
408+
else:
409+
values[field.name] = field.get_default()
410+
411+
412+
def _normalize_field_value(field: ModelField, value: Any) -> Any:
413+
"""Normalize field value, converting lists to single values for non-sequence fields."""
414+
if isinstance(value, list) and not is_sequence_field(field):
415+
return value[0]
416+
return value
417+
418+
399419
def _validate_field(
400420
*,
401421
field: ModelField,

tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,100 @@ def handler(user: Annotated[Model, Body(embed=True)]) -> Model:
366366
assert result["statusCode"] == 200
367367

368368

369+
def test_validate_body_param_with_missing_body(gw_event):
370+
# GIVEN an APIGatewayRestResolver with validation enabled
371+
app = APIGatewayRestResolver(enable_validation=True)
372+
373+
# WHEN a handler is defined with multiple body parameters
374+
@app.post("/")
375+
def handler(name: str, age: int):
376+
return {"name": name, "age": age}
377+
378+
# WHEN the event has no body
379+
gw_event["httpMethod"] = "POST"
380+
gw_event["path"] = "/"
381+
gw_event["body"] = None # simulate event without body
382+
gw_event["headers"]["content-type"] = "application/json"
383+
384+
result = app(gw_event, {})
385+
386+
# THEN the handler should be invoked and return 422
387+
assert result["statusCode"] == 422
388+
assert "missing" in result["body"]
389+
390+
391+
def test_validate_body_param_with_empty_body(gw_event):
392+
# GIVEN an APIGatewayRestResolver with validation enabled
393+
app = APIGatewayRestResolver(enable_validation=True)
394+
395+
# WHEN a handler is defined with multiple body parameters
396+
@app.post("/")
397+
def handler(name: str, age: int):
398+
return {"name": name, "age": age}
399+
400+
# WHEN the event has no body
401+
gw_event["httpMethod"] = "POST"
402+
gw_event["path"] = "/"
403+
gw_event["body"] = "[]" # JSON array -> received_body is a list (no .get)
404+
gw_event["headers"]["content-type"] = "application/json"
405+
406+
result = app(gw_event, {})
407+
408+
# THEN the handler should be invoked and return 422
409+
assert result["statusCode"] == 422
410+
assert "missing" in result["body"]
411+
412+
413+
def test_validate_embed_body_param_with_missing_body(gw_event):
414+
# GIVEN an APIGatewayRestResolver with validation enabled
415+
app = APIGatewayRestResolver(enable_validation=True)
416+
417+
class Model(BaseModel):
418+
name: str
419+
420+
# WHEN a handler is defined with a body parameter
421+
@app.post("/")
422+
def handler(user: Annotated[Model, Body(embed=True)]) -> Model:
423+
return user
424+
425+
# WHEN the event has no body
426+
gw_event["httpMethod"] = "POST"
427+
gw_event["path"] = "/"
428+
gw_event["body"] = None # simulate event without body
429+
gw_event["headers"]["content-type"] = "application/json"
430+
431+
result = app(gw_event, {})
432+
433+
# THEN the handler should be invoked and return 422
434+
assert result["statusCode"] == 422
435+
assert "missing" in result["body"]
436+
437+
438+
def test_validate_embed_body_param_with_empty_body(gw_event):
439+
# GIVEN an APIGatewayRestResolver with validation enabled
440+
app = APIGatewayRestResolver(enable_validation=True)
441+
442+
class Model(BaseModel):
443+
name: str
444+
445+
# WHEN a handler is defined with a body parameter
446+
@app.post("/")
447+
def handler(user: Annotated[Model, Body(embed=True)]) -> Model:
448+
return user
449+
450+
# WHEN the event has no body
451+
gw_event["httpMethod"] = "POST"
452+
gw_event["path"] = "/"
453+
gw_event["body"] = "[]" # JSON array -> received_body is a list (no .get)
454+
gw_event["headers"]["content-type"] = "application/json"
455+
456+
result = app(gw_event, {})
457+
458+
# THEN the handler should be invoked and return 422
459+
assert result["statusCode"] == 422
460+
assert "missing" in result["body"]
461+
462+
369463
def test_validate_response_return(gw_event):
370464
# GIVEN an APIGatewayRestResolver with validation enabled
371465
app = APIGatewayRestResolver(enable_validation=True)
@@ -1481,20 +1575,33 @@ def handler_custom_route_response_validation_error() -> Model:
14811575

14821576
def test_parse_form_data_url_encoded(gw_event):
14831577
"""Test _parse_form_data method with URL-encoded form data"""
1484-
1578+
# GIVEN an APIGatewayRestResolver with validation enabled
14851579
app = APIGatewayRestResolver(enable_validation=True)
14861580

14871581
@app.post("/form")
14881582
def post_form(name: Annotated[str, Form()], tags: Annotated[List[str], Form()]):
14891583
return {"name": name, "tags": tags}
14901584

1585+
# WHEN sending a POST request with URL-encoded form data
14911586
gw_event["httpMethod"] = "POST"
14921587
gw_event["path"] = "/form"
14931588
gw_event["headers"]["content-type"] = "application/x-www-form-urlencoded"
14941589
gw_event["body"] = "name=test&tags=tag1&tags=tag2"
14951590

14961591
result = app(gw_event, {})
1592+
1593+
# THEN it should parse the form data correctly
1594+
assert result["statusCode"] == 200
1595+
assert result["body"] == '{"name":"test","tags":["tag1","tag2"]}'
1596+
1597+
# WHEN sending a POST request with a single value for a list field
1598+
gw_event["body"] = "name=test&tags=tag1"
1599+
1600+
result = app(gw_event, {})
1601+
1602+
# THEN it should parse the form data correctly
14971603
assert result["statusCode"] == 200
1604+
assert result["body"] == '{"name":"test","tags":["tag1"]}'
14981605

14991606

15001607
def test_parse_form_data_wrong_value(gw_event):

0 commit comments

Comments
 (0)