3
3
from typing import Any , cast
4
4
from unittest .mock import AsyncMock , patch
5
5
6
+ import pytest
6
7
from psycopg import AsyncConnection
7
8
from psycopg .rows import TupleRow
8
9
from psycopg_pool import AsyncConnectionPool
@@ -23,7 +24,7 @@ def __init__(self, return_value):
23
24
async def __aenter__ (self ):
24
25
return self .return_value
25
26
26
- async def __aexit__ (self , exc_type , exc_val , exc_tb ):
27
+ async def __aexit__ (self , _exc_type , _exc_val , _exc_tb ):
27
28
return None
28
29
29
30
@@ -54,7 +55,7 @@ def mock_connection() -> AsyncContextManagerMock:
54
55
55
56
self .mock_pool .connection = mock_connection
56
57
57
- def mock_cursor_method (* args : Any , ** kwargs : Any ) -> AsyncContextManagerMock :
58
+ def mock_cursor_method (* _args : Any , ** _kwargs : Any ) -> AsyncContextManagerMock :
58
59
return AsyncContextManagerMock (mock_cursor )
59
60
60
61
mock_conn .cursor = mock_cursor_method
@@ -353,3 +354,206 @@ async def test_close(self):
353
354
354
355
self .mock_pool .close .assert_called_once ()
355
356
self .assertFalse (self .session ._initialized )
357
+
358
+ @patch ("agents.extensions.memory.postgres_session.AsyncConnectionPool" )
359
+ async def test_from_connection_string_success (self , mock_pool_class ):
360
+ """Test creating a session from connection string."""
361
+ mock_pool = AsyncMock ()
362
+ mock_pool_class .return_value = mock_pool
363
+
364
+ connection_string = "postgresql://user:pass@host/db"
365
+ session_id = "test_session_123"
366
+
367
+ session = await PostgreSQLSession .from_connection_string (session_id , connection_string )
368
+
369
+ # Verify pool was created with the connection string
370
+ mock_pool_class .assert_called_once_with (connection_string )
371
+ mock_pool .open .assert_called_once ()
372
+
373
+ # Verify session was created with correct parameters
374
+ self .assertEqual (session .session_id , session_id )
375
+ self .assertEqual (session .pool , mock_pool )
376
+ self .assertEqual (session .sessions_table , "agent_sessions" )
377
+ self .assertEqual (session .messages_table , "agent_messages" )
378
+
379
+ @patch ("agents.extensions.memory.postgres_session.AsyncConnectionPool" )
380
+ async def test_from_connection_string_custom_tables (self , mock_pool_class ):
381
+ """Test creating a session from connection string with custom table names."""
382
+ mock_pool = AsyncMock ()
383
+ mock_pool_class .return_value = mock_pool
384
+
385
+ connection_string = "postgresql://user:pass@host/db"
386
+ session_id = "test_session_123"
387
+ custom_sessions_table = "custom_sessions"
388
+ custom_messages_table = "custom_messages"
389
+
390
+ session = await PostgreSQLSession .from_connection_string (
391
+ session_id ,
392
+ connection_string ,
393
+ sessions_table = custom_sessions_table ,
394
+ messages_table = custom_messages_table ,
395
+ )
396
+
397
+ # Verify pool was created with the connection string
398
+ mock_pool_class .assert_called_once_with (connection_string )
399
+ mock_pool .open .assert_called_once ()
400
+
401
+ # Verify session was created with correct parameters
402
+ self .assertEqual (session .session_id , session_id )
403
+ self .assertEqual (session .pool , mock_pool )
404
+ self .assertEqual (session .sessions_table , custom_sessions_table )
405
+ self .assertEqual (session .messages_table , custom_messages_table )
406
+
407
+
408
+ @pytest .mark .skip (reason = "Integration tests require a running PostgreSQL instance" )
409
+ class TestPostgreSQLSessionIntegration (unittest .IsolatedAsyncioTestCase ):
410
+ """Integration tests for PostgreSQL session that require a running database."""
411
+
412
+ # Test connection string - modify as needed for your test database
413
+ TEST_CONNECTION_STRING = "postgresql://postgres:password@localhost:5432/test_db"
414
+
415
+ async def asyncSetUp (self ):
416
+ """Set up test session."""
417
+ self .session_id = "test_integration_session"
418
+ self .session = await PostgreSQLSession .from_connection_string (
419
+ self .session_id ,
420
+ self .TEST_CONNECTION_STRING ,
421
+ sessions_table = "test_sessions" ,
422
+ messages_table = "test_messages" ,
423
+ )
424
+
425
+ # Clean up any existing test data
426
+ await self .session .clear_session ()
427
+
428
+ async def asyncTearDown (self ):
429
+ """Clean up after tests."""
430
+ if hasattr (self , "session" ):
431
+ await self .session .clear_session ()
432
+ await self .session .close ()
433
+
434
+ async def test_integration_full_workflow (self ):
435
+ """Test complete workflow: add items, get items, pop item, clear session."""
436
+ # Initially empty
437
+ items = await self .session .get_items ()
438
+ self .assertEqual (len (items ), 0 )
439
+
440
+ # Add some test items
441
+ test_items = cast (
442
+ list [TResponseInputItem ],
443
+ [
444
+ {"role" : "user" , "content" : "Hello" , "type" : "message" },
445
+ {"role" : "assistant" , "content" : "Hi there!" , "type" : "message" },
446
+ {"role" : "user" , "content" : "How are you?" , "type" : "message" },
447
+ {"role" : "assistant" , "content" : "I'm doing well, thank you!" , "type" : "message" },
448
+ ],
449
+ )
450
+
451
+ for item in test_items :
452
+ await self .session .add_items ([item ])
453
+
454
+ # Verify items were added
455
+ stored_items = await self .session .get_items ()
456
+ self .assertEqual (len (stored_items ), 4 )
457
+ self .assertEqual (stored_items [0 ], test_items [0 ])
458
+ self .assertEqual (stored_items [- 1 ], test_items [- 1 ])
459
+
460
+ # Test with limit
461
+ limited_items = await self .session .get_items (limit = 2 )
462
+ self .assertEqual (len (limited_items ), 2 )
463
+ # Should get the last 2 items in chronological order
464
+ self .assertEqual (limited_items [0 ], test_items [2 ])
465
+ self .assertEqual (limited_items [1 ], test_items [3 ])
466
+
467
+ # Test pop_item
468
+ popped_item = await self .session .pop_item ()
469
+ self .assertEqual (popped_item , test_items [3 ]) # Last item
470
+
471
+ # Verify item was removed
472
+ remaining_items = await self .session .get_items ()
473
+ self .assertEqual (len (remaining_items ), 3 )
474
+ self .assertEqual (remaining_items [- 1 ], test_items [2 ])
475
+
476
+ # Test clear_session
477
+ await self .session .clear_session ()
478
+ final_items = await self .session .get_items ()
479
+ self .assertEqual (len (final_items ), 0 )
480
+
481
+ async def test_integration_multiple_sessions (self ):
482
+ """Test that different sessions maintain separate data."""
483
+ # Create a second session
484
+ session2 = await PostgreSQLSession .from_connection_string (
485
+ "test_integration_session_2" ,
486
+ self .TEST_CONNECTION_STRING ,
487
+ sessions_table = "test_sessions" ,
488
+ messages_table = "test_messages" ,
489
+ )
490
+
491
+ try :
492
+ # Add different items to each session
493
+ items1 = cast (
494
+ list [TResponseInputItem ],
495
+ [{"role" : "user" , "content" : "Session 1 message" , "type" : "message" }],
496
+ )
497
+ items2 = cast (
498
+ list [TResponseInputItem ],
499
+ [{"role" : "user" , "content" : "Session 2 message" , "type" : "message" }],
500
+ )
501
+
502
+ await self .session .add_items (items1 )
503
+ await session2 .add_items (items2 )
504
+
505
+ # Verify sessions have different data
506
+ session1_items = await self .session .get_items ()
507
+ session2_items = await session2 .get_items ()
508
+
509
+ self .assertEqual (len (session1_items ), 1 )
510
+ self .assertEqual (len (session2_items ), 1 )
511
+ self .assertEqual (session1_items [0 ]["content" ], "Session 1 message" ) # type: ignore
512
+ self .assertEqual (session2_items [0 ]["content" ], "Session 2 message" ) # type: ignore
513
+
514
+ finally :
515
+ await session2 .clear_session ()
516
+ await session2 .close ()
517
+
518
+ async def test_integration_empty_session_operations (self ):
519
+ """Test operations on empty session."""
520
+ # Pop from empty session
521
+ popped = await self .session .pop_item ()
522
+ self .assertIsNone (popped )
523
+
524
+ # Get items from empty session
525
+ items = await self .session .get_items ()
526
+ self .assertEqual (len (items ), 0 )
527
+
528
+ # Get items with limit from empty session
529
+ limited_items = await self .session .get_items (limit = 5 )
530
+ self .assertEqual (len (limited_items ), 0 )
531
+
532
+ # Clear empty session (should not error)
533
+ await self .session .clear_session ()
534
+
535
+ async def test_integration_connection_string_with_custom_tables (self ):
536
+ """Test creating session with custom table names."""
537
+ custom_session = await PostgreSQLSession .from_connection_string (
538
+ "custom_table_test" ,
539
+ self .TEST_CONNECTION_STRING ,
540
+ sessions_table = "custom_sessions_table" ,
541
+ messages_table = "custom_messages_table" ,
542
+ )
543
+
544
+ try :
545
+ # Test basic functionality with custom tables
546
+ test_items = cast (
547
+ list [TResponseInputItem ],
548
+ [{"role" : "user" , "content" : "Custom table test" , "type" : "message" }],
549
+ )
550
+
551
+ await custom_session .add_items (test_items )
552
+ stored_items = await custom_session .get_items ()
553
+
554
+ self .assertEqual (len (stored_items ), 1 )
555
+ self .assertEqual (stored_items [0 ]["content" ], "Custom table test" ) # type: ignore
556
+
557
+ finally :
558
+ await custom_session .clear_session ()
559
+ await custom_session .close ()
0 commit comments