Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions python/packages/core/tests/core/test_memory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft. All rights reserved.

import sys
from collections.abc import MutableSequence
from typing import Any

Expand Down Expand Up @@ -44,6 +45,18 @@ async def invoking(self, messages: ChatMessage | MutableSequence[ChatMessage], *
return context


class MinimalContextProvider(ContextProvider):
"""Minimal ContextProvider that only implements the required abstract method.

Used to test the base class default implementations of thread_created,
invoked, __aenter__, and __aexit__.
"""

async def invoking(self, messages: ChatMessage | MutableSequence[ChatMessage], **kwargs: Any) -> Context:
"""Return empty context."""
return Context()


class TestContext:
"""Tests for Context class."""

Expand Down Expand Up @@ -91,3 +104,33 @@ async def test_invoking(self) -> None:
assert context.messages is not None
assert len(context.messages) == 1
assert context.messages[0].text == "Context message"

async def test_base_thread_created_does_nothing(self) -> None:
"""Test that base ContextProvider.thread_created does nothing by default."""
provider = MinimalContextProvider()
await provider.thread_created("some-thread-id")
await provider.thread_created(None)

async def test_base_invoked_does_nothing(self) -> None:
"""Test that base ContextProvider.invoked does nothing by default."""
provider = MinimalContextProvider()
message = ChatMessage(role=Role.USER, text="Test")
await provider.invoked(message)
await provider.invoked(message, response_messages=message)
await provider.invoked(message, invoke_exception=Exception("test"))

async def test_base_aenter_returns_self(self) -> None:
"""Test that base ContextProvider.__aenter__ returns self."""
provider = MinimalContextProvider()
async with provider as p:
assert p is provider

async def test_base_aexit_does_nothing(self) -> None:
"""Test that base ContextProvider.__aexit__ handles exceptions gracefully."""
provider = MinimalContextProvider()
await provider.__aexit__(None, None, None)
try:
raise ValueError("test error")
except ValueError:
exc_info = sys.exc_info()
await provider.__aexit__(exc_info[0], exc_info[1], exc_info[2])
237 changes: 237 additions & 0 deletions python/packages/core/tests/core/test_serializable_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,240 @@ def __init__(self, value: str, number: int, client: Any = None):
assert restored.value == "test"
assert restored.number == 42
assert restored.client == mock_client

def test_exclude_none_in_to_dict(self):
"""Test that exclude_none parameter removes None values from to_dict()."""

class TestClass(SerializationMixin):
def __init__(self, value: str, optional: str | None = None):
self.value = value
self.optional = optional

obj = TestClass(value="test", optional=None)
data = obj.to_dict(exclude_none=True)

assert data["value"] == "test"
assert "optional" not in data

def test_to_dict_with_nested_serialization_protocol(self):
"""Test to_dict handles nested SerializationProtocol objects."""

class InnerClass(SerializationMixin):
def __init__(self, inner_value: str):
self.inner_value = inner_value

class OuterClass(SerializationMixin):
def __init__(self, outer_value: str, inner: Any = None):
self.outer_value = outer_value
self.inner = inner

inner = InnerClass(inner_value="inner_test")
outer = OuterClass(outer_value="outer_test", inner=inner)
data = outer.to_dict()

assert data["outer_value"] == "outer_test"
assert data["inner"]["inner_value"] == "inner_test"

def test_to_dict_with_list_of_serialization_protocol(self):
"""Test to_dict handles lists containing SerializationProtocol objects."""

class ItemClass(SerializationMixin):
def __init__(self, name: str):
self.name = name

class ContainerClass(SerializationMixin):
def __init__(self, items: list):
self.items = items

items = [ItemClass(name="item1"), ItemClass(name="item2")]
container = ContainerClass(items=items)
data = container.to_dict()

assert len(data["items"]) == 2
assert data["items"][0]["name"] == "item1"
assert data["items"][1]["name"] == "item2"

def test_to_dict_skips_non_serializable_in_list(self, caplog):
"""Test to_dict skips non-serializable items in lists with debug logging."""

class NonSerializable:
pass

class TestClass(SerializationMixin):
def __init__(self, items: list):
self.items = items

obj = TestClass(items=["serializable", NonSerializable()])

with caplog.at_level(logging.DEBUG):
data = obj.to_dict()

# Should only contain the serializable item
assert len(data["items"]) == 1
assert data["items"][0] == "serializable"

def test_to_dict_with_dict_containing_serialization_protocol(self):
"""Test to_dict handles dicts containing SerializationProtocol values."""

class ItemClass(SerializationMixin):
def __init__(self, name: str):
self.name = name

class ContainerClass(SerializationMixin):
def __init__(self, items_dict: dict):
self.items_dict = items_dict

items = {"a": ItemClass(name="item1"), "b": ItemClass(name="item2")}
container = ContainerClass(items_dict=items)
data = container.to_dict()

assert data["items_dict"]["a"]["name"] == "item1"
assert data["items_dict"]["b"]["name"] == "item2"

def test_to_dict_with_datetime_in_dict(self):
"""Test to_dict converts datetime objects in dicts to strings."""
from datetime import datetime

class TestClass(SerializationMixin):
def __init__(self, metadata: dict):
self.metadata = metadata

now = datetime(2025, 1, 27, 12, 0, 0)
obj = TestClass(metadata={"created_at": now})
data = obj.to_dict()

assert isinstance(data["metadata"]["created_at"], str)

def test_to_dict_skips_non_serializable_in_dict(self, caplog):
"""Test to_dict skips non-serializable values in dicts with debug logging."""

class NonSerializable:
pass

class TestClass(SerializationMixin):
def __init__(self, metadata: dict):
self.metadata = metadata

obj = TestClass(metadata={"valid": "value", "invalid": NonSerializable()})

with caplog.at_level(logging.DEBUG):
data = obj.to_dict()

assert data["metadata"]["valid"] == "value"
assert "invalid" not in data["metadata"]

def test_to_dict_skips_non_serializable_attributes(self, caplog):
"""Test to_dict skips non-serializable top-level attributes."""

class TestClass(SerializationMixin):
def __init__(self, value: str, func: Any = None):
self.value = value
self.func = func

obj = TestClass(value="test", func=lambda x: x)

with caplog.at_level(logging.DEBUG):
data = obj.to_dict()

assert data["value"] == "test"
assert "func" not in data

def test_from_dict_without_type_in_data(self):
"""Test from_dict uses class TYPE when no type field in data."""

class TestClass(SerializationMixin):
TYPE = "my_custom_type"

def __init__(self, value: str):
self.value = value

# Data without 'type' field - class TYPE should be used for type identifier
data = {"value": "test"}

obj = TestClass.from_dict(data)
assert obj.value == "test"

# Verify to_dict includes the type
out = obj.to_dict()
assert out["type"] == "my_custom_type"

def test_from_json(self):
"""Test from_json deserializes JSON string."""

class TestClass(SerializationMixin):
def __init__(self, value: str):
self.value = value

json_str = '{"type": "test_class", "value": "test_value"}'
obj = TestClass.from_json(json_str)

assert obj.value == "test_value"

def test_get_type_identifier_with_instance_type(self):
"""Test _get_type_identifier uses instance 'type' attribute."""

class TestClass(SerializationMixin):
def __init__(self, value: str):
self.value = value
self.type = "custom_type"

obj = TestClass(value="test")
data = obj.to_dict()

assert data["type"] == "custom_type"

def test_get_type_identifier_with_class_TYPE(self):
"""Test _get_type_identifier uses class TYPE constant."""

class TestClass(SerializationMixin):
TYPE = "class_level_type"

def __init__(self, value: str):
self.value = value

obj = TestClass(value="test")
data = obj.to_dict()

assert data["type"] == "class_level_type"

def test_instance_specific_dependency_injection(self):
"""Test instance-specific dependency injection with field:name format."""

class TestClass(SerializationMixin):
INJECTABLE = {"config"}

def __init__(self, name: str, config: Any = None):
self.name = name
self.config = config

dependencies = {
"test_class": {
"name:special_instance": {"config": "special_config"},
}
}

# This should match the instance-specific dependency
obj = TestClass.from_dict({"type": "test_class", "name": "special_instance"}, dependencies=dependencies)

assert obj.name == "special_instance"
assert obj.config == "special_config"

def test_dependency_dict_merging(self):
"""Test that dict dependencies are merged with existing dict kwargs."""

class TestClass(SerializationMixin):
INJECTABLE = {"options"}

def __init__(self, value: str, options: dict | None = None):
self.value = value
self.options = options or {}

# Existing options in data
data = {"type": "test_class", "value": "test", "options": {"existing": "value"}}
# Additional options from dependencies
dependencies = {"test_class": {"options": {"injected": "option"}}}

obj = TestClass.from_dict(data, dependencies=dependencies)

assert obj.options["existing"] == "value"
assert obj.options["injected"] == "option"
Loading
Loading