Skip to content

Commit 38252c0

Browse files
authored
Merge branch 'main' into feature/standalone-docker
2 parents 417f435 + 1fb0bca commit 38252c0

File tree

4 files changed

+180
-36
lines changed

4 files changed

+180
-36
lines changed

aries_cloudagent/messaging/models/base.py

Lines changed: 110 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
from abc import ABC
77
from collections import namedtuple
8-
from typing import Mapping, Union
8+
from typing import Mapping, Optional, Type, TypeVar, Union, cast, overload
9+
from typing_extensions import Literal
910

1011
from marshmallow import Schema, post_dump, pre_load, post_load, ValidationError, EXCLUDE
1112

@@ -17,7 +18,7 @@
1718
SerDe = namedtuple("SerDe", "ser de")
1819

1920

20-
def resolve_class(the_cls, relative_cls: type = None):
21+
def resolve_class(the_cls, relative_cls: Optional[type] = None) -> type:
2122
"""
2223
Resolve a class.
2324
@@ -38,6 +39,10 @@ def resolve_class(the_cls, relative_cls: type = None):
3839
elif isinstance(the_cls, str):
3940
default_module = relative_cls and relative_cls.__module__
4041
resolved = ClassLoader.load_class(the_cls, default_module)
42+
else:
43+
raise TypeError(
44+
f"Could not resolve class from {the_cls}; incorrect type {type(the_cls)}"
45+
)
4146
return resolved
4247

4348

@@ -53,7 +58,10 @@ def resolve_meta_property(obj, prop_name: str, defval=None):
5358
The meta property
5459
5560
"""
56-
cls = obj.__class__
61+
if isinstance(obj, type):
62+
cls = obj
63+
else:
64+
cls = obj.__class__
5765
found = defval
5866
while cls:
5967
Meta = getattr(cls, "Meta", None)
@@ -70,6 +78,9 @@ class BaseModelError(BaseError):
7078
"""Base exception class for base model errors."""
7179

7280

81+
ModelType = TypeVar("ModelType", bound="BaseModel")
82+
83+
7384
class BaseModel(ABC):
7485
"""Base model that provides convenience methods."""
7586

@@ -94,18 +105,24 @@ def __init__(self):
94105
)
95106

96107
@classmethod
97-
def _get_schema_class(cls):
108+
def _get_schema_class(cls) -> Type["BaseModelSchema"]:
98109
"""
99110
Get the schema class.
100111
101112
Returns:
102113
The resolved schema class
103114
104115
"""
105-
return resolve_class(cls.Meta.schema_class, cls)
116+
resolved = resolve_class(cls.Meta.schema_class, cls)
117+
if issubclass(resolved, BaseModelSchema):
118+
return resolved
119+
120+
raise TypeError(
121+
f"Resolved class is not a subclass of BaseModelSchema: {resolved}"
122+
)
106123

107124
@property
108-
def Schema(self) -> type:
125+
def Schema(self) -> Type["BaseModelSchema"]:
109126
"""
110127
Accessor for the model's schema class.
111128
@@ -115,8 +132,49 @@ def Schema(self) -> type:
115132
"""
116133
return self._get_schema_class()
117134

135+
@overload
136+
@classmethod
137+
def deserialize(
138+
cls: Type[ModelType],
139+
obj,
140+
*,
141+
unknown: Optional[str] = None,
142+
) -> ModelType:
143+
"""Convert from JSON representation to a model instance."""
144+
...
145+
146+
@overload
118147
@classmethod
119-
def deserialize(cls, obj, unknown: str = None, none2none: str = False):
148+
def deserialize(
149+
cls: Type[ModelType],
150+
obj,
151+
*,
152+
none2none: Literal[False],
153+
unknown: Optional[str] = None,
154+
) -> ModelType:
155+
"""Convert from JSON representation to a model instance."""
156+
...
157+
158+
@overload
159+
@classmethod
160+
def deserialize(
161+
cls: Type[ModelType],
162+
obj,
163+
*,
164+
none2none: Literal[True],
165+
unknown: Optional[str] = None,
166+
) -> Optional[ModelType]:
167+
"""Convert from JSON representation to a model instance."""
168+
...
169+
170+
@classmethod
171+
def deserialize(
172+
cls: Type[ModelType],
173+
obj,
174+
*,
175+
unknown: Optional[str] = None,
176+
none2none: bool = False,
177+
) -> Optional[ModelType]:
120178
"""
121179
Convert from JSON representation to a model instance.
122180
@@ -132,18 +190,45 @@ def deserialize(cls, obj, unknown: str = None, none2none: str = False):
132190
if obj is None and none2none:
133191
return None
134192

135-
schema = cls._get_schema_class()(unknown=unknown or EXCLUDE)
193+
schema_cls = cls._get_schema_class()
194+
schema = schema_cls(
195+
unknown=unknown or resolve_meta_property(schema_cls, "unknown", EXCLUDE)
196+
)
197+
136198
try:
137-
return schema.loads(obj) if isinstance(obj, str) else schema.load(obj)
199+
return cast(
200+
ModelType,
201+
schema.loads(obj) if isinstance(obj, str) else schema.load(obj),
202+
)
138203
except (AttributeError, ValidationError) as err:
139204
LOGGER.exception(f"{cls.__name__} message validation error:")
140205
raise BaseModelError(f"{cls.__name__} schema validation failed") from err
141206

207+
@overload
208+
def serialize(
209+
self,
210+
*,
211+
as_string: Literal[True],
212+
unknown: Optional[str] = None,
213+
) -> str:
214+
"""Create a JSON-compatible dict representation of the model instance."""
215+
...
216+
217+
@overload
142218
def serialize(
143219
self,
144-
as_string=False,
145-
unknown: str = None,
220+
*,
221+
unknown: Optional[str] = None,
146222
) -> dict:
223+
"""Create a JSON-compatible dict representation of the model instance."""
224+
...
225+
226+
def serialize(
227+
self,
228+
*,
229+
as_string: bool = False,
230+
unknown: Optional[str] = None,
231+
) -> Union[str, dict]:
147232
"""
148233
Create a JSON-compatible dict representation of the model instance.
149234
@@ -154,7 +239,10 @@ def serialize(
154239
A dict representation of this model, or a JSON string if as_string is True
155240
156241
"""
157-
schema = self.Schema(unknown=unknown or EXCLUDE)
242+
schema_cls = self._get_schema_class()
243+
schema = schema_cls(
244+
unknown=unknown or resolve_meta_property(schema_cls, "unknown", EXCLUDE)
245+
)
158246
try:
159247
return (
160248
schema.dumps(self, separators=(",", ":"))
@@ -168,18 +256,17 @@ def serialize(
168256
) from err
169257

170258
@classmethod
171-
def serde(cls, obj: Union["BaseModel", Mapping]) -> SerDe:
259+
def serde(cls, obj: Union["BaseModel", Mapping]) -> Optional[SerDe]:
172260
"""Return serialized, deserialized representations of input object."""
261+
if obj is None:
262+
return None
173263

174-
return (
175-
SerDe(obj.serialize(), obj)
176-
if isinstance(obj, BaseModel)
177-
else None
178-
if obj is None
179-
else SerDe(obj, cls.deserialize(obj))
180-
)
264+
if isinstance(obj, BaseModel):
265+
return SerDe(obj.serialize(), obj)
266+
267+
return SerDe(obj, cls.deserialize(obj))
181268

182-
def validate(self, unknown: str = None):
269+
def validate(self, unknown: Optional[str] = None):
183270
"""Validate a constructed model."""
184271
schema = self.Schema(unknown=unknown)
185272
errors = schema.validate(self.serialize())
@@ -191,7 +278,7 @@ def validate(self, unknown: str = None):
191278
def from_json(
192279
cls,
193280
json_repr: Union[str, bytes],
194-
unknown: str = None,
281+
unknown: Optional[str] = None,
195282
):
196283
"""
197284
Parse a JSON string into a model instance.
@@ -218,7 +305,7 @@ def to_json(self, unknown: str = None) -> str:
218305
A JSON representation of this message
219306
220307
"""
221-
return json.dumps(self.serialize(unknown=unknown or EXCLUDE))
308+
return json.dumps(self.serialize(unknown=unknown))
222309

223310
def __repr__(self) -> str:
224311
"""

aries_cloudagent/messaging/models/tests/test_base.py

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,6 @@
1-
import json
2-
31
from asynctest import TestCase as AsyncTestCase, mock as async_mock
42

5-
from marshmallow import EXCLUDE, fields, validates_schema, ValidationError
6-
7-
from ....cache.base import BaseCache
8-
from ....config.injection_context import InjectionContext
9-
from ....storage.base import BaseStorage, StorageRecord
10-
11-
from ...responder import BaseResponder, MockResponder
12-
from ...util import time_now
3+
from marshmallow import EXCLUDE, INCLUDE, fields, validates_schema, ValidationError
134

145
from ..base import BaseModel, BaseModelError, BaseModelSchema
156

@@ -35,6 +26,48 @@ def validate_fields(self, data, **kwargs):
3526
raise ValidationError("")
3627

3728

29+
class ModelImplWithUnknown(BaseModel):
30+
class Meta:
31+
schema_class = "SchemaImplWithUnknown"
32+
33+
def __init__(self, *, attr=None, **kwargs):
34+
self.attr = attr
35+
self.extra = kwargs
36+
37+
38+
class SchemaImplWithUnknown(BaseModelSchema):
39+
class Meta:
40+
model_class = ModelImplWithUnknown
41+
unknown = INCLUDE
42+
43+
attr = fields.String(required=True)
44+
45+
@validates_schema
46+
def validate_fields(self, data, **kwargs):
47+
if data["attr"] != "succeeds":
48+
raise ValidationError("")
49+
50+
51+
class ModelImplWithoutUnknown(BaseModel):
52+
class Meta:
53+
schema_class = "SchemaImplWithoutUnknown"
54+
55+
def __init__(self, *, attr=None):
56+
self.attr = attr
57+
58+
59+
class SchemaImplWithoutUnknown(BaseModelSchema):
60+
class Meta:
61+
model_class = ModelImplWithoutUnknown
62+
63+
attr = fields.String(required=True)
64+
65+
@validates_schema
66+
def validate_fields(self, data, **kwargs):
67+
if data["attr"] != "succeeds":
68+
raise ValidationError("")
69+
70+
3871
class TestBase(AsyncTestCase):
3972
def test_model_validate_fails(self):
4073
model = ModelImpl(attr="string")
@@ -63,3 +96,24 @@ def test_from_json_x(self):
6396
data = "{}{}"
6497
with self.assertRaises(BaseModelError):
6598
ModelImpl.from_json(data)
99+
100+
def test_model_with_unknown(self):
101+
model = ModelImplWithUnknown(attr="succeeds")
102+
model = model.validate()
103+
assert model.attr == "succeeds"
104+
105+
model = ModelImplWithUnknown.deserialize(
106+
{"attr": "succeeds", "another": "value"}
107+
)
108+
assert model.extra
109+
assert model.extra["another"] == "value"
110+
assert model.attr == "succeeds"
111+
112+
def test_model_without_unknown_default_exclude(self):
113+
model = ModelImplWithoutUnknown(attr="succeeds")
114+
model = model.validate()
115+
assert model.attr == "succeeds"
116+
117+
assert ModelImplWithoutUnknown.deserialize(
118+
{"attr": "succeeds", "another": "value"}
119+
)

aries_cloudagent/utils/classloader.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from importlib import import_module
88
from importlib.util import find_spec, resolve_name
99
from types import ModuleType
10-
from typing import Sequence, Type
10+
from typing import Optional, Sequence, Type
1111

1212
from ..core.error import BaseError
1313

@@ -75,7 +75,10 @@ def load_module(cls, mod_path: str, package: str = None) -> ModuleType:
7575

7676
@classmethod
7777
def load_class(
78-
cls, class_name: str, default_module: str = None, package: str = None
78+
cls,
79+
class_name: str,
80+
default_module: Optional[str] = None,
81+
package: Optional[str] = None,
7982
):
8083
"""
8184
Resolve a complete class path (ie. typing.Dict) to the class itself.

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ markupsafe==2.0.1
1313
marshmallow==3.5.1
1414
msgpack~=1.0
1515
prompt_toolkit~=2.0.9
16-
pynacl~=1.4.0
16+
pynacl~=1.5.0
1717
requests~=2.25.0
1818
packaging~=20.4
1919
pyld~=2.0.3

0 commit comments

Comments
 (0)