-
Notifications
You must be signed in to change notification settings - Fork 3.1k
[Core] Add Serializable protocol #43207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,14 +6,16 @@ | |
# -------------------------------------------------------------------------- | ||
import base64 | ||
from json import JSONEncoder | ||
from typing import Dict, List, Optional, Union, cast, Any | ||
from typing import Dict, List, Optional, Union, cast, Any, Protocol, Type, TypeVar, runtime_checkable | ||
from datetime import datetime, date, time, timedelta | ||
from datetime import timezone | ||
|
||
|
||
__all__ = ["NULL", "AzureJSONEncoder", "is_generated_model", "as_attribute_dict", "attribute_list"] | ||
TZ_UTC = timezone.utc | ||
|
||
T = TypeVar("T", bound="Serializable") | ||
|
||
|
||
class _Null: | ||
"""To create a Falsy object""" | ||
|
@@ -29,6 +31,69 @@ def __bool__(self) -> bool: | |
""" | ||
|
||
|
||
@runtime_checkable | ||
class Serializable(Protocol): | ||
"""A protocol for objects that can be serialized to and deserialized from a dictionary representation. | ||
|
||
This protocol defines a standard interface for custom models to integrate with the Azure SDK serialization | ||
and deserialization mechanisms. By implementing the `to_dict` and `from_dict` methods, a custom type can | ||
control how it is converted for REST API calls and reconstituted from API responses. | ||
|
||
Examples: | ||
|
||
.. code-block:: python | ||
|
||
from typing import Dict, Any, Type | ||
|
||
class CustomModel: | ||
|
||
foo: str | ||
bar: str | ||
|
||
def __init__(self, *, foo: str, bar: str): | ||
self.foo = foo | ||
self.bar = bar | ||
|
||
def to_dict(self) -> Dict[str, Any]: | ||
return { | ||
"foo": self.foo, | ||
"bar": self.bar | ||
} | ||
|
||
@classmethod | ||
def from_dict(cls: Type["CustomModel"], data: Dict[str, Any]) -> "CustomModel": | ||
return cls( | ||
foo=data["foo"], | ||
bar=data["bar"] | ||
) | ||
""" | ||
|
||
def to_dict(self) -> Dict[str, Any]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we expecting people to implement any internal recursion themselves? Or do we do that for them? I was wondering if instead of using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The SdkJSONEncoder does automatically handle nested custom types and will call the nested type to_dict methods internally. E.g.: if someone implements to_dict like: Class FooModel:
foo: str
bar: BarModel
...
def to_dict(self) -> Dict[str, Any]:
return {"foo": self.foo, "bar": self.bar}
For Regarding the typing, I think keeping it simpler with |
||
"""Returns a dictionary representation of the object. | ||
|
||
The keys of the dictionary should correspond to the REST API's JSON field names. This method is responsible | ||
for mapping the object's attributes to the correct wire format. | ||
|
||
:return: A dictionary representing the object. | ||
:rtype: dict[str, any] | ||
""" | ||
... | ||
|
||
@classmethod | ||
def from_dict(cls: Type[T], data: Dict[str, Any]) -> T: | ||
"""Creates an instance of the class from a dictionary. | ||
|
||
The dictionary keys are expected to be the REST API's JSON field names. This method is responsible for | ||
mapping the incoming dictionary to the object's attributes. | ||
|
||
:param data: A dictionary containing the object's data. | ||
:type data: dict[str, any] | ||
:return: An instance of the class. | ||
:rtype: ~T | ||
""" | ||
... | ||
|
||
|
||
def _timedelta_as_isostr(td: timedelta) -> str: | ||
"""Converts a datetime.timedelta object into an ISO 8601 formatted string, e.g. 'P4DT12H30M05S' | ||
|
||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -28,7 +28,7 @@ | |||||||||||||||
from azure.core.exceptions import DeserializationError | ||||||||||||||||
from azure.core import CaseInsensitiveEnumMeta | ||||||||||||||||
from azure.core.pipeline import PipelineResponse | ||||||||||||||||
from azure.core.serialization import _Null | ||||||||||||||||
from azure.core.serialization import _Null, Serializable | ||||||||||||||||
|
||||||||||||||||
_LOGGER = logging.getLogger(__name__) | ||||||||||||||||
|
||||||||||||||||
|
@@ -161,6 +161,11 @@ def default(self, o): # pylint: disable=too-many-return-statements | |||||||||||||||
except AttributeError: | ||||||||||||||||
# This will be raised when it hits value.total_seconds in the method above | ||||||||||||||||
pass | ||||||||||||||||
|
||||||||||||||||
# Check if the object implements the Serializable protocol | ||||||||||||||||
if isinstance(o, Serializable): | ||||||||||||||||
return o.to_dict() | ||||||||||||||||
|
||||||||||||||||
return super(SdkJSONEncoder, self).default(o) | ||||||||||||||||
|
||||||||||||||||
|
||||||||||||||||
|
@@ -510,6 +515,9 @@ def _serialize(o, format: typing.Optional[str] = None): # pylint: disable=too-m | |||||||||||||||
except AttributeError: | ||||||||||||||||
# This will be raised when it hits value.total_seconds in the method above | ||||||||||||||||
pass | ||||||||||||||||
|
||||||||||||||||
if isinstance(o, Serializable): | ||||||||||||||||
return o.to_dict() | ||||||||||||||||
return o | ||||||||||||||||
|
||||||||||||||||
|
||||||||||||||||
|
@@ -886,6 +894,9 @@ def _deserialize_default( | |||||||||||||||
if get_deserializer(annotation, rf): | ||||||||||||||||
return functools.partial(_deserialize_default, get_deserializer(annotation, rf)) | ||||||||||||||||
|
||||||||||||||||
if isinstance(annotation, Serializable): | ||||||||||||||||
return annotation.from_dict # type: ignore | ||||||||||||||||
Comment on lines
+897
to
+898
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This check is incorrect.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||||||||
|
||||||||||||||||
return functools.partial(_deserialize_default, annotation) | ||||||||||||||||
|
||||||||||||||||
|
||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -10,9 +10,16 @@ | |||
from typing import Any, Dict, List, Optional | ||||
from io import BytesIO | ||||
|
||||
from azure.core.serialization import AzureJSONEncoder, NULL, as_attribute_dict, is_generated_model, attribute_list | ||||
from azure.core.serialization import ( | ||||
AzureJSONEncoder, | ||||
NULL, | ||||
as_attribute_dict, | ||||
is_generated_model, | ||||
attribute_list, | ||||
Serializable, | ||||
) | ||||
import pytest | ||||
from modeltypes._utils.model_base import Model as HybridModel, rest_field | ||||
from modeltypes._utils.model_base import Model as HybridModel, SdkJSONEncoder, rest_field, _deserialize | ||||
from modeltypes._utils.serialization import Model as MsrestModel | ||||
from modeltypes import models | ||||
|
||||
|
@@ -972,3 +979,120 @@ def _tests(model): | |||
"birthdate": "2017-12-13T02:29:51Z", | ||||
"complexProperty": {"color": "Red"}, | ||||
} | ||||
|
||||
|
||||
class TestSerializableProtocol: | ||||
|
||||
class FooModel: | ||||
|
||||
foo: str | ||||
bar: int | ||||
baz: float | ||||
|
||||
def __init__(self, foo: str, bar: int, baz: float): | ||||
self.foo = foo | ||||
self.bar = bar | ||||
self.baz = baz | ||||
|
||||
def to_dict(self) -> Dict[str, Any]: | ||||
return {"foo": self.foo, "bar": self.bar, "baz": self.baz} | ||||
|
||||
@classmethod | ||||
def from_dict(cls, data: Dict[str, Any]) -> "TestSerializableProtocol.FooModel": | ||||
return cls(foo=data["foo"], bar=data["bar"], baz=data["baz"]) | ||||
|
||||
def test_is_serializable_protocol(self): | ||||
model = TestSerializableProtocol.FooModel("hello", 42, 3.14) | ||||
assert isinstance(model, Serializable) | ||||
assert issubclass(TestSerializableProtocol.FooModel, Serializable) | ||||
|
||||
assert not isinstance(models.Fish(kind="goldfish", age=1), Serializable) | ||||
assert not issubclass(models.Fish, Serializable) | ||||
|
||||
assert hasattr(model, "to_dict") | ||||
assert hasattr(TestSerializableProtocol.FooModel, "from_dict") | ||||
|
||||
def test_serialization(self): | ||||
model = TestSerializableProtocol.FooModel("hello", 42, 3.14) | ||||
|
||||
json_str = json.dumps(model, cls=SdkJSONEncoder, exclude_readonly=True) | ||||
assert json.loads(json_str) == {"foo": "hello", "bar": 42, "baz": 3.14} | ||||
|
||||
def test_serialize_custom_model_in_generated_model(self): | ||||
|
||||
class GeneratedModel(HybridModel): | ||||
dog: models.HybridDog = rest_field(visibility=["read", "create", "update", "delete", "query"]) | ||||
external: TestSerializableProtocol.FooModel = rest_field( | ||||
visibility=["read", "create", "update", "delete", "query"] | ||||
) | ||||
|
||||
model = GeneratedModel( | ||||
dog=models.HybridDog(name="doggy", species="dog", breed="samoyed", is_best_boy=True), | ||||
external=TestSerializableProtocol.FooModel(foo="foo", bar=42, baz=3.14), | ||||
) | ||||
|
||||
json_str = json.dumps(model, cls=SdkJSONEncoder, exclude_readonly=True) | ||||
|
||||
expected_dict = { | ||||
"dog": { | ||||
"name": "doggy", | ||||
"species": "dog", | ||||
"breed": "samoyed", | ||||
"isBestBoy": True, | ||||
}, | ||||
"external": { | ||||
"foo": "foo", | ||||
"bar": 42, | ||||
"baz": 3.14, | ||||
}, | ||||
} | ||||
|
||||
json_str = json.dumps(model, cls=SdkJSONEncoder, exclude_readonly=True) | ||||
assert json.loads(json_str) == expected_dict | ||||
assert model.as_dict() == expected_dict | ||||
|
||||
def test_deserialize_custom_model(self): | ||||
json_dict = { | ||||
"foo": "foo", | ||||
"bar": 42, | ||||
"baz": 3.14, | ||||
} | ||||
json_dict = {"foo": "foo", "bar": 42, "baz": 3.14} | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line duplicates the dictionary definition from line 1055-1059. Remove this duplicate assignment.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||
deserialized = _deserialize(TestSerializableProtocol.FooModel, json_dict) | ||||
assert isinstance(deserialized, TestSerializableProtocol.FooModel) | ||||
assert deserialized.foo == "foo" | ||||
assert deserialized.bar == 42 | ||||
assert deserialized.baz == 3.14 | ||||
|
||||
def test_deserialize_custom_model_in_generated_model(self): | ||||
|
||||
class GeneratedModel(HybridModel): | ||||
dog: models.HybridDog = rest_field(visibility=["read", "create", "update", "delete", "query"]) | ||||
external: TestSerializableProtocol.FooModel = rest_field( | ||||
visibility=["read", "create", "update", "delete", "query"] | ||||
) | ||||
|
||||
json_dict = { | ||||
"dog": { | ||||
"name": "doggy", | ||||
"species": "dog", | ||||
"breed": "samoyed", | ||||
"isBestBoy": True, | ||||
}, | ||||
"external": { | ||||
"foo": "foo", | ||||
"bar": 42, | ||||
"baz": 3.14, | ||||
}, | ||||
} | ||||
deserialized = _deserialize(GeneratedModel, json_dict) | ||||
assert isinstance(deserialized, GeneratedModel) | ||||
assert isinstance(deserialized.dog, models.HybridDog) | ||||
assert deserialized.dog.name == "doggy" | ||||
assert deserialized.dog.species == "dog" | ||||
assert deserialized.dog.breed == "samoyed" | ||||
assert deserialized.dog.is_best_boy is True | ||||
assert isinstance(deserialized.external, TestSerializableProtocol.FooModel) | ||||
assert deserialized.external.foo == "foo" | ||||
assert deserialized.external.bar == 42 | ||||
assert deserialized.external.baz == 3.14 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
Serializable
protocol should be added to__all__
to make it part of the public API since it's intended for external use by custom models.Copilot uses AI. Check for mistakes.