Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added local.db
Binary file not shown.
24 changes: 21 additions & 3 deletions src/google/adk/cli/cli_tools_click.py
Original file line number Diff line number Diff line change
Expand Up @@ -2071,18 +2071,36 @@ def migrate():
default="INFO",
help="Optional. Set the logging level",
)
@click.option(
"--force-untrusted-source",
is_flag=True,
default=False,
help=(
"Optional. Force migration from untrusted or remote database sources "
"(e.g., SMB shares or external IPs). Use with CAUTION as it poses RCE "
"risks if the source is malicious."
),
)
@click.pass_context
def cli_migrate_session(
*, source_db_url: str, dest_db_url: str, log_level: str
ctx,
*,
source_db_url: str,
dest_db_url: str,
log_level: str,
force_untrusted_source: bool,
):
"""Migrates a session database to the latest schema version."""
logs.setup_adk_logger(getattr(logging, log_level.upper()))
try:
from ..sessions.migration import migration_runner

migration_runner.upgrade(source_db_url, dest_db_url)
migration_runner.upgrade(
source_db_url, dest_db_url, force_untrusted_source
)
click.secho("Migration check and upgrade process finished.", fg="green")
except Exception as e:
click.secho(f"Migration failed: {e}", fg="red", err=True)
ctx.exit(1)


@deploy.command("agent_engine")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from datetime import timezone
import json
import logging
import pickle
import sys
from typing import Any

Expand All @@ -29,6 +28,7 @@
from google.adk.sessions import _session_util
from google.adk.sessions.migration import _schema_check_utils
from google.adk.sessions.schemas import v1
from google.adk.utils import serialization_utils
from google.genai import types
import sqlalchemy
from sqlalchemy import create_engine
Expand Down Expand Up @@ -59,7 +59,7 @@ def _row_to_event(row: dict) -> Event:
if actions_val is not None:
try:
if isinstance(actions_val, bytes):
actions = pickle.loads(actions_val)
actions = serialization_utils.secure_loads(actions_val)
else: # for spanner - it might return object directly
actions = actions_val
except Exception as e:
Expand Down
67 changes: 66 additions & 1 deletion src/google/adk/sessions/migration/migration_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

from __future__ import annotations

import ipaddress
import logging
import os
import tempfile
from urllib.parse import urlparse

from google.adk.sessions.migration import _schema_check_utils
from google.adk.sessions.migration import migrate_from_sqlalchemy_pickle
Expand All @@ -39,10 +41,64 @@
}
# The most recent schema version. The migration process stops once this version
# is reached.
#Reached.
LATEST_VERSION = _schema_check_utils.LATEST_SCHEMA_VERSION


def upgrade(source_db_url: str, dest_db_url: str):
class SecurityError(Exception):
"""Raised when a security policy is violated during migration."""

pass


def _is_trusted_url(db_url: str) -> bool:
r"""Checks if a database URL points to a trusted local source.

Trusted sources include:
- Localhost (127.0.0.1, ::1, 'localhost')
- Private network addresses (if explicitly allowed in future, but blocked now for safety)
- Local file paths (sqlite:///path/to/db)

Untrusted sources include:
- External IPs.
- Remote hostnames.
- Windows UNC paths (\\host\share).
"""
try:
parsed = urlparse(db_url)
host = parsed.hostname

# SQLite local paths (sqlite:///path) have no hostname in urlparse usually
if not host:
# Check for Windows UNC paths in the path component
# sqlite:///\\host\share -> path starts with /\\
if parsed.path.startswith("/\\\\") or parsed.path.startswith("//"):
return False
return True

# Check for localhost/loopback
if host.lower() in ("localhost", "127.0.0.1", "::1"):
return True

# Check if host is an IP and if it's a loopback IP
try:
ip = ipaddress.ip_address(host)
return ip.is_loopback
except ValueError:
# Not an IP address, probably a hostname.
# If it's not 'localhost', we treat it as untrusted for safety.
return False

except Exception:
# On parsing error, fail closed
return False


def upgrade(
source_db_url: str,
dest_db_url: str,
force_untrusted_source: bool = False,
):
"""Migrates a database from its current version to the latest version.

If the source database schema is older than the latest version, this
Expand Down Expand Up @@ -72,6 +128,15 @@ def upgrade(source_db_url: str, dest_db_url: str):
"Please provide a different URL for dest_db_url."
)

if not _is_trusted_url(source_db_url) and not force_untrusted_source:
raise SecurityError(
f"Untrusted source database URL detected: {source_db_url}\n"
"Migrating from remote or untrusted sources (e.g., SMB shares or "
"external IPs) poses a SIGNIFICANT Remote Code Execution (RCE) "
"risk if the source data is malicious.\n"
"To proceed anyway, use the --force-untrusted-source flag."
)

current_version = _schema_check_utils.get_db_schema_version(source_db_url)
if current_version == LATEST_VERSION:
logger.info(
Expand Down
37 changes: 37 additions & 0 deletions src/google/adk/sessions/schemas/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from sqlalchemy.types import DateTime
from sqlalchemy.types import TypeDecorator

from google.adk.utils import serialization_utils

DEFAULT_MAX_KEY_LENGTH = 128
DEFAULT_MAX_VARCHAR_LENGTH = 256

Expand Down Expand Up @@ -55,6 +57,41 @@ def process_result_value(self, value, dialect: Dialect):
return value


class JsonEncodedType(DynamicJSON):
"""A JSON-encoded type with hybrid support for secure legacy pickles.

New data is always stored as JSON. When reading, it first attempts to
decode JSON. If that fails and the value is binary, it attempts to
deserialize using serialization_utils.secure_loads (HMAC-verified).
"""

def process_result_value(self, value, dialect: Dialect):
if value is None:
return None

# Try JSON first (for new data or PostgreSQL JSONB)
if dialect.name == "postgresql":
return value

if isinstance(value, str):
try:
return json.loads(value)
except json.JSONDecodeError:
# If it's a string that's not JSON, it might be a corrupted entry
# or an unexpected format. Logic continues to check for binary.
pass

# If JSON failed, check if it's binary legacy data (HMAC signed)
if isinstance(value, bytes):
try:
return serialization_utils.secure_loads(value)
except serialization_utils.SecurityError:
# If both JSON and secure_loads fail, re-raise or handle as Error
raise

return super().process_result_value(value, dialect)


class PreciseTimestamp(TypeDecorator):
"""Represents a timestamp precise to the microsecond."""

Expand Down
33 changes: 3 additions & 30 deletions src/google/adk/sessions/schemas/v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@
from datetime import timezone
import json
import logging
import pickle
from typing import Any
from typing import Optional

from google.adk.platform import uuid as platform_uuid
from google.adk.utils import serialization_utils
from google.genai import types
from sqlalchemy import Boolean
from sqlalchemy import desc
Expand All @@ -57,9 +57,9 @@
from ...events.event import Event
from ...events.event_actions import EventActions
from ..session import Session
from .shared import DEFAULT_MAX_KEY_LENGTH
from .shared import DEFAULT_MAX_VARCHAR_LENGTH
from .shared import DynamicJSON
from .shared import JsonEncodedType
from .shared import PreciseTimestamp

logger = logging.getLogger("google_adk." + __name__)
Expand Down Expand Up @@ -89,33 +89,6 @@ def _truncate_str(value: Optional[str], max_length: int) -> Optional[str]:
return value


class DynamicPickleType(TypeDecorator):
"""Represents a type that can be pickled."""

impl = PickleType

def load_dialect_impl(self, dialect):
if dialect.name == "mysql":
return dialect.type_descriptor(mysql.LONGBLOB)
if dialect.name == "spanner+spanner":
from google.cloud.sqlalchemy_spanner.sqlalchemy_spanner import SpannerPickleType

return dialect.type_descriptor(SpannerPickleType)
return self.impl

def process_bind_param(self, value, dialect):
"""Ensures the pickled value is a bytes object before passing it to the database dialect."""
if value is not None:
if dialect.name in ("spanner+spanner", "mysql"):
return pickle.dumps(value)
return value

def process_result_value(self, value, dialect):
"""Ensures the raw bytes from the database are unpickled back into a Python object."""
if value is not None:
if dialect.name in ("spanner+spanner", "mysql"):
return pickle.loads(value)
return value


class Base(DeclarativeBase):
Expand Down Expand Up @@ -234,7 +207,7 @@ class StorageEvent(Base):

invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
author: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
actions: Mapped[MutableDict[str, Any]] = mapped_column(DynamicPickleType)
actions: Mapped[MutableDict[str, Any]] = mapped_column(JsonEncodedType)
long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column(
Text, nullable=True
)
Expand Down
87 changes: 87 additions & 0 deletions src/google/adk/utils/serialization_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for secure serialization and deserialization."""

import hashlib
import hmac
import os
import pickle
from typing import Any


class SecurityError(Exception):
"""Raised when a security validation fails during deserialization."""

pass


def _get_secret_key() -> bytes:
"""Retrieves the secret key used for HMAC signing.

In a production environment, this should be fetched from a secure secret
manager or KMS. For this remediation, we default to an environment variable.
"""
secret = os.environ.get("ADK_SECURITY_SECRET")
if not secret:
# Fallback for demonstration/local development only.
# WARNING: This should be replaced with mandatory secret fetching in prod.
return b"default_insecure_development_secret"
return secret.encode("utf-8")


def secure_dumps(obj: Any) -> bytes:
"""Serializes an object using pickle and appends an HMAC signature.

Args:
obj: The Python object to serialize.

Returns:
The signed binary blob.
"""
serialized = pickle.dumps(obj)
key = _get_secret_key()
signature = hmac.new(key, serialized, hashlib.sha256).digest()
return signature + serialized


def secure_loads(data: bytes) -> Any:
"""Verifies the HMAC signature and deserializes a binary blob.

Args:
data: The signed binary blob.

Returns:
The deserialized Python object.

Raises:
SecurityError: If the signature is invalid or missing.
"""
if len(data) < 32:
raise SecurityError("Data too short to contain a valid signature.")

signature = data[:32]
serialized = data[32:]

key = _get_secret_key()
expected_signature = hmac.new(key, serialized, hashlib.sha256).digest()

if not hmac.compare_digest(signature, expected_signature):
raise SecurityError(
"Invalid signature detected during deserialization. "
"The data may have been tampered with or originated from an "
"untrusted source."
)

return pickle.loads(serialized)
Loading