Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 234 additions & 0 deletions src/google/adk/tools/_ssrf_protection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Shared SSRF protection helpers for tools that make HTTP requests.

Two layers:

1. ``validate_url`` rejects bad schemes, missing/blocked hostnames, and any
DNS result that includes a non-globally-routable IP. It returns a
``ValidatedTarget`` so callers can use the pre-resolved address list.

2. ``send_pinned_async`` issues an ``httpx`` request against the validated IP
literal directly, preserves the ``Host`` header, and sets the TLS server
name via ``request.extensions["sni_hostname"]``. Together with (1) this
closes the DNS rebinding window between URL validation and connect: even
if the attacker flips the DNS record after validation, the socket goes to
the IP we validated and the cert check uses the original hostname.

A matching ``PinnedAddressAdapter`` for the ``requests`` library is also
provided so ``load_web_page`` and any other sync caller can share the same
resolution and blocking rules.
"""

from __future__ import annotations

from dataclasses import dataclass
import ipaddress
import socket
from typing import Any
from urllib.parse import ParseResult
from urllib.parse import urlparse
from urllib.parse import urlunparse

import httpx

_ALLOWED_URL_SCHEMES = frozenset({"http", "https"})
_DEFAULT_PORT_BY_SCHEME = {"http": 80, "https": 443}
_ResolvedAddress = ipaddress.IPv4Address | ipaddress.IPv6Address


@dataclass(frozen=True)
class ValidatedTarget:
"""A URL that passed validation, with its resolved addresses cached."""

url: str
parsed: ParseResult
scheme: str
hostname: str
host_header: str
addresses: tuple[_ResolvedAddress, ...]


def _format_host(hostname: str) -> str:
if ":" in hostname:
return f"[{hostname}]"
return hostname


def _build_host_header(
*,
hostname: str,
scheme: str,
explicit_port: int | None,
) -> str:
formatted = _format_host(hostname)
default_port = _DEFAULT_PORT_BY_SCHEME[scheme]
if explicit_port is None or explicit_port == default_port:
return formatted
return f"{formatted}:{explicit_port}"


def is_blocked_hostname(hostname: str) -> bool:
"""Return True for hostnames that always point at the local host."""
normalized = hostname.rstrip(".").lower()
return normalized == "localhost" or normalized.endswith(".localhost")


def is_blocked_address(address: _ResolvedAddress) -> bool:
"""Return True for any IP that isn't globally routable.

``ipaddress.is_global`` already covers private (RFC 1918), loopback,
link-local (including 169.254.169.254), multicast, reserved, and unspecified
ranges across IPv4 and IPv6. Using it directly avoids drift between hand
maintained allow lists in different tools.
"""
return not address.is_global


def _parse_ip_literal(hostname: str) -> _ResolvedAddress | None:
try:
return ipaddress.ip_address(hostname)
except ValueError:
return None


def resolve_host_addresses(hostname: str) -> tuple[_ResolvedAddress, ...]:
"""Resolve a hostname to all of its A / AAAA records.

IP literals short-circuit and return themselves. ``getaddrinfo`` errors are
surfaced as ``ValueError`` so callers can handle resolution failure and a
bad scheme through the same code path.
"""
literal = _parse_ip_literal(hostname)
if literal is not None:
return (literal,)

try:
info = socket.getaddrinfo(
hostname,
None,
type=socket.SOCK_STREAM,
proto=socket.IPPROTO_TCP,
)
except (socket.gaierror, UnicodeError) as exc:
raise ValueError(f"Unable to resolve host: {hostname}") from exc

addresses: list[_ResolvedAddress] = []
for family, _, _, _, sockaddr in info:
if family not in (socket.AF_INET, socket.AF_INET6):
continue
addresses.append(ipaddress.ip_address(sockaddr[0]))

if not addresses:
raise ValueError(f"Unable to resolve host: {hostname}")

# Deduplicate while preserving order so the first record is still tried
# first by callers that iterate the tuple.
return tuple(dict.fromkeys(addresses))


def validate_url(url: str) -> ValidatedTarget:
"""Validate ``url`` and return its resolved addresses.

Raises ``ValueError`` for unsupported schemes, missing or blocked
hostnames, invalid ports, and DNS results where any IP is not globally
routable. The check rejects the whole hostname if even one record points
at private space so an attacker can't sneak past the gate with a
multi-record set such as ``[8.8.8.8, 127.0.0.1]``.

Returning the addresses lets the caller pin the connection to a vetted IP
instead of re-resolving at connect time. That closes the DNS rebinding
window between this validation and the eventual HTTP request.
"""
parsed = urlparse(url)
scheme = parsed.scheme.lower()
if scheme not in _ALLOWED_URL_SCHEMES:
raise ValueError(f"Unsupported url scheme: {url}")

hostname = parsed.hostname
if not hostname:
raise ValueError(f"URL is missing a hostname: {url}")

try:
explicit_port = parsed.port
except ValueError as exc:
raise ValueError(f"Invalid url port: {url}") from exc

if is_blocked_hostname(hostname):
raise ValueError(f"Blocked host: {hostname}")

addresses = resolve_host_addresses(hostname)
if any(is_blocked_address(addr) for addr in addresses):
raise ValueError(f"Blocked host: {hostname}")

return ValidatedTarget(
url=url,
parsed=parsed,
scheme=scheme,
hostname=hostname,
host_header=_build_host_header(
hostname=hostname,
scheme=scheme,
explicit_port=explicit_port,
),
addresses=addresses,
)


def rewrite_url_host(parsed: ParseResult, ip: str) -> str:
"""Rewrite ``parsed`` to use ``ip`` (literal) in place of the hostname."""
formatted = _format_host(ip)
port = parsed.port
netloc = formatted if port is None else f"{formatted}:{port}"
return urlunparse(parsed._replace(netloc=netloc))


async def send_pinned_async(
client: httpx.AsyncClient,
target: ValidatedTarget,
**request_params: Any,
) -> httpx.Response:
"""Send a request to ``target`` via ``client`` with the IP pinned.

The URL is rewritten to use the first validated IP literally so the
connection bypasses DNS at send time. The original hostname is preserved in
the ``Host`` header (for HTTP routing) and in the ``sni_hostname`` request
extension (for TLS verification, consumed by ``httpcore``).

If the chosen address fails to connect, the next address in
``target.addresses`` is tried. All addresses in the tuple have already
passed ``is_blocked_address``, so this loop never reaches a private IP.
"""
request_params.pop("url", None)
headers = dict(request_params.pop("headers", None) or {})
headers["Host"] = target.host_header
base_extensions = request_params.pop("extensions", None) or {}
extensions = {**base_extensions, "sni_hostname": target.hostname}

last_error: Exception | None = None
for address in target.addresses:
rewritten_url = rewrite_url_host(target.parsed, str(address))
try:
return await client.request(
url=rewritten_url,
headers=headers,
extensions=extensions,
**request_params,
)
except httpx.HTTPError as exc:
last_error = exc

assert last_error is not None # loop ran at least once: addresses is non-empty
raise last_error
88 changes: 8 additions & 80 deletions src/google/adk/tools/load_web_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
"""Tool for web browse."""

from dataclasses import dataclass
import ipaddress
import socket
from typing import Any
from urllib.parse import ParseResult
from urllib.parse import urlparse
Expand All @@ -28,9 +26,14 @@
from requests.utils import get_environ_proxies
from requests.utils import select_proxy

_ALLOWED_URL_SCHEMES = frozenset({'http', 'https'})
_DEFAULT_PORT_BY_SCHEME = {'http': 80, 'https': 443}
_ResolvedAddress = ipaddress.IPv4Address | ipaddress.IPv6Address
from ._ssrf_protection import _ALLOWED_URL_SCHEMES
from ._ssrf_protection import _build_host_header
from ._ssrf_protection import _parse_ip_literal
from ._ssrf_protection import _ResolvedAddress
from ._ssrf_protection import is_blocked_address as _is_blocked_address
from ._ssrf_protection import is_blocked_hostname as _is_blocked_hostname
from ._ssrf_protection import resolve_host_addresses as _resolve_host_addresses
from ._ssrf_protection import rewrite_url_host as _rewrite_url_host


@dataclass(frozen=True)
Expand Down Expand Up @@ -96,25 +99,6 @@ def _failed_to_fetch_message(url: str) -> str:
return f'Failed to fetch url: {url}'


def _format_host(hostname: str) -> str:
if ':' in hostname:
return f'[{hostname}]'
return hostname


def _default_port_for_scheme(scheme: str) -> int:
return _DEFAULT_PORT_BY_SCHEME[scheme]


def _build_host_header(
*, hostname: str, scheme: str, explicit_port: int | None
) -> str:
formatted_hostname = _format_host(hostname)
if explicit_port is None or explicit_port == _default_port_for_scheme(scheme):
return formatted_hostname
return f'{formatted_hostname}:{explicit_port}'


def _parse_request_target(url: str) -> _RequestTarget:
parsed_url = urlparse(url)
scheme = parsed_url.scheme.lower()
Expand Down Expand Up @@ -142,52 +126,6 @@ def _parse_request_target(url: str) -> _RequestTarget:
)


def _parse_ip_literal(hostname: str) -> _ResolvedAddress | None:
try:
return ipaddress.ip_address(hostname)
except ValueError:
return None


def _is_blocked_hostname(hostname: str) -> bool:
normalized_hostname = hostname.rstrip('.').lower()
return normalized_hostname == 'localhost' or normalized_hostname.endswith(
'.localhost'
)


def _is_blocked_address(address: _ResolvedAddress) -> bool:
return not address.is_global


def _resolve_host_addresses(hostname: str) -> tuple[_ResolvedAddress, ...]:
resolved_address = _parse_ip_literal(hostname)

if resolved_address is not None:
return (resolved_address,)

try:
address_info = socket.getaddrinfo(
hostname,
None,
type=socket.SOCK_STREAM,
proto=socket.IPPROTO_TCP,
)
except (socket.gaierror, UnicodeError) as exc:
raise ValueError(f'Unable to resolve host: {hostname}') from exc

resolved_addresses: list[_ResolvedAddress] = []
for family, _, _, _, sockaddr in address_info:
if family not in (socket.AF_INET, socket.AF_INET6):
continue
resolved_addresses.append(ipaddress.ip_address(sockaddr[0]))

if not resolved_addresses:
raise ValueError(f'Unable to resolve host: {hostname}')

return tuple(resolved_addresses)


def _get_proxy_url(url: str) -> str | None:
proxies = get_environ_proxies(url)
return select_proxy(url, proxies)
Expand All @@ -200,16 +138,6 @@ def _resolve_direct_addresses(hostname: str) -> tuple[_ResolvedAddress, ...]:
return resolved_addresses


def _rewrite_url_host(parsed_url: ParseResult, hostname: str) -> str:
explicit_port = parsed_url.port
formatted_hostname = _format_host(hostname)
if explicit_port is None:
rewritten_netloc = formatted_hostname
else:
rewritten_netloc = f'{formatted_hostname}:{explicit_port}'
return parsed_url._replace(netloc=rewritten_netloc).geturl()


def _fetch_direct_response(
*,
url: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -611,9 +611,22 @@ async def _request(
httpx_client_factory: Optional[HttpxClientFactory] = None,
**request_params,
) -> httpx.Response:
# SSRF defence:
# 1. validate_url rejects bad schemes, localhost-style names, and any
# hostname whose DNS records include a non-globally-routable IP.
# 2. send_pinned_async issues the request against the validated IP so the
# socket can't be flipped by a DNS rebinding between this validation
# and the connect that follows. The Host header and TLS SNI keep the
# original hostname so cert verification still works.
from ..._ssrf_protection import send_pinned_async
from ..._ssrf_protection import validate_url

target = validate_url(request_params.get("url", ""))
verify = request_params.pop("verify", True)

if httpx_client_factory is not None:
async with httpx_client_factory() as client:
return await client.request(**request_params)
return await send_pinned_async(client, target, **request_params)

async with httpx.AsyncClient(verify=verify, timeout=None) as client:
return await client.request(**request_params)
return await send_pinned_async(client, target, **request_params)
Loading