55
66from abc import ABC
77from collections import namedtuple
8- from typing import Mapping , Union
8+ from typing import Mapping , Optional , Type , TypeVar , Union , cast , overload
9+ from typing_extensions import Literal
910
1011from marshmallow import Schema , post_dump , pre_load , post_load , ValidationError , EXCLUDE
1112
1718SerDe = namedtuple ("SerDe" , "ser de" )
1819
1920
20- def resolve_class (the_cls , relative_cls : type = None ):
21+ def resolve_class (the_cls , relative_cls : Optional [ type ] = None ) -> type :
2122 """
2223 Resolve a class.
2324
@@ -38,6 +39,10 @@ def resolve_class(the_cls, relative_cls: type = None):
3839 elif isinstance (the_cls , str ):
3940 default_module = relative_cls and relative_cls .__module__
4041 resolved = ClassLoader .load_class (the_cls , default_module )
42+ else :
43+ raise TypeError (
44+ f"Could not resolve class from { the_cls } ; incorrect type { type (the_cls )} "
45+ )
4146 return resolved
4247
4348
@@ -53,7 +58,10 @@ def resolve_meta_property(obj, prop_name: str, defval=None):
5358 The meta property
5459
5560 """
56- cls = obj .__class__
61+ if isinstance (obj , type ):
62+ cls = obj
63+ else :
64+ cls = obj .__class__
5765 found = defval
5866 while cls :
5967 Meta = getattr (cls , "Meta" , None )
@@ -70,6 +78,9 @@ class BaseModelError(BaseError):
7078 """Base exception class for base model errors."""
7179
7280
81+ ModelType = TypeVar ("ModelType" , bound = "BaseModel" )
82+
83+
7384class BaseModel (ABC ):
7485 """Base model that provides convenience methods."""
7586
@@ -94,18 +105,24 @@ def __init__(self):
94105 )
95106
96107 @classmethod
97- def _get_schema_class (cls ):
108+ def _get_schema_class (cls ) -> Type [ "BaseModelSchema" ] :
98109 """
99110 Get the schema class.
100111
101112 Returns:
102113 The resolved schema class
103114
104115 """
105- return resolve_class (cls .Meta .schema_class , cls )
116+ resolved = resolve_class (cls .Meta .schema_class , cls )
117+ if issubclass (resolved , BaseModelSchema ):
118+ return resolved
119+
120+ raise TypeError (
121+ f"Resolved class is not a subclass of BaseModelSchema: { resolved } "
122+ )
106123
107124 @property
108- def Schema (self ) -> type :
125+ def Schema (self ) -> Type [ "BaseModelSchema" ] :
109126 """
110127 Accessor for the model's schema class.
111128
@@ -115,8 +132,49 @@ def Schema(self) -> type:
115132 """
116133 return self ._get_schema_class ()
117134
135+ @overload
136+ @classmethod
137+ def deserialize (
138+ cls : Type [ModelType ],
139+ obj ,
140+ * ,
141+ unknown : Optional [str ] = None ,
142+ ) -> ModelType :
143+ """Convert from JSON representation to a model instance."""
144+ ...
145+
146+ @overload
118147 @classmethod
119- def deserialize (cls , obj , unknown : str = None , none2none : str = False ):
148+ def deserialize (
149+ cls : Type [ModelType ],
150+ obj ,
151+ * ,
152+ none2none : Literal [False ],
153+ unknown : Optional [str ] = None ,
154+ ) -> ModelType :
155+ """Convert from JSON representation to a model instance."""
156+ ...
157+
158+ @overload
159+ @classmethod
160+ def deserialize (
161+ cls : Type [ModelType ],
162+ obj ,
163+ * ,
164+ none2none : Literal [True ],
165+ unknown : Optional [str ] = None ,
166+ ) -> Optional [ModelType ]:
167+ """Convert from JSON representation to a model instance."""
168+ ...
169+
170+ @classmethod
171+ def deserialize (
172+ cls : Type [ModelType ],
173+ obj ,
174+ * ,
175+ unknown : Optional [str ] = None ,
176+ none2none : bool = False ,
177+ ) -> Optional [ModelType ]:
120178 """
121179 Convert from JSON representation to a model instance.
122180
@@ -132,18 +190,45 @@ def deserialize(cls, obj, unknown: str = None, none2none: str = False):
132190 if obj is None and none2none :
133191 return None
134192
135- schema = cls ._get_schema_class ()(unknown = unknown or EXCLUDE )
193+ schema_cls = cls ._get_schema_class ()
194+ schema = schema_cls (
195+ unknown = unknown or resolve_meta_property (schema_cls , "unknown" , EXCLUDE )
196+ )
197+
136198 try :
137- return schema .loads (obj ) if isinstance (obj , str ) else schema .load (obj )
199+ return cast (
200+ ModelType ,
201+ schema .loads (obj ) if isinstance (obj , str ) else schema .load (obj ),
202+ )
138203 except (AttributeError , ValidationError ) as err :
139204 LOGGER .exception (f"{ cls .__name__ } message validation error:" )
140205 raise BaseModelError (f"{ cls .__name__ } schema validation failed" ) from err
141206
207+ @overload
208+ def serialize (
209+ self ,
210+ * ,
211+ as_string : Literal [True ],
212+ unknown : Optional [str ] = None ,
213+ ) -> str :
214+ """Create a JSON-compatible dict representation of the model instance."""
215+ ...
216+
217+ @overload
142218 def serialize (
143219 self ,
144- as_string = False ,
145- unknown : str = None ,
220+ * ,
221+ unknown : Optional [ str ] = None ,
146222 ) -> dict :
223+ """Create a JSON-compatible dict representation of the model instance."""
224+ ...
225+
226+ def serialize (
227+ self ,
228+ * ,
229+ as_string : bool = False ,
230+ unknown : Optional [str ] = None ,
231+ ) -> Union [str , dict ]:
147232 """
148233 Create a JSON-compatible dict representation of the model instance.
149234
@@ -154,7 +239,10 @@ def serialize(
154239 A dict representation of this model, or a JSON string if as_string is True
155240
156241 """
157- schema = self .Schema (unknown = unknown or EXCLUDE )
242+ schema_cls = self ._get_schema_class ()
243+ schema = schema_cls (
244+ unknown = unknown or resolve_meta_property (schema_cls , "unknown" , EXCLUDE )
245+ )
158246 try :
159247 return (
160248 schema .dumps (self , separators = ("," , ":" ))
@@ -168,18 +256,17 @@ def serialize(
168256 ) from err
169257
170258 @classmethod
171- def serde (cls , obj : Union ["BaseModel" , Mapping ]) -> SerDe :
259+ def serde (cls , obj : Union ["BaseModel" , Mapping ]) -> Optional [ SerDe ] :
172260 """Return serialized, deserialized representations of input object."""
261+ if obj is None :
262+ return None
173263
174- return (
175- SerDe (obj .serialize (), obj )
176- if isinstance (obj , BaseModel )
177- else None
178- if obj is None
179- else SerDe (obj , cls .deserialize (obj ))
180- )
264+ if isinstance (obj , BaseModel ):
265+ return SerDe (obj .serialize (), obj )
266+
267+ return SerDe (obj , cls .deserialize (obj ))
181268
182- def validate (self , unknown : str = None ):
269+ def validate (self , unknown : Optional [ str ] = None ):
183270 """Validate a constructed model."""
184271 schema = self .Schema (unknown = unknown )
185272 errors = schema .validate (self .serialize ())
@@ -191,7 +278,7 @@ def validate(self, unknown: str = None):
191278 def from_json (
192279 cls ,
193280 json_repr : Union [str , bytes ],
194- unknown : str = None ,
281+ unknown : Optional [ str ] = None ,
195282 ):
196283 """
197284 Parse a JSON string into a model instance.
@@ -218,7 +305,7 @@ def to_json(self, unknown: str = None) -> str:
218305 A JSON representation of this message
219306
220307 """
221- return json .dumps (self .serialize (unknown = unknown or EXCLUDE ))
308+ return json .dumps (self .serialize (unknown = unknown ))
222309
223310 def __repr__ (self ) -> str :
224311 """
0 commit comments