diff --git a/python/cocoindex/convert.py b/python/cocoindex/convert.py index b46c082e..a60bf9ec 100644 --- a/python/cocoindex/convert.py +++ b/python/cocoindex/convert.py @@ -7,6 +7,7 @@ import dataclasses import datetime import inspect +import warnings from enum import Enum from typing import Any, Callable, Mapping, get_origin @@ -286,6 +287,65 @@ def decode_scalar(value: Any) -> Any | None: return lambda value: value +def _get_auto_default_for_type( + annotation: Any, field_name: str, field_path: list[str] +) -> tuple[Any, bool]: + """ + Get an auto-default value for a type annotation if it's safe to do so. + + Returns: + A tuple of (default_value, is_supported) where: + - default_value: The default value if auto-defaulting is supported + - is_supported: True if auto-defaulting is supported for this type + """ + if annotation is None or annotation is inspect.Parameter.empty or annotation is Any: + return None, False + + try: + type_info = analyze_type_info(annotation) + + # Case 1: Nullable types (Optional[T] or T | None) + if type_info.nullable: + return None, True + + # Case 2: Table types (KTable or LTable) - check if it's a list or dict type + if isinstance(type_info.variant, AnalyzedListType): + return [], True + elif isinstance(type_info.variant, AnalyzedDictType): + return {}, True + + # For all other types, don't auto-default to avoid ambiguity + return None, False + + except (ValueError, TypeError): + return None, False + + +def _handle_missing_field_with_auto_default( + param: inspect.Parameter, name: str, field_path: list[str] +) -> Any: + """ + Handle missing field by trying auto-default or raising an error. + + Returns the auto-default value if supported, otherwise raises ValueError. + """ + auto_default, is_supported = _get_auto_default_for_type( + param.annotation, name, field_path + ) + if is_supported: + warnings.warn( + f"Field '{name}' (type {param.annotation}) without default value is missing in input: " + f"{''.join(field_path)}. Auto-assigning default value: {auto_default}", + UserWarning, + stacklevel=4, + ) + return auto_default + + raise ValueError( + f"Field '{name}' (type {param.annotation}) without default value is missing in input: {''.join(field_path)}" + ) + + def make_engine_struct_decoder( field_path: list[str], src_fields: list[dict[str, Any]], @@ -349,19 +409,28 @@ def make_closure_for_value( field_decoder = make_engine_value_decoder( field_path, src_fields[src_idx]["type"], param.annotation ) - return ( - lambda values: field_decoder(values[src_idx]) - if len(values) > src_idx - else param.default - ) + + def field_value_getter(values: list[Any]) -> Any: + if src_idx is not None and len(values) > src_idx: + return field_decoder(values[src_idx]) + default_value = param.default + if default_value is not inspect.Parameter.empty: + return default_value + + return _handle_missing_field_with_auto_default( + param, name, field_path + ) + + return field_value_getter default_value = param.default - if default_value is inspect.Parameter.empty: - raise ValueError( - f"Field without default value is missing in input: {''.join(field_path)}" - ) + if default_value is not inspect.Parameter.empty: + return lambda _: default_value - return lambda _: default_value + auto_default = _handle_missing_field_with_auto_default( + param, name, field_path + ) + return lambda _: auto_default field_value_decoder = [ make_closure_for_value(name, param) for (name, param) in parameters.items() diff --git a/python/cocoindex/tests/test_convert.py b/python/cocoindex/tests/test_convert.py index 71710981..4f87dafe 100644 --- a/python/cocoindex/tests/test_convert.py +++ b/python/cocoindex/tests/test_convert.py @@ -1,6 +1,6 @@ import datetime import uuid -from dataclasses import dataclass, make_dataclass +from dataclasses import dataclass, make_dataclass, field from typing import Annotated, Any, Callable, Literal, NamedTuple import numpy as np @@ -1489,3 +1489,48 @@ class Team: # Test Any annotation validate_full_roundtrip(teams, dict[str, Team], (expected_dict_dict, Any)) + + +def test_auto_default_for_supported_and_unsupported_types() -> None: + @dataclass + class Base: + a: int + + @dataclass + class NullableField: + a: int + b: int | None + + @dataclass + class LTableField: + a: int + b: list[Base] + + @dataclass + class KTableField: + a: int + b: dict[str, Base] + + @dataclass + class UnsupportedField: + a: int + b: int + + engine_val = [1] + + validate_full_roundtrip(NullableField(1, None), NullableField) + + validate_full_roundtrip(LTableField(1, []), LTableField) + + decoder = build_engine_value_decoder(KTableField) + result = decoder(engine_val) + assert result == KTableField(1, {}) + + # validate_full_roundtrip(KTableField(1, {}), KTableField) + + with pytest.raises( + ValueError, + match=r"Field 'b' \(type \) without default value is missing in input: ", + ): + decoder = build_engine_value_decoder(Base, UnsupportedField) + decoder(engine_val)