From bda6baa6691d35f52440144a79f2abdeb6bf9ee2 Mon Sep 17 00:00:00 2001 From: I745505 Date: Tue, 18 Nov 2025 15:12:30 +0100 Subject: [PATCH 01/32] :card_file_box: IO and Split utils --- topobench/data/utils/io_utils.py | 256 ++++++++++++++++++++++++++++ topobench/data/utils/split_utils.py | 123 ++++++++++++- 2 files changed, 373 insertions(+), 6 deletions(-) diff --git a/topobench/data/utils/io_utils.py b/topobench/data/utils/io_utils.py index 372db85e6..9e2ae8344 100644 --- a/topobench/data/utils/io_utils.py +++ b/topobench/data/utils/io_utils.py @@ -1,15 +1,21 @@ """Data IO utilities.""" import json +import os import os.path as osp import pickle +import tempfile +import warnings +import zipfile from urllib.parse import parse_qs, urlparse +import gdown import numpy as np import pandas as pd import requests import torch import torch_geometric +from pybiomart import Dataset as BioMartDataset from toponetx.classes import SimplicialComplex from torch_geometric.data import Data from torch_sparse import coalesce @@ -50,6 +56,218 @@ def get_file_id_from_url(url): return file_id +def get_folder_id_from_url(url): + """Extract the folder ID from a Google Drive folder URL or return ID if already provided. + + Parameters + ---------- + url : str + The Google Drive folder URL or folder ID. + + Returns + ------- + str + The folder ID extracted from the URL, or the ID itself if already an ID. + + Raises + ------ + ValueError + If the provided string is not a valid Google Drive folder URL or ID. + """ + # If it doesn't look like a URL (no scheme), assume it's already an ID + if "://" not in url and "/" not in url: + return url + + parsed_url = urlparse(url) + query_params = parse_qs(parsed_url.query) + + if "id" in query_params: # Case 1: URL format contains '?id=' + folder_id = query_params["id"][0] + elif ( + "folders/" in parsed_url.path + ): # Case 2: URL format contains '/folders/' + folder_id = parsed_url.path.split("/folders/")[1].split("/")[0] + else: + raise ValueError( + "The provided string is not a valid Google Drive folder URL or ID." + ) + return folder_id + + +def download_file(url, output_path, timeout=30, verify_ssl=True): + """Download a file from URL and save it. + + Parameters + ---------- + url : str + URL of the file to download. + output_path : str + Path where the file will be saved. + timeout : int, optional + Request timeout in seconds. Defaults to 30. + verify_ssl : bool, optional + Whether to verify SSL certificates. Defaults to True. + + Returns + ------- + bool + True if download succeeded, False otherwise. + + Raises + ------ + RuntimeError + If download fails. + """ + if not verify_ssl: + warnings.warn( + "SSL certificate verification is disabled", + UserWarning, + stacklevel=2, + ) + + try: + response = requests.get(url, timeout=timeout, verify=verify_ssl) + response.raise_for_status() + + with open(output_path, "wb") as f: + f.write(response.content) + + return True + + except Exception as e: + raise RuntimeError(f"Failed to download from {url}: {e}") from e + + +def download_and_extract_zip( + url, output_dir, filename_to_extract=None, timeout=30, verify_ssl=True +): + """Download a zip file from URL and extract it. + + Parameters + ---------- + url : str + URL of the zip file to download. + output_dir : str + Directory where files will be extracted. + filename_to_extract : str, optional + If provided, only extract this specific file from the zip. + If None, extract all files. + timeout : int, optional + Request timeout in seconds. Defaults to 30. + verify_ssl : bool, optional + Whether to verify SSL certificates. Defaults to True. + + Returns + ------- + bool + True if download and extraction succeeded, False otherwise. + + Raises + ------ + RuntimeError + If download or extraction fails. + """ + if not verify_ssl: + warnings.warn( + "SSL certificate verification is disabled", + UserWarning, + stacklevel=2, + ) + + try: + response = requests.get(url, timeout=timeout, verify=verify_ssl) + response.raise_for_status() + + # Save to temporary file + with tempfile.NamedTemporaryFile( + delete=False, suffix=".zip" + ) as tmp_file: + tmp_file.write(response.content) + zip_path = tmp_file.name + + # Extract + with zipfile.ZipFile(zip_path, "r") as zip_ref: + if filename_to_extract: + zip_ref.extract(filename_to_extract, output_dir) + else: + zip_ref.extractall(output_dir) + + # Clean up temp file + os.remove(zip_path) + return True + + except Exception as e: + raise RuntimeError( + f"Failed to download and extract from {url}: {e}" + ) from e + + +def download_ensembl_biomart_mapping( + output_path, + dataset="hsapiens_gene_ensembl", + attributes=None, + id_prefix="9606.", + timeout=120, +): + """Download ID mappings from Ensembl BioMart using pybiomart library. + + Note: Requires 'pybiomart' package. Install with: pip install pybiomart + + Parameters + ---------- + output_path : str + Path where the mapping file will be saved. + dataset : str, optional + BioMart dataset name. Defaults to "hsapiens_gene_ensembl". + attributes : list of str, optional + Attributes to retrieve. Defaults to ["ensembl_peptide_id", "uniprotswissprot"]. + id_prefix : str, optional + Prefix to add to IDs (e.g., "9606." for taxon). Defaults to "9606.". + timeout : int, optional + Request timeout in seconds. Defaults to 120. + + Returns + ------- + bool + True if download succeeded, False otherwise. + + Raises + ------ + RuntimeError + If download fails. + """ + + if attributes is None: + attributes = ["ensembl_peptide_id", "uniprotswissprot"] + + try: + # Query BioMart using the library + biomart_dataset = BioMartDataset( + name=dataset, host="http://www.ensembl.org" + ) + result_df = biomart_dataset.query(attributes=attributes) + + # Save to file with optional prefix + with open(output_path, "w") as f: + for _, row in result_df.iterrows(): + # Skip rows with missing values + if row.isnull().any(): + continue + + values = row.tolist() + # Add prefix to first column (ID) if specified + if id_prefix: + values[0] = f"{id_prefix}{values[0]}" + f.write("\t".join(str(v) for v in values) + "\n") + + return True + + except Exception as e: + raise RuntimeError( + f"Failed to download from Ensembl BioMart: {e}" + ) from e + + def download_file_from_drive( file_link, path_to_save, dataset_name, file_format="tar.gz" ): @@ -84,6 +302,44 @@ def download_file_from_drive( print("Failed to download the file.") +def download_folder_from_drive(folder_link, output_dir, quiet=False): + """Download an entire folder from Google Drive using gdown. + + Parameters + ---------- + folder_link : str + The Google Drive folder URL or folder ID. + output_dir : str + The directory where the folder contents will be saved. + quiet : bool, optional + If True, suppress download progress messages. Defaults to False. + + Returns + ------- + bool + True if download succeeded, False otherwise. + + Raises + ------ + ValueError + If the provided link is not a valid Google Drive folder URL. + """ + # Extract folder ID from URL if needed + folder_id = get_folder_id_from_url(folder_link) + + try: + gdown.download_folder( + id=folder_id, + output=output_dir, + quiet=quiet, + use_cookies=False, + ) + return True + except Exception as e: + print(f"Failed to download folder from Google Drive: {e}") + return False + + def download_file_from_link( file_link, path_to_save, dataset_name, file_format="tar.gz" ): diff --git a/topobench/data/utils/split_utils.py b/topobench/data/utils/split_utils.py index f78994222..cf63b50a4 100644 --- a/topobench/data/utils/split_utils.py +++ b/topobench/data/utils/split_utils.py @@ -248,16 +248,19 @@ def load_transductive_splits(dataset, parameters): ) data = dataset.data_list[0] + + # Check if this is multi-rank cell prediction + target_ranks = getattr(dataset, "target_ranks", None) + if target_ranks is not None and len(target_ranks) > 1: + return load_multirank_transductive_splits(dataset, parameters) + + # Single rank or node/graph prediction labels = data.y.numpy() # Ensure labels are one dimensional array assert len(labels.shape) == 1, "Labels should be one dimensional array" - root = ( - dataset.dataset.get_data_dir() - if hasattr(dataset.dataset, "get_data_dir") - else None - ) + root = dataset.get_data_dir() if hasattr(dataset, "get_data_dir") else None if parameters.split_type == "random": splits = random_splitting(labels, parameters, root=root) @@ -265,9 +268,18 @@ def load_transductive_splits(dataset, parameters): elif parameters.split_type == "k-fold": splits = k_fold_split(labels, parameters, root=root) + elif parameters.split_type == "fixed" and hasattr(dataset, "split_idx"): + splits = dataset.split_idx + if splits is None: + raise ValueError( + "Dataset has split_type='fixed' but split_idx property returned None. " + "Either the dataset doesn't support fixed splits or they failed to load." + ) + else: raise NotImplementedError( - f"split_type {parameters.split_type} not valid. Choose either 'random' or 'k-fold'" + f"split_type {parameters.split_type} not valid. Choose 'random', 'k-fold', or 'fixed'.\n" + f"If 'fixed' is chosen, the dataset must have a split_idx property." ) # Assign train val test masks to the graph @@ -287,6 +299,105 @@ def load_transductive_splits(dataset, parameters): return DataloadDataset([data]), None, None +def load_multirank_transductive_splits(dataset, parameters): + r"""Load dataset with multi-rank cell-level splits. + + For datasets with cell-level predictions across multiple ranks (e.g., edges, + triangles, tetrahedra simultaneously), this function creates independent + train/val/test splits for each rank. + + Parameters + ---------- + dataset : torch_geometric.data.Dataset + Dataset with multi-rank cell labels. + parameters : DictConfig + Configuration parameters containing split_type and train_prop. + + Returns + ------- + list: + List containing the train dataset (validation and test are None for transductive). + + Notes + ----- + Expects dataset to have: + - target_ranks: list of ranks to split + - data.cell_labels_{rank}: labels for each rank + + Creates per-rank masks: + - data.train_mask_{rank}: training indices for rank + - data.val_mask_{rank}: validation indices for rank + - data.test_mask_{rank}: test indices for rank + """ + assert len(dataset) == 1, ( + "Dataset should have only one graph/complex in a transductive setting." + ) + + data = dataset.data_list[0] + target_ranks = dataset.target_ranks + + root = dataset.get_data_dir() if hasattr(dataset, "get_data_dir") else None + + # Split each rank independently + for rank in target_ranks: + label_attr = f"cell_labels_{rank}" + + if not hasattr(data, label_attr): + raise ValueError( + f"Data object missing {label_attr} for rank {rank}. " + f"Available attributes: {list(data.keys())}" + ) + + labels = getattr(data, label_attr).numpy() + + # Handle multi-dimensional labels (e.g., multi-label classification) + if len(labels.shape) > 1: + # Use first column for stratification (common practice) + stratify_labels = ( + labels[:, 0] if labels.shape[1] > 0 else labels.flatten() + ) + else: + stratify_labels = labels + + # Create rank-specific root directory for splits + # This ensures each rank gets independent splits + rank_root = os.path.join(root, f"rank_{rank}") if root else None + + # Perform splitting + if parameters.split_type == "random": + splits = random_splitting( + stratify_labels, parameters, root=rank_root + ) + elif parameters.split_type == "k-fold": + splits = k_fold_split(stratify_labels, parameters, root=root) + elif parameters.split_type == "fixed" and hasattr( + dataset, "split_idx" + ): + splits = dataset.split_idx + if splits is None: + raise ValueError( + "Dataset has split_type='fixed' but split_idx property returned None. " + "Either the dataset doesn't support fixed splits or they failed to load." + ) + else: + raise NotImplementedError( + f"split_type {parameters.split_type} not valid. " + f"Choose 'random', 'k-fold', or 'fixed'.\n" + f"If 'fixed' is chosen, the dataset must have a split_idx property." + ) + + # Store per-rank masks + train_mask = torch.from_numpy(splits["train"]) + val_mask = torch.from_numpy(splits["valid"]) + test_mask = torch.from_numpy(splits["test"]) + + setattr(data, f"train_mask_{rank}", train_mask) + setattr(data, f"val_mask_{rank}", val_mask) + setattr(data, f"test_mask_{rank}", test_mask) + + return DataloadDataset([data]), None, None + + def load_inductive_splits(dataset, parameters): r"""Load multiple-graph datasets with the specified split. From 7108ab3ec5180a26a033d04121789f43a03ce5d0 Mon Sep 17 00:00:00 2001 From: I745505 Date: Tue, 18 Nov 2025 15:19:22 +0100 Subject: [PATCH 02/32] :sparkles: PPI Dataset: HIGH-PPI + CORUM --- configs/dataset/simplicial/ppi_highppi.yaml | 100 ++++ .../data/datasets/ppi_highppi_dataset.py | 415 +++++++++++++ .../loaders/simplicial/ppi_highppi_loader.py | 71 +++ .../utils/datasets/simplicial/ppi_utils.py | 558 ++++++++++++++++++ 4 files changed, 1144 insertions(+) create mode 100644 configs/dataset/simplicial/ppi_highppi.yaml create mode 100644 topobench/data/datasets/ppi_highppi_dataset.py create mode 100644 topobench/data/loaders/simplicial/ppi_highppi_loader.py create mode 100644 topobench/data/utils/datasets/simplicial/ppi_utils.py diff --git a/configs/dataset/simplicial/ppi_highppi.yaml b/configs/dataset/simplicial/ppi_highppi.yaml new file mode 100644 index 000000000..a0f36793e --- /dev/null +++ b/configs/dataset/simplicial/ppi_highppi.yaml @@ -0,0 +1,100 @@ +################################################################################ +# HIGH-PPI + CORUM: Protein Interaction Prediction via Simplicial Complexes +################################################################################ +# +# Data Structure: +# - Proteins (rank 0): ~1,553 proteins +# - Edges (rank 1): ~6,660 HIGH-PPI edges with: +# * Features: 8-dim (7 interaction types + 1 STRING confidence score) +# - Interaction types: reaction, binding, ptmod, activation, inhibition, catalysis, expression +# - Confidence score [0, 1] measuring interaction probability (mapped to [-1, 1]) +# - Higher-order cells: CORUM protein complexes (2+ proteins) +# * Features: 1-dim (binary existence: 1=real, -1=fake) +# +# Note: Features at any rank can also serve as prediction targets (labels). +# Models should mask features of the rank being predicted to avoid data leakage. +# +# Prediction Tasks: +# - Edge (rank 1): Regression (confidence scores) or multi-label (interaction types) +# - Cell (ranks 2+): Binary classification (complex existence) +# +################################################################################ + +# Data loading configuration +loader: + _target_: topobench.data.loaders.PPIHighPPIDatasetLoader + parameters: + data_domain: simplicial + model_domain: simplicial + data_name: ppi_highppi + data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain} + + # CORUM Complex Configuration + min_complex_size: 2 # Minimum proteins per CORUM complex (2+ allowed) + # Edge features for edges in CORUM: + # - In HIGH-PPI: Interaction types + confidence boosted to 1.0 + # - Not in HIGH-PPI: [0,0,0,0,0,0,0, 1.0] (unknown types, high confidence) + max_complex_size: 6 # Maximum proteins per CORUM complex + + # Negative Sampling (for classification tasks) + neg_ratio: 1.0 # Ratio of negative to positive samples (1.0 = balanced) + + # Multi-Rank Prediction + target_ranks: [2, 3, 4, 5] # Which ranks to predict (train/test on) + # Max target rank must be <= max_complex_size - 1 + + # Edge Task Type (only applied when rank 1 in target_ranks) + edge_task: score # "score": Regression - predict confidence of interaction [0-1] + # "interaction_type": Multi-label - predict 7 interaction types + +# Model training configuration +parameters: + # Feature dimensions: [rank-0, rank-1, ..., rank-max] + # rank-0: One-hot encoded proteins TODO: Replace with richer embedding + # rank-1: 8-dim edge features (7 interaction types + 1 confidence score) + # rank-2+: 1-dim features (binary existence) + num_features: [1553, 8, 1, 1, 1, 1] + + num_classes: 2 # Depends on task: + # - Higher-order (ranks 2+): 2 (exists/doesn't exist) + # - Edge regression (rank 1, score): 1 (continuous output) + # - Edge multi-label (rank 1, interaction_type): 7 (7 types) + task: classification # Depends on target_ranks and edge_task: + # - Higher-order (ranks 2+): classification + # - Edge regression (rank 1, score): regression + # - Edge multi-label (rank 1, interaction_type): classification + task_level: cell # Predict on cells (edges/triangles/etc), not nodes or graphs + + # Multi-Rank Prediction + target_ranks: ${dataset.loader.parameters.target_ranks} + + loss_type: cross_entropy # Depends on task: + # - Higher-order binary: cross_entropy + # - Edge regression: mse or mae + # - Edge multi-label: bce_with_logits + monitor_metric: auroc # Depends on task: + # - Higher-order binary: auroc, f1, accuracy + # - Edge regression: mae, rmse + # - Edge multi-label: f1, auroc + +# Splits Configuration +split_params: + learning_setting: transductive # Single complex, split labeled cells + data_split_dir: ${paths.data_dir}/data_splits/${dataset.loader.parameters.data_name} + data_seed: 42 # Random seed for reproducible splits + + # Split Type Options: + # - "random": Random splitting with train_prop ratio + # - "k-fold": K-fold cross-validation + # - "fixed": Use HIGH-PPI's official train/val split (if available in raw data) + split_type: random + + train_prop: 0.8 # For random/k-fold: 80% train, 10% val, 10% test + # Ignored when split_type: fixed + +# Dataloader +dataloader_params: + batch_size: 1 + num_workers: 0 + pin_memory: False + persistent_workers: False diff --git a/topobench/data/datasets/ppi_highppi_dataset.py b/topobench/data/datasets/ppi_highppi_dataset.py new file mode 100644 index 000000000..bc439b340 --- /dev/null +++ b/topobench/data/datasets/ppi_highppi_dataset.py @@ -0,0 +1,415 @@ +"""PPI dataset integrating HIGH-PPI network data with CORUM human protein complexes. + +Combines: +- HIGH-PPI SHS27k: PPI network with 7 interaction type features + confidence scores +- CORUM: ~470 experimentally validated human protein complexes as native higher-order structures +- TODO: Add data for node features (embeddings) + +Simplicial complex structure: +- 0-cells: 1,553 proteins +- 1-cells: 6,660 PPI edges + CORUM complexes of size 2 +- 2+ cells: CORUM complexes of size 3+ +""" + +import json +import os +import os.path as osp +from typing import ClassVar + +import numpy as np +import torch +from omegaconf import DictConfig +from torch_geometric.data import Data, InMemoryDataset +from torch_geometric.io import fs + +from topobench.data.utils import get_complex_connectivity +from topobench.data.utils.datasets.simplicial.ppi_utils import ( + build_data_features_and_labels, + build_simplicial_complex_with_features, + generate_negative_samples, + load_corum_complexes, + load_highppi_network, + load_id_mapping, +) +from topobench.data.utils.io_utils import ( + download_ensembl_biomart_mapping, + download_file, + download_folder_from_drive, +) + + +class PPIHighPPIDataset(InMemoryDataset): + """HIGH-PPI network integrated with CORUM human protein complexes. + + Combines 6,660 protein-protein interactions from HIGH-PPI SHS27k with ~470 + experimentally validated human protein complexes from CORUM database. + + Parameters + ---------- + root : str + Root directory. + name : str, optional + Dataset name, default "ppi_highppi". + parameters : DictConfig, optional + Config with min_complex_size, max_complex_size, target_ranks, neg_ratio, + edge_task ("score" or "interaction_type"). + **kwargs : dict + Additional keyword arguments passed to InMemoryDataset. + """ + + INTERACTION_TYPES: ClassVar[list[str]] = [ + "reaction", + "binding", + "ptmod", + "activation", + "inhibition", + "catalysis", + "expression", + ] + + # Data source URLs + HIGHPPI_GDRIVE_FOLDER: ClassVar[str] = ( + "https://drive.google.com/drive/folders/1Yb-fdWJ5vTe0ePAGNfrUluzO9tz1lHIF?usp=sharing" + ) + CORUM_URL: ClassVar[str] = ( + "https://mips.helmholtz-muenchen.de/fastapi-corum/public/file/download_current_file?file_id=human&file_format=txt" + ) + + # Required raw data filenames + HIGHPPI_NETWORK_FILE: ClassVar[str] = ( + "protein.actions.SHS27k.STRING.pro2.txt" + ) + ID_MAPPING_FILE: ClassVar[str] = "ensp_uniprot.txt" + CORUM_COMPLEXES_FILE: ClassVar[str] = "allComplexes.txt" + + def __init__( + self, + root: str, + name: str = "ppi_highppi", + parameters: DictConfig = None, + **kwargs, + ): + self.name = name + self.parameters = parameters or DictConfig({}) + self.min_complex_size = self.parameters.get("min_complex_size", 2) + self.max_complex_size = self.parameters.get("max_complex_size", 6) + self.max_rank = self.max_complex_size - 1 + self.neg_ratio = self.parameters.get("neg_ratio", 1.0) + self.target_ranks = self.parameters.get("target_ranks", [2, 3, 4, 5]) + self.edge_task = self.parameters.get("edge_task", "score") + + self.highppi_edges = [] # List of (p1, p2, interaction_type_vector, confidence_score) + self.corum_complexes = [] # List of sets of proteins in a complex + self.all_proteins = set() + self.ensembl_to_uniprot = {} + self.uniprot_to_ensembl = {} + self.official_splits = {} + + super().__init__(root, **kwargs) + + out = fs.torch_load(self.processed_paths[0]) + if len(out) == 3: + data, self.slices, self.sizes = out + data_cls = Data + else: + data, self.slices, self.sizes, data_cls = out + + if not isinstance(data, dict): + self.data = data + else: + self.data = data_cls.from_dict(data) + + # Ensure data.y is set for single-rank compatibility + # TODO: Change for B2 submission which will introduce a unified training loop + if len(self.target_ranks) == 1: + label_attr = f"cell_labels_{self.target_ranks[0]}" + if hasattr(self._data, label_attr): + self._data.y = getattr(self._data, label_attr) + + @property + def raw_dir(self) -> str: + """Return the path to the raw directory. + + Returns + ------- + str + Path to the raw directory. + """ + return osp.join(self.root, "raw") + + @property + def processed_dir(self) -> str: + """Return the path to the processed directory. + + Returns + ------- + str + Path to the processed directory. + """ + return osp.join(self.root, "processed") + + @property + def raw_file_names(self) -> list[str]: + """Return list of required raw file names. + + Returns + ------- + List[str] + Required raw data files. + """ + return [ + self.HIGHPPI_NETWORK_FILE, + self.ID_MAPPING_FILE, + self.CORUM_COMPLEXES_FILE, + ] + + @property + def processed_file_names(self) -> list[str]: + """Return the name of the processed file. + + Filename includes target_ranks to avoid cache conflicts when + different ranks are requested. + + Returns + ------- + List[str] + List containing the name of the processed file. + """ + # Include target_ranks in filename to prevent cache conflicts + ranks_str = "_".join(map(str, self.target_ranks)) + return [f"data_ranks_{ranks_str}.pt"] + + def download(self) -> None: + """Download HIGH-PPI and CORUM data files.""" + + # Check if files already exist + all_exist = all( + osp.exists(osp.join(self.raw_dir, fname)) + for fname in self.raw_file_names + ) + if all_exist: + print("All required files already present") + return + + print("Downloading HIGH-PPI SHS27k dataset and CORUM complexes...") + os.makedirs(self.raw_dir, exist_ok=True) + + if not osp.exists(osp.join(self.raw_dir, self.CORUM_COMPLEXES_FILE)): + print("Downloading CORUM human protein complexes...") + download_file( + self.CORUM_URL, + osp.join(self.raw_dir, self.CORUM_COMPLEXES_FILE), + verify_ssl=False, + ) + print("CORUM download complete") + + if not osp.exists(osp.join(self.raw_dir, self.ID_MAPPING_FILE)): + print("Downloading Ensembl-UniProt ID mapping...") + download_ensembl_biomart_mapping( + osp.join(self.raw_dir, self.ID_MAPPING_FILE) + ) + print("ID mapping download complete") + + if not osp.exists(osp.join(self.raw_dir, self.HIGHPPI_NETWORK_FILE)): + print("Downloading HIGH-PPI network data from Google Drive...") + success = download_folder_from_drive( + self.HIGHPPI_GDRIVE_FOLDER, self.raw_dir, quiet=False + ) + if not success: + raise RuntimeError( + "Failed to download HIGH-PPI data from Google Drive" + ) + print("HIGH-PPI download complete") + + # Final verification + missing_files = [ + fname + for fname in self.raw_file_names + if not osp.exists(osp.join(self.raw_dir, fname)) + ] + + if missing_files: + raise FileNotFoundError( + f"Failed to download required files: {missing_files}. " + ) + + def process(self): + """Build simplicial complex: HIGH-PPI edges + CORUM complexes.""" + print("\n" + "=" * 70) + print( + "Building PPI simplicial complex from HIGH-PPI and CORUM datasets" + ) + print("=" * 70) + + # Load Ensembl <-> UniProt ID mapping + mapping_path = osp.join(self.raw_dir, self.ID_MAPPING_FILE) + self.ensembl_to_uniprot, self.uniprot_to_ensembl = load_id_mapping( + mapping_path + ) + + # Load HIGH-PPI network with interaction types and confidence scores + highppi_path = osp.join(self.raw_dir, self.HIGHPPI_NETWORK_FILE) + self.highppi_edges, self.all_proteins = load_highppi_network( + highppi_path, self.INTERACTION_TYPES + ) + + # Load CORUM complexes, filter to SHS27k proteins + corum_path = osp.join(self.raw_dir, self.CORUM_COMPLEXES_FILE) + self.corum_complexes = load_corum_complexes( + corum_path, + self.all_proteins, + self.ensembl_to_uniprot, + self.uniprot_to_ensembl, + self.min_complex_size, + self.max_complex_size, + ) + + self._load_splits() + + print("Building simplicial complex...") + sc, edge_data, cell_data = build_simplicial_complex_with_features( + self.all_proteins, + self.highppi_edges, + self.corum_complexes, + self.min_complex_size, + self.max_rank, + ) + + print("Generating negative samples...") + edge_data, cell_data = generate_negative_samples( + sc, edge_data, cell_data, self.all_proteins, self.neg_ratio + ) + + print("Extracting features and connectivity...") + x_dict, labels_dict = build_data_features_and_labels( + sc, + edge_data, + cell_data, + self.target_ranks, + self.max_rank, + edge_task=self.edge_task, + ) + + # Get connectivity + connectivity = get_complex_connectivity( + sc, self.max_rank, signed=False + ) + + # Build Data object + protein_list = sorted(list(sc.nodes)) + protein_to_idx = {p: i for i, p in enumerate(protein_list)} + n_edges = len(list(sc.skeleton(1))) + + data = Data( + **x_dict, + **connectivity, + **labels_dict, + num_proteins=len(protein_list), + num_edges=n_edges, + num_complexes=len(self.corum_complexes), + protein_to_idx=protein_to_idx, + ) + + # Add x and y for compatibility with generic tests + # x_0 uses one-hot encoding, so dimension equals number of proteins + data.x = x_dict.get( + "x_0", torch.zeros(0, len(protein_list)) + ) # TODO: This data will not be used for node-level prediction + if ( + self.target_ranks + and f"cell_labels_{self.target_ranks[0]}" in labels_dict + ): + data.y = labels_dict[f"cell_labels_{self.target_ranks[0]}"] + else: + data.y = torch.zeros(len(protein_list), dtype=torch.long) + + # Add official splits if available + if self.official_splits: + data.train_mask = torch.tensor( + self.official_splits.get("train_index", []), dtype=torch.long + ) + data.val_mask = torch.tensor( + self.official_splits.get("valid_index", []), dtype=torch.long + ) + + # Save processed data + print("Saving processed data...") + self.data, self.slices = self.collate([data]) + fs.torch_save( + (self._data.to_dict(), self.slices, {}, self._data.__class__), + self.processed_paths[0], + ) + + print("\n" + "=" * 70) + print("✅ PROCESSING COMPLETE!") + print("📊 Dataset statistics:") + print(f" - Proteins (0-cells): {len(self.all_proteins)}") + print(f" - Labeled edges (1-cells): {len(self.highppi_edges)}") + print(f" - CORUM complexes: {len(self.corum_complexes)}") + print(f"📁 Saved to: {self.processed_paths[0]}") + print(f"💾 Size: {osp.getsize(self.processed_paths[0]) / 1e6:.1f} MB") + print("=" * 70 + "\n") + + @property + def data_list(self): + """Return list of data objects for TopoBench compatibility. + + Returns + ------- + list + List containing single data object (transductive setting). + """ + return [self._data] + + def get_data_dir(self): + """Return data directory for split file storage. + + Returns + ------- + str + Path to data directory. + """ + return self.root + + @property + def split_idx(self): + """Return train/val/test split indices for split_type='fixed'. + + Used when config has split_type='fixed'. Returns HIGH-PPI's official + train/val split if it was successfully loaded, otherwise None. + + Returns + ------- + dict or None + Dictionary with 'train', 'valid', 'test' keys containing indices, + or None (triggers random/k-fold splitting based on split_type). + """ + if hasattr(self, "official_splits") and self.official_splits: + return { + "train": np.array(self.official_splits.get("train_index", [])), + "valid": np.array(self.official_splits.get("val_index", [])), + "test": np.array(self.official_splits.get("val_index", [])), + } + return None + + # TODO: This is not working yet + def _load_splits(self): + """Load official train/val split indices from HIGH-PPI. + + Loads splits into self.official_splits which will be used if split_type='fixed'. + Fails silently if splits are not available (random/k-fold will be used instead). + """ + split_path = osp.join(self.raw_dir, "train_val_split_1.json") + if not osp.exists(split_path): + return + + try: + with open(split_path) as f: + content = f.read().strip() + if len(content) >= 10: # Basic validation + self.official_splits = json.loads(content) + print( + "Official train/val splits available (use split_type='fixed' to use them)" + ) + except (json.JSONDecodeError, Exception): + pass # Silently ignore - will use random/k-fold splitting diff --git a/topobench/data/loaders/simplicial/ppi_highppi_loader.py b/topobench/data/loaders/simplicial/ppi_highppi_loader.py new file mode 100644 index 000000000..9cd737589 --- /dev/null +++ b/topobench/data/loaders/simplicial/ppi_highppi_loader.py @@ -0,0 +1,71 @@ +"""Loader for PPI dataset (HIGH-PPI variant) with CORUM complexes.""" + +from omegaconf import DictConfig + +from topobench.data.datasets.ppi_highppi_dataset import PPIHighPPIDataset +from topobench.data.loaders.base import AbstractLoader + + +class PPIHighPPIDatasetLoader(AbstractLoader): + """Load HIGH-PPI SHS27k dataset with CORUM topological enrichment. + + This loader creates a hybrid simplicial complex from: + - HIGH-PPI's SHS27k PPI network (labeled edges) + - CORUM protein complexes (unlabeled higher-order cells) + + Task: Edge-level multi-label classification (7 interaction types) + + Parameters + ---------- + parameters : DictConfig + Configuration parameters containing: + - data_dir: Root directory for data + - data_name: Name of the dataset + - min_complex_size: Minimum CORUM complex size + - max_complex_size: Maximum CORUM complex size + - max_rank: Maximum simplicial rank + - use_official_split: Use HIGH-PPI's train/val split + **kwargs : dict + Additional keyword arguments. + """ + + def __init__(self, parameters: DictConfig, **kwargs) -> None: + super().__init__(parameters, **kwargs) + + def load_dataset(self, **kwargs) -> PPIHighPPIDataset: + """Load the dataset. + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments passed to dataset initialization. + + Returns + ------- + PPIHighPPIDataset + Dataset with HIGH-PPI network and CORUM complexes. + """ + dataset = self._initialize_dataset(**kwargs) + self.data_dir = self.get_data_dir() + return dataset + + def _initialize_dataset(self, **kwargs) -> PPIHighPPIDataset: + """Initialize the HIGH-PPI SHS27k dataset. + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments for dataset initialization. + + Returns + ------- + PPIHighPPIDataset + The initialized dataset instance. + """ + self.dataset = PPIHighPPIDataset( + root=str(self.root_data_dir), + name=self.parameters.get("data_name", "highppi_shs27k"), + parameters=self.parameters, + **kwargs, + ) + return self.dataset diff --git a/topobench/data/utils/datasets/simplicial/ppi_utils.py b/topobench/data/utils/datasets/simplicial/ppi_utils.py new file mode 100644 index 000000000..3eb85234b --- /dev/null +++ b/topobench/data/utils/datasets/simplicial/ppi_utils.py @@ -0,0 +1,558 @@ +"""Refactored PPI utilities with cleaner separation of concerns. + +Key improvements: +1. Separate topology building from feature/label assignment +2. Consistent data types (no mixed list/float/int in cell_labels) +3. Single-pass iterations (no redundant loops) +4. Clear data flow: topology → features → labels → tensors +""" + +import os +import random +from itertools import combinations + +import pandas as pd +import torch +from toponetx.classes import SimplicialComplex + + +def load_id_mapping( + mapping_path: str, +) -> tuple[dict[str, str], dict[str, list[str]]]: + """Load Ensembl ↔ UniProt ID mapping. + + Parameters + ---------- + mapping_path : str + Path to ensp_uniprot.txt mapping file. + + Returns + ------- + ensembl_to_uniprot : dict + Mapping from Ensembl IDs to UniProt IDs. + uniprot_to_ensembl : dict + Reverse mapping (UniProt to list of Ensembl IDs). + + Raises + ------ + FileNotFoundError + If mapping file does not exist. + """ + if not os.path.exists(mapping_path): + raise FileNotFoundError( + f"ID mapping file not found: {mapping_path}. " + "This file is required to map between Ensembl and UniProt IDs for CORUM complexes." + ) + + ensembl_to_uniprot = {} + uniprot_to_ensembl = {} + + with open(mapping_path) as f: + for line in f: + line = line.strip() + if not line or line.startswith("#"): + continue + + parts = line.split("\t") + if len(parts) >= 2: + ensembl_id = parts[0].strip() + uniprot_id = parts[1].strip() + + if uniprot_id and uniprot_id not in ("Noneid", "None"): + ensembl_to_uniprot[ensembl_id] = uniprot_id + + if uniprot_id not in uniprot_to_ensembl: + uniprot_to_ensembl[uniprot_id] = [] + uniprot_to_ensembl[uniprot_id].append(ensembl_id) + + return ensembl_to_uniprot, uniprot_to_ensembl + + +def load_highppi_network( + file_path: str, interaction_types: list[str] +) -> tuple[list[tuple], set[str]]: + """Load HIGH-PPI network with interaction types and confidence scores. + + Parameters + ---------- + file_path : str + Path to HIGH-PPI SHS27k file. + interaction_types : list + List of valid interaction type names. + + Returns + ------- + highppi_edges : list + List of (p1, p2, interaction_vector, score) tuples. + all_proteins : set + Set of all protein IDs in the network. + + Raises + ------ + FileNotFoundError + If HIGH-PPI network file does not exist. + """ + + if not os.path.exists(file_path): + raise FileNotFoundError( + f"HIGH-PPI network file not found: {file_path}. " + "This file contains the protein-protein interaction network." + ) + + df = pd.read_csv(file_path, sep="\t") + + # Rename columns to more intuitive names + df = df.rename( + columns={ + "item_id_a": "protein_1", + "item_id_b": "protein_2", + "mode": "interaction_type", + "score": "confidence_score", + } + ) + + edge_labels = {} + edge_scores = {} + all_proteins = set() + + for _, row in df.iterrows(): + p1 = str(row["protein_1"]).strip() + p2 = str(row["protein_2"]).strip() + + all_proteins.add(p1) + all_proteins.add(p2) + + edge_key = tuple(sorted([p1, p2])) + + score = float(row["confidence_score"]) / 1000.0 + if edge_key not in edge_scores: + edge_scores[edge_key] = 0.0 + edge_scores[edge_key] = max(edge_scores[edge_key], score) + + interaction_type = str(row["interaction_type"]).strip() + if edge_key not in edge_labels: + edge_labels[edge_key] = [0] * 7 + if interaction_type in interaction_types: + idx_type = interaction_types.index(interaction_type) + edge_labels[edge_key][idx_type] = 1 + + highppi_edges = [ + (p1, p2, labels, edge_scores[(p1, p2)]) + for (p1, p2), labels in edge_labels.items() + ] + + return highppi_edges, all_proteins + + +def load_corum_complexes( + file_path: str, + all_proteins: set[str], + ensembl_to_uniprot: dict[str, str], + uniprot_to_ensembl: dict[str, list[str]], + min_size: int, + max_size: int, +) -> list[set[str]]: + """Load and filter CORUM protein complexes. + + Parameters + ---------- + file_path : str + Path to CORUM allComplexes.txt file. + all_proteins : set + Set of proteins in the network (for filtering). + ensembl_to_uniprot : dict + Ensembl to UniProt ID mapping. + uniprot_to_ensembl : dict + UniProt to Ensembl ID mapping. + min_size : int + Minimum complex size. + max_size : int + Maximum complex size. + + Returns + ------- + list[set[str]] + List of sets, each containing Ensembl protein IDs. + + Raises + ------ + FileNotFoundError + If CORUM file does not exist. + """ + + if not os.path.exists(file_path): + raise FileNotFoundError( + f"CORUM complexes file not found: {file_path}. " + "This file is required to load experimentally validated protein complexes." + ) + + df = pd.read_csv(file_path, sep="\t", low_memory=False) + + # CORUM uses 'subunits_uniprot_id' column for UniProt IDs + if "subunits_uniprot_id" not in df.columns: + raise ValueError( + f"Expected column 'subunits_uniprot_id' not found in CORUM file. Available columns: {df.columns.tolist()}" + ) + + # Map proteins to UniProt + shs27k_uniprot = { + ensembl_to_uniprot[eid] + for eid in all_proteins + if eid in ensembl_to_uniprot + } + + corum_complexes = [] + for _, row in df.iterrows(): + subunits_str = row["subunits_uniprot_id"] + if pd.isna(subunits_str): + continue + + proteins_uniprot = { + p.strip() for p in subunits_str.split(";") if p.strip() + } + proteins_in_network = proteins_uniprot & shs27k_uniprot + + if not (min_size <= len(proteins_in_network) <= max_size): + continue + + # Convert to Ensembl IDs + ensembl_complex = set() + for uniprot_id in proteins_in_network: + if uniprot_id in uniprot_to_ensembl: + for ensembl_id in uniprot_to_ensembl[uniprot_id]: + if ensembl_id in all_proteins: + ensembl_complex.add(ensembl_id) + break + + corum_complexes.append(ensembl_complex) + + return corum_complexes + + +def build_simplicial_complex_with_features( + all_proteins: set[str], + highppi_edges: list[tuple], + corum_complexes: list[set[str]], + min_complex_size: int, + max_rank: int, +) -> tuple[SimplicialComplex, dict, dict]: + """Build simplicial complex with topology and metadata from PPI data. + + Constructs the complex structure and tracks cell data. + + Parameters + ---------- + all_proteins : set + Set of all protein IDs. + highppi_edges : list + List of (p1, p2, interaction_vector, score) tuples. + corum_complexes : list + List of protein complexes as sets. + min_complex_size : int + Minimum complex size to include. + max_rank : int + Maximum rank to consider. + + Returns + ------- + sc : SimplicialComplex + The constructed simplicial complex. + edge_data : dict + Edge features {edge_tuple: tensor([7 interaction types, 1 confidence])}. + cell_data : dict + Binary labels per rank {rank: {cell_tuple: {-1, 1}}}. + """ + sc = SimplicialComplex() + edge_data = {} # {edge_tuple: tensor([7 types + 1 score])} + cell_data = {} # {rank: {cell_tuple: -1 or 1}} + + # Add 0-cells (proteins) + for protein in sorted(all_proteins): + sc.add_simplex([protein]) + + # Add 1-cells (HIGH-PPI edges) with features + for p1, p2, interaction_vector, score in highppi_edges: + edge_tuple = tuple(sorted([p1, p2])) + sc.add_simplex([p1, p2]) + + # Store 8-dim feature vector (7 interaction types + 1 confidence) + edge_data[edge_tuple] = torch.tensor( + interaction_vector + [2 * score - 1], + dtype=torch.float, # Convert confidence score affinely: [0, 1] -> [-1, 1] + ) + + # Process CORUM complexes top-down (largest first) + # This ensures lower-rank CORUM complexes can override negative labels + sorted_complexes = sorted(corum_complexes, key=len, reverse=True) + + for complex_proteins in sorted_complexes: + # Filter by size and protein membership + if len(complex_proteins) < min_complex_size: + continue + if not complex_proteins.issubset(all_proteins): + continue + + complex_tuple = tuple(sorted(complex_proteins)) + rank = len(complex_tuple) - 1 + + if rank > max_rank: + continue + + # Add complex to simplicial complex (automatically adds all faces) + sc.add_simplex(list(complex_tuple)) + + if rank == 1: + edge_data[complex_tuple] = torch.tensor( + [0, 0, 0, 0, 0, 0, 0, 1.0], dtype=torch.float + ) + continue + + # Mark this complex as positive + if rank not in cell_data: + cell_data[rank] = {} + cell_data[rank][complex_tuple] = 1 + + # Mark all proper sub-faces as negative (not real complexes themselves) + # Only mark if not already labeled (top-down iteration handles overlaps) + for sub_rank in range(1, rank): + if sub_rank not in cell_data: + cell_data[sub_rank] = {} + + for sub_face in combinations(complex_tuple, sub_rank + 1): + sub_tuple = tuple(sorted(sub_face)) + if sub_rank == 1 and sub_tuple not in edge_data: + edge_data[sub_tuple] = torch.tensor( + [0, 0, 0, 0, 0, 0, 0, -1.0], dtype=torch.float + ) + elif sub_tuple not in cell_data[sub_rank]: + cell_data[sub_rank][sub_tuple] = -1 + + return sc, edge_data, cell_data + + +def generate_negative_samples( + sc: SimplicialComplex, + edge_data: dict[tuple, torch.Tensor], + cell_data: dict[int, dict[tuple, int]], + all_proteins: set[str], + neg_ratio: float, +) -> tuple[dict[tuple, torch.Tensor], dict[int, dict[tuple, int]]]: + """Generate negative samples proportionally across ranks. + + Parameters + ---------- + sc : SimplicialComplex + Current simplicial complex. + edge_data : dict + Edge features to update with negative edges. + cell_data : dict + Existing binary data per rank {rank: {cell_tuple: {-1, 1}}}. + all_proteins : set + Set of all protein IDs. + neg_ratio : float + Ratio of negative to positive samples. + + Returns + ------- + edge_data : dict + Updated with negative edge features. + cell_data : dict + Updated data with negative samples added (value=-1). + """ + random.seed(42) + + # Count positive samples per rank + positive_counts = {} + for rank in range(2, sc.dim + 1): + if rank in cell_data: + positive_counts[rank] = sum( + 1 for label in cell_data[rank].values() if label == 1 + ) + else: + positive_counts[rank] = 0 + + # Generate negatives per rank + all_proteins_list = list(all_proteins) + + for rank, n_positive in positive_counts.items(): + if n_positive == 0: + continue + + n_negative_needed = int(n_positive * neg_ratio) + if n_negative_needed == 0: + continue + + # Get existing cells at this rank + existing_cells = set(cell_data.get(rank, {}).keys()) + existing_cells.update( + tuple(sorted(cell)) for cell in sc.skeleton(rank) + ) + + # Generate random cells until we have enough negatives + negatives_added = 0 + max_attempts = n_negative_needed * 100 + + for _ in range(max_attempts): + if negatives_added >= n_negative_needed: + break + + # Sample random proteins for this rank + sampled = random.sample(all_proteins_list, rank + 1) + cell_tuple = tuple(sorted(sampled)) + + # Only add if it doesn't exist yet + if cell_tuple not in existing_cells: + # Add to complex and label as negative + sc.add_simplex(list(cell_tuple)) + + if rank not in cell_data: + cell_data[rank] = {} + cell_data[rank][cell_tuple] = -1 + + # For edges (rank 1), also create feature vector + if rank == 1 and cell_tuple not in edge_data: + edge_data[cell_tuple] = torch.tensor( + [0, 0, 0, 0, 0, 0, 0, 0.0], dtype=torch.float + ) + + existing_cells.add(cell_tuple) + negatives_added += 1 + + print( + f" Rank {rank}: {n_positive} positive, {negatives_added} negative samples" + ) + + return edge_data, cell_data + + +def build_data_features_and_labels( + sc: SimplicialComplex, + edge_data: dict[tuple, torch.Tensor], + cell_data: dict[int, dict[tuple, int]], + target_ranks: list[int], + max_rank: int, + edge_task: str = None, +) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: + """Create feature and label tensors for all ranks. + + Parameters + ---------- + sc : SimplicialComplex + The constructed simplicial complex. + edge_data : dict + Edge features {edge_tuple: tensor([7 interaction types, 1 confidence])}. + cell_data : dict + Binary labels {rank: {cell_tuple: {-1, 1}}}. + target_ranks : list + Ranks to predict on. + max_rank : int + Maximum rank in configuration. + edge_task : str, optional + Edge prediction task: "interaction_type" or "score". + Only used if rank 1 is in target_ranks. + + Returns + ------- + x_dict : dict + Features per rank {f"x_{rank}": tensor}. + labels_dict : dict + Labels per target rank {f"cell_labels_{rank}": tensor}. + """ + actual_max_rank = sc.dim + x_dict = {} + labels_dict = {} + + # For each rank: build features and labels + for rank in range(min(max_rank, actual_max_rank) + 1): + cells = list(sc.skeleton(rank)) + n_cells = len(cells) + + if n_cells == 0: + dim = 1 if rank == 0 else (8 if rank == 1 else 1) + x_dict[f"x_{rank}"] = torch.zeros(0, dim) + continue + + is_target = rank in target_ranks + + match rank: + case 0: + # Nodes: one-hot encoding + # TODO: Use richer embeddings (ESM, structure, GO annotations) + x_dict["x_0"] = torch.eye(n_cells) + + case 1: + # Edges: 8-dim features (7 interaction types + 1 confidence) + features = [] + labels = [] if is_target else None + + for edge in cells: + edge_tuple = tuple(sorted(edge)) + feat_vec = edge_data[edge_tuple] + + if is_target: + # Split features/labels based on edge_task + if edge_task == "interaction_type": + labels.append( + feat_vec[:7] + ) # First 7 dims = labels + features.append( + feat_vec[7:8] + ) # Last dim = feature + elif edge_task == "score": + labels.append(feat_vec[7:8]) # Last dim = label + features.append( + feat_vec[:7] + ) # First 7 dims = features + else: + # Not a target rank: use all 8 dims as features + features.append(feat_vec) + + x_dict["x_1"] = torch.stack(features) + + if is_target: + labels_dict["cell_labels_1"] = torch.stack(labels) + + case _: + # Higher-order cells + features = [] + labels = [] if is_target else None + + for cell in cells: + cell_tuple = tuple(sorted(cell)) + binary_existence_val = cell_data[rank][cell_tuple] + + # Features: 0 for target, {-1,+1} for non-target TODO: Bit unsure about this. Non-interacting edges also get 0 and 0 means it will not influence neighbors + # TODO: Maybe we should pass some labels as features for true transductivity/semi-supervision? + if is_target: + # Target rank: features are 0, labels are in {-1, 1} + features.append(torch.zeros(1, dtype=torch.float)) + labels.append( + torch.tensor( + [binary_existence_val], dtype=torch.float + ) + ) + else: + # Non-target rank: use labels as features {-1, +1} + features.append( + torch.tensor( + [binary_existence_val], dtype=torch.float + ) + ) + + x_dict[f"x_{rank}"] = torch.stack(features) + + if is_target: + labels_tensor = torch.tensor(labels, dtype=torch.long) + labels_dict[f"cell_labels_{rank}"] = labels_tensor + n_pos = (labels_tensor == 1).sum().item() + n_neg = (labels_tensor == 0).sum().item() + print( + f" Rank {rank}: {n_pos} positive, {n_neg} negative labels" + ) + + # Add empty features for non-existent ranks + for rank in range(actual_max_rank + 1, max_rank + 1): + dim = 1 if rank == 0 else (8 if rank == 1 else 1) + x_dict[f"x_{rank}"] = torch.zeros(0, dim) + + return x_dict, labels_dict From 086ef7b8c76bba7e4d2dcdd53eb3a29d3152d120 Mon Sep 17 00:00:00 2001 From: I745505 Date: Tue, 18 Nov 2025 23:22:08 +0100 Subject: [PATCH 03/32] :bug: Creation of unlabeled cells that were meant to be negative + Labels should be {0,1} or [0,1] respectively for CrossEntropyLoss and interpretability --- .../utils/datasets/simplicial/ppi_utils.py | 74 ++++++++++++------- 1 file changed, 48 insertions(+), 26 deletions(-) diff --git a/topobench/data/utils/datasets/simplicial/ppi_utils.py b/topobench/data/utils/datasets/simplicial/ppi_utils.py index 3eb85234b..8271b14f3 100644 --- a/topobench/data/utils/datasets/simplicial/ppi_utils.py +++ b/topobench/data/utils/datasets/simplicial/ppi_utils.py @@ -361,25 +361,29 @@ def generate_negative_samples( """ random.seed(42) - # Count positive samples per rank - positive_counts = {} - for rank in range(2, sc.dim + 1): + all_proteins_list = list(all_proteins) + + # Generate negatives from highest to rank 2 (top-down) + # Note that edges already have score in [-1, 1] so no need to add more negatives + for rank in range(sc.dim, 1, -1): + # Count positive and existing negative samples at this rank + n_positive = 0 + n_existing_negative = 0 if rank in cell_data: - positive_counts[rank] = sum( + n_positive = sum( 1 for label in cell_data[rank].values() if label == 1 ) - else: - positive_counts[rank] = 0 - - # Generate negatives per rank - all_proteins_list = list(all_proteins) - - for rank, n_positive in positive_counts.items(): + n_existing_negative = sum( + 1 for label in cell_data[rank].values() if label == -1 + ) if n_positive == 0: continue - n_negative_needed = int(n_positive * neg_ratio) - if n_negative_needed == 0: + # Calculate how many more negatives we need (accounting for existing ones) + n_negative_target = int(n_positive * neg_ratio) + n_negative_needed = n_negative_target - n_existing_negative + if n_negative_needed <= 0: + # Enough negatives continue # Get existing cells at this rank @@ -402,24 +406,37 @@ def generate_negative_samples( # Only add if it doesn't exist yet if cell_tuple not in existing_cells: - # Add to complex and label as negative + # Add to complex (automiatically adds faces) sc.add_simplex(list(cell_tuple)) if rank not in cell_data: cell_data[rank] = {} + # Label as negative cell_data[rank][cell_tuple] = -1 - # For edges (rank 1), also create feature vector - if rank == 1 and cell_tuple not in edge_data: - edge_data[cell_tuple] = torch.tensor( - [0, 0, 0, 0, 0, 0, 0, 0.0], dtype=torch.float - ) + # Mark all proper sub-faces as negative (if they don't exist yet) + for sub_rank in range(1, rank): + if sub_rank not in cell_data: + cell_data[sub_rank] = {} + + for sub_face in combinations(cell_tuple, sub_rank + 1): + sub_tuple = tuple(sorted(sub_face)) + if sub_rank == 1 and sub_tuple not in edge_data: + # Edge: create feature vector + edge_data[sub_tuple] = torch.tensor( + [0, 0, 0, 0, 0, 0, 0, -1.0], dtype=torch.float + ) + elif sub_tuple not in cell_data[sub_rank]: + # Higher-order sub-face: mark as negative + cell_data[sub_rank][sub_tuple] = -1 existing_cells.add(cell_tuple) negatives_added += 1 + # Calculate total negatives (existing + newly added) + n_total_negative = n_existing_negative + negatives_added print( - f" Rank {rank}: {n_positive} positive, {negatives_added} negative samples" + f" Rank {rank}: {n_positive} positive, {n_total_negative} negative samples" ) return edge_data, cell_data @@ -499,7 +516,9 @@ def build_data_features_and_labels( feat_vec[7:8] ) # Last dim = feature elif edge_task == "score": - labels.append(feat_vec[7:8]) # Last dim = label + # Convert score back from [-1, 1] to [0, 1] for standard regression + score_normalized = (feat_vec[7:8] + 1) / 2 + labels.append(score_normalized) features.append( feat_vec[:7] ) # First 7 dims = features @@ -521,7 +540,8 @@ def build_data_features_and_labels( cell_tuple = tuple(sorted(cell)) binary_existence_val = cell_data[rank][cell_tuple] - # Features: 0 for target, {-1,+1} for non-target TODO: Bit unsure about this. Non-interacting edges also get 0 and 0 means it will not influence neighbors + # Features: 0 for target, {-1,+1} for non-target TODO: Bit unsure about this. + # Labels (only target): {0, 1} for PyTorch CrossEntropyLoss # TODO: Maybe we should pass some labels as features for true transductivity/semi-supervision? if is_target: # Target rank: features are 0, labels are in {-1, 1} @@ -542,10 +562,12 @@ def build_data_features_and_labels( x_dict[f"x_{rank}"] = torch.stack(features) if is_target: - labels_tensor = torch.tensor(labels, dtype=torch.long) - labels_dict[f"cell_labels_{rank}"] = labels_tensor - n_pos = (labels_tensor == 1).sum().item() - n_neg = (labels_tensor == 0).sum().item() + labels_tensor = torch.stack(labels).squeeze() + # Convert {-1, +1} → {0, 1} for PyTorch CrossEntropyLoss + labels_mapped = ((labels_tensor + 1) / 2).long() + labels_dict[f"cell_labels_{rank}"] = labels_mapped + n_pos = (labels_mapped == 1).sum().item() + n_neg = (labels_mapped == 0).sum().item() print( f" Rank {rank}: {n_pos} positive, {n_neg} negative labels" ) From 2487ee739860617a48bf5ac045e20c629633e2ee Mon Sep 17 00:00:00 2001 From: I745505 Date: Wed, 19 Nov 2025 00:01:59 +0100 Subject: [PATCH 04/32] :waste_basket: Remove unused imports --- test/data/load/test_datasetloaders.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/test/data/load/test_datasetloaders.py b/test/data/load/test_datasetloaders.py index cb21fd421..db838e99b 100644 --- a/test/data/load/test_datasetloaders.py +++ b/test/data/load/test_datasetloaders.py @@ -1,12 +1,9 @@ """Comprehensive test suite for all dataset loaders.""" -import os import pytest -import torch import hydra from pathlib import Path from typing import List, Tuple, Dict, Any -from omegaconf import DictConfig -from topobench.data.preprocessor.preprocessor import PreProcessor + class TestLoaders: """Comprehensive test suite for all dataset loaders.""" From c99fb659584fe073720d6e4ec802c68a8d3376e0 Mon Sep 17 00:00:00 2001 From: I745505 Date: Wed, 19 Nov 2025 11:03:49 +0100 Subject: [PATCH 05/32] :bug: 1D tensor label for edge-level regression --- .../data/utils/datasets/simplicial/ppi_utils.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/topobench/data/utils/datasets/simplicial/ppi_utils.py b/topobench/data/utils/datasets/simplicial/ppi_utils.py index 8271b14f3..d9539c9c2 100644 --- a/topobench/data/utils/datasets/simplicial/ppi_utils.py +++ b/topobench/data/utils/datasets/simplicial/ppi_utils.py @@ -517,7 +517,9 @@ def build_data_features_and_labels( ) # Last dim = feature elif edge_task == "score": # Convert score back from [-1, 1] to [0, 1] for standard regression - score_normalized = (feat_vec[7:8] + 1) / 2 + score_normalized = ( + (feat_vec[7] + 1) / 2 + ).item() # Scalar value labels.append(score_normalized) features.append( feat_vec[:7] @@ -529,7 +531,14 @@ def build_data_features_and_labels( x_dict["x_1"] = torch.stack(features) if is_target: - labels_dict["cell_labels_1"] = torch.stack(labels) + if edge_task == "score": + # For regression: 1D tensor of scalar values + labels_dict["cell_labels_1"] = torch.tensor( + labels, dtype=torch.float + ) + else: + # For multi-label classification: 2D tensor + labels_dict["cell_labels_1"] = torch.stack(labels) case _: # Higher-order cells From 57728a0e3d7682a648c7ea657a7239a8e262f952 Mon Sep 17 00:00:00 2001 From: I745505 Date: Wed, 19 Nov 2025 11:22:11 +0100 Subject: [PATCH 06/32] :construction: Prepare cell-level prediction --- topobench/model/model.py | 4 ++-- topobench/nn/readouts/base.py | 9 ++++++--- topobench/run.py | 6 ++++-- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/topobench/model/model.py b/topobench/model/model.py index a7c688b47..b708d96e8 100755 --- a/topobench/model/model.py +++ b/topobench/model/model.py @@ -242,8 +242,8 @@ def process_outputs(self, model_out: dict, batch: Data) -> dict: else: raise ValueError("Invalid state_str") - if self.task_level == "node": - # Keep only train data points + if self.task_level in ["node", "cell"]: + # Keep only train data points (for node-level or cell-level tasks) for key, val in model_out.items(): if key in ["logits", "labels"]: model_out[key] = val[mask] diff --git a/topobench/nn/readouts/base.py b/topobench/nn/readouts/base.py index 6fdd8412f..45c5ac69a 100755 --- a/topobench/nn/readouts/base.py +++ b/topobench/nn/readouts/base.py @@ -42,7 +42,9 @@ def __init__( if hidden_dim != out_channels or logits_linear_layer else torch.nn.Identity() ) - assert task_level in ["graph", "node"], "Invalid task_level" + assert task_level in ["graph", "node", "cell"], ( + "Invalid task_level. Must be 'graph', 'node', or 'cell'." + ) self.task_level = task_level self.logits_linear_layer = logits_linear_layer @@ -84,7 +86,7 @@ def compute_logits(self, x, batch): Parameters ---------- x : torch.Tensor - Node embeddings. + Cell embeddings. batch : torch.Tensor Batch index tensor. @@ -94,8 +96,9 @@ def compute_logits(self, x, batch): Logits tensor. """ if self.task_level == "graph": + # Graph-level: pool across batch (one prediction per graph) x = scatter(x, batch, dim=0, reduce=self.pooling_type) - + # Cell-level: no pooling (one prediction per Cell) return self.linear(x) @abstractmethod diff --git a/topobench/run.py b/topobench/run.py index ab6f8602a..a3e749411 100755 --- a/topobench/run.py +++ b/topobench/run.py @@ -171,7 +171,7 @@ def run(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]: ) # Prepare datamodule log.info("Instantiating datamodule...") - if cfg.dataset.parameters.task_level in ["node", "graph"]: + if cfg.dataset.parameters.task_level in ["node", "graph", "cell"]: datamodule = TBDataloader( dataset_train=dataset_train, dataset_val=dataset_val, @@ -179,7 +179,9 @@ def run(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]: **cfg.dataset.get("dataloader_params", {}), ) else: - raise ValueError("Invalid task_level") + raise ValueError( + f"Invalid task_level: {cfg.dataset.parameters.task_level}. Must be 'node', 'graph', or 'cell'." + ) # Model for us is Network + logic: inputs backbone, readout, losses log.info(f"Instantiating model <{cfg.model._target_}>") From 92a60c268fbc89eee3d4f64572af19a5f081ab42 Mon Sep 17 00:00:00 2001 From: I745505 Date: Wed, 19 Nov 2025 18:24:59 +0100 Subject: [PATCH 07/32] :children_crossing: Infer num features --- configs/dataset/simplicial/ppi_highppi.yaml | 8 ++-- topobench/run.py | 10 ++++- topobench/utils/config_resolvers.py | 48 +++++++++++++++++++++ 3 files changed, 59 insertions(+), 7 deletions(-) diff --git a/configs/dataset/simplicial/ppi_highppi.yaml b/configs/dataset/simplicial/ppi_highppi.yaml index a0f36793e..9960618f2 100644 --- a/configs/dataset/simplicial/ppi_highppi.yaml +++ b/configs/dataset/simplicial/ppi_highppi.yaml @@ -49,11 +49,9 @@ loader: # Model training configuration parameters: - # Feature dimensions: [rank-0, rank-1, ..., rank-max] - # rank-0: One-hot encoded proteins TODO: Replace with richer embedding - # rank-1: 8-dim edge features (7 interaction types + 1 confidence score) - # rank-2+: 1-dim features (binary existence) - num_features: [1553, 8, 1, 1, 1, 1] + _num_proteins: 1553 # HIGH-PPI has 1,553 proteins + + num_features: ${infer_ppi_num_features:${dataset.parameters._num_proteins},${dataset.loader.parameters.edge_task},${dataset.loader.parameters.max_complex_size}} num_classes: 2 # Depends on task: # - Higher-order (ranks 2+): 2 (exists/doesn't exist) diff --git a/topobench/run.py b/topobench/run.py index ab6f8602a..72693f9f9 100755 --- a/topobench/run.py +++ b/topobench/run.py @@ -34,6 +34,7 @@ get_required_lifting, infer_in_channels, infer_num_cell_dimensions, + infer_ppi_num_features, infer_topotune_num_cell_dimensions, ) @@ -95,6 +96,9 @@ infer_topotune_num_cell_dimensions, replace=True, ) +OmegaConf.register_new_resolver( + "infer_ppi_num_features", infer_ppi_num_features, replace=True +) OmegaConf.register_new_resolver( "parameter_multiplication", lambda x, y: int(int(x) * int(y)), replace=True ) @@ -171,7 +175,7 @@ def run(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]: ) # Prepare datamodule log.info("Instantiating datamodule...") - if cfg.dataset.parameters.task_level in ["node", "graph"]: + if cfg.dataset.parameters.task_level in ["node", "graph", "cell"]: datamodule = TBDataloader( dataset_train=dataset_train, dataset_val=dataset_val, @@ -179,7 +183,9 @@ def run(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]: **cfg.dataset.get("dataloader_params", {}), ) else: - raise ValueError("Invalid task_level") + raise ValueError( + f"Invalid task_level: {cfg.dataset.parameters.task_level}. Must be 'node', 'graph', or 'cell'." + ) # Model for us is Network + logic: inputs backbone, readout, losses log.info(f"Instantiating model <{cfg.model._target_}>") diff --git a/topobench/utils/config_resolvers.py b/topobench/utils/config_resolvers.py index 65ab69667..bafb6f6dd 100644 --- a/topobench/utils/config_resolvers.py +++ b/topobench/utils/config_resolvers.py @@ -497,3 +497,51 @@ def get_default_metrics(task, metrics=None): return ["mse", "mae"] else: raise ValueError(f"Invalid task {task}") + + +def infer_ppi_num_features(num_proteins, edge_task, max_complex_size): + r"""Infer feature dimensions for HIGH-PPI dataset. + + For simplicial complexes from HIGH-PPI: + - Rank 0 (proteins): One-hot encoding (num_proteins features) + - Rank 1 (edges): 7 or 8 features depending on edge_task + - If edge_task="score": 7 features (8th is label) + - If edge_task="type": 8 features (all features) + - Rank 2+: 1 feature (binary existence) + + Parameters + ---------- + num_proteins : int + Number of proteins (for one-hot encoding dimension). + edge_task : str + Edge task type: "score" (regression) or "type" (classification). + max_complex_size : int + Maximum number of proteins per complex (determines number of ranks). + + Returns + ------- + list + List of feature dimensions per rank: [rank_0, rank_1, rank_2, ...]. + + Examples + -------- + >>> infer_ppi_num_features(1553, "score", 6) + [1553, 7, 1, 1, 1, 1] # 7 edge features (score is label) + + >>> infer_ppi_num_features(1553, "type", 6) + [1553, 8, 1, 1, 1, 1] # 8 edge features (all features) + """ + # Rank 0: protein features (one-hot) + features = [num_proteins] + + # Rank 1: edge features (depends on task) + if edge_task == "score": + features.append(7) # 7 features, 8th (score) becomes label + else: + features.append(1) + + # Rank 2+: cell features (binary existence) + num_higher_ranks = max_complex_size - 2 # Subtract rank 0 and rank 1 + features.extend([1] * num_higher_ranks) + + return features From b705e4b01a61d02b0839467f2d9b63a4d9dfed2e Mon Sep 17 00:00:00 2001 From: I745505 Date: Wed, 19 Nov 2025 18:48:36 +0100 Subject: [PATCH 08/32] :technologist: Filter labels to high-ppi edges + infrastructure for general masking which can be used for semi-supervision --- .../data/datasets/ppi_highppi_dataset.py | 6 + .../utils/datasets/simplicial/ppi_utils.py | 17 ++ topobench/data/utils/split_utils.py | 165 ++++++++++++++---- topobench/model/model.py | 24 ++- 4 files changed, 177 insertions(+), 35 deletions(-) diff --git a/topobench/data/datasets/ppi_highppi_dataset.py b/topobench/data/datasets/ppi_highppi_dataset.py index bc439b340..e8558941f 100644 --- a/topobench/data/datasets/ppi_highppi_dataset.py +++ b/topobench/data/datasets/ppi_highppi_dataset.py @@ -253,6 +253,11 @@ def process(self): highppi_path, self.INTERACTION_TYPES ) + # Create set of HIGH-PPI edges to use as prediction-target for edge level tasks + highppi_edge_set = { + tuple(sorted([p1, p2])) for p1, p2, _, _ in self.highppi_edges + } + # Load CORUM complexes, filter to SHS27k proteins corum_path = osp.join(self.raw_dir, self.CORUM_COMPLEXES_FILE) self.corum_complexes = load_corum_complexes( @@ -288,6 +293,7 @@ def process(self): self.target_ranks, self.max_rank, edge_task=self.edge_task, + highppi_edge_set=highppi_edge_set, ) # Get connectivity diff --git a/topobench/data/utils/datasets/simplicial/ppi_utils.py b/topobench/data/utils/datasets/simplicial/ppi_utils.py index d9539c9c2..6218d4742 100644 --- a/topobench/data/utils/datasets/simplicial/ppi_utils.py +++ b/topobench/data/utils/datasets/simplicial/ppi_utils.py @@ -449,6 +449,7 @@ def build_data_features_and_labels( target_ranks: list[int], max_rank: int, edge_task: str = None, + highppi_edge_set: set[tuple] = None, ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: """Create feature and label tensors for all ranks. @@ -467,6 +468,9 @@ def build_data_features_and_labels( edge_task : str, optional Edge prediction task: "interaction_type" or "score". Only used if rank 1 is in target_ranks. + highppi_edge_set : set, optional + Set of edge tuples from HIGH-PPI. + Used to filter edges for edge-level tasks. Returns ------- @@ -501,12 +505,18 @@ def build_data_features_and_labels( # Edges: 8-dim features (7 interaction types + 1 confidence) features = [] labels = [] if is_target else None + # Track HIGH-PPI edges for loss filtering (all edge tasks) + highppi_mask = [] if (is_target and highppi_edge_set) else None for edge in cells: edge_tuple = tuple(sorted(edge)) feat_vec = edge_data[edge_tuple] if is_target: + # Track if this edge is from HIGH-PPI (not CORUM-generated) + if highppi_mask is not None: + highppi_mask.append(edge_tuple in highppi_edge_set) + # Split features/labels based on edge_task if edge_task == "interaction_type": labels.append( @@ -540,6 +550,13 @@ def build_data_features_and_labels( # For multi-label classification: 2D tensor labels_dict["cell_labels_1"] = torch.stack(labels) + # Store mask to filter out CORUM-generated edges during training + # This applies to ALL edge tasks (score regression, interaction type, etc.) + if highppi_mask is not None: + labels_dict["mask_1"] = torch.tensor( + highppi_mask, dtype=torch.bool + ) + case _: # Higher-order cells features = [] diff --git a/topobench/data/utils/split_utils.py b/topobench/data/utils/split_utils.py index cf63b50a4..e1722c94a 100644 --- a/topobench/data/utils/split_utils.py +++ b/topobench/data/utils/split_utils.py @@ -1,8 +1,10 @@ """Split utilities.""" import os +from typing import Any import numpy as np +import pandas as pd import torch from sklearn.model_selection import StratifiedKFold @@ -155,9 +157,15 @@ def random_splitting(labels, parameters, root=None, global_data_seed=42): val_indices = perm[train_num : train_num + valid_num] test_indices = perm[train_num + valid_num :] split_idx = { - "train": train_indices, - "valid": val_indices, - "test": test_indices, + "train": train_indices.numpy() + if hasattr(train_indices, "numpy") + else np.array(train_indices), + "valid": val_indices.numpy() + if hasattr(val_indices, "numpy") + else np.array(val_indices), + "test": test_indices.numpy() + if hasattr(test_indices, "numpy") + else np.array(test_indices), } # Save generated split @@ -169,13 +177,16 @@ def random_splitting(labels, parameters, root=None, global_data_seed=42): split_idx = np.load(split_path) # Check that all nodes/graph have been assigned to some split - assert np.unique( - np.array( - split_idx["train"].tolist() - + split_idx["valid"].tolist() - + split_idx["test"].tolist() - ) - ).shape[0] == len(labels), "Not all nodes within splits" + train_arr = split_idx["train"] + val_arr = split_idx["valid"] + test_arr = split_idx["test"] + + all_indices = np.concatenate([train_arr, val_arr, test_arr]) + unique_indices = np.unique(all_indices) + + assert unique_indices.shape[0] == len(labels), ( + f"Not all nodes within splits: {unique_indices.shape[0]} != {len(labels)}" + ) return split_idx @@ -257,16 +268,40 @@ def load_transductive_splits(dataset, parameters): # Single rank or node/graph prediction labels = data.y.numpy() - # Ensure labels are one dimensional array - assert len(labels.shape) == 1, "Labels should be one dimensional array" + # Check for rank-specific mask (e.g., mask_1 for edges) + # If present, split only on filtered/valid entities for honest ratios + rank_mask = None + valid_indices = None + target_ranks = getattr( + dataset, "target_ranks", [1] + ) # Default to rank 1 for edges + + if target_ranks: + rank = target_ranks[0] # Single rank case + mask_attr = f"mask_{rank}" + if hasattr(data, mask_attr): + rank_mask = getattr(data, mask_attr) # Boolean mask + valid_indices = torch.where(rank_mask)[ + 0 + ] # Original indices of valid entities + labels = labels[rank_mask.numpy()] # Filter to valid entities only + + # Handle multi-dimensional labels (e.g., multi-label classification) + if len(labels.shape) > 1: + # Use first column for stratification (common practice) + stratify_labels = ( + labels[:, 0] if labels.shape[1] > 0 else labels.flatten() + ) + else: + stratify_labels = labels root = dataset.get_data_dir() if hasattr(dataset, "get_data_dir") else None if parameters.split_type == "random": - splits = random_splitting(labels, parameters, root=root) + splits = random_splitting(stratify_labels, parameters, root=root) elif parameters.split_type == "k-fold": - splits = k_fold_split(labels, parameters, root=root) + splits = k_fold_split(stratify_labels, parameters, root=root) elif parameters.split_type == "fixed" and hasattr(dataset, "split_idx"): splits = dataset.split_idx @@ -283,9 +318,21 @@ def load_transductive_splits(dataset, parameters): ) # Assign train val test masks to the graph - data.train_mask = torch.from_numpy(splits["train"]) - data.val_mask = torch.from_numpy(splits["valid"]) - data.test_mask = torch.from_numpy(splits["test"]) + # If we filtered by rank_mask, map indices back to original positions + if valid_indices is not None: + # Splits are indices into filtered data, map back to original + train_mask = valid_indices[torch.from_numpy(splits["train"])] + val_mask = valid_indices[torch.from_numpy(splits["valid"])] + test_mask = valid_indices[torch.from_numpy(splits["test"])] + else: + # No filtering: use indices directly + train_mask = torch.from_numpy(splits["train"]) + val_mask = torch.from_numpy(splits["valid"]) + test_mask = torch.from_numpy(splits["test"]) + + data.train_mask = train_mask + data.val_mask = val_mask + data.test_mask = test_mask if parameters.get("standardize", False): # Standardize the node features respecting train mask @@ -299,12 +346,55 @@ def load_transductive_splits(dataset, parameters): return DataloadDataset([data]), None, None -def load_multirank_transductive_splits(dataset, parameters): +def get_multilabel_stratification_targets( + labels: np.ndarray | pd.DataFrame, +) -> np.ndarray: + """Generate a single stratification target vector for multi-label data. + + For multi-label classification, uses the index of the most frequent label + per sample (argmax). This is simpler and more robust than Label Powerset, + avoiding issues with rare label combinations. + + Parameters + ---------- + labels : np.ndarray or pd.DataFrame + The multi-label target array (2D) or vector (1D). + Can be a NumPy array or Pandas DataFrame. + + Returns + ------- + np.ndarray + A 1D array suitable for the 'stratify' parameter in sklearn. + For 1D input: returns as-is. + For 2D input: returns argmax (most frequent label index). + """ + # Standardize input to NumPy array + if isinstance(labels, pd.DataFrame): + labels = labels.values + + # Handle 1D arrays (standard classification) + if labels.ndim == 1: + return labels + + # Handle 2D arrays (multi-label classification) + if labels.ndim == 2 and labels.shape[1] > 1: + # Use argmax: index of most frequent label (or first '1' for binary) + # This ensures stratification works even with rare label combinations + return labels.argmax(axis=1) + + # Fallback for 2D arrays with single column + return labels.flatten() + + +def load_multirank_transductive_splits( + dataset, parameters +) -> tuple[list[Any], None, None]: r"""Load dataset with multi-rank cell-level splits. For datasets with cell-level predictions across multiple ranks (e.g., edges, triangles, tetrahedra simultaneously), this function creates independent - train/val/test splits for each rank. + train/val/test splits for each rank on valid entities (filtered by masks) + using multi-label stratification. Parameters ---------- @@ -350,14 +440,20 @@ def load_multirank_transductive_splits(dataset, parameters): labels = getattr(data, label_attr).numpy() - # Handle multi-dimensional labels (e.g., multi-label classification) - if len(labels.shape) > 1: - # Use first column for stratification (common practice) - stratify_labels = ( - labels[:, 0] if labels.shape[1] > 0 else labels.flatten() - ) - else: - stratify_labels = labels + # Check for rank-specific mask + # If present, split only on filtered entities for honest ratios + rank_mask = None + valid_indices = None + mask_attr = f"mask_{rank}" + + if hasattr(data, mask_attr): + rank_mask = getattr(data, mask_attr) # Boolean mask + valid_indices = torch.where(rank_mask)[ + 0 + ] # Original indices of valid entities + labels = labels[rank_mask.numpy()] # Filter to valid entities only + + stratify_labels = get_multilabel_stratification_targets(labels) # Create rank-specific root directory for splits # This ensures each rank gets independent splits @@ -387,14 +483,23 @@ def load_multirank_transductive_splits(dataset, parameters): ) # Store per-rank masks - train_mask = torch.from_numpy(splits["train"]) - val_mask = torch.from_numpy(splits["valid"]) - test_mask = torch.from_numpy(splits["test"]) + # If we filtered by rank_mask, map indices back to original positions + if valid_indices is not None: + # Splits are indices into filtered data, map back to original + train_mask = valid_indices[torch.from_numpy(splits["train"])] + val_mask = valid_indices[torch.from_numpy(splits["valid"])] + test_mask = valid_indices[torch.from_numpy(splits["test"])] + else: + # No filtering: use indices directly + train_mask = torch.from_numpy(splits["train"]) + val_mask = torch.from_numpy(splits["valid"]) + test_mask = torch.from_numpy(splits["test"]) setattr(data, f"train_mask_{rank}", train_mask) setattr(data, f"val_mask_{rank}", val_mask) setattr(data, f"test_mask_{rank}", test_mask) + # Assumes DataloadDataset is available in scope return DataloadDataset([data]), None, None diff --git a/topobench/model/model.py b/topobench/model/model.py index a7c688b47..b5f3313c0 100755 --- a/topobench/model/model.py +++ b/topobench/model/model.py @@ -242,11 +242,25 @@ def process_outputs(self, model_out: dict, batch: Data) -> dict: else: raise ValueError("Invalid state_str") - if self.task_level == "node": - # Keep only train data points - for key, val in model_out.items(): - if key in ["logits", "labels"]: - model_out[key] = val[mask] + if self.task_level in ["node", "cell"]: + # Keep only train data points (for node-level or cell-level tasks) + # Note: Rank-specific masks are applied in readout + # The readout stores which indices it kept in cell_indices + if "cell_indices" in model_out: + # Find intersection: which readout outputs are in this split? + # Note: The split respects the masks applied in the readout + cell_indices = model_out["cell_indices"] + keep_mask = torch.isin(cell_indices, mask) + + # Filter logits and labels + for key, val in model_out.items(): + if key in ["logits", "labels"]: + model_out[key] = val[keep_mask] + else: + # No cell_indices: standard filtering (for non-masked tasks) + for key, val in model_out.items(): + if key in ["logits", "labels"]: + model_out[key] = val[mask] return model_out From 88efcb30ba4359ecc657f948ddb3dae66fff6cd3 Mon Sep 17 00:00:00 2001 From: I745505 Date: Wed, 19 Nov 2025 19:44:51 +0100 Subject: [PATCH 09/32] :sparkles: Extend SCCNN to work for arbitrary ranks --- topobench/nn/backbones/simplicial/sccnn.py | 523 +++++++++++---------- 1 file changed, 279 insertions(+), 244 deletions(-) diff --git a/topobench/nn/backbones/simplicial/sccnn.py b/topobench/nn/backbones/simplicial/sccnn.py index 72bac2902..cff5830dd 100644 --- a/topobench/nn/backbones/simplicial/sccnn.py +++ b/topobench/nn/backbones/simplicial/sccnn.py @@ -13,13 +13,13 @@ class SCCNNCustom(torch.nn.Module): Parameters ---------- in_channels_all : tuple of int - Dimension of input features on (nodes, edges, faces). + Dimension of input features on each rank (nodes, edges, faces, ...). hidden_channels_all : tuple of int - Dimension of features of hidden layers on (nodes, edges, faces). + Dimension of features of hidden layers on each rank. conv_order : int Order of convolutions, we consider the same order for all convolutions. sc_order : int - Order of simplicial complex. + Order of simplicial complex (max_rank + 1). aggr_norm : bool, optional Whether to normalize the aggregation (default: False). update_func : str, optional @@ -39,16 +39,15 @@ def __init__( n_layers=2, ): super().__init__() - # first layer - # we use an MLP to map the features on simplices of different dimensions to the same dimension - self.in_linear_0 = torch.nn.Linear( - in_channels_all[0], hidden_channels_all[0] - ) - self.in_linear_1 = torch.nn.Linear( - in_channels_all[1], hidden_channels_all[1] - ) - self.in_linear_2 = torch.nn.Linear( - in_channels_all[2], hidden_channels_all[2] + + self.max_rank = len(in_channels_all) - 1 + + # Create input linear layers for each rank dynamically + self.in_linears = torch.nn.ModuleList( + [ + torch.nn.Linear(in_channels_all[i], hidden_channels_all[i]) + for i in range(len(in_channels_all)) + ] ) self.layers = torch.nn.ModuleList( @@ -69,28 +68,29 @@ def forward(self, x_all, laplacian_all, incidence_all): Parameters ---------- x_all : tuple(tensors) - Tuple of feature tensors (node, edge, face). + Tuple of feature tensors for each rank (x_0, x_1, ..., x_k). laplacian_all : tuple(tensors) - Tuple of Laplacian tensors (graph laplacian L0, down edge laplacian L1_d, upper edge laplacian L1_u, face laplacian L2). + Tuple of Laplacian tensors. incidence_all : tuple(tensors) - Tuple of order 1 and 2 incidence matrices. + Tuple of incidence matrices. Returns ------- tuple(tensors) - Tuple of final hidden state tensors (node, edge, face). + Tuple of final hidden state tensors for each rank. """ - x_0, x_1, x_2 = x_all - in_x_0 = self.in_linear_0(x_0) - in_x_1 = self.in_linear_1(x_1) - in_x_2 = self.in_linear_2(x_2) + # Apply input linear transformations to each rank + x_all_transformed = tuple( + self.in_linears[i](x_all[i]) for i in range(len(x_all)) + ) - # Forward through SCCNN - x_all = (in_x_0, in_x_1, in_x_2) + # Forward through SCCNN layers for layer in self.layers: - x_all = layer(x_all, laplacian_all, incidence_all) + x_all_transformed = layer( + x_all_transformed, laplacian_all, incidence_all + ) - return x_all + return x_all_transformed class SCCNNLayer(torch.nn.Module): @@ -99,13 +99,13 @@ class SCCNNLayer(torch.nn.Module): Parameters ---------- in_channels : tuple of int - Dimensions of input features on nodes, edges, and faces. + Dimensions of input features for each rank. out_channels : tuple of int - Dimensions of output features on nodes, edges, and faces. + Dimensions of output features for each rank. conv_order : int Convolution order of the simplicial filters. sc_order : int - SC order. + SC order (max_rank + 1). aggr_norm : bool, optional Whether to normalize the aggregated message by the neighborhood size (default: False). update_func : str, optional @@ -126,15 +126,9 @@ def __init__( ) -> None: super().__init__() - in_channels_0, in_channels_1, in_channels_2 = in_channels - out_channels_0, out_channels_1, out_channels_2 = out_channels - - self.in_channels_0 = in_channels_0 - self.in_channels_1 = in_channels_1 - self.in_channels_2 = in_channels_2 - self.out_channels_0 = out_channels_0 - self.out_channels_1 = out_channels_1 - self.out_channels_2 = out_channels_2 + self.in_channels = tuple(in_channels) + self.out_channels = tuple(out_channels) + self.max_rank = len(in_channels) - 1 self.conv_order = conv_order self.sc_order = sc_order @@ -146,47 +140,65 @@ def __init__( assert initialization in ["xavier_uniform", "xavier_normal"] assert self.conv_order > 0 - self.weight_0 = Parameter( - torch.Tensor( - self.in_channels_0, - self.out_channels_0, - 1 + conv_order + 1 + conv_order, - ) - ) - - self.weight_1 = Parameter( - torch.Tensor( - self.in_channels_1, - self.out_channels_1, - 6 * conv_order + 3, - ) - ) + # Create weight parameters for each rank + self.weights = torch.nn.ParameterList() - # determine the third dimensions of the weights - # because when SC order is larger than 2, there are lower and upper - # parts for L_2; otherwise, L_2 contains only the lower part + for rank in range(self.max_rank + 1): + # Calculate weight tensor dimensions based on message types + # For rank k, we have: + # - Identity: 1 + # - Self convolutions: conv_order (down) + conv_order (up) for k>0, or just conv_order for k=0 + # - Lower messages (from k-1): 1 + conv_order (identity + convolutions) + # - Upper messages (from k+1): 1 + conv_order (identity + convolutions) - if sc_order > 2: - self.weight_2 = Parameter( - torch.Tensor( - self.in_channels_2, - self.out_channels_2, - 4 * conv_order - + 2, # in the future for arbitrary sc_order we should have this 6*conv_order + 3, - ) - ) + num_message_types = self._compute_message_types(rank) - elif sc_order == 2: - self.weight_2 = Parameter( + weight = Parameter( torch.Tensor( - self.in_channels_2, - self.out_channels_2, - 4 * conv_order + 2, + self.in_channels[rank], + self.out_channels[rank], + num_message_types, ) ) + self.weights.append(weight) self.reset_parameters() + def _compute_message_types(self, rank): + """Compute the maximum number of message types for a given rank. + + Parameters + ---------- + rank : int + Rank to consider. + + Returns + ------- + int + Number of message types for the given rank. + """ + count = 0 + + # Self messages + if rank == 0: + # Rank 0: identity + Hodge Laplacian convolutions + count += 1 + self.conv_order + else: + # Rank k>0: identity + down Laplacian + up Laplacian convolutions + count += 1 + self.conv_order + self.conv_order + + # Lower messages (from rank-1 projected to rank) + if rank > 0: + # Identity + convolutions with down/up Laplacians at current rank + count += 1 + self.conv_order + + # Upper messages (from rank+1 projected to rank) + if rank < self.max_rank: + # Identity + convolutions with down/up Laplacians at current rank + count += 1 + self.conv_order + + return count + def reset_parameters(self, gain: float = 1.414): r"""Reset learnable parameters. @@ -196,13 +208,11 @@ def reset_parameters(self, gain: float = 1.414): Gain for the weight initialization. """ if self.initialization == "xavier_uniform": - torch.nn.init.xavier_uniform_(self.weight_0, gain=gain) - torch.nn.init.xavier_uniform_(self.weight_1, gain=gain) - torch.nn.init.xavier_uniform_(self.weight_2, gain=gain) + for weight in self.weights: + torch.nn.init.xavier_uniform_(weight, gain=gain) elif self.initialization == "xavier_normal": - torch.nn.init.xavier_normal_(self.weight_0, gain=gain) - torch.nn.init.xavier_normal_(self.weight_1, gain=gain) - torch.nn.init.xavier_normal_(self.weight_2, gain=gain) + for weight in self.weights: + torch.nn.init.xavier_normal_(weight, gain=gain) else: raise RuntimeError( "Initialization method not recognized. " @@ -286,201 +296,226 @@ def chebyshev_conv(self, conv_operator, conv_order, x): return X def forward(self, x_all, laplacian_all, incidence_all): - r"""Forward computation. + r"""Forward computation for arbitrary ranks. Parameters ---------- x_all : tuple of tensors - Tuple of input feature tensors (node, edge, face). + Tuple of input feature tensors for each rank. laplacian_all : tuple of tensors - Tuple of Laplacian tensors (graph laplacian L0, down edge laplacian L1_d, upper edge laplacian L1_u, face laplacian L2). + Tuple of Laplacian tensors organized as: + (L_0, L_down_1, L_up_1, L_down_2, L_up_2, ...). incidence_all : tuple of tensors - Tuple of order 1 and 2 incidence matrices. + Tuple of incidence matrices (B_1, B_2, ..., B_k). Returns ------- - torch.Tensor - Output tensor for each 0-cell. - torch.Tensor - Output tensor for each 1-cell. - torch.Tensor - Output tensor for each 2-cell. - """ - x_0, x_1, x_2 = x_all - - if self.sc_order == 2: - laplacian_0, laplacian_down_1, laplacian_up_1, laplacian_2 = ( - laplacian_all - ) - elif self.sc_order > 2: - ( - laplacian_0, - laplacian_down_1, - laplacian_up_1, - laplacian_down_2, - laplacian_up_2, - ) = laplacian_all - - # num_nodes, num_edges, num_triangles = x_0.shape[0], x_1.shape[0], x_2.shape[0] - - b1, b2 = incidence_all - - # identity_0, identity_1, identity_2 = ( - # torch.eye(num_nodes).to(x_0.device), - # torch.eye(num_edges).to(x_0.device), - # torch.eye(num_triangles).to(x_0.device), - # ) - """ - Convolution in the node space + tuple of tensors + Output tensors for each rank after message passing. """ - # -----------Logic to obtain update for 0-cells -------- - # x_identity_0 = torch.unsqueeze(identity_0 @ x_0, 2) - # x_0_to_0 = self.chebyshev_conv(laplacian_0, self.conv_order, x_0) - # x_0_to_0 = torch.cat((x_identity_0, x_0_to_0), 2) - - x_0_laplacian = self.chebyshev_conv(laplacian_0, self.conv_order, x_0) - x_0_to_0 = torch.cat([x_0.unsqueeze(2), x_0_laplacian], dim=2) - # ------------------- - - # x_1_to_0 = torch.mm(b1, x_1) - # x_1_to_0_identity = torch.unsqueeze(identity_0 @ x_1_to_0, 2) - # x_1_to_0 = self.chebyshev_conv(laplacian_0, self.conv_order, x_1_to_0) - # x_1_to_0 = torch.cat((x_1_to_0_identity, x_1_to_0), 2) - - x_1_to_0_upper = torch.mm(b1, x_1) - x_1_to_0_laplacian = self.chebyshev_conv( - laplacian_0, self.conv_order, x_1_to_0_upper - ) - x_1_to_0 = torch.cat( - [x_1_to_0_upper.unsqueeze(2), x_1_to_0_laplacian], dim=2 - ) - # ------------------- - - x_0_all = torch.cat((x_0_to_0, x_1_to_0), 2) - - # ------------------- - """ - Convolution in the edge space - """ - - # -----------Logic to obtain update for 1-cells -------- - # x_identity_1 = torch.unsqueeze(identity_1 @ x_1, 2) - # x_1_down = self.chebyshev_conv(laplacian_down_1, self.conv_order, x_1) - # x_1_up = self.chebyshev_conv(laplacian_up_1, self.conv_order, x_1) - # x_1_to_1 = torch.cat((x_identity_1, x_1_down, x_1_up), 2) - - x_1_down = self.chebyshev_conv(laplacian_down_1, self.conv_order, x_1) - x_1_up = self.chebyshev_conv(laplacian_down_1, self.conv_order, x_1) - x_1_to_1 = torch.cat((x_1.unsqueeze(2), x_1_down, x_1_up), 2) - - # ------------------- - - # x_0_to_1 = torch.mm(b1.T, x_0) - # x_0_to_1_identity = torch.unsqueeze(identity_1 @ x_0_to_1, 2) - # x_0_to_1 = self.chebyshev_conv(laplacian_down_1, self.conv_order, x_0_to_1) - # x_0_to_1 = torch.cat((x_0_to_1_identity, x_0_to_1), 2) + outputs = [] + + for rank in range(self.max_rank + 1): + x_rank = x_all[rank] + + # Skip empty ranks (no cells at this dimension) + if x_rank.shape[0] == 0: + # Create empty output tensor for this rank + outputs.append( + torch.zeros( + 0, self.out_channels[rank], device=x_rank.device + ) + ) + continue - # Lower projection - x_0_1_lower = torch.mm(b1.T, x_0) + # Get Laplacians for this rank + laplacians = self._get_laplacians_for_rank(rank, laplacian_all) - # Calculate lowwer chebyshev_conv - x_0_1_down = self.chebyshev_conv( - laplacian_down_1, self.conv_order, x_0_1_lower - ) + # Get incidence matrices + incidence_lower = ( + incidence_all[rank - 1] + if rank > 0 and rank - 1 < len(incidence_all) + else None + ) + incidence_upper = ( + incidence_all[rank] if rank < len(incidence_all) else None + ) - # Calculate upper chebyshev_conv (Note: in case of signed incidence should be always zero) - x_0_1_up = self.chebyshev_conv( - laplacian_up_1, self.conv_order, x_0_1_lower - ) + # Compute all messages for this rank + messages = self._compute_messages_for_rank( + rank, + x_rank, + x_all, + laplacians, + incidence_lower, + incidence_upper, + ) - # Concatenate output of filters - x_0_to_1 = torch.cat( - [x_0_1_lower.unsqueeze(2), x_0_1_down, x_0_1_up], dim=2 - ) - # ------------------- + # Apply weight and aggregate + # Use only the first k dimensions of weights that match the number of messages + num_messages = messages.shape[2] + weight_slice = self.weights[rank][:, :, :num_messages] + y_rank = torch.einsum("nik,iok->no", messages, weight_slice) - # x_2_to_1 = torch.mm(b2, x_2) - # x_2_to_1_identity = torch.unsqueeze(identity_1 @ x_2_to_1, 2) - # x_2_to_1 = self.chebyshev_conv(laplacian_up_1, self.conv_order, x_2_to_1) - # x_2_to_1 = torch.cat((x_2_to_1_identity, x_2_to_1), 2) + # Apply activation if specified + if self.update_func is not None: + y_rank = self.update(y_rank) - x_2_1_upper = torch.mm(b2, x_2) + outputs.append(y_rank) - # Calculate lowwer chebyshev_conv (Note: In case of signed incidence should be always zero) - x_2_1_down = self.chebyshev_conv( - laplacian_down_1, self.conv_order, x_2_1_upper - ) + return tuple(outputs) - # Calculate upper chebyshev_conv - x_2_1_up = self.chebyshev_conv( - laplacian_up_1, self.conv_order, x_2_1_upper - ) + def _get_laplacians_for_rank(self, rank, laplacian_all): + """Extract Laplacians for a given rank from laplacian_all. - x_2_to_1 = torch.cat( - [x_2_1_upper.unsqueeze(2), x_2_1_down, x_2_1_up], dim=2 - ) + Parameters + ---------- + rank : int + The rank to extract Laplacians for. + laplacian_all : tuple + All Laplacians organized as (L_0, L_down_1, L_up_1, L_down_2, L_up_2, ...). - # ------------------- - x_1_all = torch.cat((x_0_to_1, x_1_to_1, x_2_to_1), 2) - """Convolution in the face (triangle) space, depending on the SC order, - the exact form maybe a little different.""" - # -------------------Logic to obtain update for 2-cells -------- - # x_identity_2 = torch.unsqueeze(identity_2 @ x_2, 2) - - # if self.sc_order == 2: - # x_2 = self.chebyshev_conv(laplacian_2, self.conv_order, x_2) - # x_2_to_2 = torch.cat((x_identity_2, x_2), 2) - # elif self.sc_order > 2: - # x_2_down = self.chebyshev_conv(laplacian_down_2, self.conv_order, x_2) - # x_2_up = self.chebyshev_conv(laplacian_up_2, self.conv_order, x_2) - # x_2_to_2 = torch.cat((x_identity_2, x_2_down, x_2_up), 2) - x_2_down = self.chebyshev_conv(laplacian_down_2, self.conv_order, x_2) - x_2_up = self.chebyshev_conv(laplacian_up_2, self.conv_order, x_2) - x_2_to_2 = torch.cat((x_2.unsqueeze(2), x_2_down, x_2_up), 2) - - # ------------------- - - # x_1_to_2 = torch.mm(b2.T, x_1) - # x_1_to_2_identity = torch.unsqueeze(identity_2 @ x_1_to_2, 2) - # if self.sc_order == 2: - # x_1_to_2 = self.chebyshev_conv(laplacian_2, self.conv_order, x_1_to_2) - # elif self.sc_order > 2: - # x_1_to_2 = self.chebyshev_conv(laplacian_down_2, self.conv_order, x_1_to_2) - # x_1_to_2 = torch.cat((x_1_to_2_identity, x_1_to_2), 2) - - x_1_2_lower = torch.mm(b2.T, x_1) - x_1_2_down = self.chebyshev_conv( - laplacian_down_2, self.conv_order, x_1_2_lower - ) - x_1_2_down = self.chebyshev_conv( - laplacian_up_2, self.conv_order, x_1_2_lower - ) + Returns + ------- + dict + Dictionary with keys 'hodge', 'down', 'up' containing the relevant Laplacians. + """ + laplacians = {} - x_1_to_2 = torch.cat( - [x_1_2_lower.unsqueeze(2), x_1_2_down, x_1_2_down], dim=2 - ) + if rank == 0: + # Rank 0 only has Hodge Laplacian + laplacians["hodge"] = ( + laplacian_all[0] if len(laplacian_all) > 0 else None + ) + laplacians["down"] = None + laplacians["up"] = None + else: + # For rank k > 0: index is 1 + 2*(k-1) for down, 1 + 2*(k-1) + 1 for up + idx_down = 1 + 2 * (rank - 1) + idx_up = idx_down + 1 + + laplacians["hodge"] = None + laplacians["down"] = ( + laplacian_all[idx_down] + if idx_down < len(laplacian_all) + else None + ) + laplacians["up"] = ( + laplacian_all[idx_up] if idx_up < len(laplacian_all) else None + ) - # That is my code, but to execute this part we need to have simplices order of k+1 in this case order of 3 - # x_3_2_upper = x_1_to_2 = torch.mm(b2, x_3) - # x_3_2_down = self.chebyshev_conv(laplacian_down_2, self.conv_order, x_3_2_upper) - # x_3_2_up = self.chebyshev_conv(laplacian_up_2, self.conv_order, x_3_2_upper) + return laplacians - # x_3_to_2 = torch.cat([x_3_2_upper.unsueeze(2), x_3_2_down, x_3_2_up], dim=2) + def _compute_messages_for_rank( + self, rank, x_rank, x_all, laplacians, incidence_lower, incidence_upper + ): + """Compute all messages for a given rank. - # ------------------- + Parameters + ---------- + rank : int + The rank to compute messages for. + x_rank : tensor + Features of cells at this rank. + x_all : tuple + Features of all ranks. + laplacians : dict + Dictionary of Laplacians for this rank. + incidence_lower : tensor or None + Incidence matrix from rank-1 to rank. + incidence_upper : tensor or None + Incidence matrix from rank to rank+1. - x_2_all = torch.cat([x_1_to_2, x_2_to_2], dim=2) - # The final version of this model should have the following line - # x_2_all = torch.cat([x_1_to_2, x_2_to_2, x_3_to_2], dim=2) + Returns + ------- + tensor + Concatenated messages of shape (num_cells, num_channels, num_message_types). + """ + message_list = [] - # ------------------- + # 1. Self messages (identity + convolutions) + if rank == 0: + # Identity message + message_list.append(x_rank.unsqueeze(2)) - # Need to check that this einsums are correct - y_0 = torch.einsum("nik,iok->no", x_0_all, self.weight_0) - y_1 = torch.einsum("nik,iok->no", x_1_all, self.weight_1) - y_2 = torch.einsum("nik,iok->no", x_2_all, self.weight_2) + # Hodge Laplacian convolutions + if laplacians["hodge"] is not None: + x_conv = self.chebyshev_conv( + laplacians["hodge"], self.conv_order, x_rank + ) + message_list.append(x_conv) + else: + # Identity message + message_list.append(x_rank.unsqueeze(2)) - if self.update_func is None: - return y_0, y_1, y_2 + # Down Laplacian convolutions + if laplacians["down"] is not None: + x_down = self.chebyshev_conv( + laplacians["down"], self.conv_order, x_rank + ) + message_list.append(x_down) - return self.update(y_0), self.update(y_1), self.update(y_2) + # Up Laplacian convolutions + if laplacians["up"] is not None: + x_up = self.chebyshev_conv( + laplacians["up"], self.conv_order, x_rank + ) + message_list.append(x_up) + + # 2. Lower messages (from rank-1) + if rank > 0 and incidence_lower is not None and rank - 1 < len(x_all): + x_lower = x_all[rank - 1] + # Only process if lower rank is not empty + if x_lower.shape[0] > 0: + # Project features from rank-1 to rank + x_lower_proj = torch.mm(incidence_lower.T, x_lower) + + # Identity + message_list.append(x_lower_proj.unsqueeze(2)) + + # Apply Laplacian convolutions at the current rank + # Use the appropriate Laplacian (down for rank 0, down for rank > 0) + if rank == 0: + if laplacians["hodge"] is not None: + x_lower_conv = self.chebyshev_conv( + laplacians["hodge"], + self.conv_order, + x_lower_proj, + ) + message_list.append(x_lower_conv) + else: + if laplacians["down"] is not None: + x_lower_conv = self.chebyshev_conv( + laplacians["down"], + self.conv_order, + x_lower_proj, + ) + message_list.append(x_lower_conv) + + # 3. Upper messages (from rank+1) + if ( + rank < self.max_rank + and incidence_upper is not None + and rank + 1 < len(x_all) + ): + x_upper = x_all[rank + 1] + # Only process if upper rank is not empty + if x_upper.shape[0] > 0: + # Project features from rank+1 to rank + x_upper_proj = torch.mm(incidence_upper, x_upper) + + # Identity + message_list.append(x_upper_proj.unsqueeze(2)) + + # Apply Laplacian convolutions at the current rank + # Use up Laplacian for rank > 0 + if laplacians["up"] is not None: + x_upper_conv = self.chebyshev_conv( + laplacians["up"], self.conv_order, x_upper_proj + ) + message_list.append(x_upper_conv) + + # Concatenate all messages + messages = torch.cat(message_list, dim=2) + + return messages From 4036d397b537a702d5743b0a883011672a661573 Mon Sep 17 00:00:00 2001 From: I745505 Date: Wed, 19 Nov 2025 21:32:23 +0100 Subject: [PATCH 10/32] :memo: Include edge_task in name to prevent cache conflicts --- topobench/data/datasets/ppi_highppi_dataset.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/topobench/data/datasets/ppi_highppi_dataset.py b/topobench/data/datasets/ppi_highppi_dataset.py index e8558941f..8d1e91631 100644 --- a/topobench/data/datasets/ppi_highppi_dataset.py +++ b/topobench/data/datasets/ppi_highppi_dataset.py @@ -167,17 +167,18 @@ def raw_file_names(self) -> list[str]: def processed_file_names(self) -> list[str]: """Return the name of the processed file. - Filename includes target_ranks to avoid cache conflicts when - different ranks are requested. + Filename includes target_ranks and edge_task to avoid cache conflicts + when different ranks or tasks are requested. Returns ------- List[str] List containing the name of the processed file. """ - # Include target_ranks in filename to prevent cache conflicts + # Include target_ranks and edge_task in filename to prevent cache conflicts ranks_str = "_".join(map(str, self.target_ranks)) - return [f"data_ranks_{ranks_str}.pt"] + task_str = self.edge_task if self.edge_task else "none" + return [f"data_ranks_{ranks_str}_task_{task_str}.pt"] def download(self) -> None: """Download HIGH-PPI and CORUM data files.""" From 586e8e0f8dacf481a98c494344cf6170cb79a44b Mon Sep 17 00:00:00 2001 From: I745505 Date: Wed, 19 Nov 2025 21:38:30 +0100 Subject: [PATCH 11/32] :bug: Message passing --- topobench/nn/backbones/simplicial/sccnn.py | 107 ++++++++++----------- 1 file changed, 52 insertions(+), 55 deletions(-) diff --git a/topobench/nn/backbones/simplicial/sccnn.py b/topobench/nn/backbones/simplicial/sccnn.py index cff5830dd..3033ea944 100644 --- a/topobench/nn/backbones/simplicial/sccnn.py +++ b/topobench/nn/backbones/simplicial/sccnn.py @@ -181,7 +181,7 @@ def _compute_message_types(self, rank): # Self messages if rank == 0: - # Rank 0: identity + Hodge Laplacian convolutions + # Rank 0: identity + up Laplacian convolutions count += 1 + self.conv_order else: # Rank k>0: identity + down Laplacian + up Laplacian convolutions @@ -189,13 +189,17 @@ def _compute_message_types(self, rank): # Lower messages (from rank-1 projected to rank) if rank > 0: - # Identity + convolutions with down/up Laplacians at current rank - count += 1 + self.conv_order + # Identity + convolutions with down and up Laplacians at current rank + count += 1 + self.conv_order + self.conv_order # Upper messages (from rank+1 projected to rank) if rank < self.max_rank: - # Identity + convolutions with down/up Laplacians at current rank - count += 1 + self.conv_order + # Identity + convolutions with down and up Laplacians at current rank + # Special case: rank 0 only has up Laplacian + if rank == 0: + count += 1 + self.conv_order + else: + count += 1 + self.conv_order + self.conv_order return count @@ -433,65 +437,50 @@ def _compute_messages_for_rank( """ message_list = [] - # 1. Self messages (identity + convolutions) + # 1. Lower messages (from rank-1) + if rank > 0 and incidence_lower is not None and rank - 1 < len(x_all): + x_lower = x_all[rank - 1] + # Only process if lower rank is not empty + if x_lower.shape[0] > 0: + # Project features from rank-1 to rank + x_lower_proj = torch.mm(incidence_lower.T, x_lower) + + message_list.append(x_lower_proj.unsqueeze(2)) + + # Apply down and up Laplacians + if laplacians["down"] is not None: + x_lower_down = self.chebyshev_conv( + laplacians["down"], self.conv_order, x_lower_proj + ) + message_list.append(x_lower_down) + + if laplacians["up"] is not None: + x_lower_up = self.chebyshev_conv( + laplacians["up"], self.conv_order, x_lower_proj + ) + message_list.append(x_lower_up) + + # 2. Self messages (identity + convolutions) if rank == 0: - # Identity message message_list.append(x_rank.unsqueeze(2)) - - # Hodge Laplacian convolutions if laplacians["hodge"] is not None: x_conv = self.chebyshev_conv( laplacians["hodge"], self.conv_order, x_rank ) message_list.append(x_conv) else: - # Identity message message_list.append(x_rank.unsqueeze(2)) - - # Down Laplacian convolutions if laplacians["down"] is not None: x_down = self.chebyshev_conv( laplacians["down"], self.conv_order, x_rank ) message_list.append(x_down) - - # Up Laplacian convolutions if laplacians["up"] is not None: x_up = self.chebyshev_conv( laplacians["up"], self.conv_order, x_rank ) message_list.append(x_up) - # 2. Lower messages (from rank-1) - if rank > 0 and incidence_lower is not None and rank - 1 < len(x_all): - x_lower = x_all[rank - 1] - # Only process if lower rank is not empty - if x_lower.shape[0] > 0: - # Project features from rank-1 to rank - x_lower_proj = torch.mm(incidence_lower.T, x_lower) - - # Identity - message_list.append(x_lower_proj.unsqueeze(2)) - - # Apply Laplacian convolutions at the current rank - # Use the appropriate Laplacian (down for rank 0, down for rank > 0) - if rank == 0: - if laplacians["hodge"] is not None: - x_lower_conv = self.chebyshev_conv( - laplacians["hodge"], - self.conv_order, - x_lower_proj, - ) - message_list.append(x_lower_conv) - else: - if laplacians["down"] is not None: - x_lower_conv = self.chebyshev_conv( - laplacians["down"], - self.conv_order, - x_lower_proj, - ) - message_list.append(x_lower_conv) - # 3. Upper messages (from rank+1) if ( rank < self.max_rank @@ -499,21 +488,29 @@ def _compute_messages_for_rank( and rank + 1 < len(x_all) ): x_upper = x_all[rank + 1] - # Only process if upper rank is not empty if x_upper.shape[0] > 0: - # Project features from rank+1 to rank x_upper_proj = torch.mm(incidence_upper, x_upper) - - # Identity message_list.append(x_upper_proj.unsqueeze(2)) - # Apply Laplacian convolutions at the current rank - # Use up Laplacian for rank > 0 - if laplacians["up"] is not None: - x_upper_conv = self.chebyshev_conv( - laplacians["up"], self.conv_order, x_upper_proj - ) - message_list.append(x_upper_conv) + # Apply Laplacians (Hodge for rank 0, both down/up for rank > 0) + if rank == 0: + if laplacians["hodge"] is not None: + x_upper_hodge = self.chebyshev_conv( + laplacians["hodge"], self.conv_order, x_upper_proj + ) + message_list.append(x_upper_hodge) + else: + if laplacians["down"] is not None: + x_upper_down = self.chebyshev_conv( + laplacians["down"], self.conv_order, x_upper_proj + ) + message_list.append(x_upper_down) + + if laplacians["up"] is not None: + x_upper_up = self.chebyshev_conv( + laplacians["up"], self.conv_order, x_upper_proj + ) + message_list.append(x_upper_up) # Concatenate all messages messages = torch.cat(message_list, dim=2) From 95a9e06beb046d771d422605ea335d8a69c3627e Mon Sep 17 00:00:00 2001 From: I745505 Date: Wed, 19 Nov 2025 22:51:02 +0100 Subject: [PATCH 12/32] :sparkles: SCCNN cell wrapper using higher cell features and allowing for higher level cell-prediction --- .../wrappers/simplicial/sccnn_cell_wrapper.py | 167 ++++++++++++++++++ 1 file changed, 167 insertions(+) create mode 100644 topobench/nn/wrappers/simplicial/sccnn_cell_wrapper.py diff --git a/topobench/nn/wrappers/simplicial/sccnn_cell_wrapper.py b/topobench/nn/wrappers/simplicial/sccnn_cell_wrapper.py new file mode 100644 index 000000000..552180045 --- /dev/null +++ b/topobench/nn/wrappers/simplicial/sccnn_cell_wrapper.py @@ -0,0 +1,167 @@ +"""Wrapper for SCCNN with cell-level predictions. + +This wrapper is designed for transductive learning where predictions are made +on specific cells (simplices) rather than entire graphs or individual nodes. +""" + +import torch +from torch_geometric.data import Data + +from topobench.nn.wrappers.base import AbstractWrapper + + +class SCCNNCellWrapper(AbstractWrapper): + """Wrapper for SCCNN backbone with cell-level outputs. + + Unlike standard wrappers that focus on node features (x_0), this wrapper + preserves features at ALL ranks for cell-level prediction. + + Parameters + ---------- + backbone : nn.Module + The SCCNN backbone model. + num_cell_dimensions : int + Rank +1 of the simplicial complex. + target_ranks : list[int] + Which ranks have labels to predict. + **kwargs : dict + Additional arguments. + """ + + def __init__( + self, + backbone: torch.nn.Module, + num_cell_dimensions: int, + target_ranks: list, + **kwargs, + ): + # Ensure required parameters for base class + if "out_channels" not in kwargs: + kwargs["out_channels"] = 32 # Default value + kwargs["num_cell_dimensions"] = num_cell_dimensions + # Disable residual connections for cell-level prediction + # (we just pass through features) + kwargs["residual_connections"] = kwargs.get( + "residual_connections", False + ) + + super().__init__(backbone, **kwargs) + self.target_ranks = target_ranks + self.num_cell_dimensions = num_cell_dimensions + + def __repr__(self): + return f"{self.__class__.__name__}(target_ranks={self.target_ranks})" + + def forward(self, batch: Data) -> dict: + """Forward pass preserving all rank features. + + Parameters + ---------- + batch : Data + Batch object containing features x_0, x_1, ..., x_k, Laplacians, and incidences. + + Returns + ------- + dict + The model_out containing updated features x_0, x_1, ..., x_k. + """ + # Extract features for all ranks from 0 to num_cell_dimensions-1 = rank + x_all = [] + for i in range(self.num_cell_dimensions): + x_key = f"x_{i}" + if hasattr(batch, x_key): + x_all.append(getattr(batch, x_key)) + else: + # If rank doesn't exist, add empty tensor + x_all.append(torch.zeros(0, 1, device=batch.x_0.device)) + x_all = tuple(x_all) + + # Extract Laplacians + laplacian_all = self._extract_laplacians(batch) + + # Extract incidences + incidence_all = self._extract_incidences(batch) + + # Forward through SCCNN backbone + x_all_out = self.backbone(x_all, laplacian_all, incidence_all) + + # Build output dictionary with features at ALL ranks + model_out = {} + for i, x_rank in enumerate(x_all_out): + model_out[f"x_{i}"] = x_rank + + return model_out + + def _extract_laplacians(self, batch: Data) -> tuple: + """Extract Laplacian matrices for all ranks. + + Expected format: + - hodge_laplacian_0 + - down_laplacian_1, up_laplacian_1 + - down_laplacian_2, up_laplacian_2 + - ... + + Parameters + ---------- + batch : Data + Batch object containing features x_0, x_1, ..., x_k, Laplacians, and incidences. + + Returns + ------- + tuple + Tuple of Laplacian matrices for all ranks. + """ + laplacian_all = [] + + # Rank 0: Hodge Laplacian + if hasattr(batch, "hodge_laplacian_0"): + laplacian_all.append(batch.hodge_laplacian_0) + else: + laplacian_all.append(None) + + # Store down and up Laplacians for each rank + for rank in range(1, self.num_cell_dimensions): + down_key = f"down_laplacian_{rank}" + up_key = f"up_laplacian_{rank}" + + if hasattr(batch, down_key): + laplacian_all.append(getattr(batch, down_key)) + else: + laplacian_all.append(None) + + if hasattr(batch, up_key): + laplacian_all.append(getattr(batch, up_key)) + else: + laplacian_all.append(None) + + return tuple(laplacian_all) + + def _extract_incidences(self, batch: Data) -> tuple: + """Extract incidence matrices. + + Expected format: + - incidence_1: From 0-cells to 1-cells + - incidence_2: From 1-cells to 2-cells + - ... + + Parameters + ---------- + batch : Data + Batch object containing features x_0, x_1, ..., x_k, Laplacians, and incidences. + + Returns + ------- + tuple + Tuple of incidence matrices for all ranks. + """ + incidence_all = [] + + # Incidences map from rank k-1 to rank k, so we go from 1 to num_cell_dimensions + for rank in range(1, self.num_cell_dimensions + 1): + inc_key = f"incidence_{rank}" + if hasattr(batch, inc_key): + incidence_all.append(getattr(batch, inc_key)) + else: + incidence_all.append(None) + + return tuple(incidence_all) From b4a6c75c7a5caa417fc3fcca83ceaa7ee13c26da Mon Sep 17 00:00:00 2001 From: I745505 Date: Wed, 19 Nov 2025 23:25:50 +0100 Subject: [PATCH 13/32] :sparkles: Cell Readout layer for cell level prediction --- topobench/model/model.py | 10 +- topobench/nn/readouts/cell_readout.py | 179 ++++++++++++++++++++++++++ 2 files changed, 184 insertions(+), 5 deletions(-) create mode 100644 topobench/nn/readouts/cell_readout.py diff --git a/topobench/model/model.py b/topobench/model/model.py index b5f3313c0..695a3f79a 100755 --- a/topobench/model/model.py +++ b/topobench/model/model.py @@ -245,19 +245,19 @@ def process_outputs(self, model_out: dict, batch: Data) -> dict: if self.task_level in ["node", "cell"]: # Keep only train data points (for node-level or cell-level tasks) # Note: Rank-specific masks are applied in readout - # The readout stores which indices it kept in cell_indices - if "cell_indices" in model_out: + # The readout stores which indices it kept in valid_indices + if "valid_indices" in model_out: # Find intersection: which readout outputs are in this split? # Note: The split respects the masks applied in the readout - cell_indices = model_out["cell_indices"] - keep_mask = torch.isin(cell_indices, mask) + valid_indices = model_out["valid_indices"] + keep_mask = torch.isin(valid_indices, mask) # Filter logits and labels for key, val in model_out.items(): if key in ["logits", "labels"]: model_out[key] = val[keep_mask] else: - # No cell_indices: standard filtering (for non-masked tasks) + # No valid_indices: standard filtering (for non-masked tasks) for key, val in model_out.items(): if key in ["logits", "labels"]: model_out[key] = val[mask] diff --git a/topobench/nn/readouts/cell_readout.py b/topobench/nn/readouts/cell_readout.py new file mode 100644 index 000000000..fc6feaa13 --- /dev/null +++ b/topobench/nn/readouts/cell_readout.py @@ -0,0 +1,179 @@ +"""Cell-level readout for simplicial complexes. + +This readout layer predicts labels for valid cells of rank in target_ranks. +""" + +import torch +import torch.nn as nn +from torch_geometric.data import Data + + +class SimplicialCellLevelReadout(nn.Module): + """Readout for cell-level predictions on simplicial complexes. + + Takes features at each rank and predicts labels for valid cells + at specified target ranks. + + Parameters + ---------- + hidden_dim : int + Hidden dimension of input features on all ranks. + out_channels : int + Number of output classes. + num_cell_dimensions : int + Rank + 1 of simplicial complex. + target_ranks : List[int] + Which ranks have labels to predict (e.g., [2, 3, 4] for simplices with 3-5 nodes). + """ + + def __init__( + self, + hidden_dim: int, + out_channels: int, + num_cell_dimensions: int, + target_ranks: list[int], + ): + super().__init__() + self.hidden_dim = hidden_dim + self.out_channels = out_channels + self.num_cell_dimensions = num_cell_dimensions + self.target_ranks = target_ranks + self.task_level = ( + "cell" # For compatibility with TBModel need this attribute + ) + + # Create prediction head for each target rank + # Each rank might have different hidden dims in the future, so use a dict + self.predictors = nn.ModuleDict( + { + str(rank): nn.Linear(hidden_dim, out_channels) + for rank in target_ranks + } + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"target_ranks={self.target_ranks}, " + f"out_channels={self.out_channels})" + ) + + def forward(self, model_out: dict, batch: Data) -> dict: + """Compute cell-level predictions. + + Parameters + ---------- + model_out : dict + Dictionary containing x_0, x_1, ..., x_k features per rank. + batch : Data + Batch object containing cell_labels for each target rank and mask for valid cells. + + Returns + ------- + dict + Updated model_out with: + - logits: [num_labeled_cells, out_channels] + - cell_ranks: [num_labeled_cells] - which rank each prediction is for + - valid_indices: [num_labeled_cells] - which cell indices are valid + """ + all_logits = [] + all_labels = [] + all_ranks = [] + all_indices = [] + + for rank in self.target_ranks: + # Get features for this rank + x_key = f"x_{rank}" + if x_key not in model_out: + continue + + x_rank = model_out[x_key] # [num_cells_at_rank, hidden_dim] + + # Get labels for this rank + label_key = f"cell_labels_{rank}" + if not hasattr(batch, label_key): + continue + + labels = getattr(batch, label_key) # [num_cells_at_rank] + + # Filter valid cells with rank-specific mask + valid_mask = torch.ones( + len(labels), dtype=torch.bool, device=labels.device + ) + + mask_key = f"mask_{rank}" + if hasattr(batch, mask_key): + rank_mask = getattr(batch, mask_key) # Boolean mask + valid_mask &= rank_mask + + # Get final valid indices + valid_indices = torch.where(valid_mask)[0] + + if len(valid_indices) == 0: + continue + + # Get features and labels for valid cells + x_labeled = x_rank[valid_indices] # [num_labeled, hidden_dim] + y_labeled = labels[valid_indices] # [num_labeled] + + # Predict + logits = self.predictors[str(rank)]( + x_labeled + ) # [num_labeled, out_channels] + + all_logits.append(logits) + all_labels.append(y_labeled) + all_ranks.extend([rank] * len(valid_indices)) + all_indices.extend(valid_indices.tolist()) + + # Concatenate all predictions and labels + if len(all_logits) > 0: + model_out["logits"] = torch.cat(all_logits, dim=0) + model_out["labels"] = torch.cat(all_labels, dim=0) + model_out["cell_ranks"] = torch.tensor( + all_ranks, device=all_logits[0].device + ) + model_out["valid_indices"] = torch.tensor( + all_indices, device=all_logits[0].device + ) + else: + # No labeled cells found - use any available tensor for device + device = None + for key in model_out: + if isinstance(model_out[key], torch.Tensor): + device = model_out[key].device + break + if device is None: + device = torch.device("cpu") + + model_out["logits"] = torch.zeros( + 0, self.out_channels, device=device + ) + model_out["cell_ranks"] = torch.zeros( + 0, dtype=torch.long, device=device + ) + model_out["valid_indices"] = torch.zeros( + 0, dtype=torch.long, device=device + ) + + return model_out + + def __call__(self, model_out: dict, batch: Data) -> dict: + """Wrapper for forward to match AbstractZeroCellReadOut interface. + + Parameters + ---------- + model_out : dict + Dictionary containing features per rank. + batch : Data + Batch object containing cell labels and valid cell masks. + + Returns + ------- + dict + Updated model_out with: + - logits: [num_labeled_cells, out_channels] + - cell_ranks: [num_labeled_cells] - which rank each prediction is for + - valid_indices: [num_labeled_cells] - which cell index within rank + """ + return self.forward(model_out, batch) From ba34cd1ba602c78eb98bd052127dcbfd77dd30ee Mon Sep 17 00:00:00 2001 From: I745505 Date: Thu, 20 Nov 2025 10:33:44 +0100 Subject: [PATCH 14/32] :technologist: Modify All Cell Encoder for transductive learning --- topobench/nn/encoders/all_cell_encoder.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/topobench/nn/encoders/all_cell_encoder.py b/topobench/nn/encoders/all_cell_encoder.py index 4615b95c0..456535306 100644 --- a/topobench/nn/encoders/all_cell_encoder.py +++ b/topobench/nn/encoders/all_cell_encoder.py @@ -83,8 +83,20 @@ def forward( data.x_0 = data.x for i in self.dimensions: + # Get batch assignment if hasattr(data, f"x_{i}") and hasattr(self, f"encoder_{i}"): - batch = getattr(data, f"batch_{i}") + # Inductive case: batch_{i} maps each cell to its graph + if hasattr(data, f"batch_{i}"): + batch = getattr(data, f"batch_{i}") + else: + # Transductive case: all cells in same graph (batch index = 0) + batch = torch.zeros( + data[f"x_{i}"].shape[0], + dtype=torch.long, + device=data[f"x_{i}"].device, + ) + + # Apply encoder data[f"x_{i}"] = getattr(self, f"encoder_{i}")( data[f"x_{i}"], batch ) From 241096642ff43d2a6d91caeff5460f1018382dd4 Mon Sep 17 00:00:00 2001 From: I745505 Date: Thu, 20 Nov 2025 10:44:30 +0100 Subject: [PATCH 15/32] :lipstick: More direct dataset passing --- topobench/data/preprocessor/preprocessor.py | 4 ++-- topobench/data/utils/split_utils.py | 6 +----- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/topobench/data/preprocessor/preprocessor.py b/topobench/data/preprocessor/preprocessor.py index e5c4a913a..a50129c5a 100644 --- a/topobench/data/preprocessor/preprocessor.py +++ b/topobench/data/preprocessor/preprocessor.py @@ -244,9 +244,9 @@ def load_dataset_splits( raise ValueError("No learning setting specified in split_params") if split_params.learning_setting == "inductive": - return load_inductive_splits(self, split_params) + return load_inductive_splits(self.dataset, split_params) elif split_params.learning_setting == "transductive": - return load_transductive_splits(self, split_params) + return load_transductive_splits(self.dataset, split_params) else: raise ValueError( f"Invalid '{split_params.learning_setting}' learning setting.\ diff --git a/topobench/data/utils/split_utils.py b/topobench/data/utils/split_utils.py index e1722c94a..71e5c6932 100644 --- a/topobench/data/utils/split_utils.py +++ b/topobench/data/utils/split_utils.py @@ -532,11 +532,7 @@ def load_inductive_splits(dataset, parameters): else np.array(label_list) ) - root = ( - dataset.dataset.get_data_dir() - if hasattr(dataset.dataset, "get_data_dir") - else None - ) + root = dataset.get_data_dir() if hasattr(dataset, "get_data_dir") else None if parameters.split_type == "random": split_idx = random_splitting(labels, parameters, root=root) From 3e0b32416421ac21e6681818e5b8735ea9f5c1ec Mon Sep 17 00:00:00 2001 From: grapentt Date: Tue, 18 Nov 2025 15:12:30 +0100 Subject: [PATCH 16/32] :card_file_box: IO and Split utils --- topobench/data/utils/io_utils.py | 256 ++++++++++++++++++++++++++++ topobench/data/utils/split_utils.py | 123 ++++++++++++- 2 files changed, 373 insertions(+), 6 deletions(-) diff --git a/topobench/data/utils/io_utils.py b/topobench/data/utils/io_utils.py index 372db85e6..9e2ae8344 100644 --- a/topobench/data/utils/io_utils.py +++ b/topobench/data/utils/io_utils.py @@ -1,15 +1,21 @@ """Data IO utilities.""" import json +import os import os.path as osp import pickle +import tempfile +import warnings +import zipfile from urllib.parse import parse_qs, urlparse +import gdown import numpy as np import pandas as pd import requests import torch import torch_geometric +from pybiomart import Dataset as BioMartDataset from toponetx.classes import SimplicialComplex from torch_geometric.data import Data from torch_sparse import coalesce @@ -50,6 +56,218 @@ def get_file_id_from_url(url): return file_id +def get_folder_id_from_url(url): + """Extract the folder ID from a Google Drive folder URL or return ID if already provided. + + Parameters + ---------- + url : str + The Google Drive folder URL or folder ID. + + Returns + ------- + str + The folder ID extracted from the URL, or the ID itself if already an ID. + + Raises + ------ + ValueError + If the provided string is not a valid Google Drive folder URL or ID. + """ + # If it doesn't look like a URL (no scheme), assume it's already an ID + if "://" not in url and "/" not in url: + return url + + parsed_url = urlparse(url) + query_params = parse_qs(parsed_url.query) + + if "id" in query_params: # Case 1: URL format contains '?id=' + folder_id = query_params["id"][0] + elif ( + "folders/" in parsed_url.path + ): # Case 2: URL format contains '/folders/' + folder_id = parsed_url.path.split("/folders/")[1].split("/")[0] + else: + raise ValueError( + "The provided string is not a valid Google Drive folder URL or ID." + ) + return folder_id + + +def download_file(url, output_path, timeout=30, verify_ssl=True): + """Download a file from URL and save it. + + Parameters + ---------- + url : str + URL of the file to download. + output_path : str + Path where the file will be saved. + timeout : int, optional + Request timeout in seconds. Defaults to 30. + verify_ssl : bool, optional + Whether to verify SSL certificates. Defaults to True. + + Returns + ------- + bool + True if download succeeded, False otherwise. + + Raises + ------ + RuntimeError + If download fails. + """ + if not verify_ssl: + warnings.warn( + "SSL certificate verification is disabled", + UserWarning, + stacklevel=2, + ) + + try: + response = requests.get(url, timeout=timeout, verify=verify_ssl) + response.raise_for_status() + + with open(output_path, "wb") as f: + f.write(response.content) + + return True + + except Exception as e: + raise RuntimeError(f"Failed to download from {url}: {e}") from e + + +def download_and_extract_zip( + url, output_dir, filename_to_extract=None, timeout=30, verify_ssl=True +): + """Download a zip file from URL and extract it. + + Parameters + ---------- + url : str + URL of the zip file to download. + output_dir : str + Directory where files will be extracted. + filename_to_extract : str, optional + If provided, only extract this specific file from the zip. + If None, extract all files. + timeout : int, optional + Request timeout in seconds. Defaults to 30. + verify_ssl : bool, optional + Whether to verify SSL certificates. Defaults to True. + + Returns + ------- + bool + True if download and extraction succeeded, False otherwise. + + Raises + ------ + RuntimeError + If download or extraction fails. + """ + if not verify_ssl: + warnings.warn( + "SSL certificate verification is disabled", + UserWarning, + stacklevel=2, + ) + + try: + response = requests.get(url, timeout=timeout, verify=verify_ssl) + response.raise_for_status() + + # Save to temporary file + with tempfile.NamedTemporaryFile( + delete=False, suffix=".zip" + ) as tmp_file: + tmp_file.write(response.content) + zip_path = tmp_file.name + + # Extract + with zipfile.ZipFile(zip_path, "r") as zip_ref: + if filename_to_extract: + zip_ref.extract(filename_to_extract, output_dir) + else: + zip_ref.extractall(output_dir) + + # Clean up temp file + os.remove(zip_path) + return True + + except Exception as e: + raise RuntimeError( + f"Failed to download and extract from {url}: {e}" + ) from e + + +def download_ensembl_biomart_mapping( + output_path, + dataset="hsapiens_gene_ensembl", + attributes=None, + id_prefix="9606.", + timeout=120, +): + """Download ID mappings from Ensembl BioMart using pybiomart library. + + Note: Requires 'pybiomart' package. Install with: pip install pybiomart + + Parameters + ---------- + output_path : str + Path where the mapping file will be saved. + dataset : str, optional + BioMart dataset name. Defaults to "hsapiens_gene_ensembl". + attributes : list of str, optional + Attributes to retrieve. Defaults to ["ensembl_peptide_id", "uniprotswissprot"]. + id_prefix : str, optional + Prefix to add to IDs (e.g., "9606." for taxon). Defaults to "9606.". + timeout : int, optional + Request timeout in seconds. Defaults to 120. + + Returns + ------- + bool + True if download succeeded, False otherwise. + + Raises + ------ + RuntimeError + If download fails. + """ + + if attributes is None: + attributes = ["ensembl_peptide_id", "uniprotswissprot"] + + try: + # Query BioMart using the library + biomart_dataset = BioMartDataset( + name=dataset, host="http://www.ensembl.org" + ) + result_df = biomart_dataset.query(attributes=attributes) + + # Save to file with optional prefix + with open(output_path, "w") as f: + for _, row in result_df.iterrows(): + # Skip rows with missing values + if row.isnull().any(): + continue + + values = row.tolist() + # Add prefix to first column (ID) if specified + if id_prefix: + values[0] = f"{id_prefix}{values[0]}" + f.write("\t".join(str(v) for v in values) + "\n") + + return True + + except Exception as e: + raise RuntimeError( + f"Failed to download from Ensembl BioMart: {e}" + ) from e + + def download_file_from_drive( file_link, path_to_save, dataset_name, file_format="tar.gz" ): @@ -84,6 +302,44 @@ def download_file_from_drive( print("Failed to download the file.") +def download_folder_from_drive(folder_link, output_dir, quiet=False): + """Download an entire folder from Google Drive using gdown. + + Parameters + ---------- + folder_link : str + The Google Drive folder URL or folder ID. + output_dir : str + The directory where the folder contents will be saved. + quiet : bool, optional + If True, suppress download progress messages. Defaults to False. + + Returns + ------- + bool + True if download succeeded, False otherwise. + + Raises + ------ + ValueError + If the provided link is not a valid Google Drive folder URL. + """ + # Extract folder ID from URL if needed + folder_id = get_folder_id_from_url(folder_link) + + try: + gdown.download_folder( + id=folder_id, + output=output_dir, + quiet=quiet, + use_cookies=False, + ) + return True + except Exception as e: + print(f"Failed to download folder from Google Drive: {e}") + return False + + def download_file_from_link( file_link, path_to_save, dataset_name, file_format="tar.gz" ): diff --git a/topobench/data/utils/split_utils.py b/topobench/data/utils/split_utils.py index f78994222..cf63b50a4 100644 --- a/topobench/data/utils/split_utils.py +++ b/topobench/data/utils/split_utils.py @@ -248,16 +248,19 @@ def load_transductive_splits(dataset, parameters): ) data = dataset.data_list[0] + + # Check if this is multi-rank cell prediction + target_ranks = getattr(dataset, "target_ranks", None) + if target_ranks is not None and len(target_ranks) > 1: + return load_multirank_transductive_splits(dataset, parameters) + + # Single rank or node/graph prediction labels = data.y.numpy() # Ensure labels are one dimensional array assert len(labels.shape) == 1, "Labels should be one dimensional array" - root = ( - dataset.dataset.get_data_dir() - if hasattr(dataset.dataset, "get_data_dir") - else None - ) + root = dataset.get_data_dir() if hasattr(dataset, "get_data_dir") else None if parameters.split_type == "random": splits = random_splitting(labels, parameters, root=root) @@ -265,9 +268,18 @@ def load_transductive_splits(dataset, parameters): elif parameters.split_type == "k-fold": splits = k_fold_split(labels, parameters, root=root) + elif parameters.split_type == "fixed" and hasattr(dataset, "split_idx"): + splits = dataset.split_idx + if splits is None: + raise ValueError( + "Dataset has split_type='fixed' but split_idx property returned None. " + "Either the dataset doesn't support fixed splits or they failed to load." + ) + else: raise NotImplementedError( - f"split_type {parameters.split_type} not valid. Choose either 'random' or 'k-fold'" + f"split_type {parameters.split_type} not valid. Choose 'random', 'k-fold', or 'fixed'.\n" + f"If 'fixed' is chosen, the dataset must have a split_idx property." ) # Assign train val test masks to the graph @@ -287,6 +299,105 @@ def load_transductive_splits(dataset, parameters): return DataloadDataset([data]), None, None +def load_multirank_transductive_splits(dataset, parameters): + r"""Load dataset with multi-rank cell-level splits. + + For datasets with cell-level predictions across multiple ranks (e.g., edges, + triangles, tetrahedra simultaneously), this function creates independent + train/val/test splits for each rank. + + Parameters + ---------- + dataset : torch_geometric.data.Dataset + Dataset with multi-rank cell labels. + parameters : DictConfig + Configuration parameters containing split_type and train_prop. + + Returns + ------- + list: + List containing the train dataset (validation and test are None for transductive). + + Notes + ----- + Expects dataset to have: + - target_ranks: list of ranks to split + - data.cell_labels_{rank}: labels for each rank + + Creates per-rank masks: + - data.train_mask_{rank}: training indices for rank + - data.val_mask_{rank}: validation indices for rank + - data.test_mask_{rank}: test indices for rank + """ + assert len(dataset) == 1, ( + "Dataset should have only one graph/complex in a transductive setting." + ) + + data = dataset.data_list[0] + target_ranks = dataset.target_ranks + + root = dataset.get_data_dir() if hasattr(dataset, "get_data_dir") else None + + # Split each rank independently + for rank in target_ranks: + label_attr = f"cell_labels_{rank}" + + if not hasattr(data, label_attr): + raise ValueError( + f"Data object missing {label_attr} for rank {rank}. " + f"Available attributes: {list(data.keys())}" + ) + + labels = getattr(data, label_attr).numpy() + + # Handle multi-dimensional labels (e.g., multi-label classification) + if len(labels.shape) > 1: + # Use first column for stratification (common practice) + stratify_labels = ( + labels[:, 0] if labels.shape[1] > 0 else labels.flatten() + ) + else: + stratify_labels = labels + + # Create rank-specific root directory for splits + # This ensures each rank gets independent splits + rank_root = os.path.join(root, f"rank_{rank}") if root else None + + # Perform splitting + if parameters.split_type == "random": + splits = random_splitting( + stratify_labels, parameters, root=rank_root + ) + elif parameters.split_type == "k-fold": + splits = k_fold_split(stratify_labels, parameters, root=root) + elif parameters.split_type == "fixed" and hasattr( + dataset, "split_idx" + ): + splits = dataset.split_idx + if splits is None: + raise ValueError( + "Dataset has split_type='fixed' but split_idx property returned None. " + "Either the dataset doesn't support fixed splits or they failed to load." + ) + else: + raise NotImplementedError( + f"split_type {parameters.split_type} not valid. " + f"Choose 'random', 'k-fold', or 'fixed'.\n" + f"If 'fixed' is chosen, the dataset must have a split_idx property." + ) + + # Store per-rank masks + train_mask = torch.from_numpy(splits["train"]) + val_mask = torch.from_numpy(splits["valid"]) + test_mask = torch.from_numpy(splits["test"]) + + setattr(data, f"train_mask_{rank}", train_mask) + setattr(data, f"val_mask_{rank}", val_mask) + setattr(data, f"test_mask_{rank}", test_mask) + + return DataloadDataset([data]), None, None + + def load_inductive_splits(dataset, parameters): r"""Load multiple-graph datasets with the specified split. From 8f0ed342c84fda2e5334f592bf1c15d019b73af6 Mon Sep 17 00:00:00 2001 From: grapentt Date: Tue, 18 Nov 2025 15:19:22 +0100 Subject: [PATCH 17/32] :sparkles: PPI Dataset: HIGH-PPI + CORUM --- configs/dataset/simplicial/ppi_highppi.yaml | 100 ++++ .../data/datasets/ppi_highppi_dataset.py | 415 +++++++++++++ .../loaders/simplicial/ppi_highppi_loader.py | 71 +++ .../utils/datasets/simplicial/ppi_utils.py | 558 ++++++++++++++++++ 4 files changed, 1144 insertions(+) create mode 100644 configs/dataset/simplicial/ppi_highppi.yaml create mode 100644 topobench/data/datasets/ppi_highppi_dataset.py create mode 100644 topobench/data/loaders/simplicial/ppi_highppi_loader.py create mode 100644 topobench/data/utils/datasets/simplicial/ppi_utils.py diff --git a/configs/dataset/simplicial/ppi_highppi.yaml b/configs/dataset/simplicial/ppi_highppi.yaml new file mode 100644 index 000000000..a0f36793e --- /dev/null +++ b/configs/dataset/simplicial/ppi_highppi.yaml @@ -0,0 +1,100 @@ +################################################################################ +# HIGH-PPI + CORUM: Protein Interaction Prediction via Simplicial Complexes +################################################################################ +# +# Data Structure: +# - Proteins (rank 0): ~1,553 proteins +# - Edges (rank 1): ~6,660 HIGH-PPI edges with: +# * Features: 8-dim (7 interaction types + 1 STRING confidence score) +# - Interaction types: reaction, binding, ptmod, activation, inhibition, catalysis, expression +# - Confidence score [0, 1] measuring interaction probability (mapped to [-1, 1]) +# - Higher-order cells: CORUM protein complexes (2+ proteins) +# * Features: 1-dim (binary existence: 1=real, -1=fake) +# +# Note: Features at any rank can also serve as prediction targets (labels). +# Models should mask features of the rank being predicted to avoid data leakage. +# +# Prediction Tasks: +# - Edge (rank 1): Regression (confidence scores) or multi-label (interaction types) +# - Cell (ranks 2+): Binary classification (complex existence) +# +################################################################################ + +# Data loading configuration +loader: + _target_: topobench.data.loaders.PPIHighPPIDatasetLoader + parameters: + data_domain: simplicial + model_domain: simplicial + data_name: ppi_highppi + data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain} + + # CORUM Complex Configuration + min_complex_size: 2 # Minimum proteins per CORUM complex (2+ allowed) + # Edge features for edges in CORUM: + # - In HIGH-PPI: Interaction types + confidence boosted to 1.0 + # - Not in HIGH-PPI: [0,0,0,0,0,0,0, 1.0] (unknown types, high confidence) + max_complex_size: 6 # Maximum proteins per CORUM complex + + # Negative Sampling (for classification tasks) + neg_ratio: 1.0 # Ratio of negative to positive samples (1.0 = balanced) + + # Multi-Rank Prediction + target_ranks: [2, 3, 4, 5] # Which ranks to predict (train/test on) + # Max target rank must be <= max_complex_size - 1 + + # Edge Task Type (only applied when rank 1 in target_ranks) + edge_task: score # "score": Regression - predict confidence of interaction [0-1] + # "interaction_type": Multi-label - predict 7 interaction types + +# Model training configuration +parameters: + # Feature dimensions: [rank-0, rank-1, ..., rank-max] + # rank-0: One-hot encoded proteins TODO: Replace with richer embedding + # rank-1: 8-dim edge features (7 interaction types + 1 confidence score) + # rank-2+: 1-dim features (binary existence) + num_features: [1553, 8, 1, 1, 1, 1] + + num_classes: 2 # Depends on task: + # - Higher-order (ranks 2+): 2 (exists/doesn't exist) + # - Edge regression (rank 1, score): 1 (continuous output) + # - Edge multi-label (rank 1, interaction_type): 7 (7 types) + task: classification # Depends on target_ranks and edge_task: + # - Higher-order (ranks 2+): classification + # - Edge regression (rank 1, score): regression + # - Edge multi-label (rank 1, interaction_type): classification + task_level: cell # Predict on cells (edges/triangles/etc), not nodes or graphs + + # Multi-Rank Prediction + target_ranks: ${dataset.loader.parameters.target_ranks} + + loss_type: cross_entropy # Depends on task: + # - Higher-order binary: cross_entropy + # - Edge regression: mse or mae + # - Edge multi-label: bce_with_logits + monitor_metric: auroc # Depends on task: + # - Higher-order binary: auroc, f1, accuracy + # - Edge regression: mae, rmse + # - Edge multi-label: f1, auroc + +# Splits Configuration +split_params: + learning_setting: transductive # Single complex, split labeled cells + data_split_dir: ${paths.data_dir}/data_splits/${dataset.loader.parameters.data_name} + data_seed: 42 # Random seed for reproducible splits + + # Split Type Options: + # - "random": Random splitting with train_prop ratio + # - "k-fold": K-fold cross-validation + # - "fixed": Use HIGH-PPI's official train/val split (if available in raw data) + split_type: random + + train_prop: 0.8 # For random/k-fold: 80% train, 10% val, 10% test + # Ignored when split_type: fixed + +# Dataloader +dataloader_params: + batch_size: 1 + num_workers: 0 + pin_memory: False + persistent_workers: False diff --git a/topobench/data/datasets/ppi_highppi_dataset.py b/topobench/data/datasets/ppi_highppi_dataset.py new file mode 100644 index 000000000..bc439b340 --- /dev/null +++ b/topobench/data/datasets/ppi_highppi_dataset.py @@ -0,0 +1,415 @@ +"""PPI dataset integrating HIGH-PPI network data with CORUM human protein complexes. + +Combines: +- HIGH-PPI SHS27k: PPI network with 7 interaction type features + confidence scores +- CORUM: ~470 experimentally validated human protein complexes as native higher-order structures +- TODO: Add data for node features (embeddings) + +Simplicial complex structure: +- 0-cells: 1,553 proteins +- 1-cells: 6,660 PPI edges + CORUM complexes of size 2 +- 2+ cells: CORUM complexes of size 3+ +""" + +import json +import os +import os.path as osp +from typing import ClassVar + +import numpy as np +import torch +from omegaconf import DictConfig +from torch_geometric.data import Data, InMemoryDataset +from torch_geometric.io import fs + +from topobench.data.utils import get_complex_connectivity +from topobench.data.utils.datasets.simplicial.ppi_utils import ( + build_data_features_and_labels, + build_simplicial_complex_with_features, + generate_negative_samples, + load_corum_complexes, + load_highppi_network, + load_id_mapping, +) +from topobench.data.utils.io_utils import ( + download_ensembl_biomart_mapping, + download_file, + download_folder_from_drive, +) + + +class PPIHighPPIDataset(InMemoryDataset): + """HIGH-PPI network integrated with CORUM human protein complexes. + + Combines 6,660 protein-protein interactions from HIGH-PPI SHS27k with ~470 + experimentally validated human protein complexes from CORUM database. + + Parameters + ---------- + root : str + Root directory. + name : str, optional + Dataset name, default "ppi_highppi". + parameters : DictConfig, optional + Config with min_complex_size, max_complex_size, target_ranks, neg_ratio, + edge_task ("score" or "interaction_type"). + **kwargs : dict + Additional keyword arguments passed to InMemoryDataset. + """ + + INTERACTION_TYPES: ClassVar[list[str]] = [ + "reaction", + "binding", + "ptmod", + "activation", + "inhibition", + "catalysis", + "expression", + ] + + # Data source URLs + HIGHPPI_GDRIVE_FOLDER: ClassVar[str] = ( + "https://drive.google.com/drive/folders/1Yb-fdWJ5vTe0ePAGNfrUluzO9tz1lHIF?usp=sharing" + ) + CORUM_URL: ClassVar[str] = ( + "https://mips.helmholtz-muenchen.de/fastapi-corum/public/file/download_current_file?file_id=human&file_format=txt" + ) + + # Required raw data filenames + HIGHPPI_NETWORK_FILE: ClassVar[str] = ( + "protein.actions.SHS27k.STRING.pro2.txt" + ) + ID_MAPPING_FILE: ClassVar[str] = "ensp_uniprot.txt" + CORUM_COMPLEXES_FILE: ClassVar[str] = "allComplexes.txt" + + def __init__( + self, + root: str, + name: str = "ppi_highppi", + parameters: DictConfig = None, + **kwargs, + ): + self.name = name + self.parameters = parameters or DictConfig({}) + self.min_complex_size = self.parameters.get("min_complex_size", 2) + self.max_complex_size = self.parameters.get("max_complex_size", 6) + self.max_rank = self.max_complex_size - 1 + self.neg_ratio = self.parameters.get("neg_ratio", 1.0) + self.target_ranks = self.parameters.get("target_ranks", [2, 3, 4, 5]) + self.edge_task = self.parameters.get("edge_task", "score") + + self.highppi_edges = [] # List of (p1, p2, interaction_type_vector, confidence_score) + self.corum_complexes = [] # List of sets of proteins in a complex + self.all_proteins = set() + self.ensembl_to_uniprot = {} + self.uniprot_to_ensembl = {} + self.official_splits = {} + + super().__init__(root, **kwargs) + + out = fs.torch_load(self.processed_paths[0]) + if len(out) == 3: + data, self.slices, self.sizes = out + data_cls = Data + else: + data, self.slices, self.sizes, data_cls = out + + if not isinstance(data, dict): + self.data = data + else: + self.data = data_cls.from_dict(data) + + # Ensure data.y is set for single-rank compatibility + # TODO: Change for B2 submission which will introduce a unified training loop + if len(self.target_ranks) == 1: + label_attr = f"cell_labels_{self.target_ranks[0]}" + if hasattr(self._data, label_attr): + self._data.y = getattr(self._data, label_attr) + + @property + def raw_dir(self) -> str: + """Return the path to the raw directory. + + Returns + ------- + str + Path to the raw directory. + """ + return osp.join(self.root, "raw") + + @property + def processed_dir(self) -> str: + """Return the path to the processed directory. + + Returns + ------- + str + Path to the processed directory. + """ + return osp.join(self.root, "processed") + + @property + def raw_file_names(self) -> list[str]: + """Return list of required raw file names. + + Returns + ------- + List[str] + Required raw data files. + """ + return [ + self.HIGHPPI_NETWORK_FILE, + self.ID_MAPPING_FILE, + self.CORUM_COMPLEXES_FILE, + ] + + @property + def processed_file_names(self) -> list[str]: + """Return the name of the processed file. + + Filename includes target_ranks to avoid cache conflicts when + different ranks are requested. + + Returns + ------- + List[str] + List containing the name of the processed file. + """ + # Include target_ranks in filename to prevent cache conflicts + ranks_str = "_".join(map(str, self.target_ranks)) + return [f"data_ranks_{ranks_str}.pt"] + + def download(self) -> None: + """Download HIGH-PPI and CORUM data files.""" + + # Check if files already exist + all_exist = all( + osp.exists(osp.join(self.raw_dir, fname)) + for fname in self.raw_file_names + ) + if all_exist: + print("All required files already present") + return + + print("Downloading HIGH-PPI SHS27k dataset and CORUM complexes...") + os.makedirs(self.raw_dir, exist_ok=True) + + if not osp.exists(osp.join(self.raw_dir, self.CORUM_COMPLEXES_FILE)): + print("Downloading CORUM human protein complexes...") + download_file( + self.CORUM_URL, + osp.join(self.raw_dir, self.CORUM_COMPLEXES_FILE), + verify_ssl=False, + ) + print("CORUM download complete") + + if not osp.exists(osp.join(self.raw_dir, self.ID_MAPPING_FILE)): + print("Downloading Ensembl-UniProt ID mapping...") + download_ensembl_biomart_mapping( + osp.join(self.raw_dir, self.ID_MAPPING_FILE) + ) + print("ID mapping download complete") + + if not osp.exists(osp.join(self.raw_dir, self.HIGHPPI_NETWORK_FILE)): + print("Downloading HIGH-PPI network data from Google Drive...") + success = download_folder_from_drive( + self.HIGHPPI_GDRIVE_FOLDER, self.raw_dir, quiet=False + ) + if not success: + raise RuntimeError( + "Failed to download HIGH-PPI data from Google Drive" + ) + print("HIGH-PPI download complete") + + # Final verification + missing_files = [ + fname + for fname in self.raw_file_names + if not osp.exists(osp.join(self.raw_dir, fname)) + ] + + if missing_files: + raise FileNotFoundError( + f"Failed to download required files: {missing_files}. " + ) + + def process(self): + """Build simplicial complex: HIGH-PPI edges + CORUM complexes.""" + print("\n" + "=" * 70) + print( + "Building PPI simplicial complex from HIGH-PPI and CORUM datasets" + ) + print("=" * 70) + + # Load Ensembl <-> UniProt ID mapping + mapping_path = osp.join(self.raw_dir, self.ID_MAPPING_FILE) + self.ensembl_to_uniprot, self.uniprot_to_ensembl = load_id_mapping( + mapping_path + ) + + # Load HIGH-PPI network with interaction types and confidence scores + highppi_path = osp.join(self.raw_dir, self.HIGHPPI_NETWORK_FILE) + self.highppi_edges, self.all_proteins = load_highppi_network( + highppi_path, self.INTERACTION_TYPES + ) + + # Load CORUM complexes, filter to SHS27k proteins + corum_path = osp.join(self.raw_dir, self.CORUM_COMPLEXES_FILE) + self.corum_complexes = load_corum_complexes( + corum_path, + self.all_proteins, + self.ensembl_to_uniprot, + self.uniprot_to_ensembl, + self.min_complex_size, + self.max_complex_size, + ) + + self._load_splits() + + print("Building simplicial complex...") + sc, edge_data, cell_data = build_simplicial_complex_with_features( + self.all_proteins, + self.highppi_edges, + self.corum_complexes, + self.min_complex_size, + self.max_rank, + ) + + print("Generating negative samples...") + edge_data, cell_data = generate_negative_samples( + sc, edge_data, cell_data, self.all_proteins, self.neg_ratio + ) + + print("Extracting features and connectivity...") + x_dict, labels_dict = build_data_features_and_labels( + sc, + edge_data, + cell_data, + self.target_ranks, + self.max_rank, + edge_task=self.edge_task, + ) + + # Get connectivity + connectivity = get_complex_connectivity( + sc, self.max_rank, signed=False + ) + + # Build Data object + protein_list = sorted(list(sc.nodes)) + protein_to_idx = {p: i for i, p in enumerate(protein_list)} + n_edges = len(list(sc.skeleton(1))) + + data = Data( + **x_dict, + **connectivity, + **labels_dict, + num_proteins=len(protein_list), + num_edges=n_edges, + num_complexes=len(self.corum_complexes), + protein_to_idx=protein_to_idx, + ) + + # Add x and y for compatibility with generic tests + # x_0 uses one-hot encoding, so dimension equals number of proteins + data.x = x_dict.get( + "x_0", torch.zeros(0, len(protein_list)) + ) # TODO: This data will not be used for node-level prediction + if ( + self.target_ranks + and f"cell_labels_{self.target_ranks[0]}" in labels_dict + ): + data.y = labels_dict[f"cell_labels_{self.target_ranks[0]}"] + else: + data.y = torch.zeros(len(protein_list), dtype=torch.long) + + # Add official splits if available + if self.official_splits: + data.train_mask = torch.tensor( + self.official_splits.get("train_index", []), dtype=torch.long + ) + data.val_mask = torch.tensor( + self.official_splits.get("valid_index", []), dtype=torch.long + ) + + # Save processed data + print("Saving processed data...") + self.data, self.slices = self.collate([data]) + fs.torch_save( + (self._data.to_dict(), self.slices, {}, self._data.__class__), + self.processed_paths[0], + ) + + print("\n" + "=" * 70) + print("✅ PROCESSING COMPLETE!") + print("📊 Dataset statistics:") + print(f" - Proteins (0-cells): {len(self.all_proteins)}") + print(f" - Labeled edges (1-cells): {len(self.highppi_edges)}") + print(f" - CORUM complexes: {len(self.corum_complexes)}") + print(f"📁 Saved to: {self.processed_paths[0]}") + print(f"💾 Size: {osp.getsize(self.processed_paths[0]) / 1e6:.1f} MB") + print("=" * 70 + "\n") + + @property + def data_list(self): + """Return list of data objects for TopoBench compatibility. + + Returns + ------- + list + List containing single data object (transductive setting). + """ + return [self._data] + + def get_data_dir(self): + """Return data directory for split file storage. + + Returns + ------- + str + Path to data directory. + """ + return self.root + + @property + def split_idx(self): + """Return train/val/test split indices for split_type='fixed'. + + Used when config has split_type='fixed'. Returns HIGH-PPI's official + train/val split if it was successfully loaded, otherwise None. + + Returns + ------- + dict or None + Dictionary with 'train', 'valid', 'test' keys containing indices, + or None (triggers random/k-fold splitting based on split_type). + """ + if hasattr(self, "official_splits") and self.official_splits: + return { + "train": np.array(self.official_splits.get("train_index", [])), + "valid": np.array(self.official_splits.get("val_index", [])), + "test": np.array(self.official_splits.get("val_index", [])), + } + return None + + # TODO: This is not working yet + def _load_splits(self): + """Load official train/val split indices from HIGH-PPI. + + Loads splits into self.official_splits which will be used if split_type='fixed'. + Fails silently if splits are not available (random/k-fold will be used instead). + """ + split_path = osp.join(self.raw_dir, "train_val_split_1.json") + if not osp.exists(split_path): + return + + try: + with open(split_path) as f: + content = f.read().strip() + if len(content) >= 10: # Basic validation + self.official_splits = json.loads(content) + print( + "Official train/val splits available (use split_type='fixed' to use them)" + ) + except (json.JSONDecodeError, Exception): + pass # Silently ignore - will use random/k-fold splitting diff --git a/topobench/data/loaders/simplicial/ppi_highppi_loader.py b/topobench/data/loaders/simplicial/ppi_highppi_loader.py new file mode 100644 index 000000000..9cd737589 --- /dev/null +++ b/topobench/data/loaders/simplicial/ppi_highppi_loader.py @@ -0,0 +1,71 @@ +"""Loader for PPI dataset (HIGH-PPI variant) with CORUM complexes.""" + +from omegaconf import DictConfig + +from topobench.data.datasets.ppi_highppi_dataset import PPIHighPPIDataset +from topobench.data.loaders.base import AbstractLoader + + +class PPIHighPPIDatasetLoader(AbstractLoader): + """Load HIGH-PPI SHS27k dataset with CORUM topological enrichment. + + This loader creates a hybrid simplicial complex from: + - HIGH-PPI's SHS27k PPI network (labeled edges) + - CORUM protein complexes (unlabeled higher-order cells) + + Task: Edge-level multi-label classification (7 interaction types) + + Parameters + ---------- + parameters : DictConfig + Configuration parameters containing: + - data_dir: Root directory for data + - data_name: Name of the dataset + - min_complex_size: Minimum CORUM complex size + - max_complex_size: Maximum CORUM complex size + - max_rank: Maximum simplicial rank + - use_official_split: Use HIGH-PPI's train/val split + **kwargs : dict + Additional keyword arguments. + """ + + def __init__(self, parameters: DictConfig, **kwargs) -> None: + super().__init__(parameters, **kwargs) + + def load_dataset(self, **kwargs) -> PPIHighPPIDataset: + """Load the dataset. + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments passed to dataset initialization. + + Returns + ------- + PPIHighPPIDataset + Dataset with HIGH-PPI network and CORUM complexes. + """ + dataset = self._initialize_dataset(**kwargs) + self.data_dir = self.get_data_dir() + return dataset + + def _initialize_dataset(self, **kwargs) -> PPIHighPPIDataset: + """Initialize the HIGH-PPI SHS27k dataset. + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments for dataset initialization. + + Returns + ------- + PPIHighPPIDataset + The initialized dataset instance. + """ + self.dataset = PPIHighPPIDataset( + root=str(self.root_data_dir), + name=self.parameters.get("data_name", "highppi_shs27k"), + parameters=self.parameters, + **kwargs, + ) + return self.dataset diff --git a/topobench/data/utils/datasets/simplicial/ppi_utils.py b/topobench/data/utils/datasets/simplicial/ppi_utils.py new file mode 100644 index 000000000..3eb85234b --- /dev/null +++ b/topobench/data/utils/datasets/simplicial/ppi_utils.py @@ -0,0 +1,558 @@ +"""Refactored PPI utilities with cleaner separation of concerns. + +Key improvements: +1. Separate topology building from feature/label assignment +2. Consistent data types (no mixed list/float/int in cell_labels) +3. Single-pass iterations (no redundant loops) +4. Clear data flow: topology → features → labels → tensors +""" + +import os +import random +from itertools import combinations + +import pandas as pd +import torch +from toponetx.classes import SimplicialComplex + + +def load_id_mapping( + mapping_path: str, +) -> tuple[dict[str, str], dict[str, list[str]]]: + """Load Ensembl ↔ UniProt ID mapping. + + Parameters + ---------- + mapping_path : str + Path to ensp_uniprot.txt mapping file. + + Returns + ------- + ensembl_to_uniprot : dict + Mapping from Ensembl IDs to UniProt IDs. + uniprot_to_ensembl : dict + Reverse mapping (UniProt to list of Ensembl IDs). + + Raises + ------ + FileNotFoundError + If mapping file does not exist. + """ + if not os.path.exists(mapping_path): + raise FileNotFoundError( + f"ID mapping file not found: {mapping_path}. " + "This file is required to map between Ensembl and UniProt IDs for CORUM complexes." + ) + + ensembl_to_uniprot = {} + uniprot_to_ensembl = {} + + with open(mapping_path) as f: + for line in f: + line = line.strip() + if not line or line.startswith("#"): + continue + + parts = line.split("\t") + if len(parts) >= 2: + ensembl_id = parts[0].strip() + uniprot_id = parts[1].strip() + + if uniprot_id and uniprot_id not in ("Noneid", "None"): + ensembl_to_uniprot[ensembl_id] = uniprot_id + + if uniprot_id not in uniprot_to_ensembl: + uniprot_to_ensembl[uniprot_id] = [] + uniprot_to_ensembl[uniprot_id].append(ensembl_id) + + return ensembl_to_uniprot, uniprot_to_ensembl + + +def load_highppi_network( + file_path: str, interaction_types: list[str] +) -> tuple[list[tuple], set[str]]: + """Load HIGH-PPI network with interaction types and confidence scores. + + Parameters + ---------- + file_path : str + Path to HIGH-PPI SHS27k file. + interaction_types : list + List of valid interaction type names. + + Returns + ------- + highppi_edges : list + List of (p1, p2, interaction_vector, score) tuples. + all_proteins : set + Set of all protein IDs in the network. + + Raises + ------ + FileNotFoundError + If HIGH-PPI network file does not exist. + """ + + if not os.path.exists(file_path): + raise FileNotFoundError( + f"HIGH-PPI network file not found: {file_path}. " + "This file contains the protein-protein interaction network." + ) + + df = pd.read_csv(file_path, sep="\t") + + # Rename columns to more intuitive names + df = df.rename( + columns={ + "item_id_a": "protein_1", + "item_id_b": "protein_2", + "mode": "interaction_type", + "score": "confidence_score", + } + ) + + edge_labels = {} + edge_scores = {} + all_proteins = set() + + for _, row in df.iterrows(): + p1 = str(row["protein_1"]).strip() + p2 = str(row["protein_2"]).strip() + + all_proteins.add(p1) + all_proteins.add(p2) + + edge_key = tuple(sorted([p1, p2])) + + score = float(row["confidence_score"]) / 1000.0 + if edge_key not in edge_scores: + edge_scores[edge_key] = 0.0 + edge_scores[edge_key] = max(edge_scores[edge_key], score) + + interaction_type = str(row["interaction_type"]).strip() + if edge_key not in edge_labels: + edge_labels[edge_key] = [0] * 7 + if interaction_type in interaction_types: + idx_type = interaction_types.index(interaction_type) + edge_labels[edge_key][idx_type] = 1 + + highppi_edges = [ + (p1, p2, labels, edge_scores[(p1, p2)]) + for (p1, p2), labels in edge_labels.items() + ] + + return highppi_edges, all_proteins + + +def load_corum_complexes( + file_path: str, + all_proteins: set[str], + ensembl_to_uniprot: dict[str, str], + uniprot_to_ensembl: dict[str, list[str]], + min_size: int, + max_size: int, +) -> list[set[str]]: + """Load and filter CORUM protein complexes. + + Parameters + ---------- + file_path : str + Path to CORUM allComplexes.txt file. + all_proteins : set + Set of proteins in the network (for filtering). + ensembl_to_uniprot : dict + Ensembl to UniProt ID mapping. + uniprot_to_ensembl : dict + UniProt to Ensembl ID mapping. + min_size : int + Minimum complex size. + max_size : int + Maximum complex size. + + Returns + ------- + list[set[str]] + List of sets, each containing Ensembl protein IDs. + + Raises + ------ + FileNotFoundError + If CORUM file does not exist. + """ + + if not os.path.exists(file_path): + raise FileNotFoundError( + f"CORUM complexes file not found: {file_path}. " + "This file is required to load experimentally validated protein complexes." + ) + + df = pd.read_csv(file_path, sep="\t", low_memory=False) + + # CORUM uses 'subunits_uniprot_id' column for UniProt IDs + if "subunits_uniprot_id" not in df.columns: + raise ValueError( + f"Expected column 'subunits_uniprot_id' not found in CORUM file. Available columns: {df.columns.tolist()}" + ) + + # Map proteins to UniProt + shs27k_uniprot = { + ensembl_to_uniprot[eid] + for eid in all_proteins + if eid in ensembl_to_uniprot + } + + corum_complexes = [] + for _, row in df.iterrows(): + subunits_str = row["subunits_uniprot_id"] + if pd.isna(subunits_str): + continue + + proteins_uniprot = { + p.strip() for p in subunits_str.split(";") if p.strip() + } + proteins_in_network = proteins_uniprot & shs27k_uniprot + + if not (min_size <= len(proteins_in_network) <= max_size): + continue + + # Convert to Ensembl IDs + ensembl_complex = set() + for uniprot_id in proteins_in_network: + if uniprot_id in uniprot_to_ensembl: + for ensembl_id in uniprot_to_ensembl[uniprot_id]: + if ensembl_id in all_proteins: + ensembl_complex.add(ensembl_id) + break + + corum_complexes.append(ensembl_complex) + + return corum_complexes + + +def build_simplicial_complex_with_features( + all_proteins: set[str], + highppi_edges: list[tuple], + corum_complexes: list[set[str]], + min_complex_size: int, + max_rank: int, +) -> tuple[SimplicialComplex, dict, dict]: + """Build simplicial complex with topology and metadata from PPI data. + + Constructs the complex structure and tracks cell data. + + Parameters + ---------- + all_proteins : set + Set of all protein IDs. + highppi_edges : list + List of (p1, p2, interaction_vector, score) tuples. + corum_complexes : list + List of protein complexes as sets. + min_complex_size : int + Minimum complex size to include. + max_rank : int + Maximum rank to consider. + + Returns + ------- + sc : SimplicialComplex + The constructed simplicial complex. + edge_data : dict + Edge features {edge_tuple: tensor([7 interaction types, 1 confidence])}. + cell_data : dict + Binary labels per rank {rank: {cell_tuple: {-1, 1}}}. + """ + sc = SimplicialComplex() + edge_data = {} # {edge_tuple: tensor([7 types + 1 score])} + cell_data = {} # {rank: {cell_tuple: -1 or 1}} + + # Add 0-cells (proteins) + for protein in sorted(all_proteins): + sc.add_simplex([protein]) + + # Add 1-cells (HIGH-PPI edges) with features + for p1, p2, interaction_vector, score in highppi_edges: + edge_tuple = tuple(sorted([p1, p2])) + sc.add_simplex([p1, p2]) + + # Store 8-dim feature vector (7 interaction types + 1 confidence) + edge_data[edge_tuple] = torch.tensor( + interaction_vector + [2 * score - 1], + dtype=torch.float, # Convert confidence score affinely: [0, 1] -> [-1, 1] + ) + + # Process CORUM complexes top-down (largest first) + # This ensures lower-rank CORUM complexes can override negative labels + sorted_complexes = sorted(corum_complexes, key=len, reverse=True) + + for complex_proteins in sorted_complexes: + # Filter by size and protein membership + if len(complex_proteins) < min_complex_size: + continue + if not complex_proteins.issubset(all_proteins): + continue + + complex_tuple = tuple(sorted(complex_proteins)) + rank = len(complex_tuple) - 1 + + if rank > max_rank: + continue + + # Add complex to simplicial complex (automatically adds all faces) + sc.add_simplex(list(complex_tuple)) + + if rank == 1: + edge_data[complex_tuple] = torch.tensor( + [0, 0, 0, 0, 0, 0, 0, 1.0], dtype=torch.float + ) + continue + + # Mark this complex as positive + if rank not in cell_data: + cell_data[rank] = {} + cell_data[rank][complex_tuple] = 1 + + # Mark all proper sub-faces as negative (not real complexes themselves) + # Only mark if not already labeled (top-down iteration handles overlaps) + for sub_rank in range(1, rank): + if sub_rank not in cell_data: + cell_data[sub_rank] = {} + + for sub_face in combinations(complex_tuple, sub_rank + 1): + sub_tuple = tuple(sorted(sub_face)) + if sub_rank == 1 and sub_tuple not in edge_data: + edge_data[sub_tuple] = torch.tensor( + [0, 0, 0, 0, 0, 0, 0, -1.0], dtype=torch.float + ) + elif sub_tuple not in cell_data[sub_rank]: + cell_data[sub_rank][sub_tuple] = -1 + + return sc, edge_data, cell_data + + +def generate_negative_samples( + sc: SimplicialComplex, + edge_data: dict[tuple, torch.Tensor], + cell_data: dict[int, dict[tuple, int]], + all_proteins: set[str], + neg_ratio: float, +) -> tuple[dict[tuple, torch.Tensor], dict[int, dict[tuple, int]]]: + """Generate negative samples proportionally across ranks. + + Parameters + ---------- + sc : SimplicialComplex + Current simplicial complex. + edge_data : dict + Edge features to update with negative edges. + cell_data : dict + Existing binary data per rank {rank: {cell_tuple: {-1, 1}}}. + all_proteins : set + Set of all protein IDs. + neg_ratio : float + Ratio of negative to positive samples. + + Returns + ------- + edge_data : dict + Updated with negative edge features. + cell_data : dict + Updated data with negative samples added (value=-1). + """ + random.seed(42) + + # Count positive samples per rank + positive_counts = {} + for rank in range(2, sc.dim + 1): + if rank in cell_data: + positive_counts[rank] = sum( + 1 for label in cell_data[rank].values() if label == 1 + ) + else: + positive_counts[rank] = 0 + + # Generate negatives per rank + all_proteins_list = list(all_proteins) + + for rank, n_positive in positive_counts.items(): + if n_positive == 0: + continue + + n_negative_needed = int(n_positive * neg_ratio) + if n_negative_needed == 0: + continue + + # Get existing cells at this rank + existing_cells = set(cell_data.get(rank, {}).keys()) + existing_cells.update( + tuple(sorted(cell)) for cell in sc.skeleton(rank) + ) + + # Generate random cells until we have enough negatives + negatives_added = 0 + max_attempts = n_negative_needed * 100 + + for _ in range(max_attempts): + if negatives_added >= n_negative_needed: + break + + # Sample random proteins for this rank + sampled = random.sample(all_proteins_list, rank + 1) + cell_tuple = tuple(sorted(sampled)) + + # Only add if it doesn't exist yet + if cell_tuple not in existing_cells: + # Add to complex and label as negative + sc.add_simplex(list(cell_tuple)) + + if rank not in cell_data: + cell_data[rank] = {} + cell_data[rank][cell_tuple] = -1 + + # For edges (rank 1), also create feature vector + if rank == 1 and cell_tuple not in edge_data: + edge_data[cell_tuple] = torch.tensor( + [0, 0, 0, 0, 0, 0, 0, 0.0], dtype=torch.float + ) + + existing_cells.add(cell_tuple) + negatives_added += 1 + + print( + f" Rank {rank}: {n_positive} positive, {negatives_added} negative samples" + ) + + return edge_data, cell_data + + +def build_data_features_and_labels( + sc: SimplicialComplex, + edge_data: dict[tuple, torch.Tensor], + cell_data: dict[int, dict[tuple, int]], + target_ranks: list[int], + max_rank: int, + edge_task: str = None, +) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: + """Create feature and label tensors for all ranks. + + Parameters + ---------- + sc : SimplicialComplex + The constructed simplicial complex. + edge_data : dict + Edge features {edge_tuple: tensor([7 interaction types, 1 confidence])}. + cell_data : dict + Binary labels {rank: {cell_tuple: {-1, 1}}}. + target_ranks : list + Ranks to predict on. + max_rank : int + Maximum rank in configuration. + edge_task : str, optional + Edge prediction task: "interaction_type" or "score". + Only used if rank 1 is in target_ranks. + + Returns + ------- + x_dict : dict + Features per rank {f"x_{rank}": tensor}. + labels_dict : dict + Labels per target rank {f"cell_labels_{rank}": tensor}. + """ + actual_max_rank = sc.dim + x_dict = {} + labels_dict = {} + + # For each rank: build features and labels + for rank in range(min(max_rank, actual_max_rank) + 1): + cells = list(sc.skeleton(rank)) + n_cells = len(cells) + + if n_cells == 0: + dim = 1 if rank == 0 else (8 if rank == 1 else 1) + x_dict[f"x_{rank}"] = torch.zeros(0, dim) + continue + + is_target = rank in target_ranks + + match rank: + case 0: + # Nodes: one-hot encoding + # TODO: Use richer embeddings (ESM, structure, GO annotations) + x_dict["x_0"] = torch.eye(n_cells) + + case 1: + # Edges: 8-dim features (7 interaction types + 1 confidence) + features = [] + labels = [] if is_target else None + + for edge in cells: + edge_tuple = tuple(sorted(edge)) + feat_vec = edge_data[edge_tuple] + + if is_target: + # Split features/labels based on edge_task + if edge_task == "interaction_type": + labels.append( + feat_vec[:7] + ) # First 7 dims = labels + features.append( + feat_vec[7:8] + ) # Last dim = feature + elif edge_task == "score": + labels.append(feat_vec[7:8]) # Last dim = label + features.append( + feat_vec[:7] + ) # First 7 dims = features + else: + # Not a target rank: use all 8 dims as features + features.append(feat_vec) + + x_dict["x_1"] = torch.stack(features) + + if is_target: + labels_dict["cell_labels_1"] = torch.stack(labels) + + case _: + # Higher-order cells + features = [] + labels = [] if is_target else None + + for cell in cells: + cell_tuple = tuple(sorted(cell)) + binary_existence_val = cell_data[rank][cell_tuple] + + # Features: 0 for target, {-1,+1} for non-target TODO: Bit unsure about this. Non-interacting edges also get 0 and 0 means it will not influence neighbors + # TODO: Maybe we should pass some labels as features for true transductivity/semi-supervision? + if is_target: + # Target rank: features are 0, labels are in {-1, 1} + features.append(torch.zeros(1, dtype=torch.float)) + labels.append( + torch.tensor( + [binary_existence_val], dtype=torch.float + ) + ) + else: + # Non-target rank: use labels as features {-1, +1} + features.append( + torch.tensor( + [binary_existence_val], dtype=torch.float + ) + ) + + x_dict[f"x_{rank}"] = torch.stack(features) + + if is_target: + labels_tensor = torch.tensor(labels, dtype=torch.long) + labels_dict[f"cell_labels_{rank}"] = labels_tensor + n_pos = (labels_tensor == 1).sum().item() + n_neg = (labels_tensor == 0).sum().item() + print( + f" Rank {rank}: {n_pos} positive, {n_neg} negative labels" + ) + + # Add empty features for non-existent ranks + for rank in range(actual_max_rank + 1, max_rank + 1): + dim = 1 if rank == 0 else (8 if rank == 1 else 1) + x_dict[f"x_{rank}"] = torch.zeros(0, dim) + + return x_dict, labels_dict From 9057e8c9a8208918530cba0221017732cb7df7fc Mon Sep 17 00:00:00 2001 From: grapentt Date: Tue, 18 Nov 2025 23:22:08 +0100 Subject: [PATCH 18/32] :bug: Creation of unlabeled cells that were meant to be negative + Labels should be {0,1} or [0,1] respectively for CrossEntropyLoss and interpretability --- .../utils/datasets/simplicial/ppi_utils.py | 74 ++++++++++++------- 1 file changed, 48 insertions(+), 26 deletions(-) diff --git a/topobench/data/utils/datasets/simplicial/ppi_utils.py b/topobench/data/utils/datasets/simplicial/ppi_utils.py index 3eb85234b..8271b14f3 100644 --- a/topobench/data/utils/datasets/simplicial/ppi_utils.py +++ b/topobench/data/utils/datasets/simplicial/ppi_utils.py @@ -361,25 +361,29 @@ def generate_negative_samples( """ random.seed(42) - # Count positive samples per rank - positive_counts = {} - for rank in range(2, sc.dim + 1): + all_proteins_list = list(all_proteins) + + # Generate negatives from highest to rank 2 (top-down) + # Note that edges already have score in [-1, 1] so no need to add more negatives + for rank in range(sc.dim, 1, -1): + # Count positive and existing negative samples at this rank + n_positive = 0 + n_existing_negative = 0 if rank in cell_data: - positive_counts[rank] = sum( + n_positive = sum( 1 for label in cell_data[rank].values() if label == 1 ) - else: - positive_counts[rank] = 0 - - # Generate negatives per rank - all_proteins_list = list(all_proteins) - - for rank, n_positive in positive_counts.items(): + n_existing_negative = sum( + 1 for label in cell_data[rank].values() if label == -1 + ) if n_positive == 0: continue - n_negative_needed = int(n_positive * neg_ratio) - if n_negative_needed == 0: + # Calculate how many more negatives we need (accounting for existing ones) + n_negative_target = int(n_positive * neg_ratio) + n_negative_needed = n_negative_target - n_existing_negative + if n_negative_needed <= 0: + # Enough negatives continue # Get existing cells at this rank @@ -402,24 +406,37 @@ def generate_negative_samples( # Only add if it doesn't exist yet if cell_tuple not in existing_cells: - # Add to complex and label as negative + # Add to complex (automiatically adds faces) sc.add_simplex(list(cell_tuple)) if rank not in cell_data: cell_data[rank] = {} + # Label as negative cell_data[rank][cell_tuple] = -1 - # For edges (rank 1), also create feature vector - if rank == 1 and cell_tuple not in edge_data: - edge_data[cell_tuple] = torch.tensor( - [0, 0, 0, 0, 0, 0, 0, 0.0], dtype=torch.float - ) + # Mark all proper sub-faces as negative (if they don't exist yet) + for sub_rank in range(1, rank): + if sub_rank not in cell_data: + cell_data[sub_rank] = {} + + for sub_face in combinations(cell_tuple, sub_rank + 1): + sub_tuple = tuple(sorted(sub_face)) + if sub_rank == 1 and sub_tuple not in edge_data: + # Edge: create feature vector + edge_data[sub_tuple] = torch.tensor( + [0, 0, 0, 0, 0, 0, 0, -1.0], dtype=torch.float + ) + elif sub_tuple not in cell_data[sub_rank]: + # Higher-order sub-face: mark as negative + cell_data[sub_rank][sub_tuple] = -1 existing_cells.add(cell_tuple) negatives_added += 1 + # Calculate total negatives (existing + newly added) + n_total_negative = n_existing_negative + negatives_added print( - f" Rank {rank}: {n_positive} positive, {negatives_added} negative samples" + f" Rank {rank}: {n_positive} positive, {n_total_negative} negative samples" ) return edge_data, cell_data @@ -499,7 +516,9 @@ def build_data_features_and_labels( feat_vec[7:8] ) # Last dim = feature elif edge_task == "score": - labels.append(feat_vec[7:8]) # Last dim = label + # Convert score back from [-1, 1] to [0, 1] for standard regression + score_normalized = (feat_vec[7:8] + 1) / 2 + labels.append(score_normalized) features.append( feat_vec[:7] ) # First 7 dims = features @@ -521,7 +540,8 @@ def build_data_features_and_labels( cell_tuple = tuple(sorted(cell)) binary_existence_val = cell_data[rank][cell_tuple] - # Features: 0 for target, {-1,+1} for non-target TODO: Bit unsure about this. Non-interacting edges also get 0 and 0 means it will not influence neighbors + # Features: 0 for target, {-1,+1} for non-target TODO: Bit unsure about this. + # Labels (only target): {0, 1} for PyTorch CrossEntropyLoss # TODO: Maybe we should pass some labels as features for true transductivity/semi-supervision? if is_target: # Target rank: features are 0, labels are in {-1, 1} @@ -542,10 +562,12 @@ def build_data_features_and_labels( x_dict[f"x_{rank}"] = torch.stack(features) if is_target: - labels_tensor = torch.tensor(labels, dtype=torch.long) - labels_dict[f"cell_labels_{rank}"] = labels_tensor - n_pos = (labels_tensor == 1).sum().item() - n_neg = (labels_tensor == 0).sum().item() + labels_tensor = torch.stack(labels).squeeze() + # Convert {-1, +1} → {0, 1} for PyTorch CrossEntropyLoss + labels_mapped = ((labels_tensor + 1) / 2).long() + labels_dict[f"cell_labels_{rank}"] = labels_mapped + n_pos = (labels_mapped == 1).sum().item() + n_neg = (labels_mapped == 0).sum().item() print( f" Rank {rank}: {n_pos} positive, {n_neg} negative labels" ) From 9c6047da5722a88b7191acec12b0abb7d3903a93 Mon Sep 17 00:00:00 2001 From: grapentt Date: Wed, 19 Nov 2025 00:01:59 +0100 Subject: [PATCH 19/32] :waste_basket: Remove unused imports --- test/data/load/test_datasetloaders.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/test/data/load/test_datasetloaders.py b/test/data/load/test_datasetloaders.py index cb21fd421..db838e99b 100644 --- a/test/data/load/test_datasetloaders.py +++ b/test/data/load/test_datasetloaders.py @@ -1,12 +1,9 @@ """Comprehensive test suite for all dataset loaders.""" -import os import pytest -import torch import hydra from pathlib import Path from typing import List, Tuple, Dict, Any -from omegaconf import DictConfig -from topobench.data.preprocessor.preprocessor import PreProcessor + class TestLoaders: """Comprehensive test suite for all dataset loaders.""" From 12fbb1a5a215a8bf1c729699208ea50fc9f765cc Mon Sep 17 00:00:00 2001 From: grapentt Date: Wed, 19 Nov 2025 11:03:49 +0100 Subject: [PATCH 20/32] :bug: 1D tensor label for edge-level regression --- .../data/utils/datasets/simplicial/ppi_utils.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/topobench/data/utils/datasets/simplicial/ppi_utils.py b/topobench/data/utils/datasets/simplicial/ppi_utils.py index 8271b14f3..d9539c9c2 100644 --- a/topobench/data/utils/datasets/simplicial/ppi_utils.py +++ b/topobench/data/utils/datasets/simplicial/ppi_utils.py @@ -517,7 +517,9 @@ def build_data_features_and_labels( ) # Last dim = feature elif edge_task == "score": # Convert score back from [-1, 1] to [0, 1] for standard regression - score_normalized = (feat_vec[7:8] + 1) / 2 + score_normalized = ( + (feat_vec[7] + 1) / 2 + ).item() # Scalar value labels.append(score_normalized) features.append( feat_vec[:7] @@ -529,7 +531,14 @@ def build_data_features_and_labels( x_dict["x_1"] = torch.stack(features) if is_target: - labels_dict["cell_labels_1"] = torch.stack(labels) + if edge_task == "score": + # For regression: 1D tensor of scalar values + labels_dict["cell_labels_1"] = torch.tensor( + labels, dtype=torch.float + ) + else: + # For multi-label classification: 2D tensor + labels_dict["cell_labels_1"] = torch.stack(labels) case _: # Higher-order cells From 95084b3c2ce7f1107fab9b1bfc93a6375ae03dbf Mon Sep 17 00:00:00 2001 From: grapentt Date: Wed, 19 Nov 2025 11:22:11 +0100 Subject: [PATCH 21/32] :construction: Prepare cell-level prediction --- topobench/model/model.py | 4 ++-- topobench/nn/readouts/base.py | 9 ++++++--- topobench/run.py | 6 ++++-- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/topobench/model/model.py b/topobench/model/model.py index a7c688b47..b708d96e8 100755 --- a/topobench/model/model.py +++ b/topobench/model/model.py @@ -242,8 +242,8 @@ def process_outputs(self, model_out: dict, batch: Data) -> dict: else: raise ValueError("Invalid state_str") - if self.task_level == "node": - # Keep only train data points + if self.task_level in ["node", "cell"]: + # Keep only train data points (for node-level or cell-level tasks) for key, val in model_out.items(): if key in ["logits", "labels"]: model_out[key] = val[mask] diff --git a/topobench/nn/readouts/base.py b/topobench/nn/readouts/base.py index 6fdd8412f..45c5ac69a 100755 --- a/topobench/nn/readouts/base.py +++ b/topobench/nn/readouts/base.py @@ -42,7 +42,9 @@ def __init__( if hidden_dim != out_channels or logits_linear_layer else torch.nn.Identity() ) - assert task_level in ["graph", "node"], "Invalid task_level" + assert task_level in ["graph", "node", "cell"], ( + "Invalid task_level. Must be 'graph', 'node', or 'cell'." + ) self.task_level = task_level self.logits_linear_layer = logits_linear_layer @@ -84,7 +86,7 @@ def compute_logits(self, x, batch): Parameters ---------- x : torch.Tensor - Node embeddings. + Cell embeddings. batch : torch.Tensor Batch index tensor. @@ -94,8 +96,9 @@ def compute_logits(self, x, batch): Logits tensor. """ if self.task_level == "graph": + # Graph-level: pool across batch (one prediction per graph) x = scatter(x, batch, dim=0, reduce=self.pooling_type) - + # Cell-level: no pooling (one prediction per Cell) return self.linear(x) @abstractmethod diff --git a/topobench/run.py b/topobench/run.py index ab6f8602a..a3e749411 100755 --- a/topobench/run.py +++ b/topobench/run.py @@ -171,7 +171,7 @@ def run(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]: ) # Prepare datamodule log.info("Instantiating datamodule...") - if cfg.dataset.parameters.task_level in ["node", "graph"]: + if cfg.dataset.parameters.task_level in ["node", "graph", "cell"]: datamodule = TBDataloader( dataset_train=dataset_train, dataset_val=dataset_val, @@ -179,7 +179,9 @@ def run(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]: **cfg.dataset.get("dataloader_params", {}), ) else: - raise ValueError("Invalid task_level") + raise ValueError( + f"Invalid task_level: {cfg.dataset.parameters.task_level}. Must be 'node', 'graph', or 'cell'." + ) # Model for us is Network + logic: inputs backbone, readout, losses log.info(f"Instantiating model <{cfg.model._target_}>") From e3bf31bc7d09c50a1b4f324d792988d95bf602a3 Mon Sep 17 00:00:00 2001 From: grapentt Date: Wed, 19 Nov 2025 18:24:59 +0100 Subject: [PATCH 22/32] :children_crossing: Infer num features --- configs/dataset/simplicial/ppi_highppi.yaml | 8 ++-- topobench/run.py | 4 ++ topobench/utils/config_resolvers.py | 48 +++++++++++++++++++++ 3 files changed, 55 insertions(+), 5 deletions(-) diff --git a/configs/dataset/simplicial/ppi_highppi.yaml b/configs/dataset/simplicial/ppi_highppi.yaml index a0f36793e..9960618f2 100644 --- a/configs/dataset/simplicial/ppi_highppi.yaml +++ b/configs/dataset/simplicial/ppi_highppi.yaml @@ -49,11 +49,9 @@ loader: # Model training configuration parameters: - # Feature dimensions: [rank-0, rank-1, ..., rank-max] - # rank-0: One-hot encoded proteins TODO: Replace with richer embedding - # rank-1: 8-dim edge features (7 interaction types + 1 confidence score) - # rank-2+: 1-dim features (binary existence) - num_features: [1553, 8, 1, 1, 1, 1] + _num_proteins: 1553 # HIGH-PPI has 1,553 proteins + + num_features: ${infer_ppi_num_features:${dataset.parameters._num_proteins},${dataset.loader.parameters.edge_task},${dataset.loader.parameters.max_complex_size}} num_classes: 2 # Depends on task: # - Higher-order (ranks 2+): 2 (exists/doesn't exist) diff --git a/topobench/run.py b/topobench/run.py index a3e749411..72693f9f9 100755 --- a/topobench/run.py +++ b/topobench/run.py @@ -34,6 +34,7 @@ get_required_lifting, infer_in_channels, infer_num_cell_dimensions, + infer_ppi_num_features, infer_topotune_num_cell_dimensions, ) @@ -95,6 +96,9 @@ infer_topotune_num_cell_dimensions, replace=True, ) +OmegaConf.register_new_resolver( + "infer_ppi_num_features", infer_ppi_num_features, replace=True +) OmegaConf.register_new_resolver( "parameter_multiplication", lambda x, y: int(int(x) * int(y)), replace=True ) diff --git a/topobench/utils/config_resolvers.py b/topobench/utils/config_resolvers.py index 65ab69667..bafb6f6dd 100644 --- a/topobench/utils/config_resolvers.py +++ b/topobench/utils/config_resolvers.py @@ -497,3 +497,51 @@ def get_default_metrics(task, metrics=None): return ["mse", "mae"] else: raise ValueError(f"Invalid task {task}") + + +def infer_ppi_num_features(num_proteins, edge_task, max_complex_size): + r"""Infer feature dimensions for HIGH-PPI dataset. + + For simplicial complexes from HIGH-PPI: + - Rank 0 (proteins): One-hot encoding (num_proteins features) + - Rank 1 (edges): 7 or 8 features depending on edge_task + - If edge_task="score": 7 features (8th is label) + - If edge_task="type": 8 features (all features) + - Rank 2+: 1 feature (binary existence) + + Parameters + ---------- + num_proteins : int + Number of proteins (for one-hot encoding dimension). + edge_task : str + Edge task type: "score" (regression) or "type" (classification). + max_complex_size : int + Maximum number of proteins per complex (determines number of ranks). + + Returns + ------- + list + List of feature dimensions per rank: [rank_0, rank_1, rank_2, ...]. + + Examples + -------- + >>> infer_ppi_num_features(1553, "score", 6) + [1553, 7, 1, 1, 1, 1] # 7 edge features (score is label) + + >>> infer_ppi_num_features(1553, "type", 6) + [1553, 8, 1, 1, 1, 1] # 8 edge features (all features) + """ + # Rank 0: protein features (one-hot) + features = [num_proteins] + + # Rank 1: edge features (depends on task) + if edge_task == "score": + features.append(7) # 7 features, 8th (score) becomes label + else: + features.append(1) + + # Rank 2+: cell features (binary existence) + num_higher_ranks = max_complex_size - 2 # Subtract rank 0 and rank 1 + features.extend([1] * num_higher_ranks) + + return features From e2e63bafc4264b91d211c3a83c253b8a61584e6b Mon Sep 17 00:00:00 2001 From: grapentt Date: Wed, 19 Nov 2025 18:48:36 +0100 Subject: [PATCH 23/32] :technologist: Filter labels to high-ppi edges + infrastructure for general masking which can be used for semi-supervision --- .../data/datasets/ppi_highppi_dataset.py | 6 + .../utils/datasets/simplicial/ppi_utils.py | 17 ++ topobench/data/utils/split_utils.py | 165 ++++++++++++++---- topobench/model/model.py | 17 ++ 4 files changed, 175 insertions(+), 30 deletions(-) diff --git a/topobench/data/datasets/ppi_highppi_dataset.py b/topobench/data/datasets/ppi_highppi_dataset.py index bc439b340..e8558941f 100644 --- a/topobench/data/datasets/ppi_highppi_dataset.py +++ b/topobench/data/datasets/ppi_highppi_dataset.py @@ -253,6 +253,11 @@ def process(self): highppi_path, self.INTERACTION_TYPES ) + # Create set of HIGH-PPI edges to use as prediction-target for edge level tasks + highppi_edge_set = { + tuple(sorted([p1, p2])) for p1, p2, _, _ in self.highppi_edges + } + # Load CORUM complexes, filter to SHS27k proteins corum_path = osp.join(self.raw_dir, self.CORUM_COMPLEXES_FILE) self.corum_complexes = load_corum_complexes( @@ -288,6 +293,7 @@ def process(self): self.target_ranks, self.max_rank, edge_task=self.edge_task, + highppi_edge_set=highppi_edge_set, ) # Get connectivity diff --git a/topobench/data/utils/datasets/simplicial/ppi_utils.py b/topobench/data/utils/datasets/simplicial/ppi_utils.py index d9539c9c2..6218d4742 100644 --- a/topobench/data/utils/datasets/simplicial/ppi_utils.py +++ b/topobench/data/utils/datasets/simplicial/ppi_utils.py @@ -449,6 +449,7 @@ def build_data_features_and_labels( target_ranks: list[int], max_rank: int, edge_task: str = None, + highppi_edge_set: set[tuple] = None, ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: """Create feature and label tensors for all ranks. @@ -467,6 +468,9 @@ def build_data_features_and_labels( edge_task : str, optional Edge prediction task: "interaction_type" or "score". Only used if rank 1 is in target_ranks. + highppi_edge_set : set, optional + Set of edge tuples from HIGH-PPI. + Used to filter edges for edge-level tasks. Returns ------- @@ -501,12 +505,18 @@ def build_data_features_and_labels( # Edges: 8-dim features (7 interaction types + 1 confidence) features = [] labels = [] if is_target else None + # Track HIGH-PPI edges for loss filtering (all edge tasks) + highppi_mask = [] if (is_target and highppi_edge_set) else None for edge in cells: edge_tuple = tuple(sorted(edge)) feat_vec = edge_data[edge_tuple] if is_target: + # Track if this edge is from HIGH-PPI (not CORUM-generated) + if highppi_mask is not None: + highppi_mask.append(edge_tuple in highppi_edge_set) + # Split features/labels based on edge_task if edge_task == "interaction_type": labels.append( @@ -540,6 +550,13 @@ def build_data_features_and_labels( # For multi-label classification: 2D tensor labels_dict["cell_labels_1"] = torch.stack(labels) + # Store mask to filter out CORUM-generated edges during training + # This applies to ALL edge tasks (score regression, interaction type, etc.) + if highppi_mask is not None: + labels_dict["mask_1"] = torch.tensor( + highppi_mask, dtype=torch.bool + ) + case _: # Higher-order cells features = [] diff --git a/topobench/data/utils/split_utils.py b/topobench/data/utils/split_utils.py index cf63b50a4..e1722c94a 100644 --- a/topobench/data/utils/split_utils.py +++ b/topobench/data/utils/split_utils.py @@ -1,8 +1,10 @@ """Split utilities.""" import os +from typing import Any import numpy as np +import pandas as pd import torch from sklearn.model_selection import StratifiedKFold @@ -155,9 +157,15 @@ def random_splitting(labels, parameters, root=None, global_data_seed=42): val_indices = perm[train_num : train_num + valid_num] test_indices = perm[train_num + valid_num :] split_idx = { - "train": train_indices, - "valid": val_indices, - "test": test_indices, + "train": train_indices.numpy() + if hasattr(train_indices, "numpy") + else np.array(train_indices), + "valid": val_indices.numpy() + if hasattr(val_indices, "numpy") + else np.array(val_indices), + "test": test_indices.numpy() + if hasattr(test_indices, "numpy") + else np.array(test_indices), } # Save generated split @@ -169,13 +177,16 @@ def random_splitting(labels, parameters, root=None, global_data_seed=42): split_idx = np.load(split_path) # Check that all nodes/graph have been assigned to some split - assert np.unique( - np.array( - split_idx["train"].tolist() - + split_idx["valid"].tolist() - + split_idx["test"].tolist() - ) - ).shape[0] == len(labels), "Not all nodes within splits" + train_arr = split_idx["train"] + val_arr = split_idx["valid"] + test_arr = split_idx["test"] + + all_indices = np.concatenate([train_arr, val_arr, test_arr]) + unique_indices = np.unique(all_indices) + + assert unique_indices.shape[0] == len(labels), ( + f"Not all nodes within splits: {unique_indices.shape[0]} != {len(labels)}" + ) return split_idx @@ -257,16 +268,40 @@ def load_transductive_splits(dataset, parameters): # Single rank or node/graph prediction labels = data.y.numpy() - # Ensure labels are one dimensional array - assert len(labels.shape) == 1, "Labels should be one dimensional array" + # Check for rank-specific mask (e.g., mask_1 for edges) + # If present, split only on filtered/valid entities for honest ratios + rank_mask = None + valid_indices = None + target_ranks = getattr( + dataset, "target_ranks", [1] + ) # Default to rank 1 for edges + + if target_ranks: + rank = target_ranks[0] # Single rank case + mask_attr = f"mask_{rank}" + if hasattr(data, mask_attr): + rank_mask = getattr(data, mask_attr) # Boolean mask + valid_indices = torch.where(rank_mask)[ + 0 + ] # Original indices of valid entities + labels = labels[rank_mask.numpy()] # Filter to valid entities only + + # Handle multi-dimensional labels (e.g., multi-label classification) + if len(labels.shape) > 1: + # Use first column for stratification (common practice) + stratify_labels = ( + labels[:, 0] if labels.shape[1] > 0 else labels.flatten() + ) + else: + stratify_labels = labels root = dataset.get_data_dir() if hasattr(dataset, "get_data_dir") else None if parameters.split_type == "random": - splits = random_splitting(labels, parameters, root=root) + splits = random_splitting(stratify_labels, parameters, root=root) elif parameters.split_type == "k-fold": - splits = k_fold_split(labels, parameters, root=root) + splits = k_fold_split(stratify_labels, parameters, root=root) elif parameters.split_type == "fixed" and hasattr(dataset, "split_idx"): splits = dataset.split_idx @@ -283,9 +318,21 @@ def load_transductive_splits(dataset, parameters): ) # Assign train val test masks to the graph - data.train_mask = torch.from_numpy(splits["train"]) - data.val_mask = torch.from_numpy(splits["valid"]) - data.test_mask = torch.from_numpy(splits["test"]) + # If we filtered by rank_mask, map indices back to original positions + if valid_indices is not None: + # Splits are indices into filtered data, map back to original + train_mask = valid_indices[torch.from_numpy(splits["train"])] + val_mask = valid_indices[torch.from_numpy(splits["valid"])] + test_mask = valid_indices[torch.from_numpy(splits["test"])] + else: + # No filtering: use indices directly + train_mask = torch.from_numpy(splits["train"]) + val_mask = torch.from_numpy(splits["valid"]) + test_mask = torch.from_numpy(splits["test"]) + + data.train_mask = train_mask + data.val_mask = val_mask + data.test_mask = test_mask if parameters.get("standardize", False): # Standardize the node features respecting train mask @@ -299,12 +346,55 @@ def load_transductive_splits(dataset, parameters): return DataloadDataset([data]), None, None -def load_multirank_transductive_splits(dataset, parameters): +def get_multilabel_stratification_targets( + labels: np.ndarray | pd.DataFrame, +) -> np.ndarray: + """Generate a single stratification target vector for multi-label data. + + For multi-label classification, uses the index of the most frequent label + per sample (argmax). This is simpler and more robust than Label Powerset, + avoiding issues with rare label combinations. + + Parameters + ---------- + labels : np.ndarray or pd.DataFrame + The multi-label target array (2D) or vector (1D). + Can be a NumPy array or Pandas DataFrame. + + Returns + ------- + np.ndarray + A 1D array suitable for the 'stratify' parameter in sklearn. + For 1D input: returns as-is. + For 2D input: returns argmax (most frequent label index). + """ + # Standardize input to NumPy array + if isinstance(labels, pd.DataFrame): + labels = labels.values + + # Handle 1D arrays (standard classification) + if labels.ndim == 1: + return labels + + # Handle 2D arrays (multi-label classification) + if labels.ndim == 2 and labels.shape[1] > 1: + # Use argmax: index of most frequent label (or first '1' for binary) + # This ensures stratification works even with rare label combinations + return labels.argmax(axis=1) + + # Fallback for 2D arrays with single column + return labels.flatten() + + +def load_multirank_transductive_splits( + dataset, parameters +) -> tuple[list[Any], None, None]: r"""Load dataset with multi-rank cell-level splits. For datasets with cell-level predictions across multiple ranks (e.g., edges, triangles, tetrahedra simultaneously), this function creates independent - train/val/test splits for each rank. + train/val/test splits for each rank on valid entities (filtered by masks) + using multi-label stratification. Parameters ---------- @@ -350,14 +440,20 @@ def load_multirank_transductive_splits(dataset, parameters): labels = getattr(data, label_attr).numpy() - # Handle multi-dimensional labels (e.g., multi-label classification) - if len(labels.shape) > 1: - # Use first column for stratification (common practice) - stratify_labels = ( - labels[:, 0] if labels.shape[1] > 0 else labels.flatten() - ) - else: - stratify_labels = labels + # Check for rank-specific mask + # If present, split only on filtered entities for honest ratios + rank_mask = None + valid_indices = None + mask_attr = f"mask_{rank}" + + if hasattr(data, mask_attr): + rank_mask = getattr(data, mask_attr) # Boolean mask + valid_indices = torch.where(rank_mask)[ + 0 + ] # Original indices of valid entities + labels = labels[rank_mask.numpy()] # Filter to valid entities only + + stratify_labels = get_multilabel_stratification_targets(labels) # Create rank-specific root directory for splits # This ensures each rank gets independent splits @@ -387,14 +483,23 @@ def load_multirank_transductive_splits(dataset, parameters): ) # Store per-rank masks - train_mask = torch.from_numpy(splits["train"]) - val_mask = torch.from_numpy(splits["valid"]) - test_mask = torch.from_numpy(splits["test"]) + # If we filtered by rank_mask, map indices back to original positions + if valid_indices is not None: + # Splits are indices into filtered data, map back to original + train_mask = valid_indices[torch.from_numpy(splits["train"])] + val_mask = valid_indices[torch.from_numpy(splits["valid"])] + test_mask = valid_indices[torch.from_numpy(splits["test"])] + else: + # No filtering: use indices directly + train_mask = torch.from_numpy(splits["train"]) + val_mask = torch.from_numpy(splits["valid"]) + test_mask = torch.from_numpy(splits["test"]) setattr(data, f"train_mask_{rank}", train_mask) setattr(data, f"val_mask_{rank}", val_mask) setattr(data, f"test_mask_{rank}", test_mask) + # Assumes DataloadDataset is available in scope return DataloadDataset([data]), None, None diff --git a/topobench/model/model.py b/topobench/model/model.py index b708d96e8..776487463 100755 --- a/topobench/model/model.py +++ b/topobench/model/model.py @@ -247,6 +247,23 @@ def process_outputs(self, model_out: dict, batch: Data) -> dict: for key, val in model_out.items(): if key in ["logits", "labels"]: model_out[key] = val[mask] + # Note: Rank-specific masks are applied in readout + # The readout stores which indices it kept in cell_indices + if "cell_indices" in model_out: + # Find intersection: which readout outputs are in this split? + # Note: The split respects the masks applied in the readout + cell_indices = model_out["cell_indices"] + keep_mask = torch.isin(cell_indices, mask) + + # Filter logits and labels + for key, val in model_out.items(): + if key in ["logits", "labels"]: + model_out[key] = val[keep_mask] + else: + # No cell_indices: standard filtering (for non-masked tasks) + for key, val in model_out.items(): + if key in ["logits", "labels"]: + model_out[key] = val[mask] return model_out From 51157e3501833eb48d35b58380c05a8be4e7e658 Mon Sep 17 00:00:00 2001 From: grapentt Date: Wed, 19 Nov 2025 19:44:51 +0100 Subject: [PATCH 24/32] :sparkles: Extend SCCNN to work for arbitrary ranks --- topobench/nn/backbones/simplicial/sccnn.py | 523 +++++++++++---------- 1 file changed, 279 insertions(+), 244 deletions(-) diff --git a/topobench/nn/backbones/simplicial/sccnn.py b/topobench/nn/backbones/simplicial/sccnn.py index 72bac2902..cff5830dd 100644 --- a/topobench/nn/backbones/simplicial/sccnn.py +++ b/topobench/nn/backbones/simplicial/sccnn.py @@ -13,13 +13,13 @@ class SCCNNCustom(torch.nn.Module): Parameters ---------- in_channels_all : tuple of int - Dimension of input features on (nodes, edges, faces). + Dimension of input features on each rank (nodes, edges, faces, ...). hidden_channels_all : tuple of int - Dimension of features of hidden layers on (nodes, edges, faces). + Dimension of features of hidden layers on each rank. conv_order : int Order of convolutions, we consider the same order for all convolutions. sc_order : int - Order of simplicial complex. + Order of simplicial complex (max_rank + 1). aggr_norm : bool, optional Whether to normalize the aggregation (default: False). update_func : str, optional @@ -39,16 +39,15 @@ def __init__( n_layers=2, ): super().__init__() - # first layer - # we use an MLP to map the features on simplices of different dimensions to the same dimension - self.in_linear_0 = torch.nn.Linear( - in_channels_all[0], hidden_channels_all[0] - ) - self.in_linear_1 = torch.nn.Linear( - in_channels_all[1], hidden_channels_all[1] - ) - self.in_linear_2 = torch.nn.Linear( - in_channels_all[2], hidden_channels_all[2] + + self.max_rank = len(in_channels_all) - 1 + + # Create input linear layers for each rank dynamically + self.in_linears = torch.nn.ModuleList( + [ + torch.nn.Linear(in_channels_all[i], hidden_channels_all[i]) + for i in range(len(in_channels_all)) + ] ) self.layers = torch.nn.ModuleList( @@ -69,28 +68,29 @@ def forward(self, x_all, laplacian_all, incidence_all): Parameters ---------- x_all : tuple(tensors) - Tuple of feature tensors (node, edge, face). + Tuple of feature tensors for each rank (x_0, x_1, ..., x_k). laplacian_all : tuple(tensors) - Tuple of Laplacian tensors (graph laplacian L0, down edge laplacian L1_d, upper edge laplacian L1_u, face laplacian L2). + Tuple of Laplacian tensors. incidence_all : tuple(tensors) - Tuple of order 1 and 2 incidence matrices. + Tuple of incidence matrices. Returns ------- tuple(tensors) - Tuple of final hidden state tensors (node, edge, face). + Tuple of final hidden state tensors for each rank. """ - x_0, x_1, x_2 = x_all - in_x_0 = self.in_linear_0(x_0) - in_x_1 = self.in_linear_1(x_1) - in_x_2 = self.in_linear_2(x_2) + # Apply input linear transformations to each rank + x_all_transformed = tuple( + self.in_linears[i](x_all[i]) for i in range(len(x_all)) + ) - # Forward through SCCNN - x_all = (in_x_0, in_x_1, in_x_2) + # Forward through SCCNN layers for layer in self.layers: - x_all = layer(x_all, laplacian_all, incidence_all) + x_all_transformed = layer( + x_all_transformed, laplacian_all, incidence_all + ) - return x_all + return x_all_transformed class SCCNNLayer(torch.nn.Module): @@ -99,13 +99,13 @@ class SCCNNLayer(torch.nn.Module): Parameters ---------- in_channels : tuple of int - Dimensions of input features on nodes, edges, and faces. + Dimensions of input features for each rank. out_channels : tuple of int - Dimensions of output features on nodes, edges, and faces. + Dimensions of output features for each rank. conv_order : int Convolution order of the simplicial filters. sc_order : int - SC order. + SC order (max_rank + 1). aggr_norm : bool, optional Whether to normalize the aggregated message by the neighborhood size (default: False). update_func : str, optional @@ -126,15 +126,9 @@ def __init__( ) -> None: super().__init__() - in_channels_0, in_channels_1, in_channels_2 = in_channels - out_channels_0, out_channels_1, out_channels_2 = out_channels - - self.in_channels_0 = in_channels_0 - self.in_channels_1 = in_channels_1 - self.in_channels_2 = in_channels_2 - self.out_channels_0 = out_channels_0 - self.out_channels_1 = out_channels_1 - self.out_channels_2 = out_channels_2 + self.in_channels = tuple(in_channels) + self.out_channels = tuple(out_channels) + self.max_rank = len(in_channels) - 1 self.conv_order = conv_order self.sc_order = sc_order @@ -146,47 +140,65 @@ def __init__( assert initialization in ["xavier_uniform", "xavier_normal"] assert self.conv_order > 0 - self.weight_0 = Parameter( - torch.Tensor( - self.in_channels_0, - self.out_channels_0, - 1 + conv_order + 1 + conv_order, - ) - ) - - self.weight_1 = Parameter( - torch.Tensor( - self.in_channels_1, - self.out_channels_1, - 6 * conv_order + 3, - ) - ) + # Create weight parameters for each rank + self.weights = torch.nn.ParameterList() - # determine the third dimensions of the weights - # because when SC order is larger than 2, there are lower and upper - # parts for L_2; otherwise, L_2 contains only the lower part + for rank in range(self.max_rank + 1): + # Calculate weight tensor dimensions based on message types + # For rank k, we have: + # - Identity: 1 + # - Self convolutions: conv_order (down) + conv_order (up) for k>0, or just conv_order for k=0 + # - Lower messages (from k-1): 1 + conv_order (identity + convolutions) + # - Upper messages (from k+1): 1 + conv_order (identity + convolutions) - if sc_order > 2: - self.weight_2 = Parameter( - torch.Tensor( - self.in_channels_2, - self.out_channels_2, - 4 * conv_order - + 2, # in the future for arbitrary sc_order we should have this 6*conv_order + 3, - ) - ) + num_message_types = self._compute_message_types(rank) - elif sc_order == 2: - self.weight_2 = Parameter( + weight = Parameter( torch.Tensor( - self.in_channels_2, - self.out_channels_2, - 4 * conv_order + 2, + self.in_channels[rank], + self.out_channels[rank], + num_message_types, ) ) + self.weights.append(weight) self.reset_parameters() + def _compute_message_types(self, rank): + """Compute the maximum number of message types for a given rank. + + Parameters + ---------- + rank : int + Rank to consider. + + Returns + ------- + int + Number of message types for the given rank. + """ + count = 0 + + # Self messages + if rank == 0: + # Rank 0: identity + Hodge Laplacian convolutions + count += 1 + self.conv_order + else: + # Rank k>0: identity + down Laplacian + up Laplacian convolutions + count += 1 + self.conv_order + self.conv_order + + # Lower messages (from rank-1 projected to rank) + if rank > 0: + # Identity + convolutions with down/up Laplacians at current rank + count += 1 + self.conv_order + + # Upper messages (from rank+1 projected to rank) + if rank < self.max_rank: + # Identity + convolutions with down/up Laplacians at current rank + count += 1 + self.conv_order + + return count + def reset_parameters(self, gain: float = 1.414): r"""Reset learnable parameters. @@ -196,13 +208,11 @@ def reset_parameters(self, gain: float = 1.414): Gain for the weight initialization. """ if self.initialization == "xavier_uniform": - torch.nn.init.xavier_uniform_(self.weight_0, gain=gain) - torch.nn.init.xavier_uniform_(self.weight_1, gain=gain) - torch.nn.init.xavier_uniform_(self.weight_2, gain=gain) + for weight in self.weights: + torch.nn.init.xavier_uniform_(weight, gain=gain) elif self.initialization == "xavier_normal": - torch.nn.init.xavier_normal_(self.weight_0, gain=gain) - torch.nn.init.xavier_normal_(self.weight_1, gain=gain) - torch.nn.init.xavier_normal_(self.weight_2, gain=gain) + for weight in self.weights: + torch.nn.init.xavier_normal_(weight, gain=gain) else: raise RuntimeError( "Initialization method not recognized. " @@ -286,201 +296,226 @@ def chebyshev_conv(self, conv_operator, conv_order, x): return X def forward(self, x_all, laplacian_all, incidence_all): - r"""Forward computation. + r"""Forward computation for arbitrary ranks. Parameters ---------- x_all : tuple of tensors - Tuple of input feature tensors (node, edge, face). + Tuple of input feature tensors for each rank. laplacian_all : tuple of tensors - Tuple of Laplacian tensors (graph laplacian L0, down edge laplacian L1_d, upper edge laplacian L1_u, face laplacian L2). + Tuple of Laplacian tensors organized as: + (L_0, L_down_1, L_up_1, L_down_2, L_up_2, ...). incidence_all : tuple of tensors - Tuple of order 1 and 2 incidence matrices. + Tuple of incidence matrices (B_1, B_2, ..., B_k). Returns ------- - torch.Tensor - Output tensor for each 0-cell. - torch.Tensor - Output tensor for each 1-cell. - torch.Tensor - Output tensor for each 2-cell. - """ - x_0, x_1, x_2 = x_all - - if self.sc_order == 2: - laplacian_0, laplacian_down_1, laplacian_up_1, laplacian_2 = ( - laplacian_all - ) - elif self.sc_order > 2: - ( - laplacian_0, - laplacian_down_1, - laplacian_up_1, - laplacian_down_2, - laplacian_up_2, - ) = laplacian_all - - # num_nodes, num_edges, num_triangles = x_0.shape[0], x_1.shape[0], x_2.shape[0] - - b1, b2 = incidence_all - - # identity_0, identity_1, identity_2 = ( - # torch.eye(num_nodes).to(x_0.device), - # torch.eye(num_edges).to(x_0.device), - # torch.eye(num_triangles).to(x_0.device), - # ) - """ - Convolution in the node space + tuple of tensors + Output tensors for each rank after message passing. """ - # -----------Logic to obtain update for 0-cells -------- - # x_identity_0 = torch.unsqueeze(identity_0 @ x_0, 2) - # x_0_to_0 = self.chebyshev_conv(laplacian_0, self.conv_order, x_0) - # x_0_to_0 = torch.cat((x_identity_0, x_0_to_0), 2) - - x_0_laplacian = self.chebyshev_conv(laplacian_0, self.conv_order, x_0) - x_0_to_0 = torch.cat([x_0.unsqueeze(2), x_0_laplacian], dim=2) - # ------------------- - - # x_1_to_0 = torch.mm(b1, x_1) - # x_1_to_0_identity = torch.unsqueeze(identity_0 @ x_1_to_0, 2) - # x_1_to_0 = self.chebyshev_conv(laplacian_0, self.conv_order, x_1_to_0) - # x_1_to_0 = torch.cat((x_1_to_0_identity, x_1_to_0), 2) - - x_1_to_0_upper = torch.mm(b1, x_1) - x_1_to_0_laplacian = self.chebyshev_conv( - laplacian_0, self.conv_order, x_1_to_0_upper - ) - x_1_to_0 = torch.cat( - [x_1_to_0_upper.unsqueeze(2), x_1_to_0_laplacian], dim=2 - ) - # ------------------- - - x_0_all = torch.cat((x_0_to_0, x_1_to_0), 2) - - # ------------------- - """ - Convolution in the edge space - """ - - # -----------Logic to obtain update for 1-cells -------- - # x_identity_1 = torch.unsqueeze(identity_1 @ x_1, 2) - # x_1_down = self.chebyshev_conv(laplacian_down_1, self.conv_order, x_1) - # x_1_up = self.chebyshev_conv(laplacian_up_1, self.conv_order, x_1) - # x_1_to_1 = torch.cat((x_identity_1, x_1_down, x_1_up), 2) - - x_1_down = self.chebyshev_conv(laplacian_down_1, self.conv_order, x_1) - x_1_up = self.chebyshev_conv(laplacian_down_1, self.conv_order, x_1) - x_1_to_1 = torch.cat((x_1.unsqueeze(2), x_1_down, x_1_up), 2) - - # ------------------- - - # x_0_to_1 = torch.mm(b1.T, x_0) - # x_0_to_1_identity = torch.unsqueeze(identity_1 @ x_0_to_1, 2) - # x_0_to_1 = self.chebyshev_conv(laplacian_down_1, self.conv_order, x_0_to_1) - # x_0_to_1 = torch.cat((x_0_to_1_identity, x_0_to_1), 2) + outputs = [] + + for rank in range(self.max_rank + 1): + x_rank = x_all[rank] + + # Skip empty ranks (no cells at this dimension) + if x_rank.shape[0] == 0: + # Create empty output tensor for this rank + outputs.append( + torch.zeros( + 0, self.out_channels[rank], device=x_rank.device + ) + ) + continue - # Lower projection - x_0_1_lower = torch.mm(b1.T, x_0) + # Get Laplacians for this rank + laplacians = self._get_laplacians_for_rank(rank, laplacian_all) - # Calculate lowwer chebyshev_conv - x_0_1_down = self.chebyshev_conv( - laplacian_down_1, self.conv_order, x_0_1_lower - ) + # Get incidence matrices + incidence_lower = ( + incidence_all[rank - 1] + if rank > 0 and rank - 1 < len(incidence_all) + else None + ) + incidence_upper = ( + incidence_all[rank] if rank < len(incidence_all) else None + ) - # Calculate upper chebyshev_conv (Note: in case of signed incidence should be always zero) - x_0_1_up = self.chebyshev_conv( - laplacian_up_1, self.conv_order, x_0_1_lower - ) + # Compute all messages for this rank + messages = self._compute_messages_for_rank( + rank, + x_rank, + x_all, + laplacians, + incidence_lower, + incidence_upper, + ) - # Concatenate output of filters - x_0_to_1 = torch.cat( - [x_0_1_lower.unsqueeze(2), x_0_1_down, x_0_1_up], dim=2 - ) - # ------------------- + # Apply weight and aggregate + # Use only the first k dimensions of weights that match the number of messages + num_messages = messages.shape[2] + weight_slice = self.weights[rank][:, :, :num_messages] + y_rank = torch.einsum("nik,iok->no", messages, weight_slice) - # x_2_to_1 = torch.mm(b2, x_2) - # x_2_to_1_identity = torch.unsqueeze(identity_1 @ x_2_to_1, 2) - # x_2_to_1 = self.chebyshev_conv(laplacian_up_1, self.conv_order, x_2_to_1) - # x_2_to_1 = torch.cat((x_2_to_1_identity, x_2_to_1), 2) + # Apply activation if specified + if self.update_func is not None: + y_rank = self.update(y_rank) - x_2_1_upper = torch.mm(b2, x_2) + outputs.append(y_rank) - # Calculate lowwer chebyshev_conv (Note: In case of signed incidence should be always zero) - x_2_1_down = self.chebyshev_conv( - laplacian_down_1, self.conv_order, x_2_1_upper - ) + return tuple(outputs) - # Calculate upper chebyshev_conv - x_2_1_up = self.chebyshev_conv( - laplacian_up_1, self.conv_order, x_2_1_upper - ) + def _get_laplacians_for_rank(self, rank, laplacian_all): + """Extract Laplacians for a given rank from laplacian_all. - x_2_to_1 = torch.cat( - [x_2_1_upper.unsqueeze(2), x_2_1_down, x_2_1_up], dim=2 - ) + Parameters + ---------- + rank : int + The rank to extract Laplacians for. + laplacian_all : tuple + All Laplacians organized as (L_0, L_down_1, L_up_1, L_down_2, L_up_2, ...). - # ------------------- - x_1_all = torch.cat((x_0_to_1, x_1_to_1, x_2_to_1), 2) - """Convolution in the face (triangle) space, depending on the SC order, - the exact form maybe a little different.""" - # -------------------Logic to obtain update for 2-cells -------- - # x_identity_2 = torch.unsqueeze(identity_2 @ x_2, 2) - - # if self.sc_order == 2: - # x_2 = self.chebyshev_conv(laplacian_2, self.conv_order, x_2) - # x_2_to_2 = torch.cat((x_identity_2, x_2), 2) - # elif self.sc_order > 2: - # x_2_down = self.chebyshev_conv(laplacian_down_2, self.conv_order, x_2) - # x_2_up = self.chebyshev_conv(laplacian_up_2, self.conv_order, x_2) - # x_2_to_2 = torch.cat((x_identity_2, x_2_down, x_2_up), 2) - x_2_down = self.chebyshev_conv(laplacian_down_2, self.conv_order, x_2) - x_2_up = self.chebyshev_conv(laplacian_up_2, self.conv_order, x_2) - x_2_to_2 = torch.cat((x_2.unsqueeze(2), x_2_down, x_2_up), 2) - - # ------------------- - - # x_1_to_2 = torch.mm(b2.T, x_1) - # x_1_to_2_identity = torch.unsqueeze(identity_2 @ x_1_to_2, 2) - # if self.sc_order == 2: - # x_1_to_2 = self.chebyshev_conv(laplacian_2, self.conv_order, x_1_to_2) - # elif self.sc_order > 2: - # x_1_to_2 = self.chebyshev_conv(laplacian_down_2, self.conv_order, x_1_to_2) - # x_1_to_2 = torch.cat((x_1_to_2_identity, x_1_to_2), 2) - - x_1_2_lower = torch.mm(b2.T, x_1) - x_1_2_down = self.chebyshev_conv( - laplacian_down_2, self.conv_order, x_1_2_lower - ) - x_1_2_down = self.chebyshev_conv( - laplacian_up_2, self.conv_order, x_1_2_lower - ) + Returns + ------- + dict + Dictionary with keys 'hodge', 'down', 'up' containing the relevant Laplacians. + """ + laplacians = {} - x_1_to_2 = torch.cat( - [x_1_2_lower.unsqueeze(2), x_1_2_down, x_1_2_down], dim=2 - ) + if rank == 0: + # Rank 0 only has Hodge Laplacian + laplacians["hodge"] = ( + laplacian_all[0] if len(laplacian_all) > 0 else None + ) + laplacians["down"] = None + laplacians["up"] = None + else: + # For rank k > 0: index is 1 + 2*(k-1) for down, 1 + 2*(k-1) + 1 for up + idx_down = 1 + 2 * (rank - 1) + idx_up = idx_down + 1 + + laplacians["hodge"] = None + laplacians["down"] = ( + laplacian_all[idx_down] + if idx_down < len(laplacian_all) + else None + ) + laplacians["up"] = ( + laplacian_all[idx_up] if idx_up < len(laplacian_all) else None + ) - # That is my code, but to execute this part we need to have simplices order of k+1 in this case order of 3 - # x_3_2_upper = x_1_to_2 = torch.mm(b2, x_3) - # x_3_2_down = self.chebyshev_conv(laplacian_down_2, self.conv_order, x_3_2_upper) - # x_3_2_up = self.chebyshev_conv(laplacian_up_2, self.conv_order, x_3_2_upper) + return laplacians - # x_3_to_2 = torch.cat([x_3_2_upper.unsueeze(2), x_3_2_down, x_3_2_up], dim=2) + def _compute_messages_for_rank( + self, rank, x_rank, x_all, laplacians, incidence_lower, incidence_upper + ): + """Compute all messages for a given rank. - # ------------------- + Parameters + ---------- + rank : int + The rank to compute messages for. + x_rank : tensor + Features of cells at this rank. + x_all : tuple + Features of all ranks. + laplacians : dict + Dictionary of Laplacians for this rank. + incidence_lower : tensor or None + Incidence matrix from rank-1 to rank. + incidence_upper : tensor or None + Incidence matrix from rank to rank+1. - x_2_all = torch.cat([x_1_to_2, x_2_to_2], dim=2) - # The final version of this model should have the following line - # x_2_all = torch.cat([x_1_to_2, x_2_to_2, x_3_to_2], dim=2) + Returns + ------- + tensor + Concatenated messages of shape (num_cells, num_channels, num_message_types). + """ + message_list = [] - # ------------------- + # 1. Self messages (identity + convolutions) + if rank == 0: + # Identity message + message_list.append(x_rank.unsqueeze(2)) - # Need to check that this einsums are correct - y_0 = torch.einsum("nik,iok->no", x_0_all, self.weight_0) - y_1 = torch.einsum("nik,iok->no", x_1_all, self.weight_1) - y_2 = torch.einsum("nik,iok->no", x_2_all, self.weight_2) + # Hodge Laplacian convolutions + if laplacians["hodge"] is not None: + x_conv = self.chebyshev_conv( + laplacians["hodge"], self.conv_order, x_rank + ) + message_list.append(x_conv) + else: + # Identity message + message_list.append(x_rank.unsqueeze(2)) - if self.update_func is None: - return y_0, y_1, y_2 + # Down Laplacian convolutions + if laplacians["down"] is not None: + x_down = self.chebyshev_conv( + laplacians["down"], self.conv_order, x_rank + ) + message_list.append(x_down) - return self.update(y_0), self.update(y_1), self.update(y_2) + # Up Laplacian convolutions + if laplacians["up"] is not None: + x_up = self.chebyshev_conv( + laplacians["up"], self.conv_order, x_rank + ) + message_list.append(x_up) + + # 2. Lower messages (from rank-1) + if rank > 0 and incidence_lower is not None and rank - 1 < len(x_all): + x_lower = x_all[rank - 1] + # Only process if lower rank is not empty + if x_lower.shape[0] > 0: + # Project features from rank-1 to rank + x_lower_proj = torch.mm(incidence_lower.T, x_lower) + + # Identity + message_list.append(x_lower_proj.unsqueeze(2)) + + # Apply Laplacian convolutions at the current rank + # Use the appropriate Laplacian (down for rank 0, down for rank > 0) + if rank == 0: + if laplacians["hodge"] is not None: + x_lower_conv = self.chebyshev_conv( + laplacians["hodge"], + self.conv_order, + x_lower_proj, + ) + message_list.append(x_lower_conv) + else: + if laplacians["down"] is not None: + x_lower_conv = self.chebyshev_conv( + laplacians["down"], + self.conv_order, + x_lower_proj, + ) + message_list.append(x_lower_conv) + + # 3. Upper messages (from rank+1) + if ( + rank < self.max_rank + and incidence_upper is not None + and rank + 1 < len(x_all) + ): + x_upper = x_all[rank + 1] + # Only process if upper rank is not empty + if x_upper.shape[0] > 0: + # Project features from rank+1 to rank + x_upper_proj = torch.mm(incidence_upper, x_upper) + + # Identity + message_list.append(x_upper_proj.unsqueeze(2)) + + # Apply Laplacian convolutions at the current rank + # Use up Laplacian for rank > 0 + if laplacians["up"] is not None: + x_upper_conv = self.chebyshev_conv( + laplacians["up"], self.conv_order, x_upper_proj + ) + message_list.append(x_upper_conv) + + # Concatenate all messages + messages = torch.cat(message_list, dim=2) + + return messages From 82171c996446f73cacb2f5380ad084460c522387 Mon Sep 17 00:00:00 2001 From: grapentt Date: Wed, 19 Nov 2025 21:32:23 +0100 Subject: [PATCH 25/32] :memo: Include edge_task in name to prevent cache conflicts --- topobench/data/datasets/ppi_highppi_dataset.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/topobench/data/datasets/ppi_highppi_dataset.py b/topobench/data/datasets/ppi_highppi_dataset.py index e8558941f..8d1e91631 100644 --- a/topobench/data/datasets/ppi_highppi_dataset.py +++ b/topobench/data/datasets/ppi_highppi_dataset.py @@ -167,17 +167,18 @@ def raw_file_names(self) -> list[str]: def processed_file_names(self) -> list[str]: """Return the name of the processed file. - Filename includes target_ranks to avoid cache conflicts when - different ranks are requested. + Filename includes target_ranks and edge_task to avoid cache conflicts + when different ranks or tasks are requested. Returns ------- List[str] List containing the name of the processed file. """ - # Include target_ranks in filename to prevent cache conflicts + # Include target_ranks and edge_task in filename to prevent cache conflicts ranks_str = "_".join(map(str, self.target_ranks)) - return [f"data_ranks_{ranks_str}.pt"] + task_str = self.edge_task if self.edge_task else "none" + return [f"data_ranks_{ranks_str}_task_{task_str}.pt"] def download(self) -> None: """Download HIGH-PPI and CORUM data files.""" From 0646111763cbdf182e66fe96bfb3cde58901e3b2 Mon Sep 17 00:00:00 2001 From: grapentt Date: Wed, 19 Nov 2025 21:38:30 +0100 Subject: [PATCH 26/32] :bug: Message passing --- topobench/nn/backbones/simplicial/sccnn.py | 107 ++++++++++----------- 1 file changed, 52 insertions(+), 55 deletions(-) diff --git a/topobench/nn/backbones/simplicial/sccnn.py b/topobench/nn/backbones/simplicial/sccnn.py index cff5830dd..3033ea944 100644 --- a/topobench/nn/backbones/simplicial/sccnn.py +++ b/topobench/nn/backbones/simplicial/sccnn.py @@ -181,7 +181,7 @@ def _compute_message_types(self, rank): # Self messages if rank == 0: - # Rank 0: identity + Hodge Laplacian convolutions + # Rank 0: identity + up Laplacian convolutions count += 1 + self.conv_order else: # Rank k>0: identity + down Laplacian + up Laplacian convolutions @@ -189,13 +189,17 @@ def _compute_message_types(self, rank): # Lower messages (from rank-1 projected to rank) if rank > 0: - # Identity + convolutions with down/up Laplacians at current rank - count += 1 + self.conv_order + # Identity + convolutions with down and up Laplacians at current rank + count += 1 + self.conv_order + self.conv_order # Upper messages (from rank+1 projected to rank) if rank < self.max_rank: - # Identity + convolutions with down/up Laplacians at current rank - count += 1 + self.conv_order + # Identity + convolutions with down and up Laplacians at current rank + # Special case: rank 0 only has up Laplacian + if rank == 0: + count += 1 + self.conv_order + else: + count += 1 + self.conv_order + self.conv_order return count @@ -433,65 +437,50 @@ def _compute_messages_for_rank( """ message_list = [] - # 1. Self messages (identity + convolutions) + # 1. Lower messages (from rank-1) + if rank > 0 and incidence_lower is not None and rank - 1 < len(x_all): + x_lower = x_all[rank - 1] + # Only process if lower rank is not empty + if x_lower.shape[0] > 0: + # Project features from rank-1 to rank + x_lower_proj = torch.mm(incidence_lower.T, x_lower) + + message_list.append(x_lower_proj.unsqueeze(2)) + + # Apply down and up Laplacians + if laplacians["down"] is not None: + x_lower_down = self.chebyshev_conv( + laplacians["down"], self.conv_order, x_lower_proj + ) + message_list.append(x_lower_down) + + if laplacians["up"] is not None: + x_lower_up = self.chebyshev_conv( + laplacians["up"], self.conv_order, x_lower_proj + ) + message_list.append(x_lower_up) + + # 2. Self messages (identity + convolutions) if rank == 0: - # Identity message message_list.append(x_rank.unsqueeze(2)) - - # Hodge Laplacian convolutions if laplacians["hodge"] is not None: x_conv = self.chebyshev_conv( laplacians["hodge"], self.conv_order, x_rank ) message_list.append(x_conv) else: - # Identity message message_list.append(x_rank.unsqueeze(2)) - - # Down Laplacian convolutions if laplacians["down"] is not None: x_down = self.chebyshev_conv( laplacians["down"], self.conv_order, x_rank ) message_list.append(x_down) - - # Up Laplacian convolutions if laplacians["up"] is not None: x_up = self.chebyshev_conv( laplacians["up"], self.conv_order, x_rank ) message_list.append(x_up) - # 2. Lower messages (from rank-1) - if rank > 0 and incidence_lower is not None and rank - 1 < len(x_all): - x_lower = x_all[rank - 1] - # Only process if lower rank is not empty - if x_lower.shape[0] > 0: - # Project features from rank-1 to rank - x_lower_proj = torch.mm(incidence_lower.T, x_lower) - - # Identity - message_list.append(x_lower_proj.unsqueeze(2)) - - # Apply Laplacian convolutions at the current rank - # Use the appropriate Laplacian (down for rank 0, down for rank > 0) - if rank == 0: - if laplacians["hodge"] is not None: - x_lower_conv = self.chebyshev_conv( - laplacians["hodge"], - self.conv_order, - x_lower_proj, - ) - message_list.append(x_lower_conv) - else: - if laplacians["down"] is not None: - x_lower_conv = self.chebyshev_conv( - laplacians["down"], - self.conv_order, - x_lower_proj, - ) - message_list.append(x_lower_conv) - # 3. Upper messages (from rank+1) if ( rank < self.max_rank @@ -499,21 +488,29 @@ def _compute_messages_for_rank( and rank + 1 < len(x_all) ): x_upper = x_all[rank + 1] - # Only process if upper rank is not empty if x_upper.shape[0] > 0: - # Project features from rank+1 to rank x_upper_proj = torch.mm(incidence_upper, x_upper) - - # Identity message_list.append(x_upper_proj.unsqueeze(2)) - # Apply Laplacian convolutions at the current rank - # Use up Laplacian for rank > 0 - if laplacians["up"] is not None: - x_upper_conv = self.chebyshev_conv( - laplacians["up"], self.conv_order, x_upper_proj - ) - message_list.append(x_upper_conv) + # Apply Laplacians (Hodge for rank 0, both down/up for rank > 0) + if rank == 0: + if laplacians["hodge"] is not None: + x_upper_hodge = self.chebyshev_conv( + laplacians["hodge"], self.conv_order, x_upper_proj + ) + message_list.append(x_upper_hodge) + else: + if laplacians["down"] is not None: + x_upper_down = self.chebyshev_conv( + laplacians["down"], self.conv_order, x_upper_proj + ) + message_list.append(x_upper_down) + + if laplacians["up"] is not None: + x_upper_up = self.chebyshev_conv( + laplacians["up"], self.conv_order, x_upper_proj + ) + message_list.append(x_upper_up) # Concatenate all messages messages = torch.cat(message_list, dim=2) From ec900d8e4d911c5d888e3fccb45d5d10d054bf5b Mon Sep 17 00:00:00 2001 From: grapentt Date: Wed, 19 Nov 2025 22:51:02 +0100 Subject: [PATCH 27/32] :sparkles: SCCNN cell wrapper using higher cell features and allowing for higher level cell-prediction --- .../wrappers/simplicial/sccnn_cell_wrapper.py | 167 ++++++++++++++++++ 1 file changed, 167 insertions(+) create mode 100644 topobench/nn/wrappers/simplicial/sccnn_cell_wrapper.py diff --git a/topobench/nn/wrappers/simplicial/sccnn_cell_wrapper.py b/topobench/nn/wrappers/simplicial/sccnn_cell_wrapper.py new file mode 100644 index 000000000..552180045 --- /dev/null +++ b/topobench/nn/wrappers/simplicial/sccnn_cell_wrapper.py @@ -0,0 +1,167 @@ +"""Wrapper for SCCNN with cell-level predictions. + +This wrapper is designed for transductive learning where predictions are made +on specific cells (simplices) rather than entire graphs or individual nodes. +""" + +import torch +from torch_geometric.data import Data + +from topobench.nn.wrappers.base import AbstractWrapper + + +class SCCNNCellWrapper(AbstractWrapper): + """Wrapper for SCCNN backbone with cell-level outputs. + + Unlike standard wrappers that focus on node features (x_0), this wrapper + preserves features at ALL ranks for cell-level prediction. + + Parameters + ---------- + backbone : nn.Module + The SCCNN backbone model. + num_cell_dimensions : int + Rank +1 of the simplicial complex. + target_ranks : list[int] + Which ranks have labels to predict. + **kwargs : dict + Additional arguments. + """ + + def __init__( + self, + backbone: torch.nn.Module, + num_cell_dimensions: int, + target_ranks: list, + **kwargs, + ): + # Ensure required parameters for base class + if "out_channels" not in kwargs: + kwargs["out_channels"] = 32 # Default value + kwargs["num_cell_dimensions"] = num_cell_dimensions + # Disable residual connections for cell-level prediction + # (we just pass through features) + kwargs["residual_connections"] = kwargs.get( + "residual_connections", False + ) + + super().__init__(backbone, **kwargs) + self.target_ranks = target_ranks + self.num_cell_dimensions = num_cell_dimensions + + def __repr__(self): + return f"{self.__class__.__name__}(target_ranks={self.target_ranks})" + + def forward(self, batch: Data) -> dict: + """Forward pass preserving all rank features. + + Parameters + ---------- + batch : Data + Batch object containing features x_0, x_1, ..., x_k, Laplacians, and incidences. + + Returns + ------- + dict + The model_out containing updated features x_0, x_1, ..., x_k. + """ + # Extract features for all ranks from 0 to num_cell_dimensions-1 = rank + x_all = [] + for i in range(self.num_cell_dimensions): + x_key = f"x_{i}" + if hasattr(batch, x_key): + x_all.append(getattr(batch, x_key)) + else: + # If rank doesn't exist, add empty tensor + x_all.append(torch.zeros(0, 1, device=batch.x_0.device)) + x_all = tuple(x_all) + + # Extract Laplacians + laplacian_all = self._extract_laplacians(batch) + + # Extract incidences + incidence_all = self._extract_incidences(batch) + + # Forward through SCCNN backbone + x_all_out = self.backbone(x_all, laplacian_all, incidence_all) + + # Build output dictionary with features at ALL ranks + model_out = {} + for i, x_rank in enumerate(x_all_out): + model_out[f"x_{i}"] = x_rank + + return model_out + + def _extract_laplacians(self, batch: Data) -> tuple: + """Extract Laplacian matrices for all ranks. + + Expected format: + - hodge_laplacian_0 + - down_laplacian_1, up_laplacian_1 + - down_laplacian_2, up_laplacian_2 + - ... + + Parameters + ---------- + batch : Data + Batch object containing features x_0, x_1, ..., x_k, Laplacians, and incidences. + + Returns + ------- + tuple + Tuple of Laplacian matrices for all ranks. + """ + laplacian_all = [] + + # Rank 0: Hodge Laplacian + if hasattr(batch, "hodge_laplacian_0"): + laplacian_all.append(batch.hodge_laplacian_0) + else: + laplacian_all.append(None) + + # Store down and up Laplacians for each rank + for rank in range(1, self.num_cell_dimensions): + down_key = f"down_laplacian_{rank}" + up_key = f"up_laplacian_{rank}" + + if hasattr(batch, down_key): + laplacian_all.append(getattr(batch, down_key)) + else: + laplacian_all.append(None) + + if hasattr(batch, up_key): + laplacian_all.append(getattr(batch, up_key)) + else: + laplacian_all.append(None) + + return tuple(laplacian_all) + + def _extract_incidences(self, batch: Data) -> tuple: + """Extract incidence matrices. + + Expected format: + - incidence_1: From 0-cells to 1-cells + - incidence_2: From 1-cells to 2-cells + - ... + + Parameters + ---------- + batch : Data + Batch object containing features x_0, x_1, ..., x_k, Laplacians, and incidences. + + Returns + ------- + tuple + Tuple of incidence matrices for all ranks. + """ + incidence_all = [] + + # Incidences map from rank k-1 to rank k, so we go from 1 to num_cell_dimensions + for rank in range(1, self.num_cell_dimensions + 1): + inc_key = f"incidence_{rank}" + if hasattr(batch, inc_key): + incidence_all.append(getattr(batch, inc_key)) + else: + incidence_all.append(None) + + return tuple(incidence_all) From 2a340482641719a602af265820bb189e185f2872 Mon Sep 17 00:00:00 2001 From: grapentt Date: Wed, 19 Nov 2025 23:25:50 +0100 Subject: [PATCH 28/32] :sparkles: Cell Readout layer for cell level prediction --- topobench/model/model.py | 10 +- topobench/nn/readouts/cell_readout.py | 179 ++++++++++++++++++++++++++ 2 files changed, 184 insertions(+), 5 deletions(-) create mode 100644 topobench/nn/readouts/cell_readout.py diff --git a/topobench/model/model.py b/topobench/model/model.py index 776487463..73437f634 100755 --- a/topobench/model/model.py +++ b/topobench/model/model.py @@ -248,19 +248,19 @@ def process_outputs(self, model_out: dict, batch: Data) -> dict: if key in ["logits", "labels"]: model_out[key] = val[mask] # Note: Rank-specific masks are applied in readout - # The readout stores which indices it kept in cell_indices - if "cell_indices" in model_out: + # The readout stores which indices it kept in valid_indices + if "valid_indices" in model_out: # Find intersection: which readout outputs are in this split? # Note: The split respects the masks applied in the readout - cell_indices = model_out["cell_indices"] - keep_mask = torch.isin(cell_indices, mask) + valid_indices = model_out["valid_indices"] + keep_mask = torch.isin(valid_indices, mask) # Filter logits and labels for key, val in model_out.items(): if key in ["logits", "labels"]: model_out[key] = val[keep_mask] else: - # No cell_indices: standard filtering (for non-masked tasks) + # No valid_indices: standard filtering (for non-masked tasks) for key, val in model_out.items(): if key in ["logits", "labels"]: model_out[key] = val[mask] diff --git a/topobench/nn/readouts/cell_readout.py b/topobench/nn/readouts/cell_readout.py new file mode 100644 index 000000000..fc6feaa13 --- /dev/null +++ b/topobench/nn/readouts/cell_readout.py @@ -0,0 +1,179 @@ +"""Cell-level readout for simplicial complexes. + +This readout layer predicts labels for valid cells of rank in target_ranks. +""" + +import torch +import torch.nn as nn +from torch_geometric.data import Data + + +class SimplicialCellLevelReadout(nn.Module): + """Readout for cell-level predictions on simplicial complexes. + + Takes features at each rank and predicts labels for valid cells + at specified target ranks. + + Parameters + ---------- + hidden_dim : int + Hidden dimension of input features on all ranks. + out_channels : int + Number of output classes. + num_cell_dimensions : int + Rank + 1 of simplicial complex. + target_ranks : List[int] + Which ranks have labels to predict (e.g., [2, 3, 4] for simplices with 3-5 nodes). + """ + + def __init__( + self, + hidden_dim: int, + out_channels: int, + num_cell_dimensions: int, + target_ranks: list[int], + ): + super().__init__() + self.hidden_dim = hidden_dim + self.out_channels = out_channels + self.num_cell_dimensions = num_cell_dimensions + self.target_ranks = target_ranks + self.task_level = ( + "cell" # For compatibility with TBModel need this attribute + ) + + # Create prediction head for each target rank + # Each rank might have different hidden dims in the future, so use a dict + self.predictors = nn.ModuleDict( + { + str(rank): nn.Linear(hidden_dim, out_channels) + for rank in target_ranks + } + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"target_ranks={self.target_ranks}, " + f"out_channels={self.out_channels})" + ) + + def forward(self, model_out: dict, batch: Data) -> dict: + """Compute cell-level predictions. + + Parameters + ---------- + model_out : dict + Dictionary containing x_0, x_1, ..., x_k features per rank. + batch : Data + Batch object containing cell_labels for each target rank and mask for valid cells. + + Returns + ------- + dict + Updated model_out with: + - logits: [num_labeled_cells, out_channels] + - cell_ranks: [num_labeled_cells] - which rank each prediction is for + - valid_indices: [num_labeled_cells] - which cell indices are valid + """ + all_logits = [] + all_labels = [] + all_ranks = [] + all_indices = [] + + for rank in self.target_ranks: + # Get features for this rank + x_key = f"x_{rank}" + if x_key not in model_out: + continue + + x_rank = model_out[x_key] # [num_cells_at_rank, hidden_dim] + + # Get labels for this rank + label_key = f"cell_labels_{rank}" + if not hasattr(batch, label_key): + continue + + labels = getattr(batch, label_key) # [num_cells_at_rank] + + # Filter valid cells with rank-specific mask + valid_mask = torch.ones( + len(labels), dtype=torch.bool, device=labels.device + ) + + mask_key = f"mask_{rank}" + if hasattr(batch, mask_key): + rank_mask = getattr(batch, mask_key) # Boolean mask + valid_mask &= rank_mask + + # Get final valid indices + valid_indices = torch.where(valid_mask)[0] + + if len(valid_indices) == 0: + continue + + # Get features and labels for valid cells + x_labeled = x_rank[valid_indices] # [num_labeled, hidden_dim] + y_labeled = labels[valid_indices] # [num_labeled] + + # Predict + logits = self.predictors[str(rank)]( + x_labeled + ) # [num_labeled, out_channels] + + all_logits.append(logits) + all_labels.append(y_labeled) + all_ranks.extend([rank] * len(valid_indices)) + all_indices.extend(valid_indices.tolist()) + + # Concatenate all predictions and labels + if len(all_logits) > 0: + model_out["logits"] = torch.cat(all_logits, dim=0) + model_out["labels"] = torch.cat(all_labels, dim=0) + model_out["cell_ranks"] = torch.tensor( + all_ranks, device=all_logits[0].device + ) + model_out["valid_indices"] = torch.tensor( + all_indices, device=all_logits[0].device + ) + else: + # No labeled cells found - use any available tensor for device + device = None + for key in model_out: + if isinstance(model_out[key], torch.Tensor): + device = model_out[key].device + break + if device is None: + device = torch.device("cpu") + + model_out["logits"] = torch.zeros( + 0, self.out_channels, device=device + ) + model_out["cell_ranks"] = torch.zeros( + 0, dtype=torch.long, device=device + ) + model_out["valid_indices"] = torch.zeros( + 0, dtype=torch.long, device=device + ) + + return model_out + + def __call__(self, model_out: dict, batch: Data) -> dict: + """Wrapper for forward to match AbstractZeroCellReadOut interface. + + Parameters + ---------- + model_out : dict + Dictionary containing features per rank. + batch : Data + Batch object containing cell labels and valid cell masks. + + Returns + ------- + dict + Updated model_out with: + - logits: [num_labeled_cells, out_channels] + - cell_ranks: [num_labeled_cells] - which rank each prediction is for + - valid_indices: [num_labeled_cells] - which cell index within rank + """ + return self.forward(model_out, batch) From d85e69b1db35bf76b31fe0bdf938d9b5b6b87998 Mon Sep 17 00:00:00 2001 From: I745505 Date: Wed, 26 Nov 2025 14:49:04 +0100 Subject: [PATCH 29/32] :heavy_plus: add dependency --- configs/model/simplicial/sccnn_cell.yaml | 45 ++++++++++++++++++++++++ pyproject.toml | 2 ++ 2 files changed, 47 insertions(+) create mode 100644 configs/model/simplicial/sccnn_cell.yaml diff --git a/configs/model/simplicial/sccnn_cell.yaml b/configs/model/simplicial/sccnn_cell.yaml new file mode 100644 index 000000000..c28a1fbb4 --- /dev/null +++ b/configs/model/simplicial/sccnn_cell.yaml @@ -0,0 +1,45 @@ +_target_: topobench.model.TBModel + +model_name: sccnn_cell +model_domain: simplicial + +_hidden_dim: 32 # Hidden dimension for all ranks + +_in_channels: ${infer_in_channels:${dataset},${oc.select:transforms,null}} +_num_ranks: ${infer_num_cell_dimensions:null,${model._in_channels}} +_channel_list: ${infer_channel_list:${model._hidden_dim},${model._num_ranks}} + +feature_encoder: + _target_: topobench.nn.encoders.AllCellFeatureEncoder + encoder_name: AllCellFeatureEncoder + in_channels: ${model._in_channels} + out_channels: ${model._hidden_dim} + proj_dropout: 0.0 + +backbone: + _target_: topobench.nn.backbones.simplicial.sccnn.SCCNNCustom + in_channels_all: ${model._channel_list} + hidden_channels_all: ${model._channel_list} + conv_order: 1 + sc_order: ${model._num_ranks} + aggr_norm: false + update_func: sigmoid + n_layers: 2 + +backbone_wrapper: + _target_: topobench.nn.wrappers.SCCNNCellWrapper + _partial_: true + wrapper_name: SCCNNCellWrapper + num_cell_dimensions: ${model._num_ranks} + target_ranks: ${dataset.parameters.target_ranks} + out_channels: ${model._hidden_dim} + +readout: + _target_: topobench.nn.readouts.LinearCellLevelReadout + hidden_dim: ${model._hidden_dim} + out_channels: ${dataset.parameters.num_classes} + num_cell_dimensions: ${model._num_ranks} + target_ranks: ${dataset.parameters.target_ranks} + +# Compile model for faster training (pytorch 2.0+) +compile: false diff --git a/pyproject.toml b/pyproject.toml index 3234ea9e6..2c2d6ef5d 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,8 @@ dependencies=[ "topomodelx @ git+https://github.com/pyt-team/TopoModelX.git", "toponetx @ git+https://github.com/pyt-team/TopoNetX.git", "lightning==2.4.0", + "gdown", + "pybiomart", ] [project.optional-dependencies] From ae8d011ba5c40dbfd4598e7346cd9975681fdad6 Mon Sep 17 00:00:00 2001 From: I745505 Date: Wed, 26 Nov 2025 15:13:18 +0100 Subject: [PATCH 30/32] :bug: non existing attribute --- topobench/data/utils/split_utils.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/topobench/data/utils/split_utils.py b/topobench/data/utils/split_utils.py index 71e5c6932..3bb82f034 100644 --- a/topobench/data/utils/split_utils.py +++ b/topobench/data/utils/split_utils.py @@ -258,7 +258,12 @@ def load_transductive_splits(dataset, parameters): "Dataset should have only one graph in a transductive setting." ) - data = dataset.data_list[0] + # Get the single graph - handle PreProcessor (data_list) or raw dataset + # PreProcessor.data_list contains the fully transformed data with all simplicial features + if hasattr(dataset, "data_list") and dataset.data_list: + data = dataset.data_list[0] + else: + data = dataset[0] # Check if this is multi-rank cell prediction target_ranks = getattr(dataset, "target_ranks", None) @@ -423,7 +428,12 @@ def load_multirank_transductive_splits( "Dataset should have only one graph/complex in a transductive setting." ) - data = dataset.data_list[0] + # Get the single graph - handle PreProcessor (data_list) or raw dataset + # PreProcessor.data_list contains the fully transformed data with all simplicial features + if hasattr(dataset, "data_list") and dataset.data_list: + data = dataset.data_list[0] + else: + data = dataset[0] target_ranks = dataset.target_ranks root = dataset.get_data_dir() if hasattr(dataset, "get_data_dir") else None From 289c6b3ec0a1726c92dc2edf9fca349e8e499820 Mon Sep 17 00:00:00 2001 From: I745505 Date: Wed, 26 Nov 2025 16:45:25 +0100 Subject: [PATCH 31/32] :lipstick: More test fixes --- test/pipeline/test_pipeline.py | 4 ++ topobench/dataloader/dataloader.py | 2 +- topobench/dataloader/utils.py | 50 +++++++++++++++++++ .../wrappers/hypergraph/hypergraph_wrapper.py | 7 +++ 4 files changed, 62 insertions(+), 1 deletion(-) diff --git a/test/pipeline/test_pipeline.py b/test/pipeline/test_pipeline.py index 785987159..58026899b 100644 --- a/test/pipeline/test_pipeline.py +++ b/test/pipeline/test_pipeline.py @@ -7,6 +7,10 @@ DATASET = "graph/MUTAG" # ADD YOUR DATASET HERE MODELS = ["graph/gcn", "cell/topotune", "simplicial/topotune"] # ADD ONE OR SEVERAL MODELS OF YOUR CHOICE HERE +# HIGH-PPI B2 integration (optional - uncomment to test) +# DATASET = "simplicial/ppi_highppi" +# MODELS = ["simplicial/sccnn_cell"] + class TestPipeline: """Test pipeline for a particular dataset and model.""" diff --git a/topobench/dataloader/dataloader.py b/topobench/dataloader/dataloader.py index 1f293b10c..5f4b659a0 100755 --- a/topobench/dataloader/dataloader.py +++ b/topobench/dataloader/dataloader.py @@ -2,7 +2,7 @@ from typing import Any -from lightning import LightningDataModule +from lightning.pytorch import LightningDataModule from torch.utils.data import DataLoader from topobench.dataloader.dataload_dataset import DataloadDataset diff --git a/topobench/dataloader/utils.py b/topobench/dataloader/utils.py index 4670e41a0..14bfc3435 100644 --- a/topobench/dataloader/utils.py +++ b/topobench/dataloader/utils.py @@ -56,6 +56,56 @@ def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any: else: return 0 + def __inc__(self, key: str, value: Any, *args, **kwargs) -> Any: + r"""Overwrite the `__inc__` method to handle proper index incrementing for sparse matrices. + + This method tells PyTorch Geometric how to increment indices when batching. + For incidence matrices, column indices need to be incremented by the running count + of the corresponding entity (e.g., hyperedges, edges, cells). + + Parameters + ---------- + key : str + Key of the data. + value : Any + Value of the data. + *args : Any + Additional arguments. + **kwargs : Any + Additional keyword arguments. + + Returns + ------- + Any + The increment value for indices. + """ + # Handle incidence matrices with hyperedges + if key == "incidence_hyperedges": + if hasattr(self, "num_hyperedges"): + return torch.tensor([[self.num_nodes], [self.num_hyperedges]]) + else: + # Fall back to using the shape of the incidence matrix + return torch.tensor([[self.num_nodes], [value.size(1)]]) + + # Handle incidence matrices for different ranks + # incidence_0, incidence_1, incidence_2, etc. + if key.startswith("incidence_") and key[10:].isdigit(): + rank = int(key.split("_")[1]) + num_cells_attr = f"num_cells_{rank}" + if hasattr(self, num_cells_attr): + return torch.tensor( + [[self.num_nodes], [getattr(self, num_cells_attr)]] + ) + else: + # Fall back to using the shape of the incidence matrix + return torch.tensor([[self.num_nodes], [value.size(1)]]) + + # Default PyG behavior for edge_index and similar + if "index" in key or key == "face": + return self.num_nodes + + return 0 + def to_data_list(batch): """Workaround needed since `torch_geometric` doesn't work when using `torch.sparse` instead of `torch_sparse`. diff --git a/topobench/nn/wrappers/hypergraph/hypergraph_wrapper.py b/topobench/nn/wrappers/hypergraph/hypergraph_wrapper.py index f5349f35c..442057945 100644 --- a/topobench/nn/wrappers/hypergraph/hypergraph_wrapper.py +++ b/topobench/nn/wrappers/hypergraph/hypergraph_wrapper.py @@ -23,6 +23,13 @@ def forward(self, batch): dict Dictionary containing the updated model output. """ + if not hasattr(batch, "incidence_hyperedges"): + raise AttributeError( + f"Batch object is missing 'incidence_hyperedges' attribute. " + f"Available attributes: {list(batch.keys())}. " + f"Make sure the hypergraph lifting transformation has been applied to your dataset." + ) + x_0, x_1 = self.backbone(batch.x_0, batch.incidence_hyperedges) model_out = {"labels": batch.y, "batch_0": batch.batch_0} model_out["x_0"] = x_0 From 4a86a7c7b61428eff384b3d9cf56b59267723976 Mon Sep 17 00:00:00 2001 From: grapentt Date: Wed, 26 Nov 2025 17:34:30 +0100 Subject: [PATCH 32/32] :bug: Fix tests --- test/nn/backbones/simplicial/test_sccnn.py | 5 ++--- test/test_tutorials.py | 12 +++++++++++- topobench/data/preprocessor/preprocessor.py | 4 ++-- topobench/model/model.py | 3 --- topobench/nn/backbones/combinatorial/gccn.py | 16 +++++++++++++--- 5 files changed, 28 insertions(+), 12 deletions(-) diff --git a/test/nn/backbones/simplicial/test_sccnn.py b/test/nn/backbones/simplicial/test_sccnn.py index e0608bea8..ded346e1c 100644 --- a/test/nn/backbones/simplicial/test_sccnn.py +++ b/test/nn/backbones/simplicial/test_sccnn.py @@ -94,9 +94,8 @@ def test_sccnn_basic_initialization(): # Verify layer structure assert len(model.layers) == 2 # Default n_layers is 2 - assert hasattr(model, 'in_linear_0') - assert hasattr(model, 'in_linear_1') - assert hasattr(model, 'in_linear_2') + assert hasattr(model, 'in_linears') + assert len(model.in_linears) == 3 # Should have 3 input linear layers for ranks 0, 1, 2 def test_update_functions(): """Test different update functions in the SCCNN.""" diff --git a/test/test_tutorials.py b/test/test_tutorials.py index 1c2b63a7a..7e30a19be 100644 --- a/test/test_tutorials.py +++ b/test/test_tutorials.py @@ -1,6 +1,7 @@ """Unit tests for the tutorials.""" import glob +import os import subprocess import tempfile @@ -28,7 +29,16 @@ def _exec_tutorial(path): file_name, path, ] - subprocess.check_call(args) + + # Set PYTHONPATH to include the project root so notebooks can import topobench + env = os.environ.copy() + project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + if 'PYTHONPATH' in env: + env['PYTHONPATH'] = f"{project_root}:{env['PYTHONPATH']}" + else: + env['PYTHONPATH'] = project_root + + subprocess.check_call(args, env=env) paths = sorted(glob.glob("tutorials/*.ipynb")) diff --git a/topobench/data/preprocessor/preprocessor.py b/topobench/data/preprocessor/preprocessor.py index a50129c5a..e5c4a913a 100644 --- a/topobench/data/preprocessor/preprocessor.py +++ b/topobench/data/preprocessor/preprocessor.py @@ -244,9 +244,9 @@ def load_dataset_splits( raise ValueError("No learning setting specified in split_params") if split_params.learning_setting == "inductive": - return load_inductive_splits(self.dataset, split_params) + return load_inductive_splits(self, split_params) elif split_params.learning_setting == "transductive": - return load_transductive_splits(self.dataset, split_params) + return load_transductive_splits(self, split_params) else: raise ValueError( f"Invalid '{split_params.learning_setting}' learning setting.\ diff --git a/topobench/model/model.py b/topobench/model/model.py index 73437f634..695a3f79a 100755 --- a/topobench/model/model.py +++ b/topobench/model/model.py @@ -244,9 +244,6 @@ def process_outputs(self, model_out: dict, batch: Data) -> dict: if self.task_level in ["node", "cell"]: # Keep only train data points (for node-level or cell-level tasks) - for key, val in model_out.items(): - if key in ["logits", "labels"]: - model_out[key] = val[mask] # Note: Rank-specific masks are applied in readout # The readout stores which indices it kept in valid_indices if "valid_indices" in model_out: diff --git a/topobench/nn/backbones/combinatorial/gccn.py b/topobench/nn/backbones/combinatorial/gccn.py index 90e85e131..6de687761 100644 --- a/topobench/nn/backbones/combinatorial/gccn.py +++ b/topobench/nn/backbones/combinatorial/gccn.py @@ -76,21 +76,31 @@ def get_nbhd_cache(self, params): ): src_rank, dst_rank = route if src_rank != dst_rank and (src_rank, dst_rank) not in nbhd_cache: - n_dst_nodes = getattr(params, f"x_{dst_rank}").shape[0] + # Check if the required attributes exist before accessing them + src_attr = f"x_{src_rank}" + dst_attr = f"x_{dst_rank}" + if not hasattr(params, src_attr) or not hasattr(params, dst_attr): + continue # Skip this route if the required features don't exist + + n_dst_nodes = getattr(params, dst_attr).shape[0] if src_rank > dst_rank: + if not hasattr(params, neighborhood): + continue # Skip if boundary matrix doesn't exist boundary = getattr(params, neighborhood).coalesce() nbhd_cache[(src_rank, dst_rank)] = ( interrank_boundary_index( - getattr(params, f"x_{src_rank}"), + getattr(params, src_attr), boundary.indices(), n_dst_nodes, ) ) elif src_rank < dst_rank: + if not hasattr(params, neighborhood): + continue # Skip if coboundary matrix doesn't exist coboundary = getattr(params, neighborhood).coalesce() nbhd_cache[(src_rank, dst_rank)] = ( interrank_boundary_index( - getattr(params, f"x_{src_rank}"), + getattr(params, src_attr), coboundary.indices(), n_dst_nodes, )