Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 81 additions & 3 deletions src/agents/extensions/memory/async_sqlite_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def __init__(
db_path: str | Path = ":memory:",
sessions_table: str = "agent_sessions",
messages_table: str = "agent_messages",
users_table: str = "agent_users",
user_id: str | None = None,
):
"""Initialize the async SQLite session.

Expand All @@ -39,27 +41,52 @@ def __init__(
sessions_table: Name of the table to store session metadata. Defaults to
'agent_sessions'
messages_table: Name of the table to store message data. Defaults to 'agent_messages'
users_table: Name of the table to store user metadata. Defaults to 'agent_users'
user_id: Optional user identifier to associate this session with a user.
"""
self.session_id = session_id
self.user_id = user_id
self.db_path = db_path
self.sessions_table = sessions_table
self.messages_table = messages_table
self.users_table = users_table
self._connection: aiosqlite.Connection | None = None
self._lock = asyncio.Lock()
self._init_lock = asyncio.Lock()

async def _init_db_for_connection(self, conn: aiosqlite.Connection) -> None:
"""Initialize the database schema for a specific connection."""
await conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.users_table} (
user_id TEXT PRIMARY KEY,
metadata TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)

await conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.sessions_table} (
session_id TEXT PRIMARY KEY,
user_id TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES {self.users_table} (user_id)
ON DELETE SET NULL
)
"""
)

await conn.execute(
f"""
CREATE INDEX IF NOT EXISTS idx_{self.sessions_table}_user_id
ON {self.sessions_table} (user_id)
"""
)

await conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.messages_table} (
Expand Down Expand Up @@ -160,11 +187,21 @@ async def add_items(self, items: list[TResponseInputItem]) -> None:
return

async with self._locked_connection() as conn:
# Ensure user exists if user_id is provided
if self.user_id is not None:
await conn.execute(
f"""
INSERT OR IGNORE INTO {self.users_table} (user_id) VALUES (?)
""",
(self.user_id,),
)

await conn.execute(
f"""
INSERT OR IGNORE INTO {self.sessions_table} (session_id) VALUES (?)
INSERT OR IGNORE INTO {self.sessions_table} (session_id, user_id)
VALUES (?, ?)
""",
(self.session_id,),
(self.session_id, self.user_id),
)

message_data = [(self.session_id, json.dumps(item)) for item in items]
Expand Down Expand Up @@ -233,6 +270,47 @@ async def clear_session(self) -> None:
)
await conn.commit()

async def get_sessions_for_user(
self,
user_id: str,
limit: int | None = None,
offset: int = 0,
) -> list[str]:
"""Retrieve session IDs associated with a given user.

Args:
user_id: The user identifier to look up sessions for.
limit: Maximum number of session IDs to return. If None, returns all sessions.
offset: Number of sessions to skip before returning results. Defaults to 0.

Returns:
List of session IDs belonging to the user, ordered by most recently updated first.
"""
async with self._locked_connection() as conn:
if limit is None:
cursor = await conn.execute(
f"""
SELECT session_id FROM {self.sessions_table}
WHERE user_id = ?
ORDER BY updated_at DESC
LIMIT -1 OFFSET ?
""",
(user_id, offset),
)
else:
cursor = await conn.execute(
f"""
SELECT session_id FROM {self.sessions_table}
WHERE user_id = ?
ORDER BY updated_at DESC
LIMIT ? OFFSET ?
""",
(user_id, limit, offset),
)
rows = await cursor.fetchall()
await cursor.close()
return [row[0] for row in rows]

async def close(self) -> None:
"""Close the database connection."""
if self._connection is None:
Expand Down
82 changes: 81 additions & 1 deletion src/agents/extensions/memory/sqlalchemy_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ def __init__(
create_tables: bool = False,
sessions_table: str = "agent_sessions",
messages_table: str = "agent_messages",
users_table: str = "agent_users",
user_id: str | None = None,
session_settings: SessionSettings | None = None,
):
"""Initializes a new SQLAlchemySession.
Expand All @@ -100,9 +102,13 @@ def __init__(
development and testing when migrations aren't used.
sessions_table (str, optional): Override the default table name for sessions if needed.
messages_table (str, optional): Override the default table name for messages if needed.
users_table (str, optional): Override the default table name for users if needed.
user_id (str | None, optional): Optional user identifier to associate this session
with a user in the agent_users table.
session_settings (SessionSettings | None, optional): Session configuration settings
"""
self.session_id = session_id
self.user_id = user_id
self.session_settings = session_settings or SessionSettings()
self._engine = engine
self._init_lock = (
Expand All @@ -112,10 +118,36 @@ def __init__(
)

self._metadata = MetaData()
self._users = Table(
users_table,
self._metadata,
Column("user_id", String, primary_key=True),
Column("metadata", Text, nullable=True),
Column(
"created_at",
TIMESTAMP(timezone=False),
server_default=sql_text("CURRENT_TIMESTAMP"),
nullable=False,
),
Column(
"updated_at",
TIMESTAMP(timezone=False),
server_default=sql_text("CURRENT_TIMESTAMP"),
onupdate=sql_text("CURRENT_TIMESTAMP"),
nullable=False,
),
)

self._sessions = Table(
sessions_table,
self._metadata,
Column("session_id", String, primary_key=True),
Column(
"user_id",
String,
ForeignKey(f"{users_table}.user_id", ondelete="SET NULL"),
nullable=True,
),
Column(
"created_at",
TIMESTAMP(timezone=False),
Expand All @@ -129,6 +161,7 @@ def __init__(
onupdate=sql_text("CURRENT_TIMESTAMP"),
nullable=False,
),
Index(f"idx_{sessions_table}_user_id", "user_id"),
)

self._messages = Table(
Expand Down Expand Up @@ -296,6 +329,22 @@ async def add_items(self, items: list[TResponseInputItem]) -> None:

async with self._session_factory() as sess:
async with sess.begin():
# Ensure user exists if user_id is provided
if self.user_id is not None:
existing_user = await sess.execute(
select(self._users.c.user_id).where(
self._users.c.user_id == self.user_id
)
)
if not existing_user.scalar_one_or_none():
try:
async with sess.begin_nested():
await sess.execute(
insert(self._users).values({"user_id": self.user_id})
)
except IntegrityError:
pass

# Avoid check-then-insert races on the first write while keeping
# the common path free of avoidable integrity exceptions.
existing = await sess.execute(
Expand All @@ -307,7 +356,9 @@ async def add_items(self, items: list[TResponseInputItem]) -> None:
try:
async with sess.begin_nested():
await sess.execute(
insert(self._sessions).values({"session_id": self.session_id})
insert(self._sessions).values(
{"session_id": self.session_id, "user_id": self.user_id}
)
)
except IntegrityError:
# Another concurrent writer created the parent row first.
Expand Down Expand Up @@ -372,6 +423,35 @@ async def clear_session(self) -> None:
delete(self._sessions).where(self._sessions.c.session_id == self.session_id)
)

async def get_sessions_for_user(
self,
user_id: str,
limit: int | None = None,
offset: int = 0,
) -> list[str]:
"""Retrieve session IDs associated with a given user.

Args:
user_id: The user identifier to look up sessions for.
limit: Maximum number of session IDs to return. If None, returns all sessions.
offset: Number of sessions to skip before returning results. Defaults to 0.

Returns:
List of session IDs belonging to the user, ordered by most recently updated first.
"""
await self._ensure_tables()
async with self._session_factory() as sess:
stmt = (
select(self._sessions.c.session_id)
.where(self._sessions.c.user_id == user_id)
.order_by(self._sessions.c.updated_at.desc())
.offset(offset)
)
if limit is not None:
stmt = stmt.limit(limit)
result = await sess.execute(stmt)
return [row[0] for row in result.all()]

@property
def engine(self) -> AsyncEngine:
"""Access the underlying SQLAlchemy AsyncEngine.
Expand Down
2 changes: 2 additions & 0 deletions src/agents/memory/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class Session(Protocol):
"""

session_id: str
user_id: str | None = None
session_settings: SessionSettings | None = None

async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
Expand Down Expand Up @@ -65,6 +66,7 @@ class SessionABC(ABC):
"""

session_id: str
user_id: str | None = None
session_settings: SessionSettings | None = None

@abstractmethod
Expand Down
Loading