diff --git a/python/packages/core/tests/core/test_memory.py b/python/packages/core/tests/core/test_memory.py index 6cc7ba436e..bcc299ed37 100644 --- a/python/packages/core/tests/core/test_memory.py +++ b/python/packages/core/tests/core/test_memory.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. +import sys from collections.abc import MutableSequence from typing import Any @@ -44,6 +45,18 @@ async def invoking(self, messages: ChatMessage | MutableSequence[ChatMessage], * return context +class MinimalContextProvider(ContextProvider): + """Minimal ContextProvider that only implements the required abstract method. + + Used to test the base class default implementations of thread_created, + invoked, __aenter__, and __aexit__. + """ + + async def invoking(self, messages: ChatMessage | MutableSequence[ChatMessage], **kwargs: Any) -> Context: + """Return empty context.""" + return Context() + + class TestContext: """Tests for Context class.""" @@ -91,3 +104,33 @@ async def test_invoking(self) -> None: assert context.messages is not None assert len(context.messages) == 1 assert context.messages[0].text == "Context message" + + async def test_base_thread_created_does_nothing(self) -> None: + """Test that base ContextProvider.thread_created does nothing by default.""" + provider = MinimalContextProvider() + await provider.thread_created("some-thread-id") + await provider.thread_created(None) + + async def test_base_invoked_does_nothing(self) -> None: + """Test that base ContextProvider.invoked does nothing by default.""" + provider = MinimalContextProvider() + message = ChatMessage(role=Role.USER, text="Test") + await provider.invoked(message) + await provider.invoked(message, response_messages=message) + await provider.invoked(message, invoke_exception=Exception("test")) + + async def test_base_aenter_returns_self(self) -> None: + """Test that base ContextProvider.__aenter__ returns self.""" + provider = MinimalContextProvider() + async with provider as p: + assert p is provider + + async def test_base_aexit_does_nothing(self) -> None: + """Test that base ContextProvider.__aexit__ handles exceptions gracefully.""" + provider = MinimalContextProvider() + await provider.__aexit__(None, None, None) + try: + raise ValueError("test error") + except ValueError: + exc_info = sys.exc_info() + await provider.__aexit__(exc_info[0], exc_info[1], exc_info[2]) diff --git a/python/packages/core/tests/core/test_serializable_mixin.py b/python/packages/core/tests/core/test_serializable_mixin.py index 0472f881cf..05ece1072b 100644 --- a/python/packages/core/tests/core/test_serializable_mixin.py +++ b/python/packages/core/tests/core/test_serializable_mixin.py @@ -190,3 +190,240 @@ def __init__(self, value: str, number: int, client: Any = None): assert restored.value == "test" assert restored.number == 42 assert restored.client == mock_client + + def test_exclude_none_in_to_dict(self): + """Test that exclude_none parameter removes None values from to_dict().""" + + class TestClass(SerializationMixin): + def __init__(self, value: str, optional: str | None = None): + self.value = value + self.optional = optional + + obj = TestClass(value="test", optional=None) + data = obj.to_dict(exclude_none=True) + + assert data["value"] == "test" + assert "optional" not in data + + def test_to_dict_with_nested_serialization_protocol(self): + """Test to_dict handles nested SerializationProtocol objects.""" + + class InnerClass(SerializationMixin): + def __init__(self, inner_value: str): + self.inner_value = inner_value + + class OuterClass(SerializationMixin): + def __init__(self, outer_value: str, inner: Any = None): + self.outer_value = outer_value + self.inner = inner + + inner = InnerClass(inner_value="inner_test") + outer = OuterClass(outer_value="outer_test", inner=inner) + data = outer.to_dict() + + assert data["outer_value"] == "outer_test" + assert data["inner"]["inner_value"] == "inner_test" + + def test_to_dict_with_list_of_serialization_protocol(self): + """Test to_dict handles lists containing SerializationProtocol objects.""" + + class ItemClass(SerializationMixin): + def __init__(self, name: str): + self.name = name + + class ContainerClass(SerializationMixin): + def __init__(self, items: list): + self.items = items + + items = [ItemClass(name="item1"), ItemClass(name="item2")] + container = ContainerClass(items=items) + data = container.to_dict() + + assert len(data["items"]) == 2 + assert data["items"][0]["name"] == "item1" + assert data["items"][1]["name"] == "item2" + + def test_to_dict_skips_non_serializable_in_list(self, caplog): + """Test to_dict skips non-serializable items in lists with debug logging.""" + + class NonSerializable: + pass + + class TestClass(SerializationMixin): + def __init__(self, items: list): + self.items = items + + obj = TestClass(items=["serializable", NonSerializable()]) + + with caplog.at_level(logging.DEBUG): + data = obj.to_dict() + + # Should only contain the serializable item + assert len(data["items"]) == 1 + assert data["items"][0] == "serializable" + + def test_to_dict_with_dict_containing_serialization_protocol(self): + """Test to_dict handles dicts containing SerializationProtocol values.""" + + class ItemClass(SerializationMixin): + def __init__(self, name: str): + self.name = name + + class ContainerClass(SerializationMixin): + def __init__(self, items_dict: dict): + self.items_dict = items_dict + + items = {"a": ItemClass(name="item1"), "b": ItemClass(name="item2")} + container = ContainerClass(items_dict=items) + data = container.to_dict() + + assert data["items_dict"]["a"]["name"] == "item1" + assert data["items_dict"]["b"]["name"] == "item2" + + def test_to_dict_with_datetime_in_dict(self): + """Test to_dict converts datetime objects in dicts to strings.""" + from datetime import datetime + + class TestClass(SerializationMixin): + def __init__(self, metadata: dict): + self.metadata = metadata + + now = datetime(2025, 1, 27, 12, 0, 0) + obj = TestClass(metadata={"created_at": now}) + data = obj.to_dict() + + assert isinstance(data["metadata"]["created_at"], str) + + def test_to_dict_skips_non_serializable_in_dict(self, caplog): + """Test to_dict skips non-serializable values in dicts with debug logging.""" + + class NonSerializable: + pass + + class TestClass(SerializationMixin): + def __init__(self, metadata: dict): + self.metadata = metadata + + obj = TestClass(metadata={"valid": "value", "invalid": NonSerializable()}) + + with caplog.at_level(logging.DEBUG): + data = obj.to_dict() + + assert data["metadata"]["valid"] == "value" + assert "invalid" not in data["metadata"] + + def test_to_dict_skips_non_serializable_attributes(self, caplog): + """Test to_dict skips non-serializable top-level attributes.""" + + class TestClass(SerializationMixin): + def __init__(self, value: str, func: Any = None): + self.value = value + self.func = func + + obj = TestClass(value="test", func=lambda x: x) + + with caplog.at_level(logging.DEBUG): + data = obj.to_dict() + + assert data["value"] == "test" + assert "func" not in data + + def test_from_dict_without_type_in_data(self): + """Test from_dict uses class TYPE when no type field in data.""" + + class TestClass(SerializationMixin): + TYPE = "my_custom_type" + + def __init__(self, value: str): + self.value = value + + # Data without 'type' field - class TYPE should be used for type identifier + data = {"value": "test"} + + obj = TestClass.from_dict(data) + assert obj.value == "test" + + # Verify to_dict includes the type + out = obj.to_dict() + assert out["type"] == "my_custom_type" + + def test_from_json(self): + """Test from_json deserializes JSON string.""" + + class TestClass(SerializationMixin): + def __init__(self, value: str): + self.value = value + + json_str = '{"type": "test_class", "value": "test_value"}' + obj = TestClass.from_json(json_str) + + assert obj.value == "test_value" + + def test_get_type_identifier_with_instance_type(self): + """Test _get_type_identifier uses instance 'type' attribute.""" + + class TestClass(SerializationMixin): + def __init__(self, value: str): + self.value = value + self.type = "custom_type" + + obj = TestClass(value="test") + data = obj.to_dict() + + assert data["type"] == "custom_type" + + def test_get_type_identifier_with_class_TYPE(self): + """Test _get_type_identifier uses class TYPE constant.""" + + class TestClass(SerializationMixin): + TYPE = "class_level_type" + + def __init__(self, value: str): + self.value = value + + obj = TestClass(value="test") + data = obj.to_dict() + + assert data["type"] == "class_level_type" + + def test_instance_specific_dependency_injection(self): + """Test instance-specific dependency injection with field:name format.""" + + class TestClass(SerializationMixin): + INJECTABLE = {"config"} + + def __init__(self, name: str, config: Any = None): + self.name = name + self.config = config + + dependencies = { + "test_class": { + "name:special_instance": {"config": "special_config"}, + } + } + + # This should match the instance-specific dependency + obj = TestClass.from_dict({"type": "test_class", "name": "special_instance"}, dependencies=dependencies) + + assert obj.name == "special_instance" + assert obj.config == "special_config" + + def test_dependency_dict_merging(self): + """Test that dict dependencies are merged with existing dict kwargs.""" + + class TestClass(SerializationMixin): + INJECTABLE = {"options"} + + def __init__(self, value: str, options: dict | None = None): + self.value = value + self.options = options or {} + + # Existing options in data + data = {"type": "test_class", "value": "test", "options": {"existing": "value"}} + # Additional options from dependencies + dependencies = {"test_class": {"options": {"injected": "option"}}} + + obj = TestClass.from_dict(data, dependencies=dependencies) + + assert obj.options["existing"] == "value" + assert obj.options["injected"] == "option" diff --git a/python/packages/core/tests/core/test_threads.py b/python/packages/core/tests/core/test_threads.py index 492ed11519..01d5ceb98f 100644 --- a/python/packages/core/tests/core/test_threads.py +++ b/python/packages/core/tests/core/test_threads.py @@ -446,3 +446,155 @@ def test_init_with_chat_message_store_state_no_messages(self) -> None: assert state.service_thread_id is None assert state.chat_message_store_state is not None assert state.chat_message_store_state.messages == [] + + def test_init_with_chat_message_store_state_object(self) -> None: + """Test AgentThreadState initialization with ChatMessageStoreState object.""" + store_state = ChatMessageStoreState(messages=[ChatMessage(role=Role.USER, text="test")]) + state = AgentThreadState(chat_message_store_state=store_state) + + assert state.service_thread_id is None + assert state.chat_message_store_state is store_state + assert len(state.chat_message_store_state.messages) == 1 + + def test_init_with_invalid_chat_message_store_state_type(self) -> None: + """Test AgentThreadState initialization with invalid chat_message_store_state type.""" + with pytest.raises(TypeError, match="Could not parse ChatMessageStoreState"): + AgentThreadState(chat_message_store_state="invalid_type") # type: ignore[arg-type] + + +class TestChatMessageStoreStateEdgeCases: + """Additional edge case tests for ChatMessageStoreState.""" + + def test_init_with_invalid_messages_type(self) -> None: + """Test ChatMessageStoreState initialization with invalid messages type.""" + with pytest.raises(TypeError, match="Messages should be a list"): + ChatMessageStoreState(messages="invalid") # type: ignore[arg-type] + + def test_init_with_dict_messages(self) -> None: + """Test ChatMessageStoreState initialization with dict messages.""" + messages = [ + {"role": "user", "text": "Hello"}, + {"role": "assistant", "text": "Hi there!"}, + ] + state = ChatMessageStoreState(messages=messages) + + assert len(state.messages) == 2 + assert isinstance(state.messages[0], ChatMessage) + assert state.messages[0].text == "Hello" + + +class TestChatMessageStoreEdgeCases: + """Additional edge case tests for ChatMessageStore.""" + + async def test_deserialize_class_method(self) -> None: + """Test ChatMessageStore.deserialize class method.""" + serialized_data = { + "messages": [ + {"role": "user", "text": "Hello", "message_id": "msg1"}, + ] + } + + store = await ChatMessageStore.deserialize(serialized_data) + + assert isinstance(store, ChatMessageStore) + messages = await store.list_messages() + assert len(messages) == 1 + assert messages[0].text == "Hello" + + async def test_deserialize_empty_state(self) -> None: + """Test ChatMessageStore.deserialize with empty state.""" + serialized_data: dict[str, Any] = {"messages": []} + + store = await ChatMessageStore.deserialize(serialized_data) + + assert isinstance(store, ChatMessageStore) + messages = await store.list_messages() + assert len(messages) == 0 + + +class TestAgentThreadEdgeCases: + """Additional edge case tests for AgentThread.""" + + def test_is_initialized_with_service_thread_id(self) -> None: + """Test is_initialized property when service_thread_id is set.""" + thread = AgentThread(service_thread_id="test-123") + assert thread.is_initialized is True + + def test_is_initialized_with_message_store(self) -> None: + """Test is_initialized property when message_store is set.""" + store = ChatMessageStore() + thread = AgentThread(message_store=store) + assert thread.is_initialized is True + + def test_is_initialized_with_nothing(self) -> None: + """Test is_initialized property when nothing is set.""" + thread = AgentThread() + assert thread.is_initialized is False + + async def test_deserialize_with_custom_message_store(self) -> None: + """Test deserialize using a custom message store.""" + serialized_data = { + "service_thread_id": None, + "chat_message_store_state": { + "messages": [{"role": "user", "text": "Hello"}], + }, + } + custom_store = MockChatMessageStore() + + thread = await AgentThread.deserialize(serialized_data, message_store=custom_store) + + assert thread.message_store is custom_store + messages = await custom_store.list_messages() + assert len(messages) == 1 + + async def test_deserialize_with_failing_message_store_raises(self) -> None: + """Test deserialize raises AgentThreadException when message store fails.""" + + class FailingStore: + async def add_messages(self, messages: Sequence[ChatMessage], **kwargs: Any) -> None: + raise RuntimeError("Store failed") + + serialized_data = { + "service_thread_id": None, + "chat_message_store_state": { + "messages": [{"role": "user", "text": "Hello"}], + }, + } + failing_store = FailingStore() + + with pytest.raises(AgentThreadException, match="Failed to deserialize"): + await AgentThread.deserialize(serialized_data, message_store=failing_store) + + async def test_update_from_thread_state_with_service_thread_id(self) -> None: + """Test update_from_thread_state sets service_thread_id.""" + thread = AgentThread() + serialized_data = {"service_thread_id": "new-thread-id"} + + await thread.update_from_thread_state(serialized_data) + + assert thread.service_thread_id == "new-thread-id" + + async def test_update_from_thread_state_with_empty_chat_state(self) -> None: + """Test update_from_thread_state with empty chat_message_store_state.""" + thread = AgentThread() + serialized_data = {"service_thread_id": None, "chat_message_store_state": None} + + await thread.update_from_thread_state(serialized_data) + + assert thread.message_store is None + + async def test_update_from_thread_state_creates_message_store(self) -> None: + """Test update_from_thread_state creates message store if not existing.""" + thread = AgentThread() + serialized_data = { + "service_thread_id": None, + "chat_message_store_state": { + "messages": [{"role": "user", "text": "Hello"}], + }, + } + + await thread.update_from_thread_state(serialized_data) + + assert thread.message_store is not None + messages = await thread.message_store.list_messages() + assert len(messages) == 1