Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion contributing/samples/adk_team/adk_documentation/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import Optional

from adk_documentation.settings import GITHUB_BASE_URL
from adk_documentation.settings import LOCAL_REPOS_DIR_PATH
from adk_documentation.utils import error_response
from adk_documentation.utils import get_paginated_request
from adk_documentation.utils import get_request
Expand All @@ -30,6 +31,26 @@
import requests


def _resolve_within_repos_dir(path: str) -> Optional[str]:
"""Resolves `path` and verifies it stays within LOCAL_REPOS_DIR_PATH.

The file/search tools are only meant to operate on the repositories cloned
under LOCAL_REPOS_DIR_PATH. Because these tools are driven by an LLM that
processes untrusted input (issue bodies, release diffs), an `os.path.isabs`
check alone lets a crafted instruction read or write arbitrary paths (e.g.
the credentials file pointed to by GOOGLE_APPLICATION_CREDENTIALS, or
/proc/self/environ). Resolve symlinks and `..` segments first, then require
the result to stay inside the managed directory.

Returns the resolved absolute path if it is inside the sandbox, else None.
"""
allowed_root = os.path.realpath(LOCAL_REPOS_DIR_PATH)
resolved = os.path.realpath(path)
if resolved == allowed_root or resolved.startswith(allowed_root + os.sep):
return resolved
return None


def list_releases(repo_owner: str, repo_name: str) -> Dict[str, Any]:
"""Lists all releases for a repository.

Expand Down Expand Up @@ -188,6 +209,14 @@ def read_local_git_repo_file_content(file_path: str) -> Dict[str, Any]:
f"file_path must be an absolute path, got: {file_path}"
)

safe_path = _resolve_within_repos_dir(file_path)
if safe_path is None:
return error_response(
"Access denied: file_path is outside the managed repositories "
f"directory ({LOCAL_REPOS_DIR_PATH}): {file_path}"
)
file_path = safe_path

try:
dir_path = os.path.dirname(file_path)
head_commit_sha = _find_head_commit_sha(dir_path)
Expand Down Expand Up @@ -229,6 +258,13 @@ def list_directory_contents(directory_path: str) -> Dict[str, Any]:
print(
f"Attempting to recursively list contents of directory: {directory_path}"
)
safe_path = _resolve_within_repos_dir(directory_path)
if safe_path is None:
return error_response(
"Access denied: directory_path is outside the managed repositories "
f"directory ({LOCAL_REPOS_DIR_PATH}): {directory_path}"
)
directory_path = safe_path
if not os.path.isdir(directory_path):
return error_response(f"Error: Directory not found at {directory_path}")

Expand Down Expand Up @@ -276,6 +312,13 @@ def search_local_git_repo(
f"Attempting to search for pattern: {pattern} in directory:"
f" {directory_path}, with extensions: {extensions}"
)
safe_path = _resolve_within_repos_dir(directory_path)
if safe_path is None:
return error_response(
"Access denied: directory_path is outside the managed repositories "
f"directory ({LOCAL_REPOS_DIR_PATH}): {directory_path}"
)
directory_path = safe_path
try:
grep_process = _git_grep(directory_path, pattern, extensions, ignored_dirs)
if grep_process.returncode > 1:
Expand Down Expand Up @@ -350,8 +393,15 @@ def create_pull_request_from_changes(
if not changes:
return error_response("No changes provided to apply.")

repo_root = os.path.realpath(local_path)
for relative_path, new_content in changes.items():
full_path = os.path.join(local_path, relative_path)
full_path = os.path.realpath(os.path.join(local_path, relative_path))
# Confine writes to the repository: a crafted relative_path such as
# "../../etc/x" would otherwise escape `local_path`.
if not (full_path == repo_root or full_path.startswith(repo_root + os.sep)):
return error_response(
f"Access denied: change path escapes the repository: {relative_path}"
)
os.makedirs(os.path.dirname(full_path), exist_ok=True)
with open(full_path, "w", encoding="utf-8") as f:
f.write(new_content)
Expand Down
91 changes: 91 additions & 0 deletions contributing/samples/adk_team/adk_documentation/tools_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests that the file/search tools stay confined to LOCAL_REPOS_DIR_PATH.

These tools are driven by an LLM that processes untrusted input (issue bodies,
release diffs), so a path that escapes the managed repositories directory must
be rejected even though it is absolute.
"""

import os
import tempfile
import unittest
from unittest import mock

# The tools must operate only under this sandbox; set it before importing so the
# module-level constant resolves to the test directory.
_SANDBOX = tempfile.mkdtemp(prefix="adk_repos_")
os.environ.setdefault("GITHUB_TOKEN", "test-token")
os.environ["LOCAL_REPOS_DIR_PATH"] = _SANDBOX

from adk_documentation import tools # noqa: E402


class PathConfinementTest(unittest.TestCase):

def setUp(self):
self.repo = os.path.join(_SANDBOX, "adk-docs")
os.makedirs(self.repo, exist_ok=True)
self.inside_file = os.path.join(self.repo, "index.md")
with open(self.inside_file, "w", encoding="utf-8") as f:
f.write("hello")

def test_read_inside_sandbox_succeeds(self):
res = tools.read_local_git_repo_file_content(self.inside_file)
self.assertEqual(res["status"], "success")

def test_read_outside_sandbox_is_denied(self):
for path in ("/etc/passwd", "/proc/self/environ"):
res = tools.read_local_git_repo_file_content(path)
self.assertEqual(res["status"], "error", path)
self.assertIn("Access denied", res["error_message"])

def test_read_symlink_escape_is_denied(self):
link = os.path.join(self.repo, "sneaky")
if not os.path.lexists(link):
os.symlink("/etc/passwd", link)
res = tools.read_local_git_repo_file_content(link)
self.assertEqual(res["status"], "error")
self.assertIn("Access denied", res["error_message"])

def test_list_and_search_outside_sandbox_are_denied(self):
self.assertEqual(tools.list_directory_contents("/etc")["status"], "error")
self.assertEqual(
tools.search_local_git_repo("/etc", "root")["status"], "error"
)

def test_create_pr_rejects_path_traversal_in_changes(self):
# Reach the file-writing step with the git/network calls stubbed out, then
# assert a traversing change key is rejected before any write.
with mock.patch.object(tools, "_run_git_command"), mock.patch.object(
tools, "post_request"
):
res = tools.create_pull_request_from_changes(
repo_owner="google",
repo_name="adk-docs",
local_path=self.repo,
base_branch="main",
changes={"../../../../tmp/evil.txt": "owned"},
commit_message="m",
pr_title="t",
pr_body="b",
)
self.assertEqual(res["status"], "error")
self.assertIn("escapes the repository", res["error_message"])
self.assertFalse(os.path.exists("/tmp/evil.txt"))


if __name__ == "__main__":
unittest.main()