Skip to content

Commit 6e6d28f

Browse files
authored
added default values for field decoding (#788) (#792)
* added default values for field decoding * removed few comments * made suggested changes * added new test case and changed existing lambda function to better handle errors * commented full_roundtrip test * fixed errors after merging * Addressed new comments
1 parent 7e6a6f4 commit 6e6d28f

File tree

2 files changed

+125
-11
lines changed

2 files changed

+125
-11
lines changed

python/cocoindex/convert.py

Lines changed: 79 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import dataclasses
88
import datetime
99
import inspect
10+
import warnings
1011
from enum import Enum
1112
from typing import Any, Callable, Mapping, get_origin
1213

@@ -286,6 +287,65 @@ def decode_scalar(value: Any) -> Any | None:
286287
return lambda value: value
287288

288289

290+
def _get_auto_default_for_type(
291+
annotation: Any, field_name: str, field_path: list[str]
292+
) -> tuple[Any, bool]:
293+
"""
294+
Get an auto-default value for a type annotation if it's safe to do so.
295+
296+
Returns:
297+
A tuple of (default_value, is_supported) where:
298+
- default_value: The default value if auto-defaulting is supported
299+
- is_supported: True if auto-defaulting is supported for this type
300+
"""
301+
if annotation is None or annotation is inspect.Parameter.empty or annotation is Any:
302+
return None, False
303+
304+
try:
305+
type_info = analyze_type_info(annotation)
306+
307+
# Case 1: Nullable types (Optional[T] or T | None)
308+
if type_info.nullable:
309+
return None, True
310+
311+
# Case 2: Table types (KTable or LTable) - check if it's a list or dict type
312+
if isinstance(type_info.variant, AnalyzedListType):
313+
return [], True
314+
elif isinstance(type_info.variant, AnalyzedDictType):
315+
return {}, True
316+
317+
# For all other types, don't auto-default to avoid ambiguity
318+
return None, False
319+
320+
except (ValueError, TypeError):
321+
return None, False
322+
323+
324+
def _handle_missing_field_with_auto_default(
325+
param: inspect.Parameter, name: str, field_path: list[str]
326+
) -> Any:
327+
"""
328+
Handle missing field by trying auto-default or raising an error.
329+
330+
Returns the auto-default value if supported, otherwise raises ValueError.
331+
"""
332+
auto_default, is_supported = _get_auto_default_for_type(
333+
param.annotation, name, field_path
334+
)
335+
if is_supported:
336+
warnings.warn(
337+
f"Field '{name}' (type {param.annotation}) without default value is missing in input: "
338+
f"{''.join(field_path)}. Auto-assigning default value: {auto_default}",
339+
UserWarning,
340+
stacklevel=4,
341+
)
342+
return auto_default
343+
344+
raise ValueError(
345+
f"Field '{name}' (type {param.annotation}) without default value is missing in input: {''.join(field_path)}"
346+
)
347+
348+
289349
def make_engine_struct_decoder(
290350
field_path: list[str],
291351
src_fields: list[dict[str, Any]],
@@ -349,19 +409,28 @@ def make_closure_for_value(
349409
field_decoder = make_engine_value_decoder(
350410
field_path, src_fields[src_idx]["type"], param.annotation
351411
)
352-
return (
353-
lambda values: field_decoder(values[src_idx])
354-
if len(values) > src_idx
355-
else param.default
356-
)
412+
413+
def field_value_getter(values: list[Any]) -> Any:
414+
if src_idx is not None and len(values) > src_idx:
415+
return field_decoder(values[src_idx])
416+
default_value = param.default
417+
if default_value is not inspect.Parameter.empty:
418+
return default_value
419+
420+
return _handle_missing_field_with_auto_default(
421+
param, name, field_path
422+
)
423+
424+
return field_value_getter
357425

358426
default_value = param.default
359-
if default_value is inspect.Parameter.empty:
360-
raise ValueError(
361-
f"Field without default value is missing in input: {''.join(field_path)}"
362-
)
427+
if default_value is not inspect.Parameter.empty:
428+
return lambda _: default_value
363429

364-
return lambda _: default_value
430+
auto_default = _handle_missing_field_with_auto_default(
431+
param, name, field_path
432+
)
433+
return lambda _: auto_default
365434

366435
field_value_decoder = [
367436
make_closure_for_value(name, param) for (name, param) in parameters.items()

python/cocoindex/tests/test_convert.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import datetime
22
import uuid
3-
from dataclasses import dataclass, make_dataclass
3+
from dataclasses import dataclass, make_dataclass, field
44
from typing import Annotated, Any, Callable, Literal, NamedTuple
55

66
import numpy as np
@@ -1489,3 +1489,48 @@ class Team:
14891489

14901490
# Test Any annotation
14911491
validate_full_roundtrip(teams, dict[str, Team], (expected_dict_dict, Any))
1492+
1493+
1494+
def test_auto_default_for_supported_and_unsupported_types() -> None:
1495+
@dataclass
1496+
class Base:
1497+
a: int
1498+
1499+
@dataclass
1500+
class NullableField:
1501+
a: int
1502+
b: int | None
1503+
1504+
@dataclass
1505+
class LTableField:
1506+
a: int
1507+
b: list[Base]
1508+
1509+
@dataclass
1510+
class KTableField:
1511+
a: int
1512+
b: dict[str, Base]
1513+
1514+
@dataclass
1515+
class UnsupportedField:
1516+
a: int
1517+
b: int
1518+
1519+
engine_val = [1]
1520+
1521+
validate_full_roundtrip(NullableField(1, None), NullableField)
1522+
1523+
validate_full_roundtrip(LTableField(1, []), LTableField)
1524+
1525+
decoder = build_engine_value_decoder(KTableField)
1526+
result = decoder(engine_val)
1527+
assert result == KTableField(1, {})
1528+
1529+
# validate_full_roundtrip(KTableField(1, {}), KTableField)
1530+
1531+
with pytest.raises(
1532+
ValueError,
1533+
match=r"Field 'b' \(type <class 'int'>\) without default value is missing in input: ",
1534+
):
1535+
decoder = build_engine_value_decoder(Base, UnsupportedField)
1536+
decoder(engine_val)

0 commit comments

Comments
 (0)