Skip to content

Commit 96d02cc

Browse files
committed
feat(convert): improve dict handling with type hints in value encoding
1 parent 0eaa7a2 commit 96d02cc

File tree

1 file changed

+34
-13
lines changed

1 file changed

+34
-13
lines changed

python/cocoindex/convert.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,52 +6,73 @@
66
import datetime
77
import inspect
88
from enum import Enum
9-
from typing import Any, Callable, Mapping, get_origin
9+
from typing import Any, Callable, Mapping, Type, get_args, get_origin
1010

1111
import numpy as np
1212

1313
from .typing import (
1414
KEY_FIELD_NAME,
1515
TABLE_TYPES,
16-
analyze_type_info,
17-
encode_enriched_type,
18-
is_namedtuple_type,
19-
is_struct_type,
20-
AnalyzedTypeInfo,
2116
AnalyzedAnyType,
17+
AnalyzedBasicType,
2218
AnalyzedDictType,
2319
AnalyzedListType,
24-
AnalyzedBasicType,
20+
AnalyzedStructType,
21+
AnalyzedTypeInfo,
2522
AnalyzedUnionType,
2623
AnalyzedUnknownType,
27-
AnalyzedStructType,
24+
TypeKind,
25+
analyze_type_info,
26+
encode_enriched_type,
27+
is_namedtuple_type,
2828
is_numpy_number_type,
29+
is_struct_type,
2930
)
3031

3132

32-
def encode_engine_value(value: Any) -> Any:
33+
def encode_engine_value(
34+
value: Any, _in_struct: bool = False, type_hint: Type[Any] | str | None = None
35+
) -> Any:
3336
"""Encode a Python value to an engine value."""
3437
if dataclasses.is_dataclass(value):
3538
return [
36-
encode_engine_value(getattr(value, f.name))
39+
encode_engine_value(
40+
getattr(value, f.name), _in_struct=True, type_hint=f.type
41+
)
3742
for f in dataclasses.fields(value)
3843
]
3944
if is_namedtuple_type(type(value)):
40-
return [encode_engine_value(getattr(value, name)) for name in value._fields]
45+
annotations = type(value).__annotations__
46+
return [
47+
encode_engine_value(
48+
getattr(value, name), _in_struct=True, type_hint=annotations.get(name)
49+
)
50+
for name in value._fields
51+
]
4152
if isinstance(value, np.number):
4253
return value.item()
4354
if isinstance(value, np.ndarray):
4455
return value
4556
if isinstance(value, (list, tuple)):
46-
return [encode_engine_value(v) for v in value]
57+
return [encode_engine_value(v, _in_struct) for v in value]
4758
if isinstance(value, dict):
59+
is_json_type = type_hint and any(
60+
isinstance(arg, TypeKind) and arg.kind == "Json"
61+
for arg in get_args(type_hint)[1:]
62+
)
63+
64+
# For empty dicts, check type hints if in a struct context
65+
# when no contexts are provided, return an empty dict as default
4866
if not value:
67+
if _in_struct:
68+
return value if is_json_type else []
4969
return {}
5070

5171
first_val = next(iter(value.values()))
5272
if is_struct_type(type(first_val)): # KTable
5373
return [
54-
[encode_engine_value(k)] + encode_engine_value(v)
74+
[encode_engine_value(k, _in_struct)]
75+
+ encode_engine_value(v, _in_struct)
5576
for k, v in value.items()
5677
]
5778
return value

0 commit comments

Comments
 (0)