From e8b2036cb3b1d8cee5e9eabe35e03d85ad1a830b Mon Sep 17 00:00:00 2001 From: ecanlar Date: Tue, 31 Mar 2026 11:00:56 +0200 Subject: [PATCH 1/4] Add agent_users table to associate sessions with users MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduces user-session association across all SQL-based session backends (SQLiteSession, AsyncSQLiteSession, SQLAlchemySession), similar to how Google's ADK models the User → Session relationship. Changes: - Add agent_users table with user_id, metadata, and timestamps - Add user_id foreign key to agent_sessions table - Add optional user_id parameter to session constructors - Add get_sessions_for_user() method to query sessions by user - Add user_id attribute to Session protocol and SessionABC - Add tests for user association functionality Closes #2808 Co-Authored-By: Claude Opus 4.6 (1M context) --- .../extensions/memory/async_sqlite_session.py | 65 ++++++++++++++++- .../extensions/memory/sqlalchemy_session.py | 72 ++++++++++++++++++- src/agents/memory/session.py | 2 + src/agents/memory/sqlite_session.py | 69 +++++++++++++++++- tests/test_session.py | 52 ++++++++++++++ 5 files changed, 253 insertions(+), 7 deletions(-) diff --git a/src/agents/extensions/memory/async_sqlite_session.py b/src/agents/extensions/memory/async_sqlite_session.py index 2eef596264..f9e0d2a89a 100644 --- a/src/agents/extensions/memory/async_sqlite_session.py +++ b/src/agents/extensions/memory/async_sqlite_session.py @@ -30,6 +30,8 @@ def __init__( db_path: str | Path = ":memory:", sessions_table: str = "agent_sessions", messages_table: str = "agent_messages", + users_table: str = "agent_users", + user_id: str | None = None, ): """Initialize the async SQLite session. @@ -39,27 +41,52 @@ def __init__( sessions_table: Name of the table to store session metadata. Defaults to 'agent_sessions' messages_table: Name of the table to store message data. Defaults to 'agent_messages' + users_table: Name of the table to store user metadata. Defaults to 'agent_users' + user_id: Optional user identifier to associate this session with a user. """ self.session_id = session_id + self.user_id = user_id self.db_path = db_path self.sessions_table = sessions_table self.messages_table = messages_table + self.users_table = users_table self._connection: aiosqlite.Connection | None = None self._lock = asyncio.Lock() self._init_lock = asyncio.Lock() async def _init_db_for_connection(self, conn: aiosqlite.Connection) -> None: """Initialize the database schema for a specific connection.""" + await conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.users_table} ( + user_id TEXT PRIMARY KEY, + metadata TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """ + ) + await conn.execute( f""" CREATE TABLE IF NOT EXISTS {self.sessions_table} ( session_id TEXT PRIMARY KEY, + user_id TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES {self.users_table} (user_id) + ON DELETE SET NULL ) """ ) + await conn.execute( + f""" + CREATE INDEX IF NOT EXISTS idx_{self.sessions_table}_user_id + ON {self.sessions_table} (user_id) + """ + ) + await conn.execute( f""" CREATE TABLE IF NOT EXISTS {self.messages_table} ( @@ -160,11 +187,21 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: return async with self._locked_connection() as conn: + # Ensure user exists if user_id is provided + if self.user_id is not None: + await conn.execute( + f""" + INSERT OR IGNORE INTO {self.users_table} (user_id) VALUES (?) + """, + (self.user_id,), + ) + await conn.execute( f""" - INSERT OR IGNORE INTO {self.sessions_table} (session_id) VALUES (?) + INSERT OR IGNORE INTO {self.sessions_table} (session_id, user_id) + VALUES (?, ?) """, - (self.session_id,), + (self.session_id, self.user_id), ) message_data = [(self.session_id, json.dumps(item)) for item in items] @@ -233,6 +270,28 @@ async def clear_session(self) -> None: ) await conn.commit() + async def get_sessions_for_user(self, user_id: str) -> list[str]: + """Retrieve all session IDs associated with a given user. + + Args: + user_id: The user identifier to look up sessions for. + + Returns: + List of session IDs belonging to the user, ordered by most recently updated first. + """ + async with self._locked_connection() as conn: + cursor = await conn.execute( + f""" + SELECT session_id FROM {self.sessions_table} + WHERE user_id = ? + ORDER BY updated_at DESC + """, + (user_id,), + ) + rows = await cursor.fetchall() + await cursor.close() + return [row[0] for row in rows] + async def close(self) -> None: """Close the database connection.""" if self._connection is None: diff --git a/src/agents/extensions/memory/sqlalchemy_session.py b/src/agents/extensions/memory/sqlalchemy_session.py index 759ddaf5d5..3feb6047d4 100644 --- a/src/agents/extensions/memory/sqlalchemy_session.py +++ b/src/agents/extensions/memory/sqlalchemy_session.py @@ -86,6 +86,8 @@ def __init__( create_tables: bool = False, sessions_table: str = "agent_sessions", messages_table: str = "agent_messages", + users_table: str = "agent_users", + user_id: str | None = None, session_settings: SessionSettings | None = None, ): """Initializes a new SQLAlchemySession. @@ -100,9 +102,13 @@ def __init__( development and testing when migrations aren't used. sessions_table (str, optional): Override the default table name for sessions if needed. messages_table (str, optional): Override the default table name for messages if needed. + users_table (str, optional): Override the default table name for users if needed. + user_id (str | None, optional): Optional user identifier to associate this session + with a user in the agent_users table. session_settings (SessionSettings | None, optional): Session configuration settings """ self.session_id = session_id + self.user_id = user_id self.session_settings = session_settings or SessionSettings() self._engine = engine self._init_lock = ( @@ -112,10 +118,36 @@ def __init__( ) self._metadata = MetaData() + self._users = Table( + users_table, + self._metadata, + Column("user_id", String, primary_key=True), + Column("metadata", Text, nullable=True), + Column( + "created_at", + TIMESTAMP(timezone=False), + server_default=sql_text("CURRENT_TIMESTAMP"), + nullable=False, + ), + Column( + "updated_at", + TIMESTAMP(timezone=False), + server_default=sql_text("CURRENT_TIMESTAMP"), + onupdate=sql_text("CURRENT_TIMESTAMP"), + nullable=False, + ), + ) + self._sessions = Table( sessions_table, self._metadata, Column("session_id", String, primary_key=True), + Column( + "user_id", + String, + ForeignKey(f"{users_table}.user_id", ondelete="SET NULL"), + nullable=True, + ), Column( "created_at", TIMESTAMP(timezone=False), @@ -129,6 +161,7 @@ def __init__( onupdate=sql_text("CURRENT_TIMESTAMP"), nullable=False, ), + Index(f"idx_{sessions_table}_user_id", "user_id"), ) self._messages = Table( @@ -296,6 +329,22 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: async with self._session_factory() as sess: async with sess.begin(): + # Ensure user exists if user_id is provided + if self.user_id is not None: + existing_user = await sess.execute( + select(self._users.c.user_id).where( + self._users.c.user_id == self.user_id + ) + ) + if not existing_user.scalar_one_or_none(): + try: + async with sess.begin_nested(): + await sess.execute( + insert(self._users).values({"user_id": self.user_id}) + ) + except IntegrityError: + pass + # Avoid check-then-insert races on the first write while keeping # the common path free of avoidable integrity exceptions. existing = await sess.execute( @@ -307,7 +356,9 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: try: async with sess.begin_nested(): await sess.execute( - insert(self._sessions).values({"session_id": self.session_id}) + insert(self._sessions).values( + {"session_id": self.session_id, "user_id": self.user_id} + ) ) except IntegrityError: # Another concurrent writer created the parent row first. @@ -372,6 +423,25 @@ async def clear_session(self) -> None: delete(self._sessions).where(self._sessions.c.session_id == self.session_id) ) + async def get_sessions_for_user(self, user_id: str) -> list[str]: + """Retrieve all session IDs associated with a given user. + + Args: + user_id: The user identifier to look up sessions for. + + Returns: + List of session IDs belonging to the user, ordered by most recently updated first. + """ + await self._ensure_tables() + async with self._session_factory() as sess: + stmt = ( + select(self._sessions.c.session_id) + .where(self._sessions.c.user_id == user_id) + .order_by(self._sessions.c.updated_at.desc()) + ) + result = await sess.execute(stmt) + return [row[0] for row in result.all()] + @property def engine(self) -> AsyncEngine: """Access the underlying SQLAlchemy AsyncEngine. diff --git a/src/agents/memory/session.py b/src/agents/memory/session.py index 85a65a1690..539f658f84 100644 --- a/src/agents/memory/session.py +++ b/src/agents/memory/session.py @@ -19,6 +19,7 @@ class Session(Protocol): """ session_id: str + user_id: str | None = None session_settings: SessionSettings | None = None async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: @@ -65,6 +66,7 @@ class SessionABC(ABC): """ session_id: str + user_id: str | None = None session_settings: SessionSettings | None = None @abstractmethod diff --git a/src/agents/memory/sqlite_session.py b/src/agents/memory/sqlite_session.py index 92c9630c9b..5475f179c1 100644 --- a/src/agents/memory/sqlite_session.py +++ b/src/agents/memory/sqlite_session.py @@ -27,6 +27,8 @@ def __init__( db_path: str | Path = ":memory:", sessions_table: str = "agent_sessions", messages_table: str = "agent_messages", + users_table: str = "agent_users", + user_id: str | None = None, session_settings: SessionSettings | None = None, ): """Initialize the SQLite session. @@ -37,14 +39,19 @@ def __init__( sessions_table: Name of the table to store session metadata. Defaults to 'agent_sessions' messages_table: Name of the table to store message data. Defaults to 'agent_messages' + users_table: Name of the table to store user metadata. Defaults to 'agent_users' + user_id: Optional user identifier to associate this session with a user. + When provided, the session will be linked to the user in the agent_users table. session_settings: Session configuration settings including default limit for retrieving items. If None, uses default SessionSettings(). """ self.session_id = session_id + self.user_id = user_id self.session_settings = session_settings or SessionSettings() self.db_path = db_path self.sessions_table = sessions_table self.messages_table = messages_table + self.users_table = users_table self._local = threading.local() self._lock = threading.Lock() @@ -82,16 +89,37 @@ def _get_connection(self) -> sqlite3.Connection: def _init_db_for_connection(self, conn: sqlite3.Connection) -> None: """Initialize the database schema for a specific connection.""" + conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.users_table} ( + user_id TEXT PRIMARY KEY, + metadata TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """ + ) + conn.execute( f""" CREATE TABLE IF NOT EXISTS {self.sessions_table} ( session_id TEXT PRIMARY KEY, + user_id TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES {self.users_table} (user_id) + ON DELETE SET NULL ) """ ) + conn.execute( + f""" + CREATE INDEX IF NOT EXISTS idx_{self.sessions_table}_user_id + ON {self.sessions_table} (user_id) + """ + ) + conn.execute( f""" CREATE TABLE IF NOT EXISTS {self.messages_table} ( @@ -183,12 +211,22 @@ def _add_items_sync(): conn = self._get_connection() with self._lock if self._is_memory_db else threading.Lock(): + # Ensure user exists if user_id is provided + if self.user_id is not None: + conn.execute( + f""" + INSERT OR IGNORE INTO {self.users_table} (user_id) VALUES (?) + """, + (self.user_id,), + ) + # Ensure session exists conn.execute( f""" - INSERT OR IGNORE INTO {self.sessions_table} (session_id) VALUES (?) + INSERT OR IGNORE INTO {self.sessions_table} (session_id, user_id) + VALUES (?, ?) """, - (self.session_id,), + (self.session_id, self.user_id), ) # Add items @@ -273,6 +311,31 @@ def _clear_session_sync(): await asyncio.to_thread(_clear_session_sync) + async def get_sessions_for_user(self, user_id: str) -> list[str]: + """Retrieve all session IDs associated with a given user. + + Args: + user_id: The user identifier to look up sessions for. + + Returns: + List of session IDs belonging to the user, ordered by most recently updated first. + """ + + def _get_sessions_sync(): + conn = self._get_connection() + with self._lock if self._is_memory_db else threading.Lock(): + cursor = conn.execute( + f""" + SELECT session_id FROM {self.sessions_table} + WHERE user_id = ? + ORDER BY updated_at DESC + """, + (user_id,), + ) + return [row[0] for row in cursor.fetchall()] + + return await asyncio.to_thread(_get_sessions_sync) + def close(self) -> None: """Close the database connection.""" if self._is_memory_db: diff --git a/tests/test_session.py b/tests/test_session.py index aaa80ec7aa..22dc34883a 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -671,6 +671,58 @@ async def test_session_settings_resolve(): assert final_none.limit == 100 +@pytest.mark.asyncio +async def test_sqlite_session_user_association(): + """Test that sessions can be associated with users via user_id.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_users.db" + + # Create sessions for user_1 + session_a = SQLiteSession("session_a", db_path, user_id="user_1") + session_b = SQLiteSession("session_b", db_path, user_id="user_1") + # Create a session for user_2 + session_c = SQLiteSession("session_c", db_path, user_id="user_2") + # Create a session without a user + session_d = SQLiteSession("session_d", db_path) + + # Add items to trigger session/user creation + items: list[TResponseInputItem] = [{"role": "user", "content": "Hello"}] + await session_a.add_items(items) + await session_b.add_items(items) + await session_c.add_items(items) + await session_d.add_items(items) + + # Query sessions for user_1 + user_1_sessions = await session_a.get_sessions_for_user("user_1") + assert set(user_1_sessions) == {"session_a", "session_b"} + + # Query sessions for user_2 + user_2_sessions = await session_a.get_sessions_for_user("user_2") + assert user_2_sessions == ["session_c"] + + # Query sessions for non-existent user + empty_sessions = await session_a.get_sessions_for_user("user_999") + assert empty_sessions == [] + + session_a.close() + session_b.close() + session_c.close() + session_d.close() + + +@pytest.mark.asyncio +async def test_sqlite_session_user_id_attribute(): + """Test that user_id is correctly stored on the session instance.""" + session_with_user = SQLiteSession("s1", user_id="alice") + assert session_with_user.user_id == "alice" + + session_without_user = SQLiteSession("s2") + assert session_without_user.user_id is None + + session_with_user.close() + session_without_user.close() + + @pytest.mark.asyncio async def test_runner_with_session_settings_override(): """Test that RunConfig can override session's default settings.""" From 826a5fc60e93f9770d0b6afb59883f7b45c506f4 Mon Sep 17 00:00:00 2001 From: ecanlar Date: Tue, 31 Mar 2026 11:05:33 +0200 Subject: [PATCH 2/4] Add limit/offset pagination to get_sessions_for_user Adds limit and offset parameters to get_sessions_for_user() across all three SQL backends, consistent with how get_items() supports limiting retrieved history. This enables paginated retrieval of user sessions. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../extensions/memory/async_sqlite_session.py | 39 +++++++++++----- .../extensions/memory/sqlalchemy_session.py | 14 +++++- src/agents/memory/sqlite_session.py | 39 +++++++++++----- tests/test_session.py | 45 +++++++++++++++++++ 4 files changed, 115 insertions(+), 22 deletions(-) diff --git a/src/agents/extensions/memory/async_sqlite_session.py b/src/agents/extensions/memory/async_sqlite_session.py index f9e0d2a89a..c16e332e45 100644 --- a/src/agents/extensions/memory/async_sqlite_session.py +++ b/src/agents/extensions/memory/async_sqlite_session.py @@ -270,24 +270,43 @@ async def clear_session(self) -> None: ) await conn.commit() - async def get_sessions_for_user(self, user_id: str) -> list[str]: - """Retrieve all session IDs associated with a given user. + async def get_sessions_for_user( + self, + user_id: str, + limit: int | None = None, + offset: int = 0, + ) -> list[str]: + """Retrieve session IDs associated with a given user. Args: user_id: The user identifier to look up sessions for. + limit: Maximum number of session IDs to return. If None, returns all sessions. + offset: Number of sessions to skip before returning results. Defaults to 0. Returns: List of session IDs belonging to the user, ordered by most recently updated first. """ async with self._locked_connection() as conn: - cursor = await conn.execute( - f""" - SELECT session_id FROM {self.sessions_table} - WHERE user_id = ? - ORDER BY updated_at DESC - """, - (user_id,), - ) + if limit is None: + cursor = await conn.execute( + f""" + SELECT session_id FROM {self.sessions_table} + WHERE user_id = ? + ORDER BY updated_at DESC + LIMIT -1 OFFSET ? + """, + (user_id, offset), + ) + else: + cursor = await conn.execute( + f""" + SELECT session_id FROM {self.sessions_table} + WHERE user_id = ? + ORDER BY updated_at DESC + LIMIT ? OFFSET ? + """, + (user_id, limit, offset), + ) rows = await cursor.fetchall() await cursor.close() return [row[0] for row in rows] diff --git a/src/agents/extensions/memory/sqlalchemy_session.py b/src/agents/extensions/memory/sqlalchemy_session.py index 3feb6047d4..62c5c6141f 100644 --- a/src/agents/extensions/memory/sqlalchemy_session.py +++ b/src/agents/extensions/memory/sqlalchemy_session.py @@ -423,11 +423,18 @@ async def clear_session(self) -> None: delete(self._sessions).where(self._sessions.c.session_id == self.session_id) ) - async def get_sessions_for_user(self, user_id: str) -> list[str]: - """Retrieve all session IDs associated with a given user. + async def get_sessions_for_user( + self, + user_id: str, + limit: int | None = None, + offset: int = 0, + ) -> list[str]: + """Retrieve session IDs associated with a given user. Args: user_id: The user identifier to look up sessions for. + limit: Maximum number of session IDs to return. If None, returns all sessions. + offset: Number of sessions to skip before returning results. Defaults to 0. Returns: List of session IDs belonging to the user, ordered by most recently updated first. @@ -438,7 +445,10 @@ async def get_sessions_for_user(self, user_id: str) -> list[str]: select(self._sessions.c.session_id) .where(self._sessions.c.user_id == user_id) .order_by(self._sessions.c.updated_at.desc()) + .offset(offset) ) + if limit is not None: + stmt = stmt.limit(limit) result = await sess.execute(stmt) return [row[0] for row in result.all()] diff --git a/src/agents/memory/sqlite_session.py b/src/agents/memory/sqlite_session.py index 5475f179c1..1b514950be 100644 --- a/src/agents/memory/sqlite_session.py +++ b/src/agents/memory/sqlite_session.py @@ -311,11 +311,18 @@ def _clear_session_sync(): await asyncio.to_thread(_clear_session_sync) - async def get_sessions_for_user(self, user_id: str) -> list[str]: - """Retrieve all session IDs associated with a given user. + async def get_sessions_for_user( + self, + user_id: str, + limit: int | None = None, + offset: int = 0, + ) -> list[str]: + """Retrieve session IDs associated with a given user. Args: user_id: The user identifier to look up sessions for. + limit: Maximum number of session IDs to return. If None, returns all sessions. + offset: Number of sessions to skip before returning results. Defaults to 0. Returns: List of session IDs belonging to the user, ordered by most recently updated first. @@ -324,14 +331,26 @@ async def get_sessions_for_user(self, user_id: str) -> list[str]: def _get_sessions_sync(): conn = self._get_connection() with self._lock if self._is_memory_db else threading.Lock(): - cursor = conn.execute( - f""" - SELECT session_id FROM {self.sessions_table} - WHERE user_id = ? - ORDER BY updated_at DESC - """, - (user_id,), - ) + if limit is None: + cursor = conn.execute( + f""" + SELECT session_id FROM {self.sessions_table} + WHERE user_id = ? + ORDER BY updated_at DESC + LIMIT -1 OFFSET ? + """, + (user_id, offset), + ) + else: + cursor = conn.execute( + f""" + SELECT session_id FROM {self.sessions_table} + WHERE user_id = ? + ORDER BY updated_at DESC + LIMIT ? OFFSET ? + """, + (user_id, limit, offset), + ) return [row[0] for row in cursor.fetchall()] return await asyncio.to_thread(_get_sessions_sync) diff --git a/tests/test_session.py b/tests/test_session.py index 22dc34883a..a325ad0ed2 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -710,6 +710,51 @@ async def test_sqlite_session_user_association(): session_d.close() +@pytest.mark.asyncio +async def test_sqlite_session_get_sessions_for_user_pagination(): + """Test limit and offset pagination on get_sessions_for_user.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_pagination.db" + items: list[TResponseInputItem] = [{"role": "user", "content": "hi"}] + + # Create 5 sessions for the same user, adding items sequentially so + # updated_at ordering is deterministic (most recent last). + session_ids = [f"s{i}" for i in range(5)] + sessions = [] + for sid in session_ids: + s = SQLiteSession(sid, db_path, user_id="paginated_user") + await s.add_items(items) + sessions.append(s) + + ref = sessions[0] # any session instance sharing the same db + + # Without limit — returns all 5 + all_ids = await ref.get_sessions_for_user("paginated_user") + assert len(all_ids) == 5 + + # limit=2 — returns the 2 most recently updated + page1 = await ref.get_sessions_for_user("paginated_user", limit=2) + assert len(page1) == 2 + + # limit=2, offset=2 — next page + page2 = await ref.get_sessions_for_user("paginated_user", limit=2, offset=2) + assert len(page2) == 2 + + # limit=2, offset=4 — last page (only 1 left) + page3 = await ref.get_sessions_for_user("paginated_user", limit=2, offset=4) + assert len(page3) == 1 + + # All pages together should cover all session ids + assert set(page1 + page2 + page3) == set(session_ids) + + # offset beyond total — empty + empty = await ref.get_sessions_for_user("paginated_user", offset=10) + assert empty == [] + + for s in sessions: + s.close() + + @pytest.mark.asyncio async def test_sqlite_session_user_id_attribute(): """Test that user_id is correctly stored on the session instance.""" From 65d97e10e1e9b9f5480c86391e6c9165f8c102c6 Mon Sep 17 00:00:00 2001 From: ecantn Date: Tue, 31 Mar 2026 11:16:36 +0200 Subject: [PATCH 3/4] feat: Add create_session and get_session classmethods to session backends Introduces two new classmethods for all SQL-based session backends: - create_session(user_id, ...): Creates a new session with an auto-generated UUID session_id, persisting user and session rows immediately. - get_session(user_id, session_id, ...): Retrieves an existing session, verifying it belongs to the given user. Returns None if not found. This builds on the agent_users table from #2809 and provides a proper factory pattern where the session_id is generated internally rather than requiring the caller to invent one. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../extensions/memory/async_sqlite_session.py | 94 ++++++++++++++ .../extensions/memory/sqlalchemy_session.py | 116 ++++++++++++++++++ src/agents/memory/session.py | 31 +++++ src/agents/memory/sqlite_session.py | 104 ++++++++++++++++ tests/test_session.py | 83 +++++++++++++ 5 files changed, 428 insertions(+) diff --git a/src/agents/extensions/memory/async_sqlite_session.py b/src/agents/extensions/memory/async_sqlite_session.py index c16e332e45..ff3479c28c 100644 --- a/src/agents/extensions/memory/async_sqlite_session.py +++ b/src/agents/extensions/memory/async_sqlite_session.py @@ -2,6 +2,7 @@ import asyncio import json +import uuid from collections.abc import AsyncIterator from contextlib import asynccontextmanager from pathlib import Path @@ -270,6 +271,99 @@ async def clear_session(self) -> None: ) await conn.commit() + @classmethod + async def create_session( + cls, + user_id: str, + db_path: str | Path = ":memory:", + sessions_table: str = "agent_sessions", + messages_table: str = "agent_messages", + users_table: str = "agent_users", + ) -> AsyncSQLiteSession: + """Create a new session for a user with an auto-generated session ID. + + Args: + user_id: The user identifier to associate with the new session. + db_path: Path to the SQLite database file. Defaults to ':memory:'. + sessions_table: Name of the sessions table. Defaults to 'agent_sessions'. + messages_table: Name of the messages table. Defaults to 'agent_messages'. + users_table: Name of the users table. Defaults to 'agent_users'. + + Returns: + A new AsyncSQLiteSession instance with an auto-generated session_id. + """ + session_id = str(uuid.uuid4()) + session = cls( + session_id=session_id, + db_path=db_path, + sessions_table=sessions_table, + messages_table=messages_table, + users_table=users_table, + user_id=user_id, + ) + + async with session._locked_connection() as conn: + await conn.execute( + f"INSERT OR IGNORE INTO {users_table} (user_id) VALUES (?)", + (user_id,), + ) + await conn.execute( + f"INSERT INTO {sessions_table} (session_id, user_id) VALUES (?, ?)", + (session_id, user_id), + ) + await conn.commit() + + return session + + @classmethod + async def get_session( + cls, + user_id: str, + session_id: str, + db_path: str | Path = ":memory:", + sessions_table: str = "agent_sessions", + messages_table: str = "agent_messages", + users_table: str = "agent_users", + ) -> AsyncSQLiteSession | None: + """Retrieve an existing session for a user. + + Args: + user_id: The user identifier who owns the session. + session_id: The session identifier to retrieve. + db_path: Path to the SQLite database file. Defaults to ':memory:'. + sessions_table: Name of the sessions table. Defaults to 'agent_sessions'. + messages_table: Name of the messages table. Defaults to 'agent_messages'. + users_table: Name of the users table. Defaults to 'agent_users'. + + Returns: + The AsyncSQLiteSession instance if it exists and belongs to the user, + None otherwise. + """ + session = cls( + session_id=session_id, + db_path=db_path, + sessions_table=sessions_table, + messages_table=messages_table, + users_table=users_table, + user_id=user_id, + ) + + async with session._locked_connection() as conn: + cursor = await conn.execute( + f""" + SELECT session_id FROM {sessions_table} + WHERE session_id = ? AND user_id = ? + """, + (session_id, user_id), + ) + row = await cursor.fetchone() + await cursor.close() + + if row is None: + await session.close() + return None + return session + async def get_sessions_for_user( self, user_id: str, diff --git a/src/agents/extensions/memory/sqlalchemy_session.py b/src/agents/extensions/memory/sqlalchemy_session.py index 62c5c6141f..b838b0dad4 100644 --- a/src/agents/extensions/memory/sqlalchemy_session.py +++ b/src/agents/extensions/memory/sqlalchemy_session.py @@ -26,6 +26,7 @@ import asyncio import json import threading +import uuid from typing import Any, ClassVar from sqlalchemy import ( @@ -423,6 +424,121 @@ async def clear_session(self) -> None: delete(self._sessions).where(self._sessions.c.session_id == self.session_id) ) + @classmethod + async def create_session( + cls, + user_id: str, + *, + engine: AsyncEngine, + create_tables: bool = False, + sessions_table: str = "agent_sessions", + messages_table: str = "agent_messages", + users_table: str = "agent_users", + session_settings: SessionSettings | None = None, + ) -> SQLAlchemySession: + """Create a new session for a user with an auto-generated session ID. + + Args: + user_id: The user identifier to associate with the new session. + engine: A pre-configured SQLAlchemy async engine. + create_tables: Whether to auto-create tables. Defaults to False. + sessions_table: Override the default table name for sessions. + messages_table: Override the default table name for messages. + users_table: Override the default table name for users. + session_settings: Session configuration settings. + + Returns: + A new SQLAlchemySession instance with an auto-generated session_id. + """ + session_id = str(uuid.uuid4()) + session = cls( + session_id=session_id, + engine=engine, + create_tables=create_tables, + sessions_table=sessions_table, + messages_table=messages_table, + users_table=users_table, + user_id=user_id, + session_settings=session_settings, + ) + + await session._ensure_tables() + async with session._session_factory() as sess: + async with sess.begin(): + existing_user = await sess.execute( + select(session._users.c.user_id).where( + session._users.c.user_id == user_id + ) + ) + if not existing_user.scalar_one_or_none(): + try: + async with sess.begin_nested(): + await sess.execute( + insert(session._users).values({"user_id": user_id}) + ) + except IntegrityError: + pass + await sess.execute( + insert(session._sessions).values( + {"session_id": session_id, "user_id": user_id} + ) + ) + + return session + + @classmethod + async def get_session( + cls, + user_id: str, + session_id: str, + *, + engine: AsyncEngine, + create_tables: bool = False, + sessions_table: str = "agent_sessions", + messages_table: str = "agent_messages", + users_table: str = "agent_users", + session_settings: SessionSettings | None = None, + ) -> SQLAlchemySession | None: + """Retrieve an existing session for a user. + + Args: + user_id: The user identifier who owns the session. + session_id: The session identifier to retrieve. + engine: A pre-configured SQLAlchemy async engine. + create_tables: Whether to auto-create tables. Defaults to False. + sessions_table: Override the default table name for sessions. + messages_table: Override the default table name for messages. + users_table: Override the default table name for users. + session_settings: Session configuration settings. + + Returns: + The SQLAlchemySession instance if it exists and belongs to the user, + None otherwise. + """ + session = cls( + session_id=session_id, + engine=engine, + create_tables=create_tables, + sessions_table=sessions_table, + messages_table=messages_table, + users_table=users_table, + user_id=user_id, + session_settings=session_settings, + ) + + await session._ensure_tables() + async with session._session_factory() as sess: + result = await sess.execute( + select(session._sessions.c.session_id).where( + (session._sessions.c.session_id == session_id) + & (session._sessions.c.user_id == user_id) + ) + ) + if not result.scalar_one_or_none(): + return None + + return session + async def get_sessions_for_user( self, user_id: str, diff --git a/src/agents/memory/session.py b/src/agents/memory/session.py index 539f658f84..7dd6578443 100644 --- a/src/agents/memory/session.py +++ b/src/agents/memory/session.py @@ -69,6 +69,37 @@ class SessionABC(ABC): user_id: str | None = None session_settings: SessionSettings | None = None + @classmethod + @abstractmethod + async def create_session(cls, user_id: str, **kwargs: object) -> SessionABC: + """Create a new session for a user with an auto-generated session ID. + + Args: + user_id: The user identifier to associate with the new session. + **kwargs: Backend-specific configuration (e.g., db_path, engine). + + Returns: + A new session instance with an auto-generated session_id. + """ + ... + + @classmethod + @abstractmethod + async def get_session( + cls, user_id: str, session_id: str, **kwargs: object + ) -> SessionABC | None: + """Retrieve an existing session for a user. + + Args: + user_id: The user identifier who owns the session. + session_id: The session identifier to retrieve. + **kwargs: Backend-specific configuration (e.g., db_path, engine). + + Returns: + The session instance if it exists and belongs to the user, None otherwise. + """ + ... + @abstractmethod async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. diff --git a/src/agents/memory/sqlite_session.py b/src/agents/memory/sqlite_session.py index 1b514950be..820521dfc4 100644 --- a/src/agents/memory/sqlite_session.py +++ b/src/agents/memory/sqlite_session.py @@ -4,6 +4,7 @@ import json import sqlite3 import threading +import uuid from pathlib import Path from ..items import TResponseInputItem @@ -311,6 +312,109 @@ def _clear_session_sync(): await asyncio.to_thread(_clear_session_sync) + @classmethod + async def create_session( + cls, + user_id: str, + db_path: str | Path = ":memory:", + sessions_table: str = "agent_sessions", + messages_table: str = "agent_messages", + users_table: str = "agent_users", + session_settings: SessionSettings | None = None, + ) -> SQLiteSession: + """Create a new session for a user with an auto-generated session ID. + + Args: + user_id: The user identifier to associate with the new session. + db_path: Path to the SQLite database file. Defaults to ':memory:'. + sessions_table: Name of the sessions table. Defaults to 'agent_sessions'. + messages_table: Name of the messages table. Defaults to 'agent_messages'. + users_table: Name of the users table. Defaults to 'agent_users'. + session_settings: Session configuration settings. + + Returns: + A new SQLiteSession instance with an auto-generated session_id. + """ + session_id = str(uuid.uuid4()) + session = cls( + session_id=session_id, + db_path=db_path, + sessions_table=sessions_table, + messages_table=messages_table, + users_table=users_table, + user_id=user_id, + session_settings=session_settings, + ) + + def _persist_session(): + conn = session._get_connection() + with session._lock if session._is_memory_db else threading.Lock(): + conn.execute( + f"INSERT OR IGNORE INTO {users_table} (user_id) VALUES (?)", + (user_id,), + ) + conn.execute( + f"INSERT INTO {sessions_table} (session_id, user_id) VALUES (?, ?)", + (session_id, user_id), + ) + conn.commit() + + await asyncio.to_thread(_persist_session) + return session + + @classmethod + async def get_session( + cls, + user_id: str, + session_id: str, + db_path: str | Path = ":memory:", + sessions_table: str = "agent_sessions", + messages_table: str = "agent_messages", + users_table: str = "agent_users", + session_settings: SessionSettings | None = None, + ) -> SQLiteSession | None: + """Retrieve an existing session for a user. + + Args: + user_id: The user identifier who owns the session. + session_id: The session identifier to retrieve. + db_path: Path to the SQLite database file. Defaults to ':memory:'. + sessions_table: Name of the sessions table. Defaults to 'agent_sessions'. + messages_table: Name of the messages table. Defaults to 'agent_messages'. + users_table: Name of the users table. Defaults to 'agent_users'. + session_settings: Session configuration settings. + + Returns: + The SQLiteSession instance if it exists and belongs to the user, None otherwise. + """ + session = cls( + session_id=session_id, + db_path=db_path, + sessions_table=sessions_table, + messages_table=messages_table, + users_table=users_table, + user_id=user_id, + session_settings=session_settings, + ) + + def _check_session(): + conn = session._get_connection() + with session._lock if session._is_memory_db else threading.Lock(): + cursor = conn.execute( + f""" + SELECT session_id FROM {sessions_table} + WHERE session_id = ? AND user_id = ? + """, + (session_id, user_id), + ) + return cursor.fetchone() is not None + + exists = await asyncio.to_thread(_check_session) + if not exists: + session.close() + return None + return session + async def get_sessions_for_user( self, user_id: str, diff --git a/tests/test_session.py b/tests/test_session.py index a325ad0ed2..a2e16ec606 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -768,6 +768,89 @@ async def test_sqlite_session_user_id_attribute(): session_without_user.close() +@pytest.mark.asyncio +async def test_sqlite_create_session(): + """Test create_session generates a UUID session_id and persists user and session.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_create.db" + + session = await SQLiteSession.create_session("alice", db_path=db_path) + + # session_id should be a valid UUID + import uuid + + uuid.UUID(session.session_id) # raises if not valid + assert session.user_id == "alice" + + # Session should be queryable via get_sessions_for_user + sessions = await session.get_sessions_for_user("alice") + assert session.session_id in sessions + + # Adding items should work normally + items: list[TResponseInputItem] = [{"role": "user", "content": "Hello"}] + await session.add_items(items) + retrieved = await session.get_items() + assert len(retrieved) == 1 + assert retrieved[0]["content"] == "Hello" + + session.close() + + +@pytest.mark.asyncio +async def test_sqlite_get_session(): + """Test get_session retrieves an existing session or returns None.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_get.db" + + # Create a session first + created = await SQLiteSession.create_session("bob", db_path=db_path) + session_id = created.session_id + + # Add some items + items: list[TResponseInputItem] = [{"role": "user", "content": "Hi Bob"}] + await created.add_items(items) + created.close() + + # Retrieve it + retrieved = await SQLiteSession.get_session("bob", session_id, db_path=db_path) + assert retrieved is not None + assert retrieved.session_id == session_id + assert retrieved.user_id == "bob" + + # Verify items are preserved + history = await retrieved.get_items() + assert len(history) == 1 + assert history[0]["content"] == "Hi Bob" + retrieved.close() + + # Wrong user_id should return None + wrong_user = await SQLiteSession.get_session("eve", session_id, db_path=db_path) + assert wrong_user is None + + # Non-existent session_id should return None + missing = await SQLiteSession.get_session("bob", "non-existent-id", db_path=db_path) + assert missing is None + + +@pytest.mark.asyncio +async def test_sqlite_create_multiple_sessions_for_user(): + """Test creating multiple sessions for the same user.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_multi.db" + + s1 = await SQLiteSession.create_session("charlie", db_path=db_path) + s2 = await SQLiteSession.create_session("charlie", db_path=db_path) + + assert s1.session_id != s2.session_id + assert s1.user_id == s2.user_id == "charlie" + + sessions = await s1.get_sessions_for_user("charlie") + assert set(sessions) == {s1.session_id, s2.session_id} + + s1.close() + s2.close() + + @pytest.mark.asyncio async def test_runner_with_session_settings_override(): """Test that RunConfig can override session's default settings.""" From a483fdf28ac9150f2f0190beab44743600c9b597 Mon Sep 17 00:00:00 2001 From: ecantn Date: Tue, 31 Mar 2026 12:03:15 +0200 Subject: [PATCH 4/4] refactor: replace get_session with get_sessions, remove abstract methods from SessionABC MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Rename get_session → get_sessions: returns all sessions for a user as a list of session instances (not just IDs), with limit/offset support - Remove create_session/get_session abstract methods from SessionABC to avoid forcing all implementations to support user_id - Update tests to cover get_sessions, pagination, and multi-user scenarios Co-Authored-By: Claude Opus 4.6 (1M context) --- .../extensions/memory/async_sqlite_session.py | 70 ++++++++++----- .../extensions/memory/sqlalchemy_session.py | 58 ++++++++----- src/agents/memory/session.py | 31 ------- src/agents/memory/sqlite_session.py | 77 +++++++++++------ tests/test_session.py | 85 +++++++++++++------ 5 files changed, 195 insertions(+), 126 deletions(-) diff --git a/src/agents/extensions/memory/async_sqlite_session.py b/src/agents/extensions/memory/async_sqlite_session.py index ff3479c28c..770ca771a1 100644 --- a/src/agents/extensions/memory/async_sqlite_session.py +++ b/src/agents/extensions/memory/async_sqlite_session.py @@ -316,31 +316,33 @@ async def create_session( return session @classmethod - async def get_session( + async def get_sessions( cls, user_id: str, - session_id: str, db_path: str | Path = ":memory:", sessions_table: str = "agent_sessions", messages_table: str = "agent_messages", users_table: str = "agent_users", - ) -> AsyncSQLiteSession | None: - """Retrieve an existing session for a user. + limit: int | None = None, + offset: int = 0, + ) -> list[AsyncSQLiteSession]: + """Retrieve all sessions for a user. Args: - user_id: The user identifier who owns the session. - session_id: The session identifier to retrieve. + user_id: The user identifier to look up sessions for. db_path: Path to the SQLite database file. Defaults to ':memory:'. sessions_table: Name of the sessions table. Defaults to 'agent_sessions'. messages_table: Name of the messages table. Defaults to 'agent_messages'. users_table: Name of the users table. Defaults to 'agent_users'. + limit: Maximum number of sessions to return. If None, returns all sessions. + offset: Number of sessions to skip before returning results. Defaults to 0. Returns: - The AsyncSQLiteSession instance if it exists and belongs to the user, - None otherwise. + List of AsyncSQLiteSession instances belonging to the user, ordered by + most recently updated first. """ - session = cls( - session_id=session_id, + probe = cls( + session_id="__probe__", db_path=db_path, sessions_table=sessions_table, messages_table=messages_table, @@ -348,21 +350,43 @@ async def get_session( user_id=user_id, ) - async with session._locked_connection() as conn: - cursor = await conn.execute( - f""" - SELECT session_id FROM {sessions_table} - WHERE session_id = ? AND user_id = ? - """, - (session_id, user_id), - ) - row = await cursor.fetchone() + async with probe._locked_connection() as conn: + if limit is None: + cursor = await conn.execute( + f""" + SELECT session_id FROM {sessions_table} + WHERE user_id = ? + ORDER BY updated_at DESC + LIMIT -1 OFFSET ? + """, + (user_id, offset), + ) + else: + cursor = await conn.execute( + f""" + SELECT session_id FROM {sessions_table} + WHERE user_id = ? + ORDER BY updated_at DESC + LIMIT ? OFFSET ? + """, + (user_id, limit, offset), + ) + rows = await cursor.fetchall() await cursor.close() - if row is None: - await session.close() - return None - return session + await probe.close() + + return [ + cls( + session_id=row[0], + db_path=db_path, + sessions_table=sessions_table, + messages_table=messages_table, + users_table=users_table, + user_id=user_id, + ) + for row in rows + ] async def get_sessions_for_user( self, diff --git a/src/agents/extensions/memory/sqlalchemy_session.py b/src/agents/extensions/memory/sqlalchemy_session.py index b838b0dad4..c7eabd5cea 100644 --- a/src/agents/extensions/memory/sqlalchemy_session.py +++ b/src/agents/extensions/memory/sqlalchemy_session.py @@ -487,10 +487,9 @@ async def create_session( return session @classmethod - async def get_session( + async def get_sessions( cls, user_id: str, - session_id: str, *, engine: AsyncEngine, create_tables: bool = False, @@ -498,25 +497,28 @@ async def get_session( messages_table: str = "agent_messages", users_table: str = "agent_users", session_settings: SessionSettings | None = None, - ) -> SQLAlchemySession | None: - """Retrieve an existing session for a user. + limit: int | None = None, + offset: int = 0, + ) -> list[SQLAlchemySession]: + """Retrieve all sessions for a user. Args: - user_id: The user identifier who owns the session. - session_id: The session identifier to retrieve. + user_id: The user identifier to look up sessions for. engine: A pre-configured SQLAlchemy async engine. create_tables: Whether to auto-create tables. Defaults to False. sessions_table: Override the default table name for sessions. messages_table: Override the default table name for messages. users_table: Override the default table name for users. session_settings: Session configuration settings. + limit: Maximum number of sessions to return. If None, returns all sessions. + offset: Number of sessions to skip before returning results. Defaults to 0. Returns: - The SQLAlchemySession instance if it exists and belongs to the user, - None otherwise. + List of SQLAlchemySession instances belonging to the user, ordered by + most recently updated first. """ - session = cls( - session_id=session_id, + probe = cls( + session_id="__probe__", engine=engine, create_tables=create_tables, sessions_table=sessions_table, @@ -526,18 +528,32 @@ async def get_session( session_settings=session_settings, ) - await session._ensure_tables() - async with session._session_factory() as sess: - result = await sess.execute( - select(session._sessions.c.session_id).where( - (session._sessions.c.session_id == session_id) - & (session._sessions.c.user_id == user_id) - ) + await probe._ensure_tables() + async with probe._session_factory() as sess: + stmt = ( + select(probe._sessions.c.session_id) + .where(probe._sessions.c.user_id == user_id) + .order_by(probe._sessions.c.updated_at.desc()) + .offset(offset) ) - if not result.scalar_one_or_none(): - return None - - return session + if limit is not None: + stmt = stmt.limit(limit) + result = await sess.execute(stmt) + session_ids = [row[0] for row in result.all()] + + return [ + cls( + session_id=sid, + engine=engine, + create_tables=False, + sessions_table=sessions_table, + messages_table=messages_table, + users_table=users_table, + user_id=user_id, + session_settings=session_settings, + ) + for sid in session_ids + ] async def get_sessions_for_user( self, diff --git a/src/agents/memory/session.py b/src/agents/memory/session.py index 7dd6578443..539f658f84 100644 --- a/src/agents/memory/session.py +++ b/src/agents/memory/session.py @@ -69,37 +69,6 @@ class SessionABC(ABC): user_id: str | None = None session_settings: SessionSettings | None = None - @classmethod - @abstractmethod - async def create_session(cls, user_id: str, **kwargs: object) -> SessionABC: - """Create a new session for a user with an auto-generated session ID. - - Args: - user_id: The user identifier to associate with the new session. - **kwargs: Backend-specific configuration (e.g., db_path, engine). - - Returns: - A new session instance with an auto-generated session_id. - """ - ... - - @classmethod - @abstractmethod - async def get_session( - cls, user_id: str, session_id: str, **kwargs: object - ) -> SessionABC | None: - """Retrieve an existing session for a user. - - Args: - user_id: The user identifier who owns the session. - session_id: The session identifier to retrieve. - **kwargs: Backend-specific configuration (e.g., db_path, engine). - - Returns: - The session instance if it exists and belongs to the user, None otherwise. - """ - ... - @abstractmethod async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. diff --git a/src/agents/memory/sqlite_session.py b/src/agents/memory/sqlite_session.py index 820521dfc4..d8b329ce02 100644 --- a/src/agents/memory/sqlite_session.py +++ b/src/agents/memory/sqlite_session.py @@ -363,32 +363,36 @@ def _persist_session(): return session @classmethod - async def get_session( + async def get_sessions( cls, user_id: str, - session_id: str, db_path: str | Path = ":memory:", sessions_table: str = "agent_sessions", messages_table: str = "agent_messages", users_table: str = "agent_users", session_settings: SessionSettings | None = None, - ) -> SQLiteSession | None: - """Retrieve an existing session for a user. + limit: int | None = None, + offset: int = 0, + ) -> list[SQLiteSession]: + """Retrieve all sessions for a user. Args: - user_id: The user identifier who owns the session. - session_id: The session identifier to retrieve. + user_id: The user identifier to look up sessions for. db_path: Path to the SQLite database file. Defaults to ':memory:'. sessions_table: Name of the sessions table. Defaults to 'agent_sessions'. messages_table: Name of the messages table. Defaults to 'agent_messages'. users_table: Name of the users table. Defaults to 'agent_users'. session_settings: Session configuration settings. + limit: Maximum number of sessions to return. If None, returns all sessions. + offset: Number of sessions to skip before returning results. Defaults to 0. Returns: - The SQLiteSession instance if it exists and belongs to the user, None otherwise. + List of SQLiteSession instances belonging to the user, ordered by most + recently updated first. """ - session = cls( - session_id=session_id, + # Use a temporary instance to access the DB and query session IDs + probe = cls( + session_id="__probe__", db_path=db_path, sessions_table=sessions_table, messages_table=messages_table, @@ -397,23 +401,46 @@ async def get_session( session_settings=session_settings, ) - def _check_session(): - conn = session._get_connection() - with session._lock if session._is_memory_db else threading.Lock(): - cursor = conn.execute( - f""" - SELECT session_id FROM {sessions_table} - WHERE session_id = ? AND user_id = ? - """, - (session_id, user_id), - ) - return cursor.fetchone() is not None + def _fetch_ids(): + conn = probe._get_connection() + with probe._lock if probe._is_memory_db else threading.Lock(): + if limit is None: + cursor = conn.execute( + f""" + SELECT session_id FROM {sessions_table} + WHERE user_id = ? + ORDER BY updated_at DESC + LIMIT -1 OFFSET ? + """, + (user_id, offset), + ) + else: + cursor = conn.execute( + f""" + SELECT session_id FROM {sessions_table} + WHERE user_id = ? + ORDER BY updated_at DESC + LIMIT ? OFFSET ? + """, + (user_id, limit, offset), + ) + return [row[0] for row in cursor.fetchall()] - exists = await asyncio.to_thread(_check_session) - if not exists: - session.close() - return None - return session + session_ids = await asyncio.to_thread(_fetch_ids) + probe.close() + + return [ + cls( + session_id=sid, + db_path=db_path, + sessions_table=sessions_table, + messages_table=messages_table, + users_table=users_table, + user_id=user_id, + session_settings=session_settings, + ) + for sid in session_ids + ] async def get_sessions_for_user( self, diff --git a/tests/test_session.py b/tests/test_session.py index a2e16ec606..8d4c90a9da 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -797,39 +797,70 @@ async def test_sqlite_create_session(): @pytest.mark.asyncio -async def test_sqlite_get_session(): - """Test get_session retrieves an existing session or returns None.""" +async def test_sqlite_get_sessions(): + """Test get_sessions retrieves all sessions for a user.""" with tempfile.TemporaryDirectory() as temp_dir: db_path = Path(temp_dir) / "test_get.db" - # Create a session first - created = await SQLiteSession.create_session("bob", db_path=db_path) - session_id = created.session_id + # Create sessions for bob + s1 = await SQLiteSession.create_session("bob", db_path=db_path) + s2 = await SQLiteSession.create_session("bob", db_path=db_path) + # Create a session for eve + await SQLiteSession.create_session("eve", db_path=db_path) - # Add some items + # Add items to s1 items: list[TResponseInputItem] = [{"role": "user", "content": "Hi Bob"}] - await created.add_items(items) - created.close() + await s1.add_items(items) + s1.close() + s2.close() - # Retrieve it - retrieved = await SQLiteSession.get_session("bob", session_id, db_path=db_path) - assert retrieved is not None - assert retrieved.session_id == session_id - assert retrieved.user_id == "bob" + # Retrieve bob's sessions + bob_sessions = await SQLiteSession.get_sessions("bob", db_path=db_path) + assert len(bob_sessions) == 2 + session_ids = {s.session_id for s in bob_sessions} + assert s1.session_id in session_ids + assert s2.session_id in session_ids + + # Each returned session should be usable + for s in bob_sessions: + assert s.user_id == "bob" + if s.session_id == s1.session_id: + history = await s.get_items() + assert len(history) == 1 + assert history[0]["content"] == "Hi Bob" + s.close() + + # Non-existent user returns empty list + empty = await SQLiteSession.get_sessions("nobody", db_path=db_path) + assert empty == [] - # Verify items are preserved - history = await retrieved.get_items() - assert len(history) == 1 - assert history[0]["content"] == "Hi Bob" - retrieved.close() - # Wrong user_id should return None - wrong_user = await SQLiteSession.get_session("eve", session_id, db_path=db_path) - assert wrong_user is None +@pytest.mark.asyncio +async def test_sqlite_get_sessions_pagination(): + """Test get_sessions supports limit and offset.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_pagination_get.db" - # Non-existent session_id should return None - missing = await SQLiteSession.get_session("bob", "non-existent-id", db_path=db_path) - assert missing is None + created = [] + for _ in range(5): + s = await SQLiteSession.create_session("paguser", db_path=db_path) + created.append(s) + s.close() + + page1 = await SQLiteSession.get_sessions("paguser", db_path=db_path, limit=2) + assert len(page1) == 2 + + page2 = await SQLiteSession.get_sessions("paguser", db_path=db_path, limit=2, offset=2) + assert len(page2) == 2 + + page3 = await SQLiteSession.get_sessions("paguser", db_path=db_path, limit=2, offset=4) + assert len(page3) == 1 + + all_ids = {s.session_id for s in page1 + page2 + page3} + assert all_ids == {s.session_id for s in created} + + for s in page1 + page2 + page3: + s.close() @pytest.mark.asyncio @@ -844,11 +875,13 @@ async def test_sqlite_create_multiple_sessions_for_user(): assert s1.session_id != s2.session_id assert s1.user_id == s2.user_id == "charlie" - sessions = await s1.get_sessions_for_user("charlie") - assert set(sessions) == {s1.session_id, s2.session_id} + sessions = await SQLiteSession.get_sessions("charlie", db_path=db_path) + assert {s.session_id for s in sessions} == {s1.session_id, s2.session_id} s1.close() s2.close() + for s in sessions: + s.close() @pytest.mark.asyncio