|
5 | 5 | import dataclasses
|
6 | 6 | import datetime
|
7 | 7 | import inspect
|
8 |
| -import uuid |
9 | 8 | from enum import Enum
|
10 | 9 | from typing import Any, Callable, Mapping, get_origin
|
11 | 10 |
|
|
14 | 13 | from .typing import (
|
15 | 14 | KEY_FIELD_NAME,
|
16 | 15 | TABLE_TYPES,
|
17 |
| - AnalyzedTypeInfo, |
18 | 16 | DtypeRegistry,
|
19 | 17 | analyze_type_info,
|
20 | 18 | encode_enriched_type,
|
@@ -74,30 +72,58 @@ def make_engine_value_decoder(
|
74 | 72 | Returns:
|
75 | 73 | A decoder from an engine value to a Python value.
|
76 | 74 | """
|
77 |
| - |
78 | 75 | src_type_kind = src_type["kind"]
|
79 | 76 |
|
80 |
| - dst_type_info: AnalyzedTypeInfo | None = None |
81 |
| - if ( |
82 |
| - dst_annotation is not None |
83 |
| - and dst_annotation is not inspect.Parameter.empty |
84 |
| - and dst_annotation is not Any |
85 |
| - ): |
86 |
| - dst_type_info = analyze_type_info(dst_annotation) |
87 |
| - if not _is_type_kind_convertible_to(src_type_kind, dst_type_info.kind): |
88 |
| - raise ValueError( |
89 |
| - f"Type mismatch for `{''.join(field_path)}`: " |
90 |
| - f"passed in {src_type_kind}, declared {dst_annotation} ({dst_type_info.kind})" |
91 |
| - ) |
92 |
| - |
93 |
| - if dst_type_info is None: |
| 77 | + dst_is_any = ( |
| 78 | + dst_annotation is None |
| 79 | + or dst_annotation is inspect.Parameter.empty |
| 80 | + or dst_annotation is Any |
| 81 | + ) |
| 82 | + if dst_is_any: |
| 83 | + if src_type_kind == "Union": |
| 84 | + return lambda value: value[1] |
94 | 85 | if src_type_kind == "Struct" or src_type_kind in TABLE_TYPES:
|
95 | 86 | raise ValueError(
|
96 | 87 | f"Missing type annotation for `{''.join(field_path)}`."
|
97 | 88 | f"It's required for {src_type_kind} type."
|
98 | 89 | )
|
99 | 90 | return lambda value: value
|
100 | 91 |
|
| 92 | + dst_type_info = analyze_type_info(dst_annotation) |
| 93 | + |
| 94 | + if src_type_kind == "Union": |
| 95 | + dst_type_variants = ( |
| 96 | + dst_type_info.union_variant_types |
| 97 | + if dst_type_info.union_variant_types is not None |
| 98 | + else [dst_annotation] |
| 99 | + ) |
| 100 | + src_type_variants = src_type["types"] |
| 101 | + decoders = [] |
| 102 | + for i, src_type_variant in enumerate(src_type_variants): |
| 103 | + src_field_path = field_path + [f"[{i}]"] |
| 104 | + decoder = None |
| 105 | + for dst_type_variant in dst_type_variants: |
| 106 | + try: |
| 107 | + decoder = make_engine_value_decoder( |
| 108 | + src_field_path, src_type_variant, dst_type_variant |
| 109 | + ) |
| 110 | + break |
| 111 | + except ValueError: |
| 112 | + pass |
| 113 | + if decoder is None: |
| 114 | + raise ValueError( |
| 115 | + f"Type mismatch for `{''.join(field_path)}`: " |
| 116 | + f"cannot find matched target type for source type variant {src_type_variant}" |
| 117 | + ) |
| 118 | + decoders.append(decoder) |
| 119 | + return lambda value: decoders[value[0]](value[1]) |
| 120 | + |
| 121 | + if not _is_type_kind_convertible_to(src_type_kind, dst_type_info.kind): |
| 122 | + raise ValueError( |
| 123 | + f"Type mismatch for `{''.join(field_path)}`: " |
| 124 | + f"passed in {src_type_kind}, declared {dst_annotation} ({dst_type_info.kind})" |
| 125 | + ) |
| 126 | + |
101 | 127 | if dst_type_info.kind in ("Float32", "Float64", "Int64"):
|
102 | 128 | dst_core_type = dst_type_info.core_type
|
103 | 129 |
|
@@ -196,9 +222,6 @@ def decode(value: Any) -> Any | None:
|
196 | 222 | field_path.pop()
|
197 | 223 | return decode
|
198 | 224 |
|
199 |
| - if src_type_kind == "Union": |
200 |
| - return lambda value: value[1] |
201 |
| - |
202 | 225 | return lambda value: value
|
203 | 226 |
|
204 | 227 |
|
|
0 commit comments