Skip to content

Commit 6376bb6

Browse files
Allow for model excluded fields in MPDataDoc construction (#989)
1 parent cb30cd5 commit 6376bb6

10 files changed

+78
-37
lines changed

.github/workflows/testing.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ jobs:
6060
#MP_API_ENDPOINT: https://api-preview.materialsproject.org/
6161
run: |
6262
pip install -e .
63-
pytest -x --cov=mp_api --cov-report=xml
63+
pytest -n auto -x --cov=mp_api --cov-report=xml
6464
- uses: codecov/codecov-action@v1
6565
with:
6666
token: ${{ secrets.CODECOV_TOKEN }}

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
default_stages: [commit]
1+
default_stages: [pre-commit]
22
default_install_hook_types: [pre-commit, commit-msg]
33

44
ci:
@@ -37,6 +37,6 @@ repos:
3737
rev: v2.2.6
3838
hooks:
3939
- id: codespell
40-
stages: [commit, commit-msg]
40+
stages: [pre-commit, commit-msg]
4141
exclude_types: [json, bib, svg]
4242
args: [--ignore-words-list, "mater,fwe,te"]

mp_api/client/core/client.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,17 @@
1515
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
1616
from copy import copy
1717
from functools import cache
18+
from importlib import import_module
1819
from importlib.metadata import PackageNotFoundError, version
1920
from json import JSONDecodeError
2021
from math import ceil
21-
from typing import TYPE_CHECKING, Generic, TypeVar
22+
from typing import (
23+
TYPE_CHECKING,
24+
ForwardRef,
25+
Generic,
26+
TypeVar,
27+
get_args,
28+
)
2229
from urllib.parse import quote, urljoin
2330

2431
import requests
@@ -65,7 +72,7 @@ class BaseRester(Generic[T]):
6572
"""Base client class with core stubs."""
6673

6774
suffix: str = ""
68-
document_model: BaseModel = None # type: ignore
75+
document_model: type[BaseModel] | None = None
6976
supports_versions: bool = False
7077
primary_key: str = "material_id"
7178

@@ -1070,10 +1077,24 @@ def _convert_to_model(self, data: list[dict]):
10701077

10711078
def _generate_returned_model(self, doc):
10721079
model_fields = self.document_model.model_fields
1080+
10731081
set_fields = doc.model_fields_set
10741082
unset_fields = [field for field in model_fields if field not in set_fields]
1083+
1084+
# Update with locals() from external module if needed
1085+
other_vars = {}
1086+
if any(
1087+
isinstance(typ, ForwardRef)
1088+
for field_meta in model_fields.values()
1089+
for typ in get_args(field_meta.annotation)
1090+
):
1091+
other_vars = vars(import_module(self.document_model.__module__))
1092+
10751093
include_fields = {
1076-
name: (model_fields[name].annotation, model_fields[name])
1094+
name: (
1095+
model_fields[name].annotation,
1096+
model_fields[name],
1097+
)
10771098
for name in set_fields
10781099
}
10791100

@@ -1085,6 +1106,8 @@ def _generate_returned_model(self, doc):
10851106
fields_not_requested=(list[str], unset_fields),
10861107
__base__=self.document_model,
10871108
)
1109+
if other_vars:
1110+
data_model.model_rebuild(_types_namespace=other_vars)
10881111

10891112
def new_repr(self) -> str:
10901113
extra = ",\n".join(

mp_api/client/core/utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,7 @@ def api_sanitize(
7171

7272
for model in models:
7373
model_fields_to_leave = {f[1] for f in fields_tuples if model.__name__ == f[0]}
74-
for name in model.model_fields:
75-
field = model.model_fields[name]
76-
field_json_extra = field.json_schema_extra
74+
for name, field in model.model_fields.items():
7775
field_type = field.annotation
7876

7977
if field_type is not None and allow_dict_msonable:
@@ -88,7 +86,14 @@ def api_sanitize(
8886
new_field = FieldInfo.from_annotated_attribute(
8987
Optional[field_type], None
9088
)
91-
new_field.json_schema_extra = field_json_extra or {}
89+
90+
for attr in (
91+
"json_schema_extra",
92+
"exclude",
93+
):
94+
if (val := getattr(field, attr)) is not None:
95+
setattr(new_field, attr, val)
96+
9297
model.model_fields[name] = new_field
9398

9499
model.model_rebuild(force=True)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ test = [
4040
"pytest-asyncio",
4141
"pytest-cov",
4242
"pytest-mock",
43+
"pytest-xdist",
4344
"flake8",
4445
"pycodestyle",
4546
"mypy",

requirements/requirements-ubuntu-latest_py3.11.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ bcrypt==4.3.0
1717
# via paramiko
1818
bibtexparser==1.4.3
1919
# via pymatgen
20-
boto3==1.40.29
20+
boto3==1.40.31
2121
# via maggma
22-
botocore==1.40.29
22+
botocore==1.40.31
2323
# via
2424
# boto3
2525
# s3transfer
@@ -83,7 +83,7 @@ msgpack==1.1.1
8383
# via
8484
# maggma
8585
# mp-api (pyproject.toml)
86-
narwhals==2.4.0
86+
narwhals==2.5.0
8787
# via plotly
8888
networkx==3.5
8989
# via pymatgen
@@ -123,7 +123,7 @@ pybtex==0.25.1
123123
# via emmet-core
124124
pycparser==2.23
125125
# via cffi
126-
pydantic==2.11.7
126+
pydantic==2.11.9
127127
# via
128128
# emmet-core
129129
# maggma
@@ -149,7 +149,7 @@ pymongo==4.10.1
149149
# via maggma
150150
pynacl==1.6.0
151151
# via paramiko
152-
pyparsing==3.2.3
152+
pyparsing==3.2.4
153153
# via
154154
# bibtexparser
155155
# matplotlib

requirements/requirements-ubuntu-latest_py3.11_extras.txt

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@ bibtexparser==1.4.3
2929
# via pymatgen
3030
boltons==25.0.0
3131
# via mpcontribs-client
32-
boto3==1.40.29
32+
boto3==1.40.31
3333
# via
3434
# maggma
3535
# mp-api (pyproject.toml)
36-
botocore==1.40.29
36+
botocore==1.40.31
3737
# via
3838
# boto3
3939
# s3transfer
@@ -76,6 +76,8 @@ docutils==0.21.2
7676
# via sphinx
7777
emmet-core[all]==0.84.10rc2
7878
# via mp-api (pyproject.toml)
79+
execnet==2.1.1
80+
# via pytest-xdist
7981
executing==2.2.1
8082
# via stack-data
8183
filelock==3.19.1
@@ -221,7 +223,7 @@ mypy-extensions==1.1.0
221223
# via
222224
# mp-api (pyproject.toml)
223225
# mypy
224-
narwhals==2.4.0
226+
narwhals==2.5.0
225227
# via plotly
226228
networkx==3.5
227229
# via
@@ -343,7 +345,7 @@ pycodestyle==2.14.0
343345
# mp-api (pyproject.toml)
344346
pycparser==2.23
345347
# via cffi
346-
pydantic==2.11.7
348+
pydantic==2.11.9
347349
# via
348350
# emmet-core
349351
# maggma
@@ -395,7 +397,7 @@ pymongo==4.10.1
395397
# mpcontribs-client
396398
pynacl==1.6.0
397399
# via paramiko
398-
pyparsing==3.2.3
400+
pyparsing==3.2.4
399401
# via
400402
# bibtexparser
401403
# matplotlib
@@ -405,13 +407,16 @@ pytest==8.4.2
405407
# pytest-asyncio
406408
# pytest-cov
407409
# pytest-mock
410+
# pytest-xdist
408411
# solvation-analysis
409-
pytest-asyncio==1.1.0
412+
pytest-asyncio==1.2.0
410413
# via mp-api (pyproject.toml)
411414
pytest-cov==7.0.0
412415
# via mp-api (pyproject.toml)
413416
pytest-mock==3.15.0
414417
# via mp-api (pyproject.toml)
418+
pytest-xdist==3.8.0
419+
# via mp-api (pyproject.toml)
415420
python-dateutil==2.9.0.post0
416421
# via
417422
# arrow
@@ -582,7 +587,7 @@ typeguard==4.4.4
582587
# via inflect
583588
types-python-dateutil==2.9.0.20250822
584589
# via arrow
585-
types-requests==2.32.4.20250809
590+
types-requests==2.32.4.20250913
586591
# via mp-api (pyproject.toml)
587592
types-setuptools==80.9.0.20250822
588593
# via mp-api (pyproject.toml)
@@ -599,6 +604,7 @@ typing-extensions==4.15.0
599604
# pydantic
600605
# pydantic-core
601606
# pydash
607+
# pytest-asyncio
602608
# referencing
603609
# spglib
604610
# swagger-spec-validator

requirements/requirements-ubuntu-latest_py3.12.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ bcrypt==4.3.0
1717
# via paramiko
1818
bibtexparser==1.4.3
1919
# via pymatgen
20-
boto3==1.40.29
20+
boto3==1.40.31
2121
# via maggma
22-
botocore==1.40.29
22+
botocore==1.40.31
2323
# via
2424
# boto3
2525
# s3transfer
@@ -83,7 +83,7 @@ msgpack==1.1.1
8383
# via
8484
# maggma
8585
# mp-api (pyproject.toml)
86-
narwhals==2.4.0
86+
narwhals==2.5.0
8787
# via plotly
8888
networkx==3.5
8989
# via pymatgen
@@ -123,7 +123,7 @@ pybtex==0.25.1
123123
# via emmet-core
124124
pycparser==2.23
125125
# via cffi
126-
pydantic==2.11.7
126+
pydantic==2.11.9
127127
# via
128128
# emmet-core
129129
# maggma
@@ -149,7 +149,7 @@ pymongo==4.10.1
149149
# via maggma
150150
pynacl==1.6.0
151151
# via paramiko
152-
pyparsing==3.2.3
152+
pyparsing==3.2.4
153153
# via
154154
# bibtexparser
155155
# matplotlib

requirements/requirements-ubuntu-latest_py3.12_extras.txt

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@ bibtexparser==1.4.3
2929
# via pymatgen
3030
boltons==25.0.0
3131
# via mpcontribs-client
32-
boto3==1.40.29
32+
boto3==1.40.31
3333
# via
3434
# maggma
3535
# mp-api (pyproject.toml)
36-
botocore==1.40.29
36+
botocore==1.40.31
3737
# via
3838
# boto3
3939
# s3transfer
@@ -76,6 +76,8 @@ docutils==0.21.2
7676
# via sphinx
7777
emmet-core[all]==0.84.10rc2
7878
# via mp-api (pyproject.toml)
79+
execnet==2.1.1
80+
# via pytest-xdist
7981
executing==2.2.1
8082
# via stack-data
8183
filelock==3.19.1
@@ -221,7 +223,7 @@ mypy-extensions==1.1.0
221223
# via
222224
# mp-api (pyproject.toml)
223225
# mypy
224-
narwhals==2.4.0
226+
narwhals==2.5.0
225227
# via plotly
226228
networkx==3.5
227229
# via
@@ -343,7 +345,7 @@ pycodestyle==2.14.0
343345
# mp-api (pyproject.toml)
344346
pycparser==2.23
345347
# via cffi
346-
pydantic==2.11.7
348+
pydantic==2.11.9
347349
# via
348350
# emmet-core
349351
# maggma
@@ -395,7 +397,7 @@ pymongo==4.10.1
395397
# mpcontribs-client
396398
pynacl==1.6.0
397399
# via paramiko
398-
pyparsing==3.2.3
400+
pyparsing==3.2.4
399401
# via
400402
# bibtexparser
401403
# matplotlib
@@ -405,13 +407,16 @@ pytest==8.4.2
405407
# pytest-asyncio
406408
# pytest-cov
407409
# pytest-mock
410+
# pytest-xdist
408411
# solvation-analysis
409-
pytest-asyncio==1.1.0
412+
pytest-asyncio==1.2.0
410413
# via mp-api (pyproject.toml)
411414
pytest-cov==7.0.0
412415
# via mp-api (pyproject.toml)
413416
pytest-mock==3.15.0
414417
# via mp-api (pyproject.toml)
418+
pytest-xdist==3.8.0
419+
# via mp-api (pyproject.toml)
415420
python-dateutil==2.9.0.post0
416421
# via
417422
# arrow
@@ -582,7 +587,7 @@ typeguard==4.4.4
582587
# via inflect
583588
types-python-dateutil==2.9.0.20250822
584589
# via arrow
585-
types-requests==2.32.4.20250809
590+
types-requests==2.32.4.20250913
586591
# via mp-api (pyproject.toml)
587592
types-setuptools==80.9.0.20250822
588593
# via mp-api (pyproject.toml)
@@ -598,6 +603,7 @@ typing-extensions==4.15.0
598603
# pydantic
599604
# pydantic-core
600605
# pydash
606+
# pytest-asyncio
601607
# referencing
602608
# spglib
603609
# swagger-spec-validator

tests/materials/test_electronic_structure.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def es_rester():
4747

4848

4949
@pytest.mark.skipif(os.getenv("MP_API_KEY", None) is None, reason="No API key found.")
50-
@pytest.mark.skip(reason="magnetic ordering fields not build correctly")
50+
@pytest.mark.skip(reason="magnetic ordering fields not built correctly")
5151
def test_es_client(es_rester):
5252
search_method = es_rester.search
5353

@@ -81,7 +81,7 @@ def bs_rester():
8181

8282

8383
@pytest.mark.skipif(os.getenv("MP_API_KEY", None) is None, reason="No API key found.")
84-
@pytest.mark.skip(reason="magnetic ordering fields not build correctly")
84+
@pytest.mark.skip(reason="magnetic ordering fields not built correctly")
8585
def test_bs_client(bs_rester):
8686
# Get specific search method
8787
search_method = bs_rester.search
@@ -127,7 +127,7 @@ def dos_rester():
127127

128128

129129
@pytest.mark.skipif(os.getenv("MP_API_KEY", None) is None, reason="No API key found.")
130-
@pytest.mark.skip(reason="magnetic ordering fields not build correctly")
130+
@pytest.mark.skip(reason="magnetic ordering fields not built correctly")
131131
def test_dos_client(dos_rester):
132132
search_method = dos_rester.search
133133

0 commit comments

Comments
 (0)