From e688fe56bfebae2130402e9ca6a3236d1de99c65 Mon Sep 17 00:00:00 2001 From: Andrea Cavallo Date: Mon, 24 Nov 2025 18:41:29 +0100 Subject: [PATCH 1/2] changes --- configs/dataset/graph/DAC.yaml | 37 + test/pipeline/test_pipeline.py | 5 +- topobench/data/datasets/dac_dataset.py | 150 ++++ .../data/loaders/combinatorial/__init__.py | 98 +++ .../combinatorial/dac_dataset_loader.py | 70 ++ .../data/loaders/graph/dac_dataset_loader.py | 72 ++ tutorials/my_dataset.ipynb | 758 ++++++++++++++++++ 7 files changed, 1187 insertions(+), 3 deletions(-) create mode 100644 configs/dataset/graph/DAC.yaml create mode 100644 topobench/data/datasets/dac_dataset.py create mode 100644 topobench/data/loaders/combinatorial/__init__.py create mode 100644 topobench/data/loaders/combinatorial/dac_dataset_loader.py create mode 100644 topobench/data/loaders/graph/dac_dataset_loader.py create mode 100644 tutorials/my_dataset.ipynb diff --git a/configs/dataset/graph/DAC.yaml b/configs/dataset/graph/DAC.yaml new file mode 100644 index 000000000..8f146629a --- /dev/null +++ b/configs/dataset/graph/DAC.yaml @@ -0,0 +1,37 @@ +# Dataset loader config +loader: + _target_: topobench.data.loaders.DACDatasetLoader + parameters: + data_domain: graph + data_type: 4-325-1 + data_name: 4-325-1 + data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_type} + split_num: 2 + +# Dataset parameters +parameters: + num_features: 2 + num_classes: 8 + num_nodes: 3224 + task: classification + loss_type: cross_entropy + monitor_metric: accuracy + task_level: graph + + +#splits +split_params: + learning_setting: inductive + # data_split_dir: ${paths.data_dir}/data_splits/${dataset.loader.parameters.data_name} + data_split_dir: ${paths.data_dir}/${dataset.loader.parameters.data_name}/processed + data_seed: 0 + split_type: fixed #'k-fold' # either "k-fold" or "random" strategies + k: 10 # for "k-fold" Cross-Validation + train_prop: 0.5 # for "random" strategy splitting + standardize: False + +# Dataloader parameters +dataloader_params: + batch_size: 4495 # Fixed + num_workers: 0 + pin_memory: False diff --git a/test/pipeline/test_pipeline.py b/test/pipeline/test_pipeline.py index 785987159..324288e94 100644 --- a/test/pipeline/test_pipeline.py +++ b/test/pipeline/test_pipeline.py @@ -1,11 +1,10 @@ """Test pipeline for a particular dataset and model.""" - import hydra from test._utils.simplified_pipeline import run -DATASET = "graph/MUTAG" # ADD YOUR DATASET HERE -MODELS = ["graph/gcn", "cell/topotune", "simplicial/topotune"] # ADD ONE OR SEVERAL MODELS OF YOUR CHOICE HERE +DATASET = "graph/DAC" # ADD YOUR DATASET HERE +MODELS = ["graph/gcn"] # ADD ONE OR SEVERAL MODELS OF YOUR CHOICE HERE class TestPipeline: diff --git a/topobench/data/datasets/dac_dataset.py b/topobench/data/datasets/dac_dataset.py new file mode 100644 index 000000000..ceefbcf6a --- /dev/null +++ b/topobench/data/datasets/dac_dataset.py @@ -0,0 +1,150 @@ +"""Dataset class for Dynamic Activity Complex (DAC) dataset.""" + +import os +import os.path as osp +import shutil +from typing import ClassVar + +import torch +from omegaconf import DictConfig +from torch_geometric.data import Data, InMemoryDataset, extract_zip + +from topobench.data.utils import download_file_from_link + + +class DACDataset(InMemoryDataset): + r"""Dataset class for the Dynamic Activity Complexes (DAC) dataset. + + Parameters + ---------- + root : str + Root directory where the dataset will be saved. + name : str + Name of the dataset. + parameters : DictConfig + Configuration parameters for the dataset. + + Attributes + ---------- + URLS (dict): Dictionary containing the URLs for downloading the dataset. + """ + + URLS: ClassVar = { + "4-325-1": "https://zenodo.org/records/17700425/files/4_325_1.zip", + "4-325-3": "https://zenodo.org/records/17700425/files/4_325_3.zip", + "4-325-5": "https://zenodo.org/records/17700425/files/4_325_5.zip", + } + + def __init__( + self, + root: str, + name: str, + parameters: DictConfig, + ): + # Load processed data (created in process()) + self.name = name + super().__init__(root) + self.data, self.slices, self.splits = torch.load( + self.processed_paths[0] + ) + + split_num = parameters.split_num + self.split_idx = self.splits[split_num] + + @property + def raw_file_names(self): + """Return the raw file names for the dataset. + + Returns + ------- + list[str] + List of raw file names. + """ + return ["all_edges.pt", "all_x.pt", "y.pt"] + + @property + def processed_file_names(self): + """Return the processed file name for the dataset. + + Returns + ------- + str + Processed file name. + """ + return ["data.pt"] + + @property + def processed_dir(self) -> str: + """Return the path to the processed directory of the dataset. + + Returns + ------- + str + Path to the processed directory. + """ + self.processed_root = osp.join(self.root) + return osp.join(self.processed_root, "processed") + + def download(self): + r"""Download the dataset from a URL and saves it to the raw directory. + + Raises: + FileNotFoundError: If the dataset URL is not found. + """ + # Step 1: Download data from the source + self.url = self.URLS[self.name] + download_file_from_link( + file_link=self.url, + path_to_save=self.raw_dir, + dataset_name=self.name, + file_format="zip", + ) + + # Step 2: extract zip file + folder = self.raw_dir + filename = f"{self.name}.zip" + path = osp.join(folder, filename) + extract_zip(path, folder) + # Delete zip file + os.unlink(path) + + # Step 3: organize files + # Move files from osp.join(folder, name_download) to folder + folder_name = "4_325_" + self.name.split("-")[2] + for file in os.listdir(osp.join(folder, folder_name)): + shutil.move(osp.join(folder, folder_name, file), folder) + # Delete osp.join(folder, self.name) dir + shutil.rmtree(osp.join(folder, folder_name)) + + def process(self): + r"""Handle the data for the dataset. + + This method loads the DAC raw data, creates one object for + each graph, and saves the processed data + to the appropriate location. + """ + # Load raw tensors + relations = torch.load(os.path.join(self.raw_dir, "all_edges.pt")) + all_x = torch.load(os.path.join(self.raw_dir, "all_x.pt")) + y = torch.load(os.path.join(self.raw_dir, "y.pt")) + + data_list = [] + for i in range(len(all_x)): + # Create PyG Data object + data = Data( + x=all_x[i], + edge_index=relations[i], + y=y[i].unsqueeze(0) if y[i].ndim == 0 else y[i], + ) + + data_list.append(data) + + # Save to processed dir using slicing format + data, slices = self.collate(data_list) + + splits = [] + for s in range(5): + split = torch.load(os.path.join(self.raw_dir, f"split_{s}.pt")) + splits.append(split) + + torch.save((data, slices, splits), self.processed_paths[0]) diff --git a/topobench/data/loaders/combinatorial/__init__.py b/topobench/data/loaders/combinatorial/__init__.py new file mode 100644 index 000000000..672e4c3cb --- /dev/null +++ b/topobench/data/loaders/combinatorial/__init__.py @@ -0,0 +1,98 @@ +"""Init file for combinatorial dataset load module with automated loader discovery.""" + +import inspect +from importlib import util +from pathlib import Path +from typing import Any, ClassVar + + +class CombinatorialLoaderManager: + """Manages automatic discovery and registration of combinatorial dataset loader classes.""" + + # Base class that all combinatorial loaders should inherit from (adjust based on your actual base class) + BASE_LOADER_CLASS: ClassVar[type] = object + + @staticmethod + def is_loader_class(obj: Any) -> bool: + """Check if an object is a valid combinatorial dataset loader class. + + Parameters + ---------- + obj : Any + The object to check if it's a valid combinatorial dataset loader class. + + Returns + ------- + bool + True if the object is a valid combinatorial dataset loader class (non-private class + with 'DatasetLoader' in name), False otherwise. + """ + return ( + inspect.isclass(obj) + and not obj.__name__.startswith("_") + and "DatasetLoader" in obj.__name__ + ) + + @classmethod + def discover_loaders(cls, package_path: str) -> dict[str, type[Any]]: + """Dynamically discover all combinatorial dataset loader classes in the package. + + Parameters + ---------- + package_path : str + Path to the package's __init__.py file. + + Returns + ------- + Dict[str, Type[Any]] + Dictionary mapping loader class names to their corresponding class objects. + """ + loaders = {} + + # Get the directory containing the loader modules + package_dir = Path(package_path).parent + + # Iterate through all .py files in the directory + for file_path in package_dir.glob("*.py"): + if file_path.stem == "__init__": + continue + + # Import the module + module_name = f"{Path(package_path).stem}.{file_path.stem}" + spec = util.spec_from_file_location(module_name, file_path) + if spec and spec.loader: + module = util.module_from_spec(spec) + spec.loader.exec_module(module) + + # Find all combinatorial dataset loader classes in the module + new_loaders = { + name: obj + for name, obj in inspect.getmembers(module) + if ( + cls.is_loader_class(obj) + and obj.__module__ == module.__name__ + ) + } + loaders.update(new_loaders) + return loaders + + +# Create the loader manager +manager = CombinatorialLoaderManager() + +# Automatically discover and populate loaders +COMBINATORIAL_LOADERS = manager.discover_loaders(__file__) + +COMBINATORIAL_LOADERS_list = list(COMBINATORIAL_LOADERS.keys()) + +# Automatically generate __all__ +__all__ = [ + # Loader collections + "COMBINATORIAL_LOADERS", + "COMBINATORIAL_LOADERS_list", + # Individual loader classes + *COMBINATORIAL_LOADERS.keys(), +] + +# For backwards compatibility, create individual imports +locals().update(**COMBINATORIAL_LOADERS) diff --git a/topobench/data/loaders/combinatorial/dac_dataset_loader.py b/topobench/data/loaders/combinatorial/dac_dataset_loader.py new file mode 100644 index 000000000..37ed2a9aa --- /dev/null +++ b/topobench/data/loaders/combinatorial/dac_dataset_loader.py @@ -0,0 +1,70 @@ +"""Loaders for Mantra dataset as simplicial.""" + +from omegaconf import DictConfig + +from topobench.data.datasets import DACDataset +from topobench.data.loaders.base import AbstractLoader + + +class DACCombinatorialDatasetLoader(AbstractLoader): + """Load Mantra dataset with configurable parameters. + + Note: for the simplicial datasets it is necessary to include DatasetLoader into the name of the class! + + Parameters + ---------- + parameters : DictConfig + Configuration parameters containing: + - data_dir: Root directory for data + - data_name: Name of the dataset + - other relevant parameters + + **kwargs : dict + Additional keyword arguments. + """ + + def __init__(self, parameters: DictConfig, **kwargs) -> None: + super().__init__(parameters, **kwargs) + + def load_dataset(self, **kwargs) -> DACDataset: + """Load the DAC Combinatorial dataset. + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments for dataset initialization. + + Returns + ------- + DACCombinatorialDataset + The loaded DAC Combinatorial dataset with the appropriate `data_dir`. + + Raises + ------ + RuntimeError + If dataset loading fails. + """ + + dataset = self._initialize_dataset(**kwargs) + self.data_dir = self.get_data_dir() + return dataset + + def _initialize_dataset(self, **kwargs) -> DACDataset: + """Initialize the Citation Hypergraph dataset. + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments for dataset initialization. + + Returns + ------- + CitationHypergraphDataset + The initialized dataset instance. + """ + return DACDataset( + root=str(self.root_data_dir), + name=self.parameters.data_name, + parameters=self.parameters, + **kwargs, + ) diff --git a/topobench/data/loaders/graph/dac_dataset_loader.py b/topobench/data/loaders/graph/dac_dataset_loader.py new file mode 100644 index 000000000..0c65f42b7 --- /dev/null +++ b/topobench/data/loaders/graph/dac_dataset_loader.py @@ -0,0 +1,72 @@ +"""Loaders for US County Demos dataset.""" + +from pathlib import Path + +from omegaconf import DictConfig + +from topobench.data.datasets import DACDataset +from topobench.data.loaders.base import AbstractLoader + + +class DACDatasetLoader(AbstractLoader): + """Load DAC dataset with configurable year and task variable. + + Parameters + ---------- + parameters : DictConfig + Configuration parameters containing: + - data_dir: Root directory for data + - data_name: Name of the dataset + - year: Year of the dataset (if applicable) + - task_variable: Task variable for the dataset + """ + + def __init__(self, parameters: DictConfig) -> None: + super().__init__(parameters) + + def load_dataset(self) -> DACDataset: + """Load the DAC dataset. + + Returns + ------- + DACDataset + The loaded DAC dataset with the appropriate `data_dir`. + + Raises + ------ + RuntimeError + If dataset loading fails. + """ + + dataset = self._initialize_dataset() + self.data_dir = self._redefine_data_dir(dataset) + return dataset + + def _initialize_dataset(self) -> DACDataset: + """Initialize the DAC dataset. + + Returns + ------- + DADataset + The initialized dataset instance. + """ + return DACDataset( + root=str(self.root_data_dir), + name=self.parameters.data_name, + parameters=self.parameters, + ) + + def _redefine_data_dir(self, dataset: DACDataset) -> Path: + """Redefine the data directory based on the chosen (year, task_variable) pair. + + Parameters + ---------- + dataset : DACDataset + The dataset instance. + + Returns + ------- + Path + The redefined data directory path. + """ + return dataset.processed_root diff --git a/tutorials/my_dataset.ipynb b/tutorials/my_dataset.ipynb new file mode 100644 index 000000000..0ee78de47 --- /dev/null +++ b/tutorials/my_dataset.ipynb @@ -0,0 +1,758 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Using a new dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this tutorial we show how you can use a dataset not present in the library.\n", + "\n", + "This particular example uses the ENZIMES dataset, uses a simplicial lifting to create simplicial complexes, and trains the SCN2 model. We train the model using the appropriate training and validation datasets, and finally test it on the test dataset." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Table of contents\n", + " [1. Imports](##sec1)\n", + "\n", + " [2. Configurations and utilities](##sec2)\n", + "\n", + " [3. Loading the data](##sec3)\n", + "\n", + " [4. Model initialization](##sec4)\n", + "\n", + " [5. Training](##sec5)\n", + "\n", + " [6. Testing the model](##sec6)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\acavallo\\AppData\\Local\\Temp\\ipykernel_8152\\2256932753.py:21: UserWarning: \n", + "The version_base parameter is not specified.\n", + "Please specify a compatability version level, or None.\n", + "Will assume defaults for version 1.1\n", + " with initialize(config_path=\"../configs\", job_name=\"job\"):\n", + "c:\\Users\\acavallo\\AppData\\Local\\miniconda3\\envs\\tb\\Lib\\site-packages\\outdated\\__init__.py:36: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.\n", + " from pkg_resources import parse_version\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using split number 2\n" + ] + } + ], + "source": [ + "import hydra\n", + "from hydra import compose, initialize\n", + "from hydra.utils import instantiate\n", + "\n", + "\n", + "\n", + "from topobench.utils.config_resolvers import (\n", + " get_default_metrics,\n", + " get_default_trainer,\n", + " get_default_transform,\n", + " get_flattened_channels,\n", + " get_monitor_metric,\n", + " get_monitor_mode,\n", + " get_non_relational_out_channels,\n", + " get_required_lifting,\n", + " infer_in_channels,\n", + " infer_num_cell_dimensions,\n", + ")\n", + "\n", + "\n", + "with initialize(config_path=\"../configs\", job_name=\"job\"):\n", + "\n", + " cfg = compose(\n", + " config_name=\"run.yaml\",\n", + " overrides=[\n", + " \"model=hypergraph/unignn2\",\n", + " \"dataset=graph/DAC\",\n", + " ], \n", + " return_hydra_config=True\n", + " )\n", + "loader = instantiate(cfg.dataset.loader)\n", + "\n", + "\n", + "dataset, dataset_dir = loader.load()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"Test pipeline for a particular dataset and model.\"\"\"\n", + "import sys \n", + "sys.path.append(\"../\")\n", + "sys.path.append(\"../../\")\n", + "import hydra\n", + "from test._utils.simplified_pipeline import run\n", + "\n", + "\n", + "DATASET = \"graph/DAC\" # ADD YOUR DATASET HERE\n", + "MODELS = [\"graph/gcn\"] # ADD ONE OR SEVERAL MODELS OF YOUR CHOICE HERE\n", + "\n", + "\n", + "class TestPipeline:\n", + " \"\"\"Test pipeline for a particular dataset and model.\"\"\"\n", + "\n", + " def setup_method(self):\n", + " \"\"\"Setup method.\"\"\"\n", + " hydra.core.global_hydra.GlobalHydra.instance().clear()\n", + " \n", + " def test_pipeline(self):\n", + " \"\"\"Test pipeline.\"\"\"\n", + " with hydra.initialize(config_path=\"../configs\", job_name=\"job\"):\n", + " for MODEL in MODELS:\n", + " cfg = hydra.compose(\n", + " config_name=\"run.yaml\",\n", + " overrides=[\n", + " f\"model={MODEL}\",\n", + " f\"dataset={DATASET}\", # IF YOU IMPLEMENT A LARGE DATASET WITH AN OPTION TO USE A SLICE OF IT, ADD BELOW THE CORRESPONDING OPTION\n", + " \"trainer.max_epochs=2\",\n", + " \"trainer.min_epochs=1\",\n", + " \"trainer.check_val_every_n_epoch=1\",\n", + " \"paths=test\",\n", + " \"callbacks=model_checkpoint\",\n", + " ],\n", + " return_hydra_config=True\n", + " )\n", + " run(cfg)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\acavallo\\AppData\\Local\\Temp\\ipykernel_8152\\661924356.py:22: UserWarning: \n", + "The version_base parameter is not specified.\n", + "Please specify a compatability version level, or None.\n", + "Will assume defaults for version 1.1\n", + " with hydra.initialize(config_path=\"../configs\", job_name=\"job\"):\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using split number 2\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: False, used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "HPU available: False, using: 0 HPUs\n", + "c:\\Users\\acavallo\\AppData\\Local\\miniconda3\\envs\\tb\\Lib\\site-packages\\lightning\\pytorch\\utilities\\parsing.py:44: Attribute 'backbone_wrapper' removed from hparams because it cannot be pickled. You can suppress this warning by setting `self.save_hyperparameters(ignore=['backbone_wrapper'])`.\n", + "c:\\Users\\acavallo\\AppData\\Local\\miniconda3\\envs\\tb\\Lib\\site-packages\\lightning\\pytorch\\callbacks\\model_checkpoint.py:654: Checkpoint directory C:\\Users\\acavallo\\Code\\TopoBench\\outputs\\checkpoints exists and is not empty.\n", + "\n", + " | Name | Type | Params | Mode \n", + "------------------------------------------------------------------\n", + "0 | feature_encoder | AllCellFeatureEncoder | 198 | train\n", + "1 | backbone | GNNWrapper | 8.4 K | train\n", + "2 | readout | NoReadOut | 520 | train\n", + "3 | val_acc_best | MeanMetric | 0 | train\n", + "------------------------------------------------------------------\n", + "9.2 K Trainable params\n", + "0 Non-trainable params\n", + "9.2 K Total params\n", + "0.037 Total estimated model params size (MB)\n", + "25 Modules in train mode\n", + "0 Modules in eval mode\n", + "c:\\Users\\acavallo\\AppData\\Local\\miniconda3\\envs\\tb\\Lib\\site-packages\\lightning\\pytorch\\trainer\\connectors\\data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n", + "c:\\Users\\acavallo\\AppData\\Local\\miniconda3\\envs\\tb\\Lib\\site-packages\\lightning\\pytorch\\trainer\\connectors\\data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "63413d186bc7431eaf04154c3edc354d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: | | 0/? [00:00┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Test metric DataLoader 0 ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ test/accuracy 0.11457174271345139 │\n", + "│ test/auroc 0.482075035572052 │\n", + "│ test/loss 333.13909912109375 │\n", + "│ test/precision 0.015568318776786327 │\n", + "│ test/recall 0.1149553582072258 │\n", + "└───────────────────────────┴───────────────────────────┘\n", + "\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│\u001b[36m \u001b[0m\u001b[36m test/accuracy \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.11457174271345139 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test/auroc \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.482075035572052 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test/loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 333.13909912109375 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test/precision \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.015568318776786327 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test/recall \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.1149553582072258 \u001b[0m\u001b[35m \u001b[0m│\n", + "└───────────────────────────┴───────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "test = TestPipeline()\n", + "test.setup_method()\n", + "test.test_pipeline()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Imports " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\acavallo\\AppData\\Local\\miniconda3\\envs\\tb\\Lib\\site-packages\\outdated\\__init__.py:36: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.\n", + " from pkg_resources import parse_version\n" + ] + } + ], + "source": [ + "import lightning as pl\n", + "# Hydra related imports\n", + "from omegaconf import OmegaConf\n", + "# Dataset related imports\n", + "from torch_geometric.datasets import TUDataset\n", + "# from topobench.dataloader.dataloader import TBDataloader\n", + "from topobench.data.loaders.graph.dac_dataset_loader import DACDatasetLoader\n", + "from topobench.data.preprocessor import PreProcessor\n", + "# Model related imports\n", + "from topobench.model.model import TBModel\n", + "from topomodelx.nn.simplicial.scn2 import SCN2\n", + "from topobench.nn.wrappers.simplicial import SCNWrapper\n", + "from topobench.nn.encoders import AllCellFeatureEncoder\n", + "from topobench.nn.readouts import PropagateSignalDown\n", + "# Optimization related imports\n", + "from topobench.loss.loss import TBLoss\n", + "from topobench.optimizer import TBOptimizer\n", + "from topobench.evaluator.evaluator import TBEvaluator" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "DACDataset.__init__() missing 1 required positional argument: 'parameters'", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mTypeError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[2]\u001b[39m\u001b[32m, line 3\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtopobench\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mdata\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mdatasets\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m DACDataset\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m data = \u001b[43mDACDataset\u001b[49m\u001b[43m(\u001b[49m\u001b[43mroot\u001b[49m\u001b[43m=\u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43m./data/4_125/\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m=\u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43m4_125\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n", + "\u001b[31mTypeError\u001b[39m: DACDataset.__init__() missing 1 required positional argument: 'parameters'" + ] + } + ], + "source": [ + "from topobench.data.datasets import DACDataset\n", + "\n", + "data = DACDataset(root='./data/4_125/', name='4_125')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Configurations and utilities " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Configurations can be specified using yaml files or directly specified in your code like in this example." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "transform_config = { \"clique_lifting\":\n", + " {\"transform_type\": \"lifting\",\n", + " \"transform_name\": \"SimplicialCliqueLifting\",\n", + " \"complex_dim\": 3,}\n", + "}\n", + "\n", + "split_config = {\n", + " \"learning_setting\": \"inductive\",\n", + " \"split_type\": \"random\",\n", + " \"data_seed\": 0,\n", + " \"data_split_dir\": \"./data/ENZYMES/splits/\",\n", + " \"train_prop\": 0.5,\n", + "}\n", + "\n", + "in_channels = 3\n", + "out_channels = 6\n", + "dim_hidden = 16\n", + "\n", + "wrapper_config = {\n", + " \"out_channels\": dim_hidden,\n", + " \"num_cell_dimensions\": 3,\n", + "}\n", + "\n", + "readout_config = {\n", + " \"readout_name\": \"PropagateSignalDown\",\n", + " \"num_cell_dimensions\": 1,\n", + " \"hidden_dim\": dim_hidden,\n", + " \"out_channels\": out_channels,\n", + " \"task_level\": \"graph\",\n", + " \"pooling_type\": \"sum\",\n", + "}\n", + "\n", + "loss_config = {\n", + " \"dataset_loss\": \n", + " {\n", + " \"task\": \"classification\", \n", + " \"loss_type\": \"cross_entropy\"\n", + " }\n", + "}\n", + "\n", + "evaluator_config = {\"task\": \"classification\",\n", + " \"num_classes\": out_channels,\n", + " \"metrics\": [\"accuracy\", \"precision\", \"recall\"]}\n", + "\n", + "optimizer_config = {\"optimizer_id\": \"Adam\",\n", + " \"parameters\":\n", + " {\"lr\": 0.001,\"weight_decay\": 0.0005}\n", + " }\n", + "\n", + "transform_config = OmegaConf.create(transform_config)\n", + "split_config = OmegaConf.create(split_config)\n", + "readout_config = OmegaConf.create(readout_config)\n", + "loss_config = OmegaConf.create(loss_config)\n", + "evaluator_config = OmegaConf.create(evaluator_config)\n", + "optimizer_config = OmegaConf.create(optimizer_config)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def wrapper(**factory_kwargs):\n", + " def factory(backbone):\n", + " return SCNWrapper(backbone, **factory_kwargs)\n", + " return factory" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Loading the data " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this example we use the ENZYMES dataset. It is a graph dataset and we use the clique lifting to transform the graphs into simplicial complexes. We invite you to check out the README of the [repository](https://github.com/pyt-team/TopoBenchX) to learn more about the various liftings offered." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Transform parameters are the same, using existing data_dir: ./data/ENZYMES/clique_lifting/3206123057\n" + ] + } + ], + "source": [ + "dataset_dir = \"./data/ENZYMES/\"\n", + "dataset = TUDataset(root=dataset_dir, name=\"ENZYMES\")\n", + "\n", + "preprocessor = PreProcessor(dataset, dataset_dir, transform_config)\n", + "dataset_train, dataset_val, dataset_test = preprocessor.load_dataset_splits(split_config)\n", + "datamodule = TBDataloader(dataset_train, dataset_val, dataset_test, batch_size=32)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Model initialization " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can create the backbone by instantiating the SCN2 model from TopoModelX. Then the `SCNWrapper` and the `TBModel` take care of the rest." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "backbone = SCN2(in_channels_0=dim_hidden, in_channels_1=dim_hidden, in_channels_2=dim_hidden)\n", + "wrapper = wrapper(**wrapper_config)\n", + "\n", + "readout = PropagateSignalDown(**readout_config)\n", + "loss = TBLoss(**loss_config)\n", + "feature_encoder = AllCellFeatureEncoder(in_channels=[in_channels, in_channels, in_channels], out_channels=dim_hidden)\n", + "\n", + "evaluator = TBEvaluator(**evaluator_config)\n", + "optimizer = TBOptimizer(**optimizer_config)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "model = TBModel(backbone=backbone,\n", + " backbone_wrapper=wrapper,\n", + " readout=readout,\n", + " loss=loss,\n", + " feature_encoder=feature_encoder,\n", + " evaluator=evaluator,\n", + " optimizer=optimizer,\n", + " compile=False,)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Training " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can use the `lightning` trainer to train the model." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (cuda), used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "HPU available: False, using: 0 HPUs\n", + "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.\n", + "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n", + "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:44: Attribute 'backbone_wrapper' removed from hparams because it cannot be pickled. You can suppress this warning by setting `self.save_hyperparameters(ignore=['backbone_wrapper'])`.\n", + "\n", + " | Name | Type | Params | Mode \n", + "------------------------------------------------------------------\n", + "0 | feature_encoder | AllCellFeatureEncoder | 1.2 K | train\n", + "1 | backbone | SCNWrapper | 1.6 K | train\n", + "2 | readout | PropagateSignalDown | 102 | train\n", + "3 | val_acc_best | MeanMetric | 0 | train\n", + "------------------------------------------------------------------\n", + "2.9 K Trainable params\n", + "0 Non-trainable params\n", + "2.9 K Total params\n", + "0.012 Total estimated model params size (MB)\n", + "36 Modules in train mode\n", + "0 Modules in eval mode\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.\n", + "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric MulticlassAccuracy was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.\n", + " warnings.warn(*args, **kwargs)\n", + "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric MulticlassPrecision was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.\n", + " warnings.warn(*args, **kwargs)\n", + "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric MulticlassRecall was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.\n", + " warnings.warn(*args, **kwargs)\n", + "/home/levtel/projects/dev/TopoBench/topobench/nn/wrappers/simplicial/scn_wrapper.py:75: UserWarning: Sparse CSR tensor support is in beta state. If you miss a functionality in the sparse tensor support, please submit a feature request to https://github.com/pytorch/pytorch/issues. (Triggered internally at ../aten/src/ATen/SparseCsrTensorImpl.cpp:53.)\n", + " normalized_matrix = diag_matrix @ (matrix @ diag_matrix)\n", + "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.\n", + "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (10) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n", + "`Trainer.fit` stopped: `max_epochs=5` reached.\n" + ] + } + ], + "source": [ + "#%%capture\n", + "# Increase the number of epochs to get better results\n", + "trainer = pl.Trainer(max_epochs=5, accelerator=\"cpu\", enable_progress_bar=False)\n", + "\n", + "trainer.fit(model, datamodule)\n", + "train_metrics = trainer.callback_metrics" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Training metrics\n", + " --------------------------\n", + "train/accuracy: 0.1867\n", + "train/precision: 0.1796\n", + "train/recall: 0.1835\n", + "val/loss: 3.2280\n", + "val/accuracy: 0.1600\n", + "val/precision: 0.1735\n", + "val/recall: 0.1554\n", + "train/loss: 3.2763\n" + ] + } + ], + "source": [ + "print(' Training metrics\\n', '-'*26)\n", + "for key in train_metrics:\n", + " print('{:<21s} {:>5.4f}'.format(key+':', train_metrics[key].item()))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Testing the model " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, we can test the model and obtain the results." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.\n" + ] + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃        Test metric               DataLoader 0        ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+       "│       test/accuracy           0.18000000715255737    │\n",
+       "│         test/loss              3.625048875808716     │\n",
+       "│      test/precision           0.1994038224220276     │\n",
+       "│        test/recall            0.1821674406528473     │\n",
+       "└───────────────────────────┴───────────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│\u001b[36m \u001b[0m\u001b[36m test/accuracy \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.18000000715255737 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test/loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 3.625048875808716 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test/precision \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.1994038224220276 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test/recall \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.1821674406528473 \u001b[0m\u001b[35m \u001b[0m│\n", + "└───────────────────────────┴───────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.test(model, datamodule)\n", + "test_metrics = trainer.callback_metrics" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Testing metrics\n", + " -------------------------\n", + "test/loss: 3.6250\n", + "test/accuracy: 0.1800\n", + "test/precision: 0.1994\n", + "test/recall: 0.1822\n" + ] + } + ], + "source": [ + "print(' Testing metrics\\n', '-'*25)\n", + "for key in test_metrics:\n", + " print('{:<20s} {:>5.4f}'.format(key+':', test_metrics[key].item()))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "tb", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} From f2fab1ed74a60b9b71cbd89ba32c1dcca4163533 Mon Sep 17 00:00:00 2001 From: Andrea Cavallo Date: Mon, 24 Nov 2025 19:24:09 +0100 Subject: [PATCH 2/2] change --- test/pipeline/test_pipeline.py | 1 + topobench/data/datasets/dac_dataset.py | 11 +- tutorials/my_dataset.ipynb | 758 ------------------------- 3 files changed, 11 insertions(+), 759 deletions(-) delete mode 100644 tutorials/my_dataset.ipynb diff --git a/test/pipeline/test_pipeline.py b/test/pipeline/test_pipeline.py index 324288e94..f40de75b9 100644 --- a/test/pipeline/test_pipeline.py +++ b/test/pipeline/test_pipeline.py @@ -3,6 +3,7 @@ from test._utils.simplified_pipeline import run + DATASET = "graph/DAC" # ADD YOUR DATASET HERE MODELS = ["graph/gcn"] # ADD ONE OR SEVERAL MODELS OF YOUR CHOICE HERE diff --git a/topobench/data/datasets/dac_dataset.py b/topobench/data/datasets/dac_dataset.py index ceefbcf6a..699642aa7 100644 --- a/topobench/data/datasets/dac_dataset.py +++ b/topobench/data/datasets/dac_dataset.py @@ -60,7 +60,16 @@ def raw_file_names(self): list[str] List of raw file names. """ - return ["all_edges.pt", "all_x.pt", "y.pt"] + return [ + "all_edges.pt", + "all_x.pt", + "y.pt", + "split_0.pt", + "split_1.pt", + "split_2.pt", + "split_3.pt", + "split_4.pt", + ] @property def processed_file_names(self): diff --git a/tutorials/my_dataset.ipynb b/tutorials/my_dataset.ipynb deleted file mode 100644 index 0ee78de47..000000000 --- a/tutorials/my_dataset.ipynb +++ /dev/null @@ -1,758 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Using a new dataset" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In this tutorial we show how you can use a dataset not present in the library.\n", - "\n", - "This particular example uses the ENZIMES dataset, uses a simplicial lifting to create simplicial complexes, and trains the SCN2 model. We train the model using the appropriate training and validation datasets, and finally test it on the test dataset." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Table of contents\n", - " [1. Imports](##sec1)\n", - "\n", - " [2. Configurations and utilities](##sec2)\n", - "\n", - " [3. Loading the data](##sec3)\n", - "\n", - " [4. Model initialization](##sec4)\n", - "\n", - " [5. Training](##sec5)\n", - "\n", - " [6. Testing the model](##sec6)" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "C:\\Users\\acavallo\\AppData\\Local\\Temp\\ipykernel_8152\\2256932753.py:21: UserWarning: \n", - "The version_base parameter is not specified.\n", - "Please specify a compatability version level, or None.\n", - "Will assume defaults for version 1.1\n", - " with initialize(config_path=\"../configs\", job_name=\"job\"):\n", - "c:\\Users\\acavallo\\AppData\\Local\\miniconda3\\envs\\tb\\Lib\\site-packages\\outdated\\__init__.py:36: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.\n", - " from pkg_resources import parse_version\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Using split number 2\n" - ] - } - ], - "source": [ - "import hydra\n", - "from hydra import compose, initialize\n", - "from hydra.utils import instantiate\n", - "\n", - "\n", - "\n", - "from topobench.utils.config_resolvers import (\n", - " get_default_metrics,\n", - " get_default_trainer,\n", - " get_default_transform,\n", - " get_flattened_channels,\n", - " get_monitor_metric,\n", - " get_monitor_mode,\n", - " get_non_relational_out_channels,\n", - " get_required_lifting,\n", - " infer_in_channels,\n", - " infer_num_cell_dimensions,\n", - ")\n", - "\n", - "\n", - "with initialize(config_path=\"../configs\", job_name=\"job\"):\n", - "\n", - " cfg = compose(\n", - " config_name=\"run.yaml\",\n", - " overrides=[\n", - " \"model=hypergraph/unignn2\",\n", - " \"dataset=graph/DAC\",\n", - " ], \n", - " return_hydra_config=True\n", - " )\n", - "loader = instantiate(cfg.dataset.loader)\n", - "\n", - "\n", - "dataset, dataset_dir = loader.load()\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "\"\"\"Test pipeline for a particular dataset and model.\"\"\"\n", - "import sys \n", - "sys.path.append(\"../\")\n", - "sys.path.append(\"../../\")\n", - "import hydra\n", - "from test._utils.simplified_pipeline import run\n", - "\n", - "\n", - "DATASET = \"graph/DAC\" # ADD YOUR DATASET HERE\n", - "MODELS = [\"graph/gcn\"] # ADD ONE OR SEVERAL MODELS OF YOUR CHOICE HERE\n", - "\n", - "\n", - "class TestPipeline:\n", - " \"\"\"Test pipeline for a particular dataset and model.\"\"\"\n", - "\n", - " def setup_method(self):\n", - " \"\"\"Setup method.\"\"\"\n", - " hydra.core.global_hydra.GlobalHydra.instance().clear()\n", - " \n", - " def test_pipeline(self):\n", - " \"\"\"Test pipeline.\"\"\"\n", - " with hydra.initialize(config_path=\"../configs\", job_name=\"job\"):\n", - " for MODEL in MODELS:\n", - " cfg = hydra.compose(\n", - " config_name=\"run.yaml\",\n", - " overrides=[\n", - " f\"model={MODEL}\",\n", - " f\"dataset={DATASET}\", # IF YOU IMPLEMENT A LARGE DATASET WITH AN OPTION TO USE A SLICE OF IT, ADD BELOW THE CORRESPONDING OPTION\n", - " \"trainer.max_epochs=2\",\n", - " \"trainer.min_epochs=1\",\n", - " \"trainer.check_val_every_n_epoch=1\",\n", - " \"paths=test\",\n", - " \"callbacks=model_checkpoint\",\n", - " ],\n", - " return_hydra_config=True\n", - " )\n", - " run(cfg)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "C:\\Users\\acavallo\\AppData\\Local\\Temp\\ipykernel_8152\\661924356.py:22: UserWarning: \n", - "The version_base parameter is not specified.\n", - "Please specify a compatability version level, or None.\n", - "Will assume defaults for version 1.1\n", - " with hydra.initialize(config_path=\"../configs\", job_name=\"job\"):\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Using split number 2\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "GPU available: False, used: False\n", - "TPU available: False, using: 0 TPU cores\n", - "HPU available: False, using: 0 HPUs\n", - "c:\\Users\\acavallo\\AppData\\Local\\miniconda3\\envs\\tb\\Lib\\site-packages\\lightning\\pytorch\\utilities\\parsing.py:44: Attribute 'backbone_wrapper' removed from hparams because it cannot be pickled. You can suppress this warning by setting `self.save_hyperparameters(ignore=['backbone_wrapper'])`.\n", - "c:\\Users\\acavallo\\AppData\\Local\\miniconda3\\envs\\tb\\Lib\\site-packages\\lightning\\pytorch\\callbacks\\model_checkpoint.py:654: Checkpoint directory C:\\Users\\acavallo\\Code\\TopoBench\\outputs\\checkpoints exists and is not empty.\n", - "\n", - " | Name | Type | Params | Mode \n", - "------------------------------------------------------------------\n", - "0 | feature_encoder | AllCellFeatureEncoder | 198 | train\n", - "1 | backbone | GNNWrapper | 8.4 K | train\n", - "2 | readout | NoReadOut | 520 | train\n", - "3 | val_acc_best | MeanMetric | 0 | train\n", - "------------------------------------------------------------------\n", - "9.2 K Trainable params\n", - "0 Non-trainable params\n", - "9.2 K Total params\n", - "0.037 Total estimated model params size (MB)\n", - "25 Modules in train mode\n", - "0 Modules in eval mode\n", - "c:\\Users\\acavallo\\AppData\\Local\\miniconda3\\envs\\tb\\Lib\\site-packages\\lightning\\pytorch\\trainer\\connectors\\data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n", - "c:\\Users\\acavallo\\AppData\\Local\\miniconda3\\envs\\tb\\Lib\\site-packages\\lightning\\pytorch\\trainer\\connectors\\data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "63413d186bc7431eaf04154c3edc354d", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Training: | | 0/? [00:00┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", - "┃ Test metric DataLoader 0 ┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", - "│ test/accuracy 0.11457174271345139 │\n", - "│ test/auroc 0.482075035572052 │\n", - "│ test/loss 333.13909912109375 │\n", - "│ test/precision 0.015568318776786327 │\n", - "│ test/recall 0.1149553582072258 │\n", - "└───────────────────────────┴───────────────────────────┘\n", - "\n" - ], - "text/plain": [ - "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", - "│\u001b[36m \u001b[0m\u001b[36m test/accuracy \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.11457174271345139 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test/auroc \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.482075035572052 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test/loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 333.13909912109375 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test/precision \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.015568318776786327 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test/recall \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.1149553582072258 \u001b[0m\u001b[35m \u001b[0m│\n", - "└───────────────────────────┴───────────────────────────┘\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "test = TestPipeline()\n", - "test.setup_method()\n", - "test.test_pipeline()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 1. Imports " - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "c:\\Users\\acavallo\\AppData\\Local\\miniconda3\\envs\\tb\\Lib\\site-packages\\outdated\\__init__.py:36: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.\n", - " from pkg_resources import parse_version\n" - ] - } - ], - "source": [ - "import lightning as pl\n", - "# Hydra related imports\n", - "from omegaconf import OmegaConf\n", - "# Dataset related imports\n", - "from torch_geometric.datasets import TUDataset\n", - "# from topobench.dataloader.dataloader import TBDataloader\n", - "from topobench.data.loaders.graph.dac_dataset_loader import DACDatasetLoader\n", - "from topobench.data.preprocessor import PreProcessor\n", - "# Model related imports\n", - "from topobench.model.model import TBModel\n", - "from topomodelx.nn.simplicial.scn2 import SCN2\n", - "from topobench.nn.wrappers.simplicial import SCNWrapper\n", - "from topobench.nn.encoders import AllCellFeatureEncoder\n", - "from topobench.nn.readouts import PropagateSignalDown\n", - "# Optimization related imports\n", - "from topobench.loss.loss import TBLoss\n", - "from topobench.optimizer import TBOptimizer\n", - "from topobench.evaluator.evaluator import TBEvaluator" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "ename": "TypeError", - "evalue": "DACDataset.__init__() missing 1 required positional argument: 'parameters'", - "output_type": "error", - "traceback": [ - "\u001b[31m---------------------------------------------------------------------------\u001b[39m", - "\u001b[31mTypeError\u001b[39m Traceback (most recent call last)", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[2]\u001b[39m\u001b[32m, line 3\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtopobench\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mdata\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mdatasets\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m DACDataset\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m data = \u001b[43mDACDataset\u001b[49m\u001b[43m(\u001b[49m\u001b[43mroot\u001b[49m\u001b[43m=\u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43m./data/4_125/\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m=\u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43m4_125\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n", - "\u001b[31mTypeError\u001b[39m: DACDataset.__init__() missing 1 required positional argument: 'parameters'" - ] - } - ], - "source": [ - "from topobench.data.datasets import DACDataset\n", - "\n", - "data = DACDataset(root='./data/4_125/', name='4_125')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2. Configurations and utilities " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Configurations can be specified using yaml files or directly specified in your code like in this example." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "transform_config = { \"clique_lifting\":\n", - " {\"transform_type\": \"lifting\",\n", - " \"transform_name\": \"SimplicialCliqueLifting\",\n", - " \"complex_dim\": 3,}\n", - "}\n", - "\n", - "split_config = {\n", - " \"learning_setting\": \"inductive\",\n", - " \"split_type\": \"random\",\n", - " \"data_seed\": 0,\n", - " \"data_split_dir\": \"./data/ENZYMES/splits/\",\n", - " \"train_prop\": 0.5,\n", - "}\n", - "\n", - "in_channels = 3\n", - "out_channels = 6\n", - "dim_hidden = 16\n", - "\n", - "wrapper_config = {\n", - " \"out_channels\": dim_hidden,\n", - " \"num_cell_dimensions\": 3,\n", - "}\n", - "\n", - "readout_config = {\n", - " \"readout_name\": \"PropagateSignalDown\",\n", - " \"num_cell_dimensions\": 1,\n", - " \"hidden_dim\": dim_hidden,\n", - " \"out_channels\": out_channels,\n", - " \"task_level\": \"graph\",\n", - " \"pooling_type\": \"sum\",\n", - "}\n", - "\n", - "loss_config = {\n", - " \"dataset_loss\": \n", - " {\n", - " \"task\": \"classification\", \n", - " \"loss_type\": \"cross_entropy\"\n", - " }\n", - "}\n", - "\n", - "evaluator_config = {\"task\": \"classification\",\n", - " \"num_classes\": out_channels,\n", - " \"metrics\": [\"accuracy\", \"precision\", \"recall\"]}\n", - "\n", - "optimizer_config = {\"optimizer_id\": \"Adam\",\n", - " \"parameters\":\n", - " {\"lr\": 0.001,\"weight_decay\": 0.0005}\n", - " }\n", - "\n", - "transform_config = OmegaConf.create(transform_config)\n", - "split_config = OmegaConf.create(split_config)\n", - "readout_config = OmegaConf.create(readout_config)\n", - "loss_config = OmegaConf.create(loss_config)\n", - "evaluator_config = OmegaConf.create(evaluator_config)\n", - "optimizer_config = OmegaConf.create(optimizer_config)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "def wrapper(**factory_kwargs):\n", - " def factory(backbone):\n", - " return SCNWrapper(backbone, **factory_kwargs)\n", - " return factory" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 3. Loading the data " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In this example we use the ENZYMES dataset. It is a graph dataset and we use the clique lifting to transform the graphs into simplicial complexes. We invite you to check out the README of the [repository](https://github.com/pyt-team/TopoBenchX) to learn more about the various liftings offered." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Transform parameters are the same, using existing data_dir: ./data/ENZYMES/clique_lifting/3206123057\n" - ] - } - ], - "source": [ - "dataset_dir = \"./data/ENZYMES/\"\n", - "dataset = TUDataset(root=dataset_dir, name=\"ENZYMES\")\n", - "\n", - "preprocessor = PreProcessor(dataset, dataset_dir, transform_config)\n", - "dataset_train, dataset_val, dataset_test = preprocessor.load_dataset_splits(split_config)\n", - "datamodule = TBDataloader(dataset_train, dataset_val, dataset_test, batch_size=32)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 4. Model initialization " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can create the backbone by instantiating the SCN2 model from TopoModelX. Then the `SCNWrapper` and the `TBModel` take care of the rest." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "backbone = SCN2(in_channels_0=dim_hidden, in_channels_1=dim_hidden, in_channels_2=dim_hidden)\n", - "wrapper = wrapper(**wrapper_config)\n", - "\n", - "readout = PropagateSignalDown(**readout_config)\n", - "loss = TBLoss(**loss_config)\n", - "feature_encoder = AllCellFeatureEncoder(in_channels=[in_channels, in_channels, in_channels], out_channels=dim_hidden)\n", - "\n", - "evaluator = TBEvaluator(**evaluator_config)\n", - "optimizer = TBOptimizer(**optimizer_config)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "model = TBModel(backbone=backbone,\n", - " backbone_wrapper=wrapper,\n", - " readout=readout,\n", - " loss=loss,\n", - " feature_encoder=feature_encoder,\n", - " evaluator=evaluator,\n", - " optimizer=optimizer,\n", - " compile=False,)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 5. Training " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now we can use the `lightning` trainer to train the model." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "GPU available: True (cuda), used: False\n", - "TPU available: False, using: 0 TPU cores\n", - "HPU available: False, using: 0 HPUs\n", - "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.\n", - "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n", - "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:44: Attribute 'backbone_wrapper' removed from hparams because it cannot be pickled. You can suppress this warning by setting `self.save_hyperparameters(ignore=['backbone_wrapper'])`.\n", - "\n", - " | Name | Type | Params | Mode \n", - "------------------------------------------------------------------\n", - "0 | feature_encoder | AllCellFeatureEncoder | 1.2 K | train\n", - "1 | backbone | SCNWrapper | 1.6 K | train\n", - "2 | readout | PropagateSignalDown | 102 | train\n", - "3 | val_acc_best | MeanMetric | 0 | train\n", - "------------------------------------------------------------------\n", - "2.9 K Trainable params\n", - "0 Non-trainable params\n", - "2.9 K Total params\n", - "0.012 Total estimated model params size (MB)\n", - "36 Modules in train mode\n", - "0 Modules in eval mode\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.\n", - "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric MulticlassAccuracy was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.\n", - " warnings.warn(*args, **kwargs)\n", - "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric MulticlassPrecision was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.\n", - " warnings.warn(*args, **kwargs)\n", - "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric MulticlassRecall was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.\n", - " warnings.warn(*args, **kwargs)\n", - "/home/levtel/projects/dev/TopoBench/topobench/nn/wrappers/simplicial/scn_wrapper.py:75: UserWarning: Sparse CSR tensor support is in beta state. If you miss a functionality in the sparse tensor support, please submit a feature request to https://github.com/pytorch/pytorch/issues. (Triggered internally at ../aten/src/ATen/SparseCsrTensorImpl.cpp:53.)\n", - " normalized_matrix = diag_matrix @ (matrix @ diag_matrix)\n", - "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.\n", - "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (10) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n", - "`Trainer.fit` stopped: `max_epochs=5` reached.\n" - ] - } - ], - "source": [ - "#%%capture\n", - "# Increase the number of epochs to get better results\n", - "trainer = pl.Trainer(max_epochs=5, accelerator=\"cpu\", enable_progress_bar=False)\n", - "\n", - "trainer.fit(model, datamodule)\n", - "train_metrics = trainer.callback_metrics" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - " Training metrics\n", - " --------------------------\n", - "train/accuracy: 0.1867\n", - "train/precision: 0.1796\n", - "train/recall: 0.1835\n", - "val/loss: 3.2280\n", - "val/accuracy: 0.1600\n", - "val/precision: 0.1735\n", - "val/recall: 0.1554\n", - "train/loss: 3.2763\n" - ] - } - ], - "source": [ - "print(' Training metrics\\n', '-'*26)\n", - "for key in train_metrics:\n", - " print('{:<21s} {:>5.4f}'.format(key+':', train_metrics[key].item()))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 6. Testing the model " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Finally, we can test the model and obtain the results." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.\n" - ] - }, - { - "data": { - "text/html": [ - "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
-       "┃        Test metric               DataLoader 0        ┃\n",
-       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
-       "│       test/accuracy           0.18000000715255737    │\n",
-       "│         test/loss              3.625048875808716     │\n",
-       "│      test/precision           0.1994038224220276     │\n",
-       "│        test/recall            0.1821674406528473     │\n",
-       "└───────────────────────────┴───────────────────────────┘\n",
-       "
\n" - ], - "text/plain": [ - "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", - "│\u001b[36m \u001b[0m\u001b[36m test/accuracy \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.18000000715255737 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test/loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 3.625048875808716 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test/precision \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.1994038224220276 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test/recall \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.1821674406528473 \u001b[0m\u001b[35m \u001b[0m│\n", - "└───────────────────────────┴───────────────────────────┘\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "trainer.test(model, datamodule)\n", - "test_metrics = trainer.callback_metrics" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - " Testing metrics\n", - " -------------------------\n", - "test/loss: 3.6250\n", - "test/accuracy: 0.1800\n", - "test/precision: 0.1994\n", - "test/recall: 0.1822\n" - ] - } - ], - "source": [ - "print(' Testing metrics\\n', '-'*25)\n", - "for key in test_metrics:\n", - " print('{:<20s} {:>5.4f}'.format(key+':', test_metrics[key].item()))" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "tb", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.3" - }, - "orig_nbformat": 4 - }, - "nbformat": 4, - "nbformat_minor": 2 -}