Skip to content

added default values for field decoding (#788) #792

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 79 additions & 10 deletions python/cocoindex/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import dataclasses
import datetime
import inspect
import warnings
from enum import Enum
from typing import Any, Callable, Mapping, get_origin

Expand Down Expand Up @@ -286,6 +287,65 @@ def decode_scalar(value: Any) -> Any | None:
return lambda value: value


def _get_auto_default_for_type(
annotation: Any, field_name: str, field_path: list[str]
) -> tuple[Any, bool]:
"""
Get an auto-default value for a type annotation if it's safe to do so.

Returns:
A tuple of (default_value, is_supported) where:
- default_value: The default value if auto-defaulting is supported
- is_supported: True if auto-defaulting is supported for this type
"""
if annotation is None or annotation is inspect.Parameter.empty or annotation is Any:
return None, False

try:
type_info = analyze_type_info(annotation)

# Case 1: Nullable types (Optional[T] or T | None)
if type_info.nullable:
return None, True

# Case 2: Table types (KTable or LTable) - check if it's a list or dict type
if isinstance(type_info.variant, AnalyzedListType):
return [], True
elif isinstance(type_info.variant, AnalyzedDictType):
return {}, True

# For all other types, don't auto-default to avoid ambiguity
return None, False

except (ValueError, TypeError):
return None, False


def _handle_missing_field_with_auto_default(
param: inspect.Parameter, name: str, field_path: list[str]
) -> Any:
"""
Handle missing field by trying auto-default or raising an error.

Returns the auto-default value if supported, otherwise raises ValueError.
"""
auto_default, is_supported = _get_auto_default_for_type(
param.annotation, name, field_path
)
if is_supported:
warnings.warn(
f"Field '{name}' (type {param.annotation}) without default value is missing in input: "
f"{''.join(field_path)}. Auto-assigning default value: {auto_default}",
UserWarning,
stacklevel=4,
)
return auto_default

raise ValueError(
f"Field '{name}' (type {param.annotation}) without default value is missing in input: {''.join(field_path)}"
)


def make_engine_struct_decoder(
field_path: list[str],
src_fields: list[dict[str, Any]],
Expand Down Expand Up @@ -349,19 +409,28 @@ def make_closure_for_value(
field_decoder = make_engine_value_decoder(
field_path, src_fields[src_idx]["type"], param.annotation
)
return (
lambda values: field_decoder(values[src_idx])
if len(values) > src_idx
else param.default
)

def field_value_getter(values: list[Any]) -> Any:
if src_idx is not None and len(values) > src_idx:
return field_decoder(values[src_idx])
default_value = param.default
if default_value is not inspect.Parameter.empty:
return default_value

return _handle_missing_field_with_auto_default(
param, name, field_path
)

return field_value_getter

default_value = param.default
if default_value is inspect.Parameter.empty:
raise ValueError(
f"Field without default value is missing in input: {''.join(field_path)}"
)
if default_value is not inspect.Parameter.empty:
return lambda _: default_value

return lambda _: default_value
auto_default = _handle_missing_field_with_auto_default(
param, name, field_path
)
return lambda _: auto_default

field_value_decoder = [
make_closure_for_value(name, param) for (name, param) in parameters.items()
Expand Down
47 changes: 46 additions & 1 deletion python/cocoindex/tests/test_convert.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime
import uuid
from dataclasses import dataclass, make_dataclass
from dataclasses import dataclass, make_dataclass, field
from typing import Annotated, Any, Callable, Literal, NamedTuple

import numpy as np
Expand Down Expand Up @@ -1489,3 +1489,48 @@ class Team:

# Test Any annotation
validate_full_roundtrip(teams, dict[str, Team], (expected_dict_dict, Any))


def test_auto_default_for_supported_and_unsupported_types() -> None:
@dataclass
class Base:
a: int

@dataclass
class NullableField:
a: int
b: int | None

@dataclass
class LTableField:
a: int
b: list[Base]

@dataclass
class KTableField:
a: int
b: dict[str, Base]

@dataclass
class UnsupportedField:
a: int
b: int

engine_val = [1]

validate_full_roundtrip(NullableField(1, None), NullableField)

validate_full_roundtrip(LTableField(1, []), LTableField)

decoder = build_engine_value_decoder(KTableField)
result = decoder(engine_val)
assert result == KTableField(1, {})

# validate_full_roundtrip(KTableField(1, {}), KTableField)

with pytest.raises(
ValueError,
match=r"Field 'b' \(type <class 'int'>\) without default value is missing in input: ",
):
decoder = build_engine_value_decoder(Base, UnsupportedField)
decoder(engine_val)
Loading