Skip to content

Commit 21881bd

Browse files
Tuple fix (#684)
The pydantic transformer to support non-native Flyte types was erroring on `complex_typeddict_workflow` and `optional_fields_workflow`. All the complex workflow example now run. --------- Signed-off-by: Yee Hing Tong <wild-endeavor@users.noreply.github.com>
1 parent cd48acf commit 21881bd

File tree

2 files changed

+122
-15
lines changed

2 files changed

+122
-15
lines changed

src/flyte/types/_tuple_dict.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,8 @@ def _convert_value_from_pydantic(self, value: Any, expected_type: Optional[Type]
256256
return [self._convert_value_from_pydantic(v, args[0]) for v in value]
257257
elif isinstance(value, dict) and origin is dict and len(args) >= 2:
258258
return {k: self._convert_value_from_pydantic(v, args[1]) for k, v in value.items()}
259+
elif isinstance(value, (list, tuple)) and (origin is tuple or origin is Tuple) and args:
260+
return tuple(self._convert_value_from_pydantic(v, t) for v, t in zip(value, args))
259261

260262
return value
261263

@@ -611,8 +613,13 @@ def _create_wrapper_model(
611613
return _model_cache[t]
612614

613615
# Get field names and types from the TypedDict
614-
annotations = getattr(t, "__annotations__", {})
615-
required_keys: frozenset[str] = getattr(t, "__required_keys__", frozenset())
616+
# Use get_type_hints(include_extras=True) to preserve NotRequired/Required wrappers,
617+
# which we need to detect optional fields. We can't rely on __required_keys__ because
618+
# it is wrong when `from __future__ import annotations` is used (all fields appear required).
619+
try:
620+
annotations_with_extras = typing.get_type_hints(t, include_extras=True)
621+
except Exception:
622+
annotations_with_extras = getattr(t, "__annotations__", {})
616623

617624
field_definitions: Dict[str, Any] = {}
618625

@@ -621,11 +628,18 @@ def _create_wrapper_model(
621628

622629
model_name = f"{self._WRAPPER_PREFIX}{t.__name__}"
623630

624-
# Create a placeholder in the cache before processing fields to handle self-referential types
625-
# We'll update it after the model is created
626-
_model_cache[t] = None # type: ignore
631+
# Use the model name as a placeholder in the cache before processing fields.
632+
# For self-referential types (e.g. TreeNode with children: List[TreeNode]),
633+
# the recursive call to _create_wrapper_model will return this string, which
634+
# becomes a forward reference (e.g. List["TypedDictWrapper_TreeNode"]) that
635+
# Pydantic resolves via model_rebuild() after model creation.
636+
_model_cache[t] = model_name # type: ignore
637+
638+
for field_name, field_type in annotations_with_extras.items():
639+
# Check if the field is NotRequired before unwrapping
640+
origin = get_origin(field_type)
641+
is_not_required = origin is NotRequired
627642

628-
for field_name, field_type in annotations.items():
629643
# Unwrap NotRequired and Required type hints to get the inner type
630644
# These are only used by TypedDict to mark optional/required fields
631645
# and should not be passed to Pydantic
@@ -635,14 +649,19 @@ def _create_wrapper_model(
635649
# This is necessary because isinstance() doesn't work with TypedDict on Python < 3.12
636650
pydantic_type = self._convert_field_type_for_pydantic(inner_type, _model_cache)
637651

638-
if field_name in required_keys:
639-
field_definitions[field_name] = (pydantic_type, ...)
640-
else:
652+
if is_not_required:
641653
# Optional fields get a default of None
642654
field_definitions[field_name] = (typing.Optional[pydantic_type], None)
655+
else:
656+
field_definitions[field_name] = (pydantic_type, ...)
643657

644658
model = create_model(model_name, **field_definitions)
645659
_model_cache[t] = model
660+
661+
# Rebuild to resolve any forward references from self-referential types
662+
rebuild_ns = {f"{self._WRAPPER_PREFIX}{k.__name__}": v for k, v in _model_cache.items() if isinstance(v, type)}
663+
model.model_rebuild(_types_namespace=rebuild_ns)
664+
646665
return model
647666

648667
def _convert_field_type_for_pydantic(self, field_type: Type, _model_cache: Dict[Type, Type[BaseModel]]) -> Type:
@@ -701,7 +720,10 @@ def _unwrap_typeddict_field_type(self, field_type: Type) -> Type:
701720

702721
def _value_to_model(self, python_val: dict, model_class: Type[BaseModel], python_type: Type) -> BaseModel:
703722
"""Convert a TypedDict to a Pydantic model instance."""
704-
annotations = getattr(python_type, "__annotations__", {})
723+
try:
724+
annotations = typing.get_type_hints(python_type)
725+
except Exception:
726+
annotations = getattr(python_type, "__annotations__", {})
705727

706728
# Convert nested values that might be TypedDicts, dataclasses, or Pydantic models
707729
# to a format Pydantic can validate (dicts)
@@ -714,19 +736,24 @@ def _value_to_model(self, python_val: dict, model_class: Type[BaseModel], python
714736

715737
def _model_to_value(self, model_instance: BaseModel, expected_type: Type) -> dict:
716738
"""Convert a Pydantic model instance back to a TypedDict."""
717-
annotations = getattr(expected_type, "__annotations__", {})
718-
required_keys: frozenset[str] = getattr(expected_type, "__required_keys__", frozenset())
739+
try:
740+
annotations_with_extras = typing.get_type_hints(expected_type, include_extras=True)
741+
except Exception:
742+
annotations_with_extras = getattr(expected_type, "__annotations__", {})
719743
result = {}
720-
for name, field_type in annotations.items():
744+
for name, field_type in annotations_with_extras.items():
721745
if hasattr(model_instance, name):
722746
value = getattr(model_instance, name)
723747
# Skip NotRequired fields when value is None
724748
# This ensures that optional fields not provided in the input
725749
# are absent from the output dict (not set to None)
726-
if name not in required_keys and value is None:
750+
origin = get_origin(field_type)
751+
if origin is NotRequired and value is None:
727752
continue
753+
# Unwrap NotRequired/Required before converting
754+
inner_type = self._unwrap_typeddict_field_type(field_type)
728755
# Recursively convert nested values
729-
converted_value = self._convert_value_from_pydantic(value, field_type)
756+
converted_value = self._convert_value_from_pydantic(value, inner_type)
730757
result[name] = converted_value
731758
return result
732759

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
"""
2+
Special test cases to test the pydantic non-flyte type transformer.
3+
4+
1. Tuple containing TypedDict elements not round-tripping (elements stayed as Pydantic wrappers)
5+
2. NotRequired fields broken with `from __future__ import annotations` (__required_keys__ wrong)
6+
3. Self-referential TypedDicts failing (None placeholder used as a type)
7+
"""
8+
9+
from __future__ import annotations
10+
11+
from typing import List
12+
13+
import pytest
14+
from typing_extensions import NotRequired, TypedDict
15+
16+
from flyte.types._type_engine import TypeEngine
17+
18+
19+
class Coord(TypedDict):
20+
x: float
21+
y: float
22+
23+
24+
class Node(TypedDict):
25+
value: str
26+
children: NotRequired[List[Node]]
27+
28+
29+
class Message(TypedDict):
30+
text: str
31+
tags: NotRequired[List[str]]
32+
33+
34+
@pytest.mark.asyncio
35+
async def test_tuple_containing_typeddict_elements():
36+
"""Tuple elements that are TypedDicts must be dicts after round-trip, not Pydantic wrappers."""
37+
pt = tuple[Coord, Coord]
38+
lt = TypeEngine.to_literal_type(pt)
39+
value = (Coord(x=1.0, y=2.0), Coord(x=3.0, y=4.0))
40+
41+
lv = await TypeEngine.to_literal(value, pt, lt)
42+
result = await TypeEngine.to_python_value(lv, pt)
43+
44+
assert result[0]["x"] == 1.0
45+
assert result[1]["y"] == 4.0
46+
47+
48+
@pytest.mark.asyncio
49+
async def test_notrequired_field_with_future_annotations():
50+
"""NotRequired fields must be optional even with `from __future__ import annotations`."""
51+
pt = Message
52+
lt = TypeEngine.to_literal_type(pt)
53+
54+
# Without optional field
55+
value_without = Message(text="hello")
56+
lv = await TypeEngine.to_literal(value_without, pt, lt)
57+
result = await TypeEngine.to_python_value(lv, pt)
58+
assert result == {"text": "hello"}
59+
assert "tags" not in result
60+
61+
# With optional field
62+
value_with = Message(text="hello", tags=["a", "b"])
63+
lv = await TypeEngine.to_literal(value_with, pt, lt)
64+
result = await TypeEngine.to_python_value(lv, pt)
65+
assert result == {"text": "hello", "tags": ["a", "b"]}
66+
67+
68+
@pytest.mark.asyncio
69+
async def test_self_referential_typeddict():
70+
"""Self-referential TypedDicts must round-trip correctly."""
71+
pt = Node
72+
lt = TypeEngine.to_literal_type(pt)
73+
value = Node(value="root", children=[Node(value="child1"), Node(value="child2")])
74+
75+
lv = await TypeEngine.to_literal(value, pt, lt)
76+
result = await TypeEngine.to_python_value(lv, pt)
77+
78+
assert result["value"] == "root"
79+
assert len(result["children"]) == 2
80+
assert result["children"][0]["value"] == "child1"

0 commit comments

Comments
 (0)