From 3db0da1520cde3198720570f5f21c78c2322af68 Mon Sep 17 00:00:00 2001 From: biefan <70761325+biefan@users.noreply.github.com> Date: Tue, 17 Mar 2026 03:07:34 +0000 Subject: [PATCH 1/2] Normalize remote dataset file types from URLs --- .../remote/remote_dataset_loader.py | 15 ++++++++- .../datasets/test_remote_dataset_loader.py | 32 +++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py b/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py index 5cd9212846..7622175ef6 100644 --- a/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py +++ b/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py @@ -10,6 +10,7 @@ from collections.abc import Callable from pathlib import Path from typing import Any, Literal, Optional, TextIO, cast +from urllib.parse import urlparse import requests from datasets import DownloadMode, disable_progress_bars, load_dataset @@ -76,6 +77,18 @@ def _validate_file_type(self, file_type: str) -> None: valid_types = ", ".join(FILE_TYPE_HANDLERS.keys()) raise ValueError(f"Invalid file_type. Expected one of: {valid_types}.") + def _get_file_type(self, *, source: str) -> str: + """ + Infer the source file type from a URL or local path. + + Query strings and fragments are ignored for URLs, and the result is + normalized to lowercase so `.JSON` and `.json` are treated identically. + """ + parsed = urlparse(source) + source_path = parsed.path if parsed.scheme else source + suffix = Path(source_path).suffix + return suffix.lstrip(".").lower() + def _read_cache(self, *, cache_file: Path, file_type: str) -> list[dict[str, str]]: """ Read data from cache. @@ -188,7 +201,7 @@ def _fetch_from_url( ... source_type='public_url' ... ) """ - file_type = source.split(".")[-1] + file_type = self._get_file_type(source=source) if file_type not in FILE_TYPE_HANDLERS: valid_types = ", ".join(FILE_TYPE_HANDLERS.keys()) raise ValueError(f"Invalid file_type. Expected one of: {valid_types}.") diff --git a/tests/unit/datasets/test_remote_dataset_loader.py b/tests/unit/datasets/test_remote_dataset_loader.py index d0052a4c78..a1325e3052 100644 --- a/tests/unit/datasets/test_remote_dataset_loader.py +++ b/tests/unit/datasets/test_remote_dataset_loader.py @@ -72,3 +72,35 @@ def test_write_cache_creates_directories(self, tmp_path): loader._write_cache(cache_file=cache_file, examples=data, file_type="json") assert cache_file.exists() + + @patch.object(_RemoteDatasetLoader, "_fetch_from_public_url", return_value=[{"key": "value"}]) + def test_fetch_from_url_supports_query_string_file_type(self, mock_fetch_from_public_url): + loader = ConcreteRemoteLoader() + + result = loader._fetch_from_url( + source="https://example.com/data.json?download=1", + source_type="public_url", + cache=False, + ) + + assert result == [{"key": "value"}] + mock_fetch_from_public_url.assert_called_once_with( + source="https://example.com/data.json?download=1", + file_type="json", + ) + + @patch.object(_RemoteDatasetLoader, "_fetch_from_public_url", return_value=[{"key": "value"}]) + def test_fetch_from_url_supports_uppercase_file_type(self, mock_fetch_from_public_url): + loader = ConcreteRemoteLoader() + + result = loader._fetch_from_url( + source="https://example.com/data.JSON", + source_type="public_url", + cache=False, + ) + + assert result == [{"key": "value"}] + mock_fetch_from_public_url.assert_called_once_with( + source="https://example.com/data.JSON", + file_type="json", + ) From 11a064fc72e78068d552a8d9d9e4eb9f6e1e3961 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Wed, 15 Apr 2026 15:01:00 -0700 Subject: [PATCH 2/2] Add direct unit tests for _get_file_type Cover query strings, fragments, uppercase extensions, local paths, and no-extension edge case. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../datasets/test_remote_dataset_loader.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/unit/datasets/test_remote_dataset_loader.py b/tests/unit/datasets/test_remote_dataset_loader.py index 4fcd42a719..0d6c6ceafc 100644 --- a/tests/unit/datasets/test_remote_dataset_loader.py +++ b/tests/unit/datasets/test_remote_dataset_loader.py @@ -83,6 +83,26 @@ def test_write_cache_csv_allows_empty_examples(self, tmp_path): assert cache_file.read_text(encoding="utf-8") == "" assert loader._read_cache(cache_file=cache_file, file_type="csv") == [] + def test_get_file_type_strips_query_string(self): + loader = ConcreteRemoteLoader() + assert loader._get_file_type(source="https://example.com/data.json?download=1") == "json" + + def test_get_file_type_strips_fragment(self): + loader = ConcreteRemoteLoader() + assert loader._get_file_type(source="https://example.com/data.csv#row5") == "csv" + + def test_get_file_type_lowercases_extension(self): + loader = ConcreteRemoteLoader() + assert loader._get_file_type(source="https://example.com/data.JSONL") == "jsonl" + + def test_get_file_type_local_path(self): + loader = ConcreteRemoteLoader() + assert loader._get_file_type(source="/tmp/data.txt") == "txt" + + def test_get_file_type_returns_empty_for_no_extension(self): + loader = ConcreteRemoteLoader() + assert loader._get_file_type(source="https://example.com/data") == "" + @patch.object(_RemoteDatasetLoader, "_fetch_from_public_url", return_value=[{"key": "value"}]) def test_fetch_from_url_supports_query_string_file_type(self, mock_fetch_from_public_url): loader = ConcreteRemoteLoader()