diff --git a/test/data/utils/test_io_utils.py b/test/data/utils/test_io_utils.py index 85c09e9c..8c6781ef 100644 --- a/test/data/utils/test_io_utils.py +++ b/test/data/utils/test_io_utils.py @@ -1,5 +1,9 @@ """Tests for the io_utils module.""" +import os +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch import pytest from topobench.data.utils.io_utils import * @@ -20,3 +24,334 @@ def test_get_file_id_from_url(): with pytest.raises(ValueError): get_file_id_from_url(url_wrong) + + +class TestDownloadFileFromLink: + """Test suite for download_file_from_link function.""" + + @pytest.fixture + def temp_dir(self): + """Create temporary directory for test outputs. + + Returns + ------- + str + Path to temporary directory. + """ + with tempfile.TemporaryDirectory() as tmpdir: + yield tmpdir + + @pytest.fixture + def mock_response(self): + """Create mock response object. + + Returns + ------- + MagicMock + Mock response object with status code and headers. + """ + response = MagicMock() + response.status_code = 200 + response.headers = {"content-length": "5242880"} # 5 MB + response.elapsed.total_seconds.return_value = 1.0 + return response + + def test_download_success_with_progress(self, temp_dir, mock_response): + """Test successful download with progress reporting. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + mock_response : MagicMock + Mock response object. + """ + # Setup mock chunks (5MB total in 1MB chunks) + chunk_data = [b"x" * (1024 * 1024) for _ in range(5)] + mock_response.iter_content.return_value = chunk_data + + with patch("requests.get", return_value=mock_response): + download_file_from_link( + file_link="http://example.com/dataset.tar.gz", + path_to_save=temp_dir, + dataset_name="test_dataset", + file_format="tar.gz", + timeout=60, + retries=1, + ) + + # Verify file was created and has correct size + output_file = os.path.join(temp_dir, "test_dataset.tar.gz") + assert os.path.exists(output_file) + assert os.path.getsize(output_file) == 5 * 1024 * 1024 + + def test_download_creates_directory_if_not_exists(self, temp_dir): + """Test that download creates directory structure. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + """ + nested_dir = os.path.join(temp_dir, "nested", "path") + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {"content-length": "1024"} + mock_response.elapsed.total_seconds.return_value = 0.5 + mock_response.iter_content.return_value = [b"x" * 1024] + + with patch("requests.get", return_value=mock_response): + download_file_from_link( + file_link="http://example.com/dataset.tar.gz", + path_to_save=nested_dir, + dataset_name="test_dataset", + file_format="tar.gz", + timeout=60, + retries=1, + ) + + output_file = os.path.join(nested_dir, "test_dataset.tar.gz") + assert os.path.exists(output_file) + assert os.path.isdir(nested_dir) + + def test_download_http_error(self, temp_dir): + """Test handling of HTTP error responses. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + """ + mock_response = MagicMock() + mock_response.status_code = 404 + + with patch("requests.get", return_value=mock_response): + download_file_from_link( + file_link="http://example.com/nonexistent.tar.gz", + path_to_save=temp_dir, + dataset_name="test_dataset", + file_format="tar.gz", + timeout=60, + retries=1, + ) + + # File should not be created on HTTP error + output_file = os.path.join(temp_dir, "test_dataset.tar.gz") + assert not os.path.exists(output_file) + + def test_download_timeout_retry(self, temp_dir): + """Test retry logic on timeout. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + """ + import requests + + with patch("requests.get") as mock_get: + # First call times out, second succeeds + mock_response_success = MagicMock() + mock_response_success.status_code = 200 + mock_response_success.headers = {"content-length": "1024"} + mock_response_success.elapsed.total_seconds.return_value = 0.5 + mock_response_success.iter_content.return_value = [b"x" * 1024] + + mock_get.side_effect = [ + requests.exceptions.Timeout("Connection timed out"), + mock_response_success, + ] + + with patch("time.sleep"): # Mock sleep to speed up test + download_file_from_link( + file_link="http://example.com/dataset.tar.gz", + path_to_save=temp_dir, + dataset_name="test_dataset", + file_format="tar.gz", + timeout=60, + retries=3, + ) + + # File should be created on successful retry + output_file = os.path.join(temp_dir, "test_dataset.tar.gz") + assert os.path.exists(output_file) + assert mock_get.call_count == 2 + + def test_download_exhausts_retries(self, temp_dir): + """Test that exception is raised after all retries exhausted. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + """ + import requests + + with patch("requests.get") as mock_get: + mock_get.side_effect = requests.exceptions.Timeout( + "Connection timed out" + ) + + with patch("time.sleep"): + with pytest.raises(requests.exceptions.Timeout): + download_file_from_link( + file_link="http://example.com/dataset.tar.gz", + path_to_save=temp_dir, + dataset_name="test_dataset", + file_format="tar.gz", + timeout=60, + retries=2, + ) + + # Verify retries were attempted + assert mock_get.call_count == 2 + + def test_download_with_different_formats(self, temp_dir, mock_response): + """Test download with different file formats. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + mock_response : MagicMock + Mock response object. + """ + mock_response.iter_content.return_value = [b"test content"] + + formats = ["zip", "tar", "tar.gz"] + + with patch("requests.get", return_value=mock_response): + for fmt in formats: + download_file_from_link( + file_link="http://example.com/dataset", + path_to_save=temp_dir, + dataset_name=f"test_dataset_{fmt.replace('.', '_')}", + file_format=fmt, + timeout=60, + retries=1, + ) + + # Verify all files were created with correct extensions + for fmt in formats: + output_file = os.path.join( + temp_dir, f"test_dataset_{fmt.replace('.', '_')}.{fmt}" + ) + assert os.path.exists(output_file) + + def test_download_empty_chunks(self, temp_dir): + """Test handling of empty chunks in response. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + """ + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {"content-length": "1024"} + mock_response.elapsed.total_seconds.return_value = 1.0 + # Include empty chunks (should be skipped) + mock_response.iter_content.return_value = [ + b"x" * 512, + b"", # Empty chunk + b"y" * 512, + b"", # Another empty chunk + ] + + with patch("requests.get", return_value=mock_response): + download_file_from_link( + file_link="http://example.com/dataset.tar.gz", + path_to_save=temp_dir, + dataset_name="test_dataset", + file_format="tar.gz", + timeout=60, + retries=1, + ) + + # File should contain only non-empty chunks + output_file = os.path.join(temp_dir, "test_dataset.tar.gz") + assert os.path.exists(output_file) + assert os.path.getsize(output_file) == 1024 + + def test_download_unknown_size(self, temp_dir): + """Test download when content-length header is missing. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + """ + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {} # No content-length header + mock_response.elapsed.total_seconds.return_value = 0.5 + mock_response.iter_content.return_value = [b"x" * 1024] + + with patch("requests.get", return_value=mock_response): + download_file_from_link( + file_link="http://example.com/dataset.tar.gz", + path_to_save=temp_dir, + dataset_name="test_dataset", + file_format="tar.gz", + timeout=60, + retries=1, + ) + + output_file = os.path.join(temp_dir, "test_dataset.tar.gz") + assert os.path.exists(output_file) + assert os.path.getsize(output_file) == 1024 + + def test_download_ssl_verification_disabled(self, temp_dir, mock_response): + """Test that SSL verification can be disabled. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + mock_response : MagicMock + Mock response object. + """ + mock_response.iter_content.return_value = [b"test content"] + + with patch("requests.get", return_value=mock_response) as mock_get: + download_file_from_link( + file_link="https://example.com/dataset.tar.gz", + path_to_save=temp_dir, + dataset_name="test_dataset", + file_format="tar.gz", + verify=False, + timeout=60, + retries=1, + ) + + # Verify requests.get was called with verify=False + mock_get.assert_called_once() + assert mock_get.call_args[1]["verify"] is False + + def test_download_custom_timeout(self, temp_dir, mock_response): + """Test that custom timeout is used. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + mock_response : MagicMock + Mock response object. + """ + mock_response.iter_content.return_value = [b"test content"] + + with patch("requests.get", return_value=mock_response) as mock_get: + custom_timeout = 120 # 2 minutes per chunk + download_file_from_link( + file_link="https://github.com/aidos-lab/mantra/releases/download/{version}/2_manifolds.json.gz", + path_to_save=temp_dir, + dataset_name="test_dataset", + file_format="tar.gz", + timeout=custom_timeout, + retries=1, + ) + + # Verify requests.get was called with correct timeout + mock_get.assert_called_once() + assert mock_get.call_args[1]["timeout"] == (30, custom_timeout) diff --git a/topobench/data/utils/io_utils.py b/topobench/data/utils/io_utils.py index 372db85e..a20913e8 100644 --- a/topobench/data/utils/io_utils.py +++ b/topobench/data/utils/io_utils.py @@ -1,8 +1,11 @@ """Data IO utilities.""" +import glob import json +import os import os.path as osp import pickle +import time from urllib.parse import parse_qs, urlparse import numpy as np @@ -85,10 +88,19 @@ def download_file_from_drive( def download_file_from_link( - file_link, path_to_save, dataset_name, file_format="tar.gz" + file_link, + path_to_save, + dataset_name, + file_format="tar.gz", + verify=True, + timeout=None, + retries=3, ): """Download a file from a link and saves it to the specified path. + Uses streaming with chunked download and includes retry logic for + resilience against network interruptions. + Parameters ---------- file_link : str @@ -99,20 +111,171 @@ def download_file_from_link( The name of the dataset. file_format : str, optional The format of the downloaded file. Defaults to "tar.gz". + verify : bool, optional + Whether to verify SSL certificates. Defaults to True. + timeout : float, optional + Timeout in seconds per chunk read (not for entire download). For very slow + servers, increase this value. Default: 60 seconds per chunk. + retries : int, optional + Number of retry attempts if download fails. Defaults to 3. + + Notes + ----- + This function downloads files in 5MB chunks for memory efficiency. Progress is + reported every 10MB. Timeouts apply per chunk, not to the entire download, + making it suitable for very large files and slow connections. + + If a download fails, it retries with exponential backoff (5s, 10s, 15s). + + Examples + -------- + Basic download: + + >>> from topobench.data.utils import download_file_from_link + >>> download_file_from_link( + ... file_link="https://example.com/dataset.tar.gz", + ... path_to_save="./data/", + ... dataset_name="my_dataset" + ... ) + + Download with custom timeout for slow servers: + + >>> download_file_from_link( + ... file_link="https://slow-server.com/dataset.zip", + ... path_to_save="./data/", + ... dataset_name="my_dataset", + ... file_format="zip", + ... timeout=300 # 5 minutes per chunk + ... ) + + Download with increased retries for unreliable connections: + + >>> download_file_from_link( + ... file_link="https://example.com/dataset.tar.gz", + ... path_to_save="./data/", + ... dataset_name="my_dataset", + ... retries=5 # Try up to 5 times + ... ) Raises ------ - None + Exception + If download fails after all retry attempts. """ - response = requests.get(file_link) - + # Ensure output directory exists + os.makedirs(path_to_save, exist_ok=True) output_path = f"{path_to_save}/{dataset_name}.{file_format}" - if response.status_code == 200: - with open(output_path, "wb") as f: - f.write(response.content) - print("Download complete.") - else: - print("Failed to download the file.") + + # Default timeout: 60 seconds per chunk read (for very slow servers) + if timeout is None: + timeout = 60 + + for attempt in range(retries): + try: + print( + f"[Download] Starting download from: {file_link} (attempt {attempt + 1}/{retries})" + ) + + # Use tuple (connect_timeout, read_timeout) for proper streaming + response = requests.get( + file_link, + verify=verify, + stream=True, # Force streaming for chunked download + timeout=( + 30, + timeout, + ), # (connect timeout, read timeout per chunk) + ) + + if response.status_code != 200: + print( + f"[Download] Failed to download the file. HTTP {response.status_code}" + ) + return + + # Streaming download with progress reporting + total_size = int(response.headers.get("content-length", 0)) + downloaded = 0 + start_time = time.time() + + if total_size > 0: + print( + f"[Download] Total file size: {total_size / (1024**3):.2f} GB" + ) + else: + print("[Download] Total file size: unknown") + + # Stream download in chunks + chunk_size = 5 * 1024 * 1024 # 5MB chunks for faster throughput + progress_interval = ( + 10 * 1024 * 1024 + ) # Report progress every 10MB (for slow connections) + last_reported = 0 + + with open(output_path, "wb") as f: + for chunk in response.iter_content( + chunk_size=chunk_size, decode_unicode=False + ): + if chunk: + f.write(chunk) + f.flush() # Ensure data is written to disk + downloaded += len(chunk) + + # Print progress every 10MB + if ( + total_size > 0 + and (downloaded - last_reported) + >= progress_interval + ): + percent = (downloaded / total_size) * 100 + remaining = total_size - downloaded + elapsed_time = time.time() - start_time + speed_mbps = (downloaded / (1024**2)) / ( + elapsed_time + 0.001 + ) + + # Calculate ETA + if speed_mbps > 0: + eta_seconds = ( + remaining / (1024**2) / speed_mbps + ) + eta_hours = eta_seconds / 3600 + eta_minutes = (eta_seconds % 3600) / 60 + eta_str = ( + f"{eta_hours:.0f}h {eta_minutes:.0f}m" + ) + else: + eta_str = "calculating..." + + print( + f"[Download] {downloaded / (1024**3):.2f} / {total_size / (1024**3):.2f} GB ({percent:.1f}%) | Speed: {speed_mbps:.2f} MB/s | ETA: {eta_str}" + ) + last_reported = downloaded + + print(f"[Download] Download complete! Saved to: {output_path}") + break + + except ( + requests.exceptions.Timeout, + requests.exceptions.ConnectionError, + Exception, + ) as e: + print( + f"[Download] Download failed with error: {type(e).__name__}: {str(e)}" + ) + if attempt < retries - 1: + wait_time = 5 * ( + attempt + 1 + ) # Exponential backoff: 5s, 10s, 15s + print( + f"[Download] Retrying in {wait_time} seconds... (attempt {attempt + 2}/{retries})" + ) + time.sleep(wait_time) + else: + print( + f"[Download] Failed after {retries} attempts. Please check your connection and try again." + ) + raise e def read_ndim_manifolds( @@ -580,3 +743,113 @@ def load_hypergraph_content_dataset(data_dir, data_name): print("Final num_class", data.num_class) return data, data_dir + + +def collect_mat_files(data_dir: str) -> list: + """Collect all .mat files from a directory recursively. + + Excludes files containing "diffxy" in their names. + + Parameters + ---------- + data_dir : str + Root directory to search for .mat files. + + Returns + ------- + list + Sorted list of .mat file paths. + """ + patterns = [os.path.join(data_dir, "**", "*.mat")] + files = [] + for p in patterns: + files.extend(glob.glob(p, recursive=True)) + files = [f for f in files if "diffxy" not in f] + files.sort() + return files + + +def mat_cell_to_dict(mt) -> dict: + """Convert MATLAB cell array to dictionary. + + Parameters + ---------- + mt : np.ndarray + MATLAB cell array (structured array). + + Returns + ------- + dict + Dictionary with keys from cell array field names and squeezed values. + """ + clean_data = {} + keys = mt.dtype.names + for key_idx, key in enumerate(keys): + clean_data[key] = ( + np.squeeze(mt[key_idx]) + if isinstance(mt[key_idx], np.ndarray) + else mt[key_idx] + ) + return clean_data + + +def planewise_mat_cell_to_dict(mt) -> dict: + """Convert plane-wise MATLAB cell array to nested dictionary. + + Parameters + ---------- + mt : np.ndarray + MATLAB cell array with plane dimension. + + Returns + ------- + dict + Nested dictionary with plane IDs as keys. + """ + clean_data = {} + for plane_id in range(len(mt[0])): + keys = mt[0, plane_id].dtype.names + clean_data[plane_id] = {} + for key_idx, key in enumerate(keys): + clean_data[plane_id][key] = ( + np.squeeze(mt[0, plane_id][key_idx]) + if isinstance(mt[0, plane_id][key_idx], np.ndarray) + else mt[0, plane_id][key_idx] + ) + return clean_data + + +def process_mat(mat_data) -> dict: + """Generate MATLAB data structure into organized dictionary. + + Converts MATLAB cell arrays for BFInfo, CellInfo, CorrInfo, and other + experimental metadata into nested Python dictionaries. + + Parameters + ---------- + mat_data : dict + Dictionary loaded from MATLAB .mat file via scipy.io.loadmat. + + Returns + ------- + dict + Processed data structure with organized BFInfo, CellInfo, CorrInfo, + coordinate arrays, and experimental variables. + """ + mt = {} + mt["BFInfo"] = planewise_mat_cell_to_dict(mat_data["BFinfo"]) + mt["CellInfo"] = planewise_mat_cell_to_dict(mat_data["CellInfo"]) + mt["CorrInfo"] = planewise_mat_cell_to_dict(mat_data["CorrInfo"]) + mt["allZCorrInfo"] = mat_cell_to_dict(mat_data["allZCorrInfo"][0, 0]) + + for cord_key in ["allxc", "allyc", "allzc", "zDFF"]: + mt[cord_key] = {} + for p in range(mat_data[cord_key].shape[0]): + mt[cord_key][p] = mat_data[cord_key][p, 0] + + mt["exptVars"] = mat_cell_to_dict(mat_data["exptVars"][0, 0]) + mt["selectZCorrInfo"] = mat_cell_to_dict(mat_data["selectZCorrInfo"][0, 0]) + mt["stimInfo"] = planewise_mat_cell_to_dict(mat_data["stimInfo"]) + mt["zStuff"] = planewise_mat_cell_to_dict(mat_data["zStuff"]) + + return mt