Skip to content

Commit 81483c8

Browse files
authored
Handle union and literal typing correctly in annotations (#478)
1 parent a58eb8c commit 81483c8

File tree

8 files changed

+277
-27
lines changed

8 files changed

+277
-27
lines changed

csp/impl/struct.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@ def __new__(cls, name, bases, dct):
3535
# Lists need to be normalized too as potentially we need to add a boolean flag to use FastList
3636
if v == FastList:
3737
raise TypeError(f"{v} annotation is not supported without args")
38-
if CspTypingUtils.is_generic_container(v) or CspTypingUtils.is_union_type(v):
38+
if (
39+
CspTypingUtils.is_generic_container(v)
40+
or CspTypingUtils.is_union_type(v)
41+
or CspTypingUtils.is_literal_type(v)
42+
):
3943
actual_type = ContainerTypeNormalizer.normalized_type_to_actual_python_type(v)
4044
if CspTypingUtils.is_generic_container(actual_type):
4145
raise TypeError(f"{v} annotation is not supported as a struct field [{actual_type}]")
@@ -191,7 +195,8 @@ def _obj_from_python(cls, json, obj_type):
191195
if CspTypingUtils.is_generic_container(obj_type):
192196
if CspTypingUtils.get_origin(obj_type) in (typing.List, typing.Set, typing.Tuple, FastList):
193197
return_type = ContainerTypeNormalizer.normalized_type_to_actual_python_type(obj_type)
194-
(expected_item_type,) = obj_type.__args__
198+
# We only take the first item, so like for a Tuple, we would ignore arguments after
199+
expected_item_type = obj_type.__args__[0]
195200
return_type = list if isinstance(return_type, list) else return_type
196201
return return_type(cls._obj_from_python(v, expected_item_type) for v in json)
197202
elif CspTypingUtils.get_origin(obj_type) is typing.Dict:
@@ -206,6 +211,13 @@ def _obj_from_python(cls, json, obj_type):
206211
return json
207212
else:
208213
raise NotImplementedError(f"Can not deserialize {obj_type} from json")
214+
elif CspTypingUtils.is_union_type(obj_type):
215+
return json ## no checks, just let it through
216+
elif CspTypingUtils.is_literal_type(obj_type):
217+
return_type = ContainerTypeNormalizer.normalized_type_to_actual_python_type(obj_type)
218+
if isinstance(json, return_type):
219+
return json
220+
raise ValueError(f"Expected type {return_type} received {json.__class__}")
209221
elif issubclass(obj_type, Struct):
210222
if not isinstance(json, dict):
211223
raise TypeError("Representation of struct as json is expected to be of dict type")

csp/impl/types/container_type_normalizer.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -81,21 +81,21 @@ def normalized_type_to_actual_python_type(cls, typ, level=0):
8181
return [cls.normalized_type_to_actual_python_type(typ.__args__[0], level + 1), True]
8282
if origin is typing.List and level == 0:
8383
return [cls.normalized_type_to_actual_python_type(typ.__args__[0], level + 1)]
84-
if origin is typing.Literal:
85-
# Import here to prevent circular import
86-
from csp.impl.types.instantiation_type_resolver import UpcastRegistry
87-
88-
args = typing.get_args(typ)
89-
typ = type(args[0])
90-
for arg in args[1:]:
91-
typ = UpcastRegistry.instance().resolve_type(typ, type(arg), raise_on_error=False)
92-
if typ:
93-
return typ
94-
else:
95-
return object
9684
return cls._NORMALIZED_TYPE_MAPPING.get(CspTypingUtils.get_origin(typ), typ)
9785
elif CspTypingUtils.is_union_type(typ):
9886
return object
87+
elif CspTypingUtils.is_literal_type(typ):
88+
# Import here to prevent circular import
89+
from csp.impl.types.instantiation_type_resolver import UpcastRegistry
90+
91+
args = typing.get_args(typ)
92+
typ = type(args[0])
93+
for arg in args[1:]:
94+
typ = UpcastRegistry.instance().resolve_type(typ, type(arg), raise_on_error=False)
95+
if typ:
96+
return typ
97+
else:
98+
return object
9999
else:
100100
return typ
101101

csp/impl/types/pydantic_types.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import sys
22
import types
33
import typing
4-
from typing import Any, ForwardRef, Generic, Optional, Type, TypeVar, Union, get_args, get_origin
4+
from typing import Any, ForwardRef, Generic, Literal, Optional, Type, TypeVar, Union, get_args, get_origin
55

66
from pydantic import GetCoreSchemaHandler, ValidationInfo, ValidatorFunctionWrapHandler
77
from pydantic_core import CoreSchema, core_schema
@@ -184,6 +184,8 @@ def adjust_annotations(
184184
return TsType[
185185
adjust_annotations(args[0], top_level=False, in_ts=True, make_optional=False, forced_tvars=forced_tvars)
186186
]
187+
if origin is Literal: # for literals, we stop converting
188+
return Optional[annotation] if make_optional else annotation
187189
else:
188190
try:
189191
if origin is CspTypeVar or origin is CspTypeVarType:

csp/impl/types/type_annotation_normalizer_transformer.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def visit_arg(self, node):
5151
return node
5252

5353
def visit_Subscript(self, node):
54+
# We choose to avoid parsing here
55+
# to maintain current behavior of allowing empty lists in our types
5456
return node
5557

5658
def visit_List(self, node):
@@ -98,17 +100,13 @@ def visit_Call(self, node):
98100
return node
99101

100102
def visit_Constant(self, node):
101-
if not self._cur_arg:
102-
return node
103-
104-
if self._cur_arg:
105-
return ast.Call(
106-
func=ast.Attribute(value=ast.Name(id="typing", ctx=ast.Load()), attr="TypeVar", ctx=ast.Load()),
107-
args=[node],
108-
keywords=[],
109-
)
110-
else:
103+
if not self._cur_arg or not isinstance(node.value, str):
111104
return node
105+
return ast.Call(
106+
func=ast.Attribute(value=ast.Name(id="typing", ctx=ast.Load()), attr="TypeVar", ctx=ast.Load()),
107+
args=[node],
108+
keywords=[],
109+
)
112110

113111
def visit_Str(self, node):
114112
return self.visit_Constant(node)

csp/impl/types/typing_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class CspTypingUtils39:
2323

2424
@classmethod
2525
def is_generic_container(cls, typ):
26-
return isinstance(typ, cls._GENERIC_ALIASES) and typ.__origin__ is not typing.Union
26+
return isinstance(typ, cls._GENERIC_ALIASES) and typ.__origin__ not in (typing.Union, typing.Literal)
2727

2828
@classmethod
2929
def is_type_spec(cls, val):
@@ -56,6 +56,10 @@ def is_numpy_nd_array_type(cls, typ):
5656
def is_union_type(cls, typ):
5757
return isinstance(typ, typing._GenericAlias) and typ.__origin__ is typing.Union
5858

59+
@classmethod
60+
def is_literal_type(cls, typ):
61+
return isinstance(typ, typing._GenericAlias) and typ.__origin__ is typing.Literal
62+
5963
@classmethod
6064
def is_forward_ref(cls, typ):
6165
return isinstance(typ, typing.ForwardRef)

csp/tests/impl/test_struct.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import enum
22
import json
33
import pickle
4+
import sys
45
import typing
56
import unittest
67
from datetime import date, datetime, time, timedelta
@@ -3903,6 +3904,129 @@ class DataPoint(csp.Struct):
39033904
self.assertNotIn("_last_updated", json_data)
39043905
self.assertNotIn("_source", json_data["data"])
39053906

3907+
def test_literal_types_validation(self):
3908+
"""Test that Literal type annotations correctly validate input values in CSP Structs"""
3909+
3910+
# Define a simple class with various Literal types
3911+
class StructWithLiterals(csp.Struct):
3912+
# String literals
3913+
color: Literal["red", "green", "blue"]
3914+
# Integer literals
3915+
size: Literal[1, 2, 3]
3916+
# Mixed type literals
3917+
status: Literal["on", "off", 0, 1, True, False]
3918+
# Optional literal with default
3919+
mode: Optional[Literal["fast", "slow"]] = "fast"
3920+
3921+
# Test valid assignments
3922+
s1 = StructWithLiterals(color="red", size=2, status="on")
3923+
self.assertEqual(s1.color, "red")
3924+
self.assertEqual(s1.size, 2)
3925+
self.assertEqual(s1.status, "on")
3926+
self.assertEqual(s1.mode, "fast") # Default value
3927+
3928+
s2 = StructWithLiterals.from_dict(dict(color="blue", size=1, status=True, mode="slow"))
3929+
s2_dump = s2.to_json()
3930+
s2_looped = TypeAdapter(StructWithLiterals).validate_json(s2_dump)
3931+
self.assertEqual(s2, s2_looped)
3932+
s2_dict = s2.to_dict()
3933+
s2_looped_dict = s2.from_dict(s2_dict)
3934+
self.assertEqual(s2_looped_dict, s2)
3935+
3936+
# Invalid color, but from_dict still accepts
3937+
StructWithLiterals.from_dict(dict(color="yellow", size=1, status="on"))
3938+
3939+
# Invalid size but from_dict still accepts
3940+
StructWithLiterals.from_dict(dict(color="red", size=4, status="on"))
3941+
3942+
# Invalid status but from_dict still accepts
3943+
StructWithLiterals.from_dict(dict(color="red", size=1, status="standby"))
3944+
3945+
# Invalid mode but from_dict still accepts
3946+
StructWithLiterals.from_dict(dict(color="red", size=1, mode=12))
3947+
3948+
# Invalid size and since the literals are all the same type
3949+
# If we give an incorrect type, we catch the error
3950+
with self.assertRaises(ValueError) as exc_info:
3951+
StructWithLiterals.from_dict(dict(color="red", size="adasd", mode=12))
3952+
self.assertIn("Expected type <class 'int'> received <class 'str'>", str(exc_info.exception))
3953+
3954+
# Test valid values
3955+
result = TypeAdapter(StructWithLiterals).validate_python({"color": "green", "size": 3, "status": 0})
3956+
self.assertEqual(result.color, "green")
3957+
self.assertEqual(result.size, 3)
3958+
self.assertEqual(result.status, 0)
3959+
3960+
# Test invalid color with Pydantic validation
3961+
with self.assertRaises(ValidationError) as exc_info:
3962+
TypeAdapter(StructWithLiterals).validate_python({"color": "yellow", "size": 1, "status": "on"})
3963+
self.assertIn("1 validation error for", str(exc_info.exception))
3964+
self.assertIn("color", str(exc_info.exception))
3965+
3966+
# Test invalid size with Pydantic validation
3967+
with self.assertRaises(ValidationError) as exc_info:
3968+
TypeAdapter(StructWithLiterals).validate_python({"color": "red", "size": 4, "status": "on"})
3969+
self.assertIn("1 validation error for", str(exc_info.exception))
3970+
self.assertIn("size", str(exc_info.exception))
3971+
3972+
# Test invalid status with Pydantic validation
3973+
with self.assertRaises(ValidationError) as exc_info:
3974+
TypeAdapter(StructWithLiterals).validate_python({"color": "red", "size": 1, "status": "standby"})
3975+
self.assertIn("1 validation error for", str(exc_info.exception))
3976+
self.assertIn("status", str(exc_info.exception))
3977+
3978+
# Test invalid mode with Pydantic validation
3979+
with self.assertRaises(ValidationError) as exc_info:
3980+
TypeAdapter(StructWithLiterals).validate_python(
3981+
{"color": "red", "size": 1, "status": "on", "mode": "medium"}
3982+
)
3983+
self.assertIn("1 validation error for", str(exc_info.exception))
3984+
self.assertIn("mode", str(exc_info.exception))
3985+
3986+
def test_pipe_operator_types(self):
3987+
"""Test using the pipe operator for union types in Python 3.10+"""
3988+
if sys.version_info >= (3, 10): # Only run on Python 3.10+
3989+
# Define a class using various pipe operator combinations
3990+
class PipeTypesConfig(csp.Struct):
3991+
# Basic primitive types with pipe
3992+
id_field: str | int
3993+
# Pipe with None (similar to Optional)
3994+
description: str | None = None
3995+
# Multiple types with pipe
3996+
value: str | int | float | bool
3997+
# Container with pipe
3998+
tags: List[str] | Dict[str, str] | None = None
3999+
# Pipe with literal for comparison
4000+
status: Literal["active", "inactive"] | None = "active"
4001+
4002+
# Test all valid types
4003+
valid_cases = [
4004+
{"id_field": "string_id", "value": "string_value"},
4005+
{"id_field": 42, "value": 123},
4006+
{"id_field": "mixed", "value": 3.14},
4007+
{"id_field": 999, "value": True},
4008+
{"id_field": "with_desc", "value": 1, "description": "Description"},
4009+
{"id_field": "with_dict", "value": 1, "tags": None},
4010+
]
4011+
4012+
for case in valid_cases:
4013+
result = PipeTypesConfig.from_dict(case)
4014+
# use the other route to get back the result
4015+
result_to_dict_loop = TypeAdapter(PipeTypesConfig).validate_python(result.to_dict())
4016+
self.assertEqual(result, result_to_dict_loop)
4017+
4018+
# Test invalid values
4019+
invalid_cases = [
4020+
{"id_field": 3.14, "value": 1}, # Float for id_field
4021+
{"id_field": None, "value": 1}, # None for required id_field
4022+
{"id_field": "test", "value": {}}, # Dict for value
4023+
{"id_field": "test", "value": None}, # None for required value
4024+
{"id_field": "test", "value": 1, "status": "unknown"}, # Invalid literal
4025+
]
4026+
for case in invalid_cases:
4027+
with self.assertRaises(ValidationError):
4028+
TypeAdapter(PipeTypesConfig).validate_python(case)
4029+
39064030

39074031
if __name__ == "__main__":
39084032
unittest.main()

csp/tests/impl/types/test_pydantic_types.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import sys
22
from inspect import isclass
3-
from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, Union, get_args, get_origin
3+
from typing import Any, Callable, Dict, Generic, List, Literal, Optional, Type, TypeVar, Union, get_args, get_origin
44
from unittest import TestCase
55

66
import csp
@@ -160,3 +160,12 @@ def test_force_tvars(self):
160160
self.assertAnnotationsEqual(
161161
adjust_annotations(CspTypeVarType[T], forced_tvars={"T": float}), Union[Type[float], Type[int]]
162162
)
163+
164+
def test_literal(self):
165+
self.assertAnnotationsEqual(adjust_annotations(Literal["a", "b"]), Literal["a", "b"])
166+
self.assertAnnotationsEqual(
167+
adjust_annotations(Literal["a", "b"], make_optional=True), Optional[Literal["a", "b"]]
168+
)
169+
self.assertAnnotationsEqual(adjust_annotations(Literal[123, "a"]), Literal[123, "a"])
170+
self.assertAnnotationsEqual(adjust_annotations(Literal[123, None]), Literal[123, None])
171+
self.assertAnnotationsEqual(adjust_annotations(ts[Literal[123, None]]), ts[Literal[123, None]])

0 commit comments

Comments
 (0)