diff --git a/python/cocoindex/convert.py b/python/cocoindex/convert.py index ebe0474a..24cfd0d5 100644 --- a/python/cocoindex/convert.py +++ b/python/cocoindex/convert.py @@ -9,26 +9,26 @@ import inspect import warnings from enum import Enum -from typing import Any, Callable, Mapping, get_origin +from typing import Any, Callable, Mapping, Type, get_origin import numpy as np from .typing import ( KEY_FIELD_NAME, TABLE_TYPES, - analyze_type_info, - encode_enriched_type, - is_namedtuple_type, - is_struct_type, - AnalyzedTypeInfo, AnalyzedAnyType, + AnalyzedBasicType, AnalyzedDictType, AnalyzedListType, - AnalyzedBasicType, + AnalyzedStructType, + AnalyzedTypeInfo, AnalyzedUnionType, AnalyzedUnknownType, - AnalyzedStructType, + analyze_type_info, + encode_enriched_type, + is_namedtuple_type, is_numpy_number_type, + is_struct_type, ) @@ -50,34 +50,6 @@ def __exit__(self, _exc_type: Any, _exc_val: Any, _exc_tb: Any) -> None: self._field_path.pop() -def encode_engine_value(value: Any) -> Any: - """Encode a Python value to an engine value.""" - if dataclasses.is_dataclass(value): - return [ - encode_engine_value(getattr(value, f.name)) - for f in dataclasses.fields(value) - ] - if is_namedtuple_type(type(value)): - return [encode_engine_value(getattr(value, name)) for name in value._fields] - if isinstance(value, np.number): - return value.item() - if isinstance(value, np.ndarray): - return value - if isinstance(value, (list, tuple)): - return [encode_engine_value(v) for v in value] - if isinstance(value, dict): - if not value: - return {} - - first_val = next(iter(value.values())) - if is_struct_type(type(first_val)): # KTable - return [ - [encode_engine_value(k)] + encode_engine_value(v) - for k, v in value.items() - ] - return value - - _CONVERTIBLE_KINDS = { ("Float32", "Float64"), ("LocalDateTime", "OffsetDateTime"), @@ -91,6 +63,145 @@ def _is_type_kind_convertible_to(src_type_kind: str, dst_type_kind: str) -> bool ) +# Pre-computed type info for missing/Any type annotations +ANY_TYPE_INFO = analyze_type_info(inspect.Parameter.empty) + + +def _make_encoder_closure(type_info: AnalyzedTypeInfo) -> Callable[[Any], Any]: + """ + Create an encoder closure for a specific type. + """ + variant = type_info.variant + + if isinstance(variant, AnalyzedListType): + elem_type_info = ( + analyze_type_info(variant.elem_type) if variant.elem_type else ANY_TYPE_INFO + ) + if isinstance(elem_type_info.variant, AnalyzedStructType): + elem_encoder = _make_encoder_closure(elem_type_info) + + def encode_struct_list(value: Any) -> Any: + return None if value is None else [elem_encoder(v) for v in value] + + return encode_struct_list + + if isinstance(variant, AnalyzedDictType): + if not variant.value_type: + return lambda value: value + + value_type_info = analyze_type_info(variant.value_type) + if isinstance(value_type_info.variant, AnalyzedStructType): + + def encode_struct_dict(value: Any) -> Any: + if not isinstance(value, dict): + return value + if not value: + return [] + + sample_key, sample_val = next(iter(value.items())) + key_type, val_type = type(sample_key), type(sample_val) + + # Handle KTable case + if value and is_struct_type(val_type): + key_encoder = ( + _make_encoder_closure(analyze_type_info(key_type)) + if is_struct_type(key_type) + else _make_encoder_closure(ANY_TYPE_INFO) + ) + value_encoder = _make_encoder_closure(analyze_type_info(val_type)) + return [ + [key_encoder(k)] + value_encoder(v) for k, v in value.items() + ] + return {key_encoder(k): value_encoder(v) for k, v in value.items()} + + return encode_struct_dict + + if isinstance(variant, AnalyzedStructType): + struct_type = variant.struct_type + + if dataclasses.is_dataclass(struct_type): + fields = dataclasses.fields(struct_type) + field_encoders = [ + _make_encoder_closure(analyze_type_info(f.type)) for f in fields + ] + field_names = [f.name for f in fields] + + def encode_dataclass(value: Any) -> Any: + if not dataclasses.is_dataclass(value): + return value + return [ + encoder(getattr(value, name)) + for encoder, name in zip(field_encoders, field_names) + ] + + return encode_dataclass + + elif is_namedtuple_type(struct_type): + annotations = struct_type.__annotations__ + field_names = list(getattr(struct_type, "_fields", ())) + field_encoders = [ + _make_encoder_closure( + analyze_type_info(annotations[name]) + if name in annotations + else ANY_TYPE_INFO + ) + for name in field_names + ] + + def encode_namedtuple(value: Any) -> Any: + if not is_namedtuple_type(type(value)): + return value + return [ + encoder(getattr(value, name)) + for encoder, name in zip(field_encoders, field_names) + ] + + return encode_namedtuple + + def encode_basic_value(value: Any) -> Any: + if isinstance(value, np.number): + return value.item() + if isinstance(value, np.ndarray): + return value + if isinstance(value, (list, tuple)): + return [encode_basic_value(v) for v in value] + return value + + return encode_basic_value + + +def make_engine_value_encoder(type_hint: Type[Any] | str) -> Callable[[Any], Any]: + """ + Create an encoder closure for converting Python values to engine values. + + Args: + type_hint: Type annotation for the values to encode + + Returns: + A closure that encodes Python values to engine values + """ + type_info = analyze_type_info(type_hint) + if isinstance(type_info.variant, AnalyzedUnknownType): + raise ValueError(f"Type annotation `{type_info.core_type}` is unsupported") + + return _make_encoder_closure(type_info) + + +def encode_engine_value(value: Any, type_hint: Type[Any] | str) -> Any: + """ + Encode a Python value to an engine value. + + Args: + value: The Python value to encode + type_hint: Type annotation for the value. This should always be provided. + + Returns: + The encoded engine value + """ + encoder = make_engine_value_encoder(type_hint) + return encoder(value) + + def make_engine_value_decoder( field_path: list[str], src_type: dict[str, Any], diff --git a/python/cocoindex/flow.py b/python/cocoindex/flow.py index 8e184a69..8a93b1af 100644 --- a/python/cocoindex/flow.py +++ b/python/cocoindex/flow.py @@ -9,15 +9,6 @@ import functools import inspect import re - -from .validation import ( - validate_flow_name, - NamingError, - validate_full_flow_name, - validate_target_name, -) -from .typing import analyze_type_info - from dataclasses import dataclass from enum import Enum from threading import Lock @@ -25,13 +16,13 @@ Any, Callable, Generic, + Iterable, NamedTuple, Sequence, TypeVar, cast, get_args, get_origin, - Iterable, ) from rich.text import Text @@ -45,7 +36,12 @@ from .op import FunctionSpec from .runtime import execution_context from .setup import SetupChangeBundle -from .typing import encode_enriched_type +from .typing import analyze_type_info, encode_enriched_type +from .validation import ( + validate_flow_name, + validate_full_flow_name, + validate_target_name, +) class _NameBuilder: @@ -1099,11 +1095,16 @@ async def eval_async(self, *args: Any, **kwargs: Any) -> T: """ flow_info = await self._flow_info_async() params = [] - for i, arg in enumerate(self._param_names): + for i, (arg, arg_type) in enumerate( + zip(self._param_names, self._flow_arg_types) + ): + param_type = ( + self._flow_arg_types[i] if i < len(self._flow_arg_types) else Any + ) if i < len(args): - params.append(encode_engine_value(args[i])) + params.append(encode_engine_value(args[i], type_hint=param_type)) elif arg in kwargs: - params.append(encode_engine_value(kwargs[arg])) + params.append(encode_engine_value(kwargs[arg], type_hint=param_type)) else: raise ValueError(f"Parameter {arg} is not provided") engine_result = await flow_info.engine_flow.evaluate_async(params) diff --git a/python/cocoindex/op.py b/python/cocoindex/op.py index 0de3a1df..88084207 100644 --- a/python/cocoindex/op.py +++ b/python/cocoindex/op.py @@ -317,7 +317,7 @@ async def __call__(self, *args: Any, **kwargs: Any) -> Any: output = await self._acall(*decoded_args, **decoded_kwargs) else: output = await self._acall(*decoded_args, **decoded_kwargs) - return encode_engine_value(output) + return encode_engine_value(output, type_hint=expected_return) _WrappedClass.__name__ = executor_cls.__name__ _WrappedClass.__doc__ = executor_cls.__doc__ diff --git a/python/cocoindex/tests/test_convert.py b/python/cocoindex/tests/test_convert.py index 637e6d83..77cc1599 100644 --- a/python/cocoindex/tests/test_convert.py +++ b/python/cocoindex/tests/test_convert.py @@ -1,7 +1,7 @@ import datetime import inspect import uuid -from dataclasses import dataclass, make_dataclass, field +from dataclasses import dataclass, make_dataclass from typing import Annotated, Any, Callable, Literal, NamedTuple import numpy as np @@ -19,8 +19,8 @@ Float64, TypeKind, Vector, - encode_enriched_type, analyze_type_info, + encode_enriched_type, ) @@ -99,7 +99,7 @@ def eq(a: Any, b: Any) -> bool: return np.array_equal(a, b) return type(a) is type(b) and not not (a == b) - encoded_value = encode_engine_value(value) + encoded_value = encode_engine_value(value, value_type) value_type = value_type or type(value) encoded_output_type = encode_enriched_type(value_type)["type"] value_from_engine = _engine.testutil.seder_roundtrip( @@ -133,24 +133,24 @@ def validate_full_roundtrip( def test_encode_engine_value_basic_types() -> None: - assert encode_engine_value(123) == 123 - assert encode_engine_value(3.14) == 3.14 - assert encode_engine_value("hello") == "hello" - assert encode_engine_value(True) is True + assert encode_engine_value(123, int) == 123 + assert encode_engine_value(3.14, float) == 3.14 + assert encode_engine_value("hello", str) == "hello" + assert encode_engine_value(True, bool) is True def test_encode_engine_value_uuid() -> None: u = uuid.uuid4() - assert encode_engine_value(u) == u + assert encode_engine_value(u, uuid.UUID) == u def test_encode_engine_value_date_time_types() -> None: d = datetime.date(2024, 1, 1) - assert encode_engine_value(d) == d + assert encode_engine_value(d, datetime.date) == d t = datetime.time(12, 30) - assert encode_engine_value(t) == t + assert encode_engine_value(t, datetime.time) == t dt = datetime.datetime(2024, 1, 1, 12, 30) - assert encode_engine_value(dt) == dt + assert encode_engine_value(dt, datetime.datetime) == dt def test_encode_scalar_numpy_values() -> None: @@ -161,17 +161,22 @@ def test_encode_scalar_numpy_values() -> None: (np.float64(2.718), pytest.approx(2.718)), ] for np_value, expected in test_cases: - encoded = encode_engine_value(np_value) + encoded = encode_engine_value(np_value, type(np_value)) assert encoded == expected assert isinstance(encoded, (int, float)) def test_encode_engine_value_struct() -> None: order = Order(order_id="O123", name="mixed nuts", price=25.0) - assert encode_engine_value(order) == ["O123", "mixed nuts", 25.0, "default_extra"] + assert encode_engine_value(order, Order) == [ + "O123", + "mixed nuts", + 25.0, + "default_extra", + ] order_nt = OrderNamedTuple(order_id="O123", name="mixed nuts", price=25.0) - assert encode_engine_value(order_nt) == [ + assert encode_engine_value(order_nt, OrderNamedTuple) == [ "O123", "mixed nuts", 25.0, @@ -181,7 +186,7 @@ def test_encode_engine_value_struct() -> None: def test_encode_engine_value_list_of_structs() -> None: orders = [Order("O1", "item1", 10.0), Order("O2", "item2", 20.0)] - assert encode_engine_value(orders) == [ + assert encode_engine_value(orders, list[Order]) == [ ["O1", "item1", 10.0, "default_extra"], ["O2", "item2", 20.0, "default_extra"], ] @@ -190,7 +195,7 @@ def test_encode_engine_value_list_of_structs() -> None: OrderNamedTuple("O1", "item1", 10.0), OrderNamedTuple("O2", "item2", 20.0), ] - assert encode_engine_value(orders_nt) == [ + assert encode_engine_value(orders_nt, list[OrderNamedTuple]) == [ ["O1", "item1", 10.0, "default_extra"], ["O2", "item2", 20.0, "default_extra"], ] @@ -198,12 +203,12 @@ def test_encode_engine_value_list_of_structs() -> None: def test_encode_engine_value_struct_with_list() -> None: basket = Basket(items=["apple", "banana"]) - assert encode_engine_value(basket) == [["apple", "banana"]] + assert encode_engine_value(basket, Basket) == [["apple", "banana"]] def test_encode_engine_value_nested_struct() -> None: customer = Customer(name="Alice", order=Order("O1", "item1", 10.0)) - assert encode_engine_value(customer) == [ + assert encode_engine_value(customer, Customer) == [ "Alice", ["O1", "item1", 10.0, "default_extra"], None, @@ -212,7 +217,7 @@ def test_encode_engine_value_nested_struct() -> None: customer_nt = CustomerNamedTuple( name="Alice", order=OrderNamedTuple("O1", "item1", 10.0) ) - assert encode_engine_value(customer_nt) == [ + assert encode_engine_value(customer_nt, CustomerNamedTuple) == [ "Alice", ["O1", "item1", 10.0, "default_extra"], None, @@ -220,20 +225,20 @@ def test_encode_engine_value_nested_struct() -> None: def test_encode_engine_value_empty_list() -> None: - assert encode_engine_value([]) == [] - assert encode_engine_value([[]]) == [[]] + assert encode_engine_value([], list) == [] + assert encode_engine_value([[]], list[list[Any]]) == [[]] def test_encode_engine_value_tuple() -> None: - assert encode_engine_value(()) == [] - assert encode_engine_value((1, 2, 3)) == [1, 2, 3] - assert encode_engine_value(((1, 2), (3, 4))) == [[1, 2], [3, 4]] - assert encode_engine_value(([],)) == [[]] - assert encode_engine_value(((),)) == [[]] + assert encode_engine_value((), Any) == [] + assert encode_engine_value((1, 2, 3), Any) == [1, 2, 3] + assert encode_engine_value(((1, 2), (3, 4)), Any) == [[1, 2], [3, 4]] + assert encode_engine_value(([],), Any) == [[]] + assert encode_engine_value(((),), Any) == [[]] def test_encode_engine_value_none() -> None: - assert encode_engine_value(None) is None + assert encode_engine_value(None, Any) is None def test_roundtrip_basic_types() -> None: @@ -743,7 +748,7 @@ class OrderKey: def test_vector_as_vector() -> None: value = np.array([1, 2, 3, 4, 5], dtype=np.int64) - encoded = encode_engine_value(value) + encoded = encode_engine_value(value, IntVectorType) assert np.array_equal(encoded, value) decoded = build_engine_value_decoder(IntVectorType)(encoded) assert np.array_equal(decoded, value) @@ -754,7 +759,7 @@ def test_vector_as_vector() -> None: def test_vector_as_list() -> None: value: ListIntType = [1, 2, 3, 4, 5] - encoded = encode_engine_value(value) + encoded = encode_engine_value(value, ListIntType) assert encoded == [1, 2, 3, 4, 5] decoded = build_engine_value_decoder(ListIntType)(encoded) assert np.array_equal(decoded, value) @@ -772,13 +777,19 @@ def test_vector_as_list() -> None: def test_encode_engine_value_ndarray() -> None: """Test encoding NDArray vectors to lists for the Rust engine.""" vec_f32: Float32VectorType = np.array([1.0, 2.0, 3.0], dtype=np.float32) - assert np.array_equal(encode_engine_value(vec_f32), [1.0, 2.0, 3.0]) + assert np.array_equal( + encode_engine_value(vec_f32, Float32VectorType), [1.0, 2.0, 3.0] + ) vec_f64: Float64VectorType = np.array([1.0, 2.0, 3.0], dtype=np.float64) - assert np.array_equal(encode_engine_value(vec_f64), [1.0, 2.0, 3.0]) + assert np.array_equal( + encode_engine_value(vec_f64, Float64VectorType), [1.0, 2.0, 3.0] + ) vec_i64: Int64VectorType = np.array([1, 2, 3], dtype=np.int64) - assert np.array_equal(encode_engine_value(vec_i64), [1, 2, 3]) + assert np.array_equal(encode_engine_value(vec_i64, Int64VectorType), [1, 2, 3]) vec_nd_f32: NDArrayFloat32Type = np.array([1.0, 2.0, 3.0], dtype=np.float32) - assert np.array_equal(encode_engine_value(vec_nd_f32), [1.0, 2.0, 3.0]) + assert np.array_equal( + encode_engine_value(vec_nd_f32, NDArrayFloat32Type), [1.0, 2.0, 3.0] + ) def test_make_engine_value_decoder_ndarray() -> None: @@ -808,21 +819,21 @@ def test_make_engine_value_decoder_ndarray() -> None: def test_roundtrip_ndarray_vector() -> None: """Test roundtrip encoding and decoding of NDArray vectors.""" value_f32 = np.array([1.0, 2.0, 3.0], dtype=np.float32) - encoded_f32 = encode_engine_value(value_f32) + encoded_f32 = encode_engine_value(value_f32, Float32VectorType) np.array_equal(encoded_f32, [1.0, 2.0, 3.0]) decoded_f32 = build_engine_value_decoder(Float32VectorType)(encoded_f32) assert isinstance(decoded_f32, np.ndarray) assert decoded_f32.dtype == np.float32 assert np.array_equal(decoded_f32, value_f32) value_i64 = np.array([1, 2, 3], dtype=np.int64) - encoded_i64 = encode_engine_value(value_i64) + encoded_i64 = encode_engine_value(value_i64, Int64VectorType) assert np.array_equal(encoded_i64, [1, 2, 3]) decoded_i64 = build_engine_value_decoder(Int64VectorType)(encoded_i64) assert isinstance(decoded_i64, np.ndarray) assert decoded_i64.dtype == np.int64 assert np.array_equal(decoded_i64, value_i64) value_nd_f64: NDArrayFloat64Type = np.array([1.0, 2.0, 3.0], dtype=np.float64) - encoded_nd_f64 = encode_engine_value(value_nd_f64) + encoded_nd_f64 = encode_engine_value(value_nd_f64, NDArrayFloat64Type) assert np.array_equal(encoded_nd_f64, [1.0, 2.0, 3.0]) decoded_nd_f64 = build_engine_value_decoder(NDArrayFloat64Type)(encoded_nd_f64) assert isinstance(decoded_nd_f64, np.ndarray) @@ -833,7 +844,7 @@ def test_roundtrip_ndarray_vector() -> None: def test_ndarray_dimension_mismatch() -> None: """Test dimension enforcement for Vector with specified dimension.""" value = np.array([1.0, 2.0], dtype=np.float32) - encoded = encode_engine_value(value) + encoded = encode_engine_value(value, NDArray[np.float32]) assert np.array_equal(encoded, [1.0, 2.0]) with pytest.raises(ValueError, match="Vector dimension mismatch"): build_engine_value_decoder(Float32VectorType)(encoded) @@ -842,14 +853,14 @@ def test_ndarray_dimension_mismatch() -> None: def test_list_vector_backward_compatibility() -> None: """Test that list-based vectors still work for backward compatibility.""" value = [1, 2, 3, 4, 5] - encoded = encode_engine_value(value) + encoded = encode_engine_value(value, list[int]) assert encoded == [1, 2, 3, 4, 5] decoded = build_engine_value_decoder(IntVectorType)(encoded) assert isinstance(decoded, np.ndarray) assert decoded.dtype == np.int64 assert np.array_equal(decoded, np.array([1, 2, 3, 4, 5], dtype=np.int64)) value_list: ListIntType = [1, 2, 3, 4, 5] - encoded = encode_engine_value(value_list) + encoded = encode_engine_value(value_list, ListIntType) assert np.array_equal(encoded, [1, 2, 3, 4, 5]) decoded = build_engine_value_decoder(ListIntType)(encoded) assert np.array_equal(decoded, [1, 2, 3, 4, 5]) @@ -867,7 +878,7 @@ class MyStructWithNDArray: original = MyStructWithNDArray( name="test_np", data=np.array([1.0, 0.5], dtype=np.float32), value=100 ) - encoded = encode_engine_value(original) + encoded = encode_engine_value(original, MyStructWithNDArray) assert encoded[0] == original.name assert np.array_equal(encoded[1], original.data) @@ -1026,7 +1037,7 @@ def test_full_roundtrip_vector_of_vector() -> None: ), ( value_f32, - np.typing.NDArray[np.typing.NDArray[np.float32]], + np.typing.NDArray[np.float32], ), ) @@ -1515,7 +1526,7 @@ class UnsupportedField: validate_full_roundtrip(LTableField(1, []), LTableField) - # validate_full_roundtrip(KTableField(1, {}), KTableField) + validate_full_roundtrip(KTableField(1, {}), KTableField) with pytest.raises( ValueError,