Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 215 additions & 10 deletions lib/crewai-files/src/crewai_files/core/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@

from __future__ import annotations

from collections.abc import AsyncIterator, Iterator
import asyncio
from collections.abc import AsyncIterator, Iterator, Sequence
import inspect
import ipaddress
import json
import mimetypes
from pathlib import Path
import socket
from typing import Annotated, Any, BinaryIO, Protocol, cast, runtime_checkable
from urllib.parse import urljoin, urlparse

import aiofiles
from pydantic import (
Expand Down Expand Up @@ -486,6 +490,154 @@ async def aread_chunks(self, chunk_size: int = 65536) -> AsyncIterator[bytes]:
yield chunk


_MAX_REDIRECTS = 5


def _is_blocked_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
"""Return whether an IP address must not be fetched (SSRF protection).

Normalizes IPv4-mapped IPv6 addresses (e.g. ``::ffff:127.0.0.1``) to their
IPv4 form first, because :class:`ipaddress.IPv6Address` does not flag the
mapped form as loopback or private on its own, a common SSRF guard bypass.

Args:
ip: The resolved IP address to classify.

Returns:
True if the address is loopback, private, link-local, reserved,
multicast, or unspecified; False for a routable public address.
"""
if isinstance(ip, ipaddress.IPv6Address) and ip.ipv4_mapped is not None:
ip = ip.ipv4_mapped
return (
ip.is_private
or ip.is_loopback
or ip.is_link_local
or ip.is_reserved
or ip.is_multicast
or ip.is_unspecified
)


def _url_host(url: str) -> str:
"""Extract the host of an absolute URL.

Args:
url: The absolute http(s) URL.

Returns:
The host component.

Raises:
ValueError: If the URL has no host.
"""
host = urlparse(url).hostname
if not host:
raise ValueError(f"URL has no host: {url}")
return host


def _select_validated_ip(host: str, addrinfo: Sequence[tuple[Any, ...]]) -> str:
"""Validate every resolved address and return one public IP to connect to.

Rejects the host if *any* resolved address is non-public, so a multi-record
host cannot slip a private address past the guard, then returns the first
address for the caller to connect to directly.

Args:
host: The host being resolved, used for the error message.
addrinfo: The ``getaddrinfo`` result for ``host``.

Returns:
A validated, public IP address to connect to.

Raises:
ValueError: If the host resolves to no address or to a blocked one.
"""
selected: str | None = None
for *_, sockaddr in addrinfo:
ip = ipaddress.ip_address(sockaddr[0])
if _is_blocked_ip(ip):
raise ValueError(
f"Refusing to fetch URL resolving to non-public address {ip} "
f"(SSRF protection): host {host}"
)
if selected is None:
selected = sockaddr[0]
if selected is None:
raise ValueError(f"Cannot resolve URL host: {host}")
return selected


def _resolve_validated_ip(host: str) -> str:
"""Resolve ``host`` and return a validated public IP (SSRF guard).

Args:
host: The host to resolve.

Returns:
A validated, public IP address.

Raises:
ValueError: If the host cannot be resolved or is non-public.
"""
try:
addrinfo = socket.getaddrinfo(host, None)
except socket.gaierror as exc:
raise ValueError(f"Cannot resolve URL host: {host}") from exc
return _select_validated_ip(host, addrinfo)


async def _aresolve_validated_ip(host: str) -> str:
"""Async variant of :func:`_resolve_validated_ip`.

Runs the blocking name resolution in the default executor so the DNS lookup
does not block the running event loop.

Args:
host: The host to resolve.

Returns:
A validated, public IP address.

Raises:
ValueError: If the host cannot be resolved or is non-public.
"""
loop = asyncio.get_running_loop()
try:
addrinfo = await loop.run_in_executor(None, socket.getaddrinfo, host, None)
except socket.gaierror as exc:
raise ValueError(f"Cannot resolve URL host: {host}") from exc
return _select_validated_ip(host, addrinfo)


def _pin_request_kwargs(
url: str, ip: str
) -> tuple[str, dict[str, str], dict[str, Any]]:
"""Build request arguments that pin the connection to a validated IP.

The URL host is replaced with the validated IP so the HTTP client connects
to that exact address and never resolves the hostname again — closing the
DNS-rebinding (TOCTOU) window — while the original ``Host`` header and TLS
SNI hostname are preserved so virtual hosting and certificate verification
keep working.

Args:
url: The absolute http(s) URL being fetched.
ip: The validated public IP to connect to.

Returns:
A ``(pinned_url, headers, extensions)`` tuple for ``build_request``.
"""
parsed = urlparse(url)
host = parsed.hostname or ""
host_header = parsed.netloc.rsplit("@", 1)[-1]
ip_host = f"[{ip}]" if ":" in ip else ip
netloc = f"{ip_host}:{parsed.port}" if parsed.port is not None else ip_host
pinned_url = parsed._replace(netloc=netloc).geturl()
return pinned_url, {"Host": host_header}, {"sni_hostname": host}


class FileUrl(BaseModel):
"""File referenced by URL.

Expand Down Expand Up @@ -526,28 +678,81 @@ def _guess_content_type(self) -> str:
return guessed or "application/octet-stream"

def read(self) -> bytes:
"""Fetch content from URL (for providers that don't support URL references)."""
"""Fetch content from URL, blocking SSRF to non-public addresses.

Each request connects to a validated public IP (the hostname is never
re-resolved by the HTTP client, closing the DNS-rebinding window), and
redirects are followed manually so every hop is re-validated. A public
URL therefore cannot reach — or redirect into — an internal or
cloud-metadata address.

Returns:
The fetched file content as bytes.

Raises:
ValueError: If the URL (or a redirect target) resolves to a blocked
address, or if there are too many redirects.
"""
if self._content is None:
import httpx

response = httpx.get(self.url, follow_redirects=True)
current = self.url
with httpx.Client(follow_redirects=False) as client:
for _ in range(_MAX_REDIRECTS + 1):
ip = _resolve_validated_ip(_url_host(current))
pinned_url, headers, extensions = _pin_request_kwargs(current, ip)
request = client.build_request(
"GET", pinned_url, headers=headers, extensions=extensions
)
response = client.send(request)
if response.is_redirect and "location" in response.headers:
current = urljoin(current, response.headers["location"])
continue
break
else:
raise ValueError(
f"Too many redirects while fetching URL: {self.url}"
)
response.raise_for_status()
self._content = response.content
if "content-type" in response.headers:
self._content_type = response.headers["content-type"].split(";")[0]
return self._content

async def aread(self) -> bytes:
"""Async fetch content from URL."""
"""Async fetch with the same SSRF protection as :meth:`read`.

Returns:
The fetched file content as bytes.

Raises:
ValueError: If the URL (or a redirect target) resolves to a blocked
address, or if there are too many redirects.
"""
if self._content is None:
import httpx

async with httpx.AsyncClient() as client:
response = await client.get(self.url, follow_redirects=True)
response.raise_for_status()
self._content = response.content
if "content-type" in response.headers:
self._content_type = response.headers["content-type"].split(";")[0]
current = self.url
async with httpx.AsyncClient(follow_redirects=False) as client:
for _ in range(_MAX_REDIRECTS + 1):
ip = await _aresolve_validated_ip(_url_host(current))
pinned_url, headers, extensions = _pin_request_kwargs(current, ip)
request = client.build_request(
"GET", pinned_url, headers=headers, extensions=extensions
)
response = await client.send(request)
if response.is_redirect and "location" in response.headers:
current = urljoin(current, response.headers["location"])
continue
break
else:
raise ValueError(
f"Too many redirects while fetching URL: {self.url}"
)
response.raise_for_status()
self._content = response.content
if "content-type" in response.headers:
self._content_type = response.headers["content-type"].split(";")[0]
return self._content


Expand Down
Loading