From 14974af59d3d25c2b502aa0e33bdbcbc2610dc24 Mon Sep 17 00:00:00 2001 From: Zied Jlassi <6190550+zied-jlassi@users.noreply.github.com> Date: Sat, 20 Jun 2026 22:16:54 +0200 Subject: [PATCH 1/3] fix(crewai-files): block SSRF to non-public addresses in FileUrl fetch FileUrl.read/aread fetched any http(s) URL with follow_redirects=True and no network validation. Because the library auto-coerces "http(s)://..." strings into FileUrl, a URL forwarded as a file input could reach loopback, private, link-local or cloud-metadata addresses (169.254.169.254), or be redirected into them from a public URL -- a Server-Side Request Forgery (CWE-918). Resolve the host and reject non-public addresses before each request, and follow redirects manually so every hop is re-validated. IPv4-mapped IPv6 addresses are normalized so ::ffff:127.0.0.1 cannot bypass the check. The async path resolves DNS in the default executor to avoid blocking the loop. AI-assisted audit (audit by AI and Zied Jlassi). PR labeled llm-generated per CONTRIBUTING.md. --- .../src/crewai_files/core/sources.py | 176 +++++++++++++++++- lib/crewai-files/tests/test_file_url.py | 124 +++++++++++- 2 files changed, 289 insertions(+), 11 deletions(-) diff --git a/lib/crewai-files/src/crewai_files/core/sources.py b/lib/crewai-files/src/crewai_files/core/sources.py index 0a4204d4d2..642a8d75a9 100644 --- a/lib/crewai-files/src/crewai_files/core/sources.py +++ b/lib/crewai-files/src/crewai_files/core/sources.py @@ -2,12 +2,16 @@ from __future__ import annotations -from collections.abc import AsyncIterator, Iterator +import asyncio +from collections.abc import AsyncIterator, Iterator, Sequence import inspect +import ipaddress import json import mimetypes from pathlib import Path +import socket from typing import Annotated, Any, BinaryIO, Protocol, cast, runtime_checkable +from urllib.parse import urljoin, urlparse import aiofiles from pydantic import ( @@ -486,6 +490,118 @@ async def aread_chunks(self, chunk_size: int = 65536) -> AsyncIterator[bytes]: yield chunk +_MAX_REDIRECTS = 5 + + +def _is_blocked_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool: + """Return whether an IP address must not be fetched (SSRF protection). + + Normalizes IPv4-mapped IPv6 addresses (e.g. ``::ffff:127.0.0.1``) to their + IPv4 form first, because :class:`ipaddress.IPv6Address` does not flag the + mapped form as loopback or private on its own, a common SSRF guard bypass. + + Args: + ip: The resolved IP address to classify. + + Returns: + True if the address is loopback, private, link-local, reserved, + multicast, or unspecified; False for a routable public address. + """ + if isinstance(ip, ipaddress.IPv6Address) and ip.ipv4_mapped is not None: + ip = ip.ipv4_mapped + return ( + ip.is_private + or ip.is_loopback + or ip.is_link_local + or ip.is_reserved + or ip.is_multicast + or ip.is_unspecified + ) + + +def _url_host(url: str) -> str: + """Extract the host of an absolute URL. + + Args: + url: The absolute http(s) URL. + + Returns: + The host component. + + Raises: + ValueError: If the URL has no host. + """ + host = urlparse(url).hostname + if not host: + raise ValueError(f"URL has no host: {url}") + return host + + +def _reject_blocked_addrinfo( + url: str, addrinfo: Sequence[tuple[Any, ...]] +) -> None: + """Raise if any resolved address is non-public (SSRF protection). + + Args: + url: The URL being validated, used for the error message. + addrinfo: The ``getaddrinfo`` result for the URL host. + + Raises: + ValueError: If any resolved address is a blocked (non-public) address. + """ + for *_, sockaddr in addrinfo: + ip = ipaddress.ip_address(sockaddr[0]) + if _is_blocked_ip(ip): + raise ValueError( + f"Refusing to fetch URL resolving to non-public address {ip} " + f"(SSRF protection): {url}" + ) + + +def _assert_url_allowed(url: str) -> None: + """Raise if ``url``'s host resolves to a non-public address (SSRF guard). + + Resolves every address the host maps to and rejects the request if any of + them is non-public, so a multi-record host cannot slip a private address + past the guard. + + Args: + url: The absolute http(s) URL about to be fetched. + + Raises: + ValueError: If the URL has no host, cannot be resolved, or resolves to + a blocked address. + """ + host = _url_host(url) + try: + addrinfo = socket.getaddrinfo(host, None) + except socket.gaierror as exc: + raise ValueError(f"Cannot resolve URL host: {host}") from exc + _reject_blocked_addrinfo(url, addrinfo) + + +async def _aassert_url_allowed(url: str) -> None: + """Async variant of :func:`_assert_url_allowed`. + + Runs the blocking name resolution in the default executor so the DNS lookup + does not block the running event loop. + + Args: + url: The absolute http(s) URL about to be fetched. + + Raises: + ValueError: If the URL has no host, cannot be resolved, or resolves to + a blocked address. + """ + host = _url_host(url) + loop = asyncio.get_running_loop() + try: + addrinfo = await loop.run_in_executor(None, socket.getaddrinfo, host, None) + except socket.gaierror as exc: + raise ValueError(f"Cannot resolve URL host: {host}") from exc + _reject_blocked_addrinfo(url, addrinfo) + + class FileUrl(BaseModel): """File referenced by URL. @@ -526,11 +642,32 @@ def _guess_content_type(self) -> str: return guessed or "application/octet-stream" def read(self) -> bytes: - """Fetch content from URL (for providers that don't support URL references).""" + """Fetch content from URL, blocking SSRF to non-public addresses. + + Redirects are followed manually so every hop is re-validated by + :func:`_assert_url_allowed`; a public URL therefore cannot redirect into + an internal or cloud-metadata address. + + Returns: + The fetched file content as bytes. + + Raises: + ValueError: If the URL (or a redirect target) resolves to a blocked + address, or if there are too many redirects. + """ if self._content is None: import httpx - response = httpx.get(self.url, follow_redirects=True) + current = self.url + for _ in range(_MAX_REDIRECTS + 1): + _assert_url_allowed(current) + response = httpx.get(current, follow_redirects=False) + if response.is_redirect and "location" in response.headers: + current = urljoin(current, response.headers["location"]) + continue + break + else: + raise ValueError(f"Too many redirects while fetching URL: {self.url}") response.raise_for_status() self._content = response.content if "content-type" in response.headers: @@ -538,16 +675,35 @@ def read(self) -> bytes: return self._content async def aread(self) -> bytes: - """Async fetch content from URL.""" + """Async fetch with the same SSRF protection as :meth:`read`. + + Returns: + The fetched file content as bytes. + + Raises: + ValueError: If the URL (or a redirect target) resolves to a blocked + address, or if there are too many redirects. + """ if self._content is None: import httpx - async with httpx.AsyncClient() as client: - response = await client.get(self.url, follow_redirects=True) - response.raise_for_status() - self._content = response.content - if "content-type" in response.headers: - self._content_type = response.headers["content-type"].split(";")[0] + current = self.url + async with httpx.AsyncClient(follow_redirects=False) as client: + for _ in range(_MAX_REDIRECTS + 1): + await _aassert_url_allowed(current) + response = await client.get(current) + if response.is_redirect and "location" in response.headers: + current = urljoin(current, response.headers["location"]) + continue + break + else: + raise ValueError( + f"Too many redirects while fetching URL: {self.url}" + ) + response.raise_for_status() + self._content = response.content + if "content-type" in response.headers: + self._content_type = response.headers["content-type"].split(";")[0] return self._content diff --git a/lib/crewai-files/tests/test_file_url.py b/lib/crewai-files/tests/test_file_url.py index 7885723e61..bee4efd173 100644 --- a/lib/crewai-files/tests/test_file_url.py +++ b/lib/crewai-files/tests/test_file_url.py @@ -9,6 +9,33 @@ import pytest +def _addrinfo(ip: str) -> list[tuple[int, int, int, str, tuple[str, int]]]: + """Build a minimal ``socket.getaddrinfo`` result for a single IP. + + Args: + ip: The IP address the host should resolve to. + + Returns: + A one-entry list shaped like ``socket.getaddrinfo`` output. + """ + return [(0, 0, 0, "", (ip, 0))] + + +@pytest.fixture(autouse=True) +def mock_public_dns(): + """Resolve every host to a public IP so fetch tests stay offline. + + The SSRF guard resolves the URL host; without this fixture the read tests + would perform real DNS lookups. Individual tests can still override + ``socket.getaddrinfo`` to exercise blocked addresses. + + Yields: + The patched ``getaddrinfo`` mock. + """ + with patch("socket.getaddrinfo", return_value=_addrinfo("93.184.216.34")) as m: + yield m + + class TestFileUrl: """Tests for FileUrl source type.""" @@ -90,12 +117,13 @@ def test_read_fetches_content(self): mock_response = MagicMock() mock_response.content = b"fake image content" mock_response.headers = {"content-type": "image/png"} + mock_response.is_redirect = False with patch("httpx.get", return_value=mock_response) as mock_get: content = url.read() mock_get.assert_called_once_with( - "https://example.com/image.png", follow_redirects=True + "https://example.com/image.png", follow_redirects=False ) assert content == b"fake image content" @@ -309,3 +337,97 @@ def test_image_file_from_file_url(self): assert file.source is url assert file.content_type == "image/jpeg" + + +class TestFileUrlSSRF: + """SSRF protection for FileUrl.read / aread (CWE-918).""" + + @pytest.mark.parametrize( + "blocked_ip", + [ + "127.0.0.1", # loopback + "169.254.169.254", # cloud metadata + "10.0.0.5", # RFC1918 + "192.168.1.10", # RFC1918 + "::1", # IPv6 loopback + "::ffff:127.0.0.1", # IPv4-mapped loopback (naive-guard bypass) + "0.0.0.0", # unspecified + ], + ) + def test_read_blocks_non_public_addresses(self, blocked_ip): + """read() must refuse URLs resolving to a non-public address.""" + url = FileUrl(url="http://internal.example/secret") + with patch("socket.getaddrinfo", return_value=_addrinfo(blocked_ip)): + with patch("httpx.get") as mock_get: + with pytest.raises(ValueError, match="SSRF protection"): + url.read() + mock_get.assert_not_called() + + def test_read_allows_public_address(self): + """read() must still fetch a normal public URL (no false positive).""" + url = FileUrl(url="https://example.com/image.png") + mock_response = MagicMock() + mock_response.content = b"ok" + mock_response.headers = {"content-type": "image/png"} + mock_response.is_redirect = False + + with patch("socket.getaddrinfo", return_value=_addrinfo("93.184.216.34")): + with patch("httpx.get", return_value=mock_response) as mock_get: + assert url.read() == b"ok" + mock_get.assert_called_once_with( + "https://example.com/image.png", follow_redirects=False + ) + + def test_read_blocks_redirect_to_internal(self): + """A public URL redirecting to an internal address must be blocked.""" + url = FileUrl(url="https://example.com/start") + redirect = MagicMock() + redirect.is_redirect = True + redirect.headers = {"location": "http://169.254.169.254/latest/meta-data/"} + + def fake_getaddrinfo(host, *_args, **_kwargs): + """Resolve the public start host and the internal redirect host. + + Args: + host: The host being resolved. + + Returns: + A ``getaddrinfo``-shaped result for the requested host. + """ + mapping = { + "example.com": "93.184.216.34", + "169.254.169.254": "169.254.169.254", + } + return _addrinfo(mapping[host]) + + with patch("socket.getaddrinfo", side_effect=fake_getaddrinfo): + with patch("httpx.get", return_value=redirect): + with pytest.raises(ValueError, match="SSRF protection"): + url.read() + + def test_read_blocks_redirect_bomb(self): + """Endless redirects must raise rather than loop forever.""" + url = FileUrl(url="https://example.com/a") + loop_response = MagicMock() + loop_response.is_redirect = True + loop_response.headers = {"location": "https://example.com/a"} + + with patch("socket.getaddrinfo", return_value=_addrinfo("93.184.216.34")): + with patch("httpx.get", return_value=loop_response): + with pytest.raises(ValueError, match="Too many redirects"): + url.read() + + @pytest.mark.asyncio + async def test_aread_blocks_non_public_address(self): + """aread() must apply the same SSRF guard as read().""" + url = FileUrl(url="http://internal.example/secret") + mock_client = MagicMock() + mock_client.get = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + + with patch("socket.getaddrinfo", return_value=_addrinfo("127.0.0.1")): + with patch("httpx.AsyncClient", return_value=mock_client): + with pytest.raises(ValueError, match="SSRF protection"): + await url.aread() + mock_client.get.assert_not_called() From b2f5ad978988cd7f91d9ea865b4ab082800b06c8 Mon Sep 17 00:00:00 2001 From: Zied Jlassi <6190550+zied-jlassi@users.noreply.github.com> Date: Sat, 20 Jun 2026 23:52:53 +0200 Subject: [PATCH 2/3] fix(crewai-files): pin connection to validated IP to close DNS rebinding The previous guard validated the host with a separate DNS lookup and then let httpx resolve the hostname again when connecting, leaving a DNS-rebinding (TOCTOU) window: an attacker-controlled host could answer public for the check and private/metadata for the connection. Resolve the host once, validate every returned address, and connect to the validated IP directly (the URL host is replaced with the IP, while the original Host header and TLS SNI hostname are preserved). The hostname is never resolved a second time, so rebinding cannot occur. Every redirect hop is resolved and validated the same way. Adds tests for IP pinning, Host/SNI preservation, single-resolution (rebinding), mixed public/private records, explicit ports, IPv6 bracketing, and redirect re-validation. AI-assisted audit (audit by AI and Zied Jlassi). Addresses the corridor-security DNS-rebinding finding on this PR. --- .../src/crewai_files/core/sources.py | 129 ++++++--- lib/crewai-files/tests/test_file_url.py | 253 +++++++++++++----- 2 files changed, 271 insertions(+), 111 deletions(-) diff --git a/lib/crewai-files/src/crewai_files/core/sources.py b/lib/crewai-files/src/crewai_files/core/sources.py index 642a8d75a9..014c39e588 100644 --- a/lib/crewai-files/src/crewai_files/core/sources.py +++ b/lib/crewai-files/src/crewai_files/core/sources.py @@ -537,69 +537,105 @@ def _url_host(url: str) -> str: return host -def _reject_blocked_addrinfo( - url: str, addrinfo: Sequence[tuple[Any, ...]] -) -> None: - """Raise if any resolved address is non-public (SSRF protection). +def _select_validated_ip(host: str, addrinfo: Sequence[tuple[Any, ...]]) -> str: + """Validate every resolved address and return one public IP to connect to. + + Rejects the host if *any* resolved address is non-public, so a multi-record + host cannot slip a private address past the guard, then returns the first + address for the caller to connect to directly. Args: - url: The URL being validated, used for the error message. - addrinfo: The ``getaddrinfo`` result for the URL host. + host: The host being resolved, used for the error message. + addrinfo: The ``getaddrinfo`` result for ``host``. + + Returns: + A validated, public IP address to connect to. Raises: - ValueError: If any resolved address is a blocked (non-public) address. + ValueError: If the host resolves to no address or to a blocked one. """ + selected: str | None = None for *_, sockaddr in addrinfo: ip = ipaddress.ip_address(sockaddr[0]) if _is_blocked_ip(ip): raise ValueError( f"Refusing to fetch URL resolving to non-public address {ip} " - f"(SSRF protection): {url}" + f"(SSRF protection): host {host}" ) + if selected is None: + selected = sockaddr[0] + if selected is None: + raise ValueError(f"Cannot resolve URL host: {host}") + return selected -def _assert_url_allowed(url: str) -> None: - """Raise if ``url``'s host resolves to a non-public address (SSRF guard). - - Resolves every address the host maps to and rejects the request if any of - them is non-public, so a multi-record host cannot slip a private address - past the guard. +def _resolve_validated_ip(host: str) -> str: + """Resolve ``host`` and return a validated public IP (SSRF guard). Args: - url: The absolute http(s) URL about to be fetched. + host: The host to resolve. + + Returns: + A validated, public IP address. Raises: - ValueError: If the URL has no host, cannot be resolved, or resolves to - a blocked address. + ValueError: If the host cannot be resolved or is non-public. """ - host = _url_host(url) try: addrinfo = socket.getaddrinfo(host, None) except socket.gaierror as exc: raise ValueError(f"Cannot resolve URL host: {host}") from exc - _reject_blocked_addrinfo(url, addrinfo) + return _select_validated_ip(host, addrinfo) -async def _aassert_url_allowed(url: str) -> None: - """Async variant of :func:`_assert_url_allowed`. +async def _aresolve_validated_ip(host: str) -> str: + """Async variant of :func:`_resolve_validated_ip`. Runs the blocking name resolution in the default executor so the DNS lookup does not block the running event loop. Args: - url: The absolute http(s) URL about to be fetched. + host: The host to resolve. + + Returns: + A validated, public IP address. Raises: - ValueError: If the URL has no host, cannot be resolved, or resolves to - a blocked address. + ValueError: If the host cannot be resolved or is non-public. """ - host = _url_host(url) loop = asyncio.get_running_loop() try: addrinfo = await loop.run_in_executor(None, socket.getaddrinfo, host, None) except socket.gaierror as exc: raise ValueError(f"Cannot resolve URL host: {host}") from exc - _reject_blocked_addrinfo(url, addrinfo) + return _select_validated_ip(host, addrinfo) + + +def _pin_request_kwargs( + url: str, ip: str +) -> tuple[str, dict[str, str], dict[str, Any]]: + """Build request arguments that pin the connection to a validated IP. + + The URL host is replaced with the validated IP so the HTTP client connects + to that exact address and never resolves the hostname again — closing the + DNS-rebinding (TOCTOU) window — while the original ``Host`` header and TLS + SNI hostname are preserved so virtual hosting and certificate verification + keep working. + + Args: + url: The absolute http(s) URL being fetched. + ip: The validated public IP to connect to. + + Returns: + A ``(pinned_url, headers, extensions)`` tuple for ``build_request``. + """ + parsed = urlparse(url) + host = parsed.hostname or "" + host_header = parsed.netloc.rsplit("@", 1)[-1] + ip_host = f"[{ip}]" if ":" in ip else ip + netloc = f"{ip_host}:{parsed.port}" if parsed.port is not None else ip_host + pinned_url = parsed._replace(netloc=netloc).geturl() + return pinned_url, {"Host": host_header}, {"sni_hostname": host} class FileUrl(BaseModel): @@ -644,9 +680,11 @@ def _guess_content_type(self) -> str: def read(self) -> bytes: """Fetch content from URL, blocking SSRF to non-public addresses. - Redirects are followed manually so every hop is re-validated by - :func:`_assert_url_allowed`; a public URL therefore cannot redirect into - an internal or cloud-metadata address. + Each request connects to a validated public IP (the hostname is never + re-resolved by the HTTP client, closing the DNS-rebinding window), and + redirects are followed manually so every hop is re-validated. A public + URL therefore cannot reach — or redirect into — an internal or + cloud-metadata address. Returns: The fetched file content as bytes. @@ -659,15 +697,22 @@ def read(self) -> bytes: import httpx current = self.url - for _ in range(_MAX_REDIRECTS + 1): - _assert_url_allowed(current) - response = httpx.get(current, follow_redirects=False) - if response.is_redirect and "location" in response.headers: - current = urljoin(current, response.headers["location"]) - continue - break - else: - raise ValueError(f"Too many redirects while fetching URL: {self.url}") + with httpx.Client(follow_redirects=False) as client: + for _ in range(_MAX_REDIRECTS + 1): + ip = _resolve_validated_ip(_url_host(current)) + pinned_url, headers, extensions = _pin_request_kwargs(current, ip) + request = client.build_request( + "GET", pinned_url, headers=headers, extensions=extensions + ) + response = client.send(request) + if response.is_redirect and "location" in response.headers: + current = urljoin(current, response.headers["location"]) + continue + break + else: + raise ValueError( + f"Too many redirects while fetching URL: {self.url}" + ) response.raise_for_status() self._content = response.content if "content-type" in response.headers: @@ -690,8 +735,12 @@ async def aread(self) -> bytes: current = self.url async with httpx.AsyncClient(follow_redirects=False) as client: for _ in range(_MAX_REDIRECTS + 1): - await _aassert_url_allowed(current) - response = await client.get(current) + ip = await _aresolve_validated_ip(_url_host(current)) + pinned_url, headers, extensions = _pin_request_kwargs(current, ip) + request = client.build_request( + "GET", pinned_url, headers=headers, extensions=extensions + ) + response = await client.send(request) if response.is_redirect and "location" in response.headers: current = urljoin(current, response.headers["location"]) continue diff --git a/lib/crewai-files/tests/test_file_url.py b/lib/crewai-files/tests/test_file_url.py index bee4efd173..0a319192d1 100644 --- a/lib/crewai-files/tests/test_file_url.py +++ b/lib/crewai-files/tests/test_file_url.py @@ -4,7 +4,7 @@ from crewai_files import FileBytes, FileUrl, ImageFile from crewai_files.core.resolved import InlineBase64, UrlReference -from crewai_files.core.sources import FilePath, _normalize_source +from crewai_files.core.sources import _MAX_REDIRECTS, FilePath, _normalize_source from crewai_files.resolution.resolver import FileResolver import pytest @@ -36,6 +36,68 @@ def mock_public_dns(): yield m +def _response( + content: bytes = b"", + headers: dict[str, str] | None = None, + *, + is_redirect: bool = False, +) -> MagicMock: + """Build a mock httpx response. + + Args: + content: The response body. + headers: The response headers. + is_redirect: Whether the response is a redirect. + + Returns: + A configured mock response. + """ + response = MagicMock() + response.content = content + response.headers = headers if headers is not None else {} + response.is_redirect = is_redirect + response.raise_for_status = MagicMock() + return response + + +def _sync_client(responses: list[MagicMock]) -> MagicMock: + """Build a mock ``httpx.Client`` yielding ``responses`` in order. + + Args: + responses: Responses returned by successive ``send`` calls. + + Returns: + A mock client supporting the context-manager and request API. + """ + client = MagicMock() + client.__enter__ = MagicMock(return_value=client) + client.__exit__ = MagicMock(return_value=False) + client.build_request = MagicMock( + side_effect=lambda method, url, **kwargs: {"method": method, "url": url, **kwargs} + ) + client.send = MagicMock(side_effect=list(responses)) + return client + + +def _async_client(responses: list[MagicMock]) -> MagicMock: + """Build a mock ``httpx.AsyncClient`` yielding ``responses`` in order. + + Args: + responses: Responses returned by successive ``send`` calls. + + Returns: + A mock async client supporting the async-context-manager and request API. + """ + client = MagicMock() + client.__aenter__ = AsyncMock(return_value=client) + client.__aexit__ = AsyncMock(return_value=None) + client.build_request = MagicMock( + side_effect=lambda method, url, **kwargs: {"method": method, "url": url, **kwargs} + ) + client.send = AsyncMock(side_effect=list(responses)) + return client + + class TestFileUrl: """Tests for FileUrl source type.""" @@ -114,41 +176,34 @@ def test_content_type_no_extension(self): def test_read_fetches_content(self): """Test that read() fetches content from URL.""" url = FileUrl(url="https://example.com/image.png") - mock_response = MagicMock() - mock_response.content = b"fake image content" - mock_response.headers = {"content-type": "image/png"} - mock_response.is_redirect = False + client = _sync_client([_response(b"fake image content", {"content-type": "image/png"})]) - with patch("httpx.get", return_value=mock_response) as mock_get: + with patch("httpx.Client", return_value=client): content = url.read() - mock_get.assert_called_once_with( - "https://example.com/image.png", follow_redirects=False - ) + client.send.assert_called_once() assert content == b"fake image content" def test_read_caches_content(self): """Test that read() caches content.""" url = FileUrl(url="https://example.com/image.png") - mock_response = MagicMock() - mock_response.content = b"fake content" - mock_response.headers = {} + client = _sync_client([_response(b"fake content")]) - with patch("httpx.get", return_value=mock_response) as mock_get: + with patch("httpx.Client", return_value=client): content1 = url.read() content2 = url.read() - mock_get.assert_called_once() + client.send.assert_called_once() assert content1 == content2 def test_read_updates_content_type_from_response(self): """Test that read() updates content type from response headers.""" url = FileUrl(url="https://example.com/file") - mock_response = MagicMock() - mock_response.content = b"fake content" - mock_response.headers = {"content-type": "image/webp; charset=utf-8"} + client = _sync_client( + [_response(b"fake content", {"content-type": "image/webp; charset=utf-8"})] + ) - with patch("httpx.get", return_value=mock_response): + with patch("httpx.Client", return_value=client): url.read() assert url.content_type == "image/webp" @@ -157,17 +212,9 @@ def test_read_updates_content_type_from_response(self): async def test_aread_fetches_content(self): """Test that aread() fetches content from URL asynchronously.""" url = FileUrl(url="https://example.com/image.png") - mock_response = MagicMock() - mock_response.content = b"async fake content" - mock_response.headers = {"content-type": "image/png"} - mock_response.raise_for_status = MagicMock() + client = _async_client([_response(b"async fake content", {"content-type": "image/png"})]) - mock_client = MagicMock() - mock_client.get = AsyncMock(return_value=mock_response) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - - with patch("httpx.AsyncClient", return_value=mock_client): + with patch("httpx.AsyncClient", return_value=client): content = await url.aread() assert content == b"async fake content" @@ -176,21 +223,13 @@ async def test_aread_fetches_content(self): async def test_aread_caches_content(self): """Test that aread() caches content.""" url = FileUrl(url="https://example.com/image.png") - mock_response = MagicMock() - mock_response.content = b"cached content" - mock_response.headers = {} - mock_response.raise_for_status = MagicMock() - - mock_client = MagicMock() - mock_client.get = AsyncMock(return_value=mock_response) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) + client = _async_client([_response(b"cached content")]) - with patch("httpx.AsyncClient", return_value=mock_client): + with patch("httpx.AsyncClient", return_value=client): content1 = await url.aread() content2 = await url.aread() - mock_client.get.assert_called_once() + client.send.assert_called_once() assert content1 == content2 @@ -285,11 +324,10 @@ def test_resolve_url_source_bedrock_fetches_content(self): file_url = FileUrl(url="https://example.com/image.png") file = ImageFile(source=file_url) - mock_response = MagicMock() - mock_response.content = b"\x89PNG\r\n\x1a\n" + b"\x00" * 50 - mock_response.headers = {"content-type": "image/png"} + png_bytes = b"\x89PNG\r\n\x1a\n" + b"\x00" * 50 + client = _sync_client([_response(png_bytes, {"content-type": "image/png"})]) - with patch("httpx.get", return_value=mock_response): + with patch("httpx.Client", return_value=client): resolved = resolver.resolve(file, "bedrock") assert not isinstance(resolved, UrlReference) @@ -340,7 +378,7 @@ def test_image_file_from_file_url(self): class TestFileUrlSSRF: - """SSRF protection for FileUrl.read / aread (CWE-918).""" + """SSRF protection for FileUrl.read / aread (CWE-918), incl. DNS rebinding.""" @pytest.mark.parametrize( "blocked_ip", @@ -349,41 +387,103 @@ class TestFileUrlSSRF: "169.254.169.254", # cloud metadata "10.0.0.5", # RFC1918 "192.168.1.10", # RFC1918 + "172.16.0.1", # RFC1918 "::1", # IPv6 loopback "::ffff:127.0.0.1", # IPv4-mapped loopback (naive-guard bypass) + "::ffff:169.254.169.254", # IPv4-mapped metadata + "fc00::1", # IPv6 ULA (private) + "fe80::1", # IPv6 link-local "0.0.0.0", # unspecified + "224.0.0.1", # multicast ], ) def test_read_blocks_non_public_addresses(self, blocked_ip): """read() must refuse URLs resolving to a non-public address.""" url = FileUrl(url="http://internal.example/secret") + client = _sync_client([]) with patch("socket.getaddrinfo", return_value=_addrinfo(blocked_ip)): - with patch("httpx.get") as mock_get: + with patch("httpx.Client", return_value=client): with pytest.raises(ValueError, match="SSRF protection"): url.read() - mock_get.assert_not_called() + client.send.assert_not_called() + + def test_read_blocks_when_any_record_is_private(self): + """A host with mixed public/private records must be rejected.""" + url = FileUrl(url="http://split.example/x") + addrinfo = _addrinfo("93.184.216.34") + _addrinfo("127.0.0.1") + client = _sync_client([]) + with patch("socket.getaddrinfo", return_value=addrinfo): + with patch("httpx.Client", return_value=client): + with pytest.raises(ValueError, match="SSRF protection"): + url.read() + client.send.assert_not_called() + + def test_read_pins_connection_to_validated_ip(self): + """read() connects to the validated IP, preserving Host and SNI.""" + url = FileUrl(url="https://example.com/image.png") + client = _sync_client([_response(b"ok", {"content-type": "image/png"})]) + with patch("socket.getaddrinfo", return_value=_addrinfo("93.184.216.34")): + with patch("httpx.Client", return_value=client): + assert url.read() == b"ok" + method, sent_url = client.build_request.call_args.args + kwargs = client.build_request.call_args.kwargs + assert method == "GET" + assert sent_url == "https://93.184.216.34/image.png" + assert kwargs["headers"]["Host"] == "example.com" + assert kwargs["extensions"]["sni_hostname"] == "example.com" + + def test_read_does_not_re_resolve_hostname(self): + """DNS rebinding cannot bypass the guard (host resolved once, IP pinned).""" + url = FileUrl(url="http://rebind.example/x") + getaddrinfo = MagicMock(return_value=_addrinfo("93.184.216.34")) + client = _sync_client([_response(b"x")]) + with patch("socket.getaddrinfo", getaddrinfo): + with patch("httpx.Client", return_value=client): + url.read() + assert getaddrinfo.call_count == 1 + _, sent_url = client.build_request.call_args.args + assert "93.184.216.34" in sent_url + assert "rebind.example" not in sent_url + + def test_read_preserves_explicit_port(self): + """The validated-IP URL and Host header keep the explicit port.""" + url = FileUrl(url="https://example.com:8443/f") + client = _sync_client([_response(b"ok")]) + with patch("socket.getaddrinfo", return_value=_addrinfo("93.184.216.34")): + with patch("httpx.Client", return_value=client): + url.read() + _, sent_url = client.build_request.call_args.args + kwargs = client.build_request.call_args.kwargs + assert sent_url == "https://93.184.216.34:8443/f" + assert kwargs["headers"]["Host"] == "example.com:8443" + + def test_read_brackets_ipv6_target(self): + """An IPv6 connection target must be bracketed in the pinned URL.""" + url = FileUrl(url="https://v6.example/f") + client = _sync_client([_response(b"ok")]) + with patch("socket.getaddrinfo", return_value=_addrinfo("2606:2800:220:1::1")): + with patch("httpx.Client", return_value=client): + url.read() + _, sent_url = client.build_request.call_args.args + assert sent_url == "https://[2606:2800:220:1::1]/f" def test_read_allows_public_address(self): """read() must still fetch a normal public URL (no false positive).""" url = FileUrl(url="https://example.com/image.png") - mock_response = MagicMock() - mock_response.content = b"ok" - mock_response.headers = {"content-type": "image/png"} - mock_response.is_redirect = False - + client = _sync_client([_response(b"ok", {"content-type": "image/png"})]) with patch("socket.getaddrinfo", return_value=_addrinfo("93.184.216.34")): - with patch("httpx.get", return_value=mock_response) as mock_get: + with patch("httpx.Client", return_value=client): assert url.read() == b"ok" - mock_get.assert_called_once_with( - "https://example.com/image.png", follow_redirects=False - ) + client.send.assert_called_once() def test_read_blocks_redirect_to_internal(self): """A public URL redirecting to an internal address must be blocked.""" url = FileUrl(url="https://example.com/start") - redirect = MagicMock() - redirect.is_redirect = True - redirect.headers = {"location": "http://169.254.169.254/latest/meta-data/"} + redirect = _response( + headers={"location": "http://169.254.169.254/latest/meta-data/"}, + is_redirect=True, + ) + client = _sync_client([redirect]) def fake_getaddrinfo(host, *_args, **_kwargs): """Resolve the public start host and the internal redirect host. @@ -401,19 +501,20 @@ def fake_getaddrinfo(host, *_args, **_kwargs): return _addrinfo(mapping[host]) with patch("socket.getaddrinfo", side_effect=fake_getaddrinfo): - with patch("httpx.get", return_value=redirect): + with patch("httpx.Client", return_value=client): with pytest.raises(ValueError, match="SSRF protection"): url.read() def test_read_blocks_redirect_bomb(self): """Endless redirects must raise rather than loop forever.""" url = FileUrl(url="https://example.com/a") - loop_response = MagicMock() - loop_response.is_redirect = True - loop_response.headers = {"location": "https://example.com/a"} - + redirects = [ + _response(headers={"location": "https://example.com/a"}, is_redirect=True) + for _ in range(_MAX_REDIRECTS + 2) + ] + client = _sync_client(redirects) with patch("socket.getaddrinfo", return_value=_addrinfo("93.184.216.34")): - with patch("httpx.get", return_value=loop_response): + with patch("httpx.Client", return_value=client): with pytest.raises(ValueError, match="Too many redirects"): url.read() @@ -421,13 +522,23 @@ def test_read_blocks_redirect_bomb(self): async def test_aread_blocks_non_public_address(self): """aread() must apply the same SSRF guard as read().""" url = FileUrl(url="http://internal.example/secret") - mock_client = MagicMock() - mock_client.get = AsyncMock() - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=None) - + client = _async_client([]) with patch("socket.getaddrinfo", return_value=_addrinfo("127.0.0.1")): - with patch("httpx.AsyncClient", return_value=mock_client): + with patch("httpx.AsyncClient", return_value=client): with pytest.raises(ValueError, match="SSRF protection"): await url.aread() - mock_client.get.assert_not_called() + client.send.assert_not_called() + + @pytest.mark.asyncio + async def test_aread_pins_connection_to_validated_ip(self): + """aread() connects to the validated IP, preserving Host and SNI.""" + url = FileUrl(url="https://example.com/image.png") + client = _async_client([_response(b"ok", {"content-type": "image/png"})]) + with patch("socket.getaddrinfo", return_value=_addrinfo("93.184.216.34")): + with patch("httpx.AsyncClient", return_value=client): + assert await url.aread() == b"ok" + _, sent_url = client.build_request.call_args.args + kwargs = client.build_request.call_args.kwargs + assert sent_url == "https://93.184.216.34/image.png" + assert kwargs["headers"]["Host"] == "example.com" + assert kwargs["extensions"]["sni_hostname"] == "example.com" From 7952193e105e8f5bedd71bb0bbe5ff961d0d5862 Mon Sep 17 00:00:00 2001 From: Zied Jlassi <6190550+zied-jlassi@users.noreply.github.com> Date: Sun, 21 Jun 2026 08:58:54 +0200 Subject: [PATCH 3/3] test(crewai-files): cover async SSRF redirect paths and lock follow_redirects Mirror the sync redirect-revalidation and redirect-bomb tests for aread(), and assert both read()/aread() construct the httpx client with follow_redirects disabled so a regression cannot silently re-enable auto-follow and skip per-hop SSRF checks. Ignore S104 on the bind-all literal used as an SSRF test fixture. AI-assisted audit (audit by AI and Zied Jlassi, Architect AI). --- lib/crewai-files/tests/test_file_url.py | 88 +++++++++++++++++++++++-- pyproject.toml | 2 +- 2 files changed, 85 insertions(+), 5 deletions(-) diff --git a/lib/crewai-files/tests/test_file_url.py b/lib/crewai-files/tests/test_file_url.py index 0a319192d1..1eda6211d5 100644 --- a/lib/crewai-files/tests/test_file_url.py +++ b/lib/crewai-files/tests/test_file_url.py @@ -73,7 +73,11 @@ def _sync_client(responses: list[MagicMock]) -> MagicMock: client.__enter__ = MagicMock(return_value=client) client.__exit__ = MagicMock(return_value=False) client.build_request = MagicMock( - side_effect=lambda method, url, **kwargs: {"method": method, "url": url, **kwargs} + side_effect=lambda method, url, **kwargs: { + "method": method, + "url": url, + **kwargs, + } ) client.send = MagicMock(side_effect=list(responses)) return client @@ -92,7 +96,11 @@ def _async_client(responses: list[MagicMock]) -> MagicMock: client.__aenter__ = AsyncMock(return_value=client) client.__aexit__ = AsyncMock(return_value=None) client.build_request = MagicMock( - side_effect=lambda method, url, **kwargs: {"method": method, "url": url, **kwargs} + side_effect=lambda method, url, **kwargs: { + "method": method, + "url": url, + **kwargs, + } ) client.send = AsyncMock(side_effect=list(responses)) return client @@ -176,7 +184,9 @@ def test_content_type_no_extension(self): def test_read_fetches_content(self): """Test that read() fetches content from URL.""" url = FileUrl(url="https://example.com/image.png") - client = _sync_client([_response(b"fake image content", {"content-type": "image/png"})]) + client = _sync_client( + [_response(b"fake image content", {"content-type": "image/png"})] + ) with patch("httpx.Client", return_value=client): content = url.read() @@ -212,7 +222,9 @@ def test_read_updates_content_type_from_response(self): async def test_aread_fetches_content(self): """Test that aread() fetches content from URL asynchronously.""" url = FileUrl(url="https://example.com/image.png") - client = _async_client([_response(b"async fake content", {"content-type": "image/png"})]) + client = _async_client( + [_response(b"async fake content", {"content-type": "image/png"})] + ) with patch("httpx.AsyncClient", return_value=client): content = await url.aread() @@ -542,3 +554,71 @@ async def test_aread_pins_connection_to_validated_ip(self): assert sent_url == "https://93.184.216.34/image.png" assert kwargs["headers"]["Host"] == "example.com" assert kwargs["extensions"]["sni_hostname"] == "example.com" + + @pytest.mark.asyncio + async def test_aread_blocks_redirect_to_internal(self): + """A public URL redirecting to an internal address must be blocked (async).""" + url = FileUrl(url="https://example.com/start") + redirect = _response( + headers={"location": "http://169.254.169.254/latest/meta-data/"}, + is_redirect=True, + ) + client = _async_client([redirect]) + + def fake_getaddrinfo(host, *_args, **_kwargs): + """Resolve the public start host and the internal redirect host. + + Args: + host: The host being resolved. + + Returns: + A ``getaddrinfo``-shaped result for the requested host. + """ + mapping = { + "example.com": "93.184.216.34", + "169.254.169.254": "169.254.169.254", + } + return _addrinfo(mapping[host]) + + with patch("socket.getaddrinfo", side_effect=fake_getaddrinfo): + with patch("httpx.AsyncClient", return_value=client): + with pytest.raises(ValueError, match="SSRF protection"): + await url.aread() + + @pytest.mark.asyncio + async def test_aread_blocks_redirect_bomb(self): + """Endless redirects must raise rather than loop forever (async).""" + url = FileUrl(url="https://example.com/a") + redirects = [ + _response(headers={"location": "https://example.com/a"}, is_redirect=True) + for _ in range(_MAX_REDIRECTS + 2) + ] + client = _async_client(redirects) + with patch("socket.getaddrinfo", return_value=_addrinfo("93.184.216.34")): + with patch("httpx.AsyncClient", return_value=client): + with pytest.raises(ValueError, match="Too many redirects"): + await url.aread() + + def test_read_constructs_client_without_following_redirects(self): + """read() must build the client with follow_redirects disabled. + + Redirects are followed and re-validated manually; letting httpx auto-follow + would skip the per-hop SSRF check, so a regression to + ``follow_redirects=True`` must fail this test. + """ + url = FileUrl(url="https://example.com/image.png") + client = _sync_client([_response(b"ok")]) + with patch("socket.getaddrinfo", return_value=_addrinfo("93.184.216.34")): + with patch("httpx.Client", return_value=client) as client_cls: + url.read() + assert client_cls.call_args.kwargs.get("follow_redirects") is False + + @pytest.mark.asyncio + async def test_aread_constructs_client_without_following_redirects(self): + """aread() must build the async client with follow_redirects disabled.""" + url = FileUrl(url="https://example.com/image.png") + client = _async_client([_response(b"ok")]) + with patch("socket.getaddrinfo", return_value=_addrinfo("93.184.216.34")): + with patch("httpx.AsyncClient", return_value=client) as client_cls: + await url.aread() + assert client_cls.call_args.kwargs.get("follow_redirects") is False diff --git a/pyproject.toml b/pyproject.toml index d36586f4a5..827f431eb9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,7 +111,7 @@ ignore-decorators = ["typing.overload"] [tool.ruff.lint.per-file-ignores] "lib/crewai/tests/**/*.py" = ["S101", "RET504", "S105", "S106"] # Allow assert statements, unnecessary assignments, and hardcoded passwords in tests "lib/crewai-tools/tests/**/*.py" = ["S101", "RET504", "S105", "S106", "RUF012", "N818", "E402", "RUF043", "S110", "B017"] # Allow various test-specific patterns -"lib/crewai-files/tests/**/*.py" = ["S101", "RET504", "S105", "S106", "B017", "F841"] # Allow assert statements and blind exception assertions in tests +"lib/crewai-files/tests/**/*.py" = ["S101", "RET504", "S105", "S106", "B017", "F841", "S104"] # Allow assert statements, blind exception assertions, and bind-all-interfaces literals (SSRF test fixtures) in tests "lib/cli/tests/**/*.py" = ["S101", "RET504", "S105", "S106"] # Allow assert statements in tests "lib/crewai-core/tests/**/*.py" = ["S101", "RET504", "S105", "S106"] # Allow assert statements in tests "lib/devtools/tests/**/*.py" = ["S101"]