diff --git a/configs/dataset/graph/OC20_IS2RE.yaml b/configs/dataset/graph/OC20_IS2RE.yaml new file mode 100644 index 000000000..540677d28 --- /dev/null +++ b/configs/dataset/graph/OC20_IS2RE.yaml @@ -0,0 +1,35 @@ +# OC20 IS2RE task +# Train/val/test splits are precomputed in the LMDB archive + +loader: + _target_: topobench.data.loaders.graph.oc20_is2re_dataset_loader.IS2REDatasetLoader + parameters: + data_domain: graph + data_type: oc20 + data_name: OC20_IS2RE + data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain}/${dataset.loader.parameters.data_type} + task: is2re + download: true + legacy_format: false + dtype: float32 + max_samples: 10 # Set to integer (e.g., 1000) to limit dataset size for fast experiments, or null for full dataset + +parameters: + num_features: 6 # Will be determined by the actual data + num_classes: 1 + task: regression + loss_type: mse + monitor_metric: mae + task_level: graph + +split_params: + learning_setting: inductive + split_type: fixed # splits are precomputed in the dataset + data_split_dir: ${paths.data_dir}/data_splits/${dataset.loader.parameters.data_name} + data_seed: 0 + +dataloader_params: + batch_size: 32 + num_workers: 0 + pin_memory: true + persistent_workers: false diff --git a/configs/dataset/graph/OC20_S2EF_200K.yaml b/configs/dataset/graph/OC20_S2EF_200K.yaml new file mode 100644 index 000000000..fc8a32d93 --- /dev/null +++ b/configs/dataset/graph/OC20_S2EF_200K.yaml @@ -0,0 +1,42 @@ +# OC20 S2EF Dataset Configuration +# Structure to Energy and Forces prediction for catalyst discovery +# Dataset: 200K training samples with multiple validation splits +# Validation: all 4 validation splits aggregated (val_id, val_ood_ads, val_ood_cat, val_ood_both) +# Test: official test split + +loader: + _target_: topobench.data.loaders.graph.oc20_dataset_loader.OC20DatasetLoader + parameters: + data_domain: graph + data_type: oc20 + data_name: OC20_S2EF_200K + task: s2ef + train_split: "200K" + val_splits: null # null means use all 4 validation splits (val_id, val_ood_ads, val_ood_cat, val_ood_both) + include_test: false # S2EF test data is LMDB format (incompatible with .extxyz/ASE DB train/val) + download: true + dtype: float32 + legacy_format: false + max_samples: 10 # Set to integer (e.g., 1000) to limit dataset size for fast experiments, or null for full dataset + data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain}/${dataset.loader.parameters.data_type} + +parameters: + num_features: 1 # Number of node features (atomic numbers) + num_classes: 1 # Regression task (energy prediction) + task: regression + loss_type: mse + monitor_metric: mae + task_level: graph # Graph-level prediction + + +split_params: + learning_setting: inductive + split_type: fixed # Splits are provided by the dataset + data_split_dir: ${paths.data_dir}/data_splits/${dataset.loader.parameters.data_name} + data_seed: 0 + +dataloader_params: + batch_size: 32 + num_workers: 0 + pin_memory: true + persistent_workers: false diff --git a/configs/dataset/graph/OC20_S2EF_200K_mock.yaml b/configs/dataset/graph/OC20_S2EF_200K_mock.yaml new file mode 100644 index 000000000..477e816d8 --- /dev/null +++ b/configs/dataset/graph/OC20_S2EF_200K_mock.yaml @@ -0,0 +1,42 @@ +# OC20 S2EF Mock Dataset Configuration +# Mock configuration for testing purposes using the 200K training samples (350MB) +# This configuration is designed to be used for CI/CD testing without requiring large dataset downloads +# It downloads only the 200K training split and uses it for train/val/test + +loader: + _target_: topobench.data.loaders.graph.oc20_dataset_loader.OC20DatasetLoader + parameters: + data_domain: graph + data_type: oc20 + data_name: OC20_S2EF_200K_mock + task: s2ef + train_split: "200K" + val_splits: [] # Empty list to avoid downloading validation splits + include_test: false # Don't download test data to keep size minimal + download: true + dtype: float32 + legacy_format: false + max_samples: 10 # Limit to 10 samples for fast testing + data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain}/${dataset.loader.parameters.data_type} + +parameters: + num_features: 1 # Number of node features (atomic numbers) + num_classes: 1 # Regression task (energy prediction) + task: regression + loss_type: mse + monitor_metric: mae + task_level: graph # Graph-level prediction + + +split_params: + learning_setting: inductive + split_type: random # Use random splitting since we only download train split + data_split_dir: ${paths.data_dir}/data_splits/${dataset.loader.parameters.data_name} + data_seed: 0 + train_prop: 0.6 # 60% train, 20% validation, 20% test + +dataloader_params: + batch_size: 8 # Smaller batch size for testing + num_workers: 0 + pin_memory: true + persistent_workers: false diff --git a/configs/dataset/graph/OC20_S2EF_20M.yaml b/configs/dataset/graph/OC20_S2EF_20M.yaml new file mode 100644 index 000000000..815859189 --- /dev/null +++ b/configs/dataset/graph/OC20_S2EF_20M.yaml @@ -0,0 +1,38 @@ +# OC20 S2EF dataset with 20M training samples +# Validation: all 4 validation splits aggregated (val_id, val_ood_ads, val_ood_cat, val_ood_both) +# Test: official test split + +loader: + _target_: topobench.data.loaders.graph.oc20_dataset_loader.OC20DatasetLoader + parameters: + data_domain: graph + data_type: oc20 + data_name: OC20_S2EF_20M + data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain}/${dataset.loader.parameters.data_type} + task: s2ef + train_split: "20M" + val_splits: null # null means use all 4 validation splits + test_split: "test" + download: true + legacy_format: false + dtype: float32 + +parameters: + num_features: 6 # Will be determined by the actual data + num_classes: 1 + task: regression + loss_type: mse + monitor_metric: mae + task_level: graph + +split_params: + learning_setting: inductive + split_type: fixed # splits are provided by the dataset + data_split_dir: ${paths.data_dir}/data_splits/${dataset.loader.parameters.data_name} + data_seed: 0 + +dataloader_params: + batch_size: 32 + num_workers: 0 + pin_memory: true + persistent_workers: false diff --git a/configs/dataset/graph/OC20_S2EF_2M.yaml b/configs/dataset/graph/OC20_S2EF_2M.yaml new file mode 100644 index 000000000..172292b01 --- /dev/null +++ b/configs/dataset/graph/OC20_S2EF_2M.yaml @@ -0,0 +1,38 @@ +# OC20 S2EF dataset with 2M training samples +# Validation: all 4 validation splits aggregated (val_id, val_ood_ads, val_ood_cat, val_ood_both) +# Test: official test split + +loader: + _target_: topobench.data.loaders.graph.oc20_dataset_loader.OC20DatasetLoader + parameters: + data_domain: graph + data_type: oc20 + data_name: OC20_S2EF_2M + data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain}/${dataset.loader.parameters.data_type} + task: s2ef + train_split: "2M" + val_splits: null # null means use all 4 validation splits + test_split: "test" + download: true + legacy_format: false + dtype: float32 + +parameters: + num_features: 6 # Will be determined by the actual data + num_classes: 1 + task: regression + loss_type: mse + monitor_metric: mae + task_level: graph + +split_params: + learning_setting: inductive + split_type: fixed # splits are provided by the dataset + data_split_dir: ${paths.data_dir}/data_splits/${dataset.loader.parameters.data_name} + data_seed: 0 + +dataloader_params: + batch_size: 32 + num_workers: 0 + pin_memory: true + persistent_workers: false diff --git a/configs/dataset/graph/OC20_S2EF_all.yaml b/configs/dataset/graph/OC20_S2EF_all.yaml new file mode 100644 index 000000000..c50e85854 --- /dev/null +++ b/configs/dataset/graph/OC20_S2EF_all.yaml @@ -0,0 +1,38 @@ +# OC20 S2EF dataset with all training samples (~134M) +# Validation: all 4 validation splits aggregated (val_id, val_ood_ads, val_ood_cat, val_ood_both) +# Test: official test split + +loader: + _target_: topobench.data.loaders.graph.oc20_dataset_loader.OC20DatasetLoader + parameters: + data_domain: graph + data_type: oc20 + data_name: OC20_S2EF_all + data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain}/${dataset.loader.parameters.data_type} + task: s2ef + train_split: "all" + val_splits: null # null means use all 4 validation splits + test_split: "test" + download: true + legacy_format: false + dtype: float32 + +parameters: + num_features: 6 # Will be determined by the actual data + num_classes: 1 + task: regression + loss_type: mse + monitor_metric: mae + task_level: graph + +split_params: + learning_setting: inductive + split_type: fixed # splits are provided by the dataset + data_split_dir: ${paths.data_dir}/data_splits/${dataset.loader.parameters.data_name} + data_seed: 0 + +dataloader_params: + batch_size: 32 + num_workers: 0 + pin_memory: true + persistent_workers: false diff --git a/configs/dataset/graph/OC22_IS2RE.yaml b/configs/dataset/graph/OC22_IS2RE.yaml new file mode 100644 index 000000000..54f3b1130 --- /dev/null +++ b/configs/dataset/graph/OC22_IS2RE.yaml @@ -0,0 +1,35 @@ +# OC22 IS2RE task +# Train/val/test splits are precomputed in the LMDB archive + +loader: + _target_: topobench.data.loaders.graph.oc22_is2re_dataset_loader.OC22IS2REDatasetLoader + parameters: + data_domain: graph + data_type: oc20 + data_name: OC22_IS2RE + data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain}/${dataset.loader.parameters.data_type} + task: oc22_is2re + download: true + legacy_format: false + dtype: float32 + max_samples: 10 # Set to integer (e.g., 1000) to limit dataset size for fast experiments, or null for full dataset + +parameters: + num_features: 6 # Will be determined by the actual data + num_classes: 1 + task: regression + loss_type: mse + monitor_metric: mae + task_level: graph + +split_params: + learning_setting: inductive + split_type: fixed # splits are precomputed in the dataset + data_split_dir: ${paths.data_dir}/data_splits/${dataset.loader.parameters.data_name} + data_seed: 0 + +dataloader_params: + batch_size: 32 + num_workers: 0 + pin_memory: true + persistent_workers: false diff --git a/pyproject.toml b/pyproject.toml index 3234ea9e6..9dccf8a4a 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ dependencies=[ "topomodelx @ git+https://github.com/pyt-team/TopoModelX.git", "toponetx @ git+https://github.com/pyt-team/TopoNetX.git", "lightning==2.4.0", + "ase", # Required for OC20/OC22 S2EF dataset tests ] [project.optional-dependencies] diff --git a/test/conftest.py b/test/conftest.py index 27de49aed..649857344 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,4 +1,6 @@ """Configuration file for pytest.""" +import os +from pathlib import Path import networkx as nx import pytest import torch @@ -11,6 +13,13 @@ ) +# Set PROJECT_ROOT environment variable if not already set +if "PROJECT_ROOT" not in os.environ: + # Get the project root (parent of test directory) + project_root = Path(__file__).parent.parent.absolute() + os.environ["PROJECT_ROOT"] = str(project_root) + + @pytest.fixture def mocker_fixture(mocker): """Return pytest mocker, used when one want to use mocker in setup_method. diff --git a/test/data/load/test_datasetloaders.py b/test/data/load/test_datasetloaders.py index cb21fd421..4751e6c0b 100644 --- a/test/data/load/test_datasetloaders.py +++ b/test/data/load/test_datasetloaders.py @@ -41,7 +41,10 @@ def _gather_config_files(self, base_dir: Path) -> List[str]: # Below the datasets that have some default transforms with we manually overriten with no_transform, # due to lack of default transform for domain2domain "REDDIT-BINARY.yaml", "IMDB-MULTI.yaml", "IMDB-BINARY.yaml", #"ZINC.yaml" - "ogbg-molpcba.yaml", "manual_dataset.yaml" # "ogbg-molhiv.yaml" + "ogbg-molpcba.yaml", "manual_dataset.yaml", # "ogbg-molhiv.yaml" + # OC20/OC22 datasets that require large downloads (excluded from tests) + "OC20_S2EF_200K.yaml", "OC20_S2EF_2M.yaml", "OC20_S2EF_20M.yaml", + "OC20_S2EF_all.yaml", "OC20_IS2RE.yaml", "OC22_IS2RE.yaml" } # Below the datasets that takes quite some time to load and process diff --git a/test/data/load/test_oc20_datasets.py b/test/data/load/test_oc20_datasets.py new file mode 100644 index 000000000..9b61f7720 --- /dev/null +++ b/test/data/load/test_oc20_datasets.py @@ -0,0 +1,589 @@ +"""Unit tests for OC20 and OC22 dataset loaders.""" + +import os +import pytest +import torch +import hydra +from pathlib import Path +from omegaconf import DictConfig, OmegaConf + +# from topobench.data.loaders.graph.oc20_is2re_dataset_loader import IS2REDatasetLoader +# from topobench.data.loaders.graph.oc22_is2re_dataset_loader import OC22IS2REDatasetLoader +from topobench.data.loaders.graph.oc20_dataset_loader import OC20DatasetLoader +from topobench.data.loaders.graph.oc20_asedbs2ef_loader import OC20ASEDBDataset +# from topobench.data.datasets.oc20_is2re_dataset import IS2REDataset +# from topobench.data.datasets.oc22_is2re_dataset import OC22IS2REDataset +from topobench.data.datasets.oc20_dataset import OC20Dataset +from topobench.utils.config_resolvers import ( + get_default_metrics, + get_default_trainer, + get_default_transform, + get_flattened_channels, + get_monitor_metric, + get_monitor_mode, + get_non_relational_out_channels, + get_required_lifting, + infer_in_channels, + infer_num_cell_dimensions, + infer_topotune_num_cell_dimensions, +) + + +def register_resolvers(): + """Register OmegaConf resolvers for tests.""" + OmegaConf.register_new_resolver( + "get_default_metrics", get_default_metrics, replace=True + ) + OmegaConf.register_new_resolver( + "get_default_trainer", get_default_trainer, replace=True + ) + OmegaConf.register_new_resolver( + "get_default_transform", get_default_transform, replace=True + ) + OmegaConf.register_new_resolver( + "get_flattened_channels", + get_flattened_channels, + replace=True, + ) + OmegaConf.register_new_resolver( + "get_required_lifting", get_required_lifting, replace=True + ) + OmegaConf.register_new_resolver( + "get_monitor_metric", get_monitor_metric, replace=True + ) + OmegaConf.register_new_resolver( + "get_monitor_mode", get_monitor_mode, replace=True + ) + OmegaConf.register_new_resolver( + "get_non_relational_out_channels", + get_non_relational_out_channels, + replace=True, + ) + OmegaConf.register_new_resolver( + "infer_in_channels", infer_in_channels, replace=True + ) + OmegaConf.register_new_resolver( + "infer_num_cell_dimensions", infer_num_cell_dimensions, replace=True + ) + OmegaConf.register_new_resolver( + "infer_topotune_num_cell_dimensions", + infer_topotune_num_cell_dimensions, + replace=True, + ) + OmegaConf.register_new_resolver( + "parameter_multiplication", lambda x, y: int(int(x) * int(y)), replace=True + ) + + +def setup_project_root(): + """Set up PROJECT_ROOT environment variable for tests.""" + # Get the path to the test file's directory, then go up 3 levels to project root + test_file_dir = Path(__file__).resolve().parent + project_root = test_file_dir.parent.parent.parent + os.environ["PROJECT_ROOT"] = str(project_root) + + +# class TestOC20IS2REDatasetLoader: +# """Test suite for OC20 IS2RE dataset loader.""" + +# @pytest.fixture(autouse=True) +# def setup(self): +# """Setup test environment.""" +# hydra.core.global_hydra.GlobalHydra.instance().clear() +# register_resolvers() +# setup_project_root() +# self.relative_config_dir = "../../../configs" + +# def test_loader_initialization(self): +# """Test that the IS2RE loader can be initialized.""" +# with hydra.initialize( +# version_base="1.3", +# config_path=self.relative_config_dir, +# job_name="test_oc20_is2re" +# ): +# cfg = hydra.compose( +# config_name="run.yaml", +# overrides=["dataset=graph/OC20_IS2RE"], +# ) +# print('Test that the OC20 IS2RE loader can be initialized') +# loader = hydra.utils.instantiate(cfg.dataset.loader) +# assert isinstance(loader, IS2REDatasetLoader) +# assert loader.parameters.data_name == "OC20_IS2RE" +# assert loader.parameters.task == "is2re" + +# def test_dataset_loading(self): +# """Test that the IS2RE dataset loads correctly.""" +# with hydra.initialize( +# version_base="1.3", +# config_path=self.relative_config_dir, +# job_name="test_oc20_is2re_load" +# ): +# cfg = hydra.compose( +# config_name="run.yaml", +# overrides=["dataset=graph/OC20_IS2RE"], +# ) +# print('Test that the OC20 IS2RE dataset loads correctly') +# loader = hydra.utils.instantiate(cfg.dataset.loader) +# dataset, data_dir = loader.load() + +# # Check dataset type +# assert isinstance(dataset, IS2REDataset) + +# # Check dataset has required attributes +# assert hasattr(dataset, 'split_idx') +# assert 'train' in dataset.split_idx +# assert 'valid' in dataset.split_idx +# assert 'test' in dataset.split_idx + +# # Check splits are not empty (when max_samples is set) +# assert len(dataset.split_idx['train']) > 0 +# assert len(dataset.split_idx['valid']) > 0 +# assert len(dataset.split_idx['test']) > 0 + +# # Check dataset length +# assert len(dataset) > 0 + +# def test_dataset_item_access(self): +# """Test accessing individual items from the dataset.""" +# with hydra.initialize( +# version_base="1.3", +# config_path=self.relative_config_dir, +# job_name="test_oc20_is2re_item" +# ): +# cfg = hydra.compose( +# config_name="run.yaml", +# overrides=["dataset=graph/OC20_IS2RE"], +# ) +# print('Test that the OC20 IS2RE dataset loads correctly') +# loader = hydra.utils.instantiate(cfg.dataset.loader) +# dataset, _ = loader.load() + +# # Get first item +# data = dataset[0] + +# # Check data has required PyG attributes +# assert hasattr(data, 'x') +# assert hasattr(data, 'edge_index') +# assert hasattr(data, 'y') + +# # Check data types +# assert isinstance(data.x, torch.Tensor) +# assert isinstance(data.edge_index, torch.Tensor) +# assert isinstance(data.y, torch.Tensor) + +# # Check shapes +# assert data.x.dim() >= 1 +# assert data.edge_index.dim() == 2 +# assert data.edge_index.size(0) == 2 # [2, num_edges] + +# def test_split_indices_validity(self): +# """Test that split indices are valid and non-overlapping.""" +# with hydra.initialize( +# version_base="1.3", +# config_path=self.relative_config_dir, +# job_name="test_oc20_is2re_splits" +# ): +# cfg = hydra.compose( +# config_name="run.yaml", +# overrides=["dataset=graph/OC20_IS2RE"], +# ) +# print('Test that the OC20 IS2RE dataset loads correctly') +# loader = hydra.utils.instantiate(cfg.dataset.loader) +# dataset, _ = loader.load() + +# train_idx = dataset.split_idx['train'].numpy() +# val_idx = dataset.split_idx['valid'].numpy() +# test_idx = dataset.split_idx['test'].numpy() + +# # Check no overlap between splits +# assert len(set(train_idx) & set(val_idx)) == 0 +# assert len(set(train_idx) & set(test_idx)) == 0 +# # val and test might overlap if test reuses val when test is not available + +# # Check all indices are within dataset bounds +# all_indices = list(train_idx) + list(val_idx) +# assert all(0 <= idx < len(dataset) for idx in all_indices) + + +# class TestOC22IS2REDatasetLoader: +# """Test suite for OC22 IS2RE dataset loader.""" + +# @pytest.fixture(autouse=True) +# def setup(self): +# """Setup test environment.""" +# hydra.core.global_hydra.GlobalHydra.instance().clear() +# register_resolvers() +# setup_project_root() +# self.relative_config_dir = "../../../configs" + +# def test_loader_initialization(self): +# """Test that the OC22 IS2RE loader can be initialized.""" +# with hydra.initialize( +# version_base="1.3", +# config_path=self.relative_config_dir, +# job_name="test_oc22_is2re" +# ): +# cfg = hydra.compose( +# config_name="run.yaml", +# overrides=["dataset=graph/OC22_IS2RE"], +# ) +# print('Test that the OC20 IS2RE loader can be initialized') +# loader = hydra.utils.instantiate(cfg.dataset.loader) +# assert isinstance(loader, OC22IS2REDatasetLoader) +# assert loader.parameters.data_name == "OC22_IS2RE" +# assert loader.parameters.task == "oc22_is2re" + +# def test_dataset_loading(self): +# """Test that the OC22 IS2RE dataset loads correctly.""" +# with hydra.initialize( +# version_base="1.3", +# config_path=self.relative_config_dir, +# job_name="test_oc22_is2re_load" +# ): +# cfg = hydra.compose( +# config_name="run.yaml", +# overrides=["dataset=graph/OC22_IS2RE"], +# ) +# print('Test that the OC20 IS2RE dataset loads correctly') +# loader = hydra.utils.instantiate(cfg.dataset.loader) +# dataset, data_dir = loader.load() + +# # Check dataset type +# assert isinstance(dataset, OC22IS2REDataset) + +# # Check dataset has required attributes +# assert hasattr(dataset, 'split_idx') +# assert 'train' in dataset.split_idx +# assert 'valid' in dataset.split_idx +# assert 'test' in dataset.split_idx + +# # Check splits are not empty (when max_samples is set) +# assert len(dataset.split_idx['train']) > 0 +# assert len(dataset.split_idx['valid']) > 0 +# assert len(dataset.split_idx['test']) > 0 + +# # Check dataset length +# assert len(dataset) > 0 + +# def test_dataset_item_access(self): +# """Test accessing individual items from the OC22 dataset.""" +# with hydra.initialize( +# version_base="1.3", +# config_path=self.relative_config_dir, +# job_name="test_oc22_is2re_item" +# ): +# cfg = hydra.compose( +# config_name="run.yaml", +# overrides=["dataset=graph/OC22_IS2RE"], +# ) +# print('Test that the OC20 IS2RE dataset loads correctly') +# loader = hydra.utils.instantiate(cfg.dataset.loader) +# dataset, _ = loader.load() + +# # Get first item +# data = dataset[0] + +# # Check data has required PyG attributes +# assert hasattr(data, 'x') +# assert hasattr(data, 'edge_index') +# assert hasattr(data, 'y') + +# # Check data types +# assert isinstance(data.x, torch.Tensor) +# assert isinstance(data.edge_index, torch.Tensor) +# assert isinstance(data.y, torch.Tensor) + +# # Check shapes +# assert data.x.dim() >= 1 +# assert data.edge_index.dim() == 2 +# assert data.edge_index.size(0) == 2 + +# def test_split_indices_validity(self): +# """Test that split indices are valid and non-overlapping.""" +# with hydra.initialize( +# version_base="1.3", +# config_path=self.relative_config_dir, +# job_name="test_oc22_is2re_splits" +# ): +# cfg = hydra.compose( +# config_name="run.yaml", +# overrides=["dataset=graph/OC22_IS2RE"], +# ) +# print('Test that the OC20 IS2RE dataset loads correctly') +# loader = hydra.utils.instantiate(cfg.dataset.loader) +# dataset, _ = loader.load() + +# train_idx = dataset.split_idx['train'].numpy() +# val_idx = dataset.split_idx['valid'].numpy() +# test_idx = dataset.split_idx['test'].numpy() + +# # Check no overlap between train and val +# assert len(set(train_idx) & set(val_idx)) == 0 + +# # Check all indices are within dataset bounds +# all_indices = list(train_idx) + list(val_idx) +# assert all(0 <= idx < len(dataset) for idx in all_indices) + + +class TestOC20S2EFDatasetLoader: + """Test suite for OC20 S2EF dataset loader.""" + + @pytest.fixture(autouse=True) + def setup(self): + """Setup test environment.""" + hydra.core.global_hydra.GlobalHydra.instance().clear() + register_resolvers() + setup_project_root() + self.relative_config_dir = "../../../configs" + + def test_loader_initialization_200k(self): + """Test that the S2EF 200K loader can be initialized.""" + with hydra.initialize( + version_base="1.3", + config_path=self.relative_config_dir, + job_name="test_oc20_s2ef_200k" + ): + cfg = hydra.compose( + config_name="run.yaml", + overrides=["dataset=graph/OC20_S2EF_200K_mock"], + ) + print('Test that the OC20 S2EF 200K loader can be initialized') + loader = hydra.utils.instantiate(cfg.dataset.loader) + assert isinstance(loader, OC20DatasetLoader) + assert loader.parameters.task == "s2ef" + assert loader.parameters.train_split == "200K" + + def test_dataset_loading_200k(self): + """Test that the S2EF 200K dataset loads correctly.""" + with hydra.initialize( + version_base="1.3", + config_path=self.relative_config_dir, + job_name="test_oc20_s2ef_200k_load" + ): + cfg = hydra.compose( + config_name="run.yaml", + overrides=["dataset=graph/OC20_S2EF_200K_mock"], + ) + print('Test that the OC20 S2EF 200K dataset loads correctly') + loader = hydra.utils.instantiate(cfg.dataset.loader) + dataset, data_dir = loader.load() + + # Check dataset type (S2EF uses ASE DB backend) + assert isinstance(dataset, OC20ASEDBDataset) + + # Check dataset has required attributes + assert hasattr(dataset, 'split_idx') + assert 'train' in dataset.split_idx + assert 'valid' in dataset.split_idx + assert 'test' in dataset.split_idx + + # Check splits are not empty + assert len(dataset.split_idx['train']) > 0 + # With val_splits=[], validation data will come from random split (if split_type=random) + # Otherwise it will be empty + # S2EF test data is LMDB format (incompatible with .extxyz/ASE DB), so test split is empty + assert len(dataset.split_idx['test']) == 0 + + # Check dataset length + assert len(dataset) > 0 + + def test_dataset_item_access_s2ef(self): + """Test accessing individual items from the S2EF dataset.""" + with hydra.initialize( + version_base="1.3", + config_path=self.relative_config_dir, + job_name="test_oc20_s2ef_item" + ): + cfg = hydra.compose( + config_name="run.yaml", + overrides=["dataset=graph/OC20_S2EF_200K_mock"], + ) + print('Test that the OC20 S2EF dataset loads correctly') + loader = hydra.utils.instantiate(cfg.dataset.loader) + dataset, _ = loader.load() + + # Get first item + data = dataset[0] + + # Check data has required PyG attributes + assert hasattr(data, 'x') + assert hasattr(data, 'edge_index') + + # Check data types + assert isinstance(data.x, torch.Tensor) + assert isinstance(data.edge_index, torch.Tensor) + + # Check shapes + assert data.x.dim() >= 1 + assert data.edge_index.dim() == 2 + assert data.edge_index.size(0) == 2 + + def test_validation_splits_configuration(self): + """Test that validation splits can be configured.""" + with hydra.initialize( + version_base="1.3", + config_path=self.relative_config_dir, + job_name="test_oc20_s2ef_val_splits" + ): + # Test with val_id only + cfg = hydra.compose( + config_name="run.yaml", + overrides=["dataset=graph/OC20_S2EF_200K_mock"], + ) + print('Test that the OC20 S2EF 200K loader can be initialized') + loader = hydra.utils.instantiate(cfg.dataset.loader) + # val_splits=[] (empty list) for mock config means no separate validation files + # (validation will come from random split of train data) + assert loader.parameters.val_splits == [] + + def test_split_indices_validity_s2ef(self): + """Test that S2EF split indices are valid.""" + with hydra.initialize( + version_base="1.3", + config_path=self.relative_config_dir, + job_name="test_oc20_s2ef_splits" + ): + cfg = hydra.compose( + config_name="run.yaml", + overrides=["dataset=graph/OC20_S2EF_200K_mock"], + ) + print('Test that the OC20 S2EF 200K dataset loads correctly') + loader = hydra.utils.instantiate(cfg.dataset.loader) + dataset, _ = loader.load() + + # ASE DB dataset uses lists for split indices, not tensors + train_idx = dataset.split_idx['train'] + valid_idx = dataset.split_idx['valid'] + test_idx = dataset.split_idx['test'] + + # Check indices are valid (note: max_samples truncates dataset but indices reflect original positions) + assert len(train_idx) > 0 + # With val_splits=[], valid_idx may be empty at dataset level + # (preprocessor will create splits later with random splitting) + # Only check non-overlap if both have data + if len(valid_idx) > 0: + assert len(set(train_idx) & set(valid_idx)) == 0 + + +class TestOC20DatasetIntegration: + """Integration tests for OC20 datasets with PreProcessor.""" + + @pytest.fixture(autouse=True) + def setup(self): + """Setup test environment.""" + hydra.core.global_hydra.GlobalHydra.instance().clear() + register_resolvers() + setup_project_root() + self.relative_config_dir = "../../../configs" + +# def test_is2re_with_preprocessor(self): +# """Test IS2RE dataset with PreProcessor.""" +# with hydra.initialize( +# version_base="1.3", +# config_path=self.relative_config_dir, +# job_name="test_is2re_preprocessor" +# ): +# cfg = hydra.compose( +# config_name="run.yaml", +# overrides=[ +# "dataset=graph/OC20_IS2RE", +# "model=graph/gcn", +# ], +# return_hydra_config=True, +# ) + +# # Load dataset +# loader = hydra.utils.instantiate(cfg.dataset.loader) +# dataset, data_dir = loader.load() + +# # Use preprocessor +# from topobench.data.preprocessor import PreProcessor +# transform_config = cfg.get("transforms", None) +# preprocessor = PreProcessor(dataset, data_dir, transform_config) + +# # Load splits +# dataset_train, dataset_val, dataset_test = preprocessor.load_dataset_splits( +# cfg.dataset.split_params +# ) + +# # Verify splits exist and are not empty +# assert dataset_train is not None +# assert dataset_val is not None +# assert dataset_test is not None +# assert len(dataset_train) > 0 +# assert len(dataset_val) > 0 +# assert len(dataset_test) > 0 + +# def test_oc22_is2re_with_preprocessor(self): +# """Test OC22 IS2RE dataset with PreProcessor.""" +# with hydra.initialize( +# version_base="1.3", +# config_path=self.relative_config_dir, +# job_name="test_oc22_preprocessor" +# ): +# cfg = hydra.compose( +# config_name="run.yaml", +# overrides=[ +# "dataset=graph/OC22_IS2RE", +# "model=graph/gcn", +# ], +# return_hydra_config=True, +# ) +# print('Config used for OC22 IS2RE preprocessor test:\n', OmegaConf.to_yaml(cfg)) + +# # Load dataset +# loader = hydra.utils.instantiate(cfg.dataset.loader) +# dataset, data_dir = loader.load() + +# # Use preprocessor +# from topobench.data.preprocessor import PreProcessor +# transform_config = cfg.get("transforms", None) +# preprocessor = PreProcessor(dataset, data_dir, transform_config) + +# # Load splits +# dataset_train, dataset_val, dataset_test = preprocessor.load_dataset_splits( +# cfg.dataset.split_params +# ) + +# # Verify splits exist and are not empty +# assert dataset_train is not None +# assert dataset_val is not None +# assert dataset_test is not None +# assert len(dataset_train) > 0 +# assert len(dataset_val) > 0 +# assert len(dataset_test) > 0 + + def test_s2ef_with_preprocessor(self): + """Test mock S2EF dataset with PreProcessor.""" + with hydra.initialize( + version_base="1.3", + config_path=self.relative_config_dir, + job_name="test_s2ef_preprocessor" + ): + cfg = hydra.compose( + config_name="run.yaml", + overrides=["dataset=graph/OC20_S2EF_200K_mock", "model=graph/gcn"] + ) + print('Test that the OC20 S2EF 200K loader can be initialized') + + # Load dataset + loader = hydra.utils.instantiate(cfg.dataset.loader) + dataset, data_dir = loader.load() + + # Use preprocessor + from topobench.data.preprocessor import PreProcessor + transform_config = cfg.get("transforms", None) + preprocessor = PreProcessor(dataset, data_dir, transform_config) + + # Load splits + dataset_train, dataset_val, dataset_test = preprocessor.load_dataset_splits( + cfg.dataset.split_params + ) + + # Verify splits exist and train/val are not empty + assert dataset_train is not None + assert dataset_val is not None + assert dataset_test is not None + assert len(dataset_train) > 0 + assert len(dataset_val) > 0 + # S2EF test may be empty if using the official splits, no test for this \ No newline at end of file diff --git a/test/pipeline/test_pipeline.py b/test/pipeline/test_pipeline.py index 785987159..affb6ea39 100644 --- a/test/pipeline/test_pipeline.py +++ b/test/pipeline/test_pipeline.py @@ -1,35 +1,83 @@ -"""Test pipeline for a particular dataset and model.""" +"""Test pipeline for OC20/OC22 datasets.""" import hydra +import pytest from test._utils.simplified_pipeline import run -DATASET = "graph/MUTAG" # ADD YOUR DATASET HERE -MODELS = ["graph/gcn", "cell/topotune", "simplicial/topotune"] # ADD ONE OR SEVERAL MODELS OF YOUR CHOICE HERE - - class TestPipeline: - """Test pipeline for a particular dataset and model.""" + """Test pipeline for OC20 and OC22 datasets.""" def setup_method(self): """Setup method.""" hydra.core.global_hydra.GlobalHydra.instance().clear() - def test_pipeline(self): - """Test pipeline.""" + # IS2RE and OC22 tests commented out to prevent large dataset downloads during testing + # def test_pipeline_oc20_is2re(self): + # """Test pipeline with OC20 IS2RE dataset.""" + # dataset = "graph/OC20_IS2RE" + # model = "graph/gcn" + # + # with hydra.initialize(config_path="../../configs", job_name="job"): + # cfg = hydra.compose( + # config_name="run.yaml", + # overrides=[ + # f"model={model}", + # f"dataset={dataset}", + # "trainer.max_epochs=2", + # "trainer.min_epochs=1", + # "trainer.check_val_every_n_epoch=1", + # "paths=test", + # "callbacks=model_checkpoint", + # "dataset.dataloader_params.num_workers=0", + # "dataset.dataloader_params.persistent_workers=false", + # ], + # return_hydra_config=True + # ) + # run(cfg) + + # def test_pipeline_oc22_is2re(self): + # """Test pipeline with OC22 IS2RE dataset.""" + # dataset = "graph/OC22_IS2RE" + # model = "graph/gcn" + # + # with hydra.initialize(config_path="../../configs", job_name="job"): + # cfg = hydra.compose( + # config_name="run.yaml", + # overrides=[ + # f"model={model}", + # f"dataset={dataset}", + # "trainer.max_epochs=2", + # "trainer.min_epochs=1", + # "trainer.check_val_every_n_epoch=1", + # "paths=test", + # "callbacks=model_checkpoint", + # "dataset.dataloader_params.num_workers=0", + # "dataset.dataloader_params.persistent_workers=false", + # ], + # return_hydra_config=True + # ) + # run(cfg) + + def test_pipeline_oc20_s2ef(self): + """Test pipeline with OC20 S2EF dataset.""" + dataset = "graph/OC20_S2EF_200K_mock" + model = "graph/gcn" + with hydra.initialize(config_path="../../configs", job_name="job"): - for MODEL in MODELS: - cfg = hydra.compose( - config_name="run.yaml", - overrides=[ - f"model={MODEL}", - f"dataset={DATASET}", # IF YOU IMPLEMENT A LARGE DATASET WITH AN OPTION TO USE A SLICE OF IT, ADD BELOW THE CORRESPONDING OPTION - "trainer.max_epochs=2", - "trainer.min_epochs=1", - "trainer.check_val_every_n_epoch=1", - "paths=test", - "callbacks=model_checkpoint", - ], - return_hydra_config=True - ) - run(cfg) \ No newline at end of file + cfg = hydra.compose( + config_name="run.yaml", + overrides=[ + f"model={model}", + f"dataset={dataset}", + "trainer.max_epochs=2", + "trainer.min_epochs=1", + "trainer.check_val_every_n_epoch=1", + "paths=test", + "callbacks=model_checkpoint", + "dataset.dataloader_params.num_workers=0", + "dataset.dataloader_params.persistent_workers=false", + ], + return_hydra_config=True + ) + run(cfg) diff --git a/test_data_migration.py b/test_data_migration.py new file mode 100644 index 000000000..e336b43f5 --- /dev/null +++ b/test_data_migration.py @@ -0,0 +1,59 @@ +"""Test script to verify Data object migration from old PyG format.""" + +import copy +import pickle +from pathlib import Path + +import lmdb +from torch_geometric.data import Data + +# Open the first LMDB file +lmdb_path = Path( + "/Users/theos/Documents/code/TopoBench_contrib/datasets/graph/oc20/OC22_IS2RE/is2res_total_train_val_test_lmdbs/data/oc22/is2re-total/train/data.0000.lmdb" +) + +env = lmdb.open( + str(lmdb_path.resolve()), + subdir=False, + readonly=True, + lock=False, + readahead=True, + meminit=False, + max_readers=1, +) + +with env.begin() as txn: + cursor = txn.cursor() + cursor.first() + key, value = cursor.item() + old_data = pickle.loads(value) + +print(f"Old data type: {type(old_data)}") +print(f"Old data __dict__ keys: {old_data.__dict__.keys()}") + +# Try to migrate using the new approach +if "_store" not in old_data.__dict__ or any( + k in old_data.__dict__ for k in ["x", "edge_index", "pos"] +): + print("Detected old format (attributes in __dict__)") + + data_dict = {} + for key, val in old_data.__dict__.items(): + if not key.startswith("_") and val is not None: + data_dict[key] = val + print(f" {key}: {type(val)}") + + # Create new Data object + new_data = Data(**data_dict) + print(f"\nNew data type: {type(new_data)}") + print(f"New data __dict__ keys: {new_data.__dict__.keys()}") + + # Try to copy + try: + copied_data = copy.copy(new_data) + print("✅ Copy successful!") + print(f"Copied data __dict__ keys: {copied_data.__dict__.keys()}") + except Exception as e: + print(f"❌ Copy failed: {e}") + +env.close() diff --git a/topobench/data/datasets/__init__.py b/topobench/data/datasets/__init__.py index 68a4d6345..fe8d09678 100644 --- a/topobench/data/datasets/__init__.py +++ b/topobench/data/datasets/__init__.py @@ -5,7 +5,7 @@ from pathlib import Path from typing import ClassVar -from torch_geometric.data import InMemoryDataset +from torch_geometric.data import Dataset, InMemoryDataset class DatasetManager: @@ -41,9 +41,7 @@ class DatasetManager: ] @classmethod - def discover_datasets( - cls, package_path: str - ) -> dict[str, type[InMemoryDataset]]: + def discover_datasets(cls, package_path: str) -> dict[str, type[Dataset]]: """Dynamically discover all dataset classes in the package. Parameters @@ -53,7 +51,7 @@ def discover_datasets( Returns ------- - Dict[str, Type[InMemoryDataset]] + Dict[str, Type[Dataset]] Dictionary mapping class names to their corresponding class objects. """ datasets = {} @@ -81,8 +79,9 @@ def discover_datasets( inspect.isclass(obj) and obj.__module__ == module.__name__ and not name.startswith("_") - and issubclass(obj, InMemoryDataset) + and issubclass(obj, Dataset) and obj != InMemoryDataset + and obj != Dataset ) } datasets.update(new_datasets) diff --git a/topobench/data/datasets/oc20_dataset.py b/topobench/data/datasets/oc20_dataset.py new file mode 100644 index 000000000..aa19a37e1 --- /dev/null +++ b/topobench/data/datasets/oc20_dataset.py @@ -0,0 +1,375 @@ +"""Dataset class for Open Catalyst 2020 (OC20) family of datasets.""" + +from __future__ import annotations + +import logging +import pickle +from collections.abc import Iterator +from pathlib import Path +from typing import ClassVar + +import torch +from omegaconf import DictConfig +from torch_geometric.data import Data, Dataset + +try: + import lmdb + + HAS_LMDB = True +except ImportError: + lmdb = None + HAS_LMDB = False + +logger = logging.getLogger(__name__) + + +class OC20Dataset(Dataset): + """Dataset class for Open Catalyst 2020 (OC20) family. + + Supports S2EF (Structure to Energy and Forces) task for catalyst + discovery and materials science. + + The OC20 dataset contains DFT calculations for catalyst-adsorbate systems, + enabling machine learning models to predict energies and forces for + accelerated materials discovery. + + Parameters + ---------- + root : str + Root directory where the dataset is stored. + name : str + Name of the dataset. + parameters : DictConfig + Configuration parameters for the dataset. + + Attributes + ---------- + task : str + Task type: "s2ef", "is2re", or "oc22_is2re". + train_split : str + Training split size for S2EF (e.g., "200K", "2M", "20M", "all"). + val_splits : list[str] + Validation splits to use. + """ + + # S2EF validation splits + VALID_VAL_SPLITS: ClassVar = [ + "val_id", + "val_ood_ads", + "val_ood_cat", + "val_ood_both", + ] + + # S2EF train splits + VALID_TRAIN_SPLITS: ClassVar = ["200K", "2M", "20M", "all"] + + def __init__( + self, + root: str, + name: str, + parameters: DictConfig, + ) -> None: + self.name = name + self.parameters = parameters + + # Task configuration + self.task = parameters.get("task", "s2ef").lower() + if self.task != "s2ef": + raise ValueError(f"Unsupported task: {self.task}") + self.dtype = self._parse_dtype(parameters.get("dtype", "float32")) + self.legacy_format = parameters.get("legacy_format", False) + self.include_test = parameters.get("include_test", True) + + # S2EF-specific configuration + self.train_split = parameters.get("train_split", "200K") + if self.train_split not in self.VALID_TRAIN_SPLITS: + raise ValueError( + f"Invalid S2EF train split: {self.train_split}. " + f"Choose from {self.VALID_TRAIN_SPLITS}" + ) + + # Parse validation splits + val_splits = parameters.get("val_splits", None) + if val_splits is None: + self.val_splits = self.VALID_VAL_SPLITS + elif isinstance(val_splits, str): + self.val_splits = [val_splits] + else: + self.val_splits = list(val_splits) + + # Validate splits + for vs in self.val_splits: + if vs not in self.VALID_VAL_SPLITS: + raise ValueError( + f"Invalid S2EF val split: {vs}. " + f"Choose from {self.VALID_VAL_SPLITS}" + ) + + self.test_split = parameters.get("test_split", "test") + + # Limit for fast experimentation + self.max_samples = parameters.get("max_samples", None) + if self.max_samples is not None: + self.max_samples = int(self.max_samples) + logger.info( + f"⚠️ Limiting dataset to {self.max_samples} samples for fast experimentation" + ) + + super().__init__(root) + + # Open LMDB environments + self._open_lmdbs() + + def __repr__(self) -> str: + task_info = f"task={self.task}" + if self.task == "s2ef": + task_info += f", train={self.train_split}" + return f"{self.name}(root={self.root}, {task_info}, size={len(self)})" + + @staticmethod + def _parse_dtype(dtype) -> torch.dtype: + """Parse dtype parameter to torch.dtype.""" + if isinstance(dtype, str): + return getattr(torch, dtype) + return dtype + + def _get_data_paths(self) -> dict[str, list[Path]]: + """Get paths to LMDB files for each split. + + Returns + ------- + dict[str, list[Path]] + Dictionary mapping split names to lists of LMDB file paths. + """ + root = Path(self.root) + paths = {"train": [], "val": [], "test": []} + + # Training data path structure + train_subdir = f"s2ef_train_{self.train_split}" + train_dir = ( + root / "s2ef" / self.train_split / train_subdir / train_subdir + ) + paths["train"] = sorted(train_dir.glob("**/*.lmdb")) + + # Validation data paths + for val_split in self.val_splits: + val_subdir = f"s2ef_{val_split}" + val_dir = root / "s2ef" / "all" / val_subdir / val_subdir + paths["val"].extend(sorted(val_dir.glob("**/*.lmdb"))) + + # Test data path + if self.include_test: + test_dir = root / "s2ef" / "all" / "s2ef_test" / "s2ef_test" + paths["test"] = sorted(test_dir.glob("**/*.lmdb")) + + return paths + + def _open_lmdbs(self): + """Open LMDB files and create split mappings.""" + if not HAS_LMDB: + raise ImportError( + "LMDB is required for OC20 dataset. Install with: pip install lmdb" + ) + + paths = self._get_data_paths() + + # Initialize storage + self.envs = [] + self.cumulative_sizes = [0] + self.split_idx = {"train": [], "valid": [], "test": []} + + current_idx = 0 + + # Open train LMDBs + for lmdb_path in paths["train"]: + env, size = self._open_single_lmdb(lmdb_path) + self.envs.append((lmdb_path, env, size)) + self.cumulative_sizes.append(self.cumulative_sizes[-1] + size) + self.split_idx["train"].extend( + range(current_idx, current_idx + size) + ) + current_idx += size + + # Open validation LMDBs + for lmdb_path in paths["val"]: + env, size = self._open_single_lmdb(lmdb_path) + self.envs.append((lmdb_path, env, size)) + self.cumulative_sizes.append(self.cumulative_sizes[-1] + size) + self.split_idx["valid"].extend( + range(current_idx, current_idx + size) + ) + current_idx += size + + # Open test LMDBs + for lmdb_path in paths["test"]: + env, size = self._open_single_lmdb(lmdb_path) + self.envs.append((lmdb_path, env, size)) + self.cumulative_sizes.append(self.cumulative_sizes[-1] + size) + self.split_idx["test"].extend( + range(current_idx, current_idx + size) + ) + current_idx += size + + # If no test data, reuse validation indices + if not self.include_test or len(self.split_idx["test"]) == 0: + self.split_idx["test"] = list(self.split_idx["valid"]) + + # Convert to tensors + self.split_idx = { + k: torch.tensor(v, dtype=torch.long) + for k, v in self.split_idx.items() + } + + logger.info( + f"Loaded {self.task.upper()} dataset: " + f"{len(self.split_idx['train'])} train, " + f"{len(self.split_idx['valid'])} val, " + f"{len(self.split_idx['test'])} test" + ) + + def _open_single_lmdb(self, lmdb_path: Path) -> tuple: + """Open a single LMDB file and return (env, size). + + Parameters + ---------- + lmdb_path : Path + Path to LMDB file. + + Returns + ------- + tuple + (environment, size) tuple. + """ + env = lmdb.open( + str(lmdb_path.resolve()), + subdir=False, + readonly=True, + lock=False, + readahead=True, + meminit=False, + max_readers=1, + ) + size = env.stat()["entries"] + return env, size + + def _find_lmdb_and_local_idx(self, idx: int) -> tuple: + """Find which LMDB contains the given index and the local index within it. + + Parameters + ---------- + idx : int + Global dataset index. + + Returns + ------- + tuple + (lmdb_index, local_index) tuple. + """ + if idx < 0 or idx >= len(self): + raise IndexError(f"Index {idx} out of range [0, {len(self)})") + + # Binary search for the LMDB + left, right = 0, len(self.envs) + while left < right - 1: + mid = (left + right) // 2 + if self.cumulative_sizes[mid] <= idx: + left = mid + else: + right = mid + + lmdb_idx = left + local_idx = idx - self.cumulative_sizes[lmdb_idx] + return lmdb_idx, local_idx + + def len(self) -> int: + """Get dataset length.""" + return self.cumulative_sizes[-1] + + def get(self, idx: int) -> Data: + """Get data sample at index. + + Parameters + ---------- + idx : int + Sample index. + + Returns + ------- + Data + PyTorch Geometric Data object. + """ + lmdb_idx, local_idx = self._find_lmdb_and_local_idx(idx) + lmdb_path, env, _ = self.envs[lmdb_idx] + + with env.begin() as txn: + cursor = txn.cursor() + if not cursor.first(): + raise RuntimeError(f"Empty LMDB at {lmdb_path}") + + # Navigate to the target entry + for _ in range(local_idx): + if not cursor.next(): + raise RuntimeError( + f"Index {local_idx} out of range in {lmdb_path}" + ) + + key, value = cursor.item() + data = pickle.loads(value) + + # Convert to legacy format if needed + if self.legacy_format and isinstance(data, Data): + data = Data( + **{k: v for k, v in data.__dict__.items() if v is not None} + ) + + return data + + def __len__(self) -> int: + """Get dataset length.""" + return self.len() + + def __getitem__(self, idx: int) -> Data: + """Get item at index.""" + return self.get(idx) + + def __iter__(self) -> Iterator[Data]: + """Iterate over dataset.""" + for i in range(len(self)): + yield self[i] + + def __del__(self): + """Clean up LMDB environments.""" + if hasattr(self, "envs"): + for _, env, _ in self.envs: + env.close() + + @property + def num_node_features(self) -> int: + """Number of node features per atom.""" + # Will be determined by the actual data + return 1 # Atomic numbers + + @property + def data(self): + """Get combined data view for compatibility with InMemoryDataset API. + + Returns a Data object with x and y attributes representing stacked + features from a sample of the dataset. + """ + if not hasattr(self, "_data_cache"): + # Get a sample to determine feature dimensions + if len(self) > 0: + sample = self.get(0) + # Create a mock data object with minimal info for compatibility + self._data_cache = Data( + x=sample.x if hasattr(sample, "x") else torch.zeros(1, 1), + y=sample.y if hasattr(sample, "y") else torch.zeros(1), + ) + else: + self._data_cache = Data(x=torch.zeros(1, 1), y=torch.zeros(1)) + return self._data_cache + + @property + def num_classes(self) -> int: + """Number of classes (regression task).""" + return 1 # Single regression target (energy) diff --git a/topobench/data/datasets/oc20_is2re_dataset.py b/topobench/data/datasets/oc20_is2re_dataset.py new file mode 100644 index 000000000..d0bc15ecb --- /dev/null +++ b/topobench/data/datasets/oc20_is2re_dataset.py @@ -0,0 +1,378 @@ +"""Dataset class for Open Catalyst 2020 (OC20) IS2RE dataset.""" + +from __future__ import annotations + +import logging +import pickle +from collections.abc import Iterator +from pathlib import Path + +import torch +from omegaconf import DictConfig +from torch_geometric.data import Data, Dataset + +try: + import lmdb + + HAS_LMDB = True +except ImportError: + lmdb = None + HAS_LMDB = False + +logger = logging.getLogger(__name__) + + +class IS2REDataset(Dataset): + """Dataset class for Open Catalyst 2020 (OC20) IS2RE task. + + IS2RE (Initial Structure to Relaxed Energy) is a task for catalyst + discovery and materials science. + + The OC20 dataset contains DFT calculations for catalyst-adsorbate systems, + enabling machine learning models to predict energies for accelerated + materials discovery. + + Parameters + ---------- + root : str + Root directory where the dataset is stored. + name : str + Name of the dataset. + parameters : DictConfig + Configuration parameters for the dataset. + """ + + def __init__( + self, + root: str, + name: str, + parameters: DictConfig, + ) -> None: + self.name = name + self.parameters = parameters + + # Task configuration + self.task = "is2re" + self.dtype = self._parse_dtype(parameters.get("dtype", "float32")) + self.legacy_format = parameters.get("legacy_format", False) + self.include_test = parameters.get("include_test", True) + + # Limit for fast experimentation + self.max_samples = parameters.get("max_samples", None) + if self.max_samples is not None: + self.max_samples = int(self.max_samples) + logger.info( + f"⚠️ Limiting dataset to {self.max_samples} samples for fast experimentation" + ) + + super().__init__(root) + + # Open LMDB environments + self._open_lmdbs() + + def __repr__(self) -> str: + return f"{self.name}(root={self.root}, task={self.task}, size={len(self)})" + + @staticmethod + def _parse_dtype(dtype) -> torch.dtype: + """Parse dtype parameter to torch.dtype.""" + if isinstance(dtype, str): + return getattr(torch, dtype) + return dtype + + def _get_data_paths(self) -> dict[str, list[Path]]: + """Get paths to LMDB files for each split. + + Returns + ------- + dict[str, list[Path]] + Dictionary mapping split names to lists of LMDB file paths. + """ + root = Path(self.root) + paths = {"train": [], "val": [], "test": []} + + # The downloaded data is extracted to: + # root/is2res_train_val_test_lmdbs/data/is2re/all/{train,val_id,test_id}/data.lmdb + base_path = ( + root / "is2res_train_val_test_lmdbs" / "data" / "is2re" / "all" + ) + + if base_path.exists(): + # Train data + train_lmdb = base_path / "train" / "data.lmdb" + if train_lmdb.exists(): + paths["train"] = [train_lmdb] + + # Validation data + val_lmdb = base_path / "val_id" / "data.lmdb" + if val_lmdb.exists(): + paths["val"] = [val_lmdb] + + # Test data + test_lmdb = base_path / "test_id" / "data.lmdb" + if test_lmdb.exists(): + paths["test"] = [test_lmdb] + + return paths + + def _open_lmdbs(self): + """Open LMDB files and create split mappings.""" + if not HAS_LMDB: + raise ImportError( + "LMDB is required for IS2RE dataset. Install with: pip install lmdb" + ) + + paths = self._get_data_paths() + + # Initialize storage + self.envs = [] + self.cumulative_sizes = [0] + self.split_idx = {"train": [], "valid": [], "test": []} + + current_idx = 0 + + # Open train LMDBs + for lmdb_path in paths["train"]: + env, size = self._open_single_lmdb(lmdb_path) + # Apply max_samples limit if specified (per split) + if self.max_samples is not None: + size = min(size, self.max_samples) + self.envs.append((lmdb_path, env, size)) + self.cumulative_sizes.append(self.cumulative_sizes[-1] + size) + self.split_idx["train"].extend( + range(current_idx, current_idx + size) + ) + current_idx += size + + # Open validation LMDBs + for lmdb_path in paths["val"]: + env, size = self._open_single_lmdb(lmdb_path) + # Apply max_samples limit if specified (per split) + if self.max_samples is not None: + size = min(size, self.max_samples) + self.envs.append((lmdb_path, env, size)) + self.cumulative_sizes.append(self.cumulative_sizes[-1] + size) + self.split_idx["valid"].extend( + range(current_idx, current_idx + size) + ) + current_idx += size + + # Open test LMDBs + for lmdb_path in paths["test"]: + env, size = self._open_single_lmdb(lmdb_path) + # Apply max_samples limit if specified (per split) + if self.max_samples is not None: + size = min(size, self.max_samples) + self.envs.append((lmdb_path, env, size)) + self.cumulative_sizes.append(self.cumulative_sizes[-1] + size) + self.split_idx["test"].extend( + range(current_idx, current_idx + size) + ) + current_idx += size + + # If no test data, reuse validation indices + if not self.include_test or len(self.split_idx["test"]) == 0: + self.split_idx["test"] = list(self.split_idx["valid"]) + + # Convert to tensors + self.split_idx = { + k: torch.tensor(v, dtype=torch.long) + for k, v in self.split_idx.items() + } + + logger.info( + f"Loaded {self.task.upper()} dataset: " + f"{len(self.split_idx['train'])} train, " + f"{len(self.split_idx['valid'])} val, " + f"{len(self.split_idx['test'])} test" + ) + + def _open_single_lmdb(self, lmdb_path: Path) -> tuple: + """Open a single LMDB file and return (env, size). + + Parameters + ---------- + lmdb_path : Path + Path to LMDB file. + + Returns + ------- + tuple + (environment, size) tuple. + """ + env = lmdb.open( + str(lmdb_path.resolve()), + subdir=False, + readonly=True, + lock=False, + readahead=True, + meminit=False, + max_readers=1, + ) + size = env.stat()["entries"] + return env, size + + def _find_lmdb_and_local_idx(self, idx: int) -> tuple: + """Find which LMDB contains the given index and the local index within it. + + Parameters + ---------- + idx : int + Global dataset index. + + Returns + ------- + tuple + (lmdb_index, local_index) tuple. + """ + if idx < 0 or idx >= len(self): + raise IndexError(f"Index {idx} out of range [0, {len(self)})") + + # Binary search for the LMDB + left, right = 0, len(self.envs) + while left < right - 1: + mid = (left + right) // 2 + if self.cumulative_sizes[mid] <= idx: + left = mid + else: + right = mid + + lmdb_idx = left + local_idx = idx - self.cumulative_sizes[lmdb_idx] + return lmdb_idx, local_idx + + def len(self) -> int: + """Get dataset length.""" + return self.cumulative_sizes[-1] + + def get(self, idx: int) -> Data: + """Get data sample at index. + + Parameters + ---------- + idx : int + Sample index. + + Returns + ------- + Data + PyTorch Geometric Data object. + """ + lmdb_idx, local_idx = self._find_lmdb_and_local_idx(idx) + lmdb_path, env, _ = self.envs[lmdb_idx] + + with env.begin() as txn: + cursor = txn.cursor() + if not cursor.first(): + raise RuntimeError(f"Empty LMDB at {lmdb_path}") + + # Navigate to the target entry + for _ in range(local_idx): + if not cursor.next(): + raise RuntimeError( + f"Index {local_idx} out of range in {lmdb_path}" + ) + + key, value = cursor.item() + data = pickle.loads(value) + + # Convert old PyG Data objects to new format by extracting all attributes + if isinstance(data, Data): + try: + # Check if this is old format data + # Old PyG format has attributes directly in __dict__ without proper _store + if "_store" not in data.__dict__ or any( + k in data.__dict__ for k in ["x", "edge_index", "pos"] + ): + # Extract all data attributes + data_dict = {} + # Get all tensor/data attributes from __dict__ + data_dict = { + key: val + for key, val in data.__dict__.items() + if not key.startswith("_") and val is not None + } + # Convert y_relaxed to y before creating new Data object + if "y_relaxed" in data_dict: + data_dict["y"] = torch.tensor( + [data_dict["y_relaxed"]] + ).float() + elif "y" not in data_dict: + data_dict["y"] = torch.tensor([float("nan")]).float() + + # Use atomic numbers as node features (x) + if "atomic_numbers" in data_dict: + data_dict["x"] = ( + data_dict["atomic_numbers"].view(-1, 1).float() + ) + elif "x" not in data_dict and "pos" in data_dict: + data_dict["x"] = torch.ones( + (data_dict["pos"].shape[0], 1) + ) + + # Create edge_index from atomic positions using radius graph + if "edge_index" not in data_dict and "pos" in data_dict: + from torch_geometric.nn import radius_graph + + data_dict["edge_index"] = radius_graph( + data_dict["pos"], r=5.0, max_num_neighbors=50 + ) + + # Keep only standard PyG attributes + standard_attrs = [ + "x", + "edge_index", + "edge_attr", + "pos", + "y", + "batch", + ] + cleaned_dict = { + k: v + for k, v in data_dict.items() + if k in standard_attrs + } + + # Create a completely new Data object with current PyG format + data = Data(**cleaned_dict) + except (AttributeError, KeyError, RuntimeError, TypeError): + # If extraction fails, pass through + pass + + # Convert to legacy format if needed + if self.legacy_format and isinstance(data, Data): + data = Data( + **{k: v for k, v in data.__dict__.items() if v is not None} + ) + + return data + + def __len__(self) -> int: + """Get dataset length.""" + return self.len() + + def __getitem__(self, idx: int) -> Data: + """Get item at index.""" + return self.get(idx) + + def __iter__(self) -> Iterator[Data]: + """Iterate over dataset.""" + for i in range(len(self)): + yield self[i] + + def __del__(self): + """Clean up LMDB environments.""" + if hasattr(self, "envs"): + for _, env, _ in self.envs: + env.close() + + @property + def num_node_features(self) -> int: + """Number of node features per atom.""" + # Will be determined by the actual data + return 1 # Atomic numbers + + @property + def num_classes(self) -> int: + """Number of classes (regression task).""" + return 1 # Single regression target (energy) diff --git a/topobench/data/datasets/oc22_is2re_dataset.py b/topobench/data/datasets/oc22_is2re_dataset.py new file mode 100644 index 000000000..6fe7d3955 --- /dev/null +++ b/topobench/data/datasets/oc22_is2re_dataset.py @@ -0,0 +1,411 @@ +"""Dataset class for Open Catalyst 2022 (OC22) IS2RE dataset.""" + +from __future__ import annotations + +import logging +import pickle +from collections.abc import Iterator +from pathlib import Path + +import torch +from omegaconf import DictConfig +from torch_geometric.data import Data, Dataset + +try: + import lmdb + + HAS_LMDB = True +except ImportError: + lmdb = None + HAS_LMDB = False + +logger = logging.getLogger(__name__) + + +class OC22IS2REDataset(Dataset): + """Dataset class for Open Catalyst 2022 (OC22) IS2RE task. + + IS2RE (Initial Structure to Relaxed Energy) is a task for catalyst + discovery and materials science. + + The OC22 dataset contains DFT calculations for catalyst-adsorbate systems, + enabling machine learning models to predict energies for accelerated + materials discovery. + + Parameters + ---------- + root : str + Root directory where the dataset is stored. + name : str + Name of the dataset. + parameters : DictConfig + Configuration parameters for the dataset. + """ + + def __init__( + self, + root: str, + name: str, + parameters: DictConfig, + ) -> None: + self.name = name + self.parameters = parameters + + # Task configuration + self.task = "oc22_is2re" + self.dtype = self._parse_dtype(parameters.get("dtype", "float32")) + self.legacy_format = parameters.get("legacy_format", False) + self.include_test = parameters.get("include_test", True) + + # Limit for fast experimentation + self.max_samples = parameters.get("max_samples", None) + if self.max_samples is not None: + self.max_samples = int(self.max_samples) + logger.info( + f"⚠️ Limiting dataset to {self.max_samples} samples for fast experimentation" + ) + + super().__init__(root) + + # Open LMDB environments + self._open_lmdbs() + + def __repr__(self) -> str: + return f"{self.name}(root={self.root}, task={self.task}, size={len(self)})" + + @staticmethod + def _parse_dtype(dtype) -> torch.dtype: + """Parse dtype parameter to torch.dtype.""" + if isinstance(dtype, str): + return getattr(torch, dtype) + return dtype + + def _get_data_paths(self) -> dict[str, list[Path]]: + """Get paths to LMDB files for each split. + + Returns + ------- + dict[str, list[Path]] + Dictionary mapping split names to lists of LMDB file paths. + """ + root = Path(self.root) + paths = {"train": [], "val": [], "test": []} + + # The downloaded data is extracted to: + # root/is2res_total_train_val_test_lmdbs/data/oc22/is2re-total/{train,val_id,test_id}/*.lmdb + base_path = ( + root + / "is2res_total_train_val_test_lmdbs" + / "data" + / "oc22" + / "is2re-total" + ) + + if base_path.exists(): + # Train data - multiple LMDB files + train_dir = base_path / "train" + if train_dir.exists(): + paths["train"] = sorted(train_dir.glob("*.lmdb")) + + # Validation data - using val_id split + val_dir = base_path / "val_id" + if val_dir.exists(): + paths["val"] = sorted(val_dir.glob("*.lmdb")) + + # Test data - using test_id split + test_dir = base_path / "test_id" + if test_dir.exists(): + paths["test"] = sorted(test_dir.glob("*.lmdb")) + + return paths + + def _open_lmdbs(self): + """Open LMDB files and create split mappings.""" + if not HAS_LMDB: + raise ImportError( + "LMDB is required for OC22 IS2RE dataset. Install with: pip install lmdb" + ) + + paths = self._get_data_paths() + + # Initialize storage + self.envs = [] + self.cumulative_sizes = [0] + self.split_idx = {"train": [], "valid": [], "test": []} + + current_idx = 0 + + # Open train LMDBs with cumulative max_samples limiting (per split) + train_samples_remaining = ( + self.max_samples if self.max_samples is not None else None + ) + for lmdb_path in paths["train"]: + if ( + train_samples_remaining is not None + and train_samples_remaining <= 0 + ): + break + env, size = self._open_single_lmdb(lmdb_path) + # Apply remaining limit + if train_samples_remaining is not None: + size = min(size, train_samples_remaining) + train_samples_remaining -= size + self.envs.append((lmdb_path, env, size)) + self.cumulative_sizes.append(self.cumulative_sizes[-1] + size) + self.split_idx["train"].extend( + range(current_idx, current_idx + size) + ) + current_idx += size + + # Open validation LMDBs with cumulative max_samples limiting (per split) + val_samples_remaining = ( + self.max_samples if self.max_samples is not None else None + ) + for lmdb_path in paths["val"]: + if ( + val_samples_remaining is not None + and val_samples_remaining <= 0 + ): + break + env, size = self._open_single_lmdb(lmdb_path) + # Apply remaining limit + if val_samples_remaining is not None: + size = min(size, val_samples_remaining) + val_samples_remaining -= size + self.envs.append((lmdb_path, env, size)) + self.cumulative_sizes.append(self.cumulative_sizes[-1] + size) + self.split_idx["valid"].extend( + range(current_idx, current_idx + size) + ) + current_idx += size + + # Open test LMDBs with cumulative max_samples limiting (per split) + test_samples_remaining = ( + self.max_samples if self.max_samples is not None else None + ) + for lmdb_path in paths["test"]: + if ( + test_samples_remaining is not None + and test_samples_remaining <= 0 + ): + break + env, size = self._open_single_lmdb(lmdb_path) + # Apply remaining limit + if test_samples_remaining is not None: + size = min(size, test_samples_remaining) + test_samples_remaining -= size + self.envs.append((lmdb_path, env, size)) + self.cumulative_sizes.append(self.cumulative_sizes[-1] + size) + self.split_idx["test"].extend( + range(current_idx, current_idx + size) + ) + current_idx += size + + # If no test data, reuse validation indices + if not self.include_test or len(self.split_idx["test"]) == 0: + self.split_idx["test"] = list(self.split_idx["valid"]) + + # Convert to tensors + self.split_idx = { + k: torch.tensor(v, dtype=torch.long) + for k, v in self.split_idx.items() + } + + logger.info( + f"Loaded {self.task.upper()} dataset: " + f"{len(self.split_idx['train'])} train, " + f"{len(self.split_idx['valid'])} val, " + f"{len(self.split_idx['test'])} test" + ) + + def _open_single_lmdb(self, lmdb_path: Path) -> tuple: + """Open a single LMDB file and return (env, size). + + Parameters + ---------- + lmdb_path : Path + Path to LMDB file. + + Returns + ------- + tuple + (environment, size) tuple. + """ + env = lmdb.open( + str(lmdb_path.resolve()), + subdir=False, + readonly=True, + lock=False, + readahead=True, + meminit=False, + max_readers=1, + ) + size = env.stat()["entries"] + return env, size + + def _find_lmdb_and_local_idx(self, idx: int) -> tuple: + """Find which LMDB contains the given index and the local index within it. + + Parameters + ---------- + idx : int + Global dataset index. + + Returns + ------- + tuple + (lmdb_index, local_index) tuple. + """ + if idx < 0 or idx >= len(self): + raise IndexError(f"Index {idx} out of range [0, {len(self)})") + + # Binary search for the LMDB + left, right = 0, len(self.envs) + while left < right - 1: + mid = (left + right) // 2 + if self.cumulative_sizes[mid] <= idx: + left = mid + else: + right = mid + + lmdb_idx = left + local_idx = idx - self.cumulative_sizes[lmdb_idx] + return lmdb_idx, local_idx + + def len(self) -> int: + """Get dataset length.""" + return self.cumulative_sizes[-1] + + def get(self, idx: int) -> Data: + """Get data sample at index. + + Parameters + ---------- + idx : int + Sample index. + + Returns + ------- + Data + PyTorch Geometric Data object. + """ + lmdb_idx, local_idx = self._find_lmdb_and_local_idx(idx) + lmdb_path, env, _ = self.envs[lmdb_idx] + + with env.begin() as txn: + cursor = txn.cursor() + if not cursor.first(): + raise RuntimeError(f"Empty LMDB at {lmdb_path}") + + # Navigate to the target entry + for _ in range(local_idx): + if not cursor.next(): + raise RuntimeError( + f"Index {local_idx} out of range in {lmdb_path}" + ) + + key, value = cursor.item() + data = pickle.loads(value) + + # Convert old PyG Data objects to new format by extracting all attributes + if isinstance(data, Data): + try: + # Check if this is old format data + # Old PyG format has attributes directly in __dict__ without proper _store + if "_store" not in data.__dict__ or any( + k in data.__dict__ for k in ["x", "edge_index", "pos"] + ): + # Extract all data attributes + data_dict = {} + # Get all tensor/data attributes from __dict__ + data_dict = { + key: val + for key, val in data.__dict__.items() + if not key.startswith("_") and val is not None + } + + # Convert y_relaxed to y before creating new Data object + if "y_relaxed" in data_dict: + data_dict["y"] = torch.tensor( + [data_dict["y_relaxed"]] + ).float() + elif "y" not in data_dict: + data_dict["y"] = torch.tensor([float("nan")]).float() + + # Use atomic numbers as node features (x) + if "atomic_numbers" in data_dict: + data_dict["x"] = ( + data_dict["atomic_numbers"].view(-1, 1).float() + ) + elif "x" not in data_dict and "pos" in data_dict: + data_dict["x"] = torch.ones( + (data_dict["pos"].shape[0], 1) + ) + + # Create edge_index from atomic positions using radius graph + if "edge_index" not in data_dict and "pos" in data_dict: + from torch_geometric.nn import radius_graph + + data_dict["edge_index"] = radius_graph( + data_dict["pos"], r=5.0, max_num_neighbors=50 + ) + + # Keep only standard PyG attributes + standard_attrs = [ + "x", + "edge_index", + "edge_attr", + "pos", + "y", + "batch", + ] + cleaned_dict = { + k: v + for k, v in data_dict.items() + if k in standard_attrs + } + + # Create a completely new Data object with current PyG format + data = Data(**cleaned_dict) + + except (AttributeError, KeyError, RuntimeError, TypeError): + # If extraction fails, pass through + pass + + # Convert to legacy format if needed + if self.legacy_format and isinstance(data, Data): + data = Data( + **{k: v for k, v in data.__dict__.items() if v is not None} + ) + + return data + + def __len__(self) -> int: + """Get dataset length.""" + return self.len() + + def __getitem__(self, idx: int) -> Data: + """Get item at index.""" + return self.get(idx) + + def __iter__(self) -> Iterator[Data]: + """Iterate over dataset.""" + for i in range(len(self)): + yield self[i] + + def __del__(self): + """Clean up LMDB environments.""" + if hasattr(self, "envs"): + for _, env, _ in self.envs: + env.close() + + @property + def num_node_features(self) -> int: + """Number of node features per atom.""" + # Will be determined by the actual data + return 1 # Atomic numbers + + @property + def num_classes(self) -> int: + """Number of classes (regression task).""" + return 1 # Single regression target (energy) diff --git a/topobench/data/loaders/graph/oc20_asedbs2ef_loader.py b/topobench/data/loaders/graph/oc20_asedbs2ef_loader.py new file mode 100644 index 000000000..e343ca4a6 --- /dev/null +++ b/topobench/data/loaders/graph/oc20_asedbs2ef_loader.py @@ -0,0 +1,501 @@ +"""Loader for OC20 S2EF dataset using ASE DB backend.""" + +from __future__ import annotations + +import logging +from pathlib import Path + +import torch +from omegaconf import DictConfig +from torch_geometric.data import Data, Dataset + +from topobench.data.loaders.base import AbstractLoader +from topobench.data.preprocessor.oc20_s2ef_preprocessor import ( + HAS_ASE, + AtomsToGraphs, + needs_preprocessing, + preprocess_s2ef_split_ase, +) + +if HAS_ASE: + import ase.db + +logger = logging.getLogger(__name__) + + +class OC20ASEDBDataset(Dataset): + """ASE DB dataset for OC20 S2EF structures. + + Parameters + ---------- + db_paths : list[str | Path] | None + Backwards-compatible single list of DBs without explicit splits. + train_db_paths : list[str | Path] | None + List of ASE DB file paths for the training split. + val_db_paths : list[str | Path] | None + List of ASE DB file paths for the validation split. + test_db_paths : list[str | Path] | None + List of ASE DB file paths for the test split. + max_neigh : int + Maximum number of neighbors per atom. + radius : float + Cutoff radius in Angstroms. + dtype : torch.dtype + Torch dtype used for tensors. + include_energy : bool + Whether to include energy information. + include_forces : bool + Whether to include forces information. + max_samples : int | None + Maximum number of samples to load for debugging or fast runs. + """ + + def __init__( + self, + db_paths: list[str | Path] | None = None, + *, + train_db_paths: list[str | Path] | None = None, + val_db_paths: list[str | Path] | None = None, + test_db_paths: list[str | Path] | None = None, + max_neigh: int = 50, + radius: float = 6.0, + dtype: torch.dtype = torch.float32, + include_energy: bool = True, + include_forces: bool = True, + max_samples: int | None = None, + ): + """Initialize dataset from ASE DB files. + + See class docstring for parameter descriptions. + """ + if not HAS_ASE: + raise ImportError("ASE required for S2EF datasets") + + super().__init__() + self.dtype = dtype + + # Converter + self.converter = AtomsToGraphs( + max_neigh=max_neigh, + radius=radius, + r_energy=include_energy, + r_forces=include_forces, + r_distances=True, + r_edges=True, + r_fixed=True, + ) + + # Normalize input options + if db_paths is not None and any( + x is not None + for x in (train_db_paths, val_db_paths, test_db_paths) + ): + raise ValueError( + "Provide either `db_paths` or the split-specific lists, not both." + ) + + if db_paths is not None: + train_db_paths = list(db_paths) + val_db_paths = [] + test_db_paths = [] + else: + train_db_paths = list(train_db_paths or []) + val_db_paths = list(val_db_paths or []) + test_db_paths = list(test_db_paths or []) + + # Track DB files per split + self._per_split_db_paths: dict[str, list[Path]] = { + "train": [Path(p) for p in train_db_paths], + "valid": [Path(p) for p in val_db_paths], + "test": [Path(p) for p in test_db_paths], + } + self.db_paths: list[Path] = ( + self._per_split_db_paths["train"] + + self._per_split_db_paths["valid"] + + self._per_split_db_paths["test"] + ) + + # Count total structures and build split indices + self._num_samples = 0 + self._db_ranges: list[ + tuple[Path, int, int] + ] = [] # (db_path, start, end) + self.split_idx: dict[str, list[int]] = { + "train": [], + "valid": [], + "test": [], + } + + for split_name in ("train", "valid", "test"): + for db_path in self._per_split_db_paths[split_name]: + with ase.db.connect(str(db_path)) as db: + count = db.count() + start = self._num_samples + end = start + count + self._db_ranges.append((db_path, start, end)) + # Append global indices for this DB to the right split + self.split_idx[split_name].extend(range(start, end)) + self._num_samples = end + + # Apply max_samples limit if specified (per split, not total) + if max_samples is not None: + logger.info(f"Limiting each split to {max_samples} samples") + # When limiting, we need to: + # 1. Truncate the split_idx lists + # 2. Create a mapping from limited indices to new contiguous indices + # 3. Update _num_samples to reflect the new total + # This ensures len(dataset) returns the correct value and prevents + # unnecessary iteration over the full dataset during preprocessing + + # Collect indices from all splits that we want to keep + all_limited_indices = [] + for split_name in ("train", "valid", "test"): + if self.split_idx[split_name]: + original_len = len(self.split_idx[split_name]) + new_len = min(max_samples, original_len) + self.split_idx[split_name] = self.split_idx[split_name][ + :new_len + ] + all_limited_indices.extend(self.split_idx[split_name]) + logger.info( + f" {split_name}: {original_len} -> {new_len} samples" + ) + + # Create mapping from old indices to new contiguous indices + old_to_new_idx = { + old_idx: new_idx + for new_idx, old_idx in enumerate( + sorted(set(all_limited_indices)) + ) + } + + # Remap split_idx to use new contiguous indices + for split_name in ("train", "valid", "test"): + self.split_idx[split_name] = [ + old_to_new_idx[idx] for idx in self.split_idx[split_name] + ] + + # Store mapping for __getitem__ to translate back to original indices + self._index_mapping = sorted(set(all_limited_indices)) + + # Update _num_samples to the actual limited size + self._num_samples = len(self._index_mapping) + logger.info( + f"Dataset length limited to {self._num_samples} samples" + ) + + logger.info( + f"Loaded {len(self.db_paths)} DB files with {self._num_samples} total structures" + ) + + def __len__(self) -> int: + """Return dataset length. + + Returns + ------- + int + Number of samples in the dataset. + """ + return self._num_samples + + @property + def num_node_features(self) -> int: + """Number of node features per atom. + + Returns + ------- + int + Number of node features (atomic numbers by default). + """ + return 1 # Atomic numbers + + @property + def data(self): + """Get combined data view for compatibility with InMemoryDataset API. + + Returns a Data object with x and y attributes representing stacked + features from a sample of the dataset. + + Returns + ------- + Data + A Data object with x and y attributes for API compatibility. + """ + if not hasattr(self, "_data_cache"): + # Get a sample to determine feature dimensions + if len(self) > 0: + sample = self[0] + # Create a mock data object with minimal info for compatibility + self._data_cache = Data( + x=sample.x if hasattr(sample, "x") else torch.zeros(1, 1), + y=sample.y if hasattr(sample, "y") else torch.zeros(1), + ) + else: + self._data_cache = Data(x=torch.zeros(1, 1), y=torch.zeros(1)) + return self._data_cache + + def _get_db_and_idx(self, idx: int) -> tuple[Path, int]: + """Get DB path and local index for global index. + + Parameters + ---------- + idx : int + Global index. + + Returns + ------- + tuple[Path, int] + Database path and local index within that database. + """ + if idx < 0 or idx >= self._num_samples: + raise IndexError( + f"Index {idx} out of range [0, {self._num_samples})" + ) + # Binary search could be used; linear scan is fine for moderate DB counts + for db_path, start, end in self._db_ranges: + if start <= idx < end: + local_idx = (idx - start) + 1 # ASE rows are 1-indexed + return db_path, local_idx + raise IndexError(f"Index {idx} not found in DB ranges") + + def __getitem__(self, idx: int) -> Data: + """Get a single graph by index. + + Parameters + ---------- + idx : int + Index of the graph to retrieve. + + Returns + ------- + Data + PyTorch Geometric Data object. + """ + # If we have an index mapping (from max_samples limiting), translate the index + if hasattr(self, "_index_mapping"): + if idx < 0 or idx >= len(self._index_mapping): + raise IndexError( + f"Index {idx} out of range [0, {len(self._index_mapping)})" + ) + idx = self._index_mapping[idx] + + db_path, local_idx = self._get_db_and_idx(idx) + + with ase.db.connect(str(db_path)) as db: + row = db.get(id=local_idx) + atoms = row.toatoms() + + # Add metadata + if hasattr(row, "data") and row.data: + atoms.info["data"] = row.data + + # Convert to PyG + data = self.converter.convert(atoms) + + # Cast dtype + if getattr(data, "pos", None) is not None: + data.pos = data.pos.to(self.dtype) + if getattr(data, "edge_attr", None) is not None: + data.edge_attr = data.edge_attr.to(self.dtype) + + return data + + +class OC20S2EFDatasetLoader(AbstractLoader): + """Loader for OC20 S2EF dataset using ASE DB. + + Parameters + ---------- + parameters : DictConfig + Configuration dictionary (usually hydra DictConfig) with dataset options. + """ + + def __init__(self, parameters: DictConfig) -> None: + super().__init__(parameters) + + def load_dataset(self) -> Dataset: + """Load OC20 S2EF dataset from ASE database files. + + Configuration parameters: + - train_split: "200K", "2M", "20M", or "all" + - val_splits: list like ["val_id", "val_ood_ads"] or None for all + - include_test: bool + - download: bool (download raw data) + - max_neigh: int (default 50) + - radius: float (default 6.0 Angstroms) + - dtype: str (default "float32") + + Returns + ------- + OC20ASEDBDataset + Dataset with train/val/test splits. + """ + if not HAS_ASE: + raise ImportError("ASE required for OC20 S2EF dataset") + + data_dir = Path(self.get_data_dir()) + train_split = getattr(self.parameters, "train_split", "200K") + val_splits_param = getattr(self.parameters, "val_splits", None) + if val_splits_param is None: + # default: use all 4 validation splits + val_splits = [ + "val_id", + "val_ood_ads", + "val_ood_cat", + "val_ood_both", + ] + elif isinstance(val_splits_param, (list, tuple)): + val_splits = list(val_splits_param) + else: + val_splits = [str(val_splits_param)] + + include_test = bool(getattr(self.parameters, "include_test", False)) + download = bool(getattr(self.parameters, "download", False)) + max_neigh = int(getattr(self.parameters, "max_neigh", 50)) + radius = float(getattr(self.parameters, "radius", 6.0)) + dtype_str = str(getattr(self.parameters, "dtype", "float32")) + dtype = getattr(torch, dtype_str) + + # Download if needed (raw extxyz/txt files) - not implemented + if download: + logger.warning( + f"S2EF download not implemented. Please download manually to {data_dir}/s2ef/{train_split}/train" + ) + + # Preprocess to ASE DB if needed for train/val/test + self._ensure_asedb_preprocessed( + data_dir, train_split, val_splits, include_test, max_neigh, radius + ) + + # Collect DB files + train_db_files = self._collect_db_files( + data_dir / "s2ef" / train_split / "train" + ) + val_db_files: list[Path] = [] + for vs in val_splits: + val_dir = data_dir / "s2ef" / "all" / vs + val_db_files.extend(self._collect_db_files(val_dir)) + test_db_files: list[Path] = [] + if include_test: + test_dir = data_dir / "s2ef" / "all" / "test" + test_db_files = self._collect_db_files(test_dir) + + if not train_db_files: + raise RuntimeError( + f"No ASE DB files found in {data_dir}/s2ef/{train_split}/train. Preprocessing may have failed." + ) + + logger.info( + f"Loading {len(train_db_files)} train DBs, {len(val_db_files)} val DBs, {len(test_db_files)} test DBs" + ) + + ds = OC20ASEDBDataset( + train_db_paths=[str(p) for p in train_db_files], + val_db_paths=[str(p) for p in val_db_files], + test_db_paths=[str(p) for p in test_db_files], + max_neigh=max_neigh, + radius=radius, + dtype=dtype, + include_energy=True, + include_forces=True, + ) + + # Expose split_idx for fixed split handling downstream + # Already set inside the dataset; nothing else to do here. + return ds + + def _ensure_asedb_preprocessed( + self, + root: Path, + train_split: str, + val_splits: list[str], + include_test: bool, + max_neigh: int, + radius: float, + ) -> None: + """Ensure ASE DB files are preprocessed for the requested splits. + + Parameters + ---------- + root : Path + Root data directory containing the S2EF dataset. + train_split : str + Name of the training split (e.g. "200K"). + val_splits : list[str] + List of validation split names. + include_test : bool + Whether to ensure preprocessing for the test split. + max_neigh : int + Maximum number of neighbors per atom. + radius : float + Cutoff radius for neighbor search in Angstroms. + + Returns + ------- + None + Performs preprocessing as a side-effect; no value is returned. + """ + # Train + train_dir = root / "s2ef" / train_split / "train" + if train_dir.exists() and needs_preprocessing(train_dir): + logger.info(f"Preprocessing {train_dir}") + preprocess_s2ef_split_ase( + data_path=train_dir, + out_path=train_dir, + num_workers=4, + max_neigh=max_neigh, + radius=radius, + ) + + # Validation + for val_split in val_splits: + val_dir = root / "s2ef" / "all" / val_split + if val_dir.exists() and needs_preprocessing(val_dir): + logger.info(f"Preprocessing {val_dir}") + preprocess_s2ef_split_ase( + data_path=val_dir, + out_path=val_dir, + num_workers=4, + max_neigh=max_neigh, + radius=radius, + ) + + # Test + if include_test: + test_dir = root / "s2ef" / "all" / "test" + if test_dir.exists() and needs_preprocessing(test_dir): + logger.info(f"Preprocessing {test_dir}") + preprocess_s2ef_split_ase( + data_path=test_dir, + out_path=test_dir, + num_workers=4, + max_neigh=max_neigh, + radius=radius, + ) + + def _collect_db_files(self, directory: Path) -> list[Path]: + """Collect all ASE DB files in directory. + + Parameters + ---------- + directory : Path + Directory to search. + + Returns + ------- + list[Path] + Sorted list of DB file paths. + """ + if not directory.exists(): + return [] + return sorted(directory.glob("*.db")) + + def get_data_dir(self) -> Path: + """Get data directory path. + + Returns + ------- + Path + Path to data directory. + """ + return Path(super().get_data_dir()) diff --git a/topobench/data/loaders/graph/oc20_dataset_loader.py b/topobench/data/loaders/graph/oc20_dataset_loader.py new file mode 100644 index 000000000..7a79cc9f9 --- /dev/null +++ b/topobench/data/loaders/graph/oc20_dataset_loader.py @@ -0,0 +1,365 @@ +"""Loader for OC20 S2EF dataset.""" + +import logging +from pathlib import Path + +import torch +from omegaconf import DictConfig +from torch_geometric.data import Dataset + +from topobench.data.datasets.oc20_dataset import OC20Dataset +from topobench.data.loaders.base import AbstractLoader +from topobench.data.utils.oc20_download import ( + download_s2ef_dataset, +) + +# Import ASE DB fallback +try: + from topobench.data.loaders.graph.oc20_asedbs2ef_loader import ( + OC20ASEDBDataset, + ) + + HAS_ASEDB = True +except ImportError: + HAS_ASEDB = False + +logger = logging.getLogger(__name__) + + +class OC20DatasetLoader(AbstractLoader): + """Load OC20 S2EF dataset for catalyst discovery and materials science. + + Supports S2EF (Structure to Energy and Forces) to predict energy/forces + from atomic structure. + + Parameters + ---------- + parameters : DictConfig + Configuration parameters containing: + - data_dir: Root directory for data + - data_name: Name of the dataset + - download: Whether to download if not present (default: False) + - train_split: Training split size ("200K", "2M", "20M", "all") + - val_splits: List of validation splits or None for all + - include_test: Whether to download test split (default: True) + - dtype: Data type for tensors (default: "float32") + - legacy_format: Use legacy PyG Data format (default: False) + - max_samples: Limit dataset size for testing (default: None) + """ + + def __init__(self, parameters: DictConfig) -> None: + super().__init__(parameters) + + def load_dataset(self) -> Dataset: + """Load the OC20 dataset. + + Returns + ------- + Dataset + The loaded OC20 dataset with the appropriate configuration. + + Raises + ------ + RuntimeError + If dataset loading fails. + """ + # Download if requested + if self.parameters.get("download", False): + self._download_dataset() + + # Check if we have LMDB files or need ASE DB fallback + data_root = Path(self.get_data_dir()) + + # Try LMDB first + lmdb_present = any(data_root.glob("**/*.lmdb")) + + if not lmdb_present and HAS_ASEDB: + # Fallback to ASE DB dataset + logger.info("No LMDB files found, using ASE DB backend") + return self._load_asedb_dataset(data_root) + + # Initialize LMDB dataset + dataset = self._initialize_dataset() + self.data_dir = self._redefine_data_dir(dataset) + return dataset + + def _download_dataset(self): + """Download the S2EF dataset.""" + root = Path(self.get_data_dir()) + train_split = self.parameters.get("train_split", "200K") + val_splits = self.parameters.get("val_splits", None) + include_test = self.parameters.get("include_test", True) + + # Parse val_splits + if val_splits is not None and isinstance(val_splits, str): + val_splits = [val_splits] + + download_s2ef_dataset( + root=root, + train_split=train_split, + val_splits=val_splits, + include_test=include_test, + ) + + def _initialize_dataset(self) -> OC20Dataset: + """Initialize the OC20 dataset. + + Returns + ------- + OC20Dataset + The initialized OC20 dataset. + + Raises + ------ + RuntimeError + If dataset initialization fails. + """ + try: + dataset = OC20Dataset( + root=str(self.get_data_dir()), + name=self.parameters.data_name, + parameters=self.parameters, + ) + return dataset + except Exception as e: + msg = f"Error initializing OC20 dataset: {e}" + raise RuntimeError(msg) from e + + def _load_asedb_dataset(self, data_root: Path) -> OC20ASEDBDataset: + """Load dataset using ASE DB backend (fallback when no LMDBs). + + Parameters + ---------- + data_root : Path + Root directory for data. + + Returns + ------- + OC20ASEDBDataset + Dataset using ASE DB backend. + """ + train_split = self.parameters.get("train_split", "200K") + val_splits = self.parameters.get("val_splits", None) + include_test = self.parameters.get("include_test", True) + + # Parse val_splits + if val_splits is None: + val_splits = [ + "val_id", + "val_ood_ads", + "val_ood_cat", + "val_ood_both", + ] + elif isinstance(val_splits, str): + val_splits = [val_splits] + + # Import preprocessing utilities + + # Get preprocessing parameters + max_neigh = int(self.parameters.get("max_neigh", 50)) + radius = float(self.parameters.get("radius", 6.0)) + + # Ensure preprocessing is done for all required splits + self._ensure_asedb_preprocessed( + data_root, train_split, val_splits, include_test, max_neigh, radius + ) + + # Collect DB files + # The data_root might already include the dataset name (e.g., datasets/graph/oc20/OC20_S2EF_200K) + # or just the base (e.g., datasets/graph/oc20) + # Try both patterns + train_subdir_name = f"s2ef_train_{train_split}" + + # Pattern 1: data_root already includes dataset name + train_dir_pattern1 = ( + data_root + / "s2ef" + / train_split + / train_subdir_name + / train_subdir_name + ) + # Pattern 2: data_root is just base, dataset name in separate dir + if not train_dir_pattern1.exists(): + # Try finding s2ef directory anywhere under data_root + s2ef_roots = list(data_root.glob("**/s2ef")) + if s2ef_roots: + s2ef_root = s2ef_roots[0].parent + train_dir_pattern1 = ( + s2ef_root + / "s2ef" + / train_split + / train_subdir_name + / train_subdir_name + ) + + train_dbs = ( + sorted(train_dir_pattern1.glob("*.db")) + if train_dir_pattern1.exists() + else [] + ) + + val_dbs = [] + for vs in val_splits: + val_subdir_name = f"s2ef_{vs}" + val_dir = ( + data_root / "s2ef" / "all" / val_subdir_name / val_subdir_name + ) + if ( + not val_dir.exists() + and "s2ef_roots" in locals() + and s2ef_roots + ): + val_dir = ( + s2ef_roots[0].parent + / "s2ef" + / "all" + / val_subdir_name + / val_subdir_name + ) + if val_dir.exists(): + val_dbs.extend(sorted(val_dir.glob("*.db"))) + + test_dbs = [] + if include_test: + test_dir = data_root / "s2ef" / "all" / "s2ef_test" / "s2ef_test" + if ( + not test_dir.exists() + and "s2ef_roots" in locals() + and s2ef_roots + ): + test_dir = ( + s2ef_roots[0].parent + / "s2ef" + / "all" + / "s2ef_test" + / "s2ef_test" + ) + if test_dir.exists(): + test_dbs = sorted(test_dir.glob("*.db")) + + logger.info( + f"Using ASE DB backend: {len(train_dbs)} train, " + f"{len(val_dbs)} val, {len(test_dbs)} test DB files" + ) + + # Parse dtype + dtype = self.parameters.get("dtype", "float32") + if isinstance(dtype, str): + dtype = getattr(torch, dtype) + + max_samples = self.parameters.get("max_samples", None) + if max_samples is not None: + max_samples = int(max_samples) + + return OC20ASEDBDataset( + train_db_paths=[str(p) for p in train_dbs], + val_db_paths=[str(p) for p in val_dbs], + test_db_paths=[str(p) for p in test_dbs], + max_neigh=int(self.parameters.get("max_neigh", 50)), + radius=float(self.parameters.get("radius", 6.0)), + dtype=dtype, + include_energy=True, + include_forces=True, + max_samples=max_samples, + ) + + def _ensure_asedb_preprocessed( + self, + root: Path, + train_split: str, + val_splits: list[str], + include_test: bool, + max_neigh: int, + radius: float, + ) -> None: + """Ensure ASE DB files are preprocessed for the requested splits. + + Parameters + ---------- + root : Path + Root data directory containing the S2EF dataset. + train_split : str + Name of the training split (e.g. "200K"). + val_splits : list[str] + List of validation split names. + include_test : bool + Whether to ensure preprocessing for the test split. + max_neigh : int + Maximum number of neighbors per atom. + radius : float + Cutoff radius for neighbor search in Angstroms. + + Returns + ------- + None + Performs preprocessing as a side-effect; no value is returned. + """ + from topobench.data.preprocessor.oc20_s2ef_preprocessor import ( + needs_preprocessing, + preprocess_s2ef_split_ase, + ) + + # Find s2ef root directory + s2ef_roots = list(root.glob("**/s2ef")) + if not s2ef_roots: + logger.warning(f"No s2ef directory found under {root}") + return + + s2ef_root = s2ef_roots[0].parent / "s2ef" + + # Train directory + train_subdir_name = f"s2ef_train_{train_split}" + train_dir = ( + s2ef_root / train_split / train_subdir_name / train_subdir_name + ) + if train_dir.exists() and needs_preprocessing(train_dir): + logger.info(f"Preprocessing {train_dir}") + preprocess_s2ef_split_ase( + data_path=train_dir, + out_path=train_dir, + num_workers=4, + max_neigh=max_neigh, + radius=radius, + ) + + # Validation directories + for val_split in val_splits: + val_subdir_name = f"s2ef_{val_split}" + val_dir = s2ef_root / "all" / val_subdir_name / val_subdir_name + if val_dir.exists() and needs_preprocessing(val_dir): + logger.info(f"Preprocessing {val_dir}") + preprocess_s2ef_split_ase( + data_path=val_dir, + out_path=val_dir, + num_workers=4, + max_neigh=max_neigh, + radius=radius, + ) + + # Test directory + if include_test: + test_dir = s2ef_root / "all" / "s2ef_test" / "s2ef_test" + if test_dir.exists() and needs_preprocessing(test_dir): + logger.info(f"Preprocessing {test_dir}") + preprocess_s2ef_split_ase( + data_path=test_dir, + out_path=test_dir, + num_workers=4, + max_neigh=max_neigh, + radius=radius, + ) + + def _redefine_data_dir(self, dataset: Dataset) -> Path: + """Redefine the data directory based on dataset configuration. + + Parameters + ---------- + dataset : Dataset + The OC20 dataset instance. + + Returns + ------- + Path + The redefined data directory path. + """ + return self.get_data_dir() diff --git a/topobench/data/loaders/graph/oc20_is2re_dataset_loader.py b/topobench/data/loaders/graph/oc20_is2re_dataset_loader.py new file mode 100644 index 000000000..bd10e2fce --- /dev/null +++ b/topobench/data/loaders/graph/oc20_is2re_dataset_loader.py @@ -0,0 +1,98 @@ +"""Loader for OC20 IS2RE dataset.""" + +import logging +from pathlib import Path + +from omegaconf import DictConfig +from torch_geometric.data import Dataset + +from topobench.data.datasets.oc20_is2re_dataset import IS2REDataset +from topobench.data.loaders.base import AbstractLoader +from topobench.data.utils.oc20_download import download_is2re_dataset + +logger = logging.getLogger(__name__) + + +class IS2REDatasetLoader(AbstractLoader): + """Load OC20 IS2RE dataset. + + Parameters + ---------- + parameters : DictConfig + Configuration parameters containing: + - data_dir: Root directory for data + - data_name: Name of the dataset + - download: Whether to download if not present (default: False) + - dtype: Data type for tensors (default: "float32") + - legacy_format: Use legacy PyG Data format (default: False) + - max_samples: Limit dataset size for testing (default: None) + """ + + def __init__(self, parameters: DictConfig) -> None: + super().__init__(parameters) + + def load_dataset(self) -> Dataset: + """Load the IS2RE dataset. + + Returns + ------- + Dataset + The loaded IS2RE dataset with the appropriate configuration. + + Raises + ------ + RuntimeError + If dataset loading fails. + """ + # Download if requested + if self.parameters.get("download", False): + self._download_dataset() + + # Initialize LMDB dataset + dataset = self._initialize_dataset() + self.data_dir = self._redefine_data_dir(dataset) + return dataset + + def _download_dataset(self): + """Download the IS2RE dataset.""" + root = Path(self.get_data_dir()) + download_is2re_dataset(root=root, task="is2re") + + def _initialize_dataset(self) -> IS2REDataset: + """Initialize the IS2RE dataset. + + Returns + ------- + IS2REDataset + The initialized IS2RE dataset. + + Raises + ------ + RuntimeError + If dataset initialization fails. + """ + try: + dataset = IS2REDataset( + root=str(self.get_data_dir()), + name=self.parameters.data_name, + parameters=self.parameters, + ) + return dataset + except Exception as e: + msg = f"Error initializing IS2RE dataset: {e}" + raise RuntimeError(msg) from e + + def _redefine_data_dir(self, dataset: Dataset) -> Path: + """Redefine the data directory based on dataset configuration. + + Parameters + ---------- + dataset : Dataset + The IS2RE dataset instance. + + Returns + ------- + Path + The redefined data directory path. + """ + return self.get_data_dir() diff --git a/topobench/data/loaders/graph/oc22_is2re_dataset_loader.py b/topobench/data/loaders/graph/oc22_is2re_dataset_loader.py new file mode 100644 index 000000000..90e0646db --- /dev/null +++ b/topobench/data/loaders/graph/oc22_is2re_dataset_loader.py @@ -0,0 +1,98 @@ +"""Loader for OC22 IS2RE dataset.""" + +import logging +from pathlib import Path + +from omegaconf import DictConfig +from torch_geometric.data import Dataset + +from topobench.data.datasets.oc22_is2re_dataset import OC22IS2REDataset +from topobench.data.loaders.base import AbstractLoader +from topobench.data.utils.oc20_download import download_is2re_dataset + +logger = logging.getLogger(__name__) + + +class OC22IS2REDatasetLoader(AbstractLoader): + """Load OC22 IS2RE dataset. + + Parameters + ---------- + parameters : DictConfig + Configuration parameters containing: + - data_dir: Root directory for data + - data_name: Name of the dataset + - download: Whether to download if not present (default: False) + - dtype: Data type for tensors (default: "float32") + - legacy_format: Use legacy PyG Data format (default: False) + - max_samples: Limit dataset size for testing (default: None) + """ + + def __init__(self, parameters: DictConfig) -> None: + super().__init__(parameters) + + def load_dataset(self) -> Dataset: + """Load the OC22 IS2RE dataset. + + Returns + ------- + Dataset + The loaded OC22 IS2RE dataset with the appropriate configuration. + + Raises + ------ + RuntimeError + If dataset loading fails. + """ + # Download if requested + if self.parameters.get("download", False): + self._download_dataset() + + # Initialize LMDB dataset + dataset = self._initialize_dataset() + self.data_dir = self._redefine_data_dir(dataset) + return dataset + + def _download_dataset(self): + """Download the OC22 IS2RE dataset.""" + root = Path(self.get_data_dir()) + download_is2re_dataset(root=root, task="oc22_is2re") + + def _initialize_dataset(self) -> OC22IS2REDataset: + """Initialize the OC22 IS2RE dataset. + + Returns + ------- + OC22IS2REDataset + The initialized OC22 IS2RE dataset. + + Raises + ------ + RuntimeError + If dataset initialization fails. + """ + try: + dataset = OC22IS2REDataset( + root=str(self.get_data_dir()), + name=self.parameters.data_name, + parameters=self.parameters, + ) + return dataset + except Exception as e: + msg = f"Error initializing OC22 IS2RE dataset: {e}" + raise RuntimeError(msg) from e + + def _redefine_data_dir(self, dataset: Dataset) -> Path: + """Redefine the data directory based on dataset configuration. + + Parameters + ---------- + dataset : Dataset + The OC22 IS2RE dataset instance. + + Returns + ------- + Path + The redefined data directory path. + """ + return self.get_data_dir() diff --git a/topobench/data/preprocessor/oc20_s2ef_preprocessor.py b/topobench/data/preprocessor/oc20_s2ef_preprocessor.py new file mode 100644 index 000000000..27626a2d8 --- /dev/null +++ b/topobench/data/preprocessor/oc20_s2ef_preprocessor.py @@ -0,0 +1,605 @@ +"""S2EF preprocessing for OC20/OC22 datasets. + +Creates ASE DB files with extracted graph features from provided *.extxyz files +for the S2EF task. + +Copyright (c) Meta Platforms, Inc. and affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import glob +import logging +import sys +from concurrent.futures import ProcessPoolExecutor, as_completed +from pathlib import Path + +import numpy as np +import torch +from torch_geometric.data import Data +from tqdm import tqdm + +logger = logging.getLogger(__name__) + +# Try importing ASE +try: + import ase.db + import ase.io + from ase.atoms import Atoms + + HAS_ASE = True +except ImportError: + HAS_ASE = False + logger.warning( + "ASE not installed. S2EF preprocessing will not be available. " + "Install with: pip install ase" + ) + +# Try importing pymatgen for neighbor search +try: + from pymatgen.io.ase import AseAtomsAdaptor + + HAS_PYMATGEN = True +except ImportError: + HAS_PYMATGEN = False + logger.warning( + "pymatgen not installed. Will use slower ASE neighbor search. " + "Install with: pip install pymatgen" + ) + + +class AtomsToGraphs: + """Convert ASE Atoms objects to PyTorch Geometric Data objects. + + This class handles periodic boundary conditions and creates graph representations + suitable for machine learning on atomic structures. + + Parameters + ---------- + max_neigh : int + Maximum number of neighbors to consider per atom. + radius : float + Cutoff radius in Angstroms for neighbor search. + r_energy : bool + Whether to include energy in the created Data objects. + r_forces : bool + Whether to include forces in the created Data objects. + r_distances : bool + Whether to include edge distances as edge attributes. + r_edges : bool + Whether to compute edges (can be disabled for debugging). + r_fixed : bool + Whether to include fixed atom flags. + """ + + def __init__( + self, + max_neigh: int = 50, + radius: float = 6.0, + r_energy: bool = True, + r_forces: bool = True, + r_distances: bool = True, + r_edges: bool = True, + r_fixed: bool = True, + ): + self.max_neigh = max_neigh + self.radius = radius + self.r_energy = r_energy + self.r_forces = r_forces + self.r_distances = r_distances + self.r_fixed = r_fixed + self.r_edges = r_edges + + def _get_neighbors_pymatgen(self, atoms: Atoms): + """Get neighbors using pymatgen (faster for periodic systems). + + Parameters + ---------- + atoms : ase.atoms.Atoms + ASE Atoms object for which to compute neighbor lists. + + Returns + ------- + tuple + Tuple (c_index, n_index, n_distance, offsets) representing neighbor + center indices, neighbor indices, distances and periodic offsets. + """ + if not HAS_PYMATGEN: + return self._get_neighbors_ase(atoms) + + struct = AseAtomsAdaptor.get_structure(atoms) + _c_index, _n_index, _offsets, n_distance = struct.get_neighbor_list( + r=self.radius, numerical_tol=0, exclude_self=True + ) + + # Limit to max_neigh neighbors per atom, sorted by distance + _nonmax_idx = [] + for i in range(len(atoms)): + idx_i = (_c_index == i).nonzero()[0] + idx_sorted = np.argsort(n_distance[idx_i])[: self.max_neigh] + _nonmax_idx.append(idx_i[idx_sorted]) + _nonmax_idx = np.concatenate(_nonmax_idx) + + _c_index = _c_index[_nonmax_idx] + _n_index = _n_index[_nonmax_idx] + n_distance = n_distance[_nonmax_idx] + _offsets = _offsets[_nonmax_idx] + + return _c_index, _n_index, n_distance, _offsets + + def _get_neighbors_ase(self, atoms: Atoms): + """Get neighbors using ASE (slower but always available). + + Parameters + ---------- + atoms : ase.atoms.Atoms + ASE Atoms object for which to compute neighbor lists. + + Returns + ------- + tuple + Tuple (idx_i, idx_j, distances, offsets) representing neighbor + center indices, neighbor indices, distances and periodic offsets. + """ + from ase.neighborlist import neighbor_list + + idx_i, idx_j, idx_S, distances = neighbor_list( + "ijSd", atoms, self.radius, self_interaction=False + ) + + # Limit to max_neigh neighbors per atom + _nonmax_idx = [] + for i in range(len(atoms)): + mask = idx_i == i + dists_i = distances[mask] + idx_sorted = np.argsort(dists_i)[: self.max_neigh] + _nonmax_idx.append(np.where(mask)[0][idx_sorted]) + _nonmax_idx = np.concatenate(_nonmax_idx) + + return ( + idx_i[_nonmax_idx], + idx_j[_nonmax_idx], + distances[_nonmax_idx], + idx_S[_nonmax_idx], + ) + + def _reshape_features(self, c_index, n_index, n_distance, offsets): + """Convert neighbor info to PyTorch tensors. + + Parameters + ---------- + c_index : array-like + Center atom indices for edges. + n_index : array-like + Neighbor atom indices for edges. + n_distance : array-like + Distances between center and neighbor atoms. + offsets : array-like + Periodic cell offsets corresponding to edges. + + Returns + ------- + tuple + (edge_index, edge_distances, cell_offsets) as PyTorch tensors. + """ + edge_index = torch.LongTensor(np.vstack((n_index, c_index))) + edge_distances = torch.FloatTensor(n_distance) + cell_offsets = torch.LongTensor(offsets) + + # Remove very small distances (self-interactions that slipped through) + nonzero = torch.where(edge_distances >= 1e-8)[0] + edge_index = edge_index[:, nonzero] + edge_distances = edge_distances[nonzero] + cell_offsets = cell_offsets[nonzero] + + return edge_index, edge_distances, cell_offsets + + def convert(self, atoms: Atoms) -> Data: + """Convert a single ASE Atoms object to a PyG Data object. + + Parameters + ---------- + atoms : ase.atoms.Atoms + ASE Atoms object with positions, atomic numbers, cell, etc. + + Returns + ------- + torch_geometric.data.Data + PyG Data object containing node features, positions, edges, and + optional energy/forces/fixed flags. + """ + # Basic atomic structure info + atomic_numbers = torch.LongTensor(atoms.get_atomic_numbers()) + positions = torch.FloatTensor(atoms.get_positions()) + cell = torch.FloatTensor(np.array(atoms.get_cell())).view(1, 3, 3) + natoms = len(atoms) + + # Create base data object + # Create node features from atomic numbers (one-hot or simple embedding) + # For now, use atomic numbers as features (can be enhanced later) + node_features = atomic_numbers.unsqueeze(1).float() + + data = Data( + cell=cell, + pos=positions, + atomic_numbers=atomic_numbers, + natoms=natoms, + z=atomic_numbers, # Alias for compatibility + x=node_features, # Add node features for compatibility with transforms + ) + + # Add edges if requested + if self.r_edges: + if HAS_PYMATGEN: + split_idx_dist = self._get_neighbors_pymatgen(atoms) + else: + split_idx_dist = self._get_neighbors_ase(atoms) + edge_index, edge_distances, cell_offsets = self._reshape_features( + *split_idx_dist + ) + data.edge_index = edge_index + data.cell_offsets = cell_offsets + + if self.r_distances: + data.distances = edge_distances + data.edge_attr = edge_distances.view( + -1, 1 + ) # For compatibility + + # Add energy if available and requested + if self.r_energy: + try: + energy = atoms.get_potential_energy(apply_constraint=False) + data.y = torch.FloatTensor([energy]) + data.energy = torch.FloatTensor([energy]) + except Exception: + # Energy not available (e.g., no calculator) + pass + + # Add forces if available and requested + if self.r_forces: + try: + forces = atoms.get_forces(apply_constraint=False) + data.force = torch.FloatTensor(forces) + except Exception: + # Forces not available + pass + + # Add fixed atom flags if requested + if self.r_fixed: + fixed_idx = torch.zeros(natoms, dtype=torch.float32) + if hasattr(atoms, "constraints"): + from ase.constraints import FixAtoms + + for constraint in atoms.constraints: + if isinstance(constraint, FixAtoms): + fixed_idx[constraint.index] = 1 + data.fixed = fixed_idx + + # Add metadata from atoms.info if present + if hasattr(atoms, "info") and atoms.info: + info_data = atoms.info.get("data", {}) + if "sid" in info_data: + data.sid = info_data["sid"] + if "fid" in info_data: + data.fid = info_data["fid"] + if "ref_energy" in info_data: + data.ref_energy = torch.FloatTensor([info_data["ref_energy"]]) + + return data + + +class S2EFPreprocessor: + """Preprocessor for S2EF data using ASE database format. + + This class handles conversion from extxyz files to ASE database format. + """ + + def write_db( + self, + extxyz_paths: list[str | Path | list[str | Path]], + dbs: dict[str, ase.db.core.Database], + map_file_to_log: dict[str, str | None] | None = None, + num_workers: int = 1, + batch_size: int = 100, + ): + """Write ASE DB files from extxyz inputs. + + This stores for each extended XYZ file the Atoms objects and the + metadata from the corresponding xyz_log file into the ASE Database. + + Parameters + ---------- + extxyz_paths : list[str | Path | list[str | Path]] + Paths to the extended XYZ files or lists of paths per DB. + dbs : dict[str, ase.db.core.Database] + Mapping from extxyz file keys to ASE Database objects. + map_file_to_log : dict[str, str | None] | None + Mapping from extxyz paths to xyz_log paths. If None, no metadata is used. + num_workers : int, optional + Number of worker processes to use, by default 1. + batch_size : int, optional + Number of structures to write to the ASE DB file at a time, by default 100. + + Returns + ------- + None + This function writes DB files as a side effect. + """ + if not HAS_ASE: + raise ImportError("ASE is required for S2EF preprocessing") + + if map_file_to_log is None: + map_file_to_log = {k: None for k in dbs} + + num_workers = min(num_workers, len(extxyz_paths)) + node_counts = [] # Track node counts for each structure + + with ProcessPoolExecutor(max_workers=num_workers) as executor: + futures = [ + executor.submit( + self._write_db_worker, + file_path, + map_file_to_log[db_path], + dbs[db_path], + batch_size, + ) + for file_path, db_path in zip( + extxyz_paths, + map_file_to_log.keys(), + strict=True, + ) + ] + for future in tqdm( + as_completed(futures), + desc="Writing DBs", + total=len(futures), + leave=False, + file=sys.stdout, + dynamic_ncols=True, + ): + db_path, num_atoms, structure_node_counts = future.result() + node_counts.extend(structure_node_counts) + + # Save node counts + node_counts_path = ( + Path(list(dbs.keys())[0]).parent.parent / "node_counts.npy" + ) + np.save(node_counts_path, np.array(node_counts)) + logger.info( + f"Saved {len(node_counts)} node counts to {node_counts_path}" + ) + + @staticmethod + def _write_db_worker( + file_paths: str | Path | list[str | Path], + xyz_log_path: str | Path | None, + db: ase.db.core.Database, + batch_size: int, + ): + """Worker function to write atoms to ASE database. + + Parameters + ---------- + file_paths : str | Path | list[str | Path] + Path or list of paths to extended XYZ file(s). + xyz_log_path : str | Path | None + Path to the log file with metadata or None. + db : ase.db.core.Database + ASE database object to write to. + batch_size : int + Number of structures to batch together. + + Returns + ------- + tuple + (db_path, total_atoms, node_counts). + """ + if xyz_log_path is not None: + with open(xyz_log_path) as f: + xyz_log = f.read().splitlines() + else: + xyz_log = None + + if isinstance(file_paths, (str, Path)): + file_paths = [file_paths] + + node_counts = [] # Track node counts for each structure + total_atoms = 0 + + for file_path in file_paths: + atoms_list = ase.io.read(file_path, ":") + atoms_batch = [] + log_batch = [] + for i, atoms in enumerate(atoms_list): + if xyz_log is not None and i < len(xyz_log): + log_line = xyz_log[i].split(",") + log_info = { + "sid": int(log_line[0].split("random")[1]), + "fid": int(log_line[1].split("frame")[1]), + "ref_energy": float(log_line[2]), + } + else: + log_info = {} + + atoms_batch.append(atoms) + log_batch.append(log_info) + node_counts.append(len(atoms)) + + if len(atoms_batch) == batch_size: + # Write batch to database + for atom, log in zip(atoms_batch, log_batch, strict=True): + db.write(atom, data=log) + atoms_batch, log_batch = [], [] + + total_atoms += len(atoms_list) + + # Write remaining batch + if atoms_batch: + for atom, log in zip(atoms_batch, log_batch, strict=True): + db.write(atom, data=log) + + db_path = db.filename if hasattr(db, "filename") else str(db) + return db_path, total_atoms, node_counts + + +def needs_preprocessing(data_path: Path) -> bool: + """Check if data needs preprocessing (has extxyz but no db files).""" + has_extxyz = bool(list(data_path.glob("*.extxyz"))) + has_db = bool(list(data_path.glob("*.db"))) + return has_extxyz and not has_db + + +def preprocess_s2ef_split_ase( + data_path: Path, + out_path: Path, + num_workers: int = 1, + ref_energy: bool = True, + test_data: bool = False, + max_neigh: int = 50, + radius: float = 6.0, +) -> None: + """Preprocess S2EF data from extxyz/txt to ASE DB format. + + Parameters + ---------- + data_path : Path + Path to directory containing *.extxyz and *.txt files. + out_path : Path + Directory to save ASE DB files. + num_workers : int + Number of parallel workers for preprocessing. + ref_energy : bool + Whether to include reference energies in metadata. + test_data : bool + Whether this is test data (no energy/forces in log). + max_neigh : int + Maximum number of neighbors per atom. + radius : float + Cutoff radius for neighbor search in Angstroms. + + Returns + ------- + None + This function writes ASE DB files as a side effect. + """ + if not HAS_ASE: + raise ImportError("ASE is required for S2EF preprocessing") + + logger.info( + f"Preprocessing S2EF data from {data_path} to {out_path} (ASE DB format)" + ) + + # Find all extxyz files + extxyz_files = sorted(glob.glob(str(data_path / "*.extxyz"))) + + if not extxyz_files: + logger.warning(f"No extxyz files found in {data_path}") + return + + out_path.mkdir(parents=True, exist_ok=True) + + # Create mapping from extxyz to log files, but only for files that need preprocessing + map_file_to_log = {} + dbs = {} + skipped_count = 0 + + for extxyz_file in extxyz_files: + extxyz_path = Path(extxyz_file) + base_name = extxyz_path.stem + + # Check if DB file already exists + db_path = out_path / f"{base_name}.db" + if db_path.exists(): + skipped_count += 1 + continue # Skip files that already have DB files + + # Find corresponding txt file + txt_path = data_path / f"{base_name}.txt" + if txt_path.exists() and ref_energy and not test_data: + map_file_to_log[str(extxyz_path)] = str(txt_path) + else: + map_file_to_log[str(extxyz_path)] = None + + # Create ASE DB for this file + dbs[str(extxyz_path)] = ase.db.connect(str(db_path)) + + if skipped_count > 0: + logger.info( + f"Skipping {skipped_count} files that already have DB files" + ) + + if not map_file_to_log: + logger.info("All files already preprocessed, skipping...") + return + + # Write to ASE databases + preprocessor = S2EFPreprocessor() + preprocessor.write_db( + extxyz_paths=list(map_file_to_log.keys()), + dbs=dbs, + map_file_to_log=map_file_to_log, + num_workers=num_workers, + batch_size=100, + ) + + logger.info(f"Preprocessing complete. DB files saved to {out_path}") + + +def preprocess_s2ef_dataset_ase( + data_path: Path, + num_workers: int = 1, + splits: list[str] | None = None, + max_neigh: int = 50, + radius: float = 6.0, +) -> None: + """Preprocess entire S2EF dataset with train/val/test splits. + + Parameters + ---------- + data_path : Path + Root path containing train/val/test subdirectories. + num_workers : int + Number of parallel workers. + splits : Optional[list[str]] + List of splits to process. If None, processes all found splits. + max_neigh : int + Maximum number of neighbors per atom. + radius : float + Cutoff radius for neighbor search in Angstroms. + + Returns + ------- + None + This function writes ASE DB files as a side effect. + """ + if not HAS_ASE: + raise ImportError("ASE is required for S2EF preprocessing") + + if splits is None: + # Auto-detect splits + splits = [d.name for d in data_path.iterdir() if d.is_dir()] + + for split in splits: + split_path = data_path / split + if not split_path.exists(): + logger.warning(f"Split {split} not found at {split_path}") + continue + + logger.info(f"Processing split: {split}") + is_test = "test" in split.lower() + + preprocess_s2ef_split_ase( + data_path=split_path, + out_path=split_path, + num_workers=num_workers, + ref_energy=True, + test_data=is_test, + max_neigh=max_neigh, + radius=radius, + ) + + logger.info("Dataset preprocessing complete") diff --git a/topobench/data/preprocessor/preprocessor.py b/topobench/data/preprocessor/preprocessor.py index e5c4a913a..f382c7ddf 100644 --- a/topobench/data/preprocessor/preprocessor.py +++ b/topobench/data/preprocessor/preprocessor.py @@ -54,8 +54,15 @@ def __init__(self, dataset, data_dir, transforms_config=None, **kwargs): self.transform = ( dataset.transform if hasattr(dataset, "transform") else None ) - self.data, self.slices = dataset._data, dataset.slices - self.data_list = [data for data in dataset] + # Handle datasets that don't have _data/slices (e.g., LMDB-based datasets like OC20) + if hasattr(dataset, "_data") and hasattr(dataset, "slices"): + self.data, self.slices = dataset._data, dataset.slices + self.data_list = [data for data in dataset] + else: + # For non-InMemoryDataset, store the dataset directly + self._data = None + self.slices = None + self.data_list = [data for data in dataset] # Some datasets have fixed splits, and those are stored as split_idx during loading # We need to store this information to be able to reproduce the splits afterwards @@ -76,6 +83,42 @@ def processed_dir(self) -> str: else: return self.root + "/processed" + def len(self) -> int: + """Return the number of samples in the dataset. + + Returns + ------- + int + Number of samples. + """ + # Only use data_list for length when transforms are NOT applied + # and dataset doesn't have standard InMemoryDataset structure + if not self.transforms_applied and self.slices is None: + if hasattr(self, "data_list") and self.data_list is not None: + return len(self.data_list) + # Fallback to dataset length during initialization + return len(self.dataset) + return super().len() + + def get(self, idx: int): + """Get data object at index. + + Parameters + ---------- + idx : int + Index of the data object. + + Returns + ------- + Data + Data object at the given index. + """ + # Only use data_list access when transforms are NOT applied + # and dataset doesn't have standard InMemoryDataset structure + if not self.transforms_applied and self.slices is None: + return self.data_list[idx] + return super().get(idx) + @property def processed_file_names(self) -> str: """Return the name of the processed file. diff --git a/topobench/data/utils/__init__.py b/topobench/data/utils/__init__.py index 8793f773e..31a15f290 100644 --- a/topobench/data/utils/__init__.py +++ b/topobench/data/utils/__init__.py @@ -62,4 +62,21 @@ # add function name here ] -__all__ = utils_functions + split_helper_functions + io_helper_functions +# OC20 download utilities +from .oc20_download import ( # noqa: E402 + download_is2re_dataset, # noqa: F401 + # import function here, add noqa: F401 for PR + download_s2ef_dataset, # noqa: F401 +) + +oc20_helper_functions = [ + "download_s2ef_dataset", + "download_is2re_dataset", +] + +__all__ = ( + utils_functions + + split_helper_functions + + io_helper_functions + + oc20_helper_functions +) diff --git a/topobench/data/utils/oc20_download.py b/topobench/data/utils/oc20_download.py new file mode 100644 index 000000000..d48a96fec --- /dev/null +++ b/topobench/data/utils/oc20_download.py @@ -0,0 +1,235 @@ +"""Utilities for downloading and preparing OC20 datasets.""" + +from __future__ import annotations + +import logging +import lzma +import os +import shutil +import tarfile +import urllib.request +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path + +from tqdm import tqdm + +logger = logging.getLogger(__name__) + +# OC20 dataset split URLs +S2EF_TRAIN_SPLITS = { + "200K": "https://dl.fbaipublicfiles.com/opencatalystproject/data/s2ef_train_200K.tar", + "2M": "https://dl.fbaipublicfiles.com/opencatalystproject/data/s2ef_train_2M.tar", + "20M": "https://dl.fbaipublicfiles.com/opencatalystproject/data/s2ef_train_20M.tar", + "all": "https://dl.fbaipublicfiles.com/opencatalystproject/data/s2ef_train_all.tar", +} + +S2EF_VAL_SPLITS = { + "val_id": "https://dl.fbaipublicfiles.com/opencatalystproject/data/s2ef_val_id.tar", + "val_ood_ads": "https://dl.fbaipublicfiles.com/opencatalystproject/data/s2ef_val_ood_ads.tar", + "val_ood_cat": "https://dl.fbaipublicfiles.com/opencatalystproject/data/s2ef_val_ood_cat.tar", + "val_ood_both": "https://dl.fbaipublicfiles.com/opencatalystproject/data/s2ef_val_ood_both.tar", +} + +S2EF_TEST_SPLIT = "https://dl.fbaipublicfiles.com/opencatalystproject/data/s2ef_test_lmdbs.tar.gz" + +IS2RE_URL = "https://dl.fbaipublicfiles.com/opencatalystproject/data/is2res_train_val_test_lmdbs.tar.gz" +OC22_IS2RE_URL = "https://dl.fbaipublicfiles.com/opencatalystproject/data/oc22/is2res_total_train_val_test_lmdbs.tar.gz" + + +def uncompress_xz(file_path: str) -> str: + """Decompress .xz files. + + Parameters + ---------- + file_path : str + Path to file to decompress. + + Returns + ------- + str + Path to decompressed file. + """ + if not file_path.endswith(".xz"): + return file_path + + output_path = file_path.replace(".xz", "") + try: + with ( + lzma.open(file_path, "rb") as f_in, + open(output_path, "wb") as f_out, + ): + shutil.copyfileobj(f_in, f_out) + os.remove(file_path) + return output_path + except Exception as e: + logger.error(f"Error uncompressing {file_path}: {e}") + return file_path + + +def download_and_extract( + url: str, target_dir: Path, skip_if_extracted: bool = True +) -> Path: + """Download and extract a tar archive. + + Parameters + ---------- + url : str + URL to download from. + target_dir : Path + Directory to extract to. + skip_if_extracted : bool + If True, skip extraction if extracted files already exist (default: True). + + Returns + ------- + Path + Path to extracted directory. + """ + target_dir.mkdir(parents=True, exist_ok=True) + target_file = target_dir / os.path.basename(url) + + # Download if needed + if not target_file.exists(): + logger.info(f"Downloading {url}...") + with tqdm( + unit="B", unit_scale=True, desc=f"Downloading {target_file.name}" + ) as pbar: + + def report(block_num, block_size, total_size): + if total_size > 0 and block_num == 0: + pbar.total = total_size + pbar.update(block_size) + + urllib.request.urlretrieve(url, target_file, reporthook=report) + else: + logger.info(f"Archive {target_file.name} already downloaded") + + # Check if extraction is needed + extraction_marker = target_dir / ".extracted" + if skip_if_extracted and extraction_marker.exists(): + logger.info(f"Archive {target_file.name} already extracted, skipping") + return target_dir + + # Extract + logger.info(f"Extracting {target_file.name}...") + if str(target_file).endswith((".tar.gz", ".tgz")): + with tarfile.open(target_file, "r:gz") as tar: + tar.extractall(path=target_dir) + elif str(target_file).endswith(".tar"): + with tarfile.open(target_file, "r:") as tar: + tar.extractall(path=target_dir) + else: + raise ValueError(f"Unsupported archive format: {target_file}") + + # Mark as extracted + extraction_marker.touch() + return target_dir + + +def decompress_xz_files(directory: Path): + """Decompress all .xz files in a directory. + + Parameters + ---------- + directory : Path + Directory to search for .xz files. + """ + xz_files = list(directory.glob("**/*.xz")) + if xz_files: + logger.info( + f"Decompressing {len(xz_files)} .xz files in {directory}..." + ) + num_workers = max(1, os.cpu_count() - 1) + # Use threads to avoid pickling/import issues with processes on macOS + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [ + executor.submit(uncompress_xz, str(f)) for f in xz_files + ] + for future in as_completed(futures): + future.result() + + +def download_s2ef_dataset( + root: Path, + train_split: str = "200K", + val_splits: list[str] | None = None, + include_test: bool = True, +): + """Download S2EF dataset splits. + + Parameters + ---------- + root : Path + Root directory for data storage. + train_split : str + Training split size: "200K", "2M", "20M", or "all". + val_splits : list[str] | None + List of validation splits to download. If None, downloads all. + include_test : bool + Whether to download test split. + """ + if val_splits is None: + val_splits = list(S2EF_VAL_SPLITS.keys()) + + # Download train split + train_url = S2EF_TRAIN_SPLITS[train_split] + train_subdir_name = f"s2ef_train_{train_split}" + train_dir = ( + root / "s2ef" / train_split / train_subdir_name / train_subdir_name + ) + if not train_dir.exists(): + logger.info(f"Downloading S2EF train split: {train_split}") + download_and_extract(train_url, root / "s2ef" / train_split) + decompress_xz_files(root / "s2ef" / train_split) + else: + logger.info( + f"S2EF train split {train_split} already exists, skipping download" + ) + + # Download validation splits + for val_split in val_splits: + val_url = S2EF_VAL_SPLITS[val_split] + val_subdir_name = f"s2ef_{val_split}" + val_dir = root / "s2ef" / "all" / val_subdir_name / val_subdir_name + if not val_dir.exists(): + logger.info(f"Downloading S2EF validation split: {val_split}") + download_and_extract(val_url, root / "s2ef" / "all") + decompress_xz_files(root / "s2ef" / "all") + else: + logger.info( + f"S2EF validation split {val_split} already exists, skipping download" + ) + + # Download test split + if include_test: + test_subdir_name = "s2ef_test" + test_dir = root / "s2ef" / "all" / test_subdir_name / test_subdir_name + if not test_dir.exists(): + logger.info("Downloading S2EF test split") + download_and_extract(S2EF_TEST_SPLIT, root / "s2ef" / "all") + decompress_xz_files(root / "s2ef" / "all") + else: + logger.info("S2EF test split already exists, skipping download") + + +def download_is2re_dataset(root: Path, task: str = "is2re"): + """Download IS2RE or OC22 IS2RE dataset. + + Parameters + ---------- + root : Path + Root directory for data storage. + task : str + Task name: "is2re" or "oc22_is2re". + """ + url = IS2RE_URL if task == "is2re" else OC22_IS2RE_URL + target_dir = root / task + + if not target_dir.exists(): + logger.info(f"Downloading {task.upper()} dataset") + download_and_extract(url, root) + decompress_xz_files(root) + else: + logger.info( + f"{task.upper()} dataset already exists, skipping download" + ) diff --git a/topobench/nn/backbones/graph/nsd_utils/inductive_discrete_models.py b/topobench/nn/backbones/graph/nsd_utils/inductive_discrete_models.py index 30ae553b8..4c2a2dc31 100644 --- a/topobench/nn/backbones/graph/nsd_utils/inductive_discrete_models.py +++ b/topobench/nn/backbones/graph/nsd_utils/inductive_discrete_models.py @@ -64,16 +64,12 @@ def __init__(self, config): ) nn.init.orthogonal_(self.lin_right_weights[-1].weight.data) for _i in range(self.layers): - self.lin_left_weights.append( - nn.Linear(self.d, self.d, bias=False) - ) + self.lin_left_weights.append(nn.Linear(self.d, self.d, bias=False)) nn.init.eye_(self.lin_left_weights[-1].weight.data) self.sheaf_learners = nn.ModuleList() - num_sheaf_learners = min( - self.layers, self.layers - ) + num_sheaf_learners = min(self.layers, self.layers) for _i in range(num_sheaf_learners): self.sheaf_learners.append( LocalConcatSheafLearner( @@ -208,17 +204,13 @@ def __init__(self, config): ) nn.init.orthogonal_(self.lin_right_weights[-1].weight.data) for _i in range(self.layers): - self.lin_left_weights.append( - nn.Linear(self.d, self.d, bias=False) - ) + self.lin_left_weights.append(nn.Linear(self.d, self.d, bias=False)) nn.init.eye_(self.lin_left_weights[-1].weight.data) self.sheaf_learners = nn.ModuleList() self.weight_learners = nn.ModuleList() - num_sheaf_learners = min( - self.layers, self.layers - ) + num_sheaf_learners = min(self.layers, self.layers) for _i in range(num_sheaf_learners): self.sheaf_learners.append( LocalConcatSheafLearner( @@ -397,16 +389,12 @@ def __init__(self, config): ) nn.init.orthogonal_(self.lin_right_weights[-1].weight.data) for _i in range(self.layers): - self.lin_left_weights.append( - nn.Linear(self.d, self.d, bias=False) - ) + self.lin_left_weights.append(nn.Linear(self.d, self.d, bias=False)) nn.init.eye_(self.lin_left_weights[-1].weight.data) self.sheaf_learners = nn.ModuleList() - num_sheaf_learners = min( - self.layers, self.layers - ) + num_sheaf_learners = min(self.layers, self.layers) for _i in range(num_sheaf_learners): self.sheaf_learners.append( LocalConcatSheafLearner( diff --git a/topobench/nn/backbones/graph/nsd_utils/laplace.py b/topobench/nn/backbones/graph/nsd_utils/laplace.py index 55903396f..606311620 100644 --- a/topobench/nn/backbones/graph/nsd_utils/laplace.py +++ b/topobench/nn/backbones/graph/nsd_utils/laplace.py @@ -214,6 +214,7 @@ def compute_learnable_diag_laplacian_indices( return diag_indices, non_diag_indices + def mergesp(index1, value1, index2, value2): """ Merge two sparse matrices with disjoint indices into one. diff --git a/topobench/nn/backbones/graph/nsd_utils/laplacian_builders.py b/topobench/nn/backbones/graph/nsd_utils/laplacian_builders.py index fb6da3ea8..0da64aab6 100644 --- a/topobench/nn/backbones/graph/nsd_utils/laplacian_builders.py +++ b/topobench/nn/backbones/graph/nsd_utils/laplacian_builders.py @@ -103,7 +103,7 @@ def scalar_normalise(self, diag, tril, row, col): assert diag.dim() == 2 d = diag.size(-1) diag_sqrt_inv = (diag + 1).pow(-0.5) - + diag_sqrt_inv = ( diag_sqrt_inv.view(-1, 1, 1) if tril.dim() > 2 @@ -122,6 +122,7 @@ def scalar_normalise(self, diag, tril, row, col): return diag_maps, non_diag_maps + class DiagLaplacianBuilder(LaplacianBuilder): """ Builder for sheaf Laplacian with diagonal restriction maps. @@ -199,6 +200,7 @@ def forward(self, maps): return (edge_index, weights), saved_tril_maps + class NormConnectionLaplacianBuilder(LaplacianBuilder): """ Builder for normalized bundle sheaf Laplacian with orthogonal restriction maps. @@ -254,7 +256,7 @@ def forward(self, map_params): """ assert len(map_params.size()) == 2 assert map_params.size(1) == self.d * (self.d + 1) // 2 - + _, full_right_idx = self.full_left_right_idx left_idx, right_idx = self.left_right_idx tril_row, tril_col = self.vertex_tril_idx @@ -294,6 +296,7 @@ def forward(self, map_params): return (edge_index, weights), saved_tril_maps + class GeneralLaplacianBuilder(LaplacianBuilder): """ Builder for general sheaf Laplacian with full matrix restriction maps. diff --git a/topobench/transforms/liftings/graph2simplicial/latentclique_lifting.py b/topobench/transforms/liftings/graph2simplicial/latentclique_lifting.py index 55ac1c406..bc05ad9e9 100755 --- a/topobench/transforms/liftings/graph2simplicial/latentclique_lifting.py +++ b/topobench/transforms/liftings/graph2simplicial/latentclique_lifting.py @@ -605,7 +605,8 @@ def splitmerge(self): if clique_i == clique_j: clique_size = self.Z[clique_i].sum() - if clique_size <= 2: return # noqa + if clique_size <= 2: + return # noqa Z_prop = self.Z.copy() Z_prop = np.delete(Z_prop, clique_i, 0)