|
1 | 1 | import enum |
2 | 2 | import json |
3 | 3 | import pickle |
| 4 | +import sys |
4 | 5 | import unittest |
5 | 6 | from datetime import date, datetime, time, timedelta |
6 | 7 | from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union |
@@ -3940,6 +3941,129 @@ class DataPoint(csp.Struct): |
3940 | 3941 | self.assertNotIn("_last_updated", json_data) |
3941 | 3942 | self.assertNotIn("_source", json_data["data"]) |
3942 | 3943 |
|
| 3944 | + def test_literal_types_validation(self): |
| 3945 | + """Test that Literal type annotations correctly validate input values in CSP Structs""" |
| 3946 | + |
| 3947 | + # Define a simple class with various Literal types |
| 3948 | + class StructWithLiterals(csp.Struct): |
| 3949 | + # String literals |
| 3950 | + color: Literal["red", "green", "blue"] |
| 3951 | + # Integer literals |
| 3952 | + size: Literal[1, 2, 3] |
| 3953 | + # Mixed type literals |
| 3954 | + status: Literal["on", "off", 0, 1, True, False] |
| 3955 | + # Optional literal with default |
| 3956 | + mode: Optional[Literal["fast", "slow"]] = "fast" |
| 3957 | + |
| 3958 | + # Test valid assignments |
| 3959 | + s1 = StructWithLiterals(color="red", size=2, status="on") |
| 3960 | + self.assertEqual(s1.color, "red") |
| 3961 | + self.assertEqual(s1.size, 2) |
| 3962 | + self.assertEqual(s1.status, "on") |
| 3963 | + self.assertEqual(s1.mode, "fast") # Default value |
| 3964 | + |
| 3965 | + s2 = StructWithLiterals.from_dict(dict(color="blue", size=1, status=True, mode="slow")) |
| 3966 | + s2_dump = s2.to_json() |
| 3967 | + s2_looped = TypeAdapter(StructWithLiterals).validate_json(s2_dump) |
| 3968 | + self.assertEqual(s2, s2_looped) |
| 3969 | + s2_dict = s2.to_dict() |
| 3970 | + s2_looped_dict = s2.from_dict(s2_dict) |
| 3971 | + self.assertEqual(s2_looped_dict, s2) |
| 3972 | + |
| 3973 | + # Invalid color, but from_dict still accepts |
| 3974 | + StructWithLiterals.from_dict(dict(color="yellow", size=1, status="on")) |
| 3975 | + |
| 3976 | + # Invalid size but from_dict still accepts |
| 3977 | + StructWithLiterals.from_dict(dict(color="red", size=4, status="on")) |
| 3978 | + |
| 3979 | + # Invalid status but from_dict still accepts |
| 3980 | + StructWithLiterals.from_dict(dict(color="red", size=1, status="standby")) |
| 3981 | + |
| 3982 | + # Invalid mode but from_dict still accepts |
| 3983 | + StructWithLiterals.from_dict(dict(color="red", size=1, mode=12)) |
| 3984 | + |
| 3985 | + # Invalid size and since the literals are all the same type |
| 3986 | + # If we give an incorrect type, we catch the error |
| 3987 | + with self.assertRaises(ValueError) as exc_info: |
| 3988 | + StructWithLiterals.from_dict(dict(color="red", size="adasd", mode=12)) |
| 3989 | + self.assertIn("Expected type <class 'int'> received <class 'str'>", str(exc_info.exception)) |
| 3990 | + |
| 3991 | + # Test valid values |
| 3992 | + result = TypeAdapter(StructWithLiterals).validate_python({"color": "green", "size": 3, "status": 0}) |
| 3993 | + self.assertEqual(result.color, "green") |
| 3994 | + self.assertEqual(result.size, 3) |
| 3995 | + self.assertEqual(result.status, 0) |
| 3996 | + |
| 3997 | + # Test invalid color with Pydantic validation |
| 3998 | + with self.assertRaises(ValidationError) as exc_info: |
| 3999 | + TypeAdapter(StructWithLiterals).validate_python({"color": "yellow", "size": 1, "status": "on"}) |
| 4000 | + self.assertIn("1 validation error for", str(exc_info.exception)) |
| 4001 | + self.assertIn("color", str(exc_info.exception)) |
| 4002 | + |
| 4003 | + # Test invalid size with Pydantic validation |
| 4004 | + with self.assertRaises(ValidationError) as exc_info: |
| 4005 | + TypeAdapter(StructWithLiterals).validate_python({"color": "red", "size": 4, "status": "on"}) |
| 4006 | + self.assertIn("1 validation error for", str(exc_info.exception)) |
| 4007 | + self.assertIn("size", str(exc_info.exception)) |
| 4008 | + |
| 4009 | + # Test invalid status with Pydantic validation |
| 4010 | + with self.assertRaises(ValidationError) as exc_info: |
| 4011 | + TypeAdapter(StructWithLiterals).validate_python({"color": "red", "size": 1, "status": "standby"}) |
| 4012 | + self.assertIn("1 validation error for", str(exc_info.exception)) |
| 4013 | + self.assertIn("status", str(exc_info.exception)) |
| 4014 | + |
| 4015 | + # Test invalid mode with Pydantic validation |
| 4016 | + with self.assertRaises(ValidationError) as exc_info: |
| 4017 | + TypeAdapter(StructWithLiterals).validate_python( |
| 4018 | + {"color": "red", "size": 1, "status": "on", "mode": "medium"} |
| 4019 | + ) |
| 4020 | + self.assertIn("1 validation error for", str(exc_info.exception)) |
| 4021 | + self.assertIn("mode", str(exc_info.exception)) |
| 4022 | + |
| 4023 | + def test_pipe_operator_types(self): |
| 4024 | + """Test using the pipe operator for union types in Python 3.10+""" |
| 4025 | + if sys.version_info >= (3, 10): # Only run on Python 3.10+ |
| 4026 | + # Define a class using various pipe operator combinations |
| 4027 | + class PipeTypesConfig(csp.Struct): |
| 4028 | + # Basic primitive types with pipe |
| 4029 | + id_field: str | int |
| 4030 | + # Pipe with None (similar to Optional) |
| 4031 | + description: str | None = None |
| 4032 | + # Multiple types with pipe |
| 4033 | + value: str | int | float | bool |
| 4034 | + # Container with pipe |
| 4035 | + tags: List[str] | Dict[str, str] | None = None |
| 4036 | + # Pipe with literal for comparison |
| 4037 | + status: Literal["active", "inactive"] | None = "active" |
| 4038 | + |
| 4039 | + # Test all valid types |
| 4040 | + valid_cases = [ |
| 4041 | + {"id_field": "string_id", "value": "string_value"}, |
| 4042 | + {"id_field": 42, "value": 123}, |
| 4043 | + {"id_field": "mixed", "value": 3.14}, |
| 4044 | + {"id_field": 999, "value": True}, |
| 4045 | + {"id_field": "with_desc", "value": 1, "description": "Description"}, |
| 4046 | + {"id_field": "with_dict", "value": 1, "tags": None}, |
| 4047 | + ] |
| 4048 | + |
| 4049 | + for case in valid_cases: |
| 4050 | + result = PipeTypesConfig.from_dict(case) |
| 4051 | + # use the other route to get back the result |
| 4052 | + result_to_dict_loop = TypeAdapter(PipeTypesConfig).validate_python(result.to_dict()) |
| 4053 | + self.assertEqual(result, result_to_dict_loop) |
| 4054 | + |
| 4055 | + # Test invalid values |
| 4056 | + invalid_cases = [ |
| 4057 | + {"id_field": 3.14, "value": 1}, # Float for id_field |
| 4058 | + {"id_field": None, "value": 1}, # None for required id_field |
| 4059 | + {"id_field": "test", "value": {}}, # Dict for value |
| 4060 | + {"id_field": "test", "value": None}, # None for required value |
| 4061 | + {"id_field": "test", "value": 1, "status": "unknown"}, # Invalid literal |
| 4062 | + ] |
| 4063 | + for case in invalid_cases: |
| 4064 | + with self.assertRaises(ValidationError): |
| 4065 | + TypeAdapter(PipeTypesConfig).validate_python(case) |
| 4066 | + |
3943 | 4067 |
|
3944 | 4068 | if __name__ == "__main__": |
3945 | 4069 | unittest.main() |
0 commit comments