From 5844fb570debad07cd1fe815bf66da414e6b1184 Mon Sep 17 00:00:00 2001 From: baobaodawang-creater Date: Sun, 19 Apr 2026 18:08:44 +0900 Subject: [PATCH] Round 2 improvements --- .github/workflows/ci.yml | 124 +++ .github/workflows/isort.yml | 69 -- .github/workflows/python-unit-tests.yml | 53 -- .pre-commit-config.yaml | 48 + CONTRIBUTING.md | 34 +- .../dev/session_state_merge_semantics.md | 235 +++++ pyproject.toml | 116 +++ src/google/adk/__init__.py | 12 +- src/google/adk/agents/base_agent.py | 241 ++++- src/google/adk/errors/__init__.py | 10 + src/google/adk/errors/agent_timeout_error.py | 63 ++ .../adk/errors/version_mismatch_error.py | 51 ++ .../adk/flows/llm_flows/base_llm_flow.py | 83 +- src/google/adk/flows/llm_flows/functions.py | 52 +- src/google/adk/sessions/__init__.py | 2 + .../adk/sessions/base_session_service.py | 116 +++ .../adk/sessions/sqlite_session_service.py | 190 +++- tests/sessions/__init__.py | 13 + tests/sessions/test_session_integration.py | 853 ++++++++++++++++++ tests/unittests/agents/test_agent_timeout.py | 355 ++++++++ .../sessions/test_sqlite_session_service.py | 361 ++++++++ 21 files changed, 2909 insertions(+), 172 deletions(-) create mode 100644 .github/workflows/ci.yml delete mode 100644 .github/workflows/isort.yml delete mode 100644 .github/workflows/python-unit-tests.yml create mode 100644 .pre-commit-config.yaml create mode 100644 contributing/dev/session_state_merge_semantics.md create mode 100644 src/google/adk/errors/agent_timeout_error.py create mode 100644 src/google/adk/errors/version_mismatch_error.py create mode 100644 tests/sessions/__init__.py create mode 100644 tests/sessions/test_session_integration.py create mode 100644 tests/unittests/agents/test_agent_timeout.py create mode 100644 tests/unittests/sessions/test_sqlite_session_service.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000000..07a6414b98 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,124 @@ +# Unified CI Workflow for ADK Python +# +# This workflow consolidates and replaces the following legacy workflows: +# - isort.yml (replaced by lint job's isort check) +# - python-unit-tests.yml (replaced by test job) +# +# The following workflows are intentionally kept separate: +# - mypy.yml: Manual-triggered multi-Python-version type checking +# - pyink.yml: Separate formatting check +# - Various release/* workflows: Release automation +# - Triaging/monitoring workflows: Project maintenance + +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + lint: + name: Lint Check + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v6 + + - name: Set up Python 3.11 + uses: actions/setup-python@v6 + with: + python-version: "3.11" + + - name: Install uv + uses: astral-sh/setup-uv@v7 + + - name: Create virtual environment and install dependencies + run: | + uv venv .venv + source .venv/bin/activate + uv sync --all-extras + uv pip install ruff mypy build + + - name: Run ruff lint check + run: | + source .venv/bin/activate + ruff check src/ tests/ --no-fix + + - name: Run isort check + run: | + source .venv/bin/activate + isort --check src/ tests/ contributing/ + + - name: Run pyink format check + run: | + source .venv/bin/activate + pyink --check --diff --config pyproject.toml src/ tests/ contributing/ + + - name: Run mypy type check (transitional, to be removed by 2026-07-19) + continue-on-error: true + run: | + source .venv/bin/activate + mypy src/ --config-file pyproject.toml + + test: + name: Test (Python ${{ matrix.python-version }}) + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] + + steps: + - name: Checkout code + uses: actions/checkout@v6 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python-version }} + + - name: Install uv + uses: astral-sh/setup-uv@v7 + + - name: Create virtual environment and install test dependencies + run: | + uv venv .venv + source .venv/bin/activate + uv sync --extra test + + - name: Run unit tests with pytest + run: | + source .venv/bin/activate + pytest tests/unittests \ + --ignore=tests/unittests/artifacts/test_artifact_service.py \ + --ignore=tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py \ + -v + + build: + name: Build Package + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v6 + + - name: Set up Python 3.11 + uses: actions/setup-python@v6 + with: + python-version: "3.11" + + - name: Install uv + uses: astral-sh/setup-uv@v7 + + - name: Build the package + run: | + uv build + + - name: Verify built artifacts + run: | + ls -la dist/ diff --git a/.github/workflows/isort.yml b/.github/workflows/isort.yml deleted file mode 100644 index 840d4ea8a7..0000000000 --- a/.github/workflows/isort.yml +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright 2026 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: Check sorting of imports - -on: - pull_request: - paths: - - '**.py' - - 'pyproject.toml' - -jobs: - isort-check: - runs-on: ubuntu-latest - - steps: - - name: Checkout code - uses: actions/checkout@v6 - with: - fetch-depth: 2 - - - name: Set up Python - uses: actions/setup-python@v6 - with: - python-version: '3.11' - - - name: Install isort - run: | - pip install isort - - - name: Run isort on changed files - id: run_isort - run: | - git fetch origin ${GITHUB_BASE_REF} - CHANGED_FILES=$(git diff --diff-filter=ACMR --name-only origin/${GITHUB_BASE_REF}...HEAD | grep -E '\.py$' || true) - if [ -n "$CHANGED_FILES" ]; then - echo "Changed Python files:" - echo "$CHANGED_FILES" - echo "" - FORMATTED_FILES=$(echo "$CHANGED_FILES" | tr '\n' ' ') - - # Run isort --check - set +e - isort --check $CHANGED_FILES - RESULT=$? - set -e - if [ $RESULT -ne 0 ]; then - echo "" - echo "❌ isort check failed!" - echo "👉 To fix import order, run locally:" - echo "" - echo " isort $FORMATTED_FILES" - echo "" - exit $RESULT - fi - else - echo "No Python files changed. Skipping isort check." - fi diff --git a/.github/workflows/python-unit-tests.yml b/.github/workflows/python-unit-tests.yml deleted file mode 100644 index 866ba8b383..0000000000 --- a/.github/workflows/python-unit-tests.yml +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright 2026 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: Python Unit Tests - -on: - push: - branches: [ main ] - pull_request: - branches: [ main ] - -jobs: - test: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] - - steps: - - name: Checkout code - uses: actions/checkout@v6 - - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v6 - with: - python-version: ${{ matrix.python-version }} - - - name: Install the latest version of uv - uses: astral-sh/setup-uv@v7 - - - name: Install dependencies - run: | - uv venv .venv - source .venv/bin/activate - uv sync --extra test - - - name: Run unit tests with pytest - run: | - source .venv/bin/activate - pytest tests/unittests \ - --ignore=tests/unittests/artifacts/test_artifact_service.py \ - --ignore=tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000..f0862432bf --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,48 @@ +repos: +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.15.11 + hooks: + - id: ruff + name: ruff-lint + args: [--fix] + +- repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + name: isort-imports + +- repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.15.0 + hooks: + - id: mypy + name: mypy-type-check + args: [--config-file=pyproject.toml] + additional_dependencies: + - pydantic>=2.12.0 + - typing-extensions>=4.5 + exclude: ^(tests/|contributing/samples/) + +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: trailing-whitespace + name: trailing-whitespace-fixer + exclude_types: [markdown] + - id: end-of-file-fixer + name: end-of-file-fixer + - id: check-yaml + name: check-yaml-syntax + - id: check-toml + name: check-toml-syntax + - id: check-added-large-files + name: check-added-large-files + args: [--maxkb=1000] + +- repo: local + hooks: + - id: pyink-format + name: pyink-format + entry: pyink --config pyproject.toml + language: system + types: [python] diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 78172029a1..cd72e677b2 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -209,13 +209,43 @@ part before or alongside your code PR. ./autoformat.sh ``` -7. **Build the wheel file:** +7. **Set up pre-commit hooks (recommended):** + + The project uses pre-commit hooks to automate code quality checks. + Install and enable them with: + + ```shell + uv pip install pre-commit + pre-commit install + ``` + + This will set up the following hooks that run automatically on every commit: + - `ruff`: Lint checking and automatic fixes (using configured rules) + - `isort`: Import organization (Google style) + - `pyink`: Code formatting (Google style, 2-space indent, 80 char line length) + - `mypy`: Type checking (**transitional state: currently allowed to fail, will be enforced by 2026-07-19**) + - `trailing-whitespace`: Removes trailing whitespace + - `end-of-file-fixer`: Ensures files end with a newline + - `check-yaml`: Validates YAML syntax + - `check-toml`: Validates TOML syntax + - `check-added-large-files`: Prevents adding large files (>1MB) + + To manually run all hooks on all files: + + ```shell + pre-commit run --all-files + ``` + + **Hook version management:** Run `pre-commit autoupdate` quarterly to update hooks + to their latest stable versions. + +8. **Build the wheel file:** ```shell uv build ``` -8. **Test the locally built wheel file:** Have a simple testing folder setup as +9. **Test the locally built wheel file:** Have a simple testing folder setup as mentioned in the [quickstart](https://google.github.io/adk-docs/get-started/quickstart/). diff --git a/contributing/dev/session_state_merge_semantics.md b/contributing/dev/session_state_merge_semantics.md new file mode 100644 index 0000000000..43c2044cdb --- /dev/null +++ b/contributing/dev/session_state_merge_semantics.md @@ -0,0 +1,235 @@ +# Session State Merge Semantics Bug Analysis + +## Problem Statement + +ADK Python 目前有三种 `BaseSessionService` 实现,但它们对 `state_delta` 的合并语义**不一致**: + +| Implementation | Merge Mechanism | Behavior | +|----------------|-----------------|----------| +| `InMemorySessionService` | `dict.update()` | **Shallow merge** (顶层 key 合并) | +| `DatabaseSessionService` | `dict \| dict` operator | **Shallow merge** (顶层 key 合并) | +| `SqliteSessionService` | SQLite `json_patch()` | **Recursive merge** (RFC 7396) | + +这是一个明显的 **接口实现不一致 Bug**,违反了 `BaseSessionService` 的行为契约。 + +--- + +## Detailed Behavior Comparison + +### Test Case + +假设初始状态: +```python +session.state = { + "top_key": "value1", + "nested": { + "inner_a": 1, + "inner_b": 2 + } +} +``` + +使用以下 `state_delta` 进行更新: +```python +state_delta = { + "nested": { + "inner_a": 100, + "inner_c": 300 + }, + "new_key": "added" +} +``` + +### Expected Results + +#### 1. InMemorySessionService (Shallow Merge) + +**Code Location**: `src/google/adk/sessions/in_memory_session_service.py:362` +```python +if session_state_delta: + storage_session.state.update(session_state_delta) +``` + +**Result**: +```python +{ + "top_key": "value1", # 保留 + "nested": { # 完全替换! + "inner_a": 100, + "inner_c": 300 + }, + "new_key": "added" # 新增 +} +``` + +**注意**: `inner_b` 丢失了,因为 `nested` dict 被整体替换。 + +#### 2. DatabaseSessionService (Shallow Merge) + +**Code Location**: `src/google/adk/sessions/database_session_service.py:727-729` +```python +storage_session.state = ( + storage_session.state | state_deltas["session"] +) +``` + +**Result** (与 InMemory 相同): +```python +{ + "top_key": "value1", + "nested": { + "inner_a": 100, + "inner_c": 300 + }, + "new_key": "added" +} +``` + +#### 3. SqliteSessionService (Recursive Merge) + +**Code Location**: `src/google/adk/sessions/sqlite_session_service.py:562-564` +```python +"UPDATE sessions SET state=json_patch(state, ?), update_time=? WHERE" +" app_name=? AND user_id=? AND id=?", +( + json.dumps(delta), # delta = {"nested": {"inner_a": 100, "inner_c": 300}, "new_key": "added"} + now, + app_name, + user_id, + session_id, +), +``` + +SQLite 的 `json_patch()` 实现的是 **[RFC 7396 JSON Merge Patch](https://datatracker.ietf.org/doc/html/rfc7396)**。 + +**RFC 7396 规则**: +1. 如果 patch 值为 `null`,从 target 删除该 key +2. 如果 patch 值是 object **且** target 对应值也是 object → **递归合并** +3. 否则 → 直接替换 + +**Result**: +```python +{ + "top_key": "value1", # 保留 + "nested": { # 递归合并! + "inner_a": 100, # 更新 + "inner_b": 2, # 保留! + "inner_c": 300 # 新增 + }, + "new_key": "added" # 新增 +} +``` + +**关键差异**: `inner_b` 被保留了,因为 `nested` dict 是递归合并而非替换。 + +--- + +## Impact Analysis + +### 1. Functional Impact + +| Scenario | InMemory/Database | Sqlite | +|----------|-------------------|--------| +| 简单值更新 | 一致 | 一致 | +| 新增顶层 key | 一致 | 一致 | +| 嵌套 dict 部分更新 | **丢失其他嵌套 key** | **保留其他嵌套 key** | +| 使用 `null` 删除 key | 不支持 (变为值为 None) | RFC 7396 支持 | + +### 2. Developer Experience + +开发者写的代码在不同存储后端行为不一致: + +```python +# 开发者意图:只更新 nested.inner_a,不影响 nested.inner_b +event = Event( + actions=EventActions( + state_delta={"nested": {"inner_a": "new_value"}} + ) +) +await session_service.append_event(session, event) + +# 实际结果: +# - InMemory/Database: nested = {"inner_a": "new_value"} (inner_b 丢失!) +# - Sqlite: nested = {"inner_a": "new_value", "inner_b": 2} (inner_b 保留) +``` + +--- + +## Recommended Solution + +### Option A: Standardize on RFC 7396 (Recursive Merge) + +**推荐方案**。RFC 7396 是业界标准,语义更直观。 + +**需要修改**: +- `InMemorySessionService`: 将 `dict.update()` 改为递归合并 +- `DatabaseSessionService`: 将 `dict | dict` 改为递归合并 + +**Pros**: +- 符合业界标准 (JSON Merge Patch) +- 语义更符合开发者直觉 ("部分更新" 应该只更新提供的字段) +- Sqlite 已经是此行为,改动最小 + +**Cons**: +- 需要修改两个实现 +- 可能影响依赖当前浅合并语义的现有代码 + +### Option B: Standardize on Shallow Merge + +**不推荐**。浅合并语义不直观,且需要修改 Sqlite (可能更复杂)。 + +**需要修改**: +- `SqliteSessionService`: 不再使用 `json_patch()`,改为序列化后再浅合并 + +**Pros**: +- InMemory/Database 已经是此行为 + +**Cons**: +- Sqlite 改动复杂 (需要放弃原生 json_patch) +- 语义不直观 ("部分更新" 变成 "替换整个嵌套结构") + +--- + +## Implementation Plan (Option A) + +### Phase 1: Add Documentation and XFail Tests + +- [x] 创建此文档 (`contributing/dev/session_state_merge_semantics.md`) +- [x] 在 `tests/sessions/test_session_integration.py` 添加 xfail 测试 +- [ ] 更新 `BaseSessionService` docstring 注明当前不一致状态 + +### Phase 2: Implement Recursive Merge + +修改 `BaseSessionService._update_session_state()` 或各实现: + +```python +# 参考实现: RFC 7396 JSON Merge Patch +def rfc7396_merge(target: dict, patch: dict) -> dict: + """Apply RFC 7396 JSON Merge Patch. + + https://datatracker.ietf.org/doc/html/rfc7396 + """ + result = copy.deepcopy(target) + for key, value in patch.items(): + if value is None: + result.pop(key, None) + elif isinstance(value, dict) and isinstance(result.get(key), dict): + result[key] = rfc7396_merge(result[key], value) + else: + result[key] = copy.deepcopy(value) + return result +``` + +### Phase 3: Migration + +1. 更新 CHANGELOG 说明此 Breaking Change +2. 提供迁移指南: 如果依赖浅合并语义,需要修改代码 +3. 将 xfail 测试改为正常测试 + +--- + +## References + +- [RFC 7396 - JSON Merge Patch](https://datatracker.ietf.org/doc/html/rfc7396) +- [SQLite json_patch() Documentation](https://www.sqlite.org/json1.html#jpatch) +- [Python dict.update() Documentation](https://docs.python.org/3/library/stdtypes.html#dict.update) diff --git a/pyproject.toml b/pyproject.toml index 5815b971f5..b27fad1009 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -237,3 +237,119 @@ plugins = ["pydantic.mypy"] strict = true disable_error_code = ["import-not-found", "import-untyped", "unused-ignore"] follow_imports = "skip" + + +[tool.ruff] +line-length = 100 +target-version = "py310" +lint.select = ["E", "W", "F", "I", "B", "C4", "UP", "SIM"] + +# ruff.lint.ignore - 现有代码豁免清单 +# ======================================== +# 这些规则被豁免是因为现有代码存在大量违反情况。 +# 后续应该逐步收敛,新代码应遵守这些规则。 +# +# 收敛策略建议: +# 1. 新代码必须遵守所有规则(不添加新的 noqa) +# 2. 按优先级逐步修复现有代码: +# - 高优先级:F 系列(逻辑错误)、B 系列(潜在 bug) +# - 中优先级:SIM 系列(代码简化)、C4 系列(理解性简化) +# - 低优先级:UP 系列(Python 版本升级)、E 系列(风格) +# 3. 每修复一类问题,从 ignore 列表中移除对应规则 + +lint.ignore = [ + # UP 系列 (pyupgrade) - Python 版本升级相关 + # 现有代码大量使用 Python 3.10 风格的类型注解, + # 这些规则要求升级到 PEP 585/PEP 604 风格 + "UP006", # non-pep585-annotation: List[str] -> list[str] + "UP007", # non-pep604-annotation-union: Union[str, int] -> str | int + "UP008", # super-call-with-parameters: super() -> super(ClassName, self) + "UP012", # unnecessary-encode-utf8: .encode("utf-8") is default + "UP015", # redundant-open-modes: open(..., "r") is default + "UP024", # os-error-alias: IOError, OSError -> OSError + "UP028", # yield-in-for-loop: yield from iterable + "UP031", # printf-string-formatting: % formatting -> f-strings + "UP032", # f-string: .format() -> f-strings + "UP034", # extraneous-parentheses: 冗余括号 + "UP035", # deprecated-import: typing.Dict etc. -> collections.abc + "UP037", # quoted-annotation: 使用 from __future__ import annotations + "UP045", # non-pep604-annotation-optional: Optional[str] -> str | None + + # I 系列 (isort) - 导入排序 + "I001", # unsorted-imports: 项目已使用 isort,有独立配置 + + # F 系列 (pyflakes) - 逻辑错误 + "F401", # unused-import: 存在有意的重导出 (re-exports) + "F402", # import-shadowed-by-loop-var: 循环变量覆盖导入 + "F403", # undefined-local-with-import-star: from module import * + "F541", # f-string-missing-placeholders: f-string 无占位符 + "F601", # multi-value-repeated-key-literal: 字典重复键 + "F811", # redefined-while-unused: 重定义未使用的变量 + "F821", # undefined-name: 未定义的名称(可能在类型注解中) + "F841", # unused-variable: 未使用的变量 + + # E 系列 (pycodestyle) - 风格错误 + "E501", # line-too-long: 行太长(项目用 80 字符,ruff 配置 100) + "E402", # module-import-not-at-top-of-file: 导入不在文件顶部 + "E711", # none-comparison: x == None -> x is None + "E712", # true-false-comparison: x == True -> x is True + "E721", # type-comparison: type(x) == Y -> isinstance(x, Y) + "E731", # lambda-assignment: 赋值 lambda 到变量 + + # B 系列 (flake8-bugbear) - 潜在 bug + "B004", # unreliable-callable-check: 不可靠的可调用检查 + "B005", # strip-with-multi-characters: .strip() 多字符问题 + "B006", # mutable-argument-default: 可变默认参数 + "B007", # unused-loop-control-variable: 未使用的循环控制变量 + "B008", # function-call-in-default-argument: 默认参数中的函数调用 + "B009", # get-attr-with-constant: getattr(x, "foo") -> x.foo + "B010", # set-attr-with-constant: setattr 常量值 + "B011", # assert-false: assert False 总是失败 + "B015", # useless-comparison: 无用比较 + "B017", # assert-raises-exception: assertRaises(Exception) 太宽泛 + "B023", # function-uses-loop-variable: 闭包使用循环变量 + "B024", # abstract-base-class-without-abstract-method: 无抽象方法的 ABC + "B027", # empty-method-without-abstract-decorator: 空方法无 abstractmethod + "B028", # no-explicit-stacklevel: warn() 无 stacklevel + "B904", # raise-without-from-inside-except: except 中 raise 无 from + "B905", # zip-without-explicit-strict: zip() 无 strict= + + # SIM 系列 (flake8-simplify) - 代码简化 + "SIM101", # duplicate-isinstance-call: 重复 isinstance 调用 + "SIM102", # collapsible-if: 可合并的 if 嵌套 + "SIM103", # needless-bool: 不必要的 bool() 调用 + "SIM105", # suppressible-exception: 可用 contextlib.suppress + "SIM108", # if-else-block-instead-of-if-exp: 可用三元表达式 + "SIM110", # reimplemented-builtin: 重新实现了内置函数 + "SIM113", # enumerate-for-loop: 可用 enumerate() + "SIM114", # if-with-same-arms: if 分支相同 + "SIM117", # multiple-with-statements: 多个 with 可合并 + "SIM118", # in-dict-keys: x in dict.keys() -> x in dict + "SIM300", # magic-values: 测试中的魔法值 + "SIM401", # if-else-block-instead-of-dict-get: 可用 dict.get() + "SIM910", # dict-get-with-none-default: .get(x, None) 是默认 + + # C4 系列 (flake8-comprehensions) - 推导式简化 + "C401", # unnecessary-generator-set: set(generator) -> {x for x in ...} + "C403", # unnecessary-list-comprehension-set: set([...]) -> {...} + "C405", # unnecessary-literal-set: set((1, 2)) -> {1, 2} + "C408", # unnecessary-collection-call: dict() -> {} + "C410", # unnecessary-literal-within-list-call: list((1, 2)) -> [1, 2] + "C413", # unnecessary-call-around-sorted: list(sorted(...)) -> sorted(...) + "C414", # unnecessary-double-cast-or-process: 冗余转换 + "C416", # unnecessary-comprehension: [x for x in iter] -> list(iter) + "C419", # unnecessary-comprehension-in-call: sum(x for x in ...) 无括号 + "C420", # unnecessary-dict-comprehension-for-iterable: {k:v for k,v in ...} + + # W 系列 (pycodestyle) - 空白相关 + "W291", # trailing-whitespace: 尾随空格 + "W293", # blank-line-with-whitespace: 空白行有空格 +] + +[tool.ruff.lint.per-file-ignores] +"contributing/samples/**/*.py" = ["E", "W", "F", "I", "B", "C4", "UP", "SIM"] + +[tool.ruff.format] +quote-style = "preserve" +indent-style = "space" +line-ending = "lf" diff --git a/src/google/adk/__init__.py b/src/google/adk/__init__.py index d48806bacd..a62246f2cf 100644 --- a/src/google/adk/__init__.py +++ b/src/google/adk/__init__.py @@ -17,7 +17,17 @@ from . import version from .agents.context import Context from .agents.llm_agent import Agent +from .errors.agent_timeout_error import AgentTimeoutError +from .errors.agent_timeout_error import TimeoutTrigger +from .errors.agent_timeout_error import TimeoutType from .runners import Runner __version__ = version.__version__ -__all__ = ["Agent", "Context", "Runner"] +__all__ = [ + "Agent", + "AgentTimeoutError", + "Context", + "Runner", + "TimeoutType", + "TimeoutTrigger", +] diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index dec85690b3..e64708eecd 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -14,12 +14,16 @@ from __future__ import annotations +import asyncio +from contextlib import asynccontextmanager import inspect import logging +import time from typing import Any from typing import AsyncGenerator from typing import Awaitable from typing import Callable +from typing import cast from typing import ClassVar from typing import Dict from typing import final @@ -38,6 +42,9 @@ from typing_extensions import override from typing_extensions import TypeAlias +from ..errors.agent_timeout_error import AgentTimeoutError +from ..errors.agent_timeout_error import TimeoutTrigger +from ..errors.agent_timeout_error import TimeoutType from ..events.event import Event from ..events.event_actions import EventActions from ..features import experimental @@ -161,8 +168,29 @@ class MyAgent(BaseAgent): Returns: Optional[types.Content]: The content to return to the user. - When the content is present, an additional event with the provided content - will be appended to event history as an additional agent response. + When present, the actual model response + will be ignored and the provided content will be returned to user. + """ + + single_turn_timeout: Optional[float] = None + """Timeout for a single LLM call in seconds. + + If set, each individual LLM call (including streaming responses) will be + canceled if it takes longer than this value. None means no timeout. + + This timeout is independent of `total_timeout`. If both are set, whichever + triggers first will raise an AgentTimeoutError. + """ + + total_timeout: Optional[float] = None + """Total timeout for the entire agent run in seconds. + + If set, the entire agent execution (including all LLM calls, tool calls, + and sub-agent executions) will be canceled if it takes longer than this + value. None means no timeout. + + This timeout is independent of `single_turn_timeout`. If both are set, + whichever triggers first will raise an AgentTimeoutError. """ def _load_agent_state( @@ -270,6 +298,30 @@ def clone( cloned_agent.parent_agent = None return cloned_agent + # ============================================================================ + # Python 3.10 Compatibility Notes: + # + # The timeout implementation in run_async() and run_live() uses a + # queue + background task pattern instead of Python 3.11's asyncio.timeout() + # context manager for two reasons: + # + # 1. Python 3.10 support: asyncio.timeout() was added in Python 3.11, + # but ADK requires Python 3.10+ compatibility. + # + # 2. AsyncGenerator limitation: asyncio.wait_for() cannot be used directly + # with async generators (functions that use 'yield'). The queue + task + # pattern wraps the async generator execution in a background task, + # allowing timeout control via asyncio.wait_for() on queue.get(). + # + # How it works: + # - A background task (_run_inner) executes the actual agent logic and + # sends events/errors to a queue. + # - The main loop consumes from the queue using asyncio.wait_for() with + # a decreasing timeout. + # - On timeout, the background task is cancelled, which cascades to any + # sub-agents via asyncio.CancelledError. + # ============================================================================ + @final async def run_async( self, @@ -284,24 +336,88 @@ async def run_async( Yields: Event: the events generated by the agent. """ - - with tracer.start_as_current_span(f'invoke_agent {self.name}') as span: - ctx = self._create_invocation_context(parent_context) - tracing.trace_agent_invocation(span, self, ctx) - if event := await self._handle_before_agent_callback(ctx): - yield event - if ctx.end_invocation: - return - - async with Aclosing(self._run_async_impl(ctx)) as agen: - async for event in agen: + start_time = time.time() + queue: asyncio.Queue[ + tuple[str, Event | Exception | None] + ] = asyncio.Queue() + done = asyncio.Event() + + async def _run_inner() -> None: + try: + with tracer.start_as_current_span(f'invoke_agent {self.name}') as span: + ctx = self._create_invocation_context(parent_context) + tracing.trace_agent_invocation(span, self, ctx) + if event := await self._handle_before_agent_callback(ctx): + await queue.put(('event', event)) + if ctx.end_invocation: + await queue.put(('done', None)) + return + + async with Aclosing(self._run_async_impl(ctx)) as agen: + async for event in agen: + await queue.put(('event', event)) + + if ctx.end_invocation: + await queue.put(('done', None)) + return + + if event := await self._handle_after_agent_callback(ctx): + await queue.put(('event', event)) + + await queue.put(('done', None)) + except Exception as e: + await queue.put(('error', e)) + finally: + done.set() + + task = asyncio.create_task(_run_inner()) + + try: + while True: + try: + remaining_timeout = None + if self.total_timeout is not None: + elapsed = time.time() - start_time + remaining_timeout = max(0.0, self.total_timeout - elapsed) + if remaining_timeout <= 0: + raise asyncio.TimeoutError() + + item = await asyncio.wait_for( + queue.get(), + timeout=remaining_timeout, + ) + except asyncio.TimeoutError: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + elapsed = time.time() - start_time + raise AgentTimeoutError( + message='', + timeout_type=TimeoutType.TOTAL, + elapsed_time=elapsed, + trigger=TimeoutTrigger.USER_INPUT, + agent_name=self.name, + ) + + if item[0] == 'event': + event = cast(Event, item[1]) yield event + elif item[0] == 'error': + error = cast(Exception, item[1]) + raise error + elif item[0] == 'done': + break - if ctx.end_invocation: - return - - if event := await self._handle_after_agent_callback(ctx): - yield event + await task + except asyncio.CancelledError: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + raise @final async def run_live( @@ -317,21 +433,84 @@ async def run_live( Yields: Event: the events generated by the agent. """ - - with tracer.start_as_current_span(f'invoke_agent {self.name}') as span: - ctx = self._create_invocation_context(parent_context) - tracing.trace_agent_invocation(span, self, ctx) - if event := await self._handle_before_agent_callback(ctx): - yield event - if ctx.end_invocation: - return - - async with Aclosing(self._run_live_impl(ctx)) as agen: - async for event in agen: + start_time = time.time() + queue: asyncio.Queue[ + tuple[str, Event | Exception | None] + ] = asyncio.Queue() + done = asyncio.Event() + + async def _run_inner() -> None: + try: + with tracer.start_as_current_span(f'invoke_agent {self.name}') as span: + ctx = self._create_invocation_context(parent_context) + tracing.trace_agent_invocation(span, self, ctx) + if event := await self._handle_before_agent_callback(ctx): + await queue.put(('event', event)) + if ctx.end_invocation: + await queue.put(('done', None)) + return + + async with Aclosing(self._run_live_impl(ctx)) as agen: + async for event in agen: + await queue.put(('event', event)) + + if event := await self._handle_after_agent_callback(ctx): + await queue.put(('event', event)) + + await queue.put(('done', None)) + except Exception as e: + await queue.put(('error', e)) + finally: + done.set() + + task = asyncio.create_task(_run_inner()) + + try: + while True: + try: + remaining_timeout = None + if self.total_timeout is not None: + elapsed = time.time() - start_time + remaining_timeout = max(0.0, self.total_timeout - elapsed) + if remaining_timeout <= 0: + raise asyncio.TimeoutError() + + item = await asyncio.wait_for( + queue.get(), + timeout=remaining_timeout, + ) + except asyncio.TimeoutError: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + elapsed = time.time() - start_time + raise AgentTimeoutError( + message='', + timeout_type=TimeoutType.TOTAL, + elapsed_time=elapsed, + trigger=TimeoutTrigger.USER_INPUT, + agent_name=self.name, + ) + + if item[0] == 'event': + event = cast(Event, item[1]) yield event + elif item[0] == 'error': + error = cast(Exception, item[1]) + raise error + elif item[0] == 'done': + break - if event := await self._handle_after_agent_callback(ctx): - yield event + await task + except asyncio.CancelledError: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + raise async def _run_async_impl( self, ctx: InvocationContext diff --git a/src/google/adk/errors/__init__.py b/src/google/adk/errors/__init__.py index 58d482ea38..15f3f50847 100644 --- a/src/google/adk/errors/__init__.py +++ b/src/google/adk/errors/__init__.py @@ -11,3 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .agent_timeout_error import AgentTimeoutError +from .agent_timeout_error import TimeoutTrigger +from .agent_timeout_error import TimeoutType + +__all__ = [ + 'AgentTimeoutError', + 'TimeoutType', + 'TimeoutTrigger', +] diff --git a/src/google/adk/errors/agent_timeout_error.py b/src/google/adk/errors/agent_timeout_error.py new file mode 100644 index 0000000000..19661f1c84 --- /dev/null +++ b/src/google/adk/errors/agent_timeout_error.py @@ -0,0 +1,63 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import enum + + +class TimeoutType(str, enum.Enum): + SINGLE_TURN = 'single_turn' + TOTAL = 'total' + + +class TimeoutTrigger(str, enum.Enum): + LLM_CALL = 'llm_call' + TOOL_CALL = 'tool_call' + USER_INPUT = 'user_input' + + +class AgentTimeoutError(TimeoutError): + def __init__( + self, + message: str, + timeout_type: TimeoutType | str, + elapsed_time: float, + trigger: TimeoutTrigger | str, + agent_name: str | None = None, + ): + self.timeout_type = ( + timeout_type.value + if isinstance(timeout_type, TimeoutType) + else timeout_type + ) + self.elapsed_time = elapsed_time + self.trigger = trigger.value if isinstance(trigger, TimeoutTrigger) else trigger + self.agent_name = agent_name + + timeout_desc = ( + 'Single-turn LLM call' + if self.timeout_type == TimeoutType.SINGLE_TURN + else 'Total agent execution' + ) + agent_info = f' for agent "{agent_name}"' if agent_name else '' + full_message = ( + f'{timeout_desc} timed out{agent_info}. ' + f'Elapsed time: {elapsed_time:.2f}s. ' + f'Triggered during: {self.trigger}.' + ) + if message: + full_message = f'{message} {full_message}' + self.message = full_message + super().__init__(self.message) diff --git a/src/google/adk/errors/version_mismatch_error.py b/src/google/adk/errors/version_mismatch_error.py new file mode 100644 index 0000000000..5b6b1be122 --- /dev/null +++ b/src/google/adk/errors/version_mismatch_error.py @@ -0,0 +1,51 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + + +class VersionMismatchError(Exception): + """Represents an error that occurs when database schema version is incompatible. + + This error is raised when the database schema version stored in the database + does not match the expected version by the current code. This typically + indicates that a migration is needed. + """ + + def __init__( + self, + message: str = "Database schema version is incompatible.", + expected_version: int | None = None, + actual_version: int | None = None, + ): + """Initializes the VersionMismatchError exception. + + Args: + message: An optional custom message to describe the error. + expected_version: The schema version expected by the current code. + actual_version: The actual schema version found in the database. + """ + self.expected_version = expected_version + self.actual_version = actual_version + + if expected_version is not None and actual_version is not None: + self.message = ( + f"Database schema version mismatch: expected version {expected_version}, " + f"but found version {actual_version}. " + "Please run the migration script to update the database schema." + ) + else: + self.message = message + + super().__init__(self.message) diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 2e45708d9e..a65ed8b43c 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -18,11 +18,16 @@ import asyncio import inspect import logging +import time from typing import AsyncGenerator +from typing import cast from typing import Optional from typing import TYPE_CHECKING from google.adk.platform import time as platform_time +from ...errors.agent_timeout_error import AgentTimeoutError +from ...errors.agent_timeout_error import TimeoutTrigger +from ...errors.agent_timeout_error import TimeoutType from google.genai import types from opentelemetry import trace from websockets.exceptions import ConnectionClosed @@ -1166,6 +1171,22 @@ async def _call_llm_async( llm_request: LlmRequest, model_response_event: Event, ) -> AsyncGenerator[LlmResponse, None]: + single_turn_timeout = ( + invocation_context.agent.single_turn_timeout + if hasattr(invocation_context.agent, 'single_turn_timeout') + else None + ) + agent_name = ( + invocation_context.agent.name + if hasattr(invocation_context.agent, 'name') + else None + ) + + start_time = time.time() + queue: asyncio.Queue[ + tuple[str, LlmResponse | Exception | None] + ] = asyncio.Queue() + done = asyncio.Event() async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]: with tracer.start_as_current_span('call_llm') as span: @@ -1262,9 +1283,65 @@ async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]: yield llm_response - async with Aclosing(_call_llm_with_tracing()) as agen: - async for event in agen: - yield event + async def _run_inner() -> None: + try: + async with Aclosing(_call_llm_with_tracing()) as agen: + async for response in agen: + await queue.put(('response', response)) + await queue.put(('done', None)) + except Exception as e: + await queue.put(('error', e)) + finally: + done.set() + + task = asyncio.create_task(_run_inner()) + + try: + while not done.is_set(): + try: + remaining_timeout = None + if single_turn_timeout is not None: + elapsed = time.time() - start_time + remaining_timeout = max(0.0, single_turn_timeout - elapsed) + if remaining_timeout <= 0: + raise asyncio.TimeoutError() + + item = await asyncio.wait_for( + queue.get(), + timeout=remaining_timeout, + ) + except asyncio.TimeoutError: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + elapsed = time.time() - start_time + raise AgentTimeoutError( + message='', + timeout_type=TimeoutType.SINGLE_TURN, + elapsed_time=elapsed, + trigger=TimeoutTrigger.LLM_CALL, + agent_name=agent_name, + ) + + if item[0] == 'response': + response = cast(LlmResponse, item[1]) + yield response + elif item[0] == 'error': + error = cast(Exception, item[1]) + raise error + elif item[0] == 'done': + break + + await task + except asyncio.CancelledError: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + raise def _finalize_model_response_event( self, diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index eda8474c01..0961180669 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -26,6 +26,7 @@ import inspect import logging import threading +import time from typing import Any from typing import AsyncGenerator from typing import cast @@ -41,6 +42,9 @@ from ...agents.live_request_queue import LiveRequestQueue from ...auth.auth_tool import AuthConfig from ...auth.auth_tool import AuthToolArguments +from ...errors.agent_timeout_error import AgentTimeoutError +from ...errors.agent_timeout_error import TimeoutTrigger +from ...errors.agent_timeout_error import TimeoutType from ...events.event import Event from ...events.event_actions import EventActions from ...telemetry.tracing import trace_merged_tool_calls @@ -529,8 +533,17 @@ async def _run_with_trace(): # Step 3: Otherwise, proceed calling the tool normally. if function_response is None: try: + single_turn_timeout = ( + agent.single_turn_timeout + if hasattr(agent, 'single_turn_timeout') + else None + ) function_response = await __call_tool_async( - tool, args=function_args, tool_context=tool_context + tool, + args=function_args, + tool_context=tool_context, + timeout=single_turn_timeout, + agent_name=agent.name, ) except Exception as tool_error: error_response = await _run_on_tool_error_callbacks( @@ -1112,9 +1125,42 @@ async def __call_tool_async( tool: BaseTool, args: dict[str, Any], tool_context: ToolContext, + timeout: Optional[float] = None, + agent_name: Optional[str] = None, ) -> Any: - """Calls the tool.""" - return await tool.run_async(args=args, tool_context=tool_context) + """Calls the tool with optional timeout. + + Args: + tool: The tool to call. + args: The arguments to pass to the tool. + tool_context: The tool context. + timeout: Optional timeout in seconds. If set and the tool takes + longer than this, raises AgentTimeoutError. + agent_name: The name of the agent calling the tool, used in error + messages. + + Returns: + The result from the tool. + + Raises: + AgentTimeoutError: If timeout is set and the tool call takes longer + than the specified timeout. + """ + start_time = time.time() + try: + coro = tool.run_async(args=args, tool_context=tool_context) + if timeout is not None and timeout > 0: + return await asyncio.wait_for(coro, timeout=timeout) + return await coro + except asyncio.TimeoutError: + elapsed = time.time() - start_time + raise AgentTimeoutError( + message='', + timeout_type=TimeoutType.SINGLE_TURN, + elapsed_time=elapsed, + trigger=TimeoutTrigger.TOOL_CALL, + agent_name=agent_name, + ) def __build_response_event( diff --git a/src/google/adk/sessions/__init__.py b/src/google/adk/sessions/__init__.py index 7505eda346..6812ec5238 100644 --- a/src/google/adk/sessions/__init__.py +++ b/src/google/adk/sessions/__init__.py @@ -14,6 +14,7 @@ from .base_session_service import BaseSessionService from .in_memory_session_service import InMemorySessionService from .session import Session +from .sqlite_session_service import SqliteSessionService from .state import State from .vertex_ai_session_service import VertexAiSessionService @@ -22,6 +23,7 @@ 'DatabaseSessionService', 'InMemorySessionService', 'Session', + 'SqliteSessionService', 'State', 'VertexAiSessionService', ] diff --git a/src/google/adk/sessions/base_session_service.py b/src/google/adk/sessions/base_session_service.py index af94bb9eeb..765f81b256 100644 --- a/src/google/adk/sessions/base_session_service.py +++ b/src/google/adk/sessions/base_session_service.py @@ -55,6 +55,122 @@ class BaseSessionService(abc.ABC): """Base class for session services. The service provides a set of methods for managing sessions and events. + + ## Behavior Contracts + + All implementations of `BaseSessionService` must adhere to the following + behavioral contracts. These contracts ensure consistent behavior across + different storage backends (in-memory, SQLite, PostgreSQL, Redis, etc.). + + ### 1. Session Creation Atomicity + + `create_session` must be atomic with respect to `session_id`. If a session + with the same `(app_name, user_id, session_id)` already exists: + - The call **must** raise `AlreadyExistsError` + - No partial state changes should be visible + + This ensures that concurrent session creations with the same ID are properly + serialized, with exactly one succeeding and all others failing. + + ### 2. Event Appending Ordering Guarantees + + `append_event` must maintain the following ordering invariants: + + - **Append Order**: Events are appended to `session.events` in the order + they are received. The list index reflects the order of appending. + + - **Timestamp Monotonicity**: Within a single session, the `timestamp` of + newly appended events must be greater than or equal to the `last_update_time` + of the session. After appending, `session.last_update_time` is updated to + `event.timestamp`. + + - **Concurrency Control**: Subclass implementations **may or may not** + serialize concurrent appends to the same session. + + - **Guaranteed serialization**: `DatabaseSessionService` uses per-session + locks and storage revision markers. When two concurrent operations attempt + to append using stale session objects, exactly one succeeds and the other + raises `ValueError` with a message indicating the session was "modified + in storage". + + - **Implementation-defined behavior**: `InMemorySessionService` and + `SqliteSessionService` do NOT guarantee serialization under concurrent + writes. Lost writes or unexpected behavior may occur. + + **Recommendation**: Use `DatabaseSessionService` for production workloads + that require guaranteed concurrency control. + + ### 3. State Update Merge Semantics + + ⚠️ **CURRENTLY INCONSISTENT ACROSS IMPLEMENTATIONS** + + See `contributing/dev/session_state_merge_semantics.md` for detailed analysis. + + **Current Behavior (Bug)**: + - `InMemorySessionService` and `DatabaseSessionService`: Use **shallow merge** + (`dict.update()` and `dict |` operator). Nested dicts are REPLACED, not merged. + - `SqliteSessionService`: Uses **RFC 7396 JSON Merge Patch** (recursive merge). + Nested dicts are RECURSIVELY merged. + + **Expected Future Behavior**: + All implementations should standardize on **RFC 7396 JSON Merge Patch** semantics, + which is the industry standard and more intuitive for "partial updates". + + **RFC 7396 Rules**: + - If patch value is `null` → delete key from target + - If patch value is dict AND target value is dict → **recursively merge** + - Otherwise → **replace** + + **RFC 7396 Example**: + ```python + # Initial state + session.state = {"a": 1, "nested": {"inner_a": 1, "inner_b": 2}} + + # state_delta + {"nested": {"inner_a": 100, "inner_c": 300}, "new_key": "added"} + + # Expected RFC 7396 Result + { + "a": 1, + "nested": { + "inner_a": 100, # updated + "inner_b": 2, # PRESERVED! (recursive merge) + "inner_c": 300 # added + }, + "new_key": "added" + } + ``` + + ### 4. Session Deletion Semantics + + `delete_session` must be idempotent and safe: + + - Deleting a non-existent session **must not** raise an exception + - After a successful deletion, subsequent `get_session` calls for the same + `(app_name, user_id, session_id)` must return `None` + + ### 5. Session Isolation + + Sessions are isolated by `(app_name, user_id, session_id)` tuple: + + - Two sessions with different `session_id` values have independent `state` + and `events`. Modifying one does not affect the other. + + - Two sessions with the same `session_id` but different `user_id` are + completely separate and do not share any state or events. + + - `list_sessions` only returns sessions matching the provided `app_name` + and optional `user_id` filter. + + ### 6. Copy-on-Read Semantics + + Implementations should return copies of session objects from `get_session` + and `list_sessions` to prevent: + - Accidental in-memory modifications from affecting persisted state + - Race conditions between concurrent readers + + This means modifications to a returned `Session` object will not be visible + to other callers unless explicitly persisted via `append_event`. """ @abc.abstractmethod diff --git a/src/google/adk/sessions/sqlite_session_service.py b/src/google/adk/sessions/sqlite_session_service.py index 427bc3e73e..8f0daebb65 100644 --- a/src/google/adk/sessions/sqlite_session_service.py +++ b/src/google/adk/sessions/sqlite_session_service.py @@ -31,6 +31,7 @@ from . import _session_util from ..errors.already_exists_error import AlreadyExistsError +from ..errors.version_mismatch_error import VersionMismatchError from ..events.event import Event from .base_session_service import BaseSessionService from .base_session_service import GetSessionConfig @@ -40,6 +41,10 @@ logger = logging.getLogger("google_adk." + __name__) +SCHEMA_VERSION = 1 + +SCHEMA_VERSION_KEY = "schema_version" + PRAGMA_FOREIGN_KEYS = "PRAGMA foreign_keys = ON" APP_STATES_TABLE_SCHEMA = """ @@ -85,14 +90,53 @@ FOREIGN KEY (app_name, user_id, session_id) REFERENCES sessions(app_name, user_id, id) ON DELETE CASCADE ); """ + +METADATA_TABLE_SCHEMA = """ +CREATE TABLE IF NOT EXISTS metadata ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL +); +""" + CREATE_SCHEMA_SQL = "\n".join([ APP_STATES_TABLE_SCHEMA, USER_STATES_TABLE_SCHEMA, SESSIONS_TABLE_SCHEMA, EVENTS_TABLE_SCHEMA, + METADATA_TABLE_SCHEMA, ]) +def _get_default_db_path() -> str: + """Returns the default database path based on environment variables and XDG Base Directory Specification. + + Priority order (highest to lowest): + 1. ADK_HOME/sessions.db (if ADK_HOME environment variable is set) + 2. XDG_DATA_HOME/adk/sessions.db (if XDG_DATA_HOME environment variable is set) + 3. ~/.local/share/adk/sessions.db (default XDG data directory) + 4. ~/.adk/sessions.db (legacy fallback for backward compatibility) + + Returns: + The default database path as a string. + """ + if "ADK_HOME" in os.environ: + adk_home = os.path.expanduser(os.environ["ADK_HOME"]) + return os.path.join(adk_home, "sessions.db") + + if "XDG_DATA_HOME" in os.environ: + xdg_data_home = os.path.expanduser(os.environ["XDG_DATA_HOME"]) + return os.path.join(xdg_data_home, "adk", "sessions.db") + + home = os.path.expanduser("~") + xdg_default = os.path.join(home, ".local", "share", "adk", "sessions.db") + + legacy_path = os.path.join(home, ".adk", "sessions.db") + if os.path.exists(legacy_path): + return legacy_path + + return xdg_default + + def _parse_db_path(db_path: str) -> tuple[str, str, bool]: """Normalizes a SQLite db path from a URL or filesystem path. @@ -134,14 +178,54 @@ class SqliteSessionService(BaseSessionService): Event data is stored as JSON to allow for schema flexibility as event fields evolve. + + State Merge Semantics: + This implementation uses SQLite's json_patch function (RFC 7396) for + state updates, which performs a RECURSIVE merge of nested dictionaries. + This differs from DatabaseSessionService which uses a shallow dict.update() + merge. For example: + + - Existing state: {"nested": {"a": 1, "b": 2}} + - State delta: {"nested": {"b": 3, "c": 4}} + - SqliteSessionService result: {"nested": {"a": 1, "b": 3, "c": 4}} + - DatabaseSessionService result: {"nested": {"b": 3, "c": 4}} + + When using nested state dictionaries, be aware of this semantic difference. + + Schema Versioning: + This service tracks schema version in the 'metadata' table. If the + database schema version is older than the expected version, a + VersionMismatchError will be raised, and a migration script should be run. """ - def __init__(self, db_path: str): - """Initializes the SQLite session service with a database path.""" + def __init__(self, db_path: str = None): + """Initializes the SQLite session service with a database path. + + Args: + db_path: The path to the SQLite database file. If not provided, + the default path is determined by the following priority order: + 1. ADK_HOME/sessions.db (if ADK_HOME environment variable is set) + 2. XDG_DATA_HOME/adk/sessions.db (if XDG_DATA_HOME environment variable is set) + 3. ~/.local/share/adk/sessions.db (default XDG data directory) + 4. ~/.adk/sessions.db (legacy fallback, only if the file already exists) + The directory will be created if it doesn't exist. + + Raises: + VersionMismatchError: If the database schema version is incompatible + with the current code. + RuntimeError: If the database uses an old schema that requires migration. + """ + if db_path is None: + db_path = _get_default_db_path() + self._db_path, self._db_connect_path, self._db_connect_uri = _parse_db_path( db_path ) + db_dir = os.path.dirname(self._db_path) + if db_dir and not os.path.exists(db_dir): + os.makedirs(db_dir) + if self._is_migration_needed(): raise RuntimeError( f"Database {self._db_path} seems to use an old schema." @@ -153,6 +237,8 @@ def __init__(self, db_path: str): f" {self._db_path}.new to {self._db_path}." ) + self._check_schema_version() + @override async def create_session( self, @@ -460,13 +546,29 @@ async def append_event(self, session: Session, event: Event) -> Event: @asynccontextmanager async def _get_db_connection(self): - """Connects to the db and performs initial setup.""" + """Connects to the db and performs initial setup. + + This method: + 1. Connects to the SQLite database + 2. Enables foreign keys + 3. Creates all tables if they don't exist + 4. Initializes schema version in metadata table if not already set + """ async with aiosqlite.connect( self._db_connect_path, uri=self._db_connect_uri ) as db: db.row_factory = aiosqlite.Row await db.execute(PRAGMA_FOREIGN_KEYS) await db.executescript(CREATE_SCHEMA_SQL) + + await db.execute( + """ + INSERT OR IGNORE INTO metadata (key, value) VALUES (?, ?) + """, + (SCHEMA_VERSION_KEY, str(SCHEMA_VERSION)), + ) + await db.commit() + yield db async def _get_state( @@ -566,27 +668,95 @@ def _is_migration_needed(self) -> bool: self._db_connect_path, uri=self._db_connect_uri ) as conn: cursor = conn.cursor() - # Check if events table exists cursor.execute( "SELECT 1 FROM sqlite_master WHERE type='table' and name='events'" ) if not cursor.fetchone(): - return False # No events table, so no migration needed. + return False - # If events table exists, check for event_data column cursor.execute("PRAGMA table_info(events)") columns = [row[1] for row in cursor.fetchall()] if "event_data" in columns: - return False # New schema: event_data column exists. + return False else: - return ( - True # Old schema: events table exists, but no event_data column. - ) + return True except sqlite3.Error as e: raise RuntimeError( f"Error accessing database {self._db_path}: {e}" ) from e + def _check_schema_version(self) -> None: + """Checks if the database schema version is compatible with the current code. + + This method: + 1. If the database doesn't exist yet, does nothing (version will be + initialized on first connection). + 2. If the database exists: + - First checks if metadata table exists and schema_version is set + - If schema_version exists but doesn't match SCHEMA_VERSION, + raises VersionMismatchError. + - If metadata table or schema_version doesn't exist, attempts to + initialize them. If the database is read-only (e.g., opened with + mode=ro), this initialization is skipped and the version check + is deferred until actual database operations are performed. + + Raises: + VersionMismatchError: If the database schema version is incompatible. + """ + if not os.path.exists(self._db_path): + return + + try: + with sqlite3.connect( + self._db_connect_path, uri=self._db_connect_uri + ) as conn: + cursor = conn.cursor() + + cursor.execute( + "SELECT 1 FROM sqlite_master WHERE type='table' AND name='metadata'" + ) + if not cursor.fetchone(): + try: + cursor.execute(METADATA_TABLE_SCHEMA) + cursor.execute( + "INSERT INTO metadata (key, value) VALUES (?, ?)", + (SCHEMA_VERSION_KEY, str(SCHEMA_VERSION)), + ) + conn.commit() + except sqlite3.OperationalError: + return + return + + cursor.execute( + "SELECT value FROM metadata WHERE key = ?", + (SCHEMA_VERSION_KEY,), + ) + result = cursor.fetchone() + if result is None: + try: + cursor.execute( + "INSERT INTO metadata (key, value) VALUES (?, ?)", + (SCHEMA_VERSION_KEY, str(SCHEMA_VERSION)), + ) + conn.commit() + except sqlite3.OperationalError: + return + return + + stored_version = int(result[0]) + if stored_version != SCHEMA_VERSION: + raise VersionMismatchError( + expected_version=SCHEMA_VERSION, + actual_version=stored_version, + ) + + except sqlite3.Error as e: + if "readonly" in str(e).lower(): + return + raise RuntimeError( + f"Error checking schema version for database {self._db_path}: {e}" + ) from e + def _merge_state(app_state, user_state, session_state): """Merges app, user, and session states into a single dictionary.""" diff --git a/tests/sessions/__init__.py b/tests/sessions/__init__.py new file mode 100644 index 0000000000..58d482ea38 --- /dev/null +++ b/tests/sessions/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/sessions/test_session_integration.py b/tests/sessions/test_session_integration.py new file mode 100644 index 0000000000..3942618a34 --- /dev/null +++ b/tests/sessions/test_session_integration.py @@ -0,0 +1,853 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Integration tests for SessionService implementations. + +This test suite verifies the behavioral contracts defined in BaseSessionService +across different storage backends. It can be used to validate any new +SessionService implementation (e.g., Redis, PostgreSQL, Spanner). + +To add a new SessionService implementation: +1. Create a new SessionServiceType enum value +2. Add it to the `params` list in the `session_service` fixture +3. Implement `get_session_service` to return an instance of your service + +All tests in this file will automatically run against the new implementation. +""" + +from __future__ import annotations + +import asyncio +import enum +from typing import Any + +from google.adk.errors.already_exists_error import AlreadyExistsError +from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions +from google.adk.sessions.base_session_service import BaseSessionService +from google.adk.sessions.database_session_service import DatabaseSessionService +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.sessions.sqlite_session_service import SqliteSessionService +import pytest + + +APP_NAME = "test_app" +USER_ID = "test_user" + + +class SessionServiceType(enum.Enum): + IN_MEMORY = "IN_MEMORY" + DATABASE = "DATABASE" + SQLITE = "SQLITE" + + +def get_session_service( + service_type: SessionServiceType, + tmp_path, +) -> BaseSessionService: + """Creates a fresh session service instance for testing. + + Args: + service_type: The type of session service to create. + tmp_path: Pytest's temporary path fixture for file-based storage. + + Returns: + A new instance of the specified session service. + """ + if service_type == SessionServiceType.IN_MEMORY: + return InMemorySessionService() + if service_type == SessionServiceType.DATABASE: + return DatabaseSessionService("sqlite+aiosqlite:///:memory:") + if service_type == SessionServiceType.SQLITE: + return SqliteSessionService(str(tmp_path / "sessions.db")) + raise ValueError(f"Unknown service type: {service_type}") + + +@pytest.fixture( + params=[ + SessionServiceType.IN_MEMORY, + SessionServiceType.DATABASE, + SessionServiceType.SQLITE, + ] +) +async def session_service(request, tmp_path): + """Parametrized fixture providing fresh SessionService instances. + + This fixture creates a new session service for each test to ensure + isolation. For database-backed services, it handles proper cleanup. + + Yields: + A fresh BaseSessionService instance ready for testing. + """ + service = get_session_service(request.param, tmp_path) + try: + yield service + finally: + if isinstance(service, DatabaseSessionService): + await service.close() + + +async def _create_event( + invocation_id: str, + author: str, + timestamp: float, + state_delta: dict[str, Any] | None = None, +) -> Event: + """Helper to create an Event with consistent parameters.""" + return Event( + invocation_id=invocation_id, + author=author, + timestamp=timestamp, + actions=EventActions(state_delta=state_delta or {}), + ) + + +@pytest.mark.asyncio +async def test_multiple_sessions_same_user_are_isolated(session_service): + """Sessions with different IDs for the same user must be isolated. + + Each session maintains its own independent state and event history. + Modifying one session must not affect any other session, even if they + belong to the same user and app. + + Contract verification: + - Session isolation by (app_name, user_id, session_id) tuple + - State and events are completely independent between sessions + """ + session1 = await session_service.create_session( + app_name=APP_NAME, + user_id=USER_ID, + session_id="session_1", + state={"counter": 1, "shared": "initial"}, + ) + session2 = await session_service.create_session( + app_name=APP_NAME, + user_id=USER_ID, + session_id="session_2", + state={"counter": 100, "shared": "initial"}, + ) + + event1 = await _create_event("inv1", "agent1", 1000.0, {"counter": 2}) + await session_service.append_event(session1, event1) + + event2 = await _create_event("inv2", "agent2", 2000.0, {"counter": 200}) + await session_service.append_event(session2, event2) + + session1_refreshed = await session_service.get_session( + app_name=APP_NAME, user_id=USER_ID, session_id="session_1" + ) + session2_refreshed = await session_service.get_session( + app_name=APP_NAME, user_id=USER_ID, session_id="session_2" + ) + + assert session1_refreshed is not None + assert session2_refreshed is not None + + assert session1_refreshed.state.get("counter") == 2 + assert len(session1_refreshed.events) == 1 + assert session1_refreshed.events[0].invocation_id == "inv1" + + assert session2_refreshed.state.get("counter") == 200 + assert len(session2_refreshed.events) == 1 + assert session2_refreshed.events[0].invocation_id == "inv2" + + +@pytest.mark.asyncio +async def test_state_persists_across_turns(session_service): + """Session state modifications must persist across multiple invocations. + + When an event modifies the session state via state_delta, those changes + must be visible in subsequent get_session calls and available for the + next invocation (turn). + + Contract verification: + - State updates are persisted after append_event + - Subsequent get_session calls return the updated state + """ + session = await session_service.create_session( + app_name=APP_NAME, + user_id=USER_ID, + session_id="persistent_session", + state={"page": 1, "filters": {"category": "all"}}, + ) + + assert session.state.get("page") == 1 + + event1 = await _create_event( + "inv1", "agent", 1000.0, {"page": 2, "filters": {"category": "books"}} + ) + await session_service.append_event(session, event1) + + session_after_turn1 = await session_service.get_session( + app_name=APP_NAME, user_id=USER_ID, session_id="persistent_session" + ) + assert session_after_turn1 is not None + assert session_after_turn1.state.get("page") == 2 + assert session_after_turn1.state.get("filters") == {"category": "books"} + + event2 = await _create_event( + "inv2", "agent", 2000.0, {"page": 3, "view_mode": "list"} + ) + await session_service.append_event(session_after_turn1, event2) + + session_after_turn2 = await session_service.get_session( + app_name=APP_NAME, user_id=USER_ID, session_id="persistent_session" + ) + assert session_after_turn2 is not None + assert session_after_turn2.state.get("page") == 3 + assert session_after_turn2.state.get("view_mode") == "list" + assert session_after_turn2.state.get("filters") == {"category": "books"} + + +@pytest.mark.asyncio +async def test_events_append_order_and_timestamp_monotonic(session_service): + """Events must maintain append order and timestamps must be monotonic. + + Events are appended to session.events in the order they are received. + Each new event's timestamp must be greater than or equal to the previous + event's timestamp, and the session's last_update_time is updated to the + event's timestamp after each append. + + Contract verification: + - Events are appended in order (list index reflects append order) + - Timestamps are preserved as provided + - last_update_time is updated to event.timestamp + """ + session = await session_service.create_session( + app_name=APP_NAME, + user_id=USER_ID, + session_id="ordered_events", + ) + + base_ts = 1000.0 + events_to_append = [] + for i in range(5): + ts = base_ts + float(i) * 100.0 + event = await _create_event(f"inv_{i}", f"agent_{i}", ts, {"seq": i}) + events_to_append.append(event) + await session_service.append_event(session, event) + + refreshed = await session_service.get_session( + app_name=APP_NAME, user_id=USER_ID, session_id="ordered_events" + ) + assert refreshed is not None + + assert len(refreshed.events) == 5 + + for i, event in enumerate(refreshed.events): + assert event.invocation_id == f"inv_{i}" + assert event.timestamp == base_ts + float(i) * 100.0 + state_delta = event.actions.state_delta + assert state_delta.get("seq") == i + + for i in range(1, len(refreshed.events)): + assert refreshed.events[i].timestamp >= refreshed.events[i - 1].timestamp + + assert refreshed.last_update_time == refreshed.events[-1].timestamp + + +@pytest.mark.asyncio +async def test_concurrent_appends_race_condition(session_service): + """Concurrent appends to the same session must detect stale sessions. + + When two concurrent operations try to append events to the same session + using stale session objects (with outdated last_update_time), at most + one should succeed. The other should raise ValueError indicating the + session was "modified in storage". + + This prevents lost writes in concurrent scenarios. + + Note: Only DatabaseSessionService provides full concurrency control with + session-level locking. InMemorySessionService is designed for single-threaded + use only, and SqliteSessionService relies on SQLite's transaction isolation + which may not catch all race conditions in this test scenario. + + Contract verification (for implementations that support it): + - Concurrency control via last_update_time or storage marker + - Exactly one success, one failure when concurrent writers use stale sessions + - No lost writes + """ + if not isinstance(session_service, DatabaseSessionService): + pytest.skip( + "This test requires session-level locking for concurrent append " + "detection, which is only provided by DatabaseSessionService" + ) + + session = await session_service.create_session( + app_name=APP_NAME, + user_id=USER_ID, + session_id="concurrent_session", + ) + + base_ts = session.last_update_time + 100.0 + + stale_session_1 = session.model_copy(deep=True) + stale_session_2 = session.model_copy(deep=True) + + event1 = await _create_event("inv_concurrent_1", "agent_a", base_ts, {"a": 1}) + event2 = await _create_event( + "inv_concurrent_2", "agent_b", base_ts + 50.0, {"b": 2} + ) + + results = await asyncio.gather( + session_service.append_event(stale_session_1, event1), + session_service.append_event(stale_session_2, event2), + return_exceptions=True, + ) + + errors = [r for r in results if isinstance(r, Exception)] + successes = [r for r in results if not isinstance(r, Exception)] + + assert len(successes) == 1, "Expected exactly one successful append" + assert len(errors) == 1, "Expected exactly one failed append" + + assert isinstance(errors[0], ValueError) + error_msg = str(errors[0]).lower() + assert "modified" in error_msg or "stale" in error_msg or "storage" in error_msg + + final_session = await session_service.get_session( + app_name=APP_NAME, user_id=USER_ID, session_id="concurrent_session" + ) + assert final_session is not None + assert len(final_session.events) == 1 + + state = final_session.state + has_a = state.get("a") == 1 + has_b = state.get("b") == 2 + assert has_a ^ has_b, "Expected exactly one of the state updates to persist" + + +@pytest.mark.asyncio +async def test_delete_session_is_idempotent(session_service): + """Deleting a session must be safe and idempotent. + + - delete_session must not raise an exception for non-existent sessions + - After deletion, get_session must return None + - Re-deleting an already-deleted session must not raise + + Contract verification: + - Idempotent deletion (no exception on missing session) + - get_session returns None after deletion + """ + session = await session_service.create_session( + app_name=APP_NAME, + user_id=USER_ID, + session_id="to_delete", + state={"will_be_gone": True}, + ) + + event = await _create_event("inv1", "agent", 1000.0, {"extra": "data"}) + await session_service.append_event(session, event) + + before_delete = await session_service.get_session( + app_name=APP_NAME, user_id=USER_ID, session_id="to_delete" + ) + assert before_delete is not None + + await session_service.delete_session( + app_name=APP_NAME, user_id=USER_ID, session_id="to_delete" + ) + + after_delete = await session_service.get_session( + app_name=APP_NAME, user_id=USER_ID, session_id="to_delete" + ) + assert after_delete is None, "get_session must return None after delete" + + await session_service.delete_session( + app_name=APP_NAME, user_id=USER_ID, session_id="to_delete" + ) + + await session_service.delete_session( + app_name=APP_NAME, user_id=USER_ID, session_id="non_existent" + ) + + +@pytest.mark.asyncio +async def test_state_delta_merge_semantics(session_service): + """State updates use merge semantics, not full replacement. + + When appending an event with state_delta, the delta is merged into the + existing session state. This means: + - New keys are added to the state + - Existing keys have their values updated + - Keys not present in the delta remain unchanged + + This test verifies that state updates do NOT replace the entire state + dictionary, but instead merge in only the changes. + + Note: Different implementations may have different semantics for nested + dictionaries. InMemorySessionService and DatabaseSessionService use + shallow merge (top-level keys only), while SqliteSessionService uses + RFC 7396 JSON Merge Patch (recursive merge for nested dicts). This test + focuses on simple values where all implementations behave consistently. + + Contract verification: + - Keys not in delta are preserved + - Existing keys are updated with new values + - New keys are added to the state + - The entire state dict is NOT replaced + """ + session = await session_service.create_session( + app_name=APP_NAME, + user_id=USER_ID, + session_id="merge_test", + state={ + "unchanged_key": "original_value", + "key_to_update": "old_value", + "counter": 1, + }, + ) + + event1 = await _create_event( + "inv1", + "agent", + 1000.0, + { + "key_to_update": "new_value", + "new_key": "just_added", + "counter": 2, + }, + ) + await session_service.append_event(session, event1) + + refreshed1 = await session_service.get_session( + app_name=APP_NAME, user_id=USER_ID, session_id="merge_test" + ) + assert refreshed1 is not None + + assert refreshed1.state.get("unchanged_key") == "original_value" + + assert refreshed1.state.get("key_to_update") == "new_value" + + assert refreshed1.state.get("new_key") == "just_added" + + assert refreshed1.state.get("counter") == 2 + + event2 = await _create_event( + "inv2", + "agent", + 2000.0, + { + "another_new_key": "added_later", + "counter": 3, + }, + ) + await session_service.append_event(refreshed1, event2) + + refreshed2 = await session_service.get_session( + app_name=APP_NAME, user_id=USER_ID, session_id="merge_test" + ) + assert refreshed2 is not None + + assert refreshed2.state.get("unchanged_key") == "original_value" + assert refreshed2.state.get("key_to_update") == "new_value" + assert refreshed2.state.get("new_key") == "just_added" + assert refreshed2.state.get("another_new_key") == "added_later" + assert refreshed2.state.get("counter") == 3 + + +@pytest.mark.asyncio +async def test_create_session_with_existing_id_raises_already_exists(session_service): + """Creating a session with duplicate ID must raise AlreadyExistsError. + + If create_session is called with a session_id that already exists for + the same (app_name, user_id), it must raise AlreadyExistsError. + + Contract verification: + - Atomic creation with duplicate detection + - AlreadyExistsError raised on duplicate session_id + """ + await session_service.create_session( + app_name=APP_NAME, + user_id=USER_ID, + session_id="duplicate_id", + state={"first": True}, + ) + + with pytest.raises(AlreadyExistsError, match="already exists"): + await session_service.create_session( + app_name=APP_NAME, + user_id=USER_ID, + session_id="duplicate_id", + state={"second": True}, + ) + + +@pytest.mark.asyncio +async def test_different_users_same_session_id_are_isolated(session_service): + """Sessions with same ID but different users must be isolated. + + The session identity is (app_name, user_id, session_id). Two sessions + with the same session_id but different user_id are completely separate. + + Contract verification: + - Full tuple identity: (app_name, user_id, session_id) + - User boundaries are respected + """ + await session_service.create_session( + app_name=APP_NAME, + user_id="user_alice", + session_id="shared_session_id", + state={"owner": "alice", "secret": "alice_secret"}, + ) + + await session_service.create_session( + app_name=APP_NAME, + user_id="user_bob", + session_id="shared_session_id", + state={"owner": "bob", "secret": "bob_secret"}, + ) + + alice_session = await session_service.get_session( + app_name=APP_NAME, user_id="user_alice", session_id="shared_session_id" + ) + bob_session = await session_service.get_session( + app_name=APP_NAME, user_id="user_bob", session_id="shared_session_id" + ) + + assert alice_session is not None + assert bob_session is not None + + assert alice_session.state.get("owner") == "alice" + assert alice_session.state.get("secret") == "alice_secret" + + assert bob_session.state.get("owner") == "bob" + assert bob_session.state.get("secret") == "bob_secret" + + +@pytest.mark.asyncio +@pytest.mark.xfail( + reason="State merge semantics are inconsistent across implementations. " + "Sqlite uses RFC 7396 recursive merge; InMemory/Database use shallow merge. " + "See contributing/dev/session_state_merge_semantics.md for details." +) +async def test_nested_state_merge_is_recursive(session_service): + """Nested dict updates should use RFC 7396 recursive merge semantics. + + This test documents the EXPECTED behavior (RFC 7396 JSON Merge Patch), + which is currently only implemented by SqliteSessionService. + + RFC 7396 rules: + - If patch value is null, delete key from target + - If patch value is dict AND target value is dict → RECURSIVELY MERGE + - Otherwise → REPLACE + + Bug: InMemorySessionService and DatabaseSessionService use shallow merge + (dict.update() and dict | operator), which replaces nested dicts entirely. + + Expected behavior (RFC 7396): + - Update nested.inner_a → nested becomes {inner_a: new, inner_b: preserved} + - NOT shallow merge where nested = {inner_a: new} (inner_b lost!) + + Contract verification (once fixed): + - Nested dicts are recursively merged + - Unspecified nested keys are preserved + - See RFC 7396 for the standard + """ + session = await session_service.create_session( + app_name=APP_NAME, + user_id=USER_ID, + session_id="nested_merge_test", + state={ + "top_level": "unchanged", + "nested": { + "inner_a": 1, + "inner_b": 2, + "inner_c": 3, + }, + }, + ) + + event = await _create_event( + "inv1", + "agent", + 1000.0, + { + "nested": { + "inner_a": 100, + "inner_d": 400, + }, + }, + ) + await session_service.append_event(session, event) + + refreshed = await session_service.get_session( + app_name=APP_NAME, user_id=USER_ID, session_id="nested_merge_test" + ) + assert refreshed is not None + + assert refreshed.state.get("top_level") == "unchanged" + + assert refreshed.state.get("nested") == { + "inner_a": 100, + "inner_b": 2, + "inner_c": 3, + "inner_d": 400, + } + + assert "inner_b" in refreshed.state["nested"] + assert "inner_c" in refreshed.state["nested"] + + +@pytest.mark.asyncio +@pytest.mark.xfail( + reason="Concurrent append behavior is implementation-defined. " + "Only DatabaseSessionService provides guaranteed serialization via locks. " + "InMemory and Sqlite have undefined behavior under concurrent writes." +) +async def test_concurrent_appends_document_undefined_behavior(session_service): + """Document the actual behavior of concurrent appends (for debugging). + + This test is marked xfail because: + - DatabaseSessionService: Uses per-session locks → one success, one error + - InMemorySessionService: No locking → both may succeed (lost writes possible) + - SqliteSessionService: Transaction-based → behavior depends on timing + + This test exists to DOCUMENT the current behavior, not to enforce a contract. + For guaranteed serialization, use DatabaseSessionService. + + Note: See BaseSessionService docstring which states: + "Subclass may or may not serialize concurrent appends; use DatabaseSessionService + for guaranteed serialization." + """ + if isinstance(session_service, DatabaseSessionService): + pytest.skip("DatabaseSessionService has guaranteed serialization") + + session = await session_service.create_session( + app_name=APP_NAME, + user_id=USER_ID, + session_id="concurrent_behavior", + ) + + base_ts = session.last_update_time + 100.0 + + stale_session_1 = session.model_copy(deep=True) + stale_session_2 = session.model_copy(deep=True) + + event1 = await _create_event("inv_concurrent_1", "agent_a", base_ts, {"a": 1}) + event2 = await _create_event( + "inv_concurrent_2", "agent_b", base_ts + 50.0, {"b": 2} + ) + + results = await asyncio.gather( + session_service.append_event(stale_session_1, event1), + session_service.append_event(stale_session_2, event2), + return_exceptions=True, + ) + + final_session = await session_service.get_session( + app_name=APP_NAME, user_id=USER_ID, session_id="concurrent_behavior" + ) + assert final_session is not None + + print(f"\n=== Concurrent append observation for {type(session_service).__name__} ===") + print(f"Results: {[type(r).__name__ if isinstance(r, Exception) else 'SUCCESS' for r in results]}") + print(f"Final events count: {len(final_session.events)}") + print(f"Final state: {final_session.state}") + + assert False, ( + f"This test documents behavior, not enforces contract. " + f"Use DatabaseSessionService for guaranteed serialization." + ) + + +@pytest.mark.asyncio +async def test_cross_app_same_user_isolation(session_service): + """Sessions must be isolated across different apps for the same user. + + Session identity is (app_name, user_id, session_id). Two sessions with + the same user_id and session_id but different app_name must be completely + isolated. + + This is important for: + - Multi-tenant applications + - Different apps sharing the same user base + - Security boundaries + + Contract verification: + - Full tuple identity: (app_name, user_id, session_id) + - App boundaries are respected + - State and events are isolated + """ + APP_1 = "shopping_app" + APP_2 = "banking_app" + SHARED_USER = "user_123" + SHARED_SESSION_ID = "session_abc" + + await session_service.create_session( + app_name=APP_1, + user_id=SHARED_USER, + session_id=SHARED_SESSION_ID, + state={ + "cart": ["item1", "item2"], + "total": 99.99, + }, + ) + + await session_service.create_session( + app_name=APP_2, + user_id=SHARED_USER, + session_id=SHARED_SESSION_ID, + state={ + "balance": 10000.00, + "account_id": "ACC-001", + }, + ) + + shopping_session = await session_service.get_session( + app_name=APP_1, user_id=SHARED_USER, session_id=SHARED_SESSION_ID + ) + banking_session = await session_service.get_session( + app_name=APP_2, user_id=SHARED_USER, session_id=SHARED_SESSION_ID + ) + + assert shopping_session is not None + assert banking_session is not None + + assert shopping_session.app_name == APP_1 + assert shopping_session.state.get("cart") == ["item1", "item2"] + assert shopping_session.state.get("total") == 99.99 + assert "balance" not in shopping_session.state + assert "account_id" not in shopping_session.state + + assert banking_session.app_name == APP_2 + assert banking_session.state.get("balance") == 10000.00 + assert banking_session.state.get("account_id") == "ACC-001" + assert "cart" not in banking_session.state + assert "total" not in banking_session.state + + +@pytest.mark.asyncio +async def test_cross_user_same_session_id_isolation(session_service): + """Sessions must be isolated across different users with the same session_id. + + This is a stricter version of test_different_users_same_session_id_are_isolated, + also testing that list_sessions respects user boundaries. + + Contract verification: + - list_sessions with user_id filter only returns that user's sessions + - list_sessions without user_id returns all sessions for the app + - State isolation is maintained + """ + SESSION_ID = "shared_session_001" + + await session_service.create_session( + app_name=APP_NAME, + user_id="alice", + session_id=SESSION_ID, + state={"role": "admin", "permissions": ["read", "write", "delete"]}, + ) + + await session_service.create_session( + app_name=APP_NAME, + user_id="bob", + session_id=SESSION_ID, + state={"role": "guest", "permissions": ["read"]}, + ) + + await session_service.create_session( + app_name=APP_NAME, + user_id="charlie", + session_id=SESSION_ID, + state={"role": "editor", "permissions": ["read", "write"]}, + ) + + alice_list = await session_service.list_sessions( + app_name=APP_NAME, user_id="alice" + ) + assert len(alice_list.sessions) == 1 + assert alice_list.sessions[0].user_id == "alice" + assert alice_list.sessions[0].state.get("role") == "admin" + + all_sessions = await session_service.list_sessions( + app_name=APP_NAME, user_id=None + ) + assert len(all_sessions.sessions) == 3 + + users_found = {s.user_id for s in all_sessions.sessions} + assert users_found == {"alice", "bob", "charlie"} + + alice_session = await session_service.get_session( + app_name=APP_NAME, user_id="alice", session_id=SESSION_ID + ) + bob_session = await session_service.get_session( + app_name=APP_NAME, user_id="bob", session_id=SESSION_ID + ) + charlie_session = await session_service.get_session( + app_name=APP_NAME, user_id="charlie", session_id=SESSION_ID + ) + + assert alice_session.state.get("permissions") == ["read", "write", "delete"] + assert bob_session.state.get("permissions") == ["read"] + assert charlie_session.state.get("permissions") == ["read", "write"] + + +@pytest.mark.asyncio +async def test_returned_session_is_deep_copy_not_reference(session_service): + """Sessions returned from get_session must be copies, not references. + + Modifying a returned session object should NOT affect: + 1. Subsequent get_session calls (persisted state) + 2. The in-memory state of the service (if any) + + This is a critical safety feature to prevent: + - Accidental modifications from affecting persisted state + - Race conditions between concurrent readers + - Side effects in caller code + + Contract verification: + - get_session returns a copy (not reference to internal storage) + - Modifying returned session doesn't affect storage + - list_sessions also returns copies + """ + session = await session_service.create_session( + app_name=APP_NAME, + user_id=USER_ID, + session_id="copy_test", + state={ + "counter": 0, + "config": {"theme": "light", "notifications": True}, + }, + ) + + read_1 = await session_service.get_session( + app_name=APP_NAME, user_id=USER_ID, session_id="copy_test" + ) + assert read_1 is not None + + read_1.state["counter"] = 999 + read_1.state["config"]["theme"] = "dark" + read_1.state["new_field"] = "accidentally_added" + + read_2 = await session_service.get_session( + app_name=APP_NAME, user_id=USER_ID, session_id="copy_test" + ) + assert read_2 is not None + + assert read_2.state.get("counter") == 0 + assert read_2.state.get("config") == {"theme": "light", "notifications": True} + assert "new_field" not in read_2.state + + list_result = await session_service.list_sessions( + app_name=APP_NAME, user_id=USER_ID + ) + assert len(list_result.sessions) == 1 + + listed_session = list_result.sessions[0] + listed_session.state["counter"] = 888 + + read_3 = await session_service.get_session( + app_name=APP_NAME, user_id=USER_ID, session_id="copy_test" + ) + assert read_3 is not None + assert read_3.state.get("counter") == 0 diff --git a/tests/unittests/agents/test_agent_timeout.py b/tests/unittests/agents/test_agent_timeout.py new file mode 100644 index 0000000000..413455358e --- /dev/null +++ b/tests/unittests/agents/test_agent_timeout.py @@ -0,0 +1,355 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for agent timeout mechanism.""" + +from __future__ import annotations + +import asyncio +from typing import AsyncGenerator +from typing import Optional +from unittest import mock + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.invocation_context import InvocationContext +from google.adk.agents.llm_agent import LlmAgent +from google.adk.agents.llm_agent import Agent +from google.adk.errors.agent_timeout_error import AgentTimeoutError +from google.adk.errors.agent_timeout_error import TimeoutTrigger +from google.adk.errors.agent_timeout_error import TimeoutType +from google.adk.events.event import Event +from google.adk.plugins.plugin_manager import PluginManager +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.genai import types +import pytest +from typing_extensions import override + +from .. import testing_utils + + +async def _create_test_invocation_context( + agent: BaseAgent, +) -> InvocationContext: + """Create a test invocation context for timeout tests.""" + session_service = InMemorySessionService() + session = await session_service.create_session( + app_name='test_app', user_id='test_user' + ) + return InvocationContext( + invocation_id='test_invocation', + agent=agent, + session=session, + session_service=session_service, + plugin_manager=PluginManager(plugins=[]), + ) + + +_sleep_seconds_by_agent: dict[str, float] = {} +_started_events: dict[str, asyncio.Event] = {} +_cancelled_events: dict[str, asyncio.Event] = {} +_sub_agent_by_parent: dict[str, str] = {} + + +class _SlowTestingAgent(BaseAgent): + """A testing agent that simulates slow execution.""" + + def __init__(self, name: str, sleep_seconds: float = 5.0): + super().__init__(name=name) + _sleep_seconds_by_agent[name] = sleep_seconds + _started_events[name] = asyncio.Event() + _cancelled_events[name] = asyncio.Event() + + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + _started_events[self.name].set() + sleep_seconds = _sleep_seconds_by_agent.get(self.name, 5.0) + try: + await asyncio.sleep(sleep_seconds) + except asyncio.CancelledError: + _cancelled_events[self.name].set() + raise + yield Event( + author=self.name, + branch=ctx.branch, + invocation_id=ctx.invocation_id, + content=types.Content(parts=[types.Part(text='Done')]), + ) + + @override + async def _run_live_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + _started_events[self.name].set() + sleep_seconds = _sleep_seconds_by_agent.get(self.name, 5.0) + try: + await asyncio.sleep(sleep_seconds) + except asyncio.CancelledError: + _cancelled_events[self.name].set() + raise + yield Event( + author=self.name, + invocation_id=ctx.invocation_id, + branch=ctx.branch, + content=types.Content(parts=[types.Part(text='Done')]), + ) + + +class _AgentWithSubAgent(BaseAgent): + """An agent that calls a sub-agent.""" + + def __init__(self, name: str, sub_agent_name: str): + super().__init__(name=name) + _sub_agent_by_parent[name] = sub_agent_name + _started_events[name] = asyncio.Event() + + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + _started_events[self.name].set() + sub_agent_name = _sub_agent_by_parent.get(self.name) + sub_agent = self.find_sub_agent(sub_agent_name) + if sub_agent: + async for event in sub_agent.run_async(ctx): + yield event + + @override + async def _run_live_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + _started_events[self.name].set() + sub_agent_name = _sub_agent_by_parent.get(self.name) + sub_agent = self.find_sub_agent(sub_agent_name) + if sub_agent: + async for event in sub_agent.run_live(ctx): + yield event + + +@pytest.mark.asyncio +async def test_run_async_without_timeout_backward_compatible(): + """Test that agents without timeout work as before (backward compatibility).""" + agent = _SlowTestingAgent(name='test_agent_bc', sleep_seconds=0.1) + parent_ctx = await _create_test_invocation_context(agent) + + events = [e async for e in agent.run_async(parent_ctx)] + + assert len(events) == 1 + assert events[0].content.parts[0].text == 'Done' + + +@pytest.mark.asyncio +async def test_run_async_total_timeout_triggers(): + """Test that total_timeout triggers when execution takes too long.""" + agent = _SlowTestingAgent(name='test_agent_timeout', sleep_seconds=5.0) + agent.total_timeout = 0.1 + parent_ctx = await _create_test_invocation_context(agent) + + with pytest.raises(AgentTimeoutError) as exc_info: + [e async for e in agent.run_async(parent_ctx)] + + error = exc_info.value + assert error.timeout_type == TimeoutType.TOTAL + assert error.trigger == TimeoutTrigger.USER_INPUT + assert error.agent_name == 'test_agent_timeout' + assert error.elapsed_time >= 0.1 + assert 'total' in str(error).lower() + + +@pytest.mark.asyncio +async def test_run_live_total_timeout_triggers(): + """Test that total_timeout triggers in run_live when execution takes too long.""" + agent = _SlowTestingAgent(name='test_agent_live_timeout', sleep_seconds=5.0) + agent.total_timeout = 0.1 + parent_ctx = await _create_test_invocation_context(agent) + + with pytest.raises(AgentTimeoutError) as exc_info: + [e async for e in agent.run_live(parent_ctx)] + + error = exc_info.value + assert error.timeout_type == TimeoutType.TOTAL + assert error.trigger == TimeoutTrigger.USER_INPUT + assert error.agent_name == 'test_agent_live_timeout' + + +@pytest.mark.asyncio +async def test_sub_agent_cascade_cancellation(): + """Test that when parent agent times out, sub-agent is also cancelled.""" + sub_agent = _SlowTestingAgent(name='sub_agent_cascade', sleep_seconds=5.0) + parent_agent = _AgentWithSubAgent( + name='parent_agent_cascade', sub_agent_name='sub_agent_cascade' + ) + parent_agent.sub_agents = [sub_agent] + parent_agent.total_timeout = 0.1 + + parent_ctx = await _create_test_invocation_context(parent_agent) + + with pytest.raises(AgentTimeoutError): + [e async for e in parent_agent.run_async(parent_ctx)] + + await asyncio.sleep(0.1) + assert _cancelled_events['sub_agent_cascade'].is_set() + + +@pytest.mark.asyncio +async def test_agent_timeout_error_message(): + """Test that AgentTimeoutError has a descriptive message.""" + agent = _SlowTestingAgent(name='my_agent_msg', sleep_seconds=5.0) + agent.total_timeout = 0.1 + parent_ctx = await _create_test_invocation_context(agent) + + with pytest.raises(AgentTimeoutError) as exc_info: + [e async for e in agent.run_async(parent_ctx)] + + error = exc_info.value + message = str(error) + + assert 'my_agent_msg' in message + assert 'total' in message.lower() or 'TOTAL' in message + assert 'USER_INPUT' in message or 'user_input' in message + + +def test_agent_timeout_fields(): + """Test that AgentTimeoutError has all required fields.""" + error = AgentTimeoutError( + message='Test timeout', + timeout_type=TimeoutType.SINGLE_TURN, + elapsed_time=5.5, + trigger=TimeoutTrigger.LLM_CALL, + agent_name='test_agent', + ) + + assert error.timeout_type == TimeoutType.SINGLE_TURN + assert error.elapsed_time == 5.5 + assert error.trigger == TimeoutTrigger.LLM_CALL + assert error.agent_name == 'test_agent' + assert isinstance(error, TimeoutError) + + +def test_agent_timeout_with_str_parameters(): + """Test that AgentTimeoutError accepts string parameters.""" + error = AgentTimeoutError( + message='Test', + timeout_type='single_turn', + elapsed_time=10.0, + trigger='tool_call', + agent_name='agent', + ) + + assert error.timeout_type == 'single_turn' + assert error.trigger == 'tool_call' + + +@pytest.mark.asyncio +async def test_llm_agent_with_timeout_fields(): + """Test that LlmAgent inherits timeout fields from BaseAgent.""" + agent = Agent( + name='test_agent_fields', + model='mock', + single_turn_timeout=30.0, + total_timeout=300.0, + ) + + assert agent.single_turn_timeout == 30.0 + assert agent.total_timeout == 300.0 + + +@pytest.mark.asyncio +async def test_llm_agent_timeout_fields_none_by_default(): + """Test that timeout fields are None by default (backward compatible).""" + agent = Agent( + name='test_agent_default', + model='mock', + ) + + assert agent.single_turn_timeout is None + assert agent.total_timeout is None + + +@pytest.mark.asyncio +async def test_run_async_total_timeout_not_triggered_if_fast_enough(): + """Test that timeout is not triggered if execution finishes within timeout.""" + agent = _SlowTestingAgent(name='test_agent_fast', sleep_seconds=0.1) + agent.total_timeout = 5.0 + parent_ctx = await _create_test_invocation_context(agent) + + events = [e async for e in agent.run_async(parent_ctx)] + + assert len(events) == 1 + assert events[0].content.parts[0].text == 'Done' + + +@pytest.mark.asyncio +async def test_run_async_cancelled_properly(): + """Test that internal task is properly cancelled on timeout.""" + agent = _SlowTestingAgent(name='test_agent_cancel', sleep_seconds=5.0) + agent.total_timeout = 0.1 + parent_ctx = await _create_test_invocation_context(agent) + + with pytest.raises(AgentTimeoutError): + [e async for e in agent.run_async(parent_ctx)] + + await asyncio.sleep(0.1) + assert _cancelled_events['test_agent_cancel'].is_set() + + +class _AgentWithMultipleSubAgents(BaseAgent): + """An agent that calls multiple sub-agents.""" + + def __init__(self, name: str, sub_agent_names: list[str]): + super().__init__(name=name) + _sub_agent_by_parent[name] = sub_agent_names[0] if sub_agent_names else None + _started_events[name] = asyncio.Event() + self._sub_agent_names = sub_agent_names + + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + _started_events[self.name].set() + for sub_agent_name in self._sub_agent_names: + sub_agent = self.find_sub_agent(sub_agent_name) + if sub_agent: + async for event in sub_agent.run_async(ctx): + yield event + + +@pytest.mark.asyncio +async def test_parent_timeout_cascades_cancel_to_children(): + """Test that when parent agent times out, all sub-agent tasks are cancelled. + + This test verifies that the cascade cancellation works correctly: + 1. Parent agent starts a sub-agent + 2. Parent agent times out + 3. Parent's internal task is cancelled + 4. Sub-agent's run_async generator is exited via Aclosing + 5. Sub-agent's internal task receives the cancellation + """ + sub_agent_1 = _SlowTestingAgent(name='child_1', sleep_seconds=5.0) + sub_agent_2 = _SlowTestingAgent(name='child_2', sleep_seconds=5.0) + parent_agent = _AgentWithMultipleSubAgents( + name='parent_multi', sub_agent_names=['child_1', 'child_2'] + ) + parent_agent.sub_agents = [sub_agent_1, sub_agent_2] + parent_agent.total_timeout = 0.1 + + parent_ctx = await _create_test_invocation_context(parent_agent) + + with pytest.raises(AgentTimeoutError): + [e async for e in parent_agent.run_async(parent_ctx)] + + await asyncio.sleep(0.1) + assert _started_events['parent_multi'].is_set() + assert _cancelled_events['child_1'].is_set() diff --git a/tests/unittests/sessions/test_sqlite_session_service.py b/tests/unittests/sessions/test_sqlite_session_service.py new file mode 100644 index 0000000000..7957141080 --- /dev/null +++ b/tests/unittests/sessions/test_sqlite_session_service.py @@ -0,0 +1,361 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sqlite3 +import tempfile +import os + +from google.adk.errors.version_mismatch_error import VersionMismatchError +from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions +from google.adk.sessions import SqliteSessionService +from google.adk.sessions.sqlite_session_service import _get_default_db_path +import pytest + + +@pytest.mark.asyncio +async def test_create_session(): + """测试创建 session 功能""" + with tempfile.TemporaryDirectory() as tmp_dir: + db_path = os.path.join(tmp_dir, "test.db") + service = SqliteSessionService(db_path) + + app_name = "test_app" + user_id = "test_user" + state = {"key": "value"} + + session = await service.create_session( + app_name=app_name, + user_id=user_id, + state=state + ) + + assert session.app_name == app_name + assert session.user_id == user_id + assert session.id is not None + assert session.state == state + + assert os.path.exists(db_path) + + +@pytest.mark.asyncio +async def test_get_nonexistent_session_returns_none(): + """测试获取不存在的 session 返回 None""" + with tempfile.TemporaryDirectory() as tmp_dir: + db_path = os.path.join(tmp_dir, "test.db") + service = SqliteSessionService(db_path) + + session = await service.get_session( + app_name="nonexistent_app", + user_id="nonexistent_user", + session_id="nonexistent_id" + ) + + assert session is None + + +@pytest.mark.asyncio +async def test_append_event_and_restore(): + """测试 append event 后能正确恢复""" + with tempfile.TemporaryDirectory() as tmp_dir: + db_path = os.path.join(tmp_dir, "test.db") + service = SqliteSessionService(db_path) + + session = await service.create_session( + app_name="test_app", + user_id="test_user" + ) + + event = Event( + invocation_id="test_invocation", + author="user", + content=None + ) + await service.append_event(session=session, event=event) + + restored_session = await service.get_session( + app_name="test_app", + user_id="test_user", + session_id=session.id + ) + + assert restored_session is not None + assert len(restored_session.events) == 1 + assert restored_session.events[0].invocation_id == "test_invocation" + assert restored_session.events[0].author == "user" + + +@pytest.mark.asyncio +async def test_list_multiple_sessions(): + """测试 list 多个 session""" + with tempfile.TemporaryDirectory() as tmp_dir: + db_path = os.path.join(tmp_dir, "test.db") + service = SqliteSessionService(db_path) + + app_name = "test_app" + user_id = "test_user" + + session_ids = [] + for i in range(5): + session = await service.create_session( + app_name=app_name, + user_id=user_id, + session_id=f"session_{i}" + ) + session_ids.append(session.id) + + response = await service.list_sessions( + app_name=app_name, + user_id=user_id + ) + + assert len(response.sessions) == 5 + assert {s.id for s in response.sessions} == set(session_ids) + + +@pytest.mark.asyncio +async def test_delete_session(): + """测试 delete 后无法再 get""" + with tempfile.TemporaryDirectory() as tmp_dir: + db_path = os.path.join(tmp_dir, "test.db") + service = SqliteSessionService(db_path) + + session = await service.create_session( + app_name="test_app", + user_id="test_user" + ) + + assert await service.get_session( + app_name="test_app", + user_id="test_user", + session_id=session.id + ) is not None + + await service.delete_session( + app_name="test_app", + user_id="test_user", + session_id=session.id + ) + + assert await service.get_session( + app_name="test_app", + user_id="test_user", + session_id=session.id + ) is None + + +@pytest.mark.asyncio +async def test_default_db_path(): + """测试默认数据库路径功能""" + with tempfile.TemporaryDirectory() as tmp_dir: + service = SqliteSessionService(os.path.join(tmp_dir, "test.db")) + assert service is not None + + nested_dir = os.path.join(tmp_dir, "nested", "dir", "test.db") + service2 = SqliteSessionService(nested_dir) + assert os.path.exists(os.path.dirname(nested_dir)) + + +def test_default_path_env_override(monkeypatch): + """测试 ADK_HOME 环境变量覆盖默认路径""" + with tempfile.TemporaryDirectory() as tmp_dir: + adk_home = os.path.join(tmp_dir, "custom_adk") + monkeypatch.setenv("ADK_HOME", adk_home) + + expected_path = os.path.join(adk_home, "sessions.db") + actual_path = _get_default_db_path() + + assert actual_path == expected_path + + +def test_default_path_xdg_fallback(monkeypatch): + """测试 XDG_DATA_HOME 环境变量作为备用路径""" + with tempfile.TemporaryDirectory() as tmp_dir: + monkeypatch.delenv("ADK_HOME", raising=False) + + xdg_data_home = os.path.join(tmp_dir, "xdg_data") + monkeypatch.setenv("XDG_DATA_HOME", xdg_data_home) + + expected_path = os.path.join(xdg_data_home, "adk", "sessions.db") + actual_path = _get_default_db_path() + + assert actual_path == expected_path + + +def test_default_path_legacy_fallback(monkeypatch, tmp_path): + """测试当 ~/.adk/sessions.db 存在时使用它""" + import google.adk.sessions.sqlite_session_service as session_module + + original_expanduser = os.path.expanduser + + def mock_expanduser(path): + if path == "~": + return str(tmp_path) + return original_expanduser(path) + + monkeypatch.setattr(os.path, "expanduser", mock_expanduser) + + legacy_dir = tmp_path / ".adk" + legacy_dir.mkdir() + legacy_db = legacy_dir / "sessions.db" + legacy_db.touch() + + monkeypatch.delenv("ADK_HOME", raising=False) + monkeypatch.delenv("XDG_DATA_HOME", raising=False) + + actual_path = _get_default_db_path() + + assert actual_path == str(legacy_db) + + +def test_default_path_xdg_default(monkeypatch, tmp_path): + """测试默认 XDG 路径 (~/.local/share/adk/sessions.db)""" + import google.adk.sessions.sqlite_session_service as session_module + + original_expanduser = os.path.expanduser + + def mock_expanduser(path): + if path == "~": + return str(tmp_path) + return original_expanduser(path) + + monkeypatch.setattr(os.path, "expanduser", mock_expanduser) + + monkeypatch.delenv("ADK_HOME", raising=False) + monkeypatch.delenv("XDG_DATA_HOME", raising=False) + + actual_path = _get_default_db_path() + + expected_path = str(tmp_path / ".local" / "share" / "adk" / "sessions.db") + assert actual_path == expected_path + + +def test_schema_version_mismatch_raises(): + """测试 schema 版本不匹配时抛出 VersionMismatchError""" + with tempfile.TemporaryDirectory() as tmp_dir: + db_path = os.path.join(tmp_dir, "test.db") + + with sqlite3.connect(db_path) as conn: + cursor = conn.cursor() + + cursor.execute(""" + CREATE TABLE IF NOT EXISTS sessions ( + app_name TEXT NOT NULL, + user_id TEXT NOT NULL, + id TEXT NOT NULL, + state TEXT NOT NULL, + create_time REAL NOT NULL, + update_time REAL NOT NULL, + PRIMARY KEY (app_name, user_id, id) + ) + """) + + cursor.execute(""" + CREATE TABLE IF NOT EXISTS events ( + id TEXT NOT NULL, + app_name TEXT NOT NULL, + user_id TEXT NOT NULL, + session_id TEXT NOT NULL, + invocation_id TEXT NOT NULL, + timestamp REAL NOT NULL, + event_data TEXT NOT NULL, + PRIMARY KEY (app_name, user_id, session_id, id) + ) + """) + + cursor.execute(""" + CREATE TABLE IF NOT EXISTS metadata ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL + ) + """) + + cursor.execute( + "INSERT INTO metadata (key, value) VALUES (?, ?)", + ("schema_version", "999"), + ) + conn.commit() + + with pytest.raises(VersionMismatchError) as exc_info: + SqliteSessionService(db_path) + + assert "999" in str(exc_info.value) + assert exc_info.value.expected_version == 1 + assert exc_info.value.actual_version == 999 + + +@pytest.mark.asyncio +async def test_schema_version_initialized_on_new_db(): + """测试新数据库初始化时 schema 版本被正确设置""" + with tempfile.TemporaryDirectory() as tmp_dir: + db_path = os.path.join(tmp_dir, "test.db") + + service = SqliteSessionService(db_path) + await service.create_session(app_name="test_app", user_id="test_user") + + with sqlite3.connect(db_path) as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT value FROM metadata WHERE key = ?", + ("schema_version",), + ) + result = cursor.fetchone() + + assert result is not None + assert int(result[0]) == 1 + + +@pytest.mark.asyncio +async def test_state_merge_shallow_vs_recursive_semantics_documented(): + """测试 SqliteSessionService 使用递归 merge (json_patch) vs DatabaseSessionService 使用浅 merge + + 这个测试验证 SqliteSessionService 的 state merge 语义: + - SqliteSessionService: 使用 SQLite json_patch (RFC 7396) - 递归 merge + - DatabaseSessionService: 使用 dict.update() - 浅 merge + + 例如: + - 现有 state: {"nested": {"a": 1, "b": 2}} + - State delta: {"nested": {"b": 3, "c": 4}} + - SqliteSessionService 结果: {"nested": {"a": 1, "b": 3, "c": 4}} (递归 merge) + - DatabaseSessionService 结果: {"nested": {"b": 3, "c": 4}} (浅 merge) + """ + with tempfile.TemporaryDirectory() as tmp_dir: + db_path = os.path.join(tmp_dir, "test.db") + service = SqliteSessionService(db_path) + + session = await service.create_session( + app_name="test_app", + user_id="test_user", + state={"nested": {"a": 1, "b": 2}} + ) + + assert session.state == {"nested": {"a": 1, "b": 2}} + + event = Event( + invocation_id="test_invocation", + author="user", + actions=EventActions( + state_delta={"nested": {"b": 3, "c": 4}} + ) + ) + await service.append_event(session=session, event=event) + + restored_session = await service.get_session( + app_name="test_app", + user_id="test_user", + session_id=session.id + ) + + assert restored_session is not None + assert restored_session.state == {"nested": {"a": 1, "b": 3, "c": 4}}