diff --git a/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py b/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py index b28d7cc70..626b9febb 100644 --- a/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py +++ b/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py @@ -12,6 +12,7 @@ from enum import Enum from pathlib import Path from typing import Any, Literal, Optional, TextIO, cast +from urllib.parse import urlparse import requests from datasets import DownloadMode, disable_progress_bars, load_dataset @@ -125,6 +126,24 @@ def _validate_file_type(self, file_type: str) -> None: valid_types = ", ".join(FILE_TYPE_HANDLERS.keys()) raise ValueError(f"Invalid file_type. Expected one of: {valid_types}.") + def _get_file_type(self, *, source: str) -> str: + """ + Infer the source file type from a URL or local path. + + Query strings and fragments are ignored for URLs, and the result is + normalized to lowercase so `.JSON` and `.json` are treated identically. + + Args: + source (str): The URL or local file path to extract the file type from. + + Returns: + str: The lowercase file extension without the leading dot. + """ + parsed = urlparse(source) + source_path = parsed.path if parsed.scheme else source + suffix = Path(source_path).suffix + return suffix.lstrip(".").lower() + def _read_cache(self, *, cache_file: Path, file_type: str) -> list[dict[str, str]]: """ Read data from cache. @@ -237,7 +256,7 @@ def _fetch_from_url( ... source_type='public_url' ... ) """ - file_type = source.split(".")[-1] + file_type = self._get_file_type(source=source) if file_type not in FILE_TYPE_HANDLERS: valid_types = ", ".join(FILE_TYPE_HANDLERS.keys()) raise ValueError(f"Invalid file_type. Expected one of: {valid_types}.") diff --git a/tests/unit/datasets/test_remote_dataset_loader.py b/tests/unit/datasets/test_remote_dataset_loader.py index d9a2c8acf..0d6c6ceaf 100644 --- a/tests/unit/datasets/test_remote_dataset_loader.py +++ b/tests/unit/datasets/test_remote_dataset_loader.py @@ -82,3 +82,55 @@ def test_write_cache_csv_allows_empty_examples(self, tmp_path): assert cache_file.exists() assert cache_file.read_text(encoding="utf-8") == "" assert loader._read_cache(cache_file=cache_file, file_type="csv") == [] + + def test_get_file_type_strips_query_string(self): + loader = ConcreteRemoteLoader() + assert loader._get_file_type(source="https://example.com/data.json?download=1") == "json" + + def test_get_file_type_strips_fragment(self): + loader = ConcreteRemoteLoader() + assert loader._get_file_type(source="https://example.com/data.csv#row5") == "csv" + + def test_get_file_type_lowercases_extension(self): + loader = ConcreteRemoteLoader() + assert loader._get_file_type(source="https://example.com/data.JSONL") == "jsonl" + + def test_get_file_type_local_path(self): + loader = ConcreteRemoteLoader() + assert loader._get_file_type(source="/tmp/data.txt") == "txt" + + def test_get_file_type_returns_empty_for_no_extension(self): + loader = ConcreteRemoteLoader() + assert loader._get_file_type(source="https://example.com/data") == "" + + @patch.object(_RemoteDatasetLoader, "_fetch_from_public_url", return_value=[{"key": "value"}]) + def test_fetch_from_url_supports_query_string_file_type(self, mock_fetch_from_public_url): + loader = ConcreteRemoteLoader() + + result = loader._fetch_from_url( + source="https://example.com/data.json?download=1", + source_type="public_url", + cache=False, + ) + + assert result == [{"key": "value"}] + mock_fetch_from_public_url.assert_called_once_with( + source="https://example.com/data.json?download=1", + file_type="json", + ) + + @patch.object(_RemoteDatasetLoader, "_fetch_from_public_url", return_value=[{"key": "value"}]) + def test_fetch_from_url_supports_uppercase_file_type(self, mock_fetch_from_public_url): + loader = ConcreteRemoteLoader() + + result = loader._fetch_from_url( + source="https://example.com/data.JSON", + source_type="public_url", + cache=False, + ) + + assert result == [{"key": "value"}] + mock_fetch_from_public_url.assert_called_once_with( + source="https://example.com/data.JSON", + file_type="json", + )