Skip to content

Commit abd72f5

Browse files
committed
Merge branch 'main' into nk/from_dict_pydantic
2 parents 08a4409 + 81483c8 commit abd72f5

File tree

10 files changed

+281
-27
lines changed

10 files changed

+281
-27
lines changed

cpp/csp/adapters/utils/MessageStructConverter.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ bool MessageStructConverterCache::registerConverter( std::string protocol, Creat
2626
return true;
2727
}
2828

29+
bool MessageStructConverterCache::hasConverter( std::string protocol ) const
30+
{
31+
return m_creators.find( protocol ) != m_creators.end();
32+
}
33+
2934
MessageStructConverterCache & MessageStructConverterCache::instance()
3035
{
3136
static MessageStructConverterCache s_instance;

cpp/csp/adapters/utils/MessageStructConverter.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class MessageStructConverterCache
5151
using Creator = std::function<MessageStructConverter*( const CspTypePtr &, const Dictionary & )>;
5252

5353
bool registerConverter( std::string protocol, Creator creator );
54-
54+
bool hasConverter( std::string protocol ) const;
5555
private:
5656
using CacheKey = std::pair<const CspType*,Dictionary>;
5757
using Cache = std::unordered_map<CacheKey,MessageStructConverterPtr,csp::hash::hash_pair>;

csp/impl/struct.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@ def __new__(cls, name, bases, dct):
3535
# Lists need to be normalized too as potentially we need to add a boolean flag to use FastList
3636
if v == FastList:
3737
raise TypeError(f"{v} annotation is not supported without args")
38-
if CspTypingUtils.is_generic_container(v) or CspTypingUtils.is_union_type(v):
38+
if (
39+
CspTypingUtils.is_generic_container(v)
40+
or CspTypingUtils.is_union_type(v)
41+
or CspTypingUtils.is_literal_type(v)
42+
):
3943
actual_type = ContainerTypeNormalizer.normalized_type_to_actual_python_type(v)
4044
if CspTypingUtils.is_generic_container(actual_type):
4145
raise TypeError(f"{v} annotation is not supported as a struct field [{actual_type}]")
@@ -219,6 +223,13 @@ def _obj_from_python(cls, json, obj_type):
219223
return json
220224
else:
221225
raise NotImplementedError(f"Can not deserialize {obj_type} from json")
226+
elif CspTypingUtils.is_union_type(obj_type):
227+
return json ## no checks, just let it through
228+
elif CspTypingUtils.is_literal_type(obj_type):
229+
return_type = ContainerTypeNormalizer.normalized_type_to_actual_python_type(obj_type)
230+
if isinstance(json, return_type):
231+
return json
232+
raise ValueError(f"Expected type {return_type} received {json.__class__}")
222233
elif issubclass(obj_type, Struct):
223234
if not isinstance(json, dict):
224235
raise TypeError("Representation of struct as json is expected to be of dict type")

csp/impl/types/container_type_normalizer.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -81,21 +81,21 @@ def normalized_type_to_actual_python_type(cls, typ, level=0):
8181
return [cls.normalized_type_to_actual_python_type(typ.__args__[0], level + 1), True]
8282
if origin is typing.List and level == 0:
8383
return [cls.normalized_type_to_actual_python_type(typ.__args__[0], level + 1)]
84-
if origin is typing.Literal:
85-
# Import here to prevent circular import
86-
from csp.impl.types.instantiation_type_resolver import UpcastRegistry
87-
88-
args = typing.get_args(typ)
89-
typ = type(args[0])
90-
for arg in args[1:]:
91-
typ = UpcastRegistry.instance().resolve_type(typ, type(arg), raise_on_error=False)
92-
if typ:
93-
return typ
94-
else:
95-
return object
9684
return cls._NORMALIZED_TYPE_MAPPING.get(CspTypingUtils.get_origin(typ), typ)
9785
elif CspTypingUtils.is_union_type(typ):
9886
return object
87+
elif CspTypingUtils.is_literal_type(typ):
88+
# Import here to prevent circular import
89+
from csp.impl.types.instantiation_type_resolver import UpcastRegistry
90+
91+
args = typing.get_args(typ)
92+
typ = type(args[0])
93+
for arg in args[1:]:
94+
typ = UpcastRegistry.instance().resolve_type(typ, type(arg), raise_on_error=False)
95+
if typ:
96+
return typ
97+
else:
98+
return object
9999
else:
100100
return typ
101101

csp/impl/types/pydantic_types.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import sys
22
import types
33
import typing
4-
from typing import Any, ForwardRef, Generic, Optional, Type, TypeVar, Union, get_args, get_origin
4+
from typing import Any, ForwardRef, Generic, Literal, Optional, Type, TypeVar, Union, get_args, get_origin
55

66
from pydantic import GetCoreSchemaHandler, ValidationInfo, ValidatorFunctionWrapHandler
77
from pydantic_core import CoreSchema, core_schema
@@ -184,6 +184,8 @@ def adjust_annotations(
184184
return TsType[
185185
adjust_annotations(args[0], top_level=False, in_ts=True, make_optional=False, forced_tvars=forced_tvars)
186186
]
187+
if origin is Literal: # for literals, we stop converting
188+
return Optional[annotation] if make_optional else annotation
187189
else:
188190
try:
189191
if origin is CspTypeVar or origin is CspTypeVarType:

csp/impl/types/type_annotation_normalizer_transformer.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def visit_arg(self, node):
5151
return node
5252

5353
def visit_Subscript(self, node):
54+
# We choose to avoid parsing here
55+
# to maintain current behavior of allowing empty lists in our types
5456
return node
5557

5658
def visit_List(self, node):
@@ -98,17 +100,13 @@ def visit_Call(self, node):
98100
return node
99101

100102
def visit_Constant(self, node):
101-
if not self._cur_arg:
102-
return node
103-
104-
if self._cur_arg:
105-
return ast.Call(
106-
func=ast.Attribute(value=ast.Name(id="typing", ctx=ast.Load()), attr="TypeVar", ctx=ast.Load()),
107-
args=[node],
108-
keywords=[],
109-
)
110-
else:
103+
if not self._cur_arg or not isinstance(node.value, str):
111104
return node
105+
return ast.Call(
106+
func=ast.Attribute(value=ast.Name(id="typing", ctx=ast.Load()), attr="TypeVar", ctx=ast.Load()),
107+
args=[node],
108+
keywords=[],
109+
)
112110

113111
def visit_Str(self, node):
114112
return self.visit_Constant(node)

csp/impl/types/typing_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class CspTypingUtils39:
5050

5151
@classmethod
5252
def is_generic_container(cls, typ):
53-
return isinstance(typ, cls._GENERIC_ALIASES) and typ.__origin__ is not typing.Union
53+
return isinstance(typ, cls._GENERIC_ALIASES) and typ.__origin__ not in (typing.Union, typing.Literal)
5454

5555
@classmethod
5656
def is_type_spec(cls, val):
@@ -83,6 +83,10 @@ def is_numpy_nd_array_type(cls, typ):
8383
def is_union_type(cls, typ):
8484
return isinstance(typ, typing._GenericAlias) and typ.__origin__ is typing.Union
8585

86+
@classmethod
87+
def is_literal_type(cls, typ):
88+
return isinstance(typ, typing._GenericAlias) and typ.__origin__ is typing.Literal
89+
8690
@classmethod
8791
def is_forward_ref(cls, typ):
8892
return isinstance(typ, typing.ForwardRef)

csp/tests/impl/test_struct.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import enum
22
import json
33
import pickle
4+
import sys
45
import unittest
56
from datetime import date, datetime, time, timedelta
67
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union
@@ -3940,6 +3941,129 @@ class DataPoint(csp.Struct):
39403941
self.assertNotIn("_last_updated", json_data)
39413942
self.assertNotIn("_source", json_data["data"])
39423943

3944+
def test_literal_types_validation(self):
3945+
"""Test that Literal type annotations correctly validate input values in CSP Structs"""
3946+
3947+
# Define a simple class with various Literal types
3948+
class StructWithLiterals(csp.Struct):
3949+
# String literals
3950+
color: Literal["red", "green", "blue"]
3951+
# Integer literals
3952+
size: Literal[1, 2, 3]
3953+
# Mixed type literals
3954+
status: Literal["on", "off", 0, 1, True, False]
3955+
# Optional literal with default
3956+
mode: Optional[Literal["fast", "slow"]] = "fast"
3957+
3958+
# Test valid assignments
3959+
s1 = StructWithLiterals(color="red", size=2, status="on")
3960+
self.assertEqual(s1.color, "red")
3961+
self.assertEqual(s1.size, 2)
3962+
self.assertEqual(s1.status, "on")
3963+
self.assertEqual(s1.mode, "fast") # Default value
3964+
3965+
s2 = StructWithLiterals.from_dict(dict(color="blue", size=1, status=True, mode="slow"))
3966+
s2_dump = s2.to_json()
3967+
s2_looped = TypeAdapter(StructWithLiterals).validate_json(s2_dump)
3968+
self.assertEqual(s2, s2_looped)
3969+
s2_dict = s2.to_dict()
3970+
s2_looped_dict = s2.from_dict(s2_dict)
3971+
self.assertEqual(s2_looped_dict, s2)
3972+
3973+
# Invalid color, but from_dict still accepts
3974+
StructWithLiterals.from_dict(dict(color="yellow", size=1, status="on"))
3975+
3976+
# Invalid size but from_dict still accepts
3977+
StructWithLiterals.from_dict(dict(color="red", size=4, status="on"))
3978+
3979+
# Invalid status but from_dict still accepts
3980+
StructWithLiterals.from_dict(dict(color="red", size=1, status="standby"))
3981+
3982+
# Invalid mode but from_dict still accepts
3983+
StructWithLiterals.from_dict(dict(color="red", size=1, mode=12))
3984+
3985+
# Invalid size and since the literals are all the same type
3986+
# If we give an incorrect type, we catch the error
3987+
with self.assertRaises(ValueError) as exc_info:
3988+
StructWithLiterals.from_dict(dict(color="red", size="adasd", mode=12))
3989+
self.assertIn("Expected type <class 'int'> received <class 'str'>", str(exc_info.exception))
3990+
3991+
# Test valid values
3992+
result = TypeAdapter(StructWithLiterals).validate_python({"color": "green", "size": 3, "status": 0})
3993+
self.assertEqual(result.color, "green")
3994+
self.assertEqual(result.size, 3)
3995+
self.assertEqual(result.status, 0)
3996+
3997+
# Test invalid color with Pydantic validation
3998+
with self.assertRaises(ValidationError) as exc_info:
3999+
TypeAdapter(StructWithLiterals).validate_python({"color": "yellow", "size": 1, "status": "on"})
4000+
self.assertIn("1 validation error for", str(exc_info.exception))
4001+
self.assertIn("color", str(exc_info.exception))
4002+
4003+
# Test invalid size with Pydantic validation
4004+
with self.assertRaises(ValidationError) as exc_info:
4005+
TypeAdapter(StructWithLiterals).validate_python({"color": "red", "size": 4, "status": "on"})
4006+
self.assertIn("1 validation error for", str(exc_info.exception))
4007+
self.assertIn("size", str(exc_info.exception))
4008+
4009+
# Test invalid status with Pydantic validation
4010+
with self.assertRaises(ValidationError) as exc_info:
4011+
TypeAdapter(StructWithLiterals).validate_python({"color": "red", "size": 1, "status": "standby"})
4012+
self.assertIn("1 validation error for", str(exc_info.exception))
4013+
self.assertIn("status", str(exc_info.exception))
4014+
4015+
# Test invalid mode with Pydantic validation
4016+
with self.assertRaises(ValidationError) as exc_info:
4017+
TypeAdapter(StructWithLiterals).validate_python(
4018+
{"color": "red", "size": 1, "status": "on", "mode": "medium"}
4019+
)
4020+
self.assertIn("1 validation error for", str(exc_info.exception))
4021+
self.assertIn("mode", str(exc_info.exception))
4022+
4023+
def test_pipe_operator_types(self):
4024+
"""Test using the pipe operator for union types in Python 3.10+"""
4025+
if sys.version_info >= (3, 10): # Only run on Python 3.10+
4026+
# Define a class using various pipe operator combinations
4027+
class PipeTypesConfig(csp.Struct):
4028+
# Basic primitive types with pipe
4029+
id_field: str | int
4030+
# Pipe with None (similar to Optional)
4031+
description: str | None = None
4032+
# Multiple types with pipe
4033+
value: str | int | float | bool
4034+
# Container with pipe
4035+
tags: List[str] | Dict[str, str] | None = None
4036+
# Pipe with literal for comparison
4037+
status: Literal["active", "inactive"] | None = "active"
4038+
4039+
# Test all valid types
4040+
valid_cases = [
4041+
{"id_field": "string_id", "value": "string_value"},
4042+
{"id_field": 42, "value": 123},
4043+
{"id_field": "mixed", "value": 3.14},
4044+
{"id_field": 999, "value": True},
4045+
{"id_field": "with_desc", "value": 1, "description": "Description"},
4046+
{"id_field": "with_dict", "value": 1, "tags": None},
4047+
]
4048+
4049+
for case in valid_cases:
4050+
result = PipeTypesConfig.from_dict(case)
4051+
# use the other route to get back the result
4052+
result_to_dict_loop = TypeAdapter(PipeTypesConfig).validate_python(result.to_dict())
4053+
self.assertEqual(result, result_to_dict_loop)
4054+
4055+
# Test invalid values
4056+
invalid_cases = [
4057+
{"id_field": 3.14, "value": 1}, # Float for id_field
4058+
{"id_field": None, "value": 1}, # None for required id_field
4059+
{"id_field": "test", "value": {}}, # Dict for value
4060+
{"id_field": "test", "value": None}, # None for required value
4061+
{"id_field": "test", "value": 1, "status": "unknown"}, # Invalid literal
4062+
]
4063+
for case in invalid_cases:
4064+
with self.assertRaises(ValidationError):
4065+
TypeAdapter(PipeTypesConfig).validate_python(case)
4066+
39434067

39444068
if __name__ == "__main__":
39454069
unittest.main()

csp/tests/impl/types/test_pydantic_types.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import sys
22
from inspect import isclass
3-
from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, Union, get_args, get_origin
3+
from typing import Any, Callable, Dict, Generic, List, Literal, Optional, Type, TypeVar, Union, get_args, get_origin
44
from unittest import TestCase
55

66
import csp
@@ -160,3 +160,12 @@ def test_force_tvars(self):
160160
self.assertAnnotationsEqual(
161161
adjust_annotations(CspTypeVarType[T], forced_tvars={"T": float}), Union[Type[float], Type[int]]
162162
)
163+
164+
def test_literal(self):
165+
self.assertAnnotationsEqual(adjust_annotations(Literal["a", "b"]), Literal["a", "b"])
166+
self.assertAnnotationsEqual(
167+
adjust_annotations(Literal["a", "b"], make_optional=True), Optional[Literal["a", "b"]]
168+
)
169+
self.assertAnnotationsEqual(adjust_annotations(Literal[123, "a"]), Literal[123, "a"])
170+
self.assertAnnotationsEqual(adjust_annotations(Literal[123, None]), Literal[123, None])
171+
self.assertAnnotationsEqual(adjust_annotations(ts[Literal[123, None]]), ts[Literal[123, None]])

0 commit comments

Comments
 (0)