Skip to content

Commit 9f73a1b

Browse files
fix(union-type): properly decode Union types for non-trivial cases (#688)
* fix(union-type): properly decode Union types for non-trivial cases * style: format fix --------- Co-authored-by: Jiangzhou He <[email protected]>
1 parent c1ce446 commit 9f73a1b

File tree

3 files changed

+60
-23
lines changed

3 files changed

+60
-23
lines changed

docs/docs/ai/llm.mdx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ cocoindex.LlmSpec(
309309

310310
You can find the full list of models supported by OpenRouter [here](https://openrouter.ai/models).
311311

312-
### vLLM
312+
### vLLM
313313

314314
Install vLLM:
315315

@@ -338,4 +338,4 @@ cocoindex.LlmSpec(
338338
```
339339

340340
</TabItem>
341-
</Tabs>
341+
</Tabs>

python/cocoindex/convert.py

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import dataclasses
66
import datetime
77
import inspect
8-
import uuid
98
from enum import Enum
109
from typing import Any, Callable, Mapping, get_origin
1110

@@ -14,7 +13,6 @@
1413
from .typing import (
1514
KEY_FIELD_NAME,
1615
TABLE_TYPES,
17-
AnalyzedTypeInfo,
1816
DtypeRegistry,
1917
analyze_type_info,
2018
encode_enriched_type,
@@ -74,30 +72,58 @@ def make_engine_value_decoder(
7472
Returns:
7573
A decoder from an engine value to a Python value.
7674
"""
77-
7875
src_type_kind = src_type["kind"]
7976

80-
dst_type_info: AnalyzedTypeInfo | None = None
81-
if (
82-
dst_annotation is not None
83-
and dst_annotation is not inspect.Parameter.empty
84-
and dst_annotation is not Any
85-
):
86-
dst_type_info = analyze_type_info(dst_annotation)
87-
if not _is_type_kind_convertible_to(src_type_kind, dst_type_info.kind):
88-
raise ValueError(
89-
f"Type mismatch for `{''.join(field_path)}`: "
90-
f"passed in {src_type_kind}, declared {dst_annotation} ({dst_type_info.kind})"
91-
)
92-
93-
if dst_type_info is None:
77+
dst_is_any = (
78+
dst_annotation is None
79+
or dst_annotation is inspect.Parameter.empty
80+
or dst_annotation is Any
81+
)
82+
if dst_is_any:
83+
if src_type_kind == "Union":
84+
return lambda value: value[1]
9485
if src_type_kind == "Struct" or src_type_kind in TABLE_TYPES:
9586
raise ValueError(
9687
f"Missing type annotation for `{''.join(field_path)}`."
9788
f"It's required for {src_type_kind} type."
9889
)
9990
return lambda value: value
10091

92+
dst_type_info = analyze_type_info(dst_annotation)
93+
94+
if src_type_kind == "Union":
95+
dst_type_variants = (
96+
dst_type_info.union_variant_types
97+
if dst_type_info.union_variant_types is not None
98+
else [dst_annotation]
99+
)
100+
src_type_variants = src_type["types"]
101+
decoders = []
102+
for i, src_type_variant in enumerate(src_type_variants):
103+
src_field_path = field_path + [f"[{i}]"]
104+
decoder = None
105+
for dst_type_variant in dst_type_variants:
106+
try:
107+
decoder = make_engine_value_decoder(
108+
src_field_path, src_type_variant, dst_type_variant
109+
)
110+
break
111+
except ValueError:
112+
pass
113+
if decoder is None:
114+
raise ValueError(
115+
f"Type mismatch for `{''.join(field_path)}`: "
116+
f"cannot find matched target type for source type variant {src_type_variant}"
117+
)
118+
decoders.append(decoder)
119+
return lambda value: decoders[value[0]](value[1])
120+
121+
if not _is_type_kind_convertible_to(src_type_kind, dst_type_info.kind):
122+
raise ValueError(
123+
f"Type mismatch for `{''.join(field_path)}`: "
124+
f"passed in {src_type_kind}, declared {dst_annotation} ({dst_type_info.kind})"
125+
)
126+
101127
if dst_type_info.kind in ("Float32", "Float64", "Int64"):
102128
dst_core_type = dst_type_info.core_type
103129

@@ -196,9 +222,6 @@ def decode(value: Any) -> Any | None:
196222
field_path.pop()
197223
return decode
198224

199-
if src_type_kind == "Union":
200-
return lambda value: value[1]
201-
202225
return lambda value: value
203226

204227

python/cocoindex/tests/test_convert.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,9 @@ def eq(a: Any, b: Any) -> bool:
104104
)
105105
decoder = make_engine_value_decoder([], encoded_output_type, value_type)
106106
decoded_value = decoder(value_from_engine)
107-
assert eq(decoded_value, value)
107+
assert eq(decoded_value, value), (
108+
f"{decoded_value} != {value}; {encoded_value}; {value_type}; {encoded_output_type}"
109+
)
108110

109111
if other_decoded_values is not None:
110112
for other_value, other_type in other_decoded_values:
@@ -613,6 +615,18 @@ def test_roundtrip_union_timedelta() -> None:
613615
validate_full_roundtrip(value, t)
614616

615617

618+
def test_roundtrip_vector_of_union() -> None:
619+
t = list[str | int]
620+
value = ["a", 1]
621+
validate_full_roundtrip(value, t)
622+
623+
624+
def test_roundtrip_union_with_vector() -> None:
625+
t = NDArray[np.float32] | str
626+
value = np.array([1.0, 2.0, 3.0], dtype=np.float32)
627+
validate_full_roundtrip(value, t, ([1.0, 2.0, 3.0], list[float] | str))
628+
629+
616630
def test_roundtrip_ltable() -> None:
617631
t = list[Order]
618632
value = [Order("O1", "item1", 10.0), Order("O2", "item2", 20.0)]

0 commit comments

Comments
 (0)