Skip to content

Commit be28fbc

Browse files
Fix PostgreSQL Session issues
* Fix warning from opening a pool in the constructor: https://www.psycopg.org/psycopg3/docs/news_pool.html#psycopg-pool-3-2-2 * Add integration tests that can be run locally * Fix get_items with limit Signed-off-by: Aidan Jensen <[email protected]>
1 parent 295dbb3 commit be28fbc

File tree

3 files changed

+220
-13
lines changed

3 files changed

+220
-13
lines changed

docs/sessions.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ from agents.extensions.memory import PostgreSQLSession
148148
from psycopg_pool import AsyncConnectionPool
149149

150150
# From a connection string (creates a new connection pool)
151-
session = PostgreSQLSession.from_connection_string("user_123", "postgresql://user:pass@host/db")
151+
session = await PostgreSQLSession.from_connection_string("user_123", "postgresql://user:pass@host/db")
152152

153153
# From existing connection pool
154154
pool = AsyncConnectionPool(connection_string)

src/agents/extensions/memory/postgres_session.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ def __init__(
5656
5757
Args:
5858
session_id: Unique identifier for the conversation session
59-
pool: PostgreSQL connection pool instance
59+
pool: PostgreSQL connection pool instance.
60+
This should be opened before passing to this class.
6061
sessions_table: Name of the table to store session metadata. Defaults to
6162
'agent_sessions'
6263
messages_table: Name of the table to store message data. Defaults to 'agent_messages'
@@ -74,7 +75,7 @@ def __init__(
7475
self._initialized = False
7576

7677
@classmethod
77-
def from_connection_string(
78+
async def from_connection_string(
7879
cls,
7980
session_id: str,
8081
connection_string: str,
@@ -93,7 +94,8 @@ def from_connection_string(
9394
Returns:
9495
PostgreSQLSession instance with a connection pool created from the connection string
9596
"""
96-
pool: AsyncConnectionPool = AsyncConnectionPool(connection_string)
97+
pool: AsyncConnectionPool = AsyncConnectionPool(connection_string, open=False)
98+
await pool.open()
9799
return cls(session_id, pool, sessions_table, messages_table)
98100

99101
async def _ensure_initialized(self) -> None:
@@ -168,13 +170,10 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
168170
else:
169171
# Fetch the latest N items in chronological order
170172
query = sql.SQL("""
171-
SELECT message_data FROM (
172-
SELECT message_data FROM {messages_table}
173-
WHERE session_id = %s
174-
ORDER BY created_at DESC
175-
LIMIT %s
176-
) t
177-
ORDER BY created_at ASC
173+
SELECT message_data FROM {messages_table}
174+
WHERE session_id = %s
175+
ORDER BY created_at DESC
176+
LIMIT %s
178177
""").format(messages_table=sql.Identifier(self.messages_table))
179178
await cur.execute(query, (self.session_id, limit))
180179

@@ -190,6 +189,10 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
190189
# Skip invalid entries
191190
continue
192191

192+
# If we used LIMIT, reverse the items to get chronological order
193+
if limit is not None:
194+
items.reverse()
195+
193196
return items
194197

195198
async def add_items(self, items: list[TResponseInputItem]) -> None:

tests/test_postgresql_session.py

Lines changed: 206 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Any, cast
44
from unittest.mock import AsyncMock, patch
55

6+
import pytest
67
from psycopg import AsyncConnection
78
from psycopg.rows import TupleRow
89
from psycopg_pool import AsyncConnectionPool
@@ -23,7 +24,7 @@ def __init__(self, return_value):
2324
async def __aenter__(self):
2425
return self.return_value
2526

26-
async def __aexit__(self, exc_type, exc_val, exc_tb):
27+
async def __aexit__(self, _exc_type, _exc_val, _exc_tb):
2728
return None
2829

2930

@@ -54,7 +55,7 @@ def mock_connection() -> AsyncContextManagerMock:
5455

5556
self.mock_pool.connection = mock_connection
5657

57-
def mock_cursor_method(*args: Any, **kwargs: Any) -> AsyncContextManagerMock:
58+
def mock_cursor_method(*_args: Any, **_kwargs: Any) -> AsyncContextManagerMock:
5859
return AsyncContextManagerMock(mock_cursor)
5960

6061
mock_conn.cursor = mock_cursor_method
@@ -353,3 +354,206 @@ async def test_close(self):
353354

354355
self.mock_pool.close.assert_called_once()
355356
self.assertFalse(self.session._initialized)
357+
358+
@patch("agents.extensions.memory.postgres_session.AsyncConnectionPool")
359+
async def test_from_connection_string_success(self, mock_pool_class):
360+
"""Test creating a session from connection string."""
361+
mock_pool = AsyncMock()
362+
mock_pool_class.return_value = mock_pool
363+
364+
connection_string = "postgresql://user:pass@host/db"
365+
session_id = "test_session_123"
366+
367+
session = await PostgreSQLSession.from_connection_string(session_id, connection_string)
368+
369+
# Verify pool was created with the connection string
370+
mock_pool_class.assert_called_once_with(connection_string)
371+
mock_pool.open.assert_called_once()
372+
373+
# Verify session was created with correct parameters
374+
self.assertEqual(session.session_id, session_id)
375+
self.assertEqual(session.pool, mock_pool)
376+
self.assertEqual(session.sessions_table, "agent_sessions")
377+
self.assertEqual(session.messages_table, "agent_messages")
378+
379+
@patch("agents.extensions.memory.postgres_session.AsyncConnectionPool")
380+
async def test_from_connection_string_custom_tables(self, mock_pool_class):
381+
"""Test creating a session from connection string with custom table names."""
382+
mock_pool = AsyncMock()
383+
mock_pool_class.return_value = mock_pool
384+
385+
connection_string = "postgresql://user:pass@host/db"
386+
session_id = "test_session_123"
387+
custom_sessions_table = "custom_sessions"
388+
custom_messages_table = "custom_messages"
389+
390+
session = await PostgreSQLSession.from_connection_string(
391+
session_id,
392+
connection_string,
393+
sessions_table=custom_sessions_table,
394+
messages_table=custom_messages_table,
395+
)
396+
397+
# Verify pool was created with the connection string
398+
mock_pool_class.assert_called_once_with(connection_string)
399+
mock_pool.open.assert_called_once()
400+
401+
# Verify session was created with correct parameters
402+
self.assertEqual(session.session_id, session_id)
403+
self.assertEqual(session.pool, mock_pool)
404+
self.assertEqual(session.sessions_table, custom_sessions_table)
405+
self.assertEqual(session.messages_table, custom_messages_table)
406+
407+
408+
@pytest.mark.skip(reason="Integration tests require a running PostgreSQL instance")
409+
class TestPostgreSQLSessionIntegration(unittest.IsolatedAsyncioTestCase):
410+
"""Integration tests for PostgreSQL session that require a running database."""
411+
412+
# Test connection string - modify as needed for your test database
413+
TEST_CONNECTION_STRING = "postgresql://postgres:password@localhost:5432/test_db"
414+
415+
async def asyncSetUp(self):
416+
"""Set up test session."""
417+
self.session_id = "test_integration_session"
418+
self.session = await PostgreSQLSession.from_connection_string(
419+
self.session_id,
420+
self.TEST_CONNECTION_STRING,
421+
sessions_table="test_sessions",
422+
messages_table="test_messages",
423+
)
424+
425+
# Clean up any existing test data
426+
await self.session.clear_session()
427+
428+
async def asyncTearDown(self):
429+
"""Clean up after tests."""
430+
if hasattr(self, "session"):
431+
await self.session.clear_session()
432+
await self.session.close()
433+
434+
async def test_integration_full_workflow(self):
435+
"""Test complete workflow: add items, get items, pop item, clear session."""
436+
# Initially empty
437+
items = await self.session.get_items()
438+
self.assertEqual(len(items), 0)
439+
440+
# Add some test items
441+
test_items = cast(
442+
list[TResponseInputItem],
443+
[
444+
{"role": "user", "content": "Hello", "type": "message"},
445+
{"role": "assistant", "content": "Hi there!", "type": "message"},
446+
{"role": "user", "content": "How are you?", "type": "message"},
447+
{"role": "assistant", "content": "I'm doing well, thank you!", "type": "message"},
448+
],
449+
)
450+
451+
for item in test_items:
452+
await self.session.add_items([item])
453+
454+
# Verify items were added
455+
stored_items = await self.session.get_items()
456+
self.assertEqual(len(stored_items), 4)
457+
self.assertEqual(stored_items[0], test_items[0])
458+
self.assertEqual(stored_items[-1], test_items[-1])
459+
460+
# Test with limit
461+
limited_items = await self.session.get_items(limit=2)
462+
self.assertEqual(len(limited_items), 2)
463+
# Should get the last 2 items in chronological order
464+
self.assertEqual(limited_items[0], test_items[2])
465+
self.assertEqual(limited_items[1], test_items[3])
466+
467+
# Test pop_item
468+
popped_item = await self.session.pop_item()
469+
self.assertEqual(popped_item, test_items[3]) # Last item
470+
471+
# Verify item was removed
472+
remaining_items = await self.session.get_items()
473+
self.assertEqual(len(remaining_items), 3)
474+
self.assertEqual(remaining_items[-1], test_items[2])
475+
476+
# Test clear_session
477+
await self.session.clear_session()
478+
final_items = await self.session.get_items()
479+
self.assertEqual(len(final_items), 0)
480+
481+
async def test_integration_multiple_sessions(self):
482+
"""Test that different sessions maintain separate data."""
483+
# Create a second session
484+
session2 = await PostgreSQLSession.from_connection_string(
485+
"test_integration_session_2",
486+
self.TEST_CONNECTION_STRING,
487+
sessions_table="test_sessions",
488+
messages_table="test_messages",
489+
)
490+
491+
try:
492+
# Add different items to each session
493+
items1 = cast(
494+
list[TResponseInputItem],
495+
[{"role": "user", "content": "Session 1 message", "type": "message"}],
496+
)
497+
items2 = cast(
498+
list[TResponseInputItem],
499+
[{"role": "user", "content": "Session 2 message", "type": "message"}],
500+
)
501+
502+
await self.session.add_items(items1)
503+
await session2.add_items(items2)
504+
505+
# Verify sessions have different data
506+
session1_items = await self.session.get_items()
507+
session2_items = await session2.get_items()
508+
509+
self.assertEqual(len(session1_items), 1)
510+
self.assertEqual(len(session2_items), 1)
511+
self.assertEqual(session1_items[0]["content"], "Session 1 message") # type: ignore
512+
self.assertEqual(session2_items[0]["content"], "Session 2 message") # type: ignore
513+
514+
finally:
515+
await session2.clear_session()
516+
await session2.close()
517+
518+
async def test_integration_empty_session_operations(self):
519+
"""Test operations on empty session."""
520+
# Pop from empty session
521+
popped = await self.session.pop_item()
522+
self.assertIsNone(popped)
523+
524+
# Get items from empty session
525+
items = await self.session.get_items()
526+
self.assertEqual(len(items), 0)
527+
528+
# Get items with limit from empty session
529+
limited_items = await self.session.get_items(limit=5)
530+
self.assertEqual(len(limited_items), 0)
531+
532+
# Clear empty session (should not error)
533+
await self.session.clear_session()
534+
535+
async def test_integration_connection_string_with_custom_tables(self):
536+
"""Test creating session with custom table names."""
537+
custom_session = await PostgreSQLSession.from_connection_string(
538+
"custom_table_test",
539+
self.TEST_CONNECTION_STRING,
540+
sessions_table="custom_sessions_table",
541+
messages_table="custom_messages_table",
542+
)
543+
544+
try:
545+
# Test basic functionality with custom tables
546+
test_items = cast(
547+
list[TResponseInputItem],
548+
[{"role": "user", "content": "Custom table test", "type": "message"}],
549+
)
550+
551+
await custom_session.add_items(test_items)
552+
stored_items = await custom_session.get_items()
553+
554+
self.assertEqual(len(stored_items), 1)
555+
self.assertEqual(stored_items[0]["content"], "Custom table test") # type: ignore
556+
557+
finally:
558+
await custom_session.clear_session()
559+
await custom_session.close()

0 commit comments

Comments
 (0)