From f1a147e39caef0cc86d4ec7c276aa79ec8f57df1 Mon Sep 17 00:00:00 2001 From: Steven Loria Date: Mon, 20 Jan 2025 18:51:29 -0500 Subject: [PATCH 1/3] Add pre/post_load parameters to Field --- src/marshmallow/fields.py | 55 +++++++++++++---- src/marshmallow/types.py | 6 ++ tests/test_fields.py | 122 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 172 insertions(+), 11 deletions(-) diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py index 0f0e735a8..597c28e8b 100644 --- a/src/marshmallow/fields.py +++ b/src/marshmallow/fields.py @@ -80,6 +80,11 @@ "Url", ] +_ProcessorT = typing.TypeVar( + "_ProcessorT", + bound=typing.Union[types.PostLoadCallable, types.PreLoadCallable, types.Validator], +) + class Field(FieldABC): """Base field from which other fields inherit. @@ -153,6 +158,12 @@ def __init__( data_key: str | None = None, attribute: str | None = None, validate: types.Validator | typing.Iterable[types.Validator] | None = None, + pre_load: types.PreLoadCallable + | typing.Iterable[types.PreLoadCallable] + | None = None, + post_load: types.PostLoadCallable + | typing.Iterable[types.PostLoadCallable] + | None = None, required: bool = False, allow_none: bool | None = None, load_only: bool = False, @@ -193,17 +204,9 @@ def __init__( self.attribute = attribute self.data_key = data_key self.validate = validate - if validate is None: - self.validators = [] - elif callable(validate): - self.validators = [validate] - elif utils.is_iterable_but_not_string(validate): - self.validators = list(validate) - else: - raise ValueError( - "The 'validate' parameter must be a callable " - "or a collection of callables." - ) + self.validators = self._normalize_processors(validate, param="validate") + self.pre_load = self._normalize_processors(pre_load, param="pre_load") + self.post_load = self._normalize_processors(post_load, param="post_load") # If allow_none is None and load_default is None # None should be considered valid by default @@ -369,10 +372,23 @@ def deserialize( if value is missing_: _miss = self.load_default return _miss() if callable(_miss) else _miss + + # Apply pre_load functions + for func in self.pre_load: + if func is not None: + value = func(value) + if self.allow_none and value is None: return None + output = self._deserialize(value, attr, data, **kwargs) + # Apply validators self._validate(output) + + # Apply post_load functions + for func in self.post_load: + if func is not None: + output = func(output) return output # Methods for concrete classes to override. @@ -484,6 +500,23 @@ def missing(self, value): ) self.load_default = value + @staticmethod + def _normalize_processors( + processors: _ProcessorT | typing.Iterable[_ProcessorT] | None, + *, + param: str, + ) -> list[_ProcessorT]: + if processors is None: + return [] + if callable(processors): + return [typing.cast(_ProcessorT, processors)] + if not utils.is_iterable_but_not_string(processors): + raise ValueError( + f"The '{param}' parameter must be a callable " + "or an iterable of callables." + ) + return list(processors) + class Raw(Field): """Field that applies no formatting.""" diff --git a/src/marshmallow/types.py b/src/marshmallow/types.py index 599f6b49e..054fa70cf 100644 --- a/src/marshmallow/types.py +++ b/src/marshmallow/types.py @@ -9,11 +9,17 @@ import typing +T = typing.TypeVar("T") + #: A type that can be either a sequence of strings or a set of strings StrSequenceOrSet = typing.Union[typing.Sequence[str], typing.AbstractSet[str]] #: Type for validator functions Validator = typing.Callable[[typing.Any], typing.Any] +#: Type for field-level pre-load functions +PreLoadCallable = typing.Callable[[typing.Any], typing.Any] +#: Type for field-level post-load functions +PostLoadCallable = typing.Callable[[T], T] class SchemaValidator(typing.Protocol): diff --git a/tests/test_fields.py b/tests/test_fields.py index a2763db35..9b0dde8cd 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -663,3 +663,125 @@ class Family(Schema): "daughter": {"value": {"age": ["Missing data for required field."]}} } } + + +class TestFieldPreAndPostLoad: + def test_field_pre_load(self): + class UserSchema(Schema): + name = fields.Str(pre_load=str) + + schema = UserSchema() + result = schema.load({"name": 808}) + assert result["name"] == "808" + + def test_field_pre_load_multiple(self): + def decrement(value): + return value - 1 + + def add_prefix(value): + return "test_" + value + + class UserSchema(Schema): + name = fields.Str(pre_load=[decrement, str, add_prefix]) + + schema = UserSchema() + result = schema.load({"name": 809}) + assert result["name"] == "test_808" + + def test_field_post_load(self): + class UserSchema(Schema): + age = fields.Int(post_load=str) + + schema = UserSchema() + result = schema.load({"age": 42}) + assert result["age"] == "42" + + def test_field_post_load_multiple(self): + def to_string(value): + return str(value) + + def add_suffix(value): + return value + " years" + + class UserSchema(Schema): + age = fields.Int(post_load=[to_string, add_suffix]) + + schema = UserSchema() + result = schema.load({"age": 42}) + assert result["age"] == "42 years" + + def test_field_pre_and_post_load(self): + def multiply_by_2(value): + return value * 2 + + class UserSchema(Schema): + age = fields.Int(pre_load=[str.strip, int], post_load=[multiply_by_2]) + + schema = UserSchema() + result = schema.load({"age": " 21 "}) + assert result["age"] == 42 + + def test_field_pre_load_validation_error(self): + def always_fail(value): + raise ValidationError("oops") + + class UserSchema(Schema): + age = fields.Int(pre_load=always_fail) + + schema = UserSchema() + with pytest.raises(ValidationError) as exc: + schema.load({"age": 42}) + assert exc.value.messages == {"age": ["oops"]} + + def test_field_post_load_validation_error(self): + def always_fail(value): + raise ValidationError("oops") + + class UserSchema(Schema): + age = fields.Int(post_load=always_fail) + + schema = UserSchema() + with pytest.raises(ValidationError) as exc: + schema.load({"age": 42}) + assert exc.value.messages == {"age": ["oops"]} + + def test_field_pre_load_none(self): + def handle_none(value): + if value is None: + return 0 + return value + + class UserSchema(Schema): + age = fields.Int(pre_load=handle_none, allow_none=True) + + schema = UserSchema() + result = schema.load({"age": None}) + assert result["age"] == 0 + + def test_field_post_load_not_called_with_none_input_when_not_allowed(self): + def handle_none(value): + if value is None: + return 0 + return value + + class UserSchema(Schema): + age = fields.Int(post_load=handle_none, allow_none=False) + + schema = UserSchema() + with pytest.raises(ValidationError) as exc: + schema.load({"age": None}) + assert exc.value.messages == {"age": ["Field may not be null."]} + + def test_invalid_type_passed_to_pre_load(self): + with pytest.raises( + ValueError, + match="The 'pre_load' parameter must be a callable or an iterable of callables.", + ): + fields.Int(pre_load="not_callable") + + def test_invalid_type_passed_to_post_load(self): + with pytest.raises( + ValueError, + match="The 'post_load' parameter must be a callable or an iterable of callables.", + ): + fields.Int(post_load="not_callable") From db222228cbc876d7e7863ce463488f5b12e0509c Mon Sep 17 00:00:00 2001 From: Steven Loria Date: Mon, 20 Jan 2025 18:57:54 -0500 Subject: [PATCH 2/3] formatting --- src/marshmallow/fields.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py index 597c28e8b..0582a1b37 100644 --- a/src/marshmallow/fields.py +++ b/src/marshmallow/fields.py @@ -158,12 +158,12 @@ def __init__( data_key: str | None = None, attribute: str | None = None, validate: types.Validator | typing.Iterable[types.Validator] | None = None, - pre_load: types.PreLoadCallable - | typing.Iterable[types.PreLoadCallable] - | None = None, - post_load: types.PostLoadCallable - | typing.Iterable[types.PostLoadCallable] - | None = None, + pre_load: ( + types.PreLoadCallable | typing.Iterable[types.PreLoadCallable] | None + ) = None, + post_load: ( + types.PostLoadCallable | typing.Iterable[types.PostLoadCallable] | None + ) = None, required: bool = False, allow_none: bool | None = None, load_only: bool = False, From 0e104c33b1945d97cc0e19af4d46a0d796235b7f Mon Sep 17 00:00:00 2001 From: Steven Loria Date: Mon, 20 Jan 2025 19:01:06 -0500 Subject: [PATCH 3/3] Improve test to be more type safe --- tests/test_fields.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_fields.py b/tests/test_fields.py index 9b0dde8cd..749153a27 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -697,18 +697,18 @@ class UserSchema(Schema): assert result["age"] == "42" def test_field_post_load_multiple(self): - def to_string(value): - return str(value) + def multiply_by_2(value): + return value * 2 - def add_suffix(value): - return value + " years" + def decrement(value): + return value - 1 class UserSchema(Schema): - age = fields.Int(post_load=[to_string, add_suffix]) + age = fields.Float(post_load=[multiply_by_2, decrement]) schema = UserSchema() - result = schema.load({"age": 42}) - assert result["age"] == "42 years" + result = schema.load({"age": 21.5}) + assert result["age"] == 42.0 def test_field_pre_and_post_load(self): def multiply_by_2(value):