|
7 | 7 | import dataclasses
|
8 | 8 | import datetime
|
9 | 9 | import inspect
|
| 10 | +import warnings |
10 | 11 | from enum import Enum
|
11 | 12 | from typing import Any, Callable, Mapping, get_origin
|
12 | 13 |
|
@@ -286,6 +287,65 @@ def decode_scalar(value: Any) -> Any | None:
|
286 | 287 | return lambda value: value
|
287 | 288 |
|
288 | 289 |
|
| 290 | +def _get_auto_default_for_type( |
| 291 | + annotation: Any, field_name: str, field_path: list[str] |
| 292 | +) -> tuple[Any, bool]: |
| 293 | + """ |
| 294 | + Get an auto-default value for a type annotation if it's safe to do so. |
| 295 | +
|
| 296 | + Returns: |
| 297 | + A tuple of (default_value, is_supported) where: |
| 298 | + - default_value: The default value if auto-defaulting is supported |
| 299 | + - is_supported: True if auto-defaulting is supported for this type |
| 300 | + """ |
| 301 | + if annotation is None or annotation is inspect.Parameter.empty or annotation is Any: |
| 302 | + return None, False |
| 303 | + |
| 304 | + try: |
| 305 | + type_info = analyze_type_info(annotation) |
| 306 | + |
| 307 | + # Case 1: Nullable types (Optional[T] or T | None) |
| 308 | + if type_info.nullable: |
| 309 | + return None, True |
| 310 | + |
| 311 | + # Case 2: Table types (KTable or LTable) - check if it's a list or dict type |
| 312 | + if isinstance(type_info.variant, AnalyzedListType): |
| 313 | + return [], True |
| 314 | + elif isinstance(type_info.variant, AnalyzedDictType): |
| 315 | + return {}, True |
| 316 | + |
| 317 | + # For all other types, don't auto-default to avoid ambiguity |
| 318 | + return None, False |
| 319 | + |
| 320 | + except (ValueError, TypeError): |
| 321 | + return None, False |
| 322 | + |
| 323 | + |
| 324 | +def _handle_missing_field_with_auto_default( |
| 325 | + param: inspect.Parameter, name: str, field_path: list[str] |
| 326 | +) -> Any: |
| 327 | + """ |
| 328 | + Handle missing field by trying auto-default or raising an error. |
| 329 | +
|
| 330 | + Returns the auto-default value if supported, otherwise raises ValueError. |
| 331 | + """ |
| 332 | + auto_default, is_supported = _get_auto_default_for_type( |
| 333 | + param.annotation, name, field_path |
| 334 | + ) |
| 335 | + if is_supported: |
| 336 | + warnings.warn( |
| 337 | + f"Field '{name}' (type {param.annotation}) without default value is missing in input: " |
| 338 | + f"{''.join(field_path)}. Auto-assigning default value: {auto_default}", |
| 339 | + UserWarning, |
| 340 | + stacklevel=4, |
| 341 | + ) |
| 342 | + return auto_default |
| 343 | + |
| 344 | + raise ValueError( |
| 345 | + f"Field '{name}' (type {param.annotation}) without default value is missing in input: {''.join(field_path)}" |
| 346 | + ) |
| 347 | + |
| 348 | + |
289 | 349 | def make_engine_struct_decoder(
|
290 | 350 | field_path: list[str],
|
291 | 351 | src_fields: list[dict[str, Any]],
|
@@ -349,19 +409,28 @@ def make_closure_for_value(
|
349 | 409 | field_decoder = make_engine_value_decoder(
|
350 | 410 | field_path, src_fields[src_idx]["type"], param.annotation
|
351 | 411 | )
|
352 |
| - return ( |
353 |
| - lambda values: field_decoder(values[src_idx]) |
354 |
| - if len(values) > src_idx |
355 |
| - else param.default |
356 |
| - ) |
| 412 | + |
| 413 | + def field_value_getter(values: list[Any]) -> Any: |
| 414 | + if src_idx is not None and len(values) > src_idx: |
| 415 | + return field_decoder(values[src_idx]) |
| 416 | + default_value = param.default |
| 417 | + if default_value is not inspect.Parameter.empty: |
| 418 | + return default_value |
| 419 | + |
| 420 | + return _handle_missing_field_with_auto_default( |
| 421 | + param, name, field_path |
| 422 | + ) |
| 423 | + |
| 424 | + return field_value_getter |
357 | 425 |
|
358 | 426 | default_value = param.default
|
359 |
| - if default_value is inspect.Parameter.empty: |
360 |
| - raise ValueError( |
361 |
| - f"Field without default value is missing in input: {''.join(field_path)}" |
362 |
| - ) |
| 427 | + if default_value is not inspect.Parameter.empty: |
| 428 | + return lambda _: default_value |
363 | 429 |
|
364 |
| - return lambda _: default_value |
| 430 | + auto_default = _handle_missing_field_with_auto_default( |
| 431 | + param, name, field_path |
| 432 | + ) |
| 433 | + return lambda _: auto_default |
365 | 434 |
|
366 | 435 | field_value_decoder = [
|
367 | 436 | make_closure_for_value(name, param) for (name, param) in parameters.items()
|
|
0 commit comments