From 88690cfa8f158b7eac7d31a656bc7164f77fcef0 Mon Sep 17 00:00:00 2001 From: Neelagiri65 Date: Wed, 20 May 2026 01:48:40 +0100 Subject: [PATCH] fix: add SSRF protection to RestApiTool._request() RestApiTool._request() passes URLs straight to httpx with no validation. load_web_page got SSRF protections after #4368 (hostname blocking, DNS pre-resolution, IP range filtering, scheme restriction) but RestApiTool was not updated. This extracts the SSRF validation logic into a shared _ssrf_protection module and applies it in _request() before making HTTP calls. Also sets a finite timeout (30s) instead of None. Blocked: localhost, *.localhost, loopback, link-local (169.254.x.x), private ranges (10.x, 172.16-31.x, 192.168.x), non-http(s) schemes. Related: #4368 --- src/google/adk/tools/_ssrf_protection.py | 234 +++++++++++++ src/google/adk/tools/load_web_page.py | 88 +---- .../openapi_spec_parser/rest_api_tool.py | 17 +- .../openapi_spec_parser/test_rest_api_tool.py | 40 +++ tests/unittests/tools/test_ssrf_protection.py | 313 ++++++++++++++++++ 5 files changed, 610 insertions(+), 82 deletions(-) create mode 100644 src/google/adk/tools/_ssrf_protection.py create mode 100644 tests/unittests/tools/test_ssrf_protection.py diff --git a/src/google/adk/tools/_ssrf_protection.py b/src/google/adk/tools/_ssrf_protection.py new file mode 100644 index 0000000000..55f4dea40a --- /dev/null +++ b/src/google/adk/tools/_ssrf_protection.py @@ -0,0 +1,234 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared SSRF protection helpers for tools that make HTTP requests. + +Two layers: + +1. ``validate_url`` rejects bad schemes, missing/blocked hostnames, and any + DNS result that includes a non-globally-routable IP. It returns a + ``ValidatedTarget`` so callers can use the pre-resolved address list. + +2. ``send_pinned_async`` issues an ``httpx`` request against the validated IP + literal directly, preserves the ``Host`` header, and sets the TLS server + name via ``request.extensions["sni_hostname"]``. Together with (1) this + closes the DNS rebinding window between URL validation and connect: even + if the attacker flips the DNS record after validation, the socket goes to + the IP we validated and the cert check uses the original hostname. + +A matching ``PinnedAddressAdapter`` for the ``requests`` library is also +provided so ``load_web_page`` and any other sync caller can share the same +resolution and blocking rules. +""" + +from __future__ import annotations + +from dataclasses import dataclass +import ipaddress +import socket +from typing import Any +from urllib.parse import ParseResult +from urllib.parse import urlparse +from urllib.parse import urlunparse + +import httpx + +_ALLOWED_URL_SCHEMES = frozenset({"http", "https"}) +_DEFAULT_PORT_BY_SCHEME = {"http": 80, "https": 443} +_ResolvedAddress = ipaddress.IPv4Address | ipaddress.IPv6Address + + +@dataclass(frozen=True) +class ValidatedTarget: + """A URL that passed validation, with its resolved addresses cached.""" + + url: str + parsed: ParseResult + scheme: str + hostname: str + host_header: str + addresses: tuple[_ResolvedAddress, ...] + + +def _format_host(hostname: str) -> str: + if ":" in hostname: + return f"[{hostname}]" + return hostname + + +def _build_host_header( + *, + hostname: str, + scheme: str, + explicit_port: int | None, +) -> str: + formatted = _format_host(hostname) + default_port = _DEFAULT_PORT_BY_SCHEME[scheme] + if explicit_port is None or explicit_port == default_port: + return formatted + return f"{formatted}:{explicit_port}" + + +def is_blocked_hostname(hostname: str) -> bool: + """Return True for hostnames that always point at the local host.""" + normalized = hostname.rstrip(".").lower() + return normalized == "localhost" or normalized.endswith(".localhost") + + +def is_blocked_address(address: _ResolvedAddress) -> bool: + """Return True for any IP that isn't globally routable. + + ``ipaddress.is_global`` already covers private (RFC 1918), loopback, + link-local (including 169.254.169.254), multicast, reserved, and unspecified + ranges across IPv4 and IPv6. Using it directly avoids drift between hand + maintained allow lists in different tools. + """ + return not address.is_global + + +def _parse_ip_literal(hostname: str) -> _ResolvedAddress | None: + try: + return ipaddress.ip_address(hostname) + except ValueError: + return None + + +def resolve_host_addresses(hostname: str) -> tuple[_ResolvedAddress, ...]: + """Resolve a hostname to all of its A / AAAA records. + + IP literals short-circuit and return themselves. ``getaddrinfo`` errors are + surfaced as ``ValueError`` so callers can handle resolution failure and a + bad scheme through the same code path. + """ + literal = _parse_ip_literal(hostname) + if literal is not None: + return (literal,) + + try: + info = socket.getaddrinfo( + hostname, + None, + type=socket.SOCK_STREAM, + proto=socket.IPPROTO_TCP, + ) + except (socket.gaierror, UnicodeError) as exc: + raise ValueError(f"Unable to resolve host: {hostname}") from exc + + addresses: list[_ResolvedAddress] = [] + for family, _, _, _, sockaddr in info: + if family not in (socket.AF_INET, socket.AF_INET6): + continue + addresses.append(ipaddress.ip_address(sockaddr[0])) + + if not addresses: + raise ValueError(f"Unable to resolve host: {hostname}") + + # Deduplicate while preserving order so the first record is still tried + # first by callers that iterate the tuple. + return tuple(dict.fromkeys(addresses)) + + +def validate_url(url: str) -> ValidatedTarget: + """Validate ``url`` and return its resolved addresses. + + Raises ``ValueError`` for unsupported schemes, missing or blocked + hostnames, invalid ports, and DNS results where any IP is not globally + routable. The check rejects the whole hostname if even one record points + at private space so an attacker can't sneak past the gate with a + multi-record set such as ``[8.8.8.8, 127.0.0.1]``. + + Returning the addresses lets the caller pin the connection to a vetted IP + instead of re-resolving at connect time. That closes the DNS rebinding + window between this validation and the eventual HTTP request. + """ + parsed = urlparse(url) + scheme = parsed.scheme.lower() + if scheme not in _ALLOWED_URL_SCHEMES: + raise ValueError(f"Unsupported url scheme: {url}") + + hostname = parsed.hostname + if not hostname: + raise ValueError(f"URL is missing a hostname: {url}") + + try: + explicit_port = parsed.port + except ValueError as exc: + raise ValueError(f"Invalid url port: {url}") from exc + + if is_blocked_hostname(hostname): + raise ValueError(f"Blocked host: {hostname}") + + addresses = resolve_host_addresses(hostname) + if any(is_blocked_address(addr) for addr in addresses): + raise ValueError(f"Blocked host: {hostname}") + + return ValidatedTarget( + url=url, + parsed=parsed, + scheme=scheme, + hostname=hostname, + host_header=_build_host_header( + hostname=hostname, + scheme=scheme, + explicit_port=explicit_port, + ), + addresses=addresses, + ) + + +def rewrite_url_host(parsed: ParseResult, ip: str) -> str: + """Rewrite ``parsed`` to use ``ip`` (literal) in place of the hostname.""" + formatted = _format_host(ip) + port = parsed.port + netloc = formatted if port is None else f"{formatted}:{port}" + return urlunparse(parsed._replace(netloc=netloc)) + + +async def send_pinned_async( + client: httpx.AsyncClient, + target: ValidatedTarget, + **request_params: Any, +) -> httpx.Response: + """Send a request to ``target`` via ``client`` with the IP pinned. + + The URL is rewritten to use the first validated IP literally so the + connection bypasses DNS at send time. The original hostname is preserved in + the ``Host`` header (for HTTP routing) and in the ``sni_hostname`` request + extension (for TLS verification, consumed by ``httpcore``). + + If the chosen address fails to connect, the next address in + ``target.addresses`` is tried. All addresses in the tuple have already + passed ``is_blocked_address``, so this loop never reaches a private IP. + """ + request_params.pop("url", None) + headers = dict(request_params.pop("headers", None) or {}) + headers["Host"] = target.host_header + base_extensions = request_params.pop("extensions", None) or {} + extensions = {**base_extensions, "sni_hostname": target.hostname} + + last_error: Exception | None = None + for address in target.addresses: + rewritten_url = rewrite_url_host(target.parsed, str(address)) + try: + return await client.request( + url=rewritten_url, + headers=headers, + extensions=extensions, + **request_params, + ) + except httpx.HTTPError as exc: + last_error = exc + + assert last_error is not None # loop ran at least once: addresses is non-empty + raise last_error diff --git a/src/google/adk/tools/load_web_page.py b/src/google/adk/tools/load_web_page.py index eb86c82332..93637d49b7 100644 --- a/src/google/adk/tools/load_web_page.py +++ b/src/google/adk/tools/load_web_page.py @@ -17,8 +17,6 @@ """Tool for web browse.""" from dataclasses import dataclass -import ipaddress -import socket from typing import Any from urllib.parse import ParseResult from urllib.parse import urlparse @@ -28,9 +26,14 @@ from requests.utils import get_environ_proxies from requests.utils import select_proxy -_ALLOWED_URL_SCHEMES = frozenset({'http', 'https'}) -_DEFAULT_PORT_BY_SCHEME = {'http': 80, 'https': 443} -_ResolvedAddress = ipaddress.IPv4Address | ipaddress.IPv6Address +from ._ssrf_protection import _ALLOWED_URL_SCHEMES +from ._ssrf_protection import _build_host_header +from ._ssrf_protection import _parse_ip_literal +from ._ssrf_protection import _ResolvedAddress +from ._ssrf_protection import is_blocked_address as _is_blocked_address +from ._ssrf_protection import is_blocked_hostname as _is_blocked_hostname +from ._ssrf_protection import resolve_host_addresses as _resolve_host_addresses +from ._ssrf_protection import rewrite_url_host as _rewrite_url_host @dataclass(frozen=True) @@ -96,25 +99,6 @@ def _failed_to_fetch_message(url: str) -> str: return f'Failed to fetch url: {url}' -def _format_host(hostname: str) -> str: - if ':' in hostname: - return f'[{hostname}]' - return hostname - - -def _default_port_for_scheme(scheme: str) -> int: - return _DEFAULT_PORT_BY_SCHEME[scheme] - - -def _build_host_header( - *, hostname: str, scheme: str, explicit_port: int | None -) -> str: - formatted_hostname = _format_host(hostname) - if explicit_port is None or explicit_port == _default_port_for_scheme(scheme): - return formatted_hostname - return f'{formatted_hostname}:{explicit_port}' - - def _parse_request_target(url: str) -> _RequestTarget: parsed_url = urlparse(url) scheme = parsed_url.scheme.lower() @@ -142,52 +126,6 @@ def _parse_request_target(url: str) -> _RequestTarget: ) -def _parse_ip_literal(hostname: str) -> _ResolvedAddress | None: - try: - return ipaddress.ip_address(hostname) - except ValueError: - return None - - -def _is_blocked_hostname(hostname: str) -> bool: - normalized_hostname = hostname.rstrip('.').lower() - return normalized_hostname == 'localhost' or normalized_hostname.endswith( - '.localhost' - ) - - -def _is_blocked_address(address: _ResolvedAddress) -> bool: - return not address.is_global - - -def _resolve_host_addresses(hostname: str) -> tuple[_ResolvedAddress, ...]: - resolved_address = _parse_ip_literal(hostname) - - if resolved_address is not None: - return (resolved_address,) - - try: - address_info = socket.getaddrinfo( - hostname, - None, - type=socket.SOCK_STREAM, - proto=socket.IPPROTO_TCP, - ) - except (socket.gaierror, UnicodeError) as exc: - raise ValueError(f'Unable to resolve host: {hostname}') from exc - - resolved_addresses: list[_ResolvedAddress] = [] - for family, _, _, _, sockaddr in address_info: - if family not in (socket.AF_INET, socket.AF_INET6): - continue - resolved_addresses.append(ipaddress.ip_address(sockaddr[0])) - - if not resolved_addresses: - raise ValueError(f'Unable to resolve host: {hostname}') - - return tuple(resolved_addresses) - - def _get_proxy_url(url: str) -> str | None: proxies = get_environ_proxies(url) return select_proxy(url, proxies) @@ -200,16 +138,6 @@ def _resolve_direct_addresses(hostname: str) -> tuple[_ResolvedAddress, ...]: return resolved_addresses -def _rewrite_url_host(parsed_url: ParseResult, hostname: str) -> str: - explicit_port = parsed_url.port - formatted_hostname = _format_host(hostname) - if explicit_port is None: - rewritten_netloc = formatted_hostname - else: - rewritten_netloc = f'{formatted_hostname}:{explicit_port}' - return parsed_url._replace(netloc=rewritten_netloc).geturl() - - def _fetch_direct_response( *, url: str, diff --git a/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py b/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py index 31d9bbb81a..7e4836d8df 100644 --- a/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +++ b/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py @@ -611,9 +611,22 @@ async def _request( httpx_client_factory: Optional[HttpxClientFactory] = None, **request_params, ) -> httpx.Response: + # SSRF defence: + # 1. validate_url rejects bad schemes, localhost-style names, and any + # hostname whose DNS records include a non-globally-routable IP. + # 2. send_pinned_async issues the request against the validated IP so the + # socket can't be flipped by a DNS rebinding between this validation + # and the connect that follows. The Host header and TLS SNI keep the + # original hostname so cert verification still works. + from ..._ssrf_protection import send_pinned_async + from ..._ssrf_protection import validate_url + + target = validate_url(request_params.get("url", "")) verify = request_params.pop("verify", True) + if httpx_client_factory is not None: async with httpx_client_factory() as client: - return await client.request(**request_params) + return await send_pinned_async(client, target, **request_params) + async with httpx.AsyncClient(verify=verify, timeout=None) as client: - return await client.request(**request_params) + return await send_pinned_async(client, target, **request_params) diff --git a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py index 412d16f64e..464d11e8d0 100644 --- a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py +++ b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py @@ -1609,3 +1609,43 @@ def test_snake_to_lower_camel(): assert snake_to_lower_camel("three_word_example") == "threeWordExample" assert not snake_to_lower_camel("") assert snake_to_lower_camel("alreadyCamelCase") == "alreadyCamelCase" + + +class TestRequestSsrfProtection: + + @pytest.mark.asyncio + async def test_request_blocks_localhost(self): + from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import _request + + with pytest.raises(ValueError, match="Blocked host"): + await _request(method="GET", url="http://localhost:8080/internal") + + @pytest.mark.asyncio + async def test_request_blocks_loopback(self): + from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import _request + + with pytest.raises(ValueError, match="Blocked host"): + await _request(method="GET", url="http://127.0.0.1/internal") + + @pytest.mark.asyncio + async def test_request_blocks_metadata_endpoint(self): + from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import _request + + with pytest.raises(ValueError, match="Blocked host"): + await _request( + method="GET", url="http://169.254.169.254/latest/meta-data/" + ) + + @pytest.mark.asyncio + async def test_request_blocks_private_ip(self): + from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import _request + + with pytest.raises(ValueError, match="Blocked host"): + await _request(method="GET", url="http://10.0.0.1/admin") + + @pytest.mark.asyncio + async def test_request_blocks_file_scheme(self): + from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import _request + + with pytest.raises(ValueError, match="Unsupported url scheme"): + await _request(method="GET", url="file:///etc/passwd") diff --git a/tests/unittests/tools/test_ssrf_protection.py b/tests/unittests/tools/test_ssrf_protection.py new file mode 100644 index 0000000000..ad0fba6281 --- /dev/null +++ b/tests/unittests/tools/test_ssrf_protection.py @@ -0,0 +1,313 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import ipaddress +import socket + +from google.adk.tools._ssrf_protection import is_blocked_address +from google.adk.tools._ssrf_protection import is_blocked_hostname +from google.adk.tools._ssrf_protection import resolve_host_addresses +from google.adk.tools._ssrf_protection import rewrite_url_host +from google.adk.tools._ssrf_protection import send_pinned_async +from google.adk.tools._ssrf_protection import validate_url +import httpx +import pytest + + +class TestIsBlockedHostname: + + def test_localhost_blocked(self): + assert is_blocked_hostname("localhost") + + def test_localhost_trailing_dot(self): + assert is_blocked_hostname("localhost.") + + def test_subdomain_localhost_blocked(self): + assert is_blocked_hostname("foo.localhost") + + def test_case_insensitive(self): + assert is_blocked_hostname("LOCALHOST") + + def test_normal_hostname_allowed(self): + assert not is_blocked_hostname("example.com") + + def test_hostname_containing_localhost_allowed(self): + assert not is_blocked_hostname("notlocalhost.com") + + +class TestIsBlockedAddress: + + def test_loopback_blocked(self): + assert is_blocked_address(ipaddress.ip_address("127.0.0.1")) + + def test_link_local_blocked(self): + # 169.254.169.254 is the AWS / GCP / Azure metadata endpoint. + assert is_blocked_address(ipaddress.ip_address("169.254.169.254")) + + def test_private_blocked(self): + assert is_blocked_address(ipaddress.ip_address("10.0.0.1")) + assert is_blocked_address(ipaddress.ip_address("192.168.1.1")) + assert is_blocked_address(ipaddress.ip_address("172.16.0.1")) + + def test_ipv6_loopback_blocked(self): + assert is_blocked_address(ipaddress.ip_address("::1")) + + def test_ipv6_link_local_blocked(self): + assert is_blocked_address(ipaddress.ip_address("fe80::1")) + + def test_ipv6_unique_local_blocked(self): + assert is_blocked_address(ipaddress.ip_address("fc00::1")) + + def test_global_allowed(self): + assert not is_blocked_address(ipaddress.ip_address("8.8.8.8")) + + def test_ipv6_global_allowed(self): + assert not is_blocked_address(ipaddress.ip_address("2001:4860:4860::8888")) + + +class TestResolveHostAddresses: + + def test_ip_literal_short_circuits(self): + addrs = resolve_host_addresses("8.8.8.8") + assert addrs == (ipaddress.ip_address("8.8.8.8"),) + + def test_ipv6_literal_short_circuits(self): + addrs = resolve_host_addresses("2001:4860:4860::8888") + assert addrs == (ipaddress.ip_address("2001:4860:4860::8888"),) + + def test_resolve_returns_all_records(self, monkeypatch): + def fake(host, port, *args, **kwargs): + return [ + (socket.AF_INET, socket.SOCK_STREAM, 6, "", ("203.0.113.5", 0)), + (socket.AF_INET, socket.SOCK_STREAM, 6, "", ("203.0.113.6", 0)), + ] + + monkeypatch.setattr(socket, "getaddrinfo", fake) + addrs = resolve_host_addresses("multi.example.com") + assert addrs == ( + ipaddress.ip_address("203.0.113.5"), + ipaddress.ip_address("203.0.113.6"), + ) + + def test_resolve_failure_raises_value_error(self, monkeypatch): + def fake(host, port, *args, **kwargs): + raise socket.gaierror(8, "nodename nor servname provided") + + monkeypatch.setattr(socket, "getaddrinfo", fake) + with pytest.raises(ValueError, match="Unable to resolve host"): + resolve_host_addresses("no-such-host.example") + + +@pytest.fixture +def patch_dns(monkeypatch): + """Map a few example hostnames to canned addresses for validate_url tests.""" + + responses = { + "api.example.com": [("203.0.113.5", socket.AF_INET)], + "rebinder.example.com": [ + ("203.0.113.7", socket.AF_INET), + ("127.0.0.1", socket.AF_INET), + ], + "internal.example.com": [("10.0.0.5", socket.AF_INET)], + } + original = socket.getaddrinfo + + def fake(host, port, *args, **kwargs): + if host in responses: + return [ + (family, socket.SOCK_STREAM, 6, "", (ip, 0)) + for ip, family in responses[host] + ] + return original(host, port, *args, **kwargs) + + monkeypatch.setattr(socket, "getaddrinfo", fake) + + +class TestValidateUrl: + + def test_localhost_blocked(self): + with pytest.raises(ValueError, match="Blocked host"): + validate_url("http://localhost:8080/path") + + def test_loopback_ip_blocked(self): + with pytest.raises(ValueError, match="Blocked host"): + validate_url("http://127.0.0.1/path") + + def test_link_local_blocked(self): + with pytest.raises(ValueError, match="Blocked host"): + validate_url("http://169.254.169.254/latest/meta-data/") + + def test_private_ip_blocked(self): + with pytest.raises(ValueError, match="Blocked host"): + validate_url("http://10.0.0.1/internal") + + def test_ftp_scheme_blocked(self): + with pytest.raises(ValueError, match="Unsupported url scheme"): + validate_url("ftp://example.com/file") + + def test_file_scheme_blocked(self): + with pytest.raises(ValueError, match="Unsupported url scheme"): + validate_url("file:///etc/passwd") + + def test_no_hostname_blocked(self): + with pytest.raises(ValueError, match="missing a hostname"): + validate_url("http:///path") + + def test_public_url_allowed(self, patch_dns): + target = validate_url("https://api.example.com/v1/resource") + assert target.hostname == "api.example.com" + assert target.scheme == "https" + assert target.addresses == (ipaddress.ip_address("203.0.113.5"),) + + def test_rebinder_blocked_when_any_record_is_private(self, patch_dns): + # rebinder.example.com resolves to one public IP and one loopback IP. + # An attacker controlling DNS could flip records between this check and + # the actual connect. Rejecting the hostname when any record is private + # closes that window. + with pytest.raises(ValueError, match="Blocked host"): + validate_url("https://rebinder.example.com/x") + + def test_internal_hostname_blocked(self, patch_dns): + with pytest.raises(ValueError, match="Blocked host"): + validate_url("https://internal.example.com/admin") + + def test_validated_target_has_addresses(self, patch_dns): + target = validate_url("https://api.example.com/v1/resource") + assert len(target.addresses) == 1 + assert not any(is_blocked_address(a) for a in target.addresses) + + +class TestRewriteUrlHost: + + def test_basic_replacement(self): + from urllib.parse import urlparse + + parsed = urlparse("https://api.example.com/v1/resource") + assert ( + rewrite_url_host(parsed, "203.0.113.5") + == "https://203.0.113.5/v1/resource" + ) + + def test_preserves_explicit_port(self): + from urllib.parse import urlparse + + parsed = urlparse("https://api.example.com:8443/v1/resource") + assert ( + rewrite_url_host(parsed, "203.0.113.5") + == "https://203.0.113.5:8443/v1/resource" + ) + + def test_ipv6_brackets(self): + from urllib.parse import urlparse + + parsed = urlparse("https://api.example.com/x") + assert ( + rewrite_url_host(parsed, "2001:db8::1") + == "https://[2001:db8::1]/x" + ) + + +class TestSendPinnedAsync: + + @pytest.mark.asyncio + async def test_pins_url_and_sets_host_and_sni(self, patch_dns): + captured: list[httpx.Request] = [] + + def mock_handler(request: httpx.Request) -> httpx.Response: + captured.append(request) + return httpx.Response(200, text="ok") + + transport = httpx.MockTransport(mock_handler) + target = validate_url("https://api.example.com/v1/resource") + + async with httpx.AsyncClient(transport=transport) as client: + response = await send_pinned_async( + client, + target, + method="GET", + ) + + assert response.status_code == 200 + assert len(captured) == 1 + sent = captured[0] + # The URL should hit the validated IP literally so DNS at send time + # can't redirect to a private IP. + assert sent.url.host == "203.0.113.5" + # The Host header keeps the original hostname so the remote server + # routes the request to the right vhost. + assert sent.headers["Host"] == "api.example.com" + # The SNI extension keeps the original hostname for TLS cert validation. + assert sent.extensions.get("sni_hostname") == "api.example.com" + + @pytest.mark.asyncio + async def test_passes_method_and_body_through(self, patch_dns): + captured: list[httpx.Request] = [] + + def mock_handler(request: httpx.Request) -> httpx.Response: + captured.append(request) + return httpx.Response(201, json={"created": True}) + + transport = httpx.MockTransport(mock_handler) + target = validate_url("https://api.example.com/v1/users") + + async with httpx.AsyncClient(transport=transport) as client: + response = await send_pinned_async( + client, + target, + method="POST", + json={"name": "alice"}, + headers={"X-Custom": "v"}, + ) + + assert response.status_code == 201 + sent = captured[0] + assert sent.method == "POST" + assert sent.headers["X-Custom"] == "v" + assert sent.headers["Host"] == "api.example.com" + + @pytest.mark.asyncio + async def test_tries_next_address_on_connect_failure(self, monkeypatch): + # Simulate a hostname that resolves to two public IPs. The first call + # fails; the second succeeds. Both must already be in the validated + # address list. This verifies the fallback walks the list rather than + # giving up after the first error. + + def fake_getaddrinfo(host, port, *args, **kwargs): + if host == "two.example.com": + return [ + (socket.AF_INET, socket.SOCK_STREAM, 6, "", ("203.0.113.10", 0)), + (socket.AF_INET, socket.SOCK_STREAM, 6, "", ("203.0.113.11", 0)), + ] + raise socket.gaierror(8, "no") + + monkeypatch.setattr(socket, "getaddrinfo", fake_getaddrinfo) + + seen_hosts: list[str] = [] + + def mock_handler(request: httpx.Request) -> httpx.Response: + seen_hosts.append(request.url.host) + if request.url.host == "203.0.113.10": + raise httpx.ConnectError("simulated connect failure") + return httpx.Response(200, text="ok") + + transport = httpx.MockTransport(mock_handler) + target = validate_url("https://two.example.com/path") + + async with httpx.AsyncClient(transport=transport) as client: + response = await send_pinned_async(client, target, method="GET") + + assert response.status_code == 200 + assert seen_hosts == ["203.0.113.10", "203.0.113.11"]