Skip to content

Add SQLAlchemy session backend for conversation history management #1357

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ voice = ["numpy>=2.2.0, <3; python_version>='3.10'", "websockets>=15.0, <16"]
viz = ["graphviz>=0.17"]
litellm = ["litellm>=1.67.4.post1, <2"]
realtime = ["websockets>=15.0, <16"]
sqlalchemy = ["SQLAlchemy>=2.0", "asyncpg>=0.29.0"]

[dependency-groups]
dev = [
Expand All @@ -63,6 +64,7 @@ dev = [
"mkdocs-static-i18n>=1.3.0",
"eval-type-backport>=0.2.2",
"fastapi >= 0.110.0, <1",
"aiosqlite>=0.21.0",
]

[tool.uv.workspace]
Expand Down
15 changes: 15 additions & 0 deletions src/agents/extensions/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

"""Session memory backends living in the extensions namespace.

This package contains optional, production-grade session implementations that
introduce extra third-party dependencies (database drivers, ORMs, etc.). They
conform to the :class:`agents.memory.session.Session` protocol so they can be
used as a drop-in replacement for :class:`agents.memory.session.SQLiteSession`.
"""
from __future__ import annotations

from .sqlalchemy_session import SQLAlchemySession # noqa: F401

__all__: list[str] = [
"SQLAlchemySession",
]
288 changes: 288 additions & 0 deletions src/agents/extensions/memory/sqlalchemy_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
"""SQLAlchemy-powered Session backend.

Usage::

from agents.extensions.memory import SQLAlchemySession

# Create from SQLAlchemy URL (uses asyncpg driver under the hood for Postgres)
session = SQLAlchemySession.from_url(
session_id="user-123",
url="postgresql+asyncpg://app:[email protected]/agents",
)

# Or pass an existing AsyncEngine that your application already manages
session = SQLAlchemySession(
session_id="user-123",
engine=my_async_engine,
)

await Runner.run(agent, "Hello", session=session)
"""

from __future__ import annotations

import asyncio
import json
from typing import Any

from sqlalchemy import (
TIMESTAMP,
Column,
ForeignKey,
Integer,
MetaData,
String,
Table,
Text,
delete,
insert,
select,
text as sql_text,
update,
)
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine

from ...items import TResponseInputItem
from ...memory.session import SessionABC


class SQLAlchemySession(SessionABC):
"""SQLAlchemy implementation of :pyclass:`agents.memory.session.Session`."""

_metadata: MetaData
_sessions: Table
_messages: Table

def __init__(
self,
session_id: str,
*,
engine: AsyncEngine,
create_tables: bool = True,
sessions_table: str = "agent_sessions",
messages_table: str = "agent_messages",
): # noqa: D401 – short description on the class-level docstring
"""Create a new session.

Parameters
----------
session_id
Unique identifier for the conversation.
engine
A pre-configured SQLAlchemy *async* engine. The engine **must** be
created with an async driver (``postgresql+asyncpg://``,
``mysql+aiomysql://`` or ``sqlite+aiosqlite://``).
create_tables
Whether to automatically create the required tables & indexes. Set
this to *False* if your migrations take care of schema management.
sessions_table, messages_table
Override default table names if needed.
"""
self.session_id = session_id
self._engine = engine
self._lock = asyncio.Lock()

self._metadata = MetaData()
self._sessions = Table(
sessions_table,
self._metadata,
Column("session_id", String, primary_key=True),
Column(
"created_at",
TIMESTAMP(timezone=False),
server_default=sql_text("CURRENT_TIMESTAMP"),
nullable=False,
),
Column(
"updated_at",
TIMESTAMP(timezone=False),
server_default=sql_text("CURRENT_TIMESTAMP"),
onupdate=sql_text("CURRENT_TIMESTAMP"),
nullable=False,
),
)

self._messages = Table(
messages_table,
self._metadata,
Column("id", Integer, primary_key=True, autoincrement=True),
Column(
"session_id",
String,
ForeignKey(f"{sessions_table}.session_id", ondelete="CASCADE"),
nullable=False,
),
Column("message_data", Text, nullable=False),
Column(
"created_at",
TIMESTAMP(timezone=False),
server_default=sql_text("CURRENT_TIMESTAMP"),
nullable=False,
),
sqlite_autoincrement=True,
)

# Index for efficient retrieval of messages per session ordered by time
from sqlalchemy import Index

Index(
f"idx_{messages_table}_session_time",
self._messages.c.session_id,
self._messages.c.created_at,
)

# Async session factory
self._session_factory = async_sessionmaker(
self._engine, expire_on_commit=False
)

self._create_tables = create_tables

# ---------------------------------------------------------------------
# Convenience constructors
# ---------------------------------------------------------------------
@classmethod
def from_url(
cls,
session_id: str,
*,
url: str,
engine_kwargs: dict[str, Any] | None = None,
**kwargs: Any,
) -> SQLAlchemySession:
"""Create a session from a database URL string.

Parameters
----------
session_id
Conversation ID.
url
Any SQLAlchemy async URL – e.g. ``"postgresql+asyncpg://user:pass@host/db"``.
engine_kwargs
Additional kwargs forwarded to :pyfunc:`sqlalchemy.ext.asyncio.create_async_engine`.
kwargs
Forwarded to the main constructor (``create_tables``, custom table names, …).
"""
engine_kwargs = engine_kwargs or {}
engine = create_async_engine(url, **engine_kwargs)
return cls(session_id, engine=engine, **kwargs)

# ------------------------------------------------------------------
# Session protocol implementation
# ------------------------------------------------------------------
async def _ensure_tables(self) -> None:
"""Ensure tables are created before any database operations."""
if self._create_tables:
async with self._engine.begin() as conn:
await conn.run_sync(self._metadata.create_all)
self._create_tables = False # Only create once

async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
await self._ensure_tables()
async with self._session_factory() as sess:
if limit is None:
stmt = (
select(self._messages.c.message_data)
.where(self._messages.c.session_id == self.session_id)
.order_by(self._messages.c.created_at.asc())
)
else:
stmt = (
select(self._messages.c.message_data)
.where(self._messages.c.session_id == self.session_id)
.order_by(self._messages.c.created_at.desc())
.limit(limit)
)

result = await sess.execute(stmt)
rows: list[str] = [row[0] for row in result.all()]

if limit is not None:
rows.reverse() # chronological order

items: list[TResponseInputItem] = []
for raw in rows:
try:
items.append(json.loads(raw))
except json.JSONDecodeError:
# Skip corrupted rows
continue
return items

async def add_items(self, items: list[TResponseInputItem]) -> None:
if not items:
return

await self._ensure_tables()
payload = [
{
"session_id": self.session_id,
"message_data": json.dumps(item, separators=(",", ":")),
}
for item in items
]

async with self._session_factory() as sess:
async with sess.begin():
# Ensure the parent session row exists - use merge for cross-DB compatibility
# Check if session exists
existing = await sess.execute(
select(self._sessions.c.session_id).where(
self._sessions.c.session_id == self.session_id
)
)
if not existing.scalar_one_or_none():
# Session doesn't exist, create it
await sess.execute(
insert(self._sessions).values({"session_id": self.session_id})
)

# Insert messages in bulk
await sess.execute(insert(self._messages), payload)

# Touch updated_at column
await sess.execute(
update(self._sessions)
.where(self._sessions.c.session_id == self.session_id)
.values(updated_at=sql_text("CURRENT_TIMESTAMP"))
)

async def pop_item(self) -> TResponseInputItem | None:
await self._ensure_tables()
async with self._session_factory() as sess:
async with sess.begin():
# Fallback for all dialects - get ID first, then delete
subq = (
select(self._messages.c.id)
.where(self._messages.c.session_id == self.session_id)
.order_by(self._messages.c.created_at.desc())
.limit(1)
)
res = await sess.execute(subq)
row_id = res.scalar_one_or_none()
if row_id is None:
return None
# Fetch data before deleting
res_data = await sess.execute(
select(self._messages.c.message_data).where(self._messages.c.id == row_id)
)
row = res_data.scalar_one_or_none()
await sess.execute(delete(self._messages).where(self._messages.c.id == row_id))

if row is None:
return None
try:
return json.loads(row) # type: ignore[no-any-return]
except json.JSONDecodeError:
return None

async def clear_session(self) -> None: # noqa: D401 – imperative mood is fine
await self._ensure_tables()
async with self._session_factory() as sess:
async with sess.begin():
await sess.execute(
delete(self._messages).where(self._messages.c.session_id == self.session_id)
)
await sess.execute(
delete(self._sessions).where(self._sessions.c.session_id == self.session_id)
)
Loading