From e8b2036cb3b1d8cee5e9eabe35e03d85ad1a830b Mon Sep 17 00:00:00 2001 From: ecanlar Date: Tue, 31 Mar 2026 11:00:56 +0200 Subject: [PATCH 1/2] Add agent_users table to associate sessions with users MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduces user-session association across all SQL-based session backends (SQLiteSession, AsyncSQLiteSession, SQLAlchemySession), similar to how Google's ADK models the User → Session relationship. Changes: - Add agent_users table with user_id, metadata, and timestamps - Add user_id foreign key to agent_sessions table - Add optional user_id parameter to session constructors - Add get_sessions_for_user() method to query sessions by user - Add user_id attribute to Session protocol and SessionABC - Add tests for user association functionality Closes #2808 Co-Authored-By: Claude Opus 4.6 (1M context) --- .../extensions/memory/async_sqlite_session.py | 65 ++++++++++++++++- .../extensions/memory/sqlalchemy_session.py | 72 ++++++++++++++++++- src/agents/memory/session.py | 2 + src/agents/memory/sqlite_session.py | 69 +++++++++++++++++- tests/test_session.py | 52 ++++++++++++++ 5 files changed, 253 insertions(+), 7 deletions(-) diff --git a/src/agents/extensions/memory/async_sqlite_session.py b/src/agents/extensions/memory/async_sqlite_session.py index 2eef596264..f9e0d2a89a 100644 --- a/src/agents/extensions/memory/async_sqlite_session.py +++ b/src/agents/extensions/memory/async_sqlite_session.py @@ -30,6 +30,8 @@ def __init__( db_path: str | Path = ":memory:", sessions_table: str = "agent_sessions", messages_table: str = "agent_messages", + users_table: str = "agent_users", + user_id: str | None = None, ): """Initialize the async SQLite session. @@ -39,27 +41,52 @@ def __init__( sessions_table: Name of the table to store session metadata. Defaults to 'agent_sessions' messages_table: Name of the table to store message data. Defaults to 'agent_messages' + users_table: Name of the table to store user metadata. Defaults to 'agent_users' + user_id: Optional user identifier to associate this session with a user. """ self.session_id = session_id + self.user_id = user_id self.db_path = db_path self.sessions_table = sessions_table self.messages_table = messages_table + self.users_table = users_table self._connection: aiosqlite.Connection | None = None self._lock = asyncio.Lock() self._init_lock = asyncio.Lock() async def _init_db_for_connection(self, conn: aiosqlite.Connection) -> None: """Initialize the database schema for a specific connection.""" + await conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.users_table} ( + user_id TEXT PRIMARY KEY, + metadata TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """ + ) + await conn.execute( f""" CREATE TABLE IF NOT EXISTS {self.sessions_table} ( session_id TEXT PRIMARY KEY, + user_id TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES {self.users_table} (user_id) + ON DELETE SET NULL ) """ ) + await conn.execute( + f""" + CREATE INDEX IF NOT EXISTS idx_{self.sessions_table}_user_id + ON {self.sessions_table} (user_id) + """ + ) + await conn.execute( f""" CREATE TABLE IF NOT EXISTS {self.messages_table} ( @@ -160,11 +187,21 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: return async with self._locked_connection() as conn: + # Ensure user exists if user_id is provided + if self.user_id is not None: + await conn.execute( + f""" + INSERT OR IGNORE INTO {self.users_table} (user_id) VALUES (?) + """, + (self.user_id,), + ) + await conn.execute( f""" - INSERT OR IGNORE INTO {self.sessions_table} (session_id) VALUES (?) + INSERT OR IGNORE INTO {self.sessions_table} (session_id, user_id) + VALUES (?, ?) """, - (self.session_id,), + (self.session_id, self.user_id), ) message_data = [(self.session_id, json.dumps(item)) for item in items] @@ -233,6 +270,28 @@ async def clear_session(self) -> None: ) await conn.commit() + async def get_sessions_for_user(self, user_id: str) -> list[str]: + """Retrieve all session IDs associated with a given user. + + Args: + user_id: The user identifier to look up sessions for. + + Returns: + List of session IDs belonging to the user, ordered by most recently updated first. + """ + async with self._locked_connection() as conn: + cursor = await conn.execute( + f""" + SELECT session_id FROM {self.sessions_table} + WHERE user_id = ? + ORDER BY updated_at DESC + """, + (user_id,), + ) + rows = await cursor.fetchall() + await cursor.close() + return [row[0] for row in rows] + async def close(self) -> None: """Close the database connection.""" if self._connection is None: diff --git a/src/agents/extensions/memory/sqlalchemy_session.py b/src/agents/extensions/memory/sqlalchemy_session.py index 759ddaf5d5..3feb6047d4 100644 --- a/src/agents/extensions/memory/sqlalchemy_session.py +++ b/src/agents/extensions/memory/sqlalchemy_session.py @@ -86,6 +86,8 @@ def __init__( create_tables: bool = False, sessions_table: str = "agent_sessions", messages_table: str = "agent_messages", + users_table: str = "agent_users", + user_id: str | None = None, session_settings: SessionSettings | None = None, ): """Initializes a new SQLAlchemySession. @@ -100,9 +102,13 @@ def __init__( development and testing when migrations aren't used. sessions_table (str, optional): Override the default table name for sessions if needed. messages_table (str, optional): Override the default table name for messages if needed. + users_table (str, optional): Override the default table name for users if needed. + user_id (str | None, optional): Optional user identifier to associate this session + with a user in the agent_users table. session_settings (SessionSettings | None, optional): Session configuration settings """ self.session_id = session_id + self.user_id = user_id self.session_settings = session_settings or SessionSettings() self._engine = engine self._init_lock = ( @@ -112,10 +118,36 @@ def __init__( ) self._metadata = MetaData() + self._users = Table( + users_table, + self._metadata, + Column("user_id", String, primary_key=True), + Column("metadata", Text, nullable=True), + Column( + "created_at", + TIMESTAMP(timezone=False), + server_default=sql_text("CURRENT_TIMESTAMP"), + nullable=False, + ), + Column( + "updated_at", + TIMESTAMP(timezone=False), + server_default=sql_text("CURRENT_TIMESTAMP"), + onupdate=sql_text("CURRENT_TIMESTAMP"), + nullable=False, + ), + ) + self._sessions = Table( sessions_table, self._metadata, Column("session_id", String, primary_key=True), + Column( + "user_id", + String, + ForeignKey(f"{users_table}.user_id", ondelete="SET NULL"), + nullable=True, + ), Column( "created_at", TIMESTAMP(timezone=False), @@ -129,6 +161,7 @@ def __init__( onupdate=sql_text("CURRENT_TIMESTAMP"), nullable=False, ), + Index(f"idx_{sessions_table}_user_id", "user_id"), ) self._messages = Table( @@ -296,6 +329,22 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: async with self._session_factory() as sess: async with sess.begin(): + # Ensure user exists if user_id is provided + if self.user_id is not None: + existing_user = await sess.execute( + select(self._users.c.user_id).where( + self._users.c.user_id == self.user_id + ) + ) + if not existing_user.scalar_one_or_none(): + try: + async with sess.begin_nested(): + await sess.execute( + insert(self._users).values({"user_id": self.user_id}) + ) + except IntegrityError: + pass + # Avoid check-then-insert races on the first write while keeping # the common path free of avoidable integrity exceptions. existing = await sess.execute( @@ -307,7 +356,9 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: try: async with sess.begin_nested(): await sess.execute( - insert(self._sessions).values({"session_id": self.session_id}) + insert(self._sessions).values( + {"session_id": self.session_id, "user_id": self.user_id} + ) ) except IntegrityError: # Another concurrent writer created the parent row first. @@ -372,6 +423,25 @@ async def clear_session(self) -> None: delete(self._sessions).where(self._sessions.c.session_id == self.session_id) ) + async def get_sessions_for_user(self, user_id: str) -> list[str]: + """Retrieve all session IDs associated with a given user. + + Args: + user_id: The user identifier to look up sessions for. + + Returns: + List of session IDs belonging to the user, ordered by most recently updated first. + """ + await self._ensure_tables() + async with self._session_factory() as sess: + stmt = ( + select(self._sessions.c.session_id) + .where(self._sessions.c.user_id == user_id) + .order_by(self._sessions.c.updated_at.desc()) + ) + result = await sess.execute(stmt) + return [row[0] for row in result.all()] + @property def engine(self) -> AsyncEngine: """Access the underlying SQLAlchemy AsyncEngine. diff --git a/src/agents/memory/session.py b/src/agents/memory/session.py index 85a65a1690..539f658f84 100644 --- a/src/agents/memory/session.py +++ b/src/agents/memory/session.py @@ -19,6 +19,7 @@ class Session(Protocol): """ session_id: str + user_id: str | None = None session_settings: SessionSettings | None = None async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: @@ -65,6 +66,7 @@ class SessionABC(ABC): """ session_id: str + user_id: str | None = None session_settings: SessionSettings | None = None @abstractmethod diff --git a/src/agents/memory/sqlite_session.py b/src/agents/memory/sqlite_session.py index 92c9630c9b..5475f179c1 100644 --- a/src/agents/memory/sqlite_session.py +++ b/src/agents/memory/sqlite_session.py @@ -27,6 +27,8 @@ def __init__( db_path: str | Path = ":memory:", sessions_table: str = "agent_sessions", messages_table: str = "agent_messages", + users_table: str = "agent_users", + user_id: str | None = None, session_settings: SessionSettings | None = None, ): """Initialize the SQLite session. @@ -37,14 +39,19 @@ def __init__( sessions_table: Name of the table to store session metadata. Defaults to 'agent_sessions' messages_table: Name of the table to store message data. Defaults to 'agent_messages' + users_table: Name of the table to store user metadata. Defaults to 'agent_users' + user_id: Optional user identifier to associate this session with a user. + When provided, the session will be linked to the user in the agent_users table. session_settings: Session configuration settings including default limit for retrieving items. If None, uses default SessionSettings(). """ self.session_id = session_id + self.user_id = user_id self.session_settings = session_settings or SessionSettings() self.db_path = db_path self.sessions_table = sessions_table self.messages_table = messages_table + self.users_table = users_table self._local = threading.local() self._lock = threading.Lock() @@ -82,16 +89,37 @@ def _get_connection(self) -> sqlite3.Connection: def _init_db_for_connection(self, conn: sqlite3.Connection) -> None: """Initialize the database schema for a specific connection.""" + conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.users_table} ( + user_id TEXT PRIMARY KEY, + metadata TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """ + ) + conn.execute( f""" CREATE TABLE IF NOT EXISTS {self.sessions_table} ( session_id TEXT PRIMARY KEY, + user_id TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES {self.users_table} (user_id) + ON DELETE SET NULL ) """ ) + conn.execute( + f""" + CREATE INDEX IF NOT EXISTS idx_{self.sessions_table}_user_id + ON {self.sessions_table} (user_id) + """ + ) + conn.execute( f""" CREATE TABLE IF NOT EXISTS {self.messages_table} ( @@ -183,12 +211,22 @@ def _add_items_sync(): conn = self._get_connection() with self._lock if self._is_memory_db else threading.Lock(): + # Ensure user exists if user_id is provided + if self.user_id is not None: + conn.execute( + f""" + INSERT OR IGNORE INTO {self.users_table} (user_id) VALUES (?) + """, + (self.user_id,), + ) + # Ensure session exists conn.execute( f""" - INSERT OR IGNORE INTO {self.sessions_table} (session_id) VALUES (?) + INSERT OR IGNORE INTO {self.sessions_table} (session_id, user_id) + VALUES (?, ?) """, - (self.session_id,), + (self.session_id, self.user_id), ) # Add items @@ -273,6 +311,31 @@ def _clear_session_sync(): await asyncio.to_thread(_clear_session_sync) + async def get_sessions_for_user(self, user_id: str) -> list[str]: + """Retrieve all session IDs associated with a given user. + + Args: + user_id: The user identifier to look up sessions for. + + Returns: + List of session IDs belonging to the user, ordered by most recently updated first. + """ + + def _get_sessions_sync(): + conn = self._get_connection() + with self._lock if self._is_memory_db else threading.Lock(): + cursor = conn.execute( + f""" + SELECT session_id FROM {self.sessions_table} + WHERE user_id = ? + ORDER BY updated_at DESC + """, + (user_id,), + ) + return [row[0] for row in cursor.fetchall()] + + return await asyncio.to_thread(_get_sessions_sync) + def close(self) -> None: """Close the database connection.""" if self._is_memory_db: diff --git a/tests/test_session.py b/tests/test_session.py index aaa80ec7aa..22dc34883a 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -671,6 +671,58 @@ async def test_session_settings_resolve(): assert final_none.limit == 100 +@pytest.mark.asyncio +async def test_sqlite_session_user_association(): + """Test that sessions can be associated with users via user_id.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_users.db" + + # Create sessions for user_1 + session_a = SQLiteSession("session_a", db_path, user_id="user_1") + session_b = SQLiteSession("session_b", db_path, user_id="user_1") + # Create a session for user_2 + session_c = SQLiteSession("session_c", db_path, user_id="user_2") + # Create a session without a user + session_d = SQLiteSession("session_d", db_path) + + # Add items to trigger session/user creation + items: list[TResponseInputItem] = [{"role": "user", "content": "Hello"}] + await session_a.add_items(items) + await session_b.add_items(items) + await session_c.add_items(items) + await session_d.add_items(items) + + # Query sessions for user_1 + user_1_sessions = await session_a.get_sessions_for_user("user_1") + assert set(user_1_sessions) == {"session_a", "session_b"} + + # Query sessions for user_2 + user_2_sessions = await session_a.get_sessions_for_user("user_2") + assert user_2_sessions == ["session_c"] + + # Query sessions for non-existent user + empty_sessions = await session_a.get_sessions_for_user("user_999") + assert empty_sessions == [] + + session_a.close() + session_b.close() + session_c.close() + session_d.close() + + +@pytest.mark.asyncio +async def test_sqlite_session_user_id_attribute(): + """Test that user_id is correctly stored on the session instance.""" + session_with_user = SQLiteSession("s1", user_id="alice") + assert session_with_user.user_id == "alice" + + session_without_user = SQLiteSession("s2") + assert session_without_user.user_id is None + + session_with_user.close() + session_without_user.close() + + @pytest.mark.asyncio async def test_runner_with_session_settings_override(): """Test that RunConfig can override session's default settings.""" From 826a5fc60e93f9770d0b6afb59883f7b45c506f4 Mon Sep 17 00:00:00 2001 From: ecanlar Date: Tue, 31 Mar 2026 11:05:33 +0200 Subject: [PATCH 2/2] Add limit/offset pagination to get_sessions_for_user Adds limit and offset parameters to get_sessions_for_user() across all three SQL backends, consistent with how get_items() supports limiting retrieved history. This enables paginated retrieval of user sessions. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../extensions/memory/async_sqlite_session.py | 39 +++++++++++----- .../extensions/memory/sqlalchemy_session.py | 14 +++++- src/agents/memory/sqlite_session.py | 39 +++++++++++----- tests/test_session.py | 45 +++++++++++++++++++ 4 files changed, 115 insertions(+), 22 deletions(-) diff --git a/src/agents/extensions/memory/async_sqlite_session.py b/src/agents/extensions/memory/async_sqlite_session.py index f9e0d2a89a..c16e332e45 100644 --- a/src/agents/extensions/memory/async_sqlite_session.py +++ b/src/agents/extensions/memory/async_sqlite_session.py @@ -270,24 +270,43 @@ async def clear_session(self) -> None: ) await conn.commit() - async def get_sessions_for_user(self, user_id: str) -> list[str]: - """Retrieve all session IDs associated with a given user. + async def get_sessions_for_user( + self, + user_id: str, + limit: int | None = None, + offset: int = 0, + ) -> list[str]: + """Retrieve session IDs associated with a given user. Args: user_id: The user identifier to look up sessions for. + limit: Maximum number of session IDs to return. If None, returns all sessions. + offset: Number of sessions to skip before returning results. Defaults to 0. Returns: List of session IDs belonging to the user, ordered by most recently updated first. """ async with self._locked_connection() as conn: - cursor = await conn.execute( - f""" - SELECT session_id FROM {self.sessions_table} - WHERE user_id = ? - ORDER BY updated_at DESC - """, - (user_id,), - ) + if limit is None: + cursor = await conn.execute( + f""" + SELECT session_id FROM {self.sessions_table} + WHERE user_id = ? + ORDER BY updated_at DESC + LIMIT -1 OFFSET ? + """, + (user_id, offset), + ) + else: + cursor = await conn.execute( + f""" + SELECT session_id FROM {self.sessions_table} + WHERE user_id = ? + ORDER BY updated_at DESC + LIMIT ? OFFSET ? + """, + (user_id, limit, offset), + ) rows = await cursor.fetchall() await cursor.close() return [row[0] for row in rows] diff --git a/src/agents/extensions/memory/sqlalchemy_session.py b/src/agents/extensions/memory/sqlalchemy_session.py index 3feb6047d4..62c5c6141f 100644 --- a/src/agents/extensions/memory/sqlalchemy_session.py +++ b/src/agents/extensions/memory/sqlalchemy_session.py @@ -423,11 +423,18 @@ async def clear_session(self) -> None: delete(self._sessions).where(self._sessions.c.session_id == self.session_id) ) - async def get_sessions_for_user(self, user_id: str) -> list[str]: - """Retrieve all session IDs associated with a given user. + async def get_sessions_for_user( + self, + user_id: str, + limit: int | None = None, + offset: int = 0, + ) -> list[str]: + """Retrieve session IDs associated with a given user. Args: user_id: The user identifier to look up sessions for. + limit: Maximum number of session IDs to return. If None, returns all sessions. + offset: Number of sessions to skip before returning results. Defaults to 0. Returns: List of session IDs belonging to the user, ordered by most recently updated first. @@ -438,7 +445,10 @@ async def get_sessions_for_user(self, user_id: str) -> list[str]: select(self._sessions.c.session_id) .where(self._sessions.c.user_id == user_id) .order_by(self._sessions.c.updated_at.desc()) + .offset(offset) ) + if limit is not None: + stmt = stmt.limit(limit) result = await sess.execute(stmt) return [row[0] for row in result.all()] diff --git a/src/agents/memory/sqlite_session.py b/src/agents/memory/sqlite_session.py index 5475f179c1..1b514950be 100644 --- a/src/agents/memory/sqlite_session.py +++ b/src/agents/memory/sqlite_session.py @@ -311,11 +311,18 @@ def _clear_session_sync(): await asyncio.to_thread(_clear_session_sync) - async def get_sessions_for_user(self, user_id: str) -> list[str]: - """Retrieve all session IDs associated with a given user. + async def get_sessions_for_user( + self, + user_id: str, + limit: int | None = None, + offset: int = 0, + ) -> list[str]: + """Retrieve session IDs associated with a given user. Args: user_id: The user identifier to look up sessions for. + limit: Maximum number of session IDs to return. If None, returns all sessions. + offset: Number of sessions to skip before returning results. Defaults to 0. Returns: List of session IDs belonging to the user, ordered by most recently updated first. @@ -324,14 +331,26 @@ async def get_sessions_for_user(self, user_id: str) -> list[str]: def _get_sessions_sync(): conn = self._get_connection() with self._lock if self._is_memory_db else threading.Lock(): - cursor = conn.execute( - f""" - SELECT session_id FROM {self.sessions_table} - WHERE user_id = ? - ORDER BY updated_at DESC - """, - (user_id,), - ) + if limit is None: + cursor = conn.execute( + f""" + SELECT session_id FROM {self.sessions_table} + WHERE user_id = ? + ORDER BY updated_at DESC + LIMIT -1 OFFSET ? + """, + (user_id, offset), + ) + else: + cursor = conn.execute( + f""" + SELECT session_id FROM {self.sessions_table} + WHERE user_id = ? + ORDER BY updated_at DESC + LIMIT ? OFFSET ? + """, + (user_id, limit, offset), + ) return [row[0] for row in cursor.fetchall()] return await asyncio.to_thread(_get_sessions_sync) diff --git a/tests/test_session.py b/tests/test_session.py index 22dc34883a..a325ad0ed2 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -710,6 +710,51 @@ async def test_sqlite_session_user_association(): session_d.close() +@pytest.mark.asyncio +async def test_sqlite_session_get_sessions_for_user_pagination(): + """Test limit and offset pagination on get_sessions_for_user.""" + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "test_pagination.db" + items: list[TResponseInputItem] = [{"role": "user", "content": "hi"}] + + # Create 5 sessions for the same user, adding items sequentially so + # updated_at ordering is deterministic (most recent last). + session_ids = [f"s{i}" for i in range(5)] + sessions = [] + for sid in session_ids: + s = SQLiteSession(sid, db_path, user_id="paginated_user") + await s.add_items(items) + sessions.append(s) + + ref = sessions[0] # any session instance sharing the same db + + # Without limit — returns all 5 + all_ids = await ref.get_sessions_for_user("paginated_user") + assert len(all_ids) == 5 + + # limit=2 — returns the 2 most recently updated + page1 = await ref.get_sessions_for_user("paginated_user", limit=2) + assert len(page1) == 2 + + # limit=2, offset=2 — next page + page2 = await ref.get_sessions_for_user("paginated_user", limit=2, offset=2) + assert len(page2) == 2 + + # limit=2, offset=4 — last page (only 1 left) + page3 = await ref.get_sessions_for_user("paginated_user", limit=2, offset=4) + assert len(page3) == 1 + + # All pages together should cover all session ids + assert set(page1 + page2 + page3) == set(session_ids) + + # offset beyond total — empty + empty = await ref.get_sessions_for_user("paginated_user", offset=10) + assert empty == [] + + for s in sessions: + s.close() + + @pytest.mark.asyncio async def test_sqlite_session_user_id_attribute(): """Test that user_id is correctly stored on the session instance."""