diff --git a/Makefile b/Makefile index 93bc25332..9509d0dbe 100644 --- a/Makefile +++ b/Makefile @@ -40,7 +40,7 @@ snapshots-create: .PHONY: old_version_tests old_version_tests: - UV_PROJECT_ENVIRONMENT=.venv_39 uv run --python 3.9 -m pytest + UV_PROJECT_ENVIRONMENT=.venv_39 uv run --python 3.9 --all-extras -m pytest .PHONY: build-docs build-docs: diff --git a/docs/ref/extensions/memory.md b/docs/ref/extensions/memory.md new file mode 100644 index 000000000..e394b10b8 --- /dev/null +++ b/docs/ref/extensions/memory.md @@ -0,0 +1,7 @@ +# `Memory Extensions` + +::: agents.extensions.memory + + options: + members: + - PostgreSQLSession diff --git a/docs/sessions.md b/docs/sessions.md index c66cb85ae..a24d3d860 100644 --- a/docs/sessions.md +++ b/docs/sessions.md @@ -141,6 +141,27 @@ result = await Runner.run( ) ``` +### PostgreSQL memory + +```python +from agents.extensions.memory import PostgreSQLSession +from psycopg_pool import AsyncConnectionPool + +# From a connection string (creates a new connection pool) +session = await PostgreSQLSession.from_connection_string("user_123", "postgresql://user:pass@host/db") + +# From existing connection pool +pool = AsyncConnectionPool(connection_string) +session = PostgreSQLSession("user_123", pool) + +# Use the session +result = await Runner.run( + agent, + "Hello", + session=session +) +``` + ### Multiple sessions ```python @@ -222,7 +243,7 @@ Use meaningful session IDs that help you organize conversations: - Use in-memory SQLite (`SQLiteSession("session_id")`) for temporary conversations - Use file-based SQLite (`SQLiteSession("session_id", "path/to/db.sqlite")`) for persistent conversations -- Consider implementing custom session backends for production systems (Redis, PostgreSQL, etc.) +- Consider implementing custom session backends for production systems (Redis, etc.) ### Session management @@ -318,3 +339,4 @@ For detailed API documentation, see: - [`Session`][agents.memory.Session] - Protocol interface - [`SQLiteSession`][agents.memory.SQLiteSession] - SQLite implementation +- [`PostgreSQLSession`][agents.extensions.memory.PostgrSQLSession] - PostgreSQL implementation diff --git a/mkdocs.yml b/mkdocs.yml index be4976be4..0bca4e356 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -143,6 +143,7 @@ plugins: - ref/extensions/handoff_filters.md - ref/extensions/handoff_prompt.md - ref/extensions/litellm.md + - ref/extensions/memory.md - locale: ja name: 日本語 diff --git a/pyproject.toml b/pyproject.toml index 1cd4d683d..22e3162f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ voice = ["numpy>=2.2.0, <3; python_version>='3.10'", "websockets>=15.0, <16"] viz = ["graphviz>=0.17"] litellm = ["litellm>=1.67.4.post1, <2"] realtime = ["websockets>=15.0, <16"] +psycopg = ["psycopg[pool]>=3.2.9,<4"] [dependency-groups] dev = [ diff --git a/src/agents/extensions/memory/__init__.py b/src/agents/extensions/memory/__init__.py new file mode 100644 index 000000000..83ed6fecd --- /dev/null +++ b/src/agents/extensions/memory/__init__.py @@ -0,0 +1,3 @@ +from .postgres_session import PostgreSQLSession + +__all__ = ["PostgreSQLSession"] diff --git a/src/agents/extensions/memory/postgres_session.py b/src/agents/extensions/memory/postgres_session.py new file mode 100644 index 000000000..f3891e28f --- /dev/null +++ b/src/agents/extensions/memory/postgres_session.py @@ -0,0 +1,293 @@ +from __future__ import annotations + +import json +from dataclasses import dataclass +from typing import TYPE_CHECKING + +try: + import psycopg + from psycopg import sql + from psycopg.rows import class_row + from psycopg_pool import AsyncConnectionPool +except ImportError as _e: + raise ImportError( + "`psycopg` is required to use the PostgreSQLSession. You can install it via the optional " + "dependency group: `pip install 'openai-agents[psycopg]'`." + ) from _e + +if TYPE_CHECKING: + from agents.items import TResponseInputItem + +from agents.memory.session import Session + + +@dataclass +class MessageRow: + """Typed row for message queries.""" + + message_data: TResponseInputItem + + +@dataclass +class MessageWithIdRow: + """Typed row for message queries that include ID.""" + + id: int + message_data: TResponseInputItem + + +class PostgreSQLSession(Session): + """PostgreSQL-based implementation of session storage. + + This implementation stores conversation history in a PostgreSQL database. + Requires psycopg to be installed. + """ + + pool: AsyncConnectionPool + + def __init__( + self, + session_id: str, + pool: AsyncConnectionPool, + sessions_table: str = "agent_sessions", + messages_table: str = "agent_messages", + ): + """Initialize the PostgreSQL session. + + Args: + session_id: Unique identifier for the conversation session + pool: PostgreSQL connection pool instance. + This should be opened before passing to this class. + sessions_table: Name of the table to store session metadata. Defaults to + 'agent_sessions' + messages_table: Name of the table to store message data. Defaults to 'agent_messages' + """ + if psycopg is None: + raise ImportError( + "psycopg is required for PostgreSQL session storage. " + "Install with: pip install psycopg" + ) + + self.session_id = session_id + self.pool = pool + self.sessions_table = sessions_table + self.messages_table = messages_table + self._initialized = False + + @classmethod + async def from_connection_string( + cls, + session_id: str, + connection_string: str, + sessions_table: str = "agent_sessions", + messages_table: str = "agent_messages", + ) -> PostgreSQLSession: + """Create a PostgreSQL session from a connection string. + + Args: + session_id: Unique identifier for the conversation session + connection_string: PostgreSQL connection string (e.g., "postgresql://user:pass@host/db") + sessions_table: Name of the table to store session metadata. Defaults to + 'agent_sessions' + messages_table: Name of the table to store message data. Defaults to 'agent_messages' + + Returns: + PostgreSQLSession instance with a connection pool created from the connection string + """ + pool: AsyncConnectionPool = AsyncConnectionPool(connection_string, open=False) + await pool.open() + return cls(session_id, pool, sessions_table, messages_table) + + async def _ensure_initialized(self) -> None: + """Ensure the database schema is initialized.""" + if not self._initialized: + await self._init_db() + + async def _init_db(self) -> None: + """Initialize the database schema.""" + async with self.pool.connection() as conn: + async with conn.cursor() as cur: + # Create sessions table + query = sql.SQL(""" + CREATE TABLE IF NOT EXISTS {sessions_table} ( + session_id TEXT PRIMARY KEY, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """).format(sessions_table=sql.Identifier(self.sessions_table)) + await cur.execute(query) + + # Create messages table + query = sql.SQL(""" + CREATE TABLE IF NOT EXISTS {messages_table} ( + id SERIAL PRIMARY KEY, + session_id TEXT NOT NULL, + message_data JSONB NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (session_id) REFERENCES {sessions_table} (session_id) + ON DELETE CASCADE + ) + """).format( + messages_table=sql.Identifier(self.messages_table), + sessions_table=sql.Identifier(self.sessions_table), + ) + await cur.execute(query) + + # Create index for better performance + query = sql.SQL(""" + CREATE INDEX IF NOT EXISTS {index_name} + ON {messages_table} (session_id, created_at) + """).format( + index_name=sql.Identifier(f"idx_{self.messages_table}_session_id"), + messages_table=sql.Identifier(self.messages_table), + ) + await cur.execute(query) + + self._initialized = True + + async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + """Retrieve the conversation history for this session. + + Args: + limit: Maximum number of items to retrieve. If None, retrieves all items. + When specified, returns the latest N items in chronological order. + + Returns: + List of input items representing the conversation history + """ + await self._ensure_initialized() + + async with self.pool.connection() as conn: + async with conn.cursor(row_factory=class_row(MessageRow)) as cur: + if limit is None: + # Fetch all items in chronological order + query = sql.SQL(""" + SELECT message_data FROM {messages_table} + WHERE session_id = %s + ORDER BY created_at ASC + """).format(messages_table=sql.Identifier(self.messages_table)) + await cur.execute(query, (self.session_id,)) + else: + # Fetch the latest N items in chronological order + query = sql.SQL(""" + SELECT message_data FROM {messages_table} + WHERE session_id = %s + ORDER BY created_at DESC + LIMIT %s + """).format(messages_table=sql.Identifier(self.messages_table)) + await cur.execute(query, (self.session_id, limit)) + + rows = await cur.fetchall() + + items = [] + for row in rows: + try: + # PostgreSQL JSONB automatically handles deserialization + item = row.message_data + items.append(item) + except (AttributeError, TypeError): + # Skip invalid entries + continue + + # If we used LIMIT, reverse the items to get chronological order + if limit is not None: + items.reverse() + + return items + + async def add_items(self, items: list[TResponseInputItem]) -> None: + """Add new items to the conversation history. + + Args: + items: List of input items to add to the history + """ + if not items: + return + + await self._ensure_initialized() + + async with self.pool.connection() as conn: + async with conn.transaction(): + async with conn.cursor() as cur: + # Ensure session exists + query = sql.SQL(""" + INSERT INTO {sessions_table} (session_id) + VALUES (%s) + ON CONFLICT (session_id) DO NOTHING + """).format(sessions_table=sql.Identifier(self.sessions_table)) + await cur.execute(query, (self.session_id,)) + + # Add items + message_data = [(self.session_id, json.dumps(item)) for item in items] + query = sql.SQL(""" + INSERT INTO {messages_table} (session_id, message_data) + VALUES (%s, %s) + """).format(messages_table=sql.Identifier(self.messages_table)) + await cur.executemany(query, message_data) + + # Update session timestamp + query = sql.SQL(""" + UPDATE {sessions_table} + SET updated_at = CURRENT_TIMESTAMP + WHERE session_id = %s + """).format(sessions_table=sql.Identifier(self.sessions_table)) + await cur.execute(query, (self.session_id,)) + + async def pop_item(self) -> TResponseInputItem | None: + """Remove and return the most recent item from the session. + + Returns: + The most recent item if it exists, None if the session is empty + """ + await self._ensure_initialized() + + async with self.pool.connection() as conn: + async with conn.transaction(): + async with conn.cursor(row_factory=class_row(MessageRow)) as cur: + # Delete and return the most recent item in one query + query = sql.SQL(""" + DELETE FROM {messages_table} + WHERE id = ( + SELECT id FROM {messages_table} + WHERE session_id = %s + ORDER BY created_at DESC + LIMIT 1 + ) + RETURNING message_data + """).format(messages_table=sql.Identifier(self.messages_table)) + await cur.execute(query, (self.session_id,)) + + row = await cur.fetchone() + + if row is None: + return None + + try: + # PostgreSQL JSONB automatically handles deserialization + item = row.message_data + return item + except (AttributeError, TypeError): + # Return None for corrupted entries (already deleted) + return None + + async def clear_session(self) -> None: + """Clear all items for this session.""" + await self._ensure_initialized() + + async with self.pool.connection() as conn: + async with conn.transaction(): + async with conn.cursor() as cur: + query = sql.SQL(""" + DELETE FROM {messages_table} WHERE session_id = %s + """).format(messages_table=sql.Identifier(self.messages_table)) + await cur.execute(query, (self.session_id,)) + + query = sql.SQL(""" + DELETE FROM {sessions_table} WHERE session_id = %s + """).format(sessions_table=sql.Identifier(self.sessions_table)) + await cur.execute(query, (self.session_id,)) + + async def close(self) -> None: + """Close the database connection pool.""" + await self.pool.close() + self._initialized = False diff --git a/tests/test_postgresql_session.py b/tests/test_postgresql_session.py new file mode 100644 index 000000000..d756b6bcc --- /dev/null +++ b/tests/test_postgresql_session.py @@ -0,0 +1,559 @@ +import json +import unittest +from typing import Any, cast +from unittest.mock import AsyncMock, patch + +import pytest +from psycopg import AsyncConnection +from psycopg.rows import TupleRow +from psycopg_pool import AsyncConnectionPool + +from agents.extensions.memory import ( + PostgreSQLSession, +) +from agents.extensions.memory.postgres_session import MessageRow +from agents.items import TResponseInputItem + + +class AsyncContextManagerMock: + """Helper class to mock async context managers.""" + + def __init__(self, return_value): + self.return_value = return_value + + async def __aenter__(self): + return self.return_value + + async def __aexit__(self, _exc_type, _exc_val, _exc_tb): + return None + + +class TestPostgreSQLSession(unittest.IsolatedAsyncioTestCase): + """Test suite for PostgreSQLSession class.""" + + def setUp(self) -> None: + """Set up test fixtures.""" + self.mock_pool = AsyncMock() + + # Make connection method return the async context manager directly, not a coroutine + def mock_connection() -> AsyncContextManagerMock: + return AsyncContextManagerMock(AsyncMock()) + + self.mock_pool.connection = mock_connection + self.session = PostgreSQLSession( + session_id="test_session_123", + pool=self.mock_pool, + sessions_table="test_sessions", + messages_table="test_messages", + ) + + def setup_connection_mock(self, mock_conn: AsyncMock, mock_cursor: AsyncMock) -> None: + """Helper to set up connection and cursor mocks properly.""" + + def mock_connection() -> AsyncContextManagerMock: + return AsyncContextManagerMock(mock_conn) + + self.mock_pool.connection = mock_connection + + def mock_cursor_method(*_args: Any, **_kwargs: Any) -> AsyncContextManagerMock: + return AsyncContextManagerMock(mock_cursor) + + mock_conn.cursor = mock_cursor_method + + def mock_transaction() -> AsyncContextManagerMock: + return AsyncContextManagerMock(None) + + mock_conn.transaction = mock_transaction + + def test_init_with_defaults(self): + """Test initialization with default table names.""" + mock_pool: AsyncConnectionPool[AsyncConnection[TupleRow]] = AsyncMock() + session = PostgreSQLSession("test", mock_pool) + self.assertEqual(session.session_id, "test") + self.assertEqual(session.pool, mock_pool) + self.assertEqual(session.sessions_table, "agent_sessions") + self.assertEqual(session.messages_table, "agent_messages") + self.assertFalse(session._initialized) + + async def test_ensure_initialized_once(self): + """Test that database initialization happens only once.""" + + async def mock_init_db(): + self.session._initialized = True + + with patch.object(self.session, "_init_db", side_effect=mock_init_db) as mock_init: + await self.session._ensure_initialized() + await self.session._ensure_initialized() + + # Should only be called once due to the _initialized flag + mock_init.assert_called_once() + + async def test_init_db_creates_tables(self): + """Test that database initialization creates necessary tables.""" + mock_conn = AsyncMock() + mock_cursor = AsyncMock() + + # Set up context managers + self.setup_connection_mock(mock_conn, mock_cursor) + + await self.session._init_db() + + # Check that execute was called for sessions table, messages table, and index + self.assertEqual(mock_cursor.execute.call_count, 3) + + # Verify sessions table creation + sessions_call = mock_cursor.execute.call_args_list[0][0][0] + sessions_call_str = str(sessions_call) + self.assertIn("CREATE TABLE IF NOT EXISTS", sessions_call_str) + self.assertIn("session_id TEXT PRIMARY KEY", sessions_call_str) + + # Verify messages table creation + messages_call = mock_cursor.execute.call_args_list[1][0][0] + messages_call_str = str(messages_call) + self.assertIn("CREATE TABLE IF NOT EXISTS", messages_call_str) + self.assertIn("message_data JSONB NOT NULL", messages_call_str) + self.assertIn("FOREIGN KEY (session_id) REFERENCES", messages_call_str) + + # Verify index creation + index_call = mock_cursor.execute.call_args_list[2][0][0] + index_call_str = str(index_call) + self.assertIn("CREATE INDEX IF NOT EXISTS", index_call_str) + + self.assertTrue(self.session._initialized) + + async def test_get_items_no_limit(self): + """Test getting all items without limit.""" + mock_conn = AsyncMock() + mock_cursor = AsyncMock() + + # Set up context managers + self.setup_connection_mock(mock_conn, mock_cursor) + + # Mock fetchall to return test data + test_data = [ + MessageRow(message_data={"role": "user", "content": "Hello"}), + MessageRow(message_data={"role": "assistant", "content": "Hi there"}), + ] + mock_cursor.fetchall.return_value = test_data + + with patch.object(self.session, "_ensure_initialized", new_callable=AsyncMock): + result = await self.session.get_items() + + self.assertEqual(len(result), 2) + self.assertEqual(result[0], {"role": "user", "content": "Hello"}) + self.assertEqual(result[1], {"role": "assistant", "content": "Hi there"}) + + # Verify query was called correctly + mock_cursor.execute.assert_called_once() + query_call = mock_cursor.execute.call_args[0][0] + query_call_str = str(query_call) + self.assertIn("SELECT message_data FROM", query_call_str) + self.assertIn("ORDER BY created_at ASC", query_call_str) + self.assertNotIn("LIMIT", query_call_str) + + async def test_get_items_with_limit(self): + """Test getting items with a limit.""" + mock_conn = AsyncMock() + mock_cursor = AsyncMock() + + # Set up context managers + self.setup_connection_mock(mock_conn, mock_cursor) + + test_data = [MessageRow(message_data={"role": "user", "content": "Hello"})] + mock_cursor.fetchall.return_value = test_data + + with patch.object(self.session, "_ensure_initialized", new_callable=AsyncMock): + result = await self.session.get_items(limit=5) + + self.assertEqual(len(result), 1) + + # Verify query includes limit and uses subquery + query_call = mock_cursor.execute.call_args[0][0] + query_call_str = str(query_call) + self.assertIn("LIMIT %s", query_call_str) + self.assertIn("ORDER BY created_at DESC", query_call_str) + self.assertEqual(mock_cursor.execute.call_args[0][1], ("test_session_123", 5)) + + async def test_get_items_handles_invalid_data(self): + """Test that get_items handles invalid data gracefully.""" + mock_conn = AsyncMock() + mock_cursor = AsyncMock() + + # Set up context managers + self.setup_connection_mock(mock_conn, mock_cursor) + + # Mix of valid and invalid data + test_data = [ + MessageRow(message_data={"role": "user", "content": "Hello"}), + None, # This should be skipped due to AttributeError + MessageRow(message_data={"role": "assistant", "content": "Hi"}), + ] + mock_cursor.fetchall.return_value = test_data + + with patch.object(self.session, "_ensure_initialized", new_callable=AsyncMock): + result = await self.session.get_items() + + # Should only return valid items + self.assertEqual(len(result), 2) + self.assertEqual(result[0], {"role": "user", "content": "Hello"}) + self.assertEqual(result[1], {"role": "assistant", "content": "Hi"}) + + async def test_add_items_empty_list(self): + """Test adding empty list of items.""" + with patch.object(self.session, "_ensure_initialized", new_callable=AsyncMock) as mock_init: + await self.session.add_items([]) + mock_init.assert_not_called() + + async def test_add_items_success(self): + """Test successfully adding items.""" + mock_conn = AsyncMock() + mock_cursor = AsyncMock() + + # Set up context managers + self.setup_connection_mock(mock_conn, mock_cursor) + + test_items = cast( + list[TResponseInputItem], + [{"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi there"}], + ) + + with patch.object(self.session, "_ensure_initialized", new_callable=AsyncMock): + await self.session.add_items(test_items) + + # Verify session creation, item insertion, and timestamp update + self.assertEqual(mock_cursor.execute.call_count, 2) # session insert + timestamp update + self.assertEqual(mock_cursor.executemany.call_count, 1) # items insert + + # Check session insert + session_call = mock_cursor.execute.call_args_list[0] + session_call_str = str(session_call[0][0]) + self.assertIn("INSERT INTO", session_call_str) + self.assertIn("ON CONFLICT (session_id) DO NOTHING", session_call_str) + + # Check items insert + items_call = mock_cursor.executemany.call_args + items_call_str = str(items_call[0][0]) + self.assertIn("INSERT INTO", items_call_str) + expected_data = [ + ("test_session_123", json.dumps(test_items[0])), + ("test_session_123", json.dumps(test_items[1])), + ] + self.assertEqual(items_call[0][1], expected_data) + + # Check timestamp update + timestamp_call = mock_cursor.execute.call_args_list[1] + timestamp_call_str = str(timestamp_call[0][0]) + self.assertIn("UPDATE", timestamp_call_str) + self.assertIn("SET updated_at = CURRENT_TIMESTAMP", timestamp_call_str) + + async def test_pop_item_success(self): + """Test successfully popping an item.""" + mock_conn = AsyncMock() + mock_cursor = AsyncMock() + + # Set up context managers + self.setup_connection_mock(mock_conn, mock_cursor) + + test_item = {"role": "user", "content": "Hello", "type": "message"} + mock_cursor.fetchone.return_value = MessageRow( + message_data={"role": "user", "content": "Hello", "type": "message"} + ) + + with patch.object(self.session, "_ensure_initialized", new_callable=AsyncMock): + result = await self.session.pop_item() + + self.assertEqual(result, test_item) + + # Verify single DELETE ... RETURNING query + self.assertEqual(mock_cursor.execute.call_count, 1) + + # Check delete query with RETURNING + delete_call = mock_cursor.execute.call_args_list[0] + delete_call_str = str(delete_call[0][0]) + self.assertIn("DELETE FROM", delete_call_str) + self.assertIn("RETURNING message_data", delete_call_str) + self.assertIn("ORDER BY created_at DESC", delete_call_str) + self.assertIn("LIMIT 1", delete_call_str) + self.assertEqual(delete_call[0][1], ("test_session_123",)) + + async def test_pop_item_empty_session(self): + """Test popping from an empty session.""" + mock_conn = AsyncMock() + mock_cursor = AsyncMock() + + # Set up context managers + self.setup_connection_mock(mock_conn, mock_cursor) + + mock_cursor.fetchone.return_value = None + + with patch.object(self.session, "_ensure_initialized", new_callable=AsyncMock): + result = await self.session.pop_item() + + self.assertIsNone(result) + + # Should only call select, not delete + self.assertEqual(mock_cursor.execute.call_count, 1) + + async def test_pop_item_handles_invalid_data(self): + """Test that pop_item handles invalid data gracefully.""" + mock_conn = AsyncMock() + mock_cursor = AsyncMock() + + # Set up context managers + self.setup_connection_mock(mock_conn, mock_cursor) + + # Invalid data structure - mock object without message_data attribute + class InvalidRow: + def __init__(self): + self.id = 123 + # No message_data attribute + + mock_cursor.fetchone.return_value = InvalidRow() + + with patch.object(self.session, "_ensure_initialized", new_callable=AsyncMock): + result = await self.session.pop_item() + + self.assertIsNone(result) + + # Should execute the DELETE ... RETURNING query once + self.assertEqual(mock_cursor.execute.call_count, 1) + + async def test_clear_session(self): + """Test clearing session.""" + mock_conn = AsyncMock() + mock_cursor = AsyncMock() + + # Set up context managers + self.setup_connection_mock(mock_conn, mock_cursor) + + with patch.object(self.session, "_ensure_initialized", new_callable=AsyncMock): + await self.session.clear_session() + + # Should delete from both messages and sessions tables + self.assertEqual(mock_cursor.execute.call_count, 2) + + # Check messages deletion + messages_call = mock_cursor.execute.call_args_list[0] + messages_call_str = str(messages_call[0][0]) + self.assertIn("DELETE FROM", messages_call_str) + self.assertIn("WHERE session_id = %s", messages_call_str) + self.assertEqual(messages_call[0][1], ("test_session_123",)) + + # Check sessions deletion + sessions_call = mock_cursor.execute.call_args_list[1] + sessions_call_str = str(sessions_call[0][0]) + self.assertIn("DELETE FROM", sessions_call_str) + self.assertIn("WHERE session_id = %s", sessions_call_str) + self.assertEqual(sessions_call[0][1], ("test_session_123",)) + + async def test_close(self): + """Test closing the session.""" + self.session._initialized = True + + await self.session.close() + + self.mock_pool.close.assert_called_once() + self.assertFalse(self.session._initialized) + + @patch("agents.extensions.memory.postgres_session.AsyncConnectionPool") + async def test_from_connection_string_success(self, mock_pool_class): + """Test creating a session from connection string.""" + mock_pool = AsyncMock() + mock_pool_class.return_value = mock_pool + + connection_string = "postgresql://user:pass@host/db" + session_id = "test_session_123" + + session = await PostgreSQLSession.from_connection_string(session_id, connection_string) + + # Verify pool was created with the connection string + mock_pool_class.assert_called_once_with(connection_string, open=False) + mock_pool.open.assert_called_once() + + # Verify session was created with correct parameters + self.assertEqual(session.session_id, session_id) + self.assertEqual(session.pool, mock_pool) + self.assertEqual(session.sessions_table, "agent_sessions") + self.assertEqual(session.messages_table, "agent_messages") + + @patch("agents.extensions.memory.postgres_session.AsyncConnectionPool") + async def test_from_connection_string_custom_tables(self, mock_pool_class): + """Test creating a session from connection string with custom table names.""" + mock_pool = AsyncMock() + mock_pool_class.return_value = mock_pool + + connection_string = "postgresql://user:pass@host/db" + session_id = "test_session_123" + custom_sessions_table = "custom_sessions" + custom_messages_table = "custom_messages" + + session = await PostgreSQLSession.from_connection_string( + session_id, + connection_string, + sessions_table=custom_sessions_table, + messages_table=custom_messages_table, + ) + + # Verify pool was created with the connection string + mock_pool_class.assert_called_once_with(connection_string, open=False) + mock_pool.open.assert_called_once() + + # Verify session was created with correct parameters + self.assertEqual(session.session_id, session_id) + self.assertEqual(session.pool, mock_pool) + self.assertEqual(session.sessions_table, custom_sessions_table) + self.assertEqual(session.messages_table, custom_messages_table) + + +@pytest.mark.skip(reason="Integration tests require a running PostgreSQL instance") +class TestPostgreSQLSessionIntegration(unittest.IsolatedAsyncioTestCase): + """Integration tests for PostgreSQL session that require a running database.""" + + # Test connection string - modify as needed for your test database + TEST_CONNECTION_STRING = "postgresql://postgres:password@localhost:5432/test_db" + + async def asyncSetUp(self): + """Set up test session.""" + self.session_id = "test_integration_session" + self.session = await PostgreSQLSession.from_connection_string( + self.session_id, + self.TEST_CONNECTION_STRING, + sessions_table="test_sessions", + messages_table="test_messages", + ) + + # Clean up any existing test data + await self.session.clear_session() + + async def asyncTearDown(self): + """Clean up after tests.""" + if hasattr(self, "session"): + await self.session.clear_session() + await self.session.close() + + async def test_integration_full_workflow(self): + """Test complete workflow: add items, get items, pop item, clear session.""" + # Initially empty + items = await self.session.get_items() + self.assertEqual(len(items), 0) + + # Add some test items + test_items = cast( + list[TResponseInputItem], + [ + {"role": "user", "content": "Hello", "type": "message"}, + {"role": "assistant", "content": "Hi there!", "type": "message"}, + {"role": "user", "content": "How are you?", "type": "message"}, + {"role": "assistant", "content": "I'm doing well, thank you!", "type": "message"}, + ], + ) + + for item in test_items: + await self.session.add_items([item]) + + # Verify items were added + stored_items = await self.session.get_items() + self.assertEqual(len(stored_items), 4) + self.assertEqual(stored_items[0], test_items[0]) + self.assertEqual(stored_items[-1], test_items[-1]) + + # Test with limit + limited_items = await self.session.get_items(limit=2) + self.assertEqual(len(limited_items), 2) + # Should get the last 2 items in chronological order + self.assertEqual(limited_items[0], test_items[2]) + self.assertEqual(limited_items[1], test_items[3]) + + # Test pop_item + popped_item = await self.session.pop_item() + self.assertEqual(popped_item, test_items[3]) # Last item + + # Verify item was removed + remaining_items = await self.session.get_items() + self.assertEqual(len(remaining_items), 3) + self.assertEqual(remaining_items[-1], test_items[2]) + + # Test clear_session + await self.session.clear_session() + final_items = await self.session.get_items() + self.assertEqual(len(final_items), 0) + + async def test_integration_multiple_sessions(self): + """Test that different sessions maintain separate data.""" + # Create a second session + session2 = await PostgreSQLSession.from_connection_string( + "test_integration_session_2", + self.TEST_CONNECTION_STRING, + sessions_table="test_sessions", + messages_table="test_messages", + ) + + try: + # Add different items to each session + items1 = cast( + list[TResponseInputItem], + [{"role": "user", "content": "Session 1 message", "type": "message"}], + ) + items2 = cast( + list[TResponseInputItem], + [{"role": "user", "content": "Session 2 message", "type": "message"}], + ) + + await self.session.add_items(items1) + await session2.add_items(items2) + + # Verify sessions have different data + session1_items = await self.session.get_items() + session2_items = await session2.get_items() + + self.assertEqual(len(session1_items), 1) + self.assertEqual(len(session2_items), 1) + self.assertEqual(session1_items[0]["content"], "Session 1 message") # type: ignore + self.assertEqual(session2_items[0]["content"], "Session 2 message") # type: ignore + + finally: + await session2.clear_session() + await session2.close() + + async def test_integration_empty_session_operations(self): + """Test operations on empty session.""" + # Pop from empty session + popped = await self.session.pop_item() + self.assertIsNone(popped) + + # Get items from empty session + items = await self.session.get_items() + self.assertEqual(len(items), 0) + + # Get items with limit from empty session + limited_items = await self.session.get_items(limit=5) + self.assertEqual(len(limited_items), 0) + + # Clear empty session (should not error) + await self.session.clear_session() + + async def test_integration_connection_string_with_custom_tables(self): + """Test creating session with custom table names.""" + custom_session = await PostgreSQLSession.from_connection_string( + "custom_table_test", + self.TEST_CONNECTION_STRING, + sessions_table="custom_sessions_table", + messages_table="custom_messages_table", + ) + + try: + # Test basic functionality with custom tables + test_items = cast( + list[TResponseInputItem], + [{"role": "user", "content": "Custom table test", "type": "message"}], + ) + + await custom_session.add_items(test_items) + stored_items = await custom_session.get_items() + + self.assertEqual(len(stored_items), 1) + self.assertEqual(stored_items[0]["content"], "Custom table test") # type: ignore + + finally: + await custom_session.clear_session() + await custom_session.close() diff --git a/uv.lock b/uv.lock index 5d19a321d..ce67e85bd 100644 --- a/uv.lock +++ b/uv.lock @@ -1498,6 +1498,9 @@ dependencies = [ litellm = [ { name = "litellm" }, ] +psycopg = [ + { name = "psycopg", extra = ["pool"] }, +] realtime = [ { name = "websockets" }, ] @@ -1542,6 +1545,7 @@ requires-dist = [ { name = "mcp", marker = "python_full_version >= '3.10'", specifier = ">=1.11.0,<2" }, { name = "numpy", marker = "python_full_version >= '3.10' and extra == 'voice'", specifier = ">=2.2.0,<3" }, { name = "openai", specifier = ">=1.97.1,<2" }, + { name = "psycopg", extras = ["pool"], marker = "extra == 'psycopg'", specifier = ">=3.2.9,<4" }, { name = "pydantic", specifier = ">=2.10,<3" }, { name = "requests", specifier = ">=2.0,<3" }, { name = "types-requests", specifier = ">=2.0,<3" }, @@ -1549,7 +1553,7 @@ requires-dist = [ { name = "websockets", marker = "extra == 'realtime'", specifier = ">=15.0,<16" }, { name = "websockets", marker = "extra == 'voice'", specifier = ">=15.0,<16" }, ] -provides-extras = ["voice", "viz", "litellm", "realtime"] +provides-extras = ["voice", "viz", "litellm", "realtime", "psycopg"] [package.metadata.requires-dev] dev = [ @@ -1745,6 +1749,36 @@ wheels = [ { url = "https://pypi.tuna.tsinghua.edu.cn/packages/b8/d3/c3cb8f1d6ae3b37f83e1de806713a9b3642c5895f0215a62e1a4bd6e5e34/propcache-0.3.1-py3-none-any.whl", hash = "sha256:9a8ecf38de50a7f518c21568c80f985e776397b902f1ce0b01f799aba1608b40", size = 12376, upload-time = "2025-03-26T03:06:10.5Z" }, ] +[[package]] +name = "psycopg" +version = "3.2.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, + { name = "tzdata", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/27/4a/93a6ab570a8d1a4ad171a1f4256e205ce48d828781312c0bbaff36380ecb/psycopg-3.2.9.tar.gz", hash = "sha256:2fbb46fcd17bc81f993f28c47f1ebea38d66ae97cc2dbc3cad73b37cefbff700", size = 158122, upload-time = "2025-05-13T16:11:15.533Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/44/b0/a73c195a56eb6b92e937a5ca58521a5c3346fb233345adc80fd3e2f542e2/psycopg-3.2.9-py3-none-any.whl", hash = "sha256:01a8dadccdaac2123c916208c96e06631641c0566b22005493f09663c7a8d3b6", size = 202705, upload-time = "2025-05-13T16:06:26.584Z" }, +] + +[package.optional-dependencies] +pool = [ + { name = "psycopg-pool" }, +] + +[[package]] +name = "psycopg-pool" +version = "3.2.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/cf/13/1e7850bb2c69a63267c3dbf37387d3f71a00fd0e2fa55c5db14d64ba1af4/psycopg_pool-3.2.6.tar.gz", hash = "sha256:0f92a7817719517212fbfe2fd58b8c35c1850cdd2a80d36b581ba2085d9148e5", size = 29770, upload-time = "2025-02-26T12:03:47.129Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/47/fd/4feb52a55c1a4bd748f2acaed1903ab54a723c47f6d0242780f4d97104d4/psycopg_pool-3.2.6-py3-none-any.whl", hash = "sha256:5887318a9f6af906d041a0b1dc1c60f8f0dda8340c2572b74e10907b51ed5da7", size = 38252, upload-time = "2025-02-26T12:03:45.073Z" }, +] + [[package]] name = "pycparser" version = "2.22" @@ -2711,6 +2745,15 @@ wheels = [ { url = "https://pypi.tuna.tsinghua.edu.cn/packages/31/08/aa4fdfb71f7de5176385bd9e90852eaf6b5d622735020ad600f2bab54385/typing_inspection-0.4.0-py3-none-any.whl", hash = "sha256:50e72559fcd2a6367a19f7a7e610e6afcb9fac940c650290eed893d61386832f", size = 14125, upload-time = "2025-02-25T17:27:57.754Z" }, ] +[[package]] +name = "tzdata" +version = "2025.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/95/32/1a225d6164441be760d75c2c42e2780dc0873fe382da3e98a2e1e48361e5/tzdata-2025.2.tar.gz", hash = "sha256:b60a638fcc0daffadf82fe0f57e53d06bdec2f36c4df66280ae79bce6bd6f2b9", size = 196380, upload-time = "2025-03-23T13:54:43.652Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5c/23/c7abc0ca0a1526a0774eca151daeb8de62ec457e77262b66b359c3c7679e/tzdata-2025.2-py2.py3-none-any.whl", hash = "sha256:1a403fada01ff9221ca8044d701868fa132215d84beb92242d9acd2147f667a8", size = 347839, upload-time = "2025-03-23T13:54:41.845Z" }, +] + [[package]] name = "uc-micro-py" version = "1.0.3"