Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 96 additions & 14 deletions python/cocoindex/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import dataclasses
import datetime
import inspect
import warnings
from enum import Enum
from typing import Any, Callable, Mapping, get_origin

Expand Down Expand Up @@ -67,6 +68,7 @@ def make_engine_value_decoder(
field_path: list[str],
src_type: dict[str, Any],
dst_annotation: Any,
auto_default_missing_fields: bool = False,
) -> Callable[[Any], Any]:
"""
Make a decoder from an engine value to a Python value.
Expand All @@ -90,15 +92,17 @@ def make_engine_value_decoder(
if src_type_kind == "Union":
return lambda value: value[1]
if src_type_kind == "Struct":
return _make_engine_struct_to_dict_decoder(field_path, src_type["fields"])
return _make_engine_struct_to_dict_decoder(
field_path, src_type["fields"], auto_default_missing_fields
)
if src_type_kind in TABLE_TYPES:
if src_type_kind == "LTable":
return _make_engine_ltable_to_list_dict_decoder(
field_path, src_type["row"]["fields"]
field_path, src_type["row"]["fields"], auto_default_missing_fields
)
elif src_type_kind == "KTable":
return _make_engine_ktable_to_dict_dict_decoder(
field_path, src_type["row"]["fields"]
field_path, src_type["row"]["fields"], auto_default_missing_fields
)
return lambda value: value

Expand All @@ -111,7 +115,9 @@ def make_engine_value_decoder(
if args == (str, Any):
is_dict_annotation = True
if is_dict_annotation and src_type_kind == "Struct":
return _make_engine_struct_to_dict_decoder(field_path, src_type["fields"])
return _make_engine_struct_to_dict_decoder(
field_path, src_type["fields"], auto_default_missing_fields
)

dst_type_info = analyze_type_info(dst_annotation)

Expand All @@ -129,7 +135,10 @@ def make_engine_value_decoder(
for dst_type_variant in dst_type_variants:
try:
decoder = make_engine_value_decoder(
src_field_path, src_type_variant, dst_type_variant
src_field_path,
src_type_variant,
dst_type_variant,
auto_default_missing_fields,
)
break
except ValueError:
Expand Down Expand Up @@ -175,6 +184,7 @@ def decode_scalar(value: Any) -> Any | None:
field_path + ["[*]"],
src_type["element_type"],
dst_type_info.elem_type,
auto_default_missing_fields,
)
else: # for NDArray vector
scalar_dtype = extract_ndarray_scalar_dtype(dst_type_info.np_number_type)
Expand Down Expand Up @@ -206,7 +216,10 @@ def decode_vector(value: Any) -> Any | None:

if dst_type_info.struct_type is not None:
return _make_engine_struct_value_decoder(
field_path, src_type["fields"], dst_type_info.struct_type
field_path,
src_type["fields"],
dst_type_info.struct_type,
auto_default_missing_fields,
)

if src_type_kind in TABLE_TYPES:
Expand All @@ -222,11 +235,17 @@ def decode_vector(value: Any) -> Any | None:
key_field_schema = engine_fields_schema[0]
field_path.append(f".{key_field_schema.get('name', KEY_FIELD_NAME)}")
key_decoder = make_engine_value_decoder(
field_path, key_field_schema["type"], elem_type_info.key_type
field_path,
key_field_schema["type"],
elem_type_info.key_type,
auto_default_missing_fields,
)
field_path.pop()
value_decoder = _make_engine_struct_value_decoder(
field_path, engine_fields_schema[1:], elem_type_info.struct_type
field_path,
engine_fields_schema[1:],
elem_type_info.struct_type,
auto_default_missing_fields,
)

def decode(value: Any) -> Any | None:
Expand All @@ -235,7 +254,10 @@ def decode(value: Any) -> Any | None:
return {key_decoder(v[0]): value_decoder(v[1:]) for v in value}
else:
elem_decoder = _make_engine_struct_value_decoder(
field_path, engine_fields_schema, elem_type_info.struct_type
field_path,
engine_fields_schema,
elem_type_info.struct_type,
auto_default_missing_fields,
)

def decode(value: Any) -> Any | None:
Expand All @@ -249,10 +271,44 @@ def decode(value: Any) -> Any | None:
return lambda value: value


def _get_auto_default_for_type(
annotation: Any, field_name: str, field_path: list[str]
) -> Any:
"""
Get an auto-default value for a type annotation if it's safe to do so.

Returns:
The default value if auto-defaulting is safe, None otherwise.
"""
if annotation is None or annotation is inspect.Parameter.empty or annotation is Any:
return None

try:
type_info = analyze_type_info(annotation)

# Case 1: Nullable types (Optional[T] or T | None)
if type_info.nullable:
return None

# Case 2: Table types (KTable or LTable)
if type_info.kind in TABLE_TYPES:
if type_info.kind == "LTable":
return []
elif type_info.kind == "KTable":
return {}

# For all other types, don't auto-default to avoid ambiguity
return None

except (ValueError, TypeError):
return None


def _make_engine_struct_value_decoder(
field_path: list[str],
src_fields: list[dict[str, Any]],
dst_struct_type: type,
auto_default_missing_fields: bool = False,
) -> Callable[[list[Any]], Any]:
"""Make a decoder from an engine field values to a Python value."""

Expand Down Expand Up @@ -285,7 +341,10 @@ def make_closure_for_value(
if src_idx is not None:
field_path.append(f".{name}")
field_decoder = make_engine_value_decoder(
field_path, src_fields[src_idx]["type"], param.annotation
field_path,
src_fields[src_idx]["type"],
param.annotation,
auto_default_missing_fields,
)
field_path.pop()
return (
Expand All @@ -296,8 +355,21 @@ def make_closure_for_value(

default_value = param.default
if default_value is inspect.Parameter.empty:
if auto_default_missing_fields:
auto_default = _get_auto_default_for_type(
param.annotation, name, field_path
)
if auto_default is not None:
warnings.warn(
f"Field '{name}' (type {param.annotation}) without default value is missing in input: "
f"{''.join(field_path)}. Auto-assigning default value: {auto_default}",
UserWarning,
stacklevel=3,
)
return lambda _: auto_default

raise ValueError(
f"Field without default value is missing in input: {''.join(field_path)}"
f"Field '{name}' (type {param.annotation}) without default value is missing in input: {''.join(field_path)}"
)

return lambda _: default_value
Expand All @@ -314,6 +386,7 @@ def make_closure_for_value(
def _make_engine_struct_to_dict_decoder(
field_path: list[str],
src_fields: list[dict[str, Any]],
auto_default_missing_fields: bool = False,
) -> Callable[[list[Any] | None], dict[str, Any] | None]:
"""Make a decoder from engine field values to a Python dict."""

Expand All @@ -325,6 +398,7 @@ def _make_engine_struct_to_dict_decoder(
field_path,
field_schema["type"],
Any, # Use Any for recursive decoding
auto_default_missing_fields,
)
field_path.pop()
field_decoders.append((field_name, field_decoder))
Expand All @@ -347,11 +421,14 @@ def decode_to_dict(values: list[Any] | None) -> dict[str, Any] | None:
def _make_engine_ltable_to_list_dict_decoder(
field_path: list[str],
src_fields: list[dict[str, Any]],
auto_default_missing_fields: bool = False,
) -> Callable[[list[Any] | None], list[dict[str, Any]] | None]:
"""Make a decoder from engine LTable values to a list of dicts."""

# Create a decoder for each row (struct) to dict
row_decoder = _make_engine_struct_to_dict_decoder(field_path, src_fields)
row_decoder = _make_engine_struct_to_dict_decoder(
field_path, src_fields, auto_default_missing_fields
)

def decode_to_list_dict(values: list[Any] | None) -> list[dict[str, Any]] | None:
if values is None:
Expand All @@ -372,6 +449,7 @@ def decode_to_list_dict(values: list[Any] | None) -> list[dict[str, Any]] | None
def _make_engine_ktable_to_dict_dict_decoder(
field_path: list[str],
src_fields: list[dict[str, Any]],
auto_default_missing_fields: bool = False,
) -> Callable[[list[Any] | None], dict[Any, dict[str, Any]] | None]:
"""Make a decoder from engine KTable values to a dict of dicts."""

Expand All @@ -384,10 +462,14 @@ def _make_engine_ktable_to_dict_dict_decoder(

# Create decoders
field_path.append(f".{key_field_schema.get('name', KEY_FIELD_NAME)}")
key_decoder = make_engine_value_decoder(field_path, key_field_schema["type"], Any)
key_decoder = make_engine_value_decoder(
field_path, key_field_schema["type"], Any, auto_default_missing_fields
)
field_path.pop()

value_decoder = _make_engine_struct_to_dict_decoder(field_path, value_fields_schema)
value_decoder = _make_engine_struct_to_dict_decoder(
field_path, value_fields_schema, auto_default_missing_fields
)

def decode_to_dict_dict(
values: list[Any] | None,
Expand Down
Loading