diff --git a/localstack_snapshot/snapshots/prototype.py b/localstack_snapshot/snapshots/prototype.py index 533f285..e7dde9a 100644 --- a/localstack_snapshot/snapshots/prototype.py +++ b/localstack_snapshot/snapshots/prototype.py @@ -2,7 +2,9 @@ import json import logging import os +from collections.abc import Iterator from datetime import datetime, timezone +from enum import Enum from json import JSONDecodeError from pathlib import Path from re import Pattern @@ -167,16 +169,27 @@ def _update(self, key: str, obj_state: dict) -> None: def match_object(self, key: str, obj: object) -> None: def _convert_object_to_dict(obj_): if isinstance(obj_, dict): - for key in list(obj_.keys()): - if key.startswith("_"): - del obj_[key] - else: - obj_[key] = _convert_object_to_dict(obj_[key]) - elif isinstance(obj_, list): - for idx, val in enumerate(obj_): - obj_[idx] = _convert_object_to_dict(val) + # Serialize the values of the dictionary, while skipping any private keys (starting with '_') + return { + key_: _convert_object_to_dict(obj_[key_]) + for key_ in obj_ + if not key_.startswith("_") + } + elif isinstance(obj_, (list, Iterator)): + return [_convert_object_to_dict(val) for val in obj_] + elif isinstance(obj_, Enum): + return obj_.value elif hasattr(obj_, "__dict__"): - return _convert_object_to_dict(obj_.__dict__) + # This is an object - let's try to convert it to a dictionary + # A naive approach would be to use the '__dict__' object directly, but that only lists the attributes + # In order to also serialize the properties, we use the __dir__() method + # Filtering by everything that is not a method gives us both attributes and properties + # We also (still) skip private attributes/properties, so everything that starts with an underscore + return { + k: _convert_object_to_dict(getattr(obj_, k)) + for k in obj_.__dir__() + if not k.startswith("_") and type(getattr(obj_, k, "")).__name__ != "method" + } return obj_ return self.match(key, _convert_object_to_dict(obj)) diff --git a/tests/test_snapshots.py b/tests/test_snapshots.py index e4436fe..398e844 100644 --- a/tests/test_snapshots.py +++ b/tests/test_snapshots.py @@ -1,4 +1,5 @@ import io +from enum import Enum import pytest @@ -75,6 +76,67 @@ def __init__(self, name): sm.match_object("key_a", CustomObject(name="myname")) sm._assert_all() + def test_match_object_lists_and_iterators(self): + class CustomObject: + def __init__(self, name): + self.name = name + self.my_list = [9, 8, 7, 6, 5] + self.my_iterator = (x for x in range(5)) + + sm = SnapshotSession(scope_key="A", verify=True, base_file_path="", update=False) + sm.recorded_state = { + "key_a": {"name": "myname", "my_iterator": [0, 1, 2, 3, 4], "my_list": [9, 8, 7, 6, 5]} + } + sm.match_object("key_a", CustomObject(name="myname")) + sm._assert_all() + + def test_match_object_include_properties(self): + class CustomObject: + def __init__(self, name): + self.name = name + self._internal = "n/a" + + def some_method(self): + # method should not be serialized + return False + + @property + def some_prop(self): + # properties should be serialized + return True + + @property + def some_iterator(self): + for i in range(3): + yield i + + @property + def _private_prop(self): + # private properties should be ignored + return False + + sm = SnapshotSession(scope_key="A", verify=True, base_file_path="", update=False) + sm.recorded_state = { + "key_a": {"name": "myname", "some_prop": True, "some_iterator": [0, 1, 2]} + } + sm.match_object("key_a", CustomObject(name="myname")) + sm._assert_all() + + def test_match_object_enums(self): + class TestEnum(Enum): + value1 = "Value 1" + value2 = "Value 2" + + class CustomObject: + def __init__(self, name): + self.name = name + self.my_enum = TestEnum.value2 + + sm = SnapshotSession(scope_key="A", verify=True, base_file_path="", update=False) + sm.recorded_state = {"key_a": {"name": "myname", "my_enum": "Value 2"}} + sm.match_object("key_a", CustomObject(name="myname")) + sm._assert_all() + def test_match_object_change(self): class CustomObject: def __init__(self, name):