@@ -207,13 +207,20 @@ def _args_flatten_map(self, args: List[FuncParam]) -> Dict[str, Tuple[str, ...]]
207
207
208
208
def _model_flatten_map (self , model : TModel , prefix : str ) -> Generator :
209
209
field : FieldInfo
210
- for attr , field in model .model_fields .items ():
211
- field_name = field .alias or attr
212
- name = f"{ prefix } { self .FLATTEN_PATH_SEP } { field_name } "
213
- if is_pydantic_model (field .annotation ):
214
- yield from self ._model_flatten_map (field .annotation , name ) # type: ignore
215
- else :
216
- yield field_name , name
210
+ if get_origin (model ) in UNION_TYPES :
211
+ # If the model is a union type, process each type in the union
212
+ for arg in get_args (model ):
213
+ if arg is type (None ):
214
+ continue # Skip NoneType
215
+ yield from self ._model_flatten_map (arg , prefix )
216
+ else :
217
+ for attr , field in model .model_fields .items ():
218
+ field_name = field .alias or attr
219
+ name = f"{ prefix } { self .FLATTEN_PATH_SEP } { field_name } "
220
+ if is_pydantic_model (field .annotation ):
221
+ yield from self ._model_flatten_map (field .annotation , name ) # type: ignore
222
+ else :
223
+ yield field_name , name
217
224
218
225
def _get_param_type (self , name : str , arg : inspect .Parameter ) -> FuncParam :
219
226
# _EMPTY = self.signature.empty
@@ -260,9 +267,9 @@ def _get_param_type(self, name: str, arg: inspect.Parameter) -> FuncParam:
260
267
261
268
# 2) if param name is a part of the path parameter
262
269
elif name in self .path_params_names :
263
- assert (
264
- default == self . signature . empty
265
- ), f"' { name } ' is a path param, default not allowed"
270
+ assert default == self . signature . empty , (
271
+ f"' { name } ' is a path param, default not allowed"
272
+ )
266
273
param_source = Path (...)
267
274
268
275
# 3) if param is a collection, or annotation is part of pydantic model:
@@ -295,7 +302,11 @@ def is_pydantic_model(cls: Any) -> bool:
295
302
296
303
# Handle Union types
297
304
if origin in UNION_TYPES :
298
- return any (issubclass (arg , pydantic .BaseModel ) for arg in get_args (cls ))
305
+ return any (
306
+ issubclass (arg , pydantic .BaseModel )
307
+ for arg in get_args (cls )
308
+ if arg is not type (None )
309
+ )
299
310
return issubclass (cls , pydantic .BaseModel )
300
311
except TypeError : # pragma: no cover
301
312
return False
@@ -338,14 +349,32 @@ def detect_collection_fields(
338
349
for attr in path [1 :]:
339
350
if hasattr (annotation_or_field , "annotation" ):
340
351
annotation_or_field = annotation_or_field .annotation
341
- annotation_or_field = next (
342
- (
343
- a
344
- for a in annotation_or_field .model_fields .values ()
345
- if a .alias == attr
346
- ),
347
- annotation_or_field .model_fields .get (attr ),
348
- ) # pragma: no cover
352
+
353
+ # check union types
354
+ if get_origin (annotation_or_field ) in UNION_TYPES :
355
+ for arg in get_args (annotation_or_field ):
356
+ if arg is type (None ):
357
+ continue # Skip NoneType
358
+ if hasattr (arg , "model_fields" ):
359
+ annotation_or_field = next (
360
+ (
361
+ a
362
+ for a in arg .model_fields .values ()
363
+ if a .alias == attr
364
+ ),
365
+ arg .model_fields .get (attr ),
366
+ ) # pragma: no cover
367
+ else :
368
+ continue
369
+ else :
370
+ annotation_or_field = next (
371
+ (
372
+ a
373
+ for a in annotation_or_field .model_fields .values ()
374
+ if a .alias == attr
375
+ ),
376
+ annotation_or_field .model_fields .get (attr ),
377
+ ) # pragma: no cover
349
378
350
379
annotation_or_field = getattr (
351
380
annotation_or_field , "outer_type_" , annotation_or_field
0 commit comments