Skip to content

Commit e71b6c1

Browse files
Initial pass at PostgreSQL session management
* Add postgres_session.py * Implements session management using PostgreSQL * Accepts a pool as a default argument * Additional class method for creation from a connection string * Add optional-dependency for psycopg Signed-off-by: Aidan Jensen <[email protected]>
1 parent 656ee0c commit e71b6c1

File tree

5 files changed

+690
-1
lines changed

5 files changed

+690
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ voice = ["numpy>=2.2.0, <3; python_version>='3.10'", "websockets>=15.0, <16"]
3838
viz = ["graphviz>=0.17"]
3939
litellm = ["litellm>=1.67.4.post1, <2"]
4040
realtime = ["websockets>=15.0, <16"]
41+
psycopg = ["psycopg[pool]>=3.2.9,<4"]
4142

4243
[dependency-groups]
4344
dev = [

src/agents/extensions/memory/__init__.py

Whitespace-only changes.
Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
from __future__ import annotations
2+
3+
import json
4+
from dataclasses import dataclass
5+
from typing import TYPE_CHECKING
6+
7+
try:
8+
import psycopg
9+
from psycopg import sql
10+
from psycopg.rows import class_row
11+
from psycopg_pool import AsyncConnectionPool
12+
except ImportError as _e:
13+
raise ImportError(
14+
"`psycopg` is required to use the PostgreSQLSession. You can install it via the optional "
15+
"dependency group: `pip install 'openai-agents[psycopg]'`."
16+
) from _e
17+
18+
if TYPE_CHECKING:
19+
from agents.items import TResponseInputItem
20+
21+
from agents.memory.session import Session
22+
23+
24+
@dataclass
25+
class MessageRow:
26+
"""Typed row for message queries."""
27+
28+
message_data: TResponseInputItem
29+
30+
31+
@dataclass
32+
class MessageWithIdRow:
33+
"""Typed row for message queries that include ID."""
34+
35+
id: int
36+
message_data: TResponseInputItem
37+
38+
39+
class PostgreSQLSession(Session):
40+
"""PostgreSQL-based implementation of session storage.
41+
42+
This implementation stores conversation history in a PostgreSQL database.
43+
Requires psycopg to be installed.
44+
"""
45+
46+
pool: AsyncConnectionPool
47+
48+
def __init__(
49+
self,
50+
session_id: str,
51+
pool: AsyncConnectionPool,
52+
sessions_table: str = "agent_sessions",
53+
messages_table: str = "agent_messages",
54+
):
55+
"""Initialize the PostgreSQL session.
56+
57+
Args:
58+
session_id: Unique identifier for the conversation session
59+
pool: PostgreSQL connection pool instance
60+
sessions_table: Name of the table to store session metadata. Defaults to
61+
'agent_sessions'
62+
messages_table: Name of the table to store message data. Defaults to 'agent_messages'
63+
"""
64+
if psycopg is None:
65+
raise ImportError(
66+
"psycopg is required for PostgreSQL session storage. "
67+
"Install with: pip install psycopg"
68+
)
69+
70+
self.session_id = session_id
71+
self.pool = pool
72+
self.sessions_table = sessions_table
73+
self.messages_table = messages_table
74+
self._initialized = False
75+
76+
@classmethod
77+
def from_connection_string(
78+
cls,
79+
session_id: str,
80+
connection_string: str,
81+
sessions_table: str = "agent_sessions",
82+
messages_table: str = "agent_messages",
83+
) -> PostgreSQLSession:
84+
"""Create a PostgreSQL session from a connection string.
85+
86+
Args:
87+
session_id: Unique identifier for the conversation session
88+
connection_string: PostgreSQL connection string (e.g., "postgresql://user:pass@host/db")
89+
sessions_table: Name of the table to store session metadata. Defaults to
90+
'agent_sessions'
91+
messages_table: Name of the table to store message data. Defaults to 'agent_messages'
92+
93+
Returns:
94+
PostgreSQLSession instance with a connection pool created from the connection string
95+
"""
96+
pool: AsyncConnectionPool = AsyncConnectionPool(connection_string)
97+
return cls(session_id, pool, sessions_table, messages_table)
98+
99+
async def _ensure_initialized(self) -> None:
100+
"""Ensure the database schema is initialized."""
101+
if not self._initialized:
102+
await self._init_db()
103+
104+
async def _init_db(self) -> None:
105+
"""Initialize the database schema."""
106+
async with self.pool.connection() as conn:
107+
async with conn.cursor() as cur:
108+
# Create sessions table
109+
query = sql.SQL("""
110+
CREATE TABLE IF NOT EXISTS {sessions_table} (
111+
session_id TEXT PRIMARY KEY,
112+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
113+
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
114+
)
115+
""").format(sessions_table=sql.Identifier(self.sessions_table))
116+
await cur.execute(query)
117+
118+
# Create messages table
119+
query = sql.SQL("""
120+
CREATE TABLE IF NOT EXISTS {messages_table} (
121+
id SERIAL PRIMARY KEY,
122+
session_id TEXT NOT NULL,
123+
message_data JSONB NOT NULL,
124+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
125+
FOREIGN KEY (session_id) REFERENCES {sessions_table} (session_id)
126+
ON DELETE CASCADE
127+
)
128+
""").format(
129+
messages_table=sql.Identifier(self.messages_table),
130+
sessions_table=sql.Identifier(self.sessions_table),
131+
)
132+
await cur.execute(query)
133+
134+
# Create index for better performance
135+
query = sql.SQL("""
136+
CREATE INDEX IF NOT EXISTS {index_name}
137+
ON {messages_table} (session_id, created_at)
138+
""").format(
139+
index_name=sql.Identifier(f"idx_{self.messages_table}_session_id"),
140+
messages_table=sql.Identifier(self.messages_table),
141+
)
142+
await cur.execute(query)
143+
144+
self._initialized = True
145+
146+
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
147+
"""Retrieve the conversation history for this session.
148+
149+
Args:
150+
limit: Maximum number of items to retrieve. If None, retrieves all items.
151+
When specified, returns the latest N items in chronological order.
152+
153+
Returns:
154+
List of input items representing the conversation history
155+
"""
156+
await self._ensure_initialized()
157+
158+
async with self.pool.connection() as conn:
159+
async with conn.cursor(row_factory=class_row(MessageRow)) as cur:
160+
if limit is None:
161+
# Fetch all items in chronological order
162+
query = sql.SQL("""
163+
SELECT message_data FROM {messages_table}
164+
WHERE session_id = %s
165+
ORDER BY created_at ASC
166+
""").format(messages_table=sql.Identifier(self.messages_table))
167+
await cur.execute(query, (self.session_id,))
168+
else:
169+
# Fetch the latest N items in chronological order
170+
query = sql.SQL("""
171+
SELECT message_data FROM (
172+
SELECT message_data FROM {messages_table}
173+
WHERE session_id = %s
174+
ORDER BY created_at DESC
175+
LIMIT %s
176+
) t
177+
ORDER BY created_at ASC
178+
""").format(messages_table=sql.Identifier(self.messages_table))
179+
await cur.execute(query, (self.session_id, limit))
180+
181+
rows = await cur.fetchall()
182+
183+
items = []
184+
for row in rows:
185+
try:
186+
# PostgreSQL JSONB automatically handles deserialization
187+
item = row.message_data
188+
items.append(item)
189+
except (AttributeError, TypeError):
190+
# Skip invalid entries
191+
continue
192+
193+
return items
194+
195+
async def add_items(self, items: list[TResponseInputItem]) -> None:
196+
"""Add new items to the conversation history.
197+
198+
Args:
199+
items: List of input items to add to the history
200+
"""
201+
if not items:
202+
return
203+
204+
await self._ensure_initialized()
205+
206+
async with self.pool.connection() as conn:
207+
async with conn.transaction():
208+
async with conn.cursor() as cur:
209+
# Ensure session exists
210+
query = sql.SQL("""
211+
INSERT INTO {sessions_table} (session_id)
212+
VALUES (%s)
213+
ON CONFLICT (session_id) DO NOTHING
214+
""").format(sessions_table=sql.Identifier(self.sessions_table))
215+
await cur.execute(query, (self.session_id,))
216+
217+
# Add items
218+
message_data = [(self.session_id, json.dumps(item)) for item in items]
219+
query = sql.SQL("""
220+
INSERT INTO {messages_table} (session_id, message_data)
221+
VALUES (%s, %s)
222+
""").format(messages_table=sql.Identifier(self.messages_table))
223+
await cur.executemany(query, message_data)
224+
225+
# Update session timestamp
226+
query = sql.SQL("""
227+
UPDATE {sessions_table}
228+
SET updated_at = CURRENT_TIMESTAMP
229+
WHERE session_id = %s
230+
""").format(sessions_table=sql.Identifier(self.sessions_table))
231+
await cur.execute(query, (self.session_id,))
232+
233+
async def pop_item(self) -> TResponseInputItem | None:
234+
"""Remove and return the most recent item from the session.
235+
236+
Returns:
237+
The most recent item if it exists, None if the session is empty
238+
"""
239+
await self._ensure_initialized()
240+
241+
async with self.pool.connection() as conn:
242+
async with conn.transaction():
243+
async with conn.cursor(row_factory=class_row(MessageRow)) as cur:
244+
# Delete and return the most recent item in one query
245+
query = sql.SQL("""
246+
DELETE FROM {messages_table}
247+
WHERE id = (
248+
SELECT id FROM {messages_table}
249+
WHERE session_id = %s
250+
ORDER BY created_at DESC
251+
LIMIT 1
252+
)
253+
RETURNING message_data
254+
""").format(messages_table=sql.Identifier(self.messages_table))
255+
await cur.execute(query, (self.session_id,))
256+
257+
row = await cur.fetchone()
258+
259+
if row is None:
260+
return None
261+
262+
try:
263+
# PostgreSQL JSONB automatically handles deserialization
264+
item = row.message_data
265+
return item
266+
except (AttributeError, TypeError):
267+
# Return None for corrupted entries (already deleted)
268+
return None
269+
270+
async def clear_session(self) -> None:
271+
"""Clear all items for this session."""
272+
await self._ensure_initialized()
273+
274+
async with self.pool.connection() as conn:
275+
async with conn.transaction():
276+
async with conn.cursor() as cur:
277+
query = sql.SQL("""
278+
DELETE FROM {messages_table} WHERE session_id = %s
279+
""").format(messages_table=sql.Identifier(self.messages_table))
280+
await cur.execute(query, (self.session_id,))
281+
282+
query = sql.SQL("""
283+
DELETE FROM {sessions_table} WHERE session_id = %s
284+
""").format(sessions_table=sql.Identifier(self.sessions_table))
285+
await cur.execute(query, (self.session_id,))
286+
287+
async def close(self) -> None:
288+
"""Close the database connection pool."""
289+
await self.pool.close()
290+
self._initialized = False

0 commit comments

Comments
 (0)