From ed73966c5cac973f7b543c448218e65e6f64dd39 Mon Sep 17 00:00:00 2001 From: Paul Van Eck Date: Wed, 1 Oct 2025 02:29:48 +0000 Subject: [PATCH] [Core] Add Serializable protocol This is a protocol that custom models can implement to allow it to work with the SDK serialization/deserialization mechanisms. Signed-off-by: Paul Van Eck --- .../azure-core/azure/core/serialization.py | 67 ++++++++- .../modeltypes/_utils/model_base.py | 13 +- .../azure-core/tests/test_serialization.py | 128 +++++++++++++++++- 3 files changed, 204 insertions(+), 4 deletions(-) diff --git a/sdk/core/azure-core/azure/core/serialization.py b/sdk/core/azure-core/azure/core/serialization.py index a339c0c58f8b..4cb0d3b7b853 100644 --- a/sdk/core/azure-core/azure/core/serialization.py +++ b/sdk/core/azure-core/azure/core/serialization.py @@ -6,7 +6,7 @@ # -------------------------------------------------------------------------- import base64 from json import JSONEncoder -from typing import Dict, List, Optional, Union, cast, Any +from typing import Dict, List, Optional, Union, cast, Any, Protocol, Type, TypeVar, runtime_checkable from datetime import datetime, date, time, timedelta from datetime import timezone @@ -14,6 +14,8 @@ __all__ = ["NULL", "AzureJSONEncoder", "is_generated_model", "as_attribute_dict", "attribute_list"] TZ_UTC = timezone.utc +T = TypeVar("T", bound="Serializable") + class _Null: """To create a Falsy object""" @@ -29,6 +31,69 @@ def __bool__(self) -> bool: """ +@runtime_checkable +class Serializable(Protocol): + """A protocol for objects that can be serialized to and deserialized from a dictionary representation. + + This protocol defines a standard interface for custom models to integrate with the Azure SDK serialization + and deserialization mechanisms. By implementing the `to_dict` and `from_dict` methods, a custom type can + control how it is converted for REST API calls and reconstituted from API responses. + + Examples: + + .. code-block:: python + + from typing import Dict, Any, Type + + class CustomModel: + + foo: str + bar: str + + def __init__(self, *, foo: str, bar: str): + self.foo = foo + self.bar = bar + + def to_dict(self) -> Dict[str, Any]: + return { + "foo": self.foo, + "bar": self.bar + } + + @classmethod + def from_dict(cls: Type["CustomModel"], data: Dict[str, Any]) -> "CustomModel": + return cls( + foo=data["foo"], + bar=data["bar"] + ) + """ + + def to_dict(self) -> Dict[str, Any]: + """Returns a dictionary representation of the object. + + The keys of the dictionary should correspond to the REST API's JSON field names. This method is responsible + for mapping the object's attributes to the correct wire format. + + :return: A dictionary representing the object. + :rtype: dict[str, any] + """ + ... + + @classmethod + def from_dict(cls: Type[T], data: Dict[str, Any]) -> T: + """Creates an instance of the class from a dictionary. + + The dictionary keys are expected to be the REST API's JSON field names. This method is responsible for + mapping the incoming dictionary to the object's attributes. + + :param data: A dictionary containing the object's data. + :type data: dict[str, any] + :return: An instance of the class. + :rtype: ~T + """ + ... + + def _timedelta_as_isostr(td: timedelta) -> str: """Converts a datetime.timedelta object into an ISO 8601 formatted string, e.g. 'P4DT12H30M05S' diff --git a/sdk/core/azure-core/tests/specs_sdk/modeltypes/modeltypes/_utils/model_base.py b/sdk/core/azure-core/tests/specs_sdk/modeltypes/modeltypes/_utils/model_base.py index 49d5c7259389..1dfa002c3e79 100644 --- a/sdk/core/azure-core/tests/specs_sdk/modeltypes/modeltypes/_utils/model_base.py +++ b/sdk/core/azure-core/tests/specs_sdk/modeltypes/modeltypes/_utils/model_base.py @@ -28,7 +28,7 @@ from azure.core.exceptions import DeserializationError from azure.core import CaseInsensitiveEnumMeta from azure.core.pipeline import PipelineResponse -from azure.core.serialization import _Null +from azure.core.serialization import _Null, Serializable _LOGGER = logging.getLogger(__name__) @@ -161,6 +161,11 @@ def default(self, o): # pylint: disable=too-many-return-statements except AttributeError: # This will be raised when it hits value.total_seconds in the method above pass + + # Check if the object implements the Serializable protocol + if isinstance(o, Serializable): + return o.to_dict() + return super(SdkJSONEncoder, self).default(o) @@ -510,6 +515,9 @@ def _serialize(o, format: typing.Optional[str] = None): # pylint: disable=too-m except AttributeError: # This will be raised when it hits value.total_seconds in the method above pass + + if isinstance(o, Serializable): + return o.to_dict() return o @@ -886,6 +894,9 @@ def _deserialize_default( if get_deserializer(annotation, rf): return functools.partial(_deserialize_default, get_deserializer(annotation, rf)) + if isinstance(annotation, Serializable): + return annotation.from_dict # type: ignore + return functools.partial(_deserialize_default, annotation) diff --git a/sdk/core/azure-core/tests/test_serialization.py b/sdk/core/azure-core/tests/test_serialization.py index b1c808068523..d5fd65217a81 100644 --- a/sdk/core/azure-core/tests/test_serialization.py +++ b/sdk/core/azure-core/tests/test_serialization.py @@ -10,9 +10,16 @@ from typing import Any, Dict, List, Optional from io import BytesIO -from azure.core.serialization import AzureJSONEncoder, NULL, as_attribute_dict, is_generated_model, attribute_list +from azure.core.serialization import ( + AzureJSONEncoder, + NULL, + as_attribute_dict, + is_generated_model, + attribute_list, + Serializable, +) import pytest -from modeltypes._utils.model_base import Model as HybridModel, rest_field +from modeltypes._utils.model_base import Model as HybridModel, SdkJSONEncoder, rest_field, _deserialize from modeltypes._utils.serialization import Model as MsrestModel from modeltypes import models @@ -972,3 +979,120 @@ def _tests(model): "birthdate": "2017-12-13T02:29:51Z", "complexProperty": {"color": "Red"}, } + + +class TestSerializableProtocol: + + class FooModel: + + foo: str + bar: int + baz: float + + def __init__(self, foo: str, bar: int, baz: float): + self.foo = foo + self.bar = bar + self.baz = baz + + def to_dict(self) -> Dict[str, Any]: + return {"foo": self.foo, "bar": self.bar, "baz": self.baz} + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "TestSerializableProtocol.FooModel": + return cls(foo=data["foo"], bar=data["bar"], baz=data["baz"]) + + def test_is_serializable_protocol(self): + model = TestSerializableProtocol.FooModel("hello", 42, 3.14) + assert isinstance(model, Serializable) + assert issubclass(TestSerializableProtocol.FooModel, Serializable) + + assert not isinstance(models.Fish(kind="goldfish", age=1), Serializable) + assert not issubclass(models.Fish, Serializable) + + assert hasattr(model, "to_dict") + assert hasattr(TestSerializableProtocol.FooModel, "from_dict") + + def test_serialization(self): + model = TestSerializableProtocol.FooModel("hello", 42, 3.14) + + json_str = json.dumps(model, cls=SdkJSONEncoder, exclude_readonly=True) + assert json.loads(json_str) == {"foo": "hello", "bar": 42, "baz": 3.14} + + def test_serialize_custom_model_in_generated_model(self): + + class GeneratedModel(HybridModel): + dog: models.HybridDog = rest_field(visibility=["read", "create", "update", "delete", "query"]) + external: TestSerializableProtocol.FooModel = rest_field( + visibility=["read", "create", "update", "delete", "query"] + ) + + model = GeneratedModel( + dog=models.HybridDog(name="doggy", species="dog", breed="samoyed", is_best_boy=True), + external=TestSerializableProtocol.FooModel(foo="foo", bar=42, baz=3.14), + ) + + json_str = json.dumps(model, cls=SdkJSONEncoder, exclude_readonly=True) + + expected_dict = { + "dog": { + "name": "doggy", + "species": "dog", + "breed": "samoyed", + "isBestBoy": True, + }, + "external": { + "foo": "foo", + "bar": 42, + "baz": 3.14, + }, + } + + json_str = json.dumps(model, cls=SdkJSONEncoder, exclude_readonly=True) + assert json.loads(json_str) == expected_dict + assert model.as_dict() == expected_dict + + def test_deserialize_custom_model(self): + json_dict = { + "foo": "foo", + "bar": 42, + "baz": 3.14, + } + json_dict = {"foo": "foo", "bar": 42, "baz": 3.14} + deserialized = _deserialize(TestSerializableProtocol.FooModel, json_dict) + assert isinstance(deserialized, TestSerializableProtocol.FooModel) + assert deserialized.foo == "foo" + assert deserialized.bar == 42 + assert deserialized.baz == 3.14 + + def test_deserialize_custom_model_in_generated_model(self): + + class GeneratedModel(HybridModel): + dog: models.HybridDog = rest_field(visibility=["read", "create", "update", "delete", "query"]) + external: TestSerializableProtocol.FooModel = rest_field( + visibility=["read", "create", "update", "delete", "query"] + ) + + json_dict = { + "dog": { + "name": "doggy", + "species": "dog", + "breed": "samoyed", + "isBestBoy": True, + }, + "external": { + "foo": "foo", + "bar": 42, + "baz": 3.14, + }, + } + deserialized = _deserialize(GeneratedModel, json_dict) + assert isinstance(deserialized, GeneratedModel) + assert isinstance(deserialized.dog, models.HybridDog) + assert deserialized.dog.name == "doggy" + assert deserialized.dog.species == "dog" + assert deserialized.dog.breed == "samoyed" + assert deserialized.dog.is_best_boy is True + assert isinstance(deserialized.external, TestSerializableProtocol.FooModel) + assert deserialized.external.foo == "foo" + assert deserialized.external.bar == 42 + assert deserialized.external.baz == 3.14