Skip to content

Commit 56c94c3

Browse files
committed
add union type and subtypes check in schema model signature
1 parent dd41a42 commit 56c94c3

File tree

1 file changed

+48
-19
lines changed

1 file changed

+48
-19
lines changed

ninja/signature/details.py

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -207,13 +207,20 @@ def _args_flatten_map(self, args: List[FuncParam]) -> Dict[str, Tuple[str, ...]]
207207

208208
def _model_flatten_map(self, model: TModel, prefix: str) -> Generator:
209209
field: FieldInfo
210-
for attr, field in model.model_fields.items():
211-
field_name = field.alias or attr
212-
name = f"{prefix}{self.FLATTEN_PATH_SEP}{field_name}"
213-
if is_pydantic_model(field.annotation):
214-
yield from self._model_flatten_map(field.annotation, name) # type: ignore
215-
else:
216-
yield field_name, name
210+
if get_origin(model) in UNION_TYPES:
211+
# If the model is a union type, process each type in the union
212+
for arg in get_args(model):
213+
if arg is type(None):
214+
continue # Skip NoneType
215+
yield from self._model_flatten_map(arg, prefix)
216+
else:
217+
for attr, field in model.model_fields.items():
218+
field_name = field.alias or attr
219+
name = f"{prefix}{self.FLATTEN_PATH_SEP}{field_name}"
220+
if is_pydantic_model(field.annotation):
221+
yield from self._model_flatten_map(field.annotation, name) # type: ignore
222+
else:
223+
yield field_name, name
217224

218225
def _get_param_type(self, name: str, arg: inspect.Parameter) -> FuncParam:
219226
# _EMPTY = self.signature.empty
@@ -260,9 +267,9 @@ def _get_param_type(self, name: str, arg: inspect.Parameter) -> FuncParam:
260267

261268
# 2) if param name is a part of the path parameter
262269
elif name in self.path_params_names:
263-
assert (
264-
default == self.signature.empty
265-
), f"'{name}' is a path param, default not allowed"
270+
assert default == self.signature.empty, (
271+
f"'{name}' is a path param, default not allowed"
272+
)
266273
param_source = Path(...)
267274

268275
# 3) if param is a collection, or annotation is part of pydantic model:
@@ -295,7 +302,11 @@ def is_pydantic_model(cls: Any) -> bool:
295302

296303
# Handle Union types
297304
if origin in UNION_TYPES:
298-
return any(issubclass(arg, pydantic.BaseModel) for arg in get_args(cls))
305+
return any(
306+
issubclass(arg, pydantic.BaseModel)
307+
for arg in get_args(cls)
308+
if arg is not type(None)
309+
)
299310
return issubclass(cls, pydantic.BaseModel)
300311
except TypeError: # pragma: no cover
301312
return False
@@ -338,14 +349,32 @@ def detect_collection_fields(
338349
for attr in path[1:]:
339350
if hasattr(annotation_or_field, "annotation"):
340351
annotation_or_field = annotation_or_field.annotation
341-
annotation_or_field = next(
342-
(
343-
a
344-
for a in annotation_or_field.model_fields.values()
345-
if a.alias == attr
346-
),
347-
annotation_or_field.model_fields.get(attr),
348-
) # pragma: no cover
352+
353+
# check union types
354+
if get_origin(annotation_or_field) in UNION_TYPES:
355+
for arg in get_args(annotation_or_field):
356+
if arg is type(None):
357+
continue # Skip NoneType
358+
if hasattr(arg, "model_fields"):
359+
annotation_or_field = next(
360+
(
361+
a
362+
for a in arg.model_fields.values()
363+
if a.alias == attr
364+
),
365+
arg.model_fields.get(attr),
366+
) # pragma: no cover
367+
else:
368+
continue
369+
else:
370+
annotation_or_field = next(
371+
(
372+
a
373+
for a in annotation_or_field.model_fields.values()
374+
if a.alias == attr
375+
),
376+
annotation_or_field.model_fields.get(attr),
377+
) # pragma: no cover
349378

350379
annotation_or_field = getattr(
351380
annotation_or_field, "outer_type_", annotation_or_field

0 commit comments

Comments
 (0)