Skip to content

Commit fb92792

Browse files
authored
feat(convert): support explicit type hints in engine value encoding (#807)
* feat(convert): improve dict handling with type hints in value encoding * feat(convert): refine JSON type handling * feat(convert): support explicit type hints in engine value encoding * fix(format): pass ruff format checks * feat(convert): make type information cached to avoid recompute in encoding * chore: unblock the previously failed test case * refactor(convert): extract caching logic into unified function * fix(convert): make type hints required in engine value encoding * feat(convert): use built-in cache for type info before encoding * feat(convert): create an encoder closure for efficient type converting * refactor: improve closure structure for each type-specific encoder * refactor(convert): streamline encoder closure logic and enhance type handling * refactor(convert): remove caching for type info and streamline encoder logic
1 parent 68558b1 commit fb92792

File tree

4 files changed

+216
-93
lines changed

4 files changed

+216
-93
lines changed

python/cocoindex/convert.py

Lines changed: 147 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,26 @@
99
import inspect
1010
import warnings
1111
from enum import Enum
12-
from typing import Any, Callable, Mapping, get_origin
12+
from typing import Any, Callable, Mapping, Type, get_origin
1313

1414
import numpy as np
1515

1616
from .typing import (
1717
KEY_FIELD_NAME,
1818
TABLE_TYPES,
19-
analyze_type_info,
20-
encode_enriched_type,
21-
is_namedtuple_type,
22-
is_struct_type,
23-
AnalyzedTypeInfo,
2419
AnalyzedAnyType,
20+
AnalyzedBasicType,
2521
AnalyzedDictType,
2622
AnalyzedListType,
27-
AnalyzedBasicType,
23+
AnalyzedStructType,
24+
AnalyzedTypeInfo,
2825
AnalyzedUnionType,
2926
AnalyzedUnknownType,
30-
AnalyzedStructType,
27+
analyze_type_info,
28+
encode_enriched_type,
29+
is_namedtuple_type,
3130
is_numpy_number_type,
31+
is_struct_type,
3232
)
3333

3434

@@ -50,34 +50,6 @@ def __exit__(self, _exc_type: Any, _exc_val: Any, _exc_tb: Any) -> None:
5050
self._field_path.pop()
5151

5252

53-
def encode_engine_value(value: Any) -> Any:
54-
"""Encode a Python value to an engine value."""
55-
if dataclasses.is_dataclass(value):
56-
return [
57-
encode_engine_value(getattr(value, f.name))
58-
for f in dataclasses.fields(value)
59-
]
60-
if is_namedtuple_type(type(value)):
61-
return [encode_engine_value(getattr(value, name)) for name in value._fields]
62-
if isinstance(value, np.number):
63-
return value.item()
64-
if isinstance(value, np.ndarray):
65-
return value
66-
if isinstance(value, (list, tuple)):
67-
return [encode_engine_value(v) for v in value]
68-
if isinstance(value, dict):
69-
if not value:
70-
return {}
71-
72-
first_val = next(iter(value.values()))
73-
if is_struct_type(type(first_val)): # KTable
74-
return [
75-
[encode_engine_value(k)] + encode_engine_value(v)
76-
for k, v in value.items()
77-
]
78-
return value
79-
80-
8153
_CONVERTIBLE_KINDS = {
8254
("Float32", "Float64"),
8355
("LocalDateTime", "OffsetDateTime"),
@@ -91,6 +63,145 @@ def _is_type_kind_convertible_to(src_type_kind: str, dst_type_kind: str) -> bool
9163
)
9264

9365

66+
# Pre-computed type info for missing/Any type annotations
67+
ANY_TYPE_INFO = analyze_type_info(inspect.Parameter.empty)
68+
69+
70+
def _make_encoder_closure(type_info: AnalyzedTypeInfo) -> Callable[[Any], Any]:
71+
"""
72+
Create an encoder closure for a specific type.
73+
"""
74+
variant = type_info.variant
75+
76+
if isinstance(variant, AnalyzedListType):
77+
elem_type_info = (
78+
analyze_type_info(variant.elem_type) if variant.elem_type else ANY_TYPE_INFO
79+
)
80+
if isinstance(elem_type_info.variant, AnalyzedStructType):
81+
elem_encoder = _make_encoder_closure(elem_type_info)
82+
83+
def encode_struct_list(value: Any) -> Any:
84+
return None if value is None else [elem_encoder(v) for v in value]
85+
86+
return encode_struct_list
87+
88+
if isinstance(variant, AnalyzedDictType):
89+
if not variant.value_type:
90+
return lambda value: value
91+
92+
value_type_info = analyze_type_info(variant.value_type)
93+
if isinstance(value_type_info.variant, AnalyzedStructType):
94+
95+
def encode_struct_dict(value: Any) -> Any:
96+
if not isinstance(value, dict):
97+
return value
98+
if not value:
99+
return []
100+
101+
sample_key, sample_val = next(iter(value.items()))
102+
key_type, val_type = type(sample_key), type(sample_val)
103+
104+
# Handle KTable case
105+
if value and is_struct_type(val_type):
106+
key_encoder = (
107+
_make_encoder_closure(analyze_type_info(key_type))
108+
if is_struct_type(key_type)
109+
else _make_encoder_closure(ANY_TYPE_INFO)
110+
)
111+
value_encoder = _make_encoder_closure(analyze_type_info(val_type))
112+
return [
113+
[key_encoder(k)] + value_encoder(v) for k, v in value.items()
114+
]
115+
return {key_encoder(k): value_encoder(v) for k, v in value.items()}
116+
117+
return encode_struct_dict
118+
119+
if isinstance(variant, AnalyzedStructType):
120+
struct_type = variant.struct_type
121+
122+
if dataclasses.is_dataclass(struct_type):
123+
fields = dataclasses.fields(struct_type)
124+
field_encoders = [
125+
_make_encoder_closure(analyze_type_info(f.type)) for f in fields
126+
]
127+
field_names = [f.name for f in fields]
128+
129+
def encode_dataclass(value: Any) -> Any:
130+
if not dataclasses.is_dataclass(value):
131+
return value
132+
return [
133+
encoder(getattr(value, name))
134+
for encoder, name in zip(field_encoders, field_names)
135+
]
136+
137+
return encode_dataclass
138+
139+
elif is_namedtuple_type(struct_type):
140+
annotations = struct_type.__annotations__
141+
field_names = list(getattr(struct_type, "_fields", ()))
142+
field_encoders = [
143+
_make_encoder_closure(
144+
analyze_type_info(annotations[name])
145+
if name in annotations
146+
else ANY_TYPE_INFO
147+
)
148+
for name in field_names
149+
]
150+
151+
def encode_namedtuple(value: Any) -> Any:
152+
if not is_namedtuple_type(type(value)):
153+
return value
154+
return [
155+
encoder(getattr(value, name))
156+
for encoder, name in zip(field_encoders, field_names)
157+
]
158+
159+
return encode_namedtuple
160+
161+
def encode_basic_value(value: Any) -> Any:
162+
if isinstance(value, np.number):
163+
return value.item()
164+
if isinstance(value, np.ndarray):
165+
return value
166+
if isinstance(value, (list, tuple)):
167+
return [encode_basic_value(v) for v in value]
168+
return value
169+
170+
return encode_basic_value
171+
172+
173+
def make_engine_value_encoder(type_hint: Type[Any] | str) -> Callable[[Any], Any]:
174+
"""
175+
Create an encoder closure for converting Python values to engine values.
176+
177+
Args:
178+
type_hint: Type annotation for the values to encode
179+
180+
Returns:
181+
A closure that encodes Python values to engine values
182+
"""
183+
type_info = analyze_type_info(type_hint)
184+
if isinstance(type_info.variant, AnalyzedUnknownType):
185+
raise ValueError(f"Type annotation `{type_info.core_type}` is unsupported")
186+
187+
return _make_encoder_closure(type_info)
188+
189+
190+
def encode_engine_value(value: Any, type_hint: Type[Any] | str) -> Any:
191+
"""
192+
Encode a Python value to an engine value.
193+
194+
Args:
195+
value: The Python value to encode
196+
type_hint: Type annotation for the value. This should always be provided.
197+
198+
Returns:
199+
The encoded engine value
200+
"""
201+
encoder = make_engine_value_encoder(type_hint)
202+
return encoder(value)
203+
204+
94205
def make_engine_value_decoder(
95206
field_path: list[str],
96207
src_type: dict[str, Any],

python/cocoindex/flow.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,29 +9,20 @@
99
import functools
1010
import inspect
1111
import re
12-
13-
from .validation import (
14-
validate_flow_name,
15-
NamingError,
16-
validate_full_flow_name,
17-
validate_target_name,
18-
)
19-
from .typing import analyze_type_info
20-
2112
from dataclasses import dataclass
2213
from enum import Enum
2314
from threading import Lock
2415
from typing import (
2516
Any,
2617
Callable,
2718
Generic,
19+
Iterable,
2820
NamedTuple,
2921
Sequence,
3022
TypeVar,
3123
cast,
3224
get_args,
3325
get_origin,
34-
Iterable,
3526
)
3627

3728
from rich.text import Text
@@ -45,7 +36,12 @@
4536
from .op import FunctionSpec
4637
from .runtime import execution_context
4738
from .setup import SetupChangeBundle
48-
from .typing import encode_enriched_type
39+
from .typing import analyze_type_info, encode_enriched_type
40+
from .validation import (
41+
validate_flow_name,
42+
validate_full_flow_name,
43+
validate_target_name,
44+
)
4945

5046

5147
class _NameBuilder:
@@ -1099,11 +1095,16 @@ async def eval_async(self, *args: Any, **kwargs: Any) -> T:
10991095
"""
11001096
flow_info = await self._flow_info_async()
11011097
params = []
1102-
for i, arg in enumerate(self._param_names):
1098+
for i, (arg, arg_type) in enumerate(
1099+
zip(self._param_names, self._flow_arg_types)
1100+
):
1101+
param_type = (
1102+
self._flow_arg_types[i] if i < len(self._flow_arg_types) else Any
1103+
)
11031104
if i < len(args):
1104-
params.append(encode_engine_value(args[i]))
1105+
params.append(encode_engine_value(args[i], type_hint=param_type))
11051106
elif arg in kwargs:
1106-
params.append(encode_engine_value(kwargs[arg]))
1107+
params.append(encode_engine_value(kwargs[arg], type_hint=param_type))
11071108
else:
11081109
raise ValueError(f"Parameter {arg} is not provided")
11091110
engine_result = await flow_info.engine_flow.evaluate_async(params)

python/cocoindex/op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ async def __call__(self, *args: Any, **kwargs: Any) -> Any:
343343
output = await self._acall(*decoded_args, **decoded_kwargs)
344344
else:
345345
output = await self._acall(*decoded_args, **decoded_kwargs)
346-
return encode_engine_value(output)
346+
return encode_engine_value(output, type_hint=expected_return)
347347

348348
_WrappedClass.__name__ = executor_cls.__name__
349349
_WrappedClass.__doc__ = executor_cls.__doc__

0 commit comments

Comments
 (0)