Skip to content

feat(convert): support explicit type hints in engine value encoding #807

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Aug 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 147 additions & 36 deletions python/cocoindex/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,26 @@
import inspect
import warnings
from enum import Enum
from typing import Any, Callable, Mapping, get_origin
from typing import Any, Callable, Mapping, Type, get_origin

import numpy as np

from .typing import (
KEY_FIELD_NAME,
TABLE_TYPES,
analyze_type_info,
encode_enriched_type,
is_namedtuple_type,
is_struct_type,
AnalyzedTypeInfo,
AnalyzedAnyType,
AnalyzedBasicType,
AnalyzedDictType,
AnalyzedListType,
AnalyzedBasicType,
AnalyzedStructType,
AnalyzedTypeInfo,
AnalyzedUnionType,
AnalyzedUnknownType,
AnalyzedStructType,
analyze_type_info,
encode_enriched_type,
is_namedtuple_type,
is_numpy_number_type,
is_struct_type,
)


Expand All @@ -50,34 +50,6 @@ def __exit__(self, _exc_type: Any, _exc_val: Any, _exc_tb: Any) -> None:
self._field_path.pop()


def encode_engine_value(value: Any) -> Any:
"""Encode a Python value to an engine value."""
if dataclasses.is_dataclass(value):
return [
encode_engine_value(getattr(value, f.name))
for f in dataclasses.fields(value)
]
if is_namedtuple_type(type(value)):
return [encode_engine_value(getattr(value, name)) for name in value._fields]
if isinstance(value, np.number):
return value.item()
if isinstance(value, np.ndarray):
return value
if isinstance(value, (list, tuple)):
return [encode_engine_value(v) for v in value]
if isinstance(value, dict):
if not value:
return {}

first_val = next(iter(value.values()))
if is_struct_type(type(first_val)): # KTable
return [
[encode_engine_value(k)] + encode_engine_value(v)
for k, v in value.items()
]
return value


_CONVERTIBLE_KINDS = {
("Float32", "Float64"),
("LocalDateTime", "OffsetDateTime"),
Expand All @@ -91,6 +63,145 @@ def _is_type_kind_convertible_to(src_type_kind: str, dst_type_kind: str) -> bool
)


# Pre-computed type info for missing/Any type annotations
ANY_TYPE_INFO = analyze_type_info(inspect.Parameter.empty)


def _make_encoder_closure(type_info: AnalyzedTypeInfo) -> Callable[[Any], Any]:
"""
Create an encoder closure for a specific type.
"""
variant = type_info.variant

if isinstance(variant, AnalyzedListType):
elem_type_info = (
analyze_type_info(variant.elem_type) if variant.elem_type else ANY_TYPE_INFO
)
if isinstance(elem_type_info.variant, AnalyzedStructType):
elem_encoder = _make_encoder_closure(elem_type_info)

def encode_struct_list(value: Any) -> Any:
return None if value is None else [elem_encoder(v) for v in value]

return encode_struct_list

if isinstance(variant, AnalyzedDictType):
if not variant.value_type:
return lambda value: value

value_type_info = analyze_type_info(variant.value_type)
if isinstance(value_type_info.variant, AnalyzedStructType):

def encode_struct_dict(value: Any) -> Any:
if not isinstance(value, dict):
return value
if not value:
return []

sample_key, sample_val = next(iter(value.items()))
key_type, val_type = type(sample_key), type(sample_val)

# Handle KTable case
if value and is_struct_type(val_type):
key_encoder = (
_make_encoder_closure(analyze_type_info(key_type))
if is_struct_type(key_type)
else _make_encoder_closure(ANY_TYPE_INFO)
)
value_encoder = _make_encoder_closure(analyze_type_info(val_type))
return [
[key_encoder(k)] + value_encoder(v) for k, v in value.items()
]
return {key_encoder(k): value_encoder(v) for k, v in value.items()}

return encode_struct_dict

if isinstance(variant, AnalyzedStructType):
struct_type = variant.struct_type

if dataclasses.is_dataclass(struct_type):
fields = dataclasses.fields(struct_type)
field_encoders = [
_make_encoder_closure(analyze_type_info(f.type)) for f in fields
]
field_names = [f.name for f in fields]

def encode_dataclass(value: Any) -> Any:
if not dataclasses.is_dataclass(value):
return value
return [
encoder(getattr(value, name))
for encoder, name in zip(field_encoders, field_names)
]

return encode_dataclass

elif is_namedtuple_type(struct_type):
annotations = struct_type.__annotations__
field_names = list(getattr(struct_type, "_fields", ()))
field_encoders = [
_make_encoder_closure(
analyze_type_info(annotations[name])
if name in annotations
else ANY_TYPE_INFO
)
for name in field_names
]

def encode_namedtuple(value: Any) -> Any:
if not is_namedtuple_type(type(value)):
return value
return [
encoder(getattr(value, name))
for encoder, name in zip(field_encoders, field_names)
]

return encode_namedtuple

def encode_basic_value(value: Any) -> Any:
if isinstance(value, np.number):
return value.item()
if isinstance(value, np.ndarray):
return value
if isinstance(value, (list, tuple)):
return [encode_basic_value(v) for v in value]
return value

return encode_basic_value


def make_engine_value_encoder(type_hint: Type[Any] | str) -> Callable[[Any], Any]:
"""
Create an encoder closure for converting Python values to engine values.

Args:
type_hint: Type annotation for the values to encode

Returns:
A closure that encodes Python values to engine values
"""
type_info = analyze_type_info(type_hint)
if isinstance(type_info.variant, AnalyzedUnknownType):
raise ValueError(f"Type annotation `{type_info.core_type}` is unsupported")

return _make_encoder_closure(type_info)


def encode_engine_value(value: Any, type_hint: Type[Any] | str) -> Any:
"""
Encode a Python value to an engine value.

Args:
value: The Python value to encode
type_hint: Type annotation for the value. This should always be provided.

Returns:
The encoded engine value
"""
encoder = make_engine_value_encoder(type_hint)
return encoder(value)


def make_engine_value_decoder(
field_path: list[str],
src_type: dict[str, Any],
Expand Down
29 changes: 15 additions & 14 deletions python/cocoindex/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,20 @@
import functools
import inspect
import re

from .validation import (
validate_flow_name,
NamingError,
validate_full_flow_name,
validate_target_name,
)
from .typing import analyze_type_info

from dataclasses import dataclass
from enum import Enum
from threading import Lock
from typing import (
Any,
Callable,
Generic,
Iterable,
NamedTuple,
Sequence,
TypeVar,
cast,
get_args,
get_origin,
Iterable,
)

from rich.text import Text
Expand All @@ -45,7 +36,12 @@
from .op import FunctionSpec
from .runtime import execution_context
from .setup import SetupChangeBundle
from .typing import encode_enriched_type
from .typing import analyze_type_info, encode_enriched_type
from .validation import (
validate_flow_name,
validate_full_flow_name,
validate_target_name,
)


class _NameBuilder:
Expand Down Expand Up @@ -1099,11 +1095,16 @@ async def eval_async(self, *args: Any, **kwargs: Any) -> T:
"""
flow_info = await self._flow_info_async()
params = []
for i, arg in enumerate(self._param_names):
for i, (arg, arg_type) in enumerate(
zip(self._param_names, self._flow_arg_types)
):
param_type = (
self._flow_arg_types[i] if i < len(self._flow_arg_types) else Any
)
if i < len(args):
params.append(encode_engine_value(args[i]))
params.append(encode_engine_value(args[i], type_hint=param_type))
elif arg in kwargs:
params.append(encode_engine_value(kwargs[arg]))
params.append(encode_engine_value(kwargs[arg], type_hint=param_type))
else:
raise ValueError(f"Parameter {arg} is not provided")
engine_result = await flow_info.engine_flow.evaluate_async(params)
Expand Down
2 changes: 1 addition & 1 deletion python/cocoindex/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ async def __call__(self, *args: Any, **kwargs: Any) -> Any:
output = await self._acall(*decoded_args, **decoded_kwargs)
else:
output = await self._acall(*decoded_args, **decoded_kwargs)
return encode_engine_value(output)
return encode_engine_value(output, type_hint=expected_return)

_WrappedClass.__name__ = executor_cls.__name__
_WrappedClass.__doc__ = executor_cls.__doc__
Expand Down
Loading
Loading