Skip to content
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
default_stages: [commit]
default_stages: [pre-commit]
default_install_hook_types: [pre-commit, commit-msg]

ci:
Expand Down Expand Up @@ -37,6 +37,6 @@ repos:
rev: v2.2.6
hooks:
- id: codespell
stages: [commit, commit-msg]
stages: [pre-commit, commit-msg]
exclude_types: [json, bib, svg]
args: [--ignore-words-list, "mater,fwe,te"]
28 changes: 25 additions & 3 deletions mp_api/client/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,18 @@
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
from copy import copy
from functools import cache
from importlib import import_module
from importlib.metadata import PackageNotFoundError, version
from json import JSONDecodeError
from math import ceil
from typing import TYPE_CHECKING, Generic, TypeVar
from typing import (
TYPE_CHECKING,
ForwardRef,
Generic,
TypeVar,
_eval_type,
get_args,
)
from urllib.parse import quote, urljoin

import requests
Expand Down Expand Up @@ -65,7 +73,7 @@ class BaseRester(Generic[T]):
"""Base client class with core stubs."""

suffix: str = ""
document_model: BaseModel = None # type: ignore
document_model: type[BaseModel] | None = None
supports_versions: bool = False
primary_key: str = "material_id"

Expand Down Expand Up @@ -1070,10 +1078,24 @@ def _convert_to_model(self, data: list[dict]):

def _generate_returned_model(self, doc):
model_fields = self.document_model.model_fields

set_fields = doc.model_fields_set
unset_fields = [field for field in model_fields if field not in set_fields]

# Update with locals() from external module if needed
other_vars = {}
if any(
isinstance(typ, ForwardRef)
for name in set_fields
for typ in get_args(model_fields[name].annotation)
):
other_vars = vars(import_module(self.document_model.__module__))

include_fields = {
name: (model_fields[name].annotation, model_fields[name])
name: (
_eval_type(model_fields[name].annotation, other_vars, {}, frozenset()),
model_fields[name],
)
for name in set_fields
}

Expand Down
13 changes: 9 additions & 4 deletions mp_api/client/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ def api_sanitize(

for model in models:
model_fields_to_leave = {f[1] for f in fields_tuples if model.__name__ == f[0]}
for name in model.model_fields:
field = model.model_fields[name]
field_json_extra = field.json_schema_extra
for name, field in model.model_fields.items():
field_type = field.annotation

if field_type is not None and allow_dict_msonable:
Expand All @@ -88,7 +86,14 @@ def api_sanitize(
new_field = FieldInfo.from_annotated_attribute(
Optional[field_type], None
)
new_field.json_schema_extra = field_json_extra or {}

for attr in (
"json_schema_extra",
"exclude",
):
if (val := getattr(field, attr)) is not None:
setattr(new_field, attr, val)

model.model_fields[name] = new_field

model.model_rebuild(force=True)
Expand Down
Loading