From 66870eccc43f5a806403b0c6c4fe87f91116c172 Mon Sep 17 00:00:00 2001 From: marindigen Date: Wed, 19 Nov 2025 12:15:34 +0100 Subject: [PATCH 1/6] Added the Auditory Mouse Cortex dataset, loader and python notebook to check and test the training. To be able to download dataset in the function 'download_file_from_link' in requests.get() verify parameter should be specified as False. Note also that currently the run script on the data doesn't run as it fails to download data even if verify parameter set to False --- configs/dataset/graph/a123.yaml | 45 + topobench/data/datasets/a123.py | 361 ++++ topobench/data/loaders/graph/a123_loader.py | 87 + tutorials/tutorial_train_brain_model.ipynb | 1923 +++++++++++++++++++ 4 files changed, 2416 insertions(+) create mode 100644 configs/dataset/graph/a123.yaml create mode 100644 topobench/data/datasets/a123.py create mode 100644 topobench/data/loaders/graph/a123_loader.py create mode 100644 tutorials/tutorial_train_brain_model.ipynb diff --git a/configs/dataset/graph/a123.yaml b/configs/dataset/graph/a123.yaml new file mode 100644 index 000000000..f2673c7fe --- /dev/null +++ b/configs/dataset/graph/a123.yaml @@ -0,0 +1,45 @@ +# Dataset loader config for A123 Cortex M +loader: + _target_: topobench.data.loaders.A123DatasetLoader + parameters: + data_domain: graph + data_type: A123CortexM + data_dir: ${dataset.parameters.data_dir} # Use data_dir from dataset config + data_name: ${dataset.parameters.data_name} # Use data_name from dataset config + num_graphs: 10 + is_undirected: True + num_channels: ${dataset.parameters.num_features} # Use num_features for node feature dim + num_classes: ${dataset.parameters.num_classes} # Use num_classes for output dim + task: ${dataset.parameters.task} # Use task type from dataset config + +# Dataset-specific parameters +parameters: + num_features: 3 + num_classes: 9 + task: classification + loss_type: cross_entropy + monitor_metric: accuracy + task_level: graph + data_name: a123_cortex_m + data_dir: ${paths.data_dir}/graph/a123 + hodge_k: 10 + min_neurons: 3 + corr_threshold: 0.2 + +# Splits +split_params: + learning_setting: inductive + data_split_dir: ${paths.data_dir}/data_splits/${dataset.loader.parameters.data_name} + data_seed: 0 + split_type: random # 'k-fold' or 'random' strategies + k: 10 # for "k-fold" Cross-Validation + train_prop: 0.7 # for "random" strategy splitting + val_prop: 0.15 # for "random" strategy splitting + test_prop: 0.15 # for "random" strategy splitting + +# Dataloader parameters +dataloader_params: + batch_size: 32 + num_workers: 0 + pin_memory: False + diff --git a/topobench/data/datasets/a123.py b/topobench/data/datasets/a123.py new file mode 100644 index 000000000..cd78c1bea --- /dev/null +++ b/topobench/data/datasets/a123.py @@ -0,0 +1,361 @@ +"""Dataset class for Auditory Cortex Mouse dataset.""" + +import os +import os.path as osp +import shutil +from typing import ClassVar + +import numpy as np +import pandas as pd +import scipy.io +import torch +from omegaconf import DictConfig +from torch_geometric.data import Data, InMemoryDataset, extract_zip +from torch_geometric.io import fs +from torch_geometric.utils import to_undirected + +from topobench.data.utils import download_file_from_link +from topobench.data.utils.io_utils import collect_mat_files, process_mat + + +class A123CortexMDataset(InMemoryDataset): + """A1 and A2/3 mouse auditory cortex dataset. + + Loads neural correlation data from mouse auditory cortex regions. + + Parameters + ---------- + root : str + Root directory where the dataset will be saved. + name : str + Name of the dataset. + parameters : DictConfig + Configuration parameters for the dataset including corr_threshold, + n_bins, min_neurons, and hodge_k. + + Attributes + ---------- + URLS : dict + Dictionary containing the URLs for downloading the dataset. + FILE_FORMAT : dict + Dictionary containing the file formats for the dataset. + RAW_FILE_NAMES : dict + Dictionary containing the raw file names for the dataset. + """ + + URLS: ClassVar = { + "a123_cortex_m": "https://gcell.umd.edu/data/Auditory_cortex_data.zip", + } + + FILE_FORMAT: ClassVar = { + "a123_cortex_m": "zip", + } + + RAW_FILE_NAMES: ClassVar = {} + + def __init__( + self, + root: str, + name: str, + parameters: DictConfig, + ) -> None: + self.name = name + self.parameters = parameters + # defensive parameter access with sensible defaults + try: + self.corr_threshold = float(parameters.get("corr_threshold", 0.2)) + except Exception: + self.corr_threshold = float( + getattr(parameters, "corr_threshold", 0.2) + ) + + try: + self.n_bins = int(parameters.get("n_bins", 9)) + except Exception: + self.n_bins = int(getattr(parameters, "n_bins", 9)) + + try: + self.min_neurons = int(parameters.get("min_neurons", 8)) + except Exception: + self.min_neurons = int(getattr(parameters, "min_neurons", 8)) + + # optional parameter controlling how many eigenvalues to compute for Hodge L1 + try: + self.hodge_k = int(parameters.get("hodge_k", 6)) + except Exception: + self.hodge_k = int(getattr(parameters, "hodge_k", 6)) + + self.session_map = {} + super().__init__( + root, + ) + + out = fs.torch_load(self.processed_paths[0]) + assert len(out) == 3 or len(out) == 4 + if len(out) == 3: # Backward compatibility. + data, self.slices, self.sizes = out + data_cls = Data + else: + data, self.slices, self.sizes, data_cls = out + + if not isinstance(data, dict): # Backward compatibility. + self.data = data + else: + self.data = data_cls.from_dict(data) + + # For this dataset we don't assume the internal _data is a torch_geometric Data + # (this dataset exposes helper methods to construct subgraphs on demand). + + def __repr__(self) -> str: + return f"{self.name}(self.root={self.root}, self.name={self.name}, self.parameters={self.parameters}, self.force_reload={self.force_reload})" + + @property + def raw_dir(self) -> str: + """Path to the raw directory of the dataset. + + Returns + ------- + str + Path to the raw directory. + """ + return osp.join(self.root, self.name, "raw") + + @property + def processed_dir(self) -> str: + """Path to the processed directory of the dataset. + + Returns + ------- + str + Path to the processed directory. + """ + return osp.join(self.root, self.name, "processed") + + @property + def raw_file_names(self) -> list[str]: + """Return the raw file names for the dataset. + + Returns + ------- + list[str] + List of raw file names. + """ + return ["Auditory cortex data/"] + + @property + def processed_file_names(self) -> str: + """Return the processed file name for the dataset. + + Returns + ------- + str + Processed file name. + """ + return "data.pt" + + def download(self) -> None: + """Download the dataset from a URL and extract to the raw directory.""" + # Download data from the source + self.url = self.URLS[self.name] + self.file_format = self.FILE_FORMAT[self.name] + + # Use self.name as the downloadable dataset name + download_file_from_link( + file_link=self.url, + path_to_save=self.raw_dir, + dataset_name=self.name, + file_format=self.file_format, + ) + + # Extract zip file + folder = self.raw_dir + filename = f"{self.name}.{self.file_format}" + path = osp.join(folder, filename) + extract_zip(path, folder) + # Delete zip file + os.unlink(path) + + # Move files from osp.join(folder, name_download) to folder + downloaded_dir = osp.join(folder, self.name) + if osp.exists(downloaded_dir): + for file in os.listdir(downloaded_dir): + shutil.move(osp.join(downloaded_dir, file), folder) + # Delete the extracted top-level directory + shutil.rmtree(downloaded_dir) + self.data_dir = folder + + @staticmethod + def extract_samples(data_dir: str, n_bins: int, min_neurons: int = 8): + """Extract subgraph samples from raw .mat files. + + Parameters + ---------- + data_dir : str + Directory containing the raw .mat files. + n_bins : int + Number of frequency bins to use for binning. + min_neurons : int, optional + Minimum number of neurons required per sample. Defaults to 8. + + Returns + ------- + pd.DataFrame + DataFrame containing extracted samples with columns for + session_file, session_id, layer, bf_bin, neuron_indices, + corr, and noise_corr. + """ + mat_files = collect_mat_files(data_dir) + + samples = [] + session_id = 0 + for f in mat_files: + print(f"Processing session {session_id}: {os.path.basename(f)}") + mt = process_mat(scipy.io.loadmat(f)) + for layer in range(1, 6): + scorrs = np.array(mt["selectZCorrInfo"]["SigCorrs"]) + ncorrs = np.array(mt["selectZCorrInfo"]["NoiseCorrsTrial"]) + bfvals = np.array(mt["BFInfo"][layer]["BFval"]).ravel() + if scorrs.size == 0 or bfvals.size == 0: + continue + + bin_ids = bfvals.astype(int) + + for bin_idx in range(n_bins): + sel = np.where(bin_ids == bin_idx)[0] + if len(sel) < min_neurons: + continue + subcorr = scorrs[np.ix_(sel, sel)] + samples.append( + { + "session_file": f, + "session_id": session_id, + "layer": layer, + "bf_bin": int(bin_idx), + "neuron_indices": sel.tolist(), + "corr": subcorr.astype(float), + "noise_corr": ncorrs[np.ix_(sel, sel)].astype( + float + ), + } + ) + session_id += 1 + + samples = pd.DataFrame(samples) + return samples + + def _sample_to_pyg_data( + self, sample: dict, threshold: float = 0.2 + ) -> Data: + """Convert a sample dictionary to a PyTorch Geometric Data object. + + Converts correlation matrices to graph representation with node features + and edges for graph-level classification tasks. + + Parameters + ---------- + sample : dict + Sample dictionary containing 'corr', 'noise_corr', 'session_id', + 'layer', and 'bf_bin' keys. + threshold : float, optional + Correlation threshold for creating edges. Defaults to 0.2. + + Returns + ------- + torch_geometric.data.Data + Data object with node features [mean_corr, std_corr, noise_diag], + edges from thresholded correlation, and label y as integer bf_bin. + """ + corr = np.asarray(sample.get("corr")) + if corr.ndim != 2 or corr.size == 0: + # empty placeholder graph + x = torch.zeros((0, 3), dtype=torch.float) + edge_index = torch.empty((2, 0), dtype=torch.long) + edge_attr = torch.empty((0, 1), dtype=torch.float) + else: + n = corr.shape[0] + # sanitize + corr = np.nan_to_num(corr) + + mean_corr = corr.mean(axis=1) + std_corr = corr.std(axis=1) + noise_diag = np.zeros(n) + if "noise_corr" in sample and sample["noise_corr"] is not None: + nc = np.asarray(sample["noise_corr"]) + if nc.shape == corr.shape: + noise_diag = np.diag(nc) + + x_np = np.vstack([mean_corr, std_corr, noise_diag]).T + x = torch.tensor(x_np, dtype=torch.float) + + # build edges from thresholded correlation (upper triangle) + adj = (corr >= threshold).astype(int) + iu = np.triu_indices(n, k=1) + sel = np.where(adj[iu] == 1)[0] + if sel.size == 0: + edge_index = torch.empty((2, 0), dtype=torch.long) + edge_attr = torch.empty((0, 1), dtype=torch.float) + else: + rows = iu[0][sel] + cols = iu[1][sel] + edge_index_np = np.vstack([rows, cols]) + edge_index = torch.tensor(edge_index_np, dtype=torch.long) + # make undirected + edge_index = to_undirected(edge_index) + # edge_attr: corresponding corr weights (for both directions, if made undirected) + weights = corr[rows, cols] + weights = ( + np.repeat(weights, 2) + if edge_index.size(1) == weights.size * 2 + else weights + ) + edge_attr = torch.tensor( + weights.reshape(-1, 1), dtype=torch.float + ) + + y = torch.tensor([int(sample.get("bf_bin", -1))], dtype=torch.long) + data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y) + # attach metadata + data.session_id = int(sample.get("session_id", -1)) + data.layer = int(sample.get("layer", -1)) + return data + + def process(self) -> None: + """Generate raw files into collated PyG dataset and save to disk. + + This implementation mirrors other datasets in the repo: it calls the + static helper `extract_samples()` to enumerate subgraphs, converts each + to a `torch_geometric.data.Data` object via `_sample_to_pyg_data()`, + optionally computes/attaches topology vectors, collates and saves. + """ + data_dir = self.raw_dir + + print(f"[A123] Processing dataset from: {data_dir}") + print(f"[A123] Files in raw_dir: {os.listdir(data_dir)}") + + # extract sample descriptions + print("[A123] Starting extract_samples()...") + samples = A123CortexMDataset.extract_samples( + data_dir, self.n_bins, self.min_neurons + ) + + print(f"[A123] Extracted {len(samples)} samples") + + data_list = [] + for idx, (_, s) in enumerate(samples.iterrows()): + if idx % 100 == 0: + print( + f"[A123] Converting sample {idx}/{len(samples)} to PyG Data..." + ) + d = self._sample_to_pyg_data(s, threshold=self.corr_threshold) + data_list.append(d) + + # collate and save processed dataset + print(f"[A123] Collating {len(data_list)} samples...") + self.data, self.slices = self.collate(data_list) + self._data_list = None + print(f"[A123] Saving processed data to {self.processed_paths[0]}...") + fs.torch_save( + (self._data.to_dict(), self.slices, {}, self._data.__class__), + self.processed_paths[0], + ) + print("[A123] Processing complete!") diff --git a/topobench/data/loaders/graph/a123_loader.py b/topobench/data/loaders/graph/a123_loader.py new file mode 100644 index 000000000..aaeb83993 --- /dev/null +++ b/topobench/data/loaders/graph/a123_loader.py @@ -0,0 +1,87 @@ +"""A123 dataset loader module.""" + +import torch +from omegaconf import DictConfig + +from topobench.data.datasets.a123 import A123CortexMDataset +from topobench.data.loaders.base import AbstractLoader + + +class A123DatasetLoader(AbstractLoader): + """Loader for A123 mouse auditory cortex dataset. + + Implements the AbstractLoader interface: accepts a DictConfig `parameters` + and implements `load_dataset()` which returns a dataset object. + + Parameters + ---------- + parameters : DictConfig + Configuration parameters for the dataset. + **overrides + Additional keyword arguments to override parameters. + """ + + def __init__(self, parameters: DictConfig, **overrides): + """Initialize the A123 dataset loader. + + Parameters + ---------- + parameters : DictConfig + Configuration parameters for the dataset. + **overrides + Additional keyword arguments to override parameters. + """ + # Initialize AbstractLoader (sets self.parameters and self.root_data_dir) + super().__init__(parameters) + + # hyperparameters can come from the DictConfig or be passed as overrides + params = parameters if parameters is not None else {} + + def _get(k, default): + """Get parameter value from DictConfig or overrides. + + Parameters + ---------- + k : str + Parameter key. + default : Any + Default value if key not found. + + Returns + ------- + Any + Parameter value from DictConfig or overrides, or default. + """ + try: + return params.get(k, overrides.get(k, default)) + except Exception: + # DictConfig may use attribute access + return getattr(params, k, overrides.get(k, default)) + + self.batch_size = int(_get("batch_size", 32)) + # dataset will be created when load_dataset() is called + self.dataset = None + + def load_dataset(self) -> torch.utils.data.Dataset: + """Instantiate and return the underlying dataset. + + Returns a `A123CortexMDataset` instance constructed from the loader's + parameters and root data directory. + + Returns + ------- + torch.utils.data.Dataset + A123CortexMDataset instance. + """ + # determine dataset name from parameters, fallback to expected id + name = self.parameters.data_name + + # root path for dataset: use the root_data_dir (Path) as string + root = str(self.root_data_dir) + + # Construct dataset; A123CortexMDataset expects (root, name, parameters) + self.dataset = A123CortexMDataset( + root=root, name=name, parameters=self.parameters + ) + + return self.dataset diff --git a/tutorials/tutorial_train_brain_model.ipynb b/tutorials/tutorial_train_brain_model.ipynb new file mode 100644 index 000000000..6eb393479 --- /dev/null +++ b/tutorials/tutorial_train_brain_model.ipynb @@ -0,0 +1,1923 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "af53c476", + "metadata": {}, + "source": [ + "# Training TBModel on Auditory Cortex Data for 1 and 2/3 regions.\n", + "\n", + "This notebook demonstrates loading the MUTAG dataset, applying a simple lifting, defining a small backbone, and training a `TBModel` using `TBLoss` and `TBOptimizer`.\n", + "\n", + "Requirements: the project installed in PYTHONPATH and optional dependencies (torch_geometric, networkx, ripser/persim) if you want advanced features." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "98d0adae", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.chdir('..')" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "f9ed7f5f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Imports OK\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/mariayuffa/anaconda3/envs/tb/lib/python3.11/site-packages/outdated/__init__.py:36: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.\n", + " from pkg_resources import parse_version\n" + ] + } + ], + "source": [ + "# 1) Imports\n", + "import torch\n", + "import lightning as pl\n", + "from omegaconf import OmegaConf\n", + "\n", + "# Data loading / preprocessing utilities from the repo\n", + "from topobench.data.loaders.graph.a123_loader import A123DatasetLoader\n", + "from topobench.dataloader.dataloader import TBDataloader\n", + "from topobench.data.preprocessor import PreProcessor\n", + "\n", + "# Model / training building blocks\n", + "from topobench.model.model import TBModel\n", + "# example backbone building block (SCN2 is optional; we provide a tiny custom backbone below)\n", + "# from topomodelx.nn.simplicial.scn2 import SCN2\n", + "from topobench.nn.wrappers.simplicial import SCNWrapper\n", + "from topobench.nn.encoders import AllCellFeatureEncoder\n", + "from topobench.nn.readouts import PropagateSignalDown\n", + "\n", + "# Optimization / evaluation\n", + "from topobench.loss.loss import TBLoss\n", + "from topobench.optimizer import TBOptimizer\n", + "from topobench.evaluator.evaluator import TBEvaluator\n", + "\n", + "print('Imports OK')" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "03042d76", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Configs created\n" + ] + } + ], + "source": [ + "# 2) Configurations and utilities\n", + "loader_config = {\n", + " 'data_domain': 'graph',\n", + " 'data_type': 'A123',\n", + " # the loader/dataset expects the dataset name key used in the dataset class\n", + " 'data_name': 'a123_cortex_m',\n", + " 'data_dir': './data/a123/'\n", + "}\n", + "\n", + "# Transform config: single transform with transform_name and transform_type\n", + "# PreProcessor expects either {\"transform_name\": ...} (single) or {\"key1\": {...}, \"key2\": {...}} (multiple)\n", + "transform_config = {\n", + " 'transform_type': 'lifting',\n", + " 'transform_name': 'HypergraphKHopLifting',\n", + " 'k_value': 1,\n", + "}\n", + "\n", + "split_config = {\n", + " 'learning_setting': 'inductive',\n", + " 'split_type': 'random',\n", + " 'data_seed': 0,\n", + " 'data_split_dir': './data/a123/splits/',\n", + " 'train_prop': 0.5,\n", + "}\n", + "\n", + "# model / task hyperparameters\n", + "# A123 sample node features are: [mean_corr, std_corr, noise_diag] => 3 channels\n", + "in_channels = 3\n", + "# Multiclass classification: 9 frequency bins (bf_bin 0-8)\n", + "out_channels = 9\n", + "dim_hidden = 16\n", + "n_bins = 9 # default binning from extract_samples\n", + "\n", + "readout_config = {\n", + " 'readout_name': 'PropagateSignalDown',\n", + " 'num_cell_dimensions': 1,\n", + " 'hidden_dim': dim_hidden,\n", + " 'out_channels': out_channels,\n", + " 'task_level': 'graph',\n", + " 'pooling_type': 'sum',\n", + "}\n", + "\n", + "loss_config = {\n", + " 'dataset_loss': {\n", + " 'task': 'classification',\n", + " 'loss_type': 'cross_entropy',\n", + " }\n", + "}\n", + "\n", + "evaluator_config = {\n", + " 'task': 'classification',\n", + " 'num_classes': out_channels,\n", + " 'metrics': ['accuracy', 'precision', 'recall'],\n", + "}\n", + "\n", + "optimizer_config = {\n", + " 'optimizer_id': 'Adam',\n", + " 'parameters': {'lr': 0.001, 'weight_decay': 0.0005},\n", + "}\n", + "\n", + "# convert to OmegaConf (the project often expects DictConfig)\n", + "loader_config = OmegaConf.create(loader_config)\n", + "transform_config = OmegaConf.create(transform_config)\n", + "split_config = OmegaConf.create(split_config)\n", + "readout_config = OmegaConf.create(readout_config)\n", + "loss_config = OmegaConf.create(loss_config)\n", + "evaluator_config = OmegaConf.create(evaluator_config)\n", + "optimizer_config = OmegaConf.create(optimizer_config)\n", + "\n", + "print('Configs created')" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "06a33ac7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dataset loaded\n", + "Transform parameters are the same, using existing data_dir: data/a123/a123_cortex_m/transform_type_transform_name_k_value/563224662\n", + "Dataset splits created\n", + "Datasets and datamodule ready\n" + ] + } + ], + "source": [ + "# 3) Loading the data\n", + "\n", + "# Use the A123-specific loader (A123DatasetLoader) to construct the dataset\n", + "graph_loader = A123DatasetLoader(loader_config)\n", + "\n", + "dataset, dataset_dir = graph_loader.load()\n", + "print('Dataset loaded')\n", + "\n", + "preprocessor = PreProcessor(dataset, dataset_dir, transform_config)\n", + "dataset_train, dataset_val, dataset_test = preprocessor.load_dataset_splits(split_config)\n", + "print('Dataset splits created')\n", + "\n", + "# create the TopoBench datamodule / dataloader wrappers\n", + "datamodule = TBDataloader(dataset_train, dataset_val, dataset_test, batch_size=32)\n", + "\n", + "print('Datasets and datamodule ready')" + ] + }, + { + "cell_type": "markdown", + "id": "3b7bc4a8", + "metadata": {}, + "source": [ + "## 4) Backbone definition\n", + "\n", + "We implement a tiny backbone as a `pl.LightningModule` which computes node and hyperedge features: $X_1 = B_1 dot X_0$ and applies two linear layers with ReLU." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "9275c748", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Backbone defined\n" + ] + } + ], + "source": [ + "class MyBackbone(pl.LightningModule):\n", + " def __init__(self, dim_hidden):\n", + " super().__init__()\n", + " self.linear_0 = torch.nn.Linear(dim_hidden, dim_hidden)\n", + " self.linear_1 = torch.nn.Linear(dim_hidden, dim_hidden)\n", + "\n", + " def forward(self, batch):\n", + " # batch.x_0: node features (dense tensor of shape [N, dim_hidden])\n", + " # batch.incidence_hyperedges: sparse incidence matrix with shape [m, n] or [n, m] depending on preprocessor convention\n", + " x_0 = batch.x_0\n", + " incidence_hyperedges = getattr(batch, 'incidence_hyperedges', None)\n", + " if incidence_hyperedges is None:\n", + " # fallback: try incidence as batch.incidence if available\n", + " incidence_hyperedges = getattr(batch, 'incidence', None)\n", + "\n", + " # compute hyperedge features X_1 = B_1 dot X_0 (we assume B_1 is sparse and transposed appropriately)\n", + " x_1 = None\n", + " if incidence_hyperedges is not None:\n", + " try:\n", + " x_1 = torch.sparse.mm(incidence_hyperedges, x_0)\n", + " except Exception:\n", + " # if orientation differs, try transpose\n", + " x_1 = torch.sparse.mm(incidence_hyperedges.T, x_0)\n", + " else:\n", + " # no incidence available: create a zero hyperedge feature placeholder\n", + " x_1 = torch.zeros_like(x_0)\n", + "\n", + " x_0 = self.linear_0(x_0)\n", + " x_0 = torch.relu(x_0)\n", + "\n", + " x_1 = self.linear_1(x_1)\n", + " x_1 = torch.relu(x_1)\n", + "\n", + " model_out = {'labels': batch.y, 'batch_0': getattr(batch, 'batch_0', None)}\n", + " model_out['x_0'] = x_0\n", + " model_out['hyperedge'] = x_1\n", + " return model_out\n", + "\n", + "print('Backbone defined')" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "489bea60", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Components instantiated\n" + ] + } + ], + "source": [ + "# 5) Model initialization (components)\n", + "backbone = MyBackbone(dim_hidden)\n", + "readout = PropagateSignalDown(**readout_config)\n", + "loss = TBLoss(**loss_config)\n", + "feature_encoder = AllCellFeatureEncoder(in_channels=[in_channels], out_channels=dim_hidden)\n", + "evaluator = TBEvaluator(**evaluator_config)\n", + "optimizer = TBOptimizer(**optimizer_config)\n", + "\n", + "print('Components instantiated')" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "366a4200", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "TBModel(backbone=MyBackbone(\n", + " (linear_0): Linear(in_features=16, out_features=16, bias=True)\n", + " (linear_1): Linear(in_features=16, out_features=16, bias=True)\n", + "), readout=PropagateSignalDown(num_cell_dimensions=0, self.hidden_dim=16, readout_name=PropagateSignalDown, loss=TBLoss(losses=[DatasetLoss(task=classification, loss_type=cross_entropy)]), feature_encoder=AllCellFeatureEncoder(in_channels=[3], out_channels=16, dimensions=range(0, 1)))\n" + ] + } + ], + "source": [ + "# 6) Instantiate TBModel\n", + "model = TBModel(backbone=backbone,\n", + " backbone_wrapper=None,\n", + " readout=readout,\n", + " loss=loss,\n", + " feature_encoder=feature_encoder,\n", + " evaluator=evaluator,\n", + " optimizer=optimizer,\n", + " compile=False)\n", + "\n", + "# Print a short summary (repr) to verify construction\n", + "print(model)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "a81da250", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (mps), used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "HPU available: False, using: 0 HPUs\n", + "/Users/mariayuffa/anaconda3/envs/tb/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ab180005f81b4c84b0c4f6c3f0d2eb53", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Sanity Checking: | | 0/? [00:00┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Test metric DataLoader 0 ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ test/accuracy 0.1269841343164444 │\n", + "│ test/loss 2.125241279602051 │\n", + "│ test/precision 0.05082417652010918 │\n", + "│ test/recall 0.125 │\n", + "└───────────────────────────┴───────────────────────────┘\n", + "\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│\u001b[36m \u001b[0m\u001b[36m test/accuracy \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.1269841343164444 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test/loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 2.125241279602051 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test/precision \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.05082417652010918 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test/recall \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.125 \u001b[0m\u001b[35m \u001b[0m│\n", + "└───────────────────────────┴───────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Test metrics:\n", + "test/loss 2.1252\n", + "test/accuracy 0.1270\n", + "test/precision 0.0508\n", + "test/recall 0.1250\n" + ] + } + ], + "source": [ + "# 8) Testing and printing metrics\n", + "trainer.test(model, datamodule)\n", + "test_metrics = trainer.callback_metrics\n", + "print('\\nTest metrics:')\n", + "for key, val in test_metrics.items():\n", + " try:\n", + " print(f'{key:25s} {float(val):.4f}')\n", + " except Exception:\n", + " print(key, val)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "tb", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 71a86f040f454325490b3dee5355aec7fe4b0fbf Mon Sep 17 00:00:00 2001 From: marindigen Date: Wed, 19 Nov 2025 14:22:52 +0100 Subject: [PATCH 2/6] Added functions collect_mat_files, mat_cell_to_dict, planewise_mat_cell_to_dict and process_mat. I have also modified download_file_from_link by specifying verify=False in requests.get() --- topobench/data/utils/io_utils.py | 114 ++++++++++++++++++++++++++++++- 1 file changed, 113 insertions(+), 1 deletion(-) diff --git a/topobench/data/utils/io_utils.py b/topobench/data/utils/io_utils.py index 372db85e6..4aca1ddc5 100644 --- a/topobench/data/utils/io_utils.py +++ b/topobench/data/utils/io_utils.py @@ -1,6 +1,8 @@ """Data IO utilities.""" +import glob import json +import os import os.path as osp import pickle from urllib.parse import parse_qs, urlparse @@ -104,7 +106,7 @@ def download_file_from_link( ------ None """ - response = requests.get(file_link) + response = requests.get(file_link, verify=False) output_path = f"{path_to_save}/{dataset_name}.{file_format}" if response.status_code == 200: @@ -580,3 +582,113 @@ def load_hypergraph_content_dataset(data_dir, data_name): print("Final num_class", data.num_class) return data, data_dir + + +def collect_mat_files(data_dir: str) -> list: + """Collect all .mat files from a directory recursively. + + Excludes files containing "diffxy" in their names. + + Parameters + ---------- + data_dir : str + Root directory to search for .mat files. + + Returns + ------- + list + Sorted list of .mat file paths. + """ + patterns = [os.path.join(data_dir, "**", "*.mat")] + files = [] + for p in patterns: + files.extend(glob.glob(p, recursive=True)) + files = [f for f in files if "diffxy" not in f] + files.sort() + return files + + +def mat_cell_to_dict(mt) -> dict: + """Convert MATLAB cell array to dictionary. + + Parameters + ---------- + mt : np.ndarray + MATLAB cell array (structured array). + + Returns + ------- + dict + Dictionary with keys from cell array field names and squeezed values. + """ + clean_data = {} + keys = mt.dtype.names + for key_idx, key in enumerate(keys): + clean_data[key] = ( + np.squeeze(mt[key_idx]) + if isinstance(mt[key_idx], np.ndarray) + else mt[key_idx] + ) + return clean_data + + +def planewise_mat_cell_to_dict(mt) -> dict: + """Convert plane-wise MATLAB cell array to nested dictionary. + + Parameters + ---------- + mt : np.ndarray + MATLAB cell array with plane dimension. + + Returns + ------- + dict + Nested dictionary with plane IDs as keys. + """ + clean_data = {} + for plane_id in range(len(mt[0])): + keys = mt[0, plane_id].dtype.names + clean_data[plane_id] = {} + for key_idx, key in enumerate(keys): + clean_data[plane_id][key] = ( + np.squeeze(mt[0, plane_id][key_idx]) + if isinstance(mt[0, plane_id][key_idx], np.ndarray) + else mt[0, plane_id][key_idx] + ) + return clean_data + + +def process_mat(mat_data) -> dict: + """Generate MATLAB data structure into organized dictionary. + + Converts MATLAB cell arrays for BFInfo, CellInfo, CorrInfo, and other + experimental metadata into nested Python dictionaries. + + Parameters + ---------- + mat_data : dict + Dictionary loaded from MATLAB .mat file via scipy.io.loadmat. + + Returns + ------- + dict + Processed data structure with organized BFInfo, CellInfo, CorrInfo, + coordinate arrays, and experimental variables. + """ + mt = {} + mt["BFInfo"] = planewise_mat_cell_to_dict(mat_data["BFinfo"]) + mt["CellInfo"] = planewise_mat_cell_to_dict(mat_data["CellInfo"]) + mt["CorrInfo"] = planewise_mat_cell_to_dict(mat_data["CorrInfo"]) + mt["allZCorrInfo"] = mat_cell_to_dict(mat_data["allZCorrInfo"][0, 0]) + + for cord_key in ["allxc", "allyc", "allzc", "zDFF"]: + mt[cord_key] = {} + for p in range(mat_data[cord_key].shape[0]): + mt[cord_key][p] = mat_data[cord_key][p, 0] + + mt["exptVars"] = mat_cell_to_dict(mat_data["exptVars"][0, 0]) + mt["selectZCorrInfo"] = mat_cell_to_dict(mat_data["selectZCorrInfo"][0, 0]) + mt["stimInfo"] = planewise_mat_cell_to_dict(mat_data["stimInfo"]) + mt["zStuff"] = planewise_mat_cell_to_dict(mat_data["zStuff"]) + + return mt From 31d8fb4e95a65cff50995ae52ae2cb6f6026b99c Mon Sep 17 00:00:00 2001 From: marindigen Date: Thu, 20 Nov 2025 14:20:59 +0100 Subject: [PATCH 3/6] Added verify flag to the download_from_link_function, and included this flag in the config file --- configs/dataset/graph/a123.yaml | 1 + topobench/data/datasets/a123.py | 1 + 2 files changed, 2 insertions(+) diff --git a/configs/dataset/graph/a123.yaml b/configs/dataset/graph/a123.yaml index f2673c7fe..5a855231b 100644 --- a/configs/dataset/graph/a123.yaml +++ b/configs/dataset/graph/a123.yaml @@ -25,6 +25,7 @@ parameters: hodge_k: 10 min_neurons: 3 corr_threshold: 0.2 + use_stream_download: True # Splits split_params: diff --git a/topobench/data/datasets/a123.py b/topobench/data/datasets/a123.py index cd78c1bea..5f82811a5 100644 --- a/topobench/data/datasets/a123.py +++ b/topobench/data/datasets/a123.py @@ -165,6 +165,7 @@ def download(self) -> None: path_to_save=self.raw_dir, dataset_name=self.name, file_format=self.file_format, + verify=False, ) # Extract zip file From 495c253790d3e9c6858c96802c58a68ecf2e0a26 Mon Sep 17 00:00:00 2001 From: marindigen Date: Sun, 23 Nov 2025 15:00:56 +0100 Subject: [PATCH 4/6] Added test for download_file_from_link function --- configs/dataset/graph/a123.yaml | 17 +- test/utils/test_io_utils.py | 342 ++++++++++++++++++++ topobench/data/datasets/a123.py | 35 +- topobench/data/loaders/graph/a123_loader.py | 20 +- topobench/data/utils/io_utils.py | 140 +++++++- 5 files changed, 532 insertions(+), 22 deletions(-) create mode 100644 test/utils/test_io_utils.py diff --git a/configs/dataset/graph/a123.yaml b/configs/dataset/graph/a123.yaml index 5a855231b..9d0d4f3b5 100644 --- a/configs/dataset/graph/a123.yaml +++ b/configs/dataset/graph/a123.yaml @@ -1,10 +1,23 @@ +# Config file for loading Bowen et al. mouse auditory cortex calcium imaging dataset. + +# This script downloads and processes the original dataset introduced in: + +# [Citation] Bowen et al. (2024), "Fractured columnar small-world functional network +# organization in volumes of L2/3 of mouse auditory cortex," PNAS Nexus, 3(2): pgae074. +# https://doi.org/10.1093/pnasnexus/pgae074 + +# We apply the preprocessing and graph-construction steps defined in this module to obtain +# a representation of neuronal activity suitable for our experiments. + +# Please cite the original paper when using this dataset or any derivatives. + # Dataset loader config for A123 Cortex M loader: _target_: topobench.data.loaders.A123DatasetLoader parameters: data_domain: graph data_type: A123CortexM - data_dir: ${dataset.parameters.data_dir} # Use data_dir from dataset config + data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain}/${dataset.loader.parameters.data_type} # Use data_dir from dataset config data_name: ${dataset.parameters.data_name} # Use data_name from dataset config num_graphs: 10 is_undirected: True @@ -21,11 +34,9 @@ parameters: monitor_metric: accuracy task_level: graph data_name: a123_cortex_m - data_dir: ${paths.data_dir}/graph/a123 hodge_k: 10 min_neurons: 3 corr_threshold: 0.2 - use_stream_download: True # Splits split_params: diff --git a/test/utils/test_io_utils.py b/test/utils/test_io_utils.py new file mode 100644 index 000000000..39594f201 --- /dev/null +++ b/test/utils/test_io_utils.py @@ -0,0 +1,342 @@ +"""Tests for data IO utilities.""" + +import os +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from topobench.data.utils.io_utils import download_file_from_link + + +class TestDownloadFileFromLink: + """Test suite for download_file_from_link function.""" + + @pytest.fixture + def temp_dir(self): + """Create temporary directory for test outputs. + + Returns + ------- + str + Path to temporary directory. + """ + with tempfile.TemporaryDirectory() as tmpdir: + yield tmpdir + + @pytest.fixture + def mock_response(self): + """Create mock response object. + + Returns + ------- + MagicMock + Mock response object with status code and headers. + """ + response = MagicMock() + response.status_code = 200 + response.headers = {"content-length": "5242880"} # 5 MB + response.elapsed.total_seconds.return_value = 1.0 + return response + + def test_download_success_with_progress(self, temp_dir, mock_response): + """Test successful download with progress reporting. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + mock_response : MagicMock + Mock response object. + """ + # Setup mock chunks (5MB total in 1MB chunks) + chunk_data = [b"x" * (1024 * 1024) for _ in range(5)] + mock_response.iter_content.return_value = chunk_data + + with patch("requests.get", return_value=mock_response): + download_file_from_link( + file_link="http://example.com/dataset.tar.gz", + path_to_save=temp_dir, + dataset_name="test_dataset", + file_format="tar.gz", + timeout=60, + retries=1, + ) + + # Verify file was created and has correct size + output_file = os.path.join(temp_dir, "test_dataset.tar.gz") + assert os.path.exists(output_file) + assert os.path.getsize(output_file) == 5 * 1024 * 1024 + + def test_download_creates_directory_if_not_exists(self, temp_dir): + """Test that download creates directory structure. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + """ + nested_dir = os.path.join(temp_dir, "nested", "path") + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {"content-length": "1024"} + mock_response.elapsed.total_seconds.return_value = 0.5 + mock_response.iter_content.return_value = [b"x" * 1024] + + with patch("requests.get", return_value=mock_response): + download_file_from_link( + file_link="http://example.com/dataset.tar.gz", + path_to_save=nested_dir, + dataset_name="test_dataset", + file_format="tar.gz", + timeout=60, + retries=1, + ) + + output_file = os.path.join(nested_dir, "test_dataset.tar.gz") + assert os.path.exists(output_file) + assert os.path.isdir(nested_dir) + + def test_download_http_error(self, temp_dir): + """Test handling of HTTP error responses. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + """ + mock_response = MagicMock() + mock_response.status_code = 404 + + with patch("requests.get", return_value=mock_response): + download_file_from_link( + file_link="http://example.com/nonexistent.tar.gz", + path_to_save=temp_dir, + dataset_name="test_dataset", + file_format="tar.gz", + timeout=60, + retries=1, + ) + + # File should not be created on HTTP error + output_file = os.path.join(temp_dir, "test_dataset.tar.gz") + assert not os.path.exists(output_file) + + def test_download_timeout_retry(self, temp_dir): + """Test retry logic on timeout. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + """ + import requests + + with patch("requests.get") as mock_get: + # First call times out, second succeeds + mock_response_success = MagicMock() + mock_response_success.status_code = 200 + mock_response_success.headers = {"content-length": "1024"} + mock_response_success.elapsed.total_seconds.return_value = 0.5 + mock_response_success.iter_content.return_value = [b"x" * 1024] + + mock_get.side_effect = [ + requests.exceptions.Timeout("Connection timed out"), + mock_response_success, + ] + + with patch("time.sleep"): # Mock sleep to speed up test + download_file_from_link( + file_link="http://example.com/dataset.tar.gz", + path_to_save=temp_dir, + dataset_name="test_dataset", + file_format="tar.gz", + timeout=60, + retries=3, + ) + + # File should be created on successful retry + output_file = os.path.join(temp_dir, "test_dataset.tar.gz") + assert os.path.exists(output_file) + assert mock_get.call_count == 2 + + def test_download_exhausts_retries(self, temp_dir): + """Test that exception is raised after all retries exhausted. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + """ + import requests + + with patch("requests.get") as mock_get: + mock_get.side_effect = requests.exceptions.Timeout( + "Connection timed out" + ) + + with patch("time.sleep"): + with pytest.raises(requests.exceptions.Timeout): + download_file_from_link( + file_link="http://example.com/dataset.tar.gz", + path_to_save=temp_dir, + dataset_name="test_dataset", + file_format="tar.gz", + timeout=60, + retries=2, + ) + + # Verify retries were attempted + assert mock_get.call_count == 2 + + def test_download_with_different_formats(self, temp_dir, mock_response): + """Test download with different file formats. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + mock_response : MagicMock + Mock response object. + """ + mock_response.iter_content.return_value = [b"test content"] + + formats = ["zip", "tar", "tar.gz"] + + with patch("requests.get", return_value=mock_response): + for fmt in formats: + download_file_from_link( + file_link="http://example.com/dataset", + path_to_save=temp_dir, + dataset_name=f"test_dataset_{fmt.replace('.', '_')}", + file_format=fmt, + timeout=60, + retries=1, + ) + + # Verify all files were created with correct extensions + for fmt in formats: + output_file = os.path.join( + temp_dir, f"test_dataset_{fmt.replace('.', '_')}.{fmt}" + ) + assert os.path.exists(output_file) + + def test_download_empty_chunks(self, temp_dir): + """Test handling of empty chunks in response. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + """ + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {"content-length": "1024"} + mock_response.elapsed.total_seconds.return_value = 1.0 + # Include empty chunks (should be skipped) + mock_response.iter_content.return_value = [ + b"x" * 512, + b"", # Empty chunk + b"y" * 512, + b"", # Another empty chunk + ] + + with patch("requests.get", return_value=mock_response): + download_file_from_link( + file_link="http://example.com/dataset.tar.gz", + path_to_save=temp_dir, + dataset_name="test_dataset", + file_format="tar.gz", + timeout=60, + retries=1, + ) + + # File should contain only non-empty chunks + output_file = os.path.join(temp_dir, "test_dataset.tar.gz") + assert os.path.exists(output_file) + assert os.path.getsize(output_file) == 1024 + + def test_download_unknown_size(self, temp_dir): + """Test download when content-length header is missing. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + """ + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {} # No content-length header + mock_response.elapsed.total_seconds.return_value = 0.5 + mock_response.iter_content.return_value = [b"x" * 1024] + + with patch("requests.get", return_value=mock_response): + download_file_from_link( + file_link="http://example.com/dataset.tar.gz", + path_to_save=temp_dir, + dataset_name="test_dataset", + file_format="tar.gz", + timeout=60, + retries=1, + ) + + output_file = os.path.join(temp_dir, "test_dataset.tar.gz") + assert os.path.exists(output_file) + assert os.path.getsize(output_file) == 1024 + + def test_download_ssl_verification_disabled(self, temp_dir, mock_response): + """Test that SSL verification can be disabled. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + mock_response : MagicMock + Mock response object. + """ + mock_response.iter_content.return_value = [b"test content"] + + with patch("requests.get", return_value=mock_response) as mock_get: + download_file_from_link( + file_link="https://example.com/dataset.tar.gz", + path_to_save=temp_dir, + dataset_name="test_dataset", + file_format="tar.gz", + verify=False, + timeout=60, + retries=1, + ) + + # Verify requests.get was called with verify=False + mock_get.assert_called_once() + assert mock_get.call_args[1]["verify"] is False + + def test_download_custom_timeout(self, temp_dir, mock_response): + """Test that custom timeout is used. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + mock_response : MagicMock + Mock response object. + """ + mock_response.iter_content.return_value = [b"test content"] + + with patch("requests.get", return_value=mock_response) as mock_get: + custom_timeout = 120 # 2 minutes per chunk + download_file_from_link( + file_link="https://github.com/aidos-lab/mantra/releases/download/{version}/2_manifolds.json.gz", + path_to_save=temp_dir, + dataset_name="test_dataset", + file_format="tar.gz", + timeout=custom_timeout, + retries=1, + ) + + # Verify requests.get was called with correct timeout + mock_get.assert_called_once() + assert mock_get.call_args[1]["timeout"] == (30, custom_timeout) + diff --git a/topobench/data/datasets/a123.py b/topobench/data/datasets/a123.py index 5f82811a5..7d307761d 100644 --- a/topobench/data/datasets/a123.py +++ b/topobench/data/datasets/a123.py @@ -1,4 +1,17 @@ -"""Dataset class for Auditory Cortex Mouse dataset.""" +""" +Dataset class for the Bowen et al. mouse auditory cortex calcium imaging dataset. + +This script downloads and processes the original dataset introduced in: + +[Citation] Bowen et al. (2024), "Fractured columnar small-world functional network +organization in volumes of L2/3 of mouse auditory cortex," PNAS Nexus, 3(2): pgae074. +https://doi.org/10.1093/pnasnexus/pgae074 + +We apply the preprocessing and graph-construction steps defined in this module to obtain +a representation of neuronal activity suitable for our experiments. + +Please cite the original paper when using this dataset or any derivatives. +""" import os import os.path as osp @@ -44,11 +57,11 @@ class A123CortexMDataset(InMemoryDataset): """ URLS: ClassVar = { - "a123_cortex_m": "https://gcell.umd.edu/data/Auditory_cortex_data.zip", + "Auditory cortex data": "https://gcell.umd.edu/data/Auditory_cortex_data.zip", } FILE_FORMAT: ClassVar = { - "a123_cortex_m": "zip", + "Auditory cortex data": "zip", } RAW_FILE_NAMES: ClassVar = {} @@ -156,8 +169,9 @@ def processed_file_names(self) -> str: def download(self) -> None: """Download the dataset from a URL and extract to the raw directory.""" # Download data from the source - self.url = self.URLS[self.name] - self.file_format = self.FILE_FORMAT[self.name] + dataset_key = "Auditory cortex data" + self.url = self.URLS[dataset_key] + self.file_format = self.FILE_FORMAT[dataset_key] # Use self.name as the downloadable dataset name download_file_from_link( @@ -166,6 +180,8 @@ def download(self) -> None: dataset_name=self.name, file_format=self.file_format, verify=False, + timeout=60, # 60 seconds per chunk read timeout + retries=3, # Retry up to 3 times ) # Extract zip file @@ -176,11 +192,16 @@ def download(self) -> None: # Delete zip file os.unlink(path) - # Move files from osp.join(folder, name_download) to folder + # Move files from extracted "Auditory cortex data/" directory to raw_dir downloaded_dir = osp.join(folder, self.name) if osp.exists(downloaded_dir): for file in os.listdir(downloaded_dir): - shutil.move(osp.join(downloaded_dir, file), folder) + src = osp.join(downloaded_dir, file) + dst = osp.join(folder, file) + if osp.isdir(src): + shutil.copytree(src, dst, dirs_exist_ok=True) + else: + shutil.move(src, dst) # Delete the extracted top-level directory shutil.rmtree(downloaded_dir) self.data_dir = folder diff --git a/topobench/data/loaders/graph/a123_loader.py b/topobench/data/loaders/graph/a123_loader.py index aaeb83993..ab01d1484 100644 --- a/topobench/data/loaders/graph/a123_loader.py +++ b/topobench/data/loaders/graph/a123_loader.py @@ -1,4 +1,17 @@ -"""A123 dataset loader module.""" +""" +Data loader for the Bowen et al. mouse auditory cortex calcium imaging dataset. + +This script downloads and processes the original dataset introduced in: + +[Citation] Bowen et al. (2024), "Fractured columnar small-world functional network +organization in volumes of L2/3 of mouse auditory cortex," PNAS Nexus, 3(2): pgae074. +https://doi.org/10.1093/pnasnexus/pgae074 + +We apply the preprocessing and graph-construction steps defined in this module to obtain +a representation of neuronal activity suitable for our experiments. + +Please cite the original paper when using this dataset or any derivatives. +""" import torch from omegaconf import DictConfig @@ -76,8 +89,9 @@ def load_dataset(self) -> torch.utils.data.Dataset: # determine dataset name from parameters, fallback to expected id name = self.parameters.data_name - # root path for dataset: use the root_data_dir (Path) as string - root = str(self.root_data_dir) + # root path for dataset: use the parent of root_data_dir since the dataset + # constructs its own subdirectory based on name + root = str(self.root_data_dir.parent) # Construct dataset; A123CortexMDataset expects (root, name, parameters) self.dataset = A123CortexMDataset( diff --git a/topobench/data/utils/io_utils.py b/topobench/data/utils/io_utils.py index 4aca1ddc5..5839b1db9 100644 --- a/topobench/data/utils/io_utils.py +++ b/topobench/data/utils/io_utils.py @@ -5,6 +5,7 @@ import os import os.path as osp import pickle +import time from urllib.parse import parse_qs, urlparse import numpy as np @@ -87,10 +88,19 @@ def download_file_from_drive( def download_file_from_link( - file_link, path_to_save, dataset_name, file_format="tar.gz" + file_link, + path_to_save, + dataset_name, + file_format="tar.gz", + verify=True, + timeout=None, + retries=3, ): """Download a file from a link and saves it to the specified path. + Uses streaming with chunked download and includes retry logic for + resilience against network interruptions. + Parameters ---------- file_link : str @@ -101,20 +111,132 @@ def download_file_from_link( The name of the dataset. file_format : str, optional The format of the downloaded file. Defaults to "tar.gz". + verify : bool, optional + Whether to verify SSL certificates. Defaults to True. + timeout : float, optional + Timeout in seconds per chunk read (not for entire download). For very slow + servers, increase this value. Default: 60 seconds per chunk. + retries : int, optional + Number of retry attempts if download fails. Defaults to 3. Raises ------ None """ - response = requests.get(file_link, verify=False) - + # Ensure output directory exists + os.makedirs(path_to_save, exist_ok=True) output_path = f"{path_to_save}/{dataset_name}.{file_format}" - if response.status_code == 200: - with open(output_path, "wb") as f: - f.write(response.content) - print("Download complete.") - else: - print("Failed to download the file.") + + # Default timeout: 60 seconds per chunk read (for very slow servers) + if timeout is None: + timeout = 60 + + for attempt in range(retries): + try: + print( + f"[Download] Starting download from: {file_link} (attempt {attempt + 1}/{retries})" + ) + + # Use tuple (connect_timeout, read_timeout) for proper streaming + response = requests.get( + file_link, + verify=verify, + stream=True, # Force streaming for chunked download + timeout=( + 30, + timeout, + ), # (connect timeout, read timeout per chunk) + ) + + if response.status_code != 200: + print( + f"[Download] Failed to download the file. HTTP {response.status_code}" + ) + return + + # Streaming download with progress reporting + total_size = int(response.headers.get("content-length", 0)) + downloaded = 0 + start_time = time.time() + + if total_size > 0: + print( + f"[Download] Total file size: {total_size / (1024**3):.2f} GB" + ) + else: + print("[Download] Total file size: unknown") + + # Stream download in chunks + chunk_size = 5 * 1024 * 1024 # 5MB chunks for faster throughput + progress_interval = ( + 10 * 1024 * 1024 + ) # Report progress every 10MB (for slow connections) + last_reported = 0 + + with open(output_path, "wb") as f: + for chunk in response.iter_content( + chunk_size=chunk_size, decode_unicode=False + ): + if chunk: + f.write(chunk) + f.flush() # Ensure data is written to disk + downloaded += len(chunk) + + # Print progress every 10MB + if ( + total_size > 0 + and (downloaded - last_reported) + >= progress_interval + ): + percent = (downloaded / total_size) * 100 + remaining = total_size - downloaded + elapsed_time = time.time() - start_time + speed_mbps = (downloaded / (1024**2)) / ( + elapsed_time + 0.001 + ) + + # Calculate ETA + if speed_mbps > 0: + eta_seconds = ( + remaining / (1024**2) / speed_mbps + ) + eta_hours = eta_seconds / 3600 + eta_minutes = (eta_seconds % 3600) / 60 + eta_str = ( + f"{eta_hours:.0f}h {eta_minutes:.0f}m" + ) + else: + eta_str = "calculating..." + + print( + f"[Download] {downloaded / (1024**3):.2f} / {total_size / (1024**3):.2f} GB ({percent:.1f}%) | Speed: {speed_mbps:.2f} MB/s | ETA: {eta_str}" + ) + last_reported = downloaded + + print(f"[Download] Download complete! Saved to: {output_path}") + break + + except ( + requests.exceptions.Timeout, + requests.exceptions.ConnectionError, + Exception, + ) as e: + print( + f"[Download] Download failed with error: {type(e).__name__}: {str(e)}" + ) + if attempt < retries - 1: + wait_time = 5 * ( + attempt + 1 + ) # Exponential backoff: 5s, 10s, 15s + print( + f"[Download] Retrying in {wait_time} seconds... (attempt {attempt + 2}/{retries})" + ) + time.sleep(wait_time) + else: + print( + f"[Download] Failed after {retries} attempts. Please check your connection and try again." + ) + raise e def read_ndim_manifolds( From 202924600d46516b77cbb337c1df9e06e01792e1 Mon Sep 17 00:00:00 2001 From: marindigen Date: Mon, 24 Nov 2025 16:07:28 +0100 Subject: [PATCH 5/6] Deleted changes related to the challenge. --- configs/dataset/graph/a123.yaml | 57 - topobench/data/datasets/a123.py | 383 ---- topobench/data/loaders/graph/a123_loader.py | 101 - tutorials/tutorial_train_brain_model.ipynb | 1923 ------------------- 4 files changed, 2464 deletions(-) delete mode 100644 configs/dataset/graph/a123.yaml delete mode 100644 topobench/data/datasets/a123.py delete mode 100644 topobench/data/loaders/graph/a123_loader.py delete mode 100644 tutorials/tutorial_train_brain_model.ipynb diff --git a/configs/dataset/graph/a123.yaml b/configs/dataset/graph/a123.yaml deleted file mode 100644 index 9d0d4f3b5..000000000 --- a/configs/dataset/graph/a123.yaml +++ /dev/null @@ -1,57 +0,0 @@ -# Config file for loading Bowen et al. mouse auditory cortex calcium imaging dataset. - -# This script downloads and processes the original dataset introduced in: - -# [Citation] Bowen et al. (2024), "Fractured columnar small-world functional network -# organization in volumes of L2/3 of mouse auditory cortex," PNAS Nexus, 3(2): pgae074. -# https://doi.org/10.1093/pnasnexus/pgae074 - -# We apply the preprocessing and graph-construction steps defined in this module to obtain -# a representation of neuronal activity suitable for our experiments. - -# Please cite the original paper when using this dataset or any derivatives. - -# Dataset loader config for A123 Cortex M -loader: - _target_: topobench.data.loaders.A123DatasetLoader - parameters: - data_domain: graph - data_type: A123CortexM - data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain}/${dataset.loader.parameters.data_type} # Use data_dir from dataset config - data_name: ${dataset.parameters.data_name} # Use data_name from dataset config - num_graphs: 10 - is_undirected: True - num_channels: ${dataset.parameters.num_features} # Use num_features for node feature dim - num_classes: ${dataset.parameters.num_classes} # Use num_classes for output dim - task: ${dataset.parameters.task} # Use task type from dataset config - -# Dataset-specific parameters -parameters: - num_features: 3 - num_classes: 9 - task: classification - loss_type: cross_entropy - monitor_metric: accuracy - task_level: graph - data_name: a123_cortex_m - hodge_k: 10 - min_neurons: 3 - corr_threshold: 0.2 - -# Splits -split_params: - learning_setting: inductive - data_split_dir: ${paths.data_dir}/data_splits/${dataset.loader.parameters.data_name} - data_seed: 0 - split_type: random # 'k-fold' or 'random' strategies - k: 10 # for "k-fold" Cross-Validation - train_prop: 0.7 # for "random" strategy splitting - val_prop: 0.15 # for "random" strategy splitting - test_prop: 0.15 # for "random" strategy splitting - -# Dataloader parameters -dataloader_params: - batch_size: 32 - num_workers: 0 - pin_memory: False - diff --git a/topobench/data/datasets/a123.py b/topobench/data/datasets/a123.py deleted file mode 100644 index 7d307761d..000000000 --- a/topobench/data/datasets/a123.py +++ /dev/null @@ -1,383 +0,0 @@ -""" -Dataset class for the Bowen et al. mouse auditory cortex calcium imaging dataset. - -This script downloads and processes the original dataset introduced in: - -[Citation] Bowen et al. (2024), "Fractured columnar small-world functional network -organization in volumes of L2/3 of mouse auditory cortex," PNAS Nexus, 3(2): pgae074. -https://doi.org/10.1093/pnasnexus/pgae074 - -We apply the preprocessing and graph-construction steps defined in this module to obtain -a representation of neuronal activity suitable for our experiments. - -Please cite the original paper when using this dataset or any derivatives. -""" - -import os -import os.path as osp -import shutil -from typing import ClassVar - -import numpy as np -import pandas as pd -import scipy.io -import torch -from omegaconf import DictConfig -from torch_geometric.data import Data, InMemoryDataset, extract_zip -from torch_geometric.io import fs -from torch_geometric.utils import to_undirected - -from topobench.data.utils import download_file_from_link -from topobench.data.utils.io_utils import collect_mat_files, process_mat - - -class A123CortexMDataset(InMemoryDataset): - """A1 and A2/3 mouse auditory cortex dataset. - - Loads neural correlation data from mouse auditory cortex regions. - - Parameters - ---------- - root : str - Root directory where the dataset will be saved. - name : str - Name of the dataset. - parameters : DictConfig - Configuration parameters for the dataset including corr_threshold, - n_bins, min_neurons, and hodge_k. - - Attributes - ---------- - URLS : dict - Dictionary containing the URLs for downloading the dataset. - FILE_FORMAT : dict - Dictionary containing the file formats for the dataset. - RAW_FILE_NAMES : dict - Dictionary containing the raw file names for the dataset. - """ - - URLS: ClassVar = { - "Auditory cortex data": "https://gcell.umd.edu/data/Auditory_cortex_data.zip", - } - - FILE_FORMAT: ClassVar = { - "Auditory cortex data": "zip", - } - - RAW_FILE_NAMES: ClassVar = {} - - def __init__( - self, - root: str, - name: str, - parameters: DictConfig, - ) -> None: - self.name = name - self.parameters = parameters - # defensive parameter access with sensible defaults - try: - self.corr_threshold = float(parameters.get("corr_threshold", 0.2)) - except Exception: - self.corr_threshold = float( - getattr(parameters, "corr_threshold", 0.2) - ) - - try: - self.n_bins = int(parameters.get("n_bins", 9)) - except Exception: - self.n_bins = int(getattr(parameters, "n_bins", 9)) - - try: - self.min_neurons = int(parameters.get("min_neurons", 8)) - except Exception: - self.min_neurons = int(getattr(parameters, "min_neurons", 8)) - - # optional parameter controlling how many eigenvalues to compute for Hodge L1 - try: - self.hodge_k = int(parameters.get("hodge_k", 6)) - except Exception: - self.hodge_k = int(getattr(parameters, "hodge_k", 6)) - - self.session_map = {} - super().__init__( - root, - ) - - out = fs.torch_load(self.processed_paths[0]) - assert len(out) == 3 or len(out) == 4 - if len(out) == 3: # Backward compatibility. - data, self.slices, self.sizes = out - data_cls = Data - else: - data, self.slices, self.sizes, data_cls = out - - if not isinstance(data, dict): # Backward compatibility. - self.data = data - else: - self.data = data_cls.from_dict(data) - - # For this dataset we don't assume the internal _data is a torch_geometric Data - # (this dataset exposes helper methods to construct subgraphs on demand). - - def __repr__(self) -> str: - return f"{self.name}(self.root={self.root}, self.name={self.name}, self.parameters={self.parameters}, self.force_reload={self.force_reload})" - - @property - def raw_dir(self) -> str: - """Path to the raw directory of the dataset. - - Returns - ------- - str - Path to the raw directory. - """ - return osp.join(self.root, self.name, "raw") - - @property - def processed_dir(self) -> str: - """Path to the processed directory of the dataset. - - Returns - ------- - str - Path to the processed directory. - """ - return osp.join(self.root, self.name, "processed") - - @property - def raw_file_names(self) -> list[str]: - """Return the raw file names for the dataset. - - Returns - ------- - list[str] - List of raw file names. - """ - return ["Auditory cortex data/"] - - @property - def processed_file_names(self) -> str: - """Return the processed file name for the dataset. - - Returns - ------- - str - Processed file name. - """ - return "data.pt" - - def download(self) -> None: - """Download the dataset from a URL and extract to the raw directory.""" - # Download data from the source - dataset_key = "Auditory cortex data" - self.url = self.URLS[dataset_key] - self.file_format = self.FILE_FORMAT[dataset_key] - - # Use self.name as the downloadable dataset name - download_file_from_link( - file_link=self.url, - path_to_save=self.raw_dir, - dataset_name=self.name, - file_format=self.file_format, - verify=False, - timeout=60, # 60 seconds per chunk read timeout - retries=3, # Retry up to 3 times - ) - - # Extract zip file - folder = self.raw_dir - filename = f"{self.name}.{self.file_format}" - path = osp.join(folder, filename) - extract_zip(path, folder) - # Delete zip file - os.unlink(path) - - # Move files from extracted "Auditory cortex data/" directory to raw_dir - downloaded_dir = osp.join(folder, self.name) - if osp.exists(downloaded_dir): - for file in os.listdir(downloaded_dir): - src = osp.join(downloaded_dir, file) - dst = osp.join(folder, file) - if osp.isdir(src): - shutil.copytree(src, dst, dirs_exist_ok=True) - else: - shutil.move(src, dst) - # Delete the extracted top-level directory - shutil.rmtree(downloaded_dir) - self.data_dir = folder - - @staticmethod - def extract_samples(data_dir: str, n_bins: int, min_neurons: int = 8): - """Extract subgraph samples from raw .mat files. - - Parameters - ---------- - data_dir : str - Directory containing the raw .mat files. - n_bins : int - Number of frequency bins to use for binning. - min_neurons : int, optional - Minimum number of neurons required per sample. Defaults to 8. - - Returns - ------- - pd.DataFrame - DataFrame containing extracted samples with columns for - session_file, session_id, layer, bf_bin, neuron_indices, - corr, and noise_corr. - """ - mat_files = collect_mat_files(data_dir) - - samples = [] - session_id = 0 - for f in mat_files: - print(f"Processing session {session_id}: {os.path.basename(f)}") - mt = process_mat(scipy.io.loadmat(f)) - for layer in range(1, 6): - scorrs = np.array(mt["selectZCorrInfo"]["SigCorrs"]) - ncorrs = np.array(mt["selectZCorrInfo"]["NoiseCorrsTrial"]) - bfvals = np.array(mt["BFInfo"][layer]["BFval"]).ravel() - if scorrs.size == 0 or bfvals.size == 0: - continue - - bin_ids = bfvals.astype(int) - - for bin_idx in range(n_bins): - sel = np.where(bin_ids == bin_idx)[0] - if len(sel) < min_neurons: - continue - subcorr = scorrs[np.ix_(sel, sel)] - samples.append( - { - "session_file": f, - "session_id": session_id, - "layer": layer, - "bf_bin": int(bin_idx), - "neuron_indices": sel.tolist(), - "corr": subcorr.astype(float), - "noise_corr": ncorrs[np.ix_(sel, sel)].astype( - float - ), - } - ) - session_id += 1 - - samples = pd.DataFrame(samples) - return samples - - def _sample_to_pyg_data( - self, sample: dict, threshold: float = 0.2 - ) -> Data: - """Convert a sample dictionary to a PyTorch Geometric Data object. - - Converts correlation matrices to graph representation with node features - and edges for graph-level classification tasks. - - Parameters - ---------- - sample : dict - Sample dictionary containing 'corr', 'noise_corr', 'session_id', - 'layer', and 'bf_bin' keys. - threshold : float, optional - Correlation threshold for creating edges. Defaults to 0.2. - - Returns - ------- - torch_geometric.data.Data - Data object with node features [mean_corr, std_corr, noise_diag], - edges from thresholded correlation, and label y as integer bf_bin. - """ - corr = np.asarray(sample.get("corr")) - if corr.ndim != 2 or corr.size == 0: - # empty placeholder graph - x = torch.zeros((0, 3), dtype=torch.float) - edge_index = torch.empty((2, 0), dtype=torch.long) - edge_attr = torch.empty((0, 1), dtype=torch.float) - else: - n = corr.shape[0] - # sanitize - corr = np.nan_to_num(corr) - - mean_corr = corr.mean(axis=1) - std_corr = corr.std(axis=1) - noise_diag = np.zeros(n) - if "noise_corr" in sample and sample["noise_corr"] is not None: - nc = np.asarray(sample["noise_corr"]) - if nc.shape == corr.shape: - noise_diag = np.diag(nc) - - x_np = np.vstack([mean_corr, std_corr, noise_diag]).T - x = torch.tensor(x_np, dtype=torch.float) - - # build edges from thresholded correlation (upper triangle) - adj = (corr >= threshold).astype(int) - iu = np.triu_indices(n, k=1) - sel = np.where(adj[iu] == 1)[0] - if sel.size == 0: - edge_index = torch.empty((2, 0), dtype=torch.long) - edge_attr = torch.empty((0, 1), dtype=torch.float) - else: - rows = iu[0][sel] - cols = iu[1][sel] - edge_index_np = np.vstack([rows, cols]) - edge_index = torch.tensor(edge_index_np, dtype=torch.long) - # make undirected - edge_index = to_undirected(edge_index) - # edge_attr: corresponding corr weights (for both directions, if made undirected) - weights = corr[rows, cols] - weights = ( - np.repeat(weights, 2) - if edge_index.size(1) == weights.size * 2 - else weights - ) - edge_attr = torch.tensor( - weights.reshape(-1, 1), dtype=torch.float - ) - - y = torch.tensor([int(sample.get("bf_bin", -1))], dtype=torch.long) - data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y) - # attach metadata - data.session_id = int(sample.get("session_id", -1)) - data.layer = int(sample.get("layer", -1)) - return data - - def process(self) -> None: - """Generate raw files into collated PyG dataset and save to disk. - - This implementation mirrors other datasets in the repo: it calls the - static helper `extract_samples()` to enumerate subgraphs, converts each - to a `torch_geometric.data.Data` object via `_sample_to_pyg_data()`, - optionally computes/attaches topology vectors, collates and saves. - """ - data_dir = self.raw_dir - - print(f"[A123] Processing dataset from: {data_dir}") - print(f"[A123] Files in raw_dir: {os.listdir(data_dir)}") - - # extract sample descriptions - print("[A123] Starting extract_samples()...") - samples = A123CortexMDataset.extract_samples( - data_dir, self.n_bins, self.min_neurons - ) - - print(f"[A123] Extracted {len(samples)} samples") - - data_list = [] - for idx, (_, s) in enumerate(samples.iterrows()): - if idx % 100 == 0: - print( - f"[A123] Converting sample {idx}/{len(samples)} to PyG Data..." - ) - d = self._sample_to_pyg_data(s, threshold=self.corr_threshold) - data_list.append(d) - - # collate and save processed dataset - print(f"[A123] Collating {len(data_list)} samples...") - self.data, self.slices = self.collate(data_list) - self._data_list = None - print(f"[A123] Saving processed data to {self.processed_paths[0]}...") - fs.torch_save( - (self._data.to_dict(), self.slices, {}, self._data.__class__), - self.processed_paths[0], - ) - print("[A123] Processing complete!") diff --git a/topobench/data/loaders/graph/a123_loader.py b/topobench/data/loaders/graph/a123_loader.py deleted file mode 100644 index ab01d1484..000000000 --- a/topobench/data/loaders/graph/a123_loader.py +++ /dev/null @@ -1,101 +0,0 @@ -""" -Data loader for the Bowen et al. mouse auditory cortex calcium imaging dataset. - -This script downloads and processes the original dataset introduced in: - -[Citation] Bowen et al. (2024), "Fractured columnar small-world functional network -organization in volumes of L2/3 of mouse auditory cortex," PNAS Nexus, 3(2): pgae074. -https://doi.org/10.1093/pnasnexus/pgae074 - -We apply the preprocessing and graph-construction steps defined in this module to obtain -a representation of neuronal activity suitable for our experiments. - -Please cite the original paper when using this dataset or any derivatives. -""" - -import torch -from omegaconf import DictConfig - -from topobench.data.datasets.a123 import A123CortexMDataset -from topobench.data.loaders.base import AbstractLoader - - -class A123DatasetLoader(AbstractLoader): - """Loader for A123 mouse auditory cortex dataset. - - Implements the AbstractLoader interface: accepts a DictConfig `parameters` - and implements `load_dataset()` which returns a dataset object. - - Parameters - ---------- - parameters : DictConfig - Configuration parameters for the dataset. - **overrides - Additional keyword arguments to override parameters. - """ - - def __init__(self, parameters: DictConfig, **overrides): - """Initialize the A123 dataset loader. - - Parameters - ---------- - parameters : DictConfig - Configuration parameters for the dataset. - **overrides - Additional keyword arguments to override parameters. - """ - # Initialize AbstractLoader (sets self.parameters and self.root_data_dir) - super().__init__(parameters) - - # hyperparameters can come from the DictConfig or be passed as overrides - params = parameters if parameters is not None else {} - - def _get(k, default): - """Get parameter value from DictConfig or overrides. - - Parameters - ---------- - k : str - Parameter key. - default : Any - Default value if key not found. - - Returns - ------- - Any - Parameter value from DictConfig or overrides, or default. - """ - try: - return params.get(k, overrides.get(k, default)) - except Exception: - # DictConfig may use attribute access - return getattr(params, k, overrides.get(k, default)) - - self.batch_size = int(_get("batch_size", 32)) - # dataset will be created when load_dataset() is called - self.dataset = None - - def load_dataset(self) -> torch.utils.data.Dataset: - """Instantiate and return the underlying dataset. - - Returns a `A123CortexMDataset` instance constructed from the loader's - parameters and root data directory. - - Returns - ------- - torch.utils.data.Dataset - A123CortexMDataset instance. - """ - # determine dataset name from parameters, fallback to expected id - name = self.parameters.data_name - - # root path for dataset: use the parent of root_data_dir since the dataset - # constructs its own subdirectory based on name - root = str(self.root_data_dir.parent) - - # Construct dataset; A123CortexMDataset expects (root, name, parameters) - self.dataset = A123CortexMDataset( - root=root, name=name, parameters=self.parameters - ) - - return self.dataset diff --git a/tutorials/tutorial_train_brain_model.ipynb b/tutorials/tutorial_train_brain_model.ipynb deleted file mode 100644 index 6eb393479..000000000 --- a/tutorials/tutorial_train_brain_model.ipynb +++ /dev/null @@ -1,1923 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "af53c476", - "metadata": {}, - "source": [ - "# Training TBModel on Auditory Cortex Data for 1 and 2/3 regions.\n", - "\n", - "This notebook demonstrates loading the MUTAG dataset, applying a simple lifting, defining a small backbone, and training a `TBModel` using `TBLoss` and `TBOptimizer`.\n", - "\n", - "Requirements: the project installed in PYTHONPATH and optional dependencies (torch_geometric, networkx, ripser/persim) if you want advanced features." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "98d0adae", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "os.chdir('..')" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "f9ed7f5f", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Imports OK\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/mariayuffa/anaconda3/envs/tb/lib/python3.11/site-packages/outdated/__init__.py:36: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.\n", - " from pkg_resources import parse_version\n" - ] - } - ], - "source": [ - "# 1) Imports\n", - "import torch\n", - "import lightning as pl\n", - "from omegaconf import OmegaConf\n", - "\n", - "# Data loading / preprocessing utilities from the repo\n", - "from topobench.data.loaders.graph.a123_loader import A123DatasetLoader\n", - "from topobench.dataloader.dataloader import TBDataloader\n", - "from topobench.data.preprocessor import PreProcessor\n", - "\n", - "# Model / training building blocks\n", - "from topobench.model.model import TBModel\n", - "# example backbone building block (SCN2 is optional; we provide a tiny custom backbone below)\n", - "# from topomodelx.nn.simplicial.scn2 import SCN2\n", - "from topobench.nn.wrappers.simplicial import SCNWrapper\n", - "from topobench.nn.encoders import AllCellFeatureEncoder\n", - "from topobench.nn.readouts import PropagateSignalDown\n", - "\n", - "# Optimization / evaluation\n", - "from topobench.loss.loss import TBLoss\n", - "from topobench.optimizer import TBOptimizer\n", - "from topobench.evaluator.evaluator import TBEvaluator\n", - "\n", - "print('Imports OK')" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "03042d76", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Configs created\n" - ] - } - ], - "source": [ - "# 2) Configurations and utilities\n", - "loader_config = {\n", - " 'data_domain': 'graph',\n", - " 'data_type': 'A123',\n", - " # the loader/dataset expects the dataset name key used in the dataset class\n", - " 'data_name': 'a123_cortex_m',\n", - " 'data_dir': './data/a123/'\n", - "}\n", - "\n", - "# Transform config: single transform with transform_name and transform_type\n", - "# PreProcessor expects either {\"transform_name\": ...} (single) or {\"key1\": {...}, \"key2\": {...}} (multiple)\n", - "transform_config = {\n", - " 'transform_type': 'lifting',\n", - " 'transform_name': 'HypergraphKHopLifting',\n", - " 'k_value': 1,\n", - "}\n", - "\n", - "split_config = {\n", - " 'learning_setting': 'inductive',\n", - " 'split_type': 'random',\n", - " 'data_seed': 0,\n", - " 'data_split_dir': './data/a123/splits/',\n", - " 'train_prop': 0.5,\n", - "}\n", - "\n", - "# model / task hyperparameters\n", - "# A123 sample node features are: [mean_corr, std_corr, noise_diag] => 3 channels\n", - "in_channels = 3\n", - "# Multiclass classification: 9 frequency bins (bf_bin 0-8)\n", - "out_channels = 9\n", - "dim_hidden = 16\n", - "n_bins = 9 # default binning from extract_samples\n", - "\n", - "readout_config = {\n", - " 'readout_name': 'PropagateSignalDown',\n", - " 'num_cell_dimensions': 1,\n", - " 'hidden_dim': dim_hidden,\n", - " 'out_channels': out_channels,\n", - " 'task_level': 'graph',\n", - " 'pooling_type': 'sum',\n", - "}\n", - "\n", - "loss_config = {\n", - " 'dataset_loss': {\n", - " 'task': 'classification',\n", - " 'loss_type': 'cross_entropy',\n", - " }\n", - "}\n", - "\n", - "evaluator_config = {\n", - " 'task': 'classification',\n", - " 'num_classes': out_channels,\n", - " 'metrics': ['accuracy', 'precision', 'recall'],\n", - "}\n", - "\n", - "optimizer_config = {\n", - " 'optimizer_id': 'Adam',\n", - " 'parameters': {'lr': 0.001, 'weight_decay': 0.0005},\n", - "}\n", - "\n", - "# convert to OmegaConf (the project often expects DictConfig)\n", - "loader_config = OmegaConf.create(loader_config)\n", - "transform_config = OmegaConf.create(transform_config)\n", - "split_config = OmegaConf.create(split_config)\n", - "readout_config = OmegaConf.create(readout_config)\n", - "loss_config = OmegaConf.create(loss_config)\n", - "evaluator_config = OmegaConf.create(evaluator_config)\n", - "optimizer_config = OmegaConf.create(optimizer_config)\n", - "\n", - "print('Configs created')" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "06a33ac7", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Dataset loaded\n", - "Transform parameters are the same, using existing data_dir: data/a123/a123_cortex_m/transform_type_transform_name_k_value/563224662\n", - "Dataset splits created\n", - "Datasets and datamodule ready\n" - ] - } - ], - "source": [ - "# 3) Loading the data\n", - "\n", - "# Use the A123-specific loader (A123DatasetLoader) to construct the dataset\n", - "graph_loader = A123DatasetLoader(loader_config)\n", - "\n", - "dataset, dataset_dir = graph_loader.load()\n", - "print('Dataset loaded')\n", - "\n", - "preprocessor = PreProcessor(dataset, dataset_dir, transform_config)\n", - "dataset_train, dataset_val, dataset_test = preprocessor.load_dataset_splits(split_config)\n", - "print('Dataset splits created')\n", - "\n", - "# create the TopoBench datamodule / dataloader wrappers\n", - "datamodule = TBDataloader(dataset_train, dataset_val, dataset_test, batch_size=32)\n", - "\n", - "print('Datasets and datamodule ready')" - ] - }, - { - "cell_type": "markdown", - "id": "3b7bc4a8", - "metadata": {}, - "source": [ - "## 4) Backbone definition\n", - "\n", - "We implement a tiny backbone as a `pl.LightningModule` which computes node and hyperedge features: $X_1 = B_1 dot X_0$ and applies two linear layers with ReLU." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "9275c748", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Backbone defined\n" - ] - } - ], - "source": [ - "class MyBackbone(pl.LightningModule):\n", - " def __init__(self, dim_hidden):\n", - " super().__init__()\n", - " self.linear_0 = torch.nn.Linear(dim_hidden, dim_hidden)\n", - " self.linear_1 = torch.nn.Linear(dim_hidden, dim_hidden)\n", - "\n", - " def forward(self, batch):\n", - " # batch.x_0: node features (dense tensor of shape [N, dim_hidden])\n", - " # batch.incidence_hyperedges: sparse incidence matrix with shape [m, n] or [n, m] depending on preprocessor convention\n", - " x_0 = batch.x_0\n", - " incidence_hyperedges = getattr(batch, 'incidence_hyperedges', None)\n", - " if incidence_hyperedges is None:\n", - " # fallback: try incidence as batch.incidence if available\n", - " incidence_hyperedges = getattr(batch, 'incidence', None)\n", - "\n", - " # compute hyperedge features X_1 = B_1 dot X_0 (we assume B_1 is sparse and transposed appropriately)\n", - " x_1 = None\n", - " if incidence_hyperedges is not None:\n", - " try:\n", - " x_1 = torch.sparse.mm(incidence_hyperedges, x_0)\n", - " except Exception:\n", - " # if orientation differs, try transpose\n", - " x_1 = torch.sparse.mm(incidence_hyperedges.T, x_0)\n", - " else:\n", - " # no incidence available: create a zero hyperedge feature placeholder\n", - " x_1 = torch.zeros_like(x_0)\n", - "\n", - " x_0 = self.linear_0(x_0)\n", - " x_0 = torch.relu(x_0)\n", - "\n", - " x_1 = self.linear_1(x_1)\n", - " x_1 = torch.relu(x_1)\n", - "\n", - " model_out = {'labels': batch.y, 'batch_0': getattr(batch, 'batch_0', None)}\n", - " model_out['x_0'] = x_0\n", - " model_out['hyperedge'] = x_1\n", - " return model_out\n", - "\n", - "print('Backbone defined')" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "489bea60", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Components instantiated\n" - ] - } - ], - "source": [ - "# 5) Model initialization (components)\n", - "backbone = MyBackbone(dim_hidden)\n", - "readout = PropagateSignalDown(**readout_config)\n", - "loss = TBLoss(**loss_config)\n", - "feature_encoder = AllCellFeatureEncoder(in_channels=[in_channels], out_channels=dim_hidden)\n", - "evaluator = TBEvaluator(**evaluator_config)\n", - "optimizer = TBOptimizer(**optimizer_config)\n", - "\n", - "print('Components instantiated')" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "366a4200", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "TBModel(backbone=MyBackbone(\n", - " (linear_0): Linear(in_features=16, out_features=16, bias=True)\n", - " (linear_1): Linear(in_features=16, out_features=16, bias=True)\n", - "), readout=PropagateSignalDown(num_cell_dimensions=0, self.hidden_dim=16, readout_name=PropagateSignalDown, loss=TBLoss(losses=[DatasetLoss(task=classification, loss_type=cross_entropy)]), feature_encoder=AllCellFeatureEncoder(in_channels=[3], out_channels=16, dimensions=range(0, 1)))\n" - ] - } - ], - "source": [ - "# 6) Instantiate TBModel\n", - "model = TBModel(backbone=backbone,\n", - " backbone_wrapper=None,\n", - " readout=readout,\n", - " loss=loss,\n", - " feature_encoder=feature_encoder,\n", - " evaluator=evaluator,\n", - " optimizer=optimizer,\n", - " compile=False)\n", - "\n", - "# Print a short summary (repr) to verify construction\n", - "print(model)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "a81da250", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "GPU available: True (mps), used: False\n", - "TPU available: False, using: 0 TPU cores\n", - "HPU available: False, using: 0 HPUs\n", - "/Users/mariayuffa/anaconda3/envs/tb/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "ab180005f81b4c84b0c4f6c3f0d2eb53", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Sanity Checking: | | 0/? [00:00┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", - "┃ Test metric DataLoader 0 ┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", - "│ test/accuracy 0.1269841343164444 │\n", - "│ test/loss 2.125241279602051 │\n", - "│ test/precision 0.05082417652010918 │\n", - "│ test/recall 0.125 │\n", - "└───────────────────────────┴───────────────────────────┘\n", - "\n" - ], - "text/plain": [ - "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", - "│\u001b[36m \u001b[0m\u001b[36m test/accuracy \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.1269841343164444 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test/loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 2.125241279602051 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test/precision \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.05082417652010918 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test/recall \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.125 \u001b[0m\u001b[35m \u001b[0m│\n", - "└───────────────────────────┴───────────────────────────┘\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Test metrics:\n", - "test/loss 2.1252\n", - "test/accuracy 0.1270\n", - "test/precision 0.0508\n", - "test/recall 0.1250\n" - ] - } - ], - "source": [ - "# 8) Testing and printing metrics\n", - "trainer.test(model, datamodule)\n", - "test_metrics = trainer.callback_metrics\n", - "print('\\nTest metrics:')\n", - "for key, val in test_metrics.items():\n", - " try:\n", - " print(f'{key:25s} {float(val):.4f}')\n", - " except Exception:\n", - " print(key, val)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "tb", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.3" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} From aa3b2b198ef3dd3e892daf617629a2f8556f0c22 Mon Sep 17 00:00:00 2001 From: marindigen Date: Mon, 24 Nov 2025 22:54:43 +0100 Subject: [PATCH 6/6] Moved the test class to the appropriate file and removed the redundant one --- test/data/utils/test_io_utils.py | 335 ++++++++++++++++++++++++++++++ test/utils/test_io_utils.py | 342 ------------------------------- topobench/data/utils/io_utils.py | 41 +++- 3 files changed, 375 insertions(+), 343 deletions(-) delete mode 100644 test/utils/test_io_utils.py diff --git a/test/data/utils/test_io_utils.py b/test/data/utils/test_io_utils.py index 85c09e9c8..8c6781ef1 100644 --- a/test/data/utils/test_io_utils.py +++ b/test/data/utils/test_io_utils.py @@ -1,5 +1,9 @@ """Tests for the io_utils module.""" +import os +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch import pytest from topobench.data.utils.io_utils import * @@ -20,3 +24,334 @@ def test_get_file_id_from_url(): with pytest.raises(ValueError): get_file_id_from_url(url_wrong) + + +class TestDownloadFileFromLink: + """Test suite for download_file_from_link function.""" + + @pytest.fixture + def temp_dir(self): + """Create temporary directory for test outputs. + + Returns + ------- + str + Path to temporary directory. + """ + with tempfile.TemporaryDirectory() as tmpdir: + yield tmpdir + + @pytest.fixture + def mock_response(self): + """Create mock response object. + + Returns + ------- + MagicMock + Mock response object with status code and headers. + """ + response = MagicMock() + response.status_code = 200 + response.headers = {"content-length": "5242880"} # 5 MB + response.elapsed.total_seconds.return_value = 1.0 + return response + + def test_download_success_with_progress(self, temp_dir, mock_response): + """Test successful download with progress reporting. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + mock_response : MagicMock + Mock response object. + """ + # Setup mock chunks (5MB total in 1MB chunks) + chunk_data = [b"x" * (1024 * 1024) for _ in range(5)] + mock_response.iter_content.return_value = chunk_data + + with patch("requests.get", return_value=mock_response): + download_file_from_link( + file_link="http://example.com/dataset.tar.gz", + path_to_save=temp_dir, + dataset_name="test_dataset", + file_format="tar.gz", + timeout=60, + retries=1, + ) + + # Verify file was created and has correct size + output_file = os.path.join(temp_dir, "test_dataset.tar.gz") + assert os.path.exists(output_file) + assert os.path.getsize(output_file) == 5 * 1024 * 1024 + + def test_download_creates_directory_if_not_exists(self, temp_dir): + """Test that download creates directory structure. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + """ + nested_dir = os.path.join(temp_dir, "nested", "path") + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {"content-length": "1024"} + mock_response.elapsed.total_seconds.return_value = 0.5 + mock_response.iter_content.return_value = [b"x" * 1024] + + with patch("requests.get", return_value=mock_response): + download_file_from_link( + file_link="http://example.com/dataset.tar.gz", + path_to_save=nested_dir, + dataset_name="test_dataset", + file_format="tar.gz", + timeout=60, + retries=1, + ) + + output_file = os.path.join(nested_dir, "test_dataset.tar.gz") + assert os.path.exists(output_file) + assert os.path.isdir(nested_dir) + + def test_download_http_error(self, temp_dir): + """Test handling of HTTP error responses. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + """ + mock_response = MagicMock() + mock_response.status_code = 404 + + with patch("requests.get", return_value=mock_response): + download_file_from_link( + file_link="http://example.com/nonexistent.tar.gz", + path_to_save=temp_dir, + dataset_name="test_dataset", + file_format="tar.gz", + timeout=60, + retries=1, + ) + + # File should not be created on HTTP error + output_file = os.path.join(temp_dir, "test_dataset.tar.gz") + assert not os.path.exists(output_file) + + def test_download_timeout_retry(self, temp_dir): + """Test retry logic on timeout. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + """ + import requests + + with patch("requests.get") as mock_get: + # First call times out, second succeeds + mock_response_success = MagicMock() + mock_response_success.status_code = 200 + mock_response_success.headers = {"content-length": "1024"} + mock_response_success.elapsed.total_seconds.return_value = 0.5 + mock_response_success.iter_content.return_value = [b"x" * 1024] + + mock_get.side_effect = [ + requests.exceptions.Timeout("Connection timed out"), + mock_response_success, + ] + + with patch("time.sleep"): # Mock sleep to speed up test + download_file_from_link( + file_link="http://example.com/dataset.tar.gz", + path_to_save=temp_dir, + dataset_name="test_dataset", + file_format="tar.gz", + timeout=60, + retries=3, + ) + + # File should be created on successful retry + output_file = os.path.join(temp_dir, "test_dataset.tar.gz") + assert os.path.exists(output_file) + assert mock_get.call_count == 2 + + def test_download_exhausts_retries(self, temp_dir): + """Test that exception is raised after all retries exhausted. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + """ + import requests + + with patch("requests.get") as mock_get: + mock_get.side_effect = requests.exceptions.Timeout( + "Connection timed out" + ) + + with patch("time.sleep"): + with pytest.raises(requests.exceptions.Timeout): + download_file_from_link( + file_link="http://example.com/dataset.tar.gz", + path_to_save=temp_dir, + dataset_name="test_dataset", + file_format="tar.gz", + timeout=60, + retries=2, + ) + + # Verify retries were attempted + assert mock_get.call_count == 2 + + def test_download_with_different_formats(self, temp_dir, mock_response): + """Test download with different file formats. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + mock_response : MagicMock + Mock response object. + """ + mock_response.iter_content.return_value = [b"test content"] + + formats = ["zip", "tar", "tar.gz"] + + with patch("requests.get", return_value=mock_response): + for fmt in formats: + download_file_from_link( + file_link="http://example.com/dataset", + path_to_save=temp_dir, + dataset_name=f"test_dataset_{fmt.replace('.', '_')}", + file_format=fmt, + timeout=60, + retries=1, + ) + + # Verify all files were created with correct extensions + for fmt in formats: + output_file = os.path.join( + temp_dir, f"test_dataset_{fmt.replace('.', '_')}.{fmt}" + ) + assert os.path.exists(output_file) + + def test_download_empty_chunks(self, temp_dir): + """Test handling of empty chunks in response. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + """ + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {"content-length": "1024"} + mock_response.elapsed.total_seconds.return_value = 1.0 + # Include empty chunks (should be skipped) + mock_response.iter_content.return_value = [ + b"x" * 512, + b"", # Empty chunk + b"y" * 512, + b"", # Another empty chunk + ] + + with patch("requests.get", return_value=mock_response): + download_file_from_link( + file_link="http://example.com/dataset.tar.gz", + path_to_save=temp_dir, + dataset_name="test_dataset", + file_format="tar.gz", + timeout=60, + retries=1, + ) + + # File should contain only non-empty chunks + output_file = os.path.join(temp_dir, "test_dataset.tar.gz") + assert os.path.exists(output_file) + assert os.path.getsize(output_file) == 1024 + + def test_download_unknown_size(self, temp_dir): + """Test download when content-length header is missing. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + """ + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {} # No content-length header + mock_response.elapsed.total_seconds.return_value = 0.5 + mock_response.iter_content.return_value = [b"x" * 1024] + + with patch("requests.get", return_value=mock_response): + download_file_from_link( + file_link="http://example.com/dataset.tar.gz", + path_to_save=temp_dir, + dataset_name="test_dataset", + file_format="tar.gz", + timeout=60, + retries=1, + ) + + output_file = os.path.join(temp_dir, "test_dataset.tar.gz") + assert os.path.exists(output_file) + assert os.path.getsize(output_file) == 1024 + + def test_download_ssl_verification_disabled(self, temp_dir, mock_response): + """Test that SSL verification can be disabled. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + mock_response : MagicMock + Mock response object. + """ + mock_response.iter_content.return_value = [b"test content"] + + with patch("requests.get", return_value=mock_response) as mock_get: + download_file_from_link( + file_link="https://example.com/dataset.tar.gz", + path_to_save=temp_dir, + dataset_name="test_dataset", + file_format="tar.gz", + verify=False, + timeout=60, + retries=1, + ) + + # Verify requests.get was called with verify=False + mock_get.assert_called_once() + assert mock_get.call_args[1]["verify"] is False + + def test_download_custom_timeout(self, temp_dir, mock_response): + """Test that custom timeout is used. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + mock_response : MagicMock + Mock response object. + """ + mock_response.iter_content.return_value = [b"test content"] + + with patch("requests.get", return_value=mock_response) as mock_get: + custom_timeout = 120 # 2 minutes per chunk + download_file_from_link( + file_link="https://github.com/aidos-lab/mantra/releases/download/{version}/2_manifolds.json.gz", + path_to_save=temp_dir, + dataset_name="test_dataset", + file_format="tar.gz", + timeout=custom_timeout, + retries=1, + ) + + # Verify requests.get was called with correct timeout + mock_get.assert_called_once() + assert mock_get.call_args[1]["timeout"] == (30, custom_timeout) diff --git a/test/utils/test_io_utils.py b/test/utils/test_io_utils.py deleted file mode 100644 index 39594f201..000000000 --- a/test/utils/test_io_utils.py +++ /dev/null @@ -1,342 +0,0 @@ -"""Tests for data IO utilities.""" - -import os -import tempfile -from pathlib import Path -from unittest.mock import MagicMock, patch - -import pytest - -from topobench.data.utils.io_utils import download_file_from_link - - -class TestDownloadFileFromLink: - """Test suite for download_file_from_link function.""" - - @pytest.fixture - def temp_dir(self): - """Create temporary directory for test outputs. - - Returns - ------- - str - Path to temporary directory. - """ - with tempfile.TemporaryDirectory() as tmpdir: - yield tmpdir - - @pytest.fixture - def mock_response(self): - """Create mock response object. - - Returns - ------- - MagicMock - Mock response object with status code and headers. - """ - response = MagicMock() - response.status_code = 200 - response.headers = {"content-length": "5242880"} # 5 MB - response.elapsed.total_seconds.return_value = 1.0 - return response - - def test_download_success_with_progress(self, temp_dir, mock_response): - """Test successful download with progress reporting. - - Parameters - ---------- - temp_dir : str - Temporary directory path. - mock_response : MagicMock - Mock response object. - """ - # Setup mock chunks (5MB total in 1MB chunks) - chunk_data = [b"x" * (1024 * 1024) for _ in range(5)] - mock_response.iter_content.return_value = chunk_data - - with patch("requests.get", return_value=mock_response): - download_file_from_link( - file_link="http://example.com/dataset.tar.gz", - path_to_save=temp_dir, - dataset_name="test_dataset", - file_format="tar.gz", - timeout=60, - retries=1, - ) - - # Verify file was created and has correct size - output_file = os.path.join(temp_dir, "test_dataset.tar.gz") - assert os.path.exists(output_file) - assert os.path.getsize(output_file) == 5 * 1024 * 1024 - - def test_download_creates_directory_if_not_exists(self, temp_dir): - """Test that download creates directory structure. - - Parameters - ---------- - temp_dir : str - Temporary directory path. - """ - nested_dir = os.path.join(temp_dir, "nested", "path") - - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.headers = {"content-length": "1024"} - mock_response.elapsed.total_seconds.return_value = 0.5 - mock_response.iter_content.return_value = [b"x" * 1024] - - with patch("requests.get", return_value=mock_response): - download_file_from_link( - file_link="http://example.com/dataset.tar.gz", - path_to_save=nested_dir, - dataset_name="test_dataset", - file_format="tar.gz", - timeout=60, - retries=1, - ) - - output_file = os.path.join(nested_dir, "test_dataset.tar.gz") - assert os.path.exists(output_file) - assert os.path.isdir(nested_dir) - - def test_download_http_error(self, temp_dir): - """Test handling of HTTP error responses. - - Parameters - ---------- - temp_dir : str - Temporary directory path. - """ - mock_response = MagicMock() - mock_response.status_code = 404 - - with patch("requests.get", return_value=mock_response): - download_file_from_link( - file_link="http://example.com/nonexistent.tar.gz", - path_to_save=temp_dir, - dataset_name="test_dataset", - file_format="tar.gz", - timeout=60, - retries=1, - ) - - # File should not be created on HTTP error - output_file = os.path.join(temp_dir, "test_dataset.tar.gz") - assert not os.path.exists(output_file) - - def test_download_timeout_retry(self, temp_dir): - """Test retry logic on timeout. - - Parameters - ---------- - temp_dir : str - Temporary directory path. - """ - import requests - - with patch("requests.get") as mock_get: - # First call times out, second succeeds - mock_response_success = MagicMock() - mock_response_success.status_code = 200 - mock_response_success.headers = {"content-length": "1024"} - mock_response_success.elapsed.total_seconds.return_value = 0.5 - mock_response_success.iter_content.return_value = [b"x" * 1024] - - mock_get.side_effect = [ - requests.exceptions.Timeout("Connection timed out"), - mock_response_success, - ] - - with patch("time.sleep"): # Mock sleep to speed up test - download_file_from_link( - file_link="http://example.com/dataset.tar.gz", - path_to_save=temp_dir, - dataset_name="test_dataset", - file_format="tar.gz", - timeout=60, - retries=3, - ) - - # File should be created on successful retry - output_file = os.path.join(temp_dir, "test_dataset.tar.gz") - assert os.path.exists(output_file) - assert mock_get.call_count == 2 - - def test_download_exhausts_retries(self, temp_dir): - """Test that exception is raised after all retries exhausted. - - Parameters - ---------- - temp_dir : str - Temporary directory path. - """ - import requests - - with patch("requests.get") as mock_get: - mock_get.side_effect = requests.exceptions.Timeout( - "Connection timed out" - ) - - with patch("time.sleep"): - with pytest.raises(requests.exceptions.Timeout): - download_file_from_link( - file_link="http://example.com/dataset.tar.gz", - path_to_save=temp_dir, - dataset_name="test_dataset", - file_format="tar.gz", - timeout=60, - retries=2, - ) - - # Verify retries were attempted - assert mock_get.call_count == 2 - - def test_download_with_different_formats(self, temp_dir, mock_response): - """Test download with different file formats. - - Parameters - ---------- - temp_dir : str - Temporary directory path. - mock_response : MagicMock - Mock response object. - """ - mock_response.iter_content.return_value = [b"test content"] - - formats = ["zip", "tar", "tar.gz"] - - with patch("requests.get", return_value=mock_response): - for fmt in formats: - download_file_from_link( - file_link="http://example.com/dataset", - path_to_save=temp_dir, - dataset_name=f"test_dataset_{fmt.replace('.', '_')}", - file_format=fmt, - timeout=60, - retries=1, - ) - - # Verify all files were created with correct extensions - for fmt in formats: - output_file = os.path.join( - temp_dir, f"test_dataset_{fmt.replace('.', '_')}.{fmt}" - ) - assert os.path.exists(output_file) - - def test_download_empty_chunks(self, temp_dir): - """Test handling of empty chunks in response. - - Parameters - ---------- - temp_dir : str - Temporary directory path. - """ - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.headers = {"content-length": "1024"} - mock_response.elapsed.total_seconds.return_value = 1.0 - # Include empty chunks (should be skipped) - mock_response.iter_content.return_value = [ - b"x" * 512, - b"", # Empty chunk - b"y" * 512, - b"", # Another empty chunk - ] - - with patch("requests.get", return_value=mock_response): - download_file_from_link( - file_link="http://example.com/dataset.tar.gz", - path_to_save=temp_dir, - dataset_name="test_dataset", - file_format="tar.gz", - timeout=60, - retries=1, - ) - - # File should contain only non-empty chunks - output_file = os.path.join(temp_dir, "test_dataset.tar.gz") - assert os.path.exists(output_file) - assert os.path.getsize(output_file) == 1024 - - def test_download_unknown_size(self, temp_dir): - """Test download when content-length header is missing. - - Parameters - ---------- - temp_dir : str - Temporary directory path. - """ - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.headers = {} # No content-length header - mock_response.elapsed.total_seconds.return_value = 0.5 - mock_response.iter_content.return_value = [b"x" * 1024] - - with patch("requests.get", return_value=mock_response): - download_file_from_link( - file_link="http://example.com/dataset.tar.gz", - path_to_save=temp_dir, - dataset_name="test_dataset", - file_format="tar.gz", - timeout=60, - retries=1, - ) - - output_file = os.path.join(temp_dir, "test_dataset.tar.gz") - assert os.path.exists(output_file) - assert os.path.getsize(output_file) == 1024 - - def test_download_ssl_verification_disabled(self, temp_dir, mock_response): - """Test that SSL verification can be disabled. - - Parameters - ---------- - temp_dir : str - Temporary directory path. - mock_response : MagicMock - Mock response object. - """ - mock_response.iter_content.return_value = [b"test content"] - - with patch("requests.get", return_value=mock_response) as mock_get: - download_file_from_link( - file_link="https://example.com/dataset.tar.gz", - path_to_save=temp_dir, - dataset_name="test_dataset", - file_format="tar.gz", - verify=False, - timeout=60, - retries=1, - ) - - # Verify requests.get was called with verify=False - mock_get.assert_called_once() - assert mock_get.call_args[1]["verify"] is False - - def test_download_custom_timeout(self, temp_dir, mock_response): - """Test that custom timeout is used. - - Parameters - ---------- - temp_dir : str - Temporary directory path. - mock_response : MagicMock - Mock response object. - """ - mock_response.iter_content.return_value = [b"test content"] - - with patch("requests.get", return_value=mock_response) as mock_get: - custom_timeout = 120 # 2 minutes per chunk - download_file_from_link( - file_link="https://github.com/aidos-lab/mantra/releases/download/{version}/2_manifolds.json.gz", - path_to_save=temp_dir, - dataset_name="test_dataset", - file_format="tar.gz", - timeout=custom_timeout, - retries=1, - ) - - # Verify requests.get was called with correct timeout - mock_get.assert_called_once() - assert mock_get.call_args[1]["timeout"] == (30, custom_timeout) - diff --git a/topobench/data/utils/io_utils.py b/topobench/data/utils/io_utils.py index 5839b1db9..a20913e8f 100644 --- a/topobench/data/utils/io_utils.py +++ b/topobench/data/utils/io_utils.py @@ -119,9 +119,48 @@ def download_file_from_link( retries : int, optional Number of retry attempts if download fails. Defaults to 3. + Notes + ----- + This function downloads files in 5MB chunks for memory efficiency. Progress is + reported every 10MB. Timeouts apply per chunk, not to the entire download, + making it suitable for very large files and slow connections. + + If a download fails, it retries with exponential backoff (5s, 10s, 15s). + + Examples + -------- + Basic download: + + >>> from topobench.data.utils import download_file_from_link + >>> download_file_from_link( + ... file_link="https://example.com/dataset.tar.gz", + ... path_to_save="./data/", + ... dataset_name="my_dataset" + ... ) + + Download with custom timeout for slow servers: + + >>> download_file_from_link( + ... file_link="https://slow-server.com/dataset.zip", + ... path_to_save="./data/", + ... dataset_name="my_dataset", + ... file_format="zip", + ... timeout=300 # 5 minutes per chunk + ... ) + + Download with increased retries for unreliable connections: + + >>> download_file_from_link( + ... file_link="https://example.com/dataset.tar.gz", + ... path_to_save="./data/", + ... dataset_name="my_dataset", + ... retries=5 # Try up to 5 times + ... ) + Raises ------ - None + Exception + If download fails after all retry attempts. """ # Ensure output directory exists os.makedirs(path_to_save, exist_ok=True)