15
15
from concurrent .futures import FIRST_COMPLETED , ThreadPoolExecutor , wait
16
16
from copy import copy
17
17
from functools import cache
18
+ from importlib import import_module
18
19
from importlib .metadata import PackageNotFoundError , version
19
20
from json import JSONDecodeError
20
21
from math import ceil
21
- from typing import TYPE_CHECKING , Generic , TypeVar
22
+ from typing import (
23
+ TYPE_CHECKING ,
24
+ ForwardRef ,
25
+ Generic ,
26
+ TypeVar ,
27
+ get_args ,
28
+ )
22
29
from urllib .parse import quote , urljoin
23
30
24
31
import requests
@@ -65,7 +72,7 @@ class BaseRester(Generic[T]):
65
72
"""Base client class with core stubs."""
66
73
67
74
suffix : str = ""
68
- document_model : BaseModel = None # type: ignore
75
+ document_model : type [ BaseModel ] | None = None
69
76
supports_versions : bool = False
70
77
primary_key : str = "material_id"
71
78
@@ -1070,10 +1077,24 @@ def _convert_to_model(self, data: list[dict]):
1070
1077
1071
1078
def _generate_returned_model (self , doc ):
1072
1079
model_fields = self .document_model .model_fields
1080
+
1073
1081
set_fields = doc .model_fields_set
1074
1082
unset_fields = [field for field in model_fields if field not in set_fields ]
1083
+
1084
+ # Update with locals() from external module if needed
1085
+ other_vars = {}
1086
+ if any (
1087
+ isinstance (typ , ForwardRef )
1088
+ for field_meta in model_fields .values ()
1089
+ for typ in get_args (field_meta .annotation )
1090
+ ):
1091
+ other_vars = vars (import_module (self .document_model .__module__ ))
1092
+
1075
1093
include_fields = {
1076
- name : (model_fields [name ].annotation , model_fields [name ])
1094
+ name : (
1095
+ model_fields [name ].annotation ,
1096
+ model_fields [name ],
1097
+ )
1077
1098
for name in set_fields
1078
1099
}
1079
1100
@@ -1085,6 +1106,8 @@ def _generate_returned_model(self, doc):
1085
1106
fields_not_requested = (list [str ], unset_fields ),
1086
1107
__base__ = self .document_model ,
1087
1108
)
1109
+ if other_vars :
1110
+ data_model .model_rebuild (_types_namespace = other_vars )
1088
1111
1089
1112
def new_repr (self ) -> str :
1090
1113
extra = ",\n " .join (
0 commit comments