diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index 93c58afa0..4ba669aab 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -18,14 +18,17 @@ The pipeline to solve differential equations with PINA follows just five steps: 5. Train the model with the PINA :doc:`Trainer `, enhance the train with `Callbacks`_ -Trainer, Dataset and Datamodule --------------------------------- +Trainer, Data Loader and Data Module +---------------------------------------- .. toctree:: :titlesonly: Trainer - Dataset - DataModule + Data Module + Single-Batch Data Loader + Aggregator + Creator + Condition Subset Data Types ------------ diff --git a/docs/source/_rst/data/aggregator.rst b/docs/source/_rst/data/aggregator.rst new file mode 100644 index 000000000..738a57524 --- /dev/null +++ b/docs/source/_rst/data/aggregator.rst @@ -0,0 +1,9 @@ +Aggregator +================ +.. currentmodule:: pina.data.aggregator + +.. automodule:: pina._src.data.aggregator + +.. autoclass:: pina._src.data.aggregator._Aggregator + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/data/condition_subset.rst b/docs/source/_rst/data/condition_subset.rst new file mode 100644 index 000000000..84c032dc8 --- /dev/null +++ b/docs/source/_rst/data/condition_subset.rst @@ -0,0 +1,9 @@ +Condition Subset +================ +.. currentmodule:: pina.data.condition_subset + +.. automodule:: pina._src.data.condition_subset + +.. autoclass:: pina._src.data.condition_subset._ConditionSubset + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/data/creator.rst b/docs/source/_rst/data/creator.rst new file mode 100644 index 000000000..5d836292d --- /dev/null +++ b/docs/source/_rst/data/creator.rst @@ -0,0 +1,9 @@ +Creator +======= +.. currentmodule:: pina.data.creator + +.. automodule:: pina._src.data.creator + +.. autoclass:: pina._src.data.creator._Creator + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/data/data_module.rst b/docs/source/_rst/data/data_module.rst index a9236ed15..e31dae2b9 100644 --- a/docs/source/_rst/data/data_module.rst +++ b/docs/source/_rst/data/data_module.rst @@ -2,6 +2,6 @@ DataModule ====================== .. currentmodule:: pina.data.data_module -.. autoclass:: pina._src.data.data_module.PinaDataModule +.. autoclass:: pina._src.data.data_module.DataModule :members: :show-inheritance: diff --git a/docs/source/_rst/data/dataset.rst b/docs/source/_rst/data/dataset.rst deleted file mode 100644 index 264722b07..000000000 --- a/docs/source/_rst/data/dataset.rst +++ /dev/null @@ -1,19 +0,0 @@ -Dataset -====================== -.. currentmodule:: pina.data.dataset - -.. autoclass:: pina._src.data.dataset.PinaDataset - :members: - :show-inheritance: - -.. autoclass:: pina._src.data.dataset.PinaDatasetFactory - :members: - :show-inheritance: - -.. autoclass:: pina._src.data.dataset.PinaGraphDataset - :members: - :show-inheritance: - -.. autoclass:: pina._src.data.dataset.PinaTensorDataset - :members: - :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/data/single_batch_data_loader.rst b/docs/source/_rst/data/single_batch_data_loader.rst new file mode 100644 index 000000000..7c1debb92 --- /dev/null +++ b/docs/source/_rst/data/single_batch_data_loader.rst @@ -0,0 +1,9 @@ +Single-Batch Data Loader +=========================== +.. currentmodule:: pina.data.single_batch_data_loader + +.. automodule:: pina._src.data.single_batch_data_loader + +.. autoclass:: pina._src.data.single_batch_data_loader._SingleBatchDataLoader + :members: + :show-inheritance: \ No newline at end of file diff --git a/pina/__init__.py b/pina/__init__.py index c3dc00f3b..cafab2d31 100644 --- a/pina/__init__.py +++ b/pina/__init__.py @@ -1,6 +1,4 @@ """ -PINA: Physics-Informed Neural Analysis. - A specialized framework for Scientific Machine Learning (SciML), providing tools for Physics-Informed Neural Networks (PINNs), Neural Operators, and data-driven physical modeling. @@ -10,7 +8,7 @@ "LabelTensor", "Trainer", "Condition", - "PinaDataModule", + "DataModule", "Graph", ] @@ -18,4 +16,22 @@ from pina._src.core.graph import Graph from pina._src.core.trainer import Trainer from pina._src.condition.condition import Condition -from pina._src.data.data_module import PinaDataModule +from pina._src.data.data_module import DataModule + +# Back-compatibility with version 0.2, to be removed soon +import warnings + +_DEPRECATED_IMPORTS = {"PinaDataModule": "DataModule"} + + +def __getattr__(name): + if name in _DEPRECATED_IMPORTS: + + warnings.warn( + f"Importing '{name}' from 'pina' is deprecated; use " + f"pina.{_DEPRECATED_IMPORTS[name]} instead.", + DeprecationWarning, + stacklevel=2, + ) + + return globals()[_DEPRECATED_IMPORTS[name]] diff --git a/pina/_src/callback/refinement/base_refinement.py b/pina/_src/callback/refinement/base_refinement.py index d1e8033b3..2f83cd11f 100644 --- a/pina/_src/callback/refinement/base_refinement.py +++ b/pina/_src/callback/refinement/base_refinement.py @@ -99,7 +99,7 @@ def on_train_start(self, trainer, solver): # Initialize dataset and compute initial population size self._dataset = trainer.datamodule.train_datasets self._initial_population_size = { - cond: self.dataset[cond].length + cond: self.dataset[cond].dataset_length for cond in self._condition_to_update } diff --git a/pina/_src/condition/base_condition.py b/pina/_src/condition/base_condition.py index 939c75e39..9adf19b2d 100644 --- a/pina/_src/condition/base_condition.py +++ b/pina/_src/condition/base_condition.py @@ -8,7 +8,7 @@ from pina._src.core.graph import LabelBatch from pina._src.core.label_tensor import LabelTensor from pina._src.core.utils import check_consistency -from pina._src.data.dummy_dataloader import DummyDataloader +from pina._src.data.single_batch_data_loader import _SingleBatchDataLoader from pina._src.problem.problem_interface import ProblemInterface @@ -74,9 +74,9 @@ def create_dataloader( :return: The DataLoader for the condition. :rtype: torch.utils.data.DataLoader """ - # If batching the entire dataset, return a DummyDataloader + # If batching the entire dataset, return a _SingleBatchDataLoader if batch_size == len(dataset): - return DummyDataloader(dataset) + return _SingleBatchDataLoader(dataset) # Otherwise, return a regular DataLoader with the appropriate collate return DataLoader( diff --git a/pina/_src/condition/time_series_condition.py b/pina/_src/condition/time_series_condition.py index a999ded28..300956a33 100644 --- a/pina/_src/condition/time_series_condition.py +++ b/pina/_src/condition/time_series_condition.py @@ -200,7 +200,7 @@ def evaluate(self, batch, solver): raise ValueError( "The provided input tensor must have at least 4 dimensions:" " [trajectories, windows, time_steps, *features]." - f" Got shape {batch["input"].shape}." + f" Got shape {batch['input'].shape}." ) # Copy the kwargs to avoid modifying the original settings diff --git a/pina/_src/core/trainer.py b/pina/_src/core/trainer.py index 575c2bfa2..1f25dfc0f 100644 --- a/pina/_src/core/trainer.py +++ b/pina/_src/core/trainer.py @@ -1,19 +1,18 @@ -"""Module for the Trainer.""" +"""Trainer utilities built on top of the PyTorch Lightning Trainer class.""" import sys import warnings import torch import lightning -from pina._src.core.utils import check_consistency, custom_warning_format -from pina._src.data.data_module import PinaDataModule -from pina._src.solver.solver_interface import ( - SolverInterface, +from pina._src.solver.base_solver import BaseSolver +from pina._src.data.data_module import DataModule +from pina._src.solver.pinn import PINN +from pina._src.core.utils import ( + check_consistency, + custom_warning_format, + check_positive_integer, ) -# from pina._src.solver.physics_informed_solver.pinn_interface import ( -# PINNInterface, -# ) - # set the warning for compile options warnings.formatwarning = custom_warning_format warnings.filterwarnings("always", category=UserWarning) @@ -21,14 +20,20 @@ class Trainer(lightning.pytorch.Trainer): """ - PINA custom Trainer class to extend the standard Lightning functionality. + PINA-specific extension of :class:`lightning.pytorch.Trainer`. - This class enables specific features or behaviors required by the PINA - framework. It modifies the standard - :class:`lightning.pytorch.Trainer ` - class to better support the training process in PINA. + The trainer configures solver execution, dataset splitting, batching, + logging, compilation support, device placement for unknown parameters, and + gradient tracking requirements for physics-informed solvers. """ + # Available batching modes + _AVAIL_BATCHING_MODES = { + "common_batch_size", + "proportional", + "separate_conditions", + } + def __init__( self, solver, @@ -36,143 +41,194 @@ def __init__( train_size=1.0, test_size=0.0, val_size=0.0, - compile=None, + compile=False, batching_mode="common_batch_size", - automatic_batching=None, - num_workers=None, - pin_memory=None, - shuffle=None, + automatic_batching=False, + num_workers=0, + pin_memory=False, + shuffle=True, **kwargs, ): """ Initialization of the :class:`Trainer` class. - :param SolverInterface solver: A - :class:`~pina.solver.solver.SolverInterface` solver used to solve a - :class:`~pina.problem.base_problem.BaseProblem`. - :param int batch_size: The number of samples per batch to load. - If ``None``, all samples are loaded and data is not batched. - Default is ``None``. - :param float train_size: The percentage of elements to include in the - training dataset. Default is ``1.0``. - :param float test_size: The percentage of elements to include in the - test dataset. Default is ``0.0``. - :param float val_size: The percentage of elements to include in the - validation dataset. Default is ``0.0``. - :param bool compile: If ``True``, the model is compiled before training. - Default is ``False``. For Windows users, it is always disabled. Not - supported for python version greater or equal than 3.14. - :param str batching_mode: The batching mode to use. Options are - ``"common_batch_size"``, ``"proportional"``, and - ``"separate_conditions"``. Default is ``"common_batch_size"``. - ``False``. - :param bool automatic_batching: If ``True``, automatic PyTorch batching - is performed, otherwise the items are retrieved from the dataset - all at once. For further details, see the - :class:`~pina.data.data_module.PinaDataModule` class. Default is - ``False``. - :param int num_workers: The number of worker threads for data loading. - Default is ``0`` (serial loading). - :param bool pin_memory: Whether to use pinned memory for faster data - transfer to GPU. Default is ``False``. - :param bool shuffle: Whether to shuffle the data during training. - Default is ``True``. - :param dict kwargs: Additional keyword arguments that specify the - training setup. These can be selected from the `pytorch-lightning - Trainer API - `_. + :param SolverInterface solver: The solver used to train, validate, and + test the associated problem. + :param int batch_size: The number of samples per batch. If ``None``, the + entire dataset is processed as a single batch. Default is ``None``. + :param float train_size: The fraction of samples assigned to the + training split. Must belong to the interval ``[0, 1]``. + Default is ``1.0``. + :param float val_size: The fraction of samples assigned to the + validation split. Must belong to the interval ``[0, 1]``. + Default is ``0.0``. + :param float test_size: The fraction of samples assigned to the test + split. Must belong to the interval ``[0, 1]``. Default is ``0.0``. + :param bool compile: Whether to compile the model before training. + Compilation is disabled on Windows and with Python 3.14 or later. + Default is ``False``. + :param str batching_mode: The strategy used to aggregate batches across + dataloaders. Available options are ``"common_batch_size"`` for + uniform batch sizes across conditions, ``"proportional"`` for batch + sizes proportional to dataset sizes, and ``"separate_conditions"`` + for iterating through each condition separately. + Default is ``"common_batch_size"``. + :param bool automatic_batching: Whether PyTorch automatic batching + should be enabled. If ``True``, dataset elements are retrieved + individually and collated into batches by the dataloader. + If ``False``, entire subsets are retrieved directly from the + condition object. Default is ``False``. + :param int num_workers: The number of worker processes used by + dataloaders. Default is ``0`` for sequential loading. + :param bool pin_memory: Whether pinned memory should be enabled during + data loading. Default is ``False``. + :param bool shuffle: Whether condition samples should be shuffled before + splitting. Default is ``True``. + :param dict kwargs: Additional keyword arguments forwarded to the + Lightning trainer. + :raises ValueError: If ``solver`` is not a PINA solver. + :raises ValueError: If ``train_size``, ``val_size``, or ``test_size`` is + not a float in the interval ``[0, 1]``. + :raises ValueError: If the sum of ``train_size``, ``val_size``, and + ``test_size`` is not equal to 1. + :raises ValueError: If ``compile``, ``automatic_batching``, + ``pin_memory``, or ``shuffle`` is not a boolean. + :raises AssertionError: If ``num_workers`` is a negative integer. + :raises ValueError: If ``batch_size``, when provided, is not a positive + integer. + :raises ValueError: If ``batching_mode`` is not one of the available + options. + :raises UserWarning: If compilation is requested on an unsupported + platform or Python version. + :raises UserWarning: If the provided ``batching_mode`` is incompatible + with the ``batch_size``. + :raises RuntimeError: If any domain in the problem has not been + discretised. """ - # check consistency for init types - self._check_input_consistency( - solver=solver, - train_size=train_size, - test_size=test_size, - val_size=val_size, - batching_mode=batching_mode, - automatic_batching=automatic_batching, - compile=compile, - ) - pin_memory, num_workers, shuffle, batch_size = ( - self._check_consistency_and_set_defaults( - pin_memory, num_workers, shuffle, batch_size + # Check consistency + check_consistency(solver, BaseSolver) + check_consistency(train_size, float) + check_consistency(test_size, float) + check_consistency(val_size, float) + check_consistency(compile, bool) + check_consistency(automatic_batching, bool) + check_consistency(pin_memory, bool) + check_consistency(shuffle, bool) + check_positive_integer(num_workers, strict=False) + if batch_size is not None: + check_positive_integer(batch_size, strict=True) + + # Check that train_size, test_size and val_size sum to 1 + total = train_size + val_size + test_size + if not torch.isclose(torch.tensor(total), torch.tensor(1.0)): + raise ValueError( + "`train_size`, `val_size`, and `test_size` must sum to 1." ) - ) - # inference mode set to false when validating/testing PINNs otherwise - # gradient is not tracked and optimization_cycle fails - # if isinstance(solver, PINNInterface): - kwargs["inference_mode"] = False + # Check consistency + if batching_mode not in self._AVAIL_BATCHING_MODES: + raise ValueError( + f"Invalid batching mode '{batching_mode}'. " + f"Expected one of: {sorted(self._AVAIL_BATCHING_MODES)}." + ) - # Logging depends on the batch size, when batch_size is None then - # log_every_n_steps should be zero - if batch_size is None: - kwargs["log_every_n_steps"] = 0 - else: - kwargs.setdefault("log_every_n_steps", 50) # default for lightning + # Set inference mode to false for PINN solvers to track gradients + if isinstance(solver, PINN): + kwargs["inference_mode"] = False - # Setting default kwargs, overriding lightning defaults + # Set log_every_n_steps to 0 if batch_size is None, otherwise default + kwargs["log_every_n_steps"] = ( + 0 if batch_size is None else kwargs.get("log_every_n_steps", 50) + ) + + # Set default value for enable_progress_bar to True if not provided kwargs.setdefault("enable_progress_bar", True) + # Initialize the parent class with the provided keyword arguments super().__init__(**kwargs) - # checking compilation and automatic batching - # compilation disabled for Windows and for Python 3.14+ - if ( - compile is None - or sys.platform == "win32" - or sys.version_info >= (3, 14) - ): - compile = False + # Disable compilation for Windows and Python 3.14+ + if sys.platform == "win32" or sys.version_info >= (3, 14) and compile: + + # Raise a warning if compilation is requested but not supported warnings.warn( - "Compilation is disabled for Python 3.14+ and for Windows.", + "Model compilation is not supported on Windows or with Python " + "3.14+. Compilation has been disabled.", UserWarning, ) - automatic_batching = ( - automatic_batching if automatic_batching is not None else False - ) + # Set compile to False if not supported + compile = False + # Raise warning if batch size and batching mode are incompatible if batch_size is None and batching_mode != "common_batch_size": warnings.warn( - "Batching mode is set to " - f"{batching_mode} but batch_size is None. " - "Batching mode will be set to common_batch_size.", + f"Batching mode '{batching_mode}' is ignored when the batch " + "size is None. Setting batching_mode to 'common_batch_size'.", UserWarning, ) + + # Set batching mode to common_batch_size if incompatible batching_mode = "common_batch_size" + # Raise warning if batch size and batching mode are incompatible if ( batch_size is not None - and batch_size <= len(solver.problem.conditions) and batching_mode == "proportional" + and batch_size <= len(solver.problem.conditions) ): warnings.warn( - "Batching mode is set to proportional but batch_size is 1. " - "Batching mode will be set to common_batch_size.", + "Batching mode 'proportional' requires the batch size to be " + "larger than the number of conditions. Setting batching_mode " + "to 'common_batch_size'.", UserWarning, ) + + # Set batching mode to common_batch_size if incompatible batching_mode = "common_batch_size" - # set attributes - self.compile = compile + # Initialize the class attributes self.solver = solver + self.compile = compile self.batch_size = batch_size + + # Move the unknown parameters to the correct device self._move_to_device() - self.data_module = None - self._create_datamodule( + # Check that all domains are discretised, otherwise raise an error + if not self.solver.problem.are_all_domains_discretised: + + # Get the list of sampled domains from the problem + sampled_domains = self.solver.problem.discretised_domains + + # Create a status message for each domain + status = "\n".join( + f" - Domain '{name}': " + f"{'sampled' if name in sampled_domains else 'not sampled'}" + for name in self.solver.problem.domains + ) + + # Raise an error with the status of each domain + raise RuntimeError( + "Cannot create the Trainer because some domains have not been " + f"sampled. Domain status:\n{status}" + ) + + # Create the data module + self.data_module = DataModule( + problem=self.solver.problem, train_size=train_size, test_size=test_size, val_size=val_size, - batch_size=batch_size, + batch_size=self.batch_size, batching_mode=batching_mode, automatic_batching=automatic_batching, - pin_memory=pin_memory, num_workers=num_workers, + pin_memory=pin_memory, shuffle=shuffle, ) - # logging + # Set logging kwargs self.logging_kwargs = { "sync_dist": bool( len(self._accelerator_connector._parallel_devices) > 1 @@ -184,109 +240,53 @@ def __init__( def _move_to_device(self): """ - Moves the ``unknown_parameters`` of an instance of - :class:`~pina.problem.base_problem.BaseProblem` to the - :class:`Trainer` device. + Move problem unknown parameters to the trainer device. + + If the associated problem defines ``unknown_parameters``, each parameter + is moved to the first device configured by the Lightning accelerator + connector. """ + # Get the device from the accelerator connector device = self._accelerator_connector._parallel_devices[0] - # move parameters to device - pb = self.solver.problem - if hasattr(pb, "unknown_parameters"): - for key in pb.unknown_parameters: - pb.unknown_parameters[key] = torch.nn.Parameter( - pb.unknown_parameters[key].data.to(device) - ) - def _create_datamodule( - self, - train_size, - test_size, - val_size, - batch_size, - batching_mode, - automatic_batching, - pin_memory, - num_workers, - shuffle, - ): - """ - This method is designed to handle the creation of a data module when - resampling is needed during training. Instead of manually defining and - modifying the trainer's dataloaders, this method is called to - automatically configure the data module. - - :param float train_size: The percentage of elements to include in the - training dataset. - :param float test_size: The percentage of elements to include in the - test dataset. - :param float val_size: The percentage of elements to include in the - validation dataset. - :param int batch_size: The number of samples per batch to load. - :param str batching_mode: The batching mode to use. Options are - ``"common_batch_size"``, ``"proportional"``, and - ``"separate_conditions"``. - :param bool automatic_batching: Whether to perform automatic batching - with PyTorch. - :param bool pin_memory: Whether to use pinned memory for faster data - transfer to GPU. - :param int num_workers: The number of worker threads for data loading. - :param bool shuffle: Whether to shuffle the data during training. - :raises RuntimeError: If not all conditions are sampled. - """ - if not self.solver.problem.are_all_domains_discretised: - error_message = "\n".join( - [ - f"""{" " * 13} ---> Domain {key} { - "sampled" if key in self.solver.problem.discretised_domains - else - "not sampled"}""" - for key in self.solver.problem.domains.keys() - ] - ) - raise RuntimeError( - "Cannot create Trainer if not all conditions " - "are sampled. The Trainer got the following:\n" - f"{error_message}" - ) - self.data_module = PinaDataModule( - self.solver.problem, - train_size=train_size, - test_size=test_size, - val_size=val_size, - batch_size=batch_size, - batching_mode=batching_mode, - automatic_batching=automatic_batching, - num_workers=num_workers, - pin_memory=pin_memory, - shuffle=shuffle, - ) + # Get the problem instance from the solver + problem = self.solver.problem + + # Move the unknown parameters to the correct device if they exist + if hasattr(problem, "unknown_parameters"): + for key in problem.unknown_parameters: + problem.unknown_parameters[key] = torch.nn.Parameter( + problem.unknown_parameters[key].data.to(device) + ) def train(self, **kwargs): """ - Manage the training process of the solver. + Fit the solver using the trainer data module. - :param dict kwargs: Additional keyword arguments. See `pytorch-lightning - Trainer API `_ - for details. + :param dict kwargs: Additional keyword arguments forwarded to the + Lightning trainer ``fit`` method. + :return: Result returned by Lightning's ``fit`` method. + :rtype: Any """ return super().fit(self.solver, datamodule=self.data_module, **kwargs) def test(self, **kwargs): """ - Manage the test process of the solver. + Test the solver using the trainer data module. - :param dict kwargs: Additional keyword arguments. See `pytorch-lightning - Trainer API `_ - for details. + :param dict kwargs: Additional keyword arguments forwarded to the + Lightning trainer ``test`` method. + :return: Result returned by Lightning's ``test`` method. + :rtype: Any """ return super().test(self.solver, datamodule=self.data_module, **kwargs) @property def solver(self): """ - Get the solver. + Return the solver attached to the trainer. - :return: The solver. + :return: The solver used by the trainer. :rtype: SolverInterface """ return self._solver @@ -294,86 +294,18 @@ def solver(self): @solver.setter def solver(self, solver): """ - Set the solver. + Set the solver attached to the trainer. - :param SolverInterface solver: The solver to set. + :param SolverInterface solver: The solver instance to attach. """ self._solver = solver - @staticmethod - def _check_input_consistency( - solver, - train_size, - test_size, - val_size, - batching_mode, - automatic_batching, - compile, - ): - """ - Verifies the consistency of the parameters for the solver configuration. - - :param SolverInterface solver: The solver. - :param float train_size: The percentage of elements to include in the - training dataset. - :param float test_size: The percentage of elements to include in the - test dataset. - :param float val_size: The percentage of elements to include in the - validation dataset. - :param str batching_mode: The batching mode to use. Options are - ``"common_batch_size"``, ``"proportional"``, and - ``"separate_conditions"``. - :param bool automatic_batching: Whether to perform automatic batching - with PyTorch. - :param bool compile: If ``True``, the model is compiled before training. - """ - - check_consistency(solver, SolverInterface) - check_consistency(train_size, float) - check_consistency(test_size, float) - check_consistency(val_size, float) - check_consistency(batching_mode, str) - if automatic_batching is not None: - check_consistency(automatic_batching, bool) - if compile is not None: - check_consistency(compile, bool) - - @staticmethod - def _check_consistency_and_set_defaults( - pin_memory, num_workers, shuffle, batch_size - ): - """ - Checks the consistency of input parameters and sets default values - for missing or invalid parameters. - - :param bool pin_memory: Whether to use pinned memory for faster data - transfer to GPU. - :param int num_workers: The number of worker threads for data loading. - :param bool shuffle: Whether to shuffle the data during training. - :param int batch_size: The number of samples per batch to load. - """ - if pin_memory is not None: - check_consistency(pin_memory, bool) - else: - pin_memory = False - if num_workers is not None: - check_consistency(num_workers, int) - else: - num_workers = 0 - if shuffle is not None: - check_consistency(shuffle, bool) - else: - shuffle = True - if batch_size is not None: - check_consistency(batch_size, int) - return pin_memory, num_workers, shuffle, batch_size - @property def compile(self): """ - Whether compilation is required or not. + Return whether model compilation is enabled. - :return: ``True`` if compilation is required, ``False`` otherwise. + :return: ``True`` if compilation is enabled, otherwise ``False``. :rtype: bool """ return self._compile @@ -381,9 +313,8 @@ def compile(self): @compile.setter def compile(self, value): """ - Setting the value of compile. + Set the value of compile. :param bool value: Whether compilation is required or not. """ - check_consistency(value, bool) self._compile = value diff --git a/pina/_src/data/aggregator.py b/pina/_src/data/aggregator.py index 605af5d46..d6e149a3f 100644 --- a/pina/_src/data/aggregator.py +++ b/pina/_src/data/aggregator.py @@ -1,61 +1,87 @@ -""" -Aggregator for multiple dataloaders. -""" +"""Utility class for aggregating multiple dataloaders into a single iterable.""" class _Aggregator: """ - The class :class:`_Aggregator` is responsible for aggregating multiple - dataloaders into a single iterable object. It supports different batching - modes to accommodate various training requirements. + Aggregate multiple dataloaders into a unified iterable object. + + The aggregator combines batches produced by multiple dataloaders according + to the selected batching strategy. It is primarily used to coordinate the + iteration of multiple training conditions within a single training loop. """ def __init__(self, dataloaders, batching_mode): """ Initialization of the :class:`_Aggregator` class. - :param dataloaders: A dictionary mapping condition names to their - respective dataloaders. - :type dataloaders: dict[str, DataLoader] - :param batching_mode: The batching mode to use. Options are - ``"common_batch_size"``, ``"proportional"``, and - ``"separate_conditions"``. - :type batching_mode: str + :param dict[str, DataLoader] dataloaders: The mapping between condition + names and their corresponding dataloaders. + :param str batching_mode: The strategy used to aggregate batches across + dataloaders. Available options are ``"common_batch_size"`` for + uniform batch sizes across conditions, ``"proportional"`` for batch + sizes proportional to dataset sizes, and ``"separate_conditions"`` + for iterating through each condition separately. + :raises NotImplementedError: If the selected batching mode is not yet + implemented. """ + # Raise not implemented error for separate_conditions mode + if batching_mode == "separate_conditions": + raise NotImplementedError( + "Batching mode 'separate_conditions' is not implemented yet." + ) + + # Initialize attributes self.dataloaders = dataloaders self.batching_mode = batching_mode def __len__(self): """ - Return the length of the aggregated dataloader. + Return the length of the aggregated dataloader. The length is determined + by the number of iterations required to exhaust the dataloaders based on + the selected batching mode. + + For ``"separate_conditions"``, the total number of iterations is the sum + of the lengths of all dataloaders. For all other batching modes, the + length corresponds to the maximum length among the aggregated + dataloaders. :return: The length of the aggregated dataloader. :rtype: int """ + # Separate conditions case if self.batching_mode == "separate_conditions": return sum(len(dl) for dl in self.dataloaders.values()) + return max(len(dl) for dl in self.dataloaders.values()) def __iter__(self): """ - Return an iterator over the aggregated dataloader. + Iterate over the aggregated dataloaders. - :return: An iterator over the aggregated dataloader. - :rtype: iterator - """ - if self.batching_mode == "separate_conditions": - # TODO: implement separate_conditions batching mode - raise NotImplementedError( - "Batching mode 'separate_conditions' is not implemented yet." - ) + At each iteration, a dictionary containing one batch per dataloader is + yielded. If a dataloader is exhausted before the others, its iterator is + restarted automatically to ensure continuous batch generation. + :yield: The dictionary mapping each condition name to its batch. + :rtype: Iterator[dict[str, Any]] + """ + # Initialize iterators for each dataloader iterators = {name: iter(dl) for name, dl in self.dataloaders.items()} + + # Iterate until the maximum number of iterations is reached for _ in range(len(self)): batch = {} - for name, it in iterators.items(): + + # Generate a batch for each dataloader + for name, dataloader in self.dataloaders.items(): + + # Attempt to get the next batch from the dataloader's iterator try: - batch[name] = next(it) + batch[name] = next(iterators[name]) + + # Restart the iterator if it is exhausted except StopIteration: - iterators[name] = iter(self.dataloaders[name]) + iterators[name] = iter(dataloader) batch[name] = next(iterators[name]) + yield batch diff --git a/pina/_src/data/condition_subset.py b/pina/_src/data/condition_subset.py new file mode 100644 index 000000000..068e833a2 --- /dev/null +++ b/pina/_src/data/condition_subset.py @@ -0,0 +1,101 @@ +"""Utilities for handling condition dataset subsets.""" + +from torch_geometric.data import Batch +from pina._src.core.graph import LabelBatch, Graph + + +class _ConditionSubset: + """ + Wrapper around a condition dataset restricted to a subset of indices. + + The class behaves similarly to :class:`torch.utils.data.Subset` and supports + cyclic indexing together with optional automatic batching. + """ + + def __init__(self, condition, indices, automatic_batching): + """ + Initialization of the :class:`_ConditionSubset` class. + + :param BaseCondition condition: The underlying condition. + :param list[int] indices: The list of indices identifying the subset + samples. + :param bool automatic_batching: Whether dataset items should be returned + directly or as raw indices. + """ + super().__init__() + + # Initialize the class attributes + self.condition = condition + self.indices = indices + self.automatic_batching = automatic_batching + + # Actual number of samples contained in the subset + self.dataset_length = len(self.indices) + + # Effective iterable length used and modified during batching + self.iterable_length = self.dataset_length + + def __len__(self): + """ + Return the effective iterable length of the subset. + + :return: The number of accessible elements in the subset. + :rtype: int + """ + return self.iterable_length + + def __getitem__(self, idx): + """ + Retrieve an element from the subset. + + If the requested index exceeds the actual dataset size, cyclic indexing + is applied through modulo wrapping. When automatic batching is disabled, + the raw dataset index is returned instead of the corresponding sample. + + :param int idx: The position of the element inside the subset. + :return: The dataset sample or raw dataset index depending on the + batching configuration. + :rtype: dict | int + """ + # Apply cyclic indexing if the requested index exceeds the subset length + if idx >= self.dataset_length: + idx = idx % self.dataset_length + + # Fetch the corresponding dataset index from the list of indices + idx = self.indices[idx] + + # Return the raw dataset index if automatic batching is disabled + if not self.automatic_batching: + return idx + + return self.condition[idx] + + def get_all_data(self): + """ + Retrieve and aggregate all subset samples. + + If the returned data contains a ``"data"`` field composed of graph + objects, the samples are merged into a single batched graph structure + using the appropriate batching implementation. + + :return: The aggregated subset data. + :rtype: dict + """ + # Fetch the data corresponding to the subset indices + data = self.condition[self.indices] + + # Data as a list of graph objects merged into a single batched graph + if "data" in data and isinstance(data["data"], list): + + # Define the batching function + batch_fn = ( + LabelBatch.from_data_list + if isinstance(data["data"][0], Graph) + else Batch.from_data_list + ) + + # Merge the list of graph objects into a single batched graph + data["data"] = batch_fn(data["data"]) + data = {"input": data["data"], "target": data["data"].y} + + return data diff --git a/pina/_src/data/creator.py b/pina/_src/data/creator.py index 95140082b..4a5e3207b 100644 --- a/pina/_src/data/creator.py +++ b/pina/_src/data/creator.py @@ -1,18 +1,15 @@ -""" -Module defining the Creator class, responsible for creating dataloaders -for multiple conditions with various batching strategies. -""" +"""Module for creating dataloaders for multiple conditions.""" import torch -from torch.utils.data import RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler class _Creator: """ - The class :class:`_Creator` is responsible for creating dataloaders for - multiple conditions based on specified batching strategies. It supports - different batching modes to accommodate various training requirements. + Utility class for creating data loaders associated with multiple conditions. + + The class supports different batching strategies to adapt data loading + behavior to specific training requirements """ def __init__( @@ -28,28 +25,29 @@ def __init__( """ Initialization of the :class:`_Creator` class. - :param batching_mode: The batching mode to use. Options are - ``"common_batch_size"``, ``"proportional"``, and - ``"separate_conditions"``. - :type batching_mode: str - :param batch_size: The batch size to use for dataloaders. If - ``batching_mode`` is ``"proportional"``, this represents the total - batch size across all conditions. - :type batch_size: int | None - :param shuffle: Whether to shuffle the data in the dataloaders. - :type shuffle: bool - :param automatic_batching: Whether to use automatic batching in the - dataloaders. - :type automatic_batching: bool - :param num_workers: The number of worker processes to use for data + :param str batching_mode: The strategy used to aggregate batches across + data loaders. Available options are ``"common_batch_size"`` for + uniform batch sizes across conditions, ``"proportional"`` for batch + sizes proportional to dataset sizes, and ``"separate_conditions"`` + for iterating through each condition separately. + :param int batch_size: Batch size configuration used by the selected + batching strategy. For ``"common_batch_size"``, the same batch size + is assigned to all conditions. For ``"proportional"``, this value + represents the total batch size distributed proportionally across + conditions. For ``"separate_conditions"``, this value is applied + independently to each condition and capped by the corresponding + dataset size. + :param bool shuffle: Whether samples should be shuffled during loading. + :param bool automatic_batching: Whether automatic batching should be + enabled in the data loaders. + :param int num_workers: The number of worker processes used for data loading. - :type num_workers: int - :param pin_memory: Whether to pin memory in the dataloaders. - :type pin_memory: bool - :param conditions: A dictionary mapping condition names to their - respective condition objects. - :type conditions: dict[str, Condition] + :param bool pin_memory: Whether data loaders should pin memory. + :param dict[str, BaseCondition] conditions: The mapping between + condition names and condition objects responsible for data loader + creation. """ + # Initialize attributes self.batching_mode = batching_mode self.batch_size = batch_size self.shuffle = shuffle @@ -58,132 +56,185 @@ def __init__( self.pin_memory = pin_memory self.conditions = conditions + def __call__(self, datasets): + """ + Create data loaders for all provided datasets. + + Batch sizes are computed according to the selected batching mode, and a + dedicated data loader is created for each condition. + + :param dict[str, _ConditionSubset] datasets: The mapping between + condition names and datasets. + :return: The mapping between condition names and the corresponding + data loaders. + :rtype: dict[str, DataLoader] + """ + # Compute batch sizes per condition based on batching_mode + batch_sizes = self._compute_batch_sizes(datasets) + dataloaders = {} + + # If common_batch_size mode, ensure all datasets have the same length + if self.batching_mode == "common_batch_size": + iterable_length = max(len(dataset) for dataset in datasets.values()) + + # Iterate through datasets and create dataloaders + for name, dataset in datasets.items(): + + # If common_batch_size mode, set max_len for datasets + if ( + self.batching_mode == "common_batch_size" + and dataset.dataset_length != batch_sizes[name] + ): + dataset.iterable_length = iterable_length + + # Create dataloader for the current condition + dataloaders[name] = self.conditions[name].create_dataloader( + dataset=dataset, + batch_size=batch_sizes[name], + automatic_batching=self.automatic_batching, + sampler=self._define_sampler(dataset, self.shuffle), + num_workers=self.num_workers, + pin_memory=self.pin_memory, + ) + + return dataloaders + def _define_sampler(self, dataset, shuffle): + """ + Define the sampling strategy for a dataset. + + Distributed training uses :class:`DistributedSampler`, while + non-distributed execution uses either :class:`RandomSampler` or + :class:`SequentialSampler` depending on ``shuffle``. + + :param _ConditionSubset dataset: The dataset associated with the + sampler. + :param bool shuffle: Whether samples should be shuffled during loading. + :return: The configured sampler instance. + :rtype: Sampler + """ + # Distributed training case if torch.distributed.is_initialized(): return DistributedSampler(dataset, shuffle=shuffle) + + # Non-distributed training case - shuffle True if shuffle: - return RandomSampler(dataset) - return SequentialSampler(dataset) + return torch.utils.data.RandomSampler(dataset) + + # Non-distributed training case - shuffle False + return torch.utils.data.SequentialSampler(dataset) def _compute_batch_sizes(self, datasets): """ - Compute batch sizes for each condition based on the specified - batching mode. + Compute batch sizes for each dataset according to the selected batching + mode. - :param datasets: A dictionary mapping condition names to their - respective datasets. - :type datasets: dict[str, Dataset] - :return: A dictionary mapping condition names to their computed batch - sizes. + :param dict[str, _ConditionSubset] datasets: The mapping between + condition names and datasets. + :return: The mapping between condition names and computed batch sizes. :rtype: dict[str, int] """ - batch_sizes = {} + # Common batch size mode if self.batching_mode == "common_batch_size": - if self.batch_size is None: - batch_size = max( - dataset.length for dataset in datasets.values() - ) - else: - batch_size = self.batch_size + # Compute batch size + batch_size = ( + max(dataset.dataset_length for dataset in datasets.values()) + if self.batch_size is None + else self.batch_size + ) + + return { + name: min(batch_size, len(dataset)) + for name, dataset in datasets.items() + } - for name in datasets.keys(): - batch_sizes[name] = min(batch_size, len(datasets[name])) - return batch_sizes + # Proportional batch size mode if self.batching_mode == "proportional": return self._compute_proportional_batch_sizes(datasets) - if self.batching_mode == "separate_conditions": - for name in datasets.keys(): - condition = self.conditions[name] - if self.batch_size is None: - batch_sizes[name] = len(datasets[name]) - else: - batch_sizes[name] = min( - self.batch_size, len(datasets[name]) - ) - return batch_sizes - raise ValueError(f"Unknown batching mode: {self.batching_mode}") + + # Separate conditions mode + return { + name: ( + len(dataset) + if self.batch_size is None + else min(self.batch_size, len(dataset)) + ) + for name, dataset in datasets.items() + } def _compute_proportional_batch_sizes(self, datasets): """ - Compute batch sizes for each condition proportionally based on the - size of their datasets. - :param datasets: A dictionary mapping condition names to their - respective datasets. - :type datasets: dict[str, Dataset] - :return: A dictionary mapping condition names to their computed batch + Compute batch sizes proportionally to dataset sizes. + + Each dataset receives a fraction of the total batch size proportional to + its number of samples, while ensuring that each dataset contributes at + least one sample. + + :param dict[str, _ConditionSubset] datasets: The mapping between + condition names and datasets. + :return: The mapping between condition names and proportional batch sizes. :rtype: dict[str, int] """ - # Compute number of elements per dataset - elements_per_dataset = { - dataset_name: len(dataset) - for dataset_name, dataset in datasets.items() + # Compute the sizes of each dataset + dataset_sizes = { + name: len(dataset) for name, dataset in datasets.items() } - # Compute the total number of elements - total_elements = sum(el for el in elements_per_dataset.values()) - # Compute the portion of each dataset - portion_per_dataset = { - name: el / total_elements - for name, el in elements_per_dataset.items() - } - # Compute batch size per dataset. Ensure at least 1 element per - # dataset. - batch_size_per_dataset = { - name: max(1, int(portion * self.batch_size)) - for name, portion in portion_per_dataset.items() + + # Determine the total number of elements across all datasets + total_size = sum(dataset_sizes.values()) + + # Compute the batch sizes + batch_sizes = { + name: max(1, int(self.batch_size * (size / total_size))) + for name, size in dataset_sizes.items() } - # Adjust batch sizes to match the specified total batch size - tot_el_per_batch = sum(el for el in batch_size_per_dataset.values()) - if self.batch_size > tot_el_per_batch: - difference = self.batch_size - tot_el_per_batch - while difference > 0: - for k, v in batch_size_per_dataset.items(): - if difference == 0: - break - if v > 1: - batch_size_per_dataset[k] += 1 - difference -= 1 - if self.batch_size < tot_el_per_batch: - difference = tot_el_per_batch - self.batch_size - while difference > 0: - for k, v in batch_size_per_dataset.items(): - if difference == 0: - break - if v > 1: - batch_size_per_dataset[k] -= 1 - difference -= 1 - return batch_size_per_dataset - def __call__(self, datasets): - """ - Create dataloaders for each condition based on the specified batching - mode. - :param datasets: A dictionary mapping condition names to their - respective datasets. - :type datasets: dict[str, Dataset] - :return: A dictionary mapping condition names to their created - dataloaders. - :rtype: dict[str, DataLoader] - """ - # Compute batch sizes per condition based on batching_mode - batch_sizes = self._compute_batch_sizes(datasets) - dataloaders = {} - if self.batching_mode == "common_batch_size": - max_len = max(len(dataset) for dataset in datasets.values()) + # Compute assigned batch size and difference with the total batch size + assigned_batch_size = sum(batch_sizes.values()) + difference = self.batch_size - assigned_batch_size - for name, dataset in datasets.items(): - if ( - self.batching_mode == "common_batch_size" - and dataset.length != batch_sizes[name] - ): - dataset.max_len = max_len - dataloaders[name] = self.conditions[name].create_dataloader( - dataset=dataset, - batch_size=batch_sizes[name], - automatic_batching=self.automatic_batching, - sampler=self._define_sampler(dataset, self.shuffle), - num_workers=self.num_workers, - pin_memory=self.pin_memory, + # If difference > 0, distribute to datasets with more than 1 sample + if difference > 0: + + # Sort datasets by size in descending order + sorted_datasets = sorted( + dataset_sizes, + key=lambda name: dataset_sizes[name], + reverse=True, ) - return dataloaders + + # Distribute to datasets with more than 1 sample + for name in sorted_datasets: + + # Stop distribution when the difference is fully allocated + if difference == 0: + break + + # Distribute to datasets with more than 1 sample + if dataset_sizes[name] > 1: + batch_sizes[name] += 1 + difference -= 1 + + # If difference < 0, reduce from datasets with more than 1 sample + if difference < 0: + + # Sort batches by size in descending order + sorted_batches = sorted( + batch_sizes, key=lambda name: batch_sizes[name], reverse=True + ) + + # Reduce from datasets with more than 1 sample + for name in sorted_batches: + + # Stop reduction when the difference is fully allocated + if difference == 0: + break + + # Reduce from datasets with more than 1 sample + if batch_sizes[name] > 1: + batch_sizes[name] -= 1 + difference += 1 + + return batch_sizes diff --git a/pina/_src/data/data_module.py b/pina/_src/data/data_module.py index 4a5b2c66a..c5d3804a5 100644 --- a/pina/_src/data/data_module.py +++ b/pina/_src/data/data_module.py @@ -1,253 +1,189 @@ """ -This module contains the PinaDataModule class, which extends the -LightningDataModule class to allow proper creation and management of -different types of Datasets defined in PINA. +Utilities for creating and managing datasets and dataloaders. + +This module defines a custom extension of the Lighting DataModule used to handle +dataset splitting, batching, and dataloader creation for PINA conditions. """ import warnings -from lightning.pytorch import LightningDataModule import torch -from torch_geometric.data import Batch -from pina._src.data.creator import _Creator -from pina._src.core.graph import LabelBatch, Graph +from lightning.pytorch import LightningDataModule +from pina._src.data.condition_subset import _ConditionSubset from pina._src.data.aggregator import _Aggregator +from pina._src.data.creator import _Creator -class _ConditionSubset: - """ - This class extends the :class:`torch.utils.data.Subset` class, allowing to - fetch the data from the dataset based on a list of indices. +class DataModule(LightningDataModule): """ + An extension of the Lightning data module for managing PINA condition + datasets. - def __init__(self, condition, indices, automatic_batching): - super().__init__() - self.condition = condition - self.indices = indices - self.automatic_batching = automatic_batching - self.length = len(self.indices) - self.max_len = self.length - - def __len__(self): - return self.max_len - - def __getitem__(self, idx): - """ - Fetch the data from the dataset based on the list of indices. - - :param int idx: The index of the data to be fetched. - :return: The data corresponding to the given index. - :rtype: dict - """ - if idx >= self.length: - idx = idx % self.length - idx = self.indices[idx] - if not self.automatic_batching: - return idx - return self.condition[idx] - - def get_all_data(self): - data = self.condition[self.indices] - if "data" in data and isinstance(data["data"], list): - batch_fn = ( - LabelBatch.from_data_list - if isinstance(data["data"][0], Graph) - else Batch.from_data_list - ) - data["data"] = batch_fn(data["data"]) - data = { - "input": data["data"], - "target": data["data"].y, - } - return data - + The data module handles train/validation/test dataset splitting, condition + subset creation, dataloader construction, and batching coordination across + multiple conditions. -class PinaDataModule(LightningDataModule): - """ - This class extends :class:`~lightning.pytorch.core.LightningDataModule`, - allowing proper creation and management of different types of datasets - defined in PINA. + Dataset splitting is performed independently for each condition, and the + resulting subsets are wrapped into :class:`_ConditionSubset` objects. + Dataloaders are then created and aggregated according to the selected + batching strategy. """ def __init__( self, problem, - train_size=0.7, - test_size=0.2, - val_size=0.1, - batch_size=None, - shuffle=True, - batching_mode="common_batch_size", - automatic_batching=None, - num_workers=0, - pin_memory=False, + train_size, + val_size, + test_size, + batch_size, + batching_mode, + automatic_batching, + shuffle, + num_workers, + pin_memory, ): """ - Initialize the object and creating datasets based on the input problem. - - :param BaseProblem problem: The problem containing the data on which - to create the datasets and dataloaders. - :param float train_size: Fraction of elements in the training split. It - must be in the range [0, 1]. - :param float test_size: Fraction of elements in the test split. It must - be in the range [0, 1]. - :param float val_size: Fraction of elements in the validation split. It - must be in the range [0, 1]. - :param int batch_size: The batch size used for training. If ``None``, - the entire dataset is returned in a single batch. - Default is ``None``. - :param bool shuffle: Whether to shuffle the dataset before splitting. - Default ``True``. - :param str batching_mode: The batching mode to use. Options are - ``"common_batch_size"``, ``"proportional"``, and - ``"separate_conditions"``. Default is ``"common_batch_size"``. - :param automatic_batching: If ``True``, automatic PyTorch batching - is performed, which consists of extracting one element at a time - from the dataset and collating them into a batch. This is useful - when the dataset is too large to fit into memory. On the other hand, - if ``False``, the items are retrieved from the dataset all at once - avoind the overhead of collating them into a batch and reducing the - ``__getitem__`` calls to the dataset. This is useful when the - dataset fits into memory. Avoid using automatic batching when - ``batch_size`` is large. Default is ``False``. - :param int num_workers: Number of worker threads for data loading. - Default ``0`` (serial loading). - :param bool pin_memory: Whether to use pinned memory for faster data - transfer to GPU. Default ``False``. - - :raises ValueError: If at least one of the splits is negative. - :raises ValueError: If the sum of the splits is different from 1. - - .. seealso:: - For more information on multi-process data loading, see: - https://pytorch.org/docs/stable/data.html#multi-process-data-loading - - For details on memory pinning, see: - https://pytorch.org/docs/stable/data.html#memory-pinning + Initialization of the :class:`DataModule` class. + + :param BaseProblem problem: The problem containing the conditions and + sampled data used to construct datasets and dataloaders. + :param float train_size: The fraction of samples assigned to the + training split. Must belong to the interval ``[0, 1]``. + :param float val_size: The fraction of samples assigned to the + validation split. Must belong to the interval ``[0, 1]``. + :param float test_size: The fraction of samples assigned to the test + split. Must belong to the interval ``[0, 1]``. + :param int batch_size: The number of samples per batch. If ``None``, the + entire dataset is processed as a single batch. + :param str batching_mode: The strategy used to aggregate batches across + dataloaders. Available options are ``"common_batch_size"`` for + uniform batch sizes across conditions, ``"proportional"`` for batch + sizes proportional to dataset sizes, and ``"separate_conditions"`` + for iterating through each condition separately. + :param bool automatic_batching: Whether PyTorch automatic batching + should be enabled. If ``True``, dataset elements are retrieved + individually and collated into batches by the dataloader. + If ``False``, entire subsets are retrieved directly from the + condition object. + :param bool shuffle: Whether condition samples should be shuffled before + splitting. + :param int num_workers: The number of worker processes used by + dataloaders. + :param bool pin_memory: Whether pinned memory should be enabled during + data loading. + :raises UserWarning: If ``num_workers`` is set to non-default value + while ``batch_size`` is None. + :raises UserWarning: If ``pin_memory`` is set to ``True`` while + ``batch_size`` is None. """ super().__init__() + # Initialize the attributes -- consistency checked in trainer self.problem = problem - # Store fixed attributes self.batch_size = batch_size - self.shuffle = shuffle self.batching_mode = batching_mode self.automatic_batching = automatic_batching - self.batching_mode = batching_mode + self.shuffle = shuffle + self.num_workers = num_workers + self.pin_memory = pin_memory # If batch size is None, num_workers has no effect if batch_size is None and num_workers != 0: - warnings.warn( - "Setting num_workers when batch_size is None has no effect on " - "the DataLoading process." - ) + warnings.warn("num_workers has no effect when batch_size is None.") self.num_workers = 0 - else: - self.num_workers = num_workers # If batch size is None, pin_memory has no effect if batch_size is None and pin_memory: - warnings.warn( - "Setting pin_memory to True has no effect when " - "batch_size is None." - ) + warnings.warn("pin_memory has no effect when batch_size is None.") self.pin_memory = False - else: - self.pin_memory = pin_memory + + # Move domain discretisation into conditions subsets self.problem.move_discretisation_into_conditions() - self._check_slit_sizes(train_size, test_size, val_size) - - # TODO: singular forms (train_dataset, val_dataset, test_dataset) seem - # to be unused. Clean code. - if train_size > 0: - self.train_dataset = None - else: - # Use the super method to create the train dataloader which - # raises NotImplementedError + + # If no splits are defined, use the default dataloaders + if train_size == 0: self.train_dataloader = super().train_dataloader - if test_size > 0: - self.test_dataset = None - else: - # Use the super method to create the train dataloader which - # raises NotImplementedError - self.test_dataloader = super().test_dataloader - if val_size > 0: - self.val_dataset = None - else: - # Use the super method to create the train dataloader which - # raises NotImplementedError + if val_size == 0: self.val_dataloader = super().val_dataloader + if test_size == 0: + self.test_dataloader = super().test_dataloader - self._create_condition_splits(problem, train_size, test_size, val_size) + # Otherwise, create the condition splits and initialize the creator + self._create_condition_splits(train_size, test_size) self.creator = _Creator( - batching_mode=batching_mode, - batch_size=batch_size, - shuffle=shuffle, - automatic_batching=automatic_batching, - num_workers=num_workers, - pin_memory=pin_memory, - conditions=problem.conditions, + batching_mode=self.batching_mode, + batch_size=self.batch_size, + shuffle=self.shuffle, + automatic_batching=self.automatic_batching, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + conditions=self.problem.conditions, ) - @staticmethod - def _check_slit_sizes(train_size, test_size, val_size): + def _create_condition_splits(self, train_size, test_size): """ - Check if the splits are correct. The splits sizes must be positive and - the sum of the splits must be 1. + Create train/validation/test index splits for each condition. - :param float train_size: The size of the training split. - :param float test_size: The size of the testing split. - :param float val_size: The size of the validation split. + Samples belonging to each condition are optionally shuffled before being + partitioned into train, validation, and test subsets according to the + specified split fractions. - :raises ValueError: If at least one of the splits is negative. - :raises ValueError: If the sum of the splits is different - from 1. + :param float train_size: The fraction of samples assigned to the + training split. Must belong to the interval ``[0, 1]``. + :param float test_size: The fraction of samples assigned to the test + split. Must belong to the interval ``[0, 1]``. """ + # Initialize the dictionary to store the split idx for each condition + self.split_idxs = {} - if train_size < 0 or test_size < 0 or val_size < 0: - raise ValueError("The splits must be positive") - if abs(train_size + test_size + val_size - 1) > 1e-6: - raise ValueError("The sum of the splits must be 1") + # Iterate through conditions and create the splits + for condition_name, condition in self.problem.conditions.items(): - def _create_condition_splits( - self, problem, train_size, test_size, val_size - ): - self.split_idxs = {} - for condition_name, condition in problem.conditions.items(): - len_condition = len(condition) - # Create the indices for shuffling and splitting + # Get the total number of samples for the current condition + condition_length = len(condition) + + # Generate shuffled or sequential indices for the condition samples indices = ( - torch.randperm(len_condition).tolist() + torch.randperm(condition_length).tolist() if self.shuffle - else list(range(len_condition)) + else list(range(condition_length)) ) - # Determine split sizes - train_end = int(train_size * len_condition) - test_end = train_end + int(test_size * len_condition) - - # Split indices - train_indices = indices[:train_end] - test_indices = indices[train_end:test_end] - val_indices = indices[test_end:] - splits = {} - splits["train"], splits["test"], splits["val"] = ( - train_indices, - test_indices, - val_indices, - ) - self.split_idxs[condition_name] = splits + # Compute the split indices for train, validation, and test subsets + train_end = int(train_size * condition_length) + test_end = train_end + int(test_size * condition_length) + + # Store the computed split indices in the dictionary + self.split_idxs[condition_name] = { + "train": indices[:train_end], + "test": indices[train_end:test_end], + "val": indices[test_end:], + } def setup(self, stage=None): """ - Create the dataset objects for the given stage. - If the stage is "fit", the training and validation datasets are created. - If the stage is "test", the testing dataset is created. + Create dataset subsets for the requested execution stage. - :param str stage: The stage for which to perform the dataset setup. + Depending on the selected stage, it initializes the ``train_datasets``, + the ``val_datasets``, or the ``test_datasets`` attributes. Each dataset + is represented as a mapping between condition names and + :class:`_ConditionSubset` instances. - :raises ValueError: If the stage is neither "fit" nor "test". + :param str stage: The execution stage. Available options are ``"fit"`` + for training/validation and ``"test"`` for testing. If ``None``, both + training/validation and testing datasets are created. + Default is ``None``. + :raises ValueError: If the provided stage is invalid. """ + # Validate the stage argument + if stage not in ("fit", "test", None): + raise ValueError( + f"Invalid stage. Got {stage}, expected either 'fit' or 'test'." + ) + + # Fit stage: create training and validation datasets if stage in ("fit", None): + + # Train dataset self.train_datasets = { name: _ConditionSubset( condition, @@ -258,6 +194,7 @@ def setup(self, stage=None): if len(self.split_idxs[name]["train"]) > 0 } + # Validation dataset self.val_datasets = { name: _ConditionSubset( condition, @@ -268,7 +205,10 @@ def setup(self, stage=None): if len(self.split_idxs[name]["val"]) > 0 } + # Test stage: create testing dataset if stage in ("test", None): + + # Test dataset self.test_datasets = { name: _ConditionSubset( condition, @@ -278,56 +218,61 @@ def setup(self, stage=None): for name, condition in self.problem.conditions.items() if len(self.split_idxs[name]["test"]) > 0 } - if stage not in ("fit", "test", None): - raise ValueError( - f"Invalid stage {stage}. Stage must be either 'fit' or 'test'." - ) + + def transfer_batch_to_device(self, batch, device, _): + """ + Transfer a batch to the target device. + + The method transfers all condition batches contained in the aggregated + batch dictionary to the specified device. + + :param dict batch: The mapping between the condition names and the + condition batches. + :param torch.device device: The target device. + :param _: Placeholder argument, not used. + :return: A list of tuples containing condition names and transferred + batches. + :rtype: list[tuple[str, Any]] + """ + return [ + (condition_name, condition.to(device)) + for condition_name, condition in batch.items() + ] def train_dataloader(self): + """ + Create the aggregated train dataloader. + + :return: The aggregated dataloader coordinating all train condition + dataloaders. + :rtype: _Aggregator + """ return _Aggregator( self.creator(self.train_datasets), batching_mode=self.batching_mode, ) def val_dataloader(self): + """ + Create the aggregated validation dataloader. + + :return: The aggregated dataloader coordinating all validation condition + dataloaders. + :rtype: _Aggregator + """ return _Aggregator( self.creator(self.val_datasets), batching_mode=self.batching_mode ) def test_dataloader(self): + """ + Create the aggregated test dataloader. + + :return: The aggregated dataloader coordinating all test condition + dataloaders. + :rtype: _Aggregator + """ return _Aggregator( self.creator(self.test_datasets), batching_mode=self.batching_mode, ) - - @staticmethod - def _transfer_batch_to_device_dummy(batch, device, dataloader_idx): - """ - Transfer the batch to the device. This method is used when the batch - size is None: batch has already been transferred to the device. - - :param list[tuple] batch: List of tuple where the first element of the - tuple is the condition name and the second element is the data. - :param torch.device device: Device to which the batch is transferred. - :param int dataloader_idx: Index of the dataloader. - :return: The batch transferred to the device. - :rtype: list[tuple] - """ - return batch - - def transfer_batch_to_device(self, batch, device, dataloader_idx): - """ - Transfer the batch to the device. This method is called in the - training loop and is used to transfer the batch to the device. - - :param dict batch: The batch to be transferred to the device. - :param torch.device device: The device to which the batch is - transferred. - :param int dataloader_idx: The index of the dataloader. - :return: The batch transferred to the device. - :rtype: list[tuple] - """ - to_return = [] - for condition_name, condition in batch.items(): - to_return.append((condition_name, condition.to(device))) - return to_return diff --git a/pina/_src/data/dataset.py b/pina/_src/data/dataset.py deleted file mode 100644 index dcad84662..000000000 --- a/pina/_src/data/dataset.py +++ /dev/null @@ -1,328 +0,0 @@ -"""Module for the PINA dataset classes.""" - -from abc import abstractmethod, ABC -from torch.utils.data import Dataset -from torch_geometric.data import Data -from pina._src.core.graph import Graph, LabelBatch - -# TODO: the whole file seems to be unused, check if it can be safely deleted. - - -class PinaDatasetFactory: - """ - Factory class for the PINA dataset. - - Depending on the data type inside the conditions, it instanciate an object - belonging to the appropriate subclass of - :class:`~pina.data.dataset.PinaDataset`. The possible subclasses are: - - - :class:`~pina.data.dataset.PinaTensorDataset`, for handling \ - :class:`torch.Tensor` and :class:`~pina.label_tensor.LabelTensor` data. - - :class:`~pina.data.dataset.PinaGraphDataset`, for handling \ - :class:`~pina.graph.Graph` and :class:`~torch_geometric.data.Data` data. - """ - - def __new__(cls, conditions_dict, **kwargs): - """ - Instantiate the appropriate subclass of - :class:`~pina.data.dataset.PinaDataset`. - - If a graph is present in the conditions, returns a - :class:`~pina.data.dataset.PinaGraphDataset`, otherwise returns a - :class:`~pina.data.dataset.PinaTensorDataset`. - - :param dict conditions_dict: Dictionary containing all the conditions - to be included in the dataset instance. - :return: A subclass of :class:`~pina.data.dataset.PinaDataset`. - :rtype: PinaTensorDataset | PinaGraphDataset - - :raises ValueError: If an empty dictionary is provided. - """ - - # Check if conditions_dict is empty - if len(conditions_dict) == 0: - raise ValueError("No conditions provided") - - # Check is a Graph is present in the conditions - is_graph = cls._is_graph_dataset(conditions_dict) - if is_graph: - # If a Graph is present, return a PinaGraphDataset - return PinaGraphDataset(conditions_dict, **kwargs) - # If no Graph is present, return a PinaTensorDataset - return PinaTensorDataset(conditions_dict, **kwargs) - - @staticmethod - def _is_graph_dataset(conditions_dict): - """ - Check if a graph is present in the conditions (at least one time). - - :param conditions_dict: Dictionary containing the conditions. - :type conditions_dict: dict - :return: True if a graph is present in the conditions, False otherwise. - :rtype: bool - """ - - # Iterate over the conditions dictionary - for v in conditions_dict.values(): - # Iterate over the values of the current condition - for cond in v.values(): - # Check if the current value is a list of Data objects - if isinstance(cond, (Data, Graph, list, tuple)): - return True - return False - - -class PinaDataset(Dataset, ABC): - """ - Abstract class for the PINA dataset which extends the PyTorch - :class:`~torch.utils.data.Dataset` class. It defines the common interface - for :class:`~pina.data.dataset.PinaTensorDataset` and - :class:`~pina.data.dataset.PinaGraphDataset` classes. - """ - - def __init__( - self, conditions_dict, max_conditions_lengths, automatic_batching - ): - """ - Initialize the instance by storing the conditions dictionary, the - maximum number of items per conditions to consider, and the automatic - batching flag. - - :param dict conditions_dict: A dictionary mapping condition names to - their respective data. Each key represents a condition name, and the - corresponding value is a dictionary containing the associated data. - :param dict max_conditions_lengths: Maximum number of data points that - can be included in a single batch per condition. - :param bool automatic_batching: Indicates whether PyTorch automatic - batching is enabled in - :class:`~pina.data.data_module.PinaDataModule`. - """ - - # Store the conditions dictionary - self.conditions_dict = conditions_dict - # Store the maximum number of conditions to consider - self.max_conditions_lengths = max_conditions_lengths - # Store length of each condition - self.conditions_length = { - k: len(v["input"]) for k, v in self.conditions_dict.items() - } - # Store the maximum length of the dataset - self.length = max(self.conditions_length.values()) - # Dynamically set the getitem function based on automatic batching - if automatic_batching: - self._getitem_func = self._getitem_int - else: - self._getitem_func = self._getitem_dummy - - def _get_max_len(self): - """ - Returns the length of the longest condition in the dataset. - - :return: Length of the longest condition in the dataset. - :rtype: int - """ - - max_len = 0 - for condition in self.conditions_dict.values(): - max_len = max(max_len, len(condition["input"])) - return max_len - - def __len__(self): - return self.length - - def __getitem__(self, idx): - return self._getitem_func(idx) - - def _getitem_dummy(self, idx): - """ - Return the index itself. This is used when automatic batching is - disabled to postpone the data retrieval to the dataloader. - - :param int idx: Index. - :return: Index. - :rtype: int - """ - - # If automatic batching is disabled, return the data at the given index - return idx - - def _getitem_int(self, idx): - """ - Return the data at the given index in the dataset. This is used when - automatic batching is enabled. - - :param int idx: Index. - :return: A dictionary containing the data at the given index. - :rtype: dict - """ - - # If automatic batching is enabled, return the data at the given index - return { - k: {k_data: v[k_data][idx % len(v["input"])] for k_data in v.keys()} - for k, v in self.conditions_dict.items() - } - - def get_all_data(self): - """ - Return all data in the dataset. - - :return: A dictionary containing all the data in the dataset. - :rtype: dict - """ - to_return_dict = {} - for condition, data in self.conditions_dict.items(): - len_condition = len( - data["input"] - ) # Length of the current condition - to_return_dict[condition] = self._retrive_data( - data, list(range(len_condition)) - ) # Retrieve the data from the current condition - return to_return_dict - - def fetch_from_idx_list(self, idx): - """ - Return data from the dataset given a list of indices. - - :param list[int] idx: List of indices. - :return: A dictionary containing the data at the given indices. - :rtype: dict - """ - - to_return_dict = {} - for condition, data in self.conditions_dict.items(): - # Get the indices for the current condition - cond_idx = idx[: self.max_conditions_lengths[condition]] - # Get the length of the current condition - condition_len = self.conditions_length[condition] - # If the length of the dataset is greater than the length of the - # current condition, repeat the indices - if self.length > condition_len: - cond_idx = [idx % condition_len for idx in cond_idx] - # Retrieve the data from the current condition - to_return_dict[condition] = self._retrive_data(data, cond_idx) - return to_return_dict - - @abstractmethod - def _retrive_data(self, data, idx_list): - """ - Abstract method to retrieve data from the dataset given a list of - indices. - """ - - -class PinaTensorDataset(PinaDataset): - """ - Dataset class for the PINA dataset with :class:`torch.Tensor` and - :class:`~pina.label_tensor.LabelTensor` data. - """ - - # Override _retrive_data method for torch.Tensor data - def _retrive_data(self, data, idx_list): - """ - Retrieve data from the dataset given a list of indices. - - :param dict data: Dictionary containing the data - (only :class:`torch.Tensor` or - :class:`~pina.label_tensor.LabelTensor`). - :param list[int] idx_list: indices to retrieve. - :return: Dictionary containing the data at the given indices. - :rtype: dict - """ - - return {k: v[idx_list] for k, v in data.items()} - - @property - def input(self): - """ - Return the input data for the dataset. - - :return: Dictionary containing the input points. - :rtype: dict - """ - return {k: v["input"] for k, v in self.conditions_dict.items()} - - def update_data(self, new_conditions_dict): - """ - Update the dataset with new data. - This method is used to update the dataset with new data. It replaces - the current data with the new data provided in the new_conditions_dict - parameter. - - :param dict new_conditions_dict: Dictionary containing the new data. - :return: None - """ - for condition, data in new_conditions_dict.items(): - if condition in self.conditions_dict: - self.conditions_dict[condition].update(data) - else: - self.conditions_dict[condition] = data - - -class PinaGraphDataset(PinaDataset): - """ - Dataset class for the PINA dataset with :class:`~torch_geometric.data.Data` - and :class:`~pina.graph.Graph` data. - """ - - def _create_graph_batch(self, data): - """ - Create a LabelBatch object from a list of - :class:`~torch_geometric.data.Data` objects. - - :param data: List of items to collate in a single batch. - :type data: list[Data] | list[Graph] - :return: LabelBatch object all the graph collated in a single batch - disconnected graphs. - :rtype: LabelBatch - """ - batch = LabelBatch.from_data_list(data) - return batch - - def create_batch(self, data): - """ - Create a Batch object from a list of :class:`~torch_geometric.data.Data` - objects. - - :param data: List of items to collate in a single batch. - :type data: list[Data] | list[Graph] - :return: Batch object. - :rtype: :class:`~torch_geometric.data.Batch` - | :class:`~pina.graph.LabelBatch` - """ - - if isinstance(data[0], Data): - return self._create_graph_batch(data) - return self._create_tensor_batch(data) - - # Override _retrive_data method for graph handling - def _retrive_data(self, data, idx_list): - """ - Retrieve data from the dataset given a list of indices. - - :param dict data: Dictionary containing the data. - :param list[int] idx_list: List of indices to retrieve. - :return: Dictionary containing the data at the given indices. - :rtype: dict - """ - - # Return the data from the current condition - # If the data is a list of Data objects, create a Batch object - # If the data is a list of torch.Tensor objects, create a torch.Tensor - return { - k: ( - self._create_graph_batch([v[i] for i in idx_list]) - if isinstance(v, list) - else v[idx_list] - ) - for k, v in data.items() - } - - @property - def input(self): - """ - Return the input data for the dataset. - - :return: Dictionary containing the input points. - :rtype: dict - """ - return {k: v["input"] for k, v in self.conditions_dict.items()} diff --git a/pina/_src/data/dummy_dataloader.py b/pina/_src/data/dummy_dataloader.py deleted file mode 100644 index c236e9d30..000000000 --- a/pina/_src/data/dummy_dataloader.py +++ /dev/null @@ -1,62 +0,0 @@ -""" -Module containing the ``DummyDataloader`` class -""" - -import torch - - -class DummyDataloader: - """ - A dummy dataloader that returns the entire dataset in a single batch. This - is used when the batch size is ``None``. It supports both distributed and - non-distributed environments. In a distributed environment, it divides the - dataset across processes using the rank and world size, fetching only the - portion of data corresponding to the current process. In a non-distributed - environment, it fetches the entire dataset. - """ - - def __init__(self, dataset): - """ - Prepare a dataloader object that returns the entire dataset in a single - batch. Depending on the number of GPUs, the dataset is managed - as follows: - - - **Distributed Environment** (multiple GPUs): Divides dataset across - processes using the rank and world size. Fetches only portion of - data corresponding to the current process. - - **Non-Distributed Environment** (single GPU): Fetches the entire - dataset. - - :param PinaDataset dataset: The dataset object to be processed. - - .. note:: - This dataloader is used when the batch size is ``None``. - """ - - if ( - torch.distributed.is_available() - and torch.distributed.is_initialized() - ): - rank = torch.distributed.get_rank() - world_size = torch.distributed.get_world_size() - if len(dataset) < world_size: - raise RuntimeError( - "Dimension of the dataset smaller than world size." - " Increase the size of the partition or use a single GPU" - ) - idx, i = [], rank - while i < len(dataset): - idx.append(i) - i += world_size - self.dataset = dataset.fetch_from_idx_list(idx).to_batch() - else: - self.dataset = dataset.get_all_data().to_batch() - - def __iter__(self): - return self - - def __len__(self): - return 1 - - def __next__(self): - return self.dataset diff --git a/pina/_src/data/single_batch_data_loader.py b/pina/_src/data/single_batch_data_loader.py new file mode 100644 index 000000000..bec4cf93e --- /dev/null +++ b/pina/_src/data/single_batch_data_loader.py @@ -0,0 +1,106 @@ +"""Module for the Single-Batch Data Loader class.""" + +import torch + + +class _SingleBatchDataLoader: + """ + Data loader wrapper that returns the entire dataset as a single batch. + + This utility is intended for cases where mini-batching is disabled (e.g. + ``batch_size=None``). The loader yields exactly one batch per iteration. + + In distributed environments, the dataset is automatically partitioned across + processes according to the current rank and world size. Each process + receives only its corresponding subset of data. + + In non-distributed environments, the full dataset is returned. + """ + + def __init__(self, dataset): + """ + Initialization of the :class:`_SingleBatchDataLoader` class. + + In distributed training, the dataset indices are split across processes + using the current rank and world size, so that each process receives + only its corresponding subset of data. + + In non-distributed training, the full dataset is loaded. + + The resulting data is converted into a single batch and stored + internally. + + :param dataset: Dataset object. + :raises RuntimeError: If the dataset size is smaller than the number of + distributed processes. + """ + # Initialize the flag to track if the batch has been yielded + self._has_yielded = False + + # Distributed training + if ( + torch.distributed.is_available() + and torch.distributed.is_initialized() + ): + # Get rank and world_size + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + + # Raise runtime error if the dataset is smaller than the world size + if len(dataset) < world_size: + raise RuntimeError( + "Dataset size is smaller than the distributed world size. " + "Increase the dataset size or use a single GPU." + ) + + # Select dataset idx assigned to the current distributed process + idx, i = [], rank + while i < len(dataset): + idx.append(i) + i += world_size + + # Fetch the process-specific subset + self.dataset = dataset.fetch_from_idx_list(idx).to_batch() + + # Non-distributed training + else: + self.dataset = dataset.get_all_data().to_batch() + + def __iter__(self): + """ + Return the data loader iterator. + + :return: The data loader instance itself. + :rtype: _SingleBatchDataLoader + """ + # Reset the flag to yield the batch again if iterator is restarted + self._has_yielded = False + return self + + def __len__(self): + """ + Return the number of batches produced by the data loader. + + Since the entire dataset is returned as a single batch, the length is + always ``1``. + + :return: The number of batches. + :rtype: int + """ + return 1 + + def __next__(self): + """ + Return the next batch. + + :return: The dataset converted into a single batch. + :rtype: _BatchManager + """ + # Yield the batch only once per iteration + if self._has_yielded: + raise StopIteration + + # Set the flag to indicate that the batch has been yielded + self._has_yielded = True + + return self.dataset diff --git a/pina/_src/problem/base_problem.py b/pina/_src/problem/base_problem.py index 28f64c54f..3408b7bdf 100644 --- a/pina/_src/problem/base_problem.py +++ b/pina/_src/problem/base_problem.py @@ -296,14 +296,3 @@ def are_all_domains_discretised(self): :rtype: bool """ return all(d in self.discretised_domains for d in self.domains) - - -# Back-compatibility with version 0.2, to be removed soon -class AbstractProblem(BaseProblem): - def __init__(self, *args, **kwargs): - warnings.warn( - "AbstractProblem is deprecated, use BaseProblem instead.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/pina/data/__init__.py b/pina/data/__init__.py index f274d5bd9..1ebcf2b9f 100644 --- a/pina/data/__init__.py +++ b/pina/data/__init__.py @@ -1,14 +1,33 @@ -"""Data management utilities for PINA. - -This module provides specialized Dataset and DataModule implementations -designed to handle physical coordinates, experimental observations, and -graph-structured data within the PINA training pipeline. -""" - -from pina._src.data.data_module import ( - PinaDataModule, -) +"""Module containing utilities for dataset and data loader management.""" __all__ = [ - "PinaDataModule", + "DataModule", + "_SingleBatchDataLoader", + "_Aggregator", + "_Creator", + "_ConditionSubset", ] + +from pina._src.data.data_module import DataModule +from pina._src.data.single_batch_data_loader import _SingleBatchDataLoader +from pina._src.data.aggregator import _Aggregator +from pina._src.data.creator import _Creator +from pina._src.data.condition_subset import _ConditionSubset + +# Back-compatibility with version 0.2, to be removed soon +import warnings + +_DEPRECATED_IMPORTS = {"PinaDataModule": "DataModule"} + + +def __getattr__(name): + if name in _DEPRECATED_IMPORTS: + + warnings.warn( + f"Importing '{name}' from 'pina.data' is deprecated; use " + f"pina.data.{_DEPRECATED_IMPORTS[name]} instead.", + DeprecationWarning, + stacklevel=2, + ) + + return globals()[_DEPRECATED_IMPORTS[name]] diff --git a/pina/problem/__init__.py b/pina/problem/__init__.py index dd8ae0950..3248c22e5 100644 --- a/pina/problem/__init__.py +++ b/pina/problem/__init__.py @@ -1,7 +1,6 @@ """Module for the Problems.""" __all__ = [ - "AbstractProblem", # back-compatibility with version 0.2, to be removed soon "ProblemInterface", "BaseProblem", "SpatialProblem", @@ -18,4 +17,19 @@ from pina._src.problem.inverse_problem import InverseProblem # Back-compatibility with version 0.2, to be removed soon -from pina._src.problem.base_problem import AbstractProblem +import warnings + +_DEPRECATED_IMPORTS = {"AbstractProblem": "BaseProblem"} + + +def __getattr__(name): + if name in _DEPRECATED_IMPORTS: + + warnings.warn( + f"Importing '{name}' from 'pina.problem' is deprecated; use " + f"pina.problem.{_DEPRECATED_IMPORTS[name]} instead.", + DeprecationWarning, + stacklevel=2, + ) + + return globals()[_DEPRECATED_IMPORTS[name]] diff --git a/tests/test_data/test_aggregator.py b/tests/test_data/test_aggregator.py new file mode 100644 index 000000000..2f79f213e --- /dev/null +++ b/tests/test_data/test_aggregator.py @@ -0,0 +1,112 @@ +import pytest +from pina.data import _Aggregator + + +""" +Note: this test intentionally avoids relying on the actual DataLoader +implementation in order to keep the test focused on the aggregator logic itself +and independent from the behavior of external classes. The full pipeline is +tested in the DataLoader tests, which ensures that the aggregator works +correctly when used in the intended context. +""" + + +# Define a dummy dataloader for testing purposes +class DummyDataloader: + def __init__(self, data): + self.data = data + + def __iter__(self): + return iter(self.data) + + def __len__(self): + return len(self.data) + + +# Create dataloaders for testing +data_loaders1 = { + "condition_1": DummyDataloader([1, 2, 3]), + "condition_2": DummyDataloader([10, 20, 30]), +} +data_loaders2 = { + "condition_1": DummyDataloader([1, 2]), + "condition_2": DummyDataloader([10, 20, 30, 40, 50]), +} +data_loaders3 = { + "condition_1": DummyDataloader([1]), + "condition_2": DummyDataloader([10, 20, 30]), +} + +# Create expected batches for testing +expected_batches1 = [ + {"condition_1": 1, "condition_2": 10}, + {"condition_1": 2, "condition_2": 20}, + {"condition_1": 3, "condition_2": 30}, +] +expected_batches2 = [ + {"condition_1": 1, "condition_2": 10}, + {"condition_1": 2, "condition_2": 20}, + {"condition_1": 1, "condition_2": 30}, + {"condition_1": 2, "condition_2": 40}, + {"condition_1": 1, "condition_2": 50}, +] +expected_batches3 = [ + {"condition_1": 1, "condition_2": 10}, + {"condition_1": 1, "condition_2": 20}, + {"condition_1": 1, "condition_2": 30}, +] + + +@pytest.mark.parametrize("batching_mode", ["common_batch_size", "proportional"]) +def test_constructor(batching_mode): + + # Create dummy dataloaders + dataloaders = { + "condition_1": DummyDataloader([1, 2, 3]), + "condition_2": DummyDataloader([10, 20]), + } + + # Initialize the aggregator + _Aggregator(dataloaders, batching_mode=batching_mode) + + # Should raise NotImplementedError for separate_conditions mode + with pytest.raises(NotImplementedError): + _Aggregator(dataloaders, batching_mode="separate_conditions") + + +@pytest.mark.parametrize("batching_mode", ["common_batch_size", "proportional"]) +def test_len(batching_mode): + + # Create dummy dataloaders + dataloaders = { + "condition_1": DummyDataloader([1, 2]), + "condition_2": DummyDataloader([10, 20, 30]), + } + + # Initialize the aggregator and check its length + aggregator = _Aggregator(dataloaders, batching_mode=batching_mode) + assert len(aggregator) == 3 + + +@pytest.mark.parametrize("batching_mode", ["common_batch_size", "proportional"]) +@pytest.mark.parametrize( + "dataloaders, expected", + [ + (data_loaders1, expected_batches1), + (data_loaders2, expected_batches2), + (data_loaders3, expected_batches3), + ], +) +def test_iter(batching_mode, dataloaders, expected): + + # Initialize the aggregator + aggregator = _Aggregator(dataloaders, batching_mode=batching_mode) + + # Check yielded batches + assert list(aggregator) == expected + + # Check that the number of yielded batches matches len(aggregator) + assert len(expected) == len(aggregator) + + # Check that the aggregator can be iterated multiple times + assert list(aggregator) == expected diff --git a/tests/test_data/test_condition_subset.py b/tests/test_data/test_condition_subset.py new file mode 100644 index 000000000..f7e54afb0 --- /dev/null +++ b/tests/test_data/test_condition_subset.py @@ -0,0 +1,149 @@ +import torch +import pytest +from pina.equation.zoo import FixedValue +from pina import Condition, LabelTensor +from pina.domain import CartesianDomain +from pina.data import _ConditionSubset + +# Define an equation and a domain for testing purposes +equation = FixedValue(value=0.0) +domain = CartesianDomain({"x": [0, 1], "y": [0, 1]}) + +# Define input and target tensors for testing purposes +n_val, n_dim = 5, 2 +input_tensor = torch.rand(n_val, n_dim) +input_label_tensor = LabelTensor(torch.rand(n_val, n_dim), labels=["x", "y"]) +target_tensor = torch.rand(n_val, n_dim) +cond_vars = torch.rand(n_val, 1) + +# Define conditions for testing purposes +# Domain - equation condition is not tested as __get_item__ is not implemented +input_target_cond = Condition(input=input_tensor, target=target_tensor) +input_equation_cond = Condition(input=input_label_tensor, equation=equation) +data_cond = Condition(input=input_tensor, conditional_variables=cond_vars) + +# Define indexes for testing purposes +indices = torch.randperm(n_val).tolist() + + +@pytest.mark.parametrize("automatic_batching", [True, False]) +@pytest.mark.parametrize("indices", [indices[:3], indices[:2]]) +@pytest.mark.parametrize( + "condition", [input_target_cond, input_equation_cond, data_cond] +) +def test_constructor(condition, automatic_batching, indices): + + # Initialize the condition subset + subset = _ConditionSubset( + condition=condition, + indices=indices, + automatic_batching=automatic_batching, + ) + + # Verify that the attributes are correctly assigned + assert subset.condition is condition + assert subset.indices == indices + assert subset.automatic_batching is automatic_batching + assert subset.dataset_length == len(indices) + assert subset.iterable_length == len(indices) + + +@pytest.mark.parametrize("automatic_batching", [True, False]) +@pytest.mark.parametrize("indices", [indices[:3], indices[:2]]) +@pytest.mark.parametrize( + "condition", [input_target_cond, input_equation_cond, data_cond] +) +def test_len(condition, automatic_batching, indices): + + # Initialize the condition subset + subset = _ConditionSubset( + condition=condition, + indices=indices, + automatic_batching=automatic_batching, + ) + + # Verify that the length of the subset is correctly computed + assert len(subset) == len(indices) + + +@pytest.mark.parametrize("automatic_batching", [True, False]) +@pytest.mark.parametrize("indices", [indices[:3], indices[:2]]) +@pytest.mark.parametrize( + "condition", [input_target_cond, input_equation_cond, data_cond] +) +def test_get_item(condition, automatic_batching, indices): + + # Initialize the condition subset + subset = _ConditionSubset( + condition=condition, + indices=indices, + automatic_batching=automatic_batching, + ) + + # Verify that the correct data is returned for each index in the subset + for local_idx in range(len(indices)): + + # Retrieve the true dataset index + true_idx = indices[local_idx] + + # If automatic batching, check data equivalence + if automatic_batching: + + # Save actual and expected data for debugging purposes + actual_data = subset[local_idx].data + expected_data = condition[true_idx].data + + # Check that the keys of the returned data match + assert actual_data.keys() == expected_data.keys() + + # Check that the values of the returned data are equal + for key in actual_data: + assert torch.equal(actual_data[key], expected_data[key]) + + # Otherwise, check that the raw dataset index is returned + else: + assert subset[local_idx] == true_idx + + # Check cyclic indexing + cyclic_idx = len(indices) + true_idx = indices[0] + + # If automatic batching, check data equivalence for cyclic index + if automatic_batching: + + # Check that the keys of the returned data match + assert subset[cyclic_idx].data.keys() == condition[true_idx].data.keys() + + # Check that the values of the returned data are equal + for key in actual_data: + assert torch.equal(actual_data[key], expected_data[key]) + + # Otherwise, check that the raw dataset index is returned for cyclic index + else: + assert subset[cyclic_idx] == true_idx + + +@pytest.mark.parametrize("automatic_batching", [True, False]) +@pytest.mark.parametrize( + "condition", + [input_target_cond, input_equation_cond, data_cond], +) +def test_get_all_data(condition, automatic_batching): + + # Initialize the condition subset + subset = _ConditionSubset( + condition=condition, + indices=indices, + automatic_batching=automatic_batching, + ) + + # Retrieve all data from the subset and check that it matches expected data + data = subset.get_all_data() + expected = condition[indices] + + # Check that the keys of the returned data match + assert data.keys == expected.keys + + # Check that the values of the returned data are equal + for key in data.keys: + assert torch.equal(data.data[key], expected.data[key]) diff --git a/tests/test_data/test_creator.py b/tests/test_data/test_creator.py new file mode 100644 index 000000000..c173e7f29 --- /dev/null +++ b/tests/test_data/test_creator.py @@ -0,0 +1,159 @@ +import torch +import pytest +from pina.data import _Creator + + +""" +Note: this test intentionally avoids relying on the actual Condition and +DataLoader implementations in order to keep the test focused on the creator +logic itself and independent from the behavior of external classes. The full +pipeline is tested in the DataLoader tests, which ensures that the creator works +correctly when used in the intended context. +""" + + +# Define a dummy dataset for testing purposes +class DummyDataset: + def __init__(self, data, length=None): + self.data = data + self.dataset_length = len(data) if length is None else length + self.iterable_length = None + + def __len__(self): + return len(self.data) + + +# Define a dummy dataloader for testing purposes +class DummyDataloader: + def create_dataloader( + self, + dataset, + batch_size, + automatic_batching, + sampler, + num_workers, + pin_memory, + ): + return { + "dataset": dataset, + "batch_size": batch_size, + "automatic_batching": automatic_batching, + "sampler": sampler, + "num_workers": num_workers, + "pin_memory": pin_memory, + } + + +# Create dataloaders for testing +dataloaders = { + "dataset_1": DummyDataloader(), + "dataset_2": DummyDataloader(), +} + + +@pytest.mark.parametrize( + "batching_mode", + ["common_batch_size", "separate_conditions", "proportional"], +) +def test_constructor(batching_mode): + + _Creator( + batching_mode=batching_mode, + batch_size=4, + shuffle=False, + automatic_batching=True, + num_workers=0, + pin_memory=False, + conditions=dataloaders, + ) + + +@pytest.mark.parametrize( + "batching_mode, batch_size, expected_batch_sizes, expected_max_len", + [ + ( + "common_batch_size", + 2, + {"dataset_1": 2, "dataset_2": 2}, + {"dataset_1": None, "dataset_2": 4}, + ), + ( + "common_batch_size", + None, + {"dataset_1": 3, "dataset_2": 4}, + {"dataset_1": 4, "dataset_2": None}, + ), + ( + "separate_conditions", + 2, + {"dataset_1": 2, "dataset_2": 2}, + {"dataset_1": None, "dataset_2": None}, + ), + ( + "separate_conditions", + None, + {"dataset_1": 3, "dataset_2": 4}, + {"dataset_1": None, "dataset_2": None}, + ), + ( + "proportional", + 4, + {"dataset_1": 1, "dataset_2": 3}, + {"dataset_1": None, "dataset_2": None}, + ), + ], +) +@pytest.mark.parametrize("shuffle", [True, False]) +def test_call( + batching_mode, + batch_size, + expected_batch_sizes, + expected_max_len, + shuffle, +): + + # Create dummy datasets + datasets = { + "dataset_1": DummyDataset([1, 2, 3], length=2), + "dataset_2": DummyDataset([10, 20, 30, 40], length=4), + } + + # Initialize the creator + creator = _Creator( + batching_mode=batching_mode, + batch_size=batch_size, + shuffle=shuffle, + automatic_batching=True, + num_workers=0, + pin_memory=False, + conditions=dataloaders, + ) + + # Call the creator to create dataloaders + created_loaders = creator(datasets) + + # Check that dataloaders are created for all conditions + assert set(created_loaders.keys()) == set(datasets.keys()) + + # Iterate over datasets + for name in datasets: + + # Assert that the dataloader is created with the correct parameters + assert created_loaders[name]["dataset"] is datasets[name] + assert created_loaders[name]["batch_size"] == expected_batch_sizes[name] + assert created_loaders[name]["automatic_batching"] is True + assert created_loaders[name]["num_workers"] == 0 + assert created_loaders[name]["pin_memory"] is False + assert datasets[name].iterable_length == expected_max_len[name] + + # Check that the correct sampler is used based on the shuffle parameter + if shuffle: + assert isinstance( + created_loaders[name]["sampler"], + torch.utils.data.RandomSampler, + ) + else: + assert isinstance( + created_loaders[name]["sampler"], + torch.utils.data.SequentialSampler, + ) diff --git a/tests/test_data/test_data_module.py b/tests/test_data/test_data_module.py index 8419a68f2..9bb81fbad 100644 --- a/tests/test_data/test_data_module.py +++ b/tests/test_data/test_data_module.py @@ -1,318 +1,170 @@ import torch import pytest -from pina.data import PinaDataModule - -# from pina.data import PinaTensorDataset, PinaGraphDataset -from pina.problem.zoo import SupervisedProblem +from copy import copy +from pina.problem.zoo import SupervisedProblem, Poisson2DSquareProblem +from pina.data import DataModule, _ConditionSubset from pina.graph import RadiusGraph -# from pina.data import DummyDataloader -from pina._src.data.data_module import _ConditionSubset -from pina import Trainer -from pina.solver import SupervisedSolver -from torch_geometric.data import Batch -from torch.utils.data import DataLoader -from pina.problem.zoo import Poisson2DSquareProblem -from pina._src.data.aggregator import _Aggregator -from pina.solver import PINN +# Number of samples in the synthetic datasets +n_samples = 100 + +# Define helper functions to create synthetic tensor data +def _create_tensor_data(n=n_samples): + return (torch.rand((n, 4)), torch.rand((n, 2))) -def _create_tensor_data(): - input_tensor = torch.rand((100, 10)) - output_tensor = torch.rand((100, 2)) - return input_tensor, output_tensor +# Define helper function to create synthetic graph data +def _create_graph_data(n=n_samples): -def _create_graph_data(): - x = torch.rand((100, 50, 10)) - pos = torch.rand((100, 50, 2)) - input_graph = [ - RadiusGraph(x=x_, pos=pos_, radius=0.2) for x_, pos_, in zip(x, pos) + # Define input graphs and output tensor + input_graphs = [ + RadiusGraph(x=torch.rand((20, 4)), pos=torch.rand((20, 2)), radius=0.2) + for _ in range(n) ] - output_graph = torch.rand((100, 50, 2)) - return input_graph, output_graph + output_tensor = torch.rand((n, 50, 2)) + return input_graphs, output_tensor -def test_init_tensor(): - input_tensor, output_tensor = _create_tensor_data() - problem = SupervisedProblem(input_=input_tensor, output_=output_tensor) - dm = PinaDataModule(problem) - assert dm.problem == problem - assert dm.trainer is None - assert hasattr(dm, "split_idxs") - assert isinstance(dm.split_idxs, dict) - assert set(dm.split_idxs.keys()) == {"data"} - assert isinstance(dm.split_idxs["data"], dict) - assert set(dm.split_idxs["data"].keys()) == {"train", "val", "test"} - assert isinstance(dm.split_idxs["data"]["train"], list) - assert isinstance(dm.split_idxs["data"]["val"], list) - assert isinstance(dm.split_idxs["data"]["test"], list) - assert len(dm.split_idxs["data"]["train"]) == 70 - assert len(dm.split_idxs["data"]["val"]) == 10 - assert len(dm.split_idxs["data"]["test"]) == 20 - - -def test_init_graph(): - input_graph, output_graph = _create_graph_data() - problem = SupervisedProblem(input_=input_graph, output_=output_graph) - dm = PinaDataModule(problem) - assert dm.problem == problem - assert dm.trainer is None - assert hasattr(dm, "split_idxs") - assert isinstance(dm.split_idxs, dict) - assert set(dm.split_idxs.keys()) == {"data"} - assert isinstance(dm.split_idxs["data"], dict) - assert set(dm.split_idxs["data"].keys()) == {"train", "val", "test"} - assert isinstance(dm.split_idxs["data"]["train"], list) - assert isinstance(dm.split_idxs["data"]["val"], list) - assert isinstance(dm.split_idxs["data"]["test"], list) - assert len(dm.split_idxs["data"]["train"]) == 70 - assert len(dm.split_idxs["data"]["val"]) == 10 - assert len(dm.split_idxs["data"]["test"]) == 20 - - -def test_init_poisson(): - problem = Poisson2DSquareProblem() - problem.discretise_domain(n=10, mode="grid") - dm = PinaDataModule(problem) - assert dm.problem == problem - assert dm.trainer is None - assert hasattr(dm, "split_idxs") - assert isinstance(dm.split_idxs, dict) - assert set(dm.split_idxs.keys()) == {"D", "boundary"} - assert isinstance(dm.split_idxs["D"], dict) - assert set(dm.split_idxs["D"].keys()) == {"train", "val", "test"} - assert isinstance(dm.split_idxs["D"]["train"], list) - assert isinstance(dm.split_idxs["D"]["val"], list) - assert isinstance(dm.split_idxs["D"]["test"], list) - assert len(dm.split_idxs["D"]["train"]) == 70 - assert len(dm.split_idxs["D"]["val"]) == 10 - assert len(dm.split_idxs["D"]["test"]) == 20 - - assert isinstance(dm.split_idxs["boundary"], dict) - assert set(dm.split_idxs["boundary"].keys()) == {"train", "val", "test"} - assert isinstance(dm.split_idxs["boundary"]["train"], list) - assert isinstance(dm.split_idxs["boundary"]["val"], list) - assert isinstance(dm.split_idxs["boundary"]["test"], list) - assert len(dm.split_idxs["boundary"]["train"]) == 7 - assert len(dm.split_idxs["boundary"]["val"]) == 1 - assert len(dm.split_idxs["boundary"]["test"]) == 2 - - -def test_setup_tensor(): - input_tensor, output_tensor = _create_tensor_data() - problem = SupervisedProblem(input_=input_tensor, output_=output_tensor) - dm = PinaDataModule(problem) - dm.setup() - assert hasattr(dm, "train_datasets") - assert isinstance(dm.train_datasets, dict) - assert set(dm.train_datasets.keys()) == {"data"} - assert isinstance(dm.train_datasets["data"], _ConditionSubset) - assert hasattr(dm, "val_datasets") - assert isinstance(dm.val_datasets, dict) - assert set(dm.val_datasets.keys()) == {"data"} - assert isinstance(dm.val_datasets["data"], _ConditionSubset) - assert hasattr(dm, "test_datasets") - assert isinstance(dm.test_datasets, dict) - assert set(dm.test_datasets.keys()) == {"data"} - assert isinstance(dm.test_datasets["data"], _ConditionSubset) - - -def test_setup_graph(): - input_graph, output_graph = _create_graph_data() - problem = SupervisedProblem(input_=input_graph, output_=output_graph) - dm = PinaDataModule(problem) - dm.setup() - assert hasattr(dm, "train_datasets") - assert isinstance(dm.train_datasets, dict) - assert set(dm.train_datasets.keys()) == {"data"} - assert isinstance(dm.train_datasets["data"], _ConditionSubset) - assert hasattr(dm, "val_datasets") - assert isinstance(dm.val_datasets, dict) - assert set(dm.val_datasets.keys()) == {"data"} - assert isinstance(dm.val_datasets["data"], _ConditionSubset) - assert hasattr(dm, "test_datasets") - assert isinstance(dm.test_datasets, dict) - assert set(dm.test_datasets.keys()) == {"data"} - assert isinstance(dm.test_datasets["data"], _ConditionSubset) - - -def test_setup_poisson(): - problem = Poisson2DSquareProblem() - problem.discretise_domain(n=10, mode="grid") - dm = PinaDataModule(problem) - dm.setup() - assert hasattr(dm, "train_datasets") - assert isinstance(dm.train_datasets, dict) - assert set(dm.train_datasets.keys()) == {"D", "boundary"} - assert isinstance(dm.train_datasets["D"], _ConditionSubset) - assert isinstance(dm.train_datasets["boundary"], _ConditionSubset) - assert hasattr(dm, "val_datasets") - assert isinstance(dm.val_datasets, dict) - assert set(dm.val_datasets.keys()) == {"D", "boundary"} - assert isinstance(dm.val_datasets["D"], _ConditionSubset) - assert isinstance(dm.val_datasets["boundary"], _ConditionSubset) - assert hasattr(dm, "test_datasets") - assert isinstance(dm.test_datasets, dict) - assert set(dm.test_datasets.keys()) == {"D", "boundary"} - assert isinstance(dm.test_datasets["D"], _ConditionSubset) - assert isinstance(dm.test_datasets["boundary"], _ConditionSubset) - - -@pytest.mark.parametrize("batch_size", [None, 1, 5, 20]) -def test_dataloader_tensor(batch_size): - input_tensor, output_tensor = _create_tensor_data() - problem = SupervisedProblem(input_=input_tensor, output_=output_tensor) - trainer = Trainer( - solver=SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10)), - batch_size=batch_size, - train_size=0.7, - val_size=0.2, - test_size=0.1, - ) - dm = trainer.data_module - dm.setup() - dataloader = dm.train_dataloader() - assert isinstance(dataloader, _Aggregator) - data = next(iter(dataloader)) - assert isinstance(data, dict) - assert isinstance(data["data"]["input"], torch.Tensor) - assert isinstance(data["data"]["target"], torch.Tensor) - assert ( - len(data["data"]["input"]) == batch_size - if batch_size is not None - else 70 - ) - dataloader = dm.val_dataloader() - assert isinstance(dataloader, _Aggregator) - data = next(iter(dataloader)) - assert isinstance(data, dict) - assert isinstance(data["data"]["input"], torch.Tensor) - assert isinstance(data["data"]["target"], torch.Tensor) - assert ( - len(data["data"]["input"]) == batch_size - if batch_size is not None - else 10 - ) +# Fixture remove data condition from pinns, caused by external tests in suite +@pytest.fixture(autouse=True) +def remove_data_from_pinn_conditions(): + yield + # Remove the data condition + Poisson2DSquareProblem.conditions.pop("data", None) -@pytest.mark.parametrize("batch_size", [None, 1, 5, 20]) -def test_dataloader_graph(batch_size): - input_graph, output_graph = _create_graph_data() - problem = SupervisedProblem(input_=input_graph, output_=output_graph) - trainer = Trainer( - solver=SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10)), - train_size=0.7, - val_size=0.2, - test_size=0.1, - batch_size=batch_size, - ) - dm = trainer.data_module - dm.setup() - dataloader = dm.train_dataloader() - assert isinstance(dataloader, _Aggregator) - data = next(iter(dataloader)) - assert isinstance(data, dict) - assert isinstance(data["data"]["input"], Batch) - assert isinstance(data["data"]["target"], torch.Tensor) - assert ( - len(data["data"]["input"]) == batch_size - if batch_size is not None - else 70 - ) - dataloader = dm.val_dataloader() - assert isinstance(dataloader, _Aggregator) - data = next(iter(dataloader)) - assert isinstance(data, dict) - assert isinstance(data["data"]["input"], Batch) - assert isinstance(data["data"]["target"], torch.Tensor) - assert ( - len(data["data"]["input"]) == batch_size - if batch_size is not None - else 10 - ) +@pytest.mark.parametrize("problem_type", ["tensor", "graph", "pinn"]) +@pytest.mark.parametrize("batch_size", [None, 5]) +@pytest.mark.parametrize( + "train_size, val_size, test_size", + [(0.7, 0.2, 0.1), (0.8, 0.2, 0.0), (0.0, 0.8, 0.2)], +) +def test_constructor(problem_type, batch_size, train_size, val_size, test_size): + + # Build a tensor problem + if problem_type == "tensor": + input_tensor, output_tensor = _create_tensor_data() + problem = SupervisedProblem(input_=input_tensor, output_=output_tensor) + + # Build a graph problem + elif problem_type == "graph": + input_graph, output_graph = _create_graph_data() + problem = SupervisedProblem(input_=input_graph, output_=output_graph) + # Build a pinn problem + elif problem_type == "pinn": + problem = Poisson2DSquareProblem() + problem.discretise_domain(n=n_samples, mode="random") -@pytest.mark.parametrize("batch_size", [None, 1, 5, 20]) -def test_dataloader_poisson_cbs(batch_size): - problem = Poisson2DSquareProblem() - problem.discretise_domain(n=10, mode="grid") - trainer = Trainer( - solver=PINN(problem=problem, model=torch.nn.Linear(10, 10)), + # Initialize the data module + dm = DataModule( + problem=problem, + train_size=train_size, + val_size=val_size, + test_size=test_size, batch_size=batch_size, - val_size=0.1, - test_size=0.2, - train_size=0.7, - batching_mode="common_batch_size", + batching_mode="proportional", + automatic_batching=True, + shuffle=True, + num_workers=0, + pin_memory=False, ) - dm = trainer.data_module - dm.setup() - dataloader = dm.train_dataloader() - assert isinstance(dataloader, _Aggregator) - data = next(iter(dataloader)) - assert isinstance(data, dict) - assert isinstance(data["D"]["input"], torch.Tensor) - assert isinstance(data["D"]["input"], torch.Tensor) - assert isinstance(data["boundary"]["input"], torch.Tensor) - assert isinstance(data["boundary"]["input"], torch.Tensor) - assert ( - len(data["D"]["input"]) == batch_size if batch_size is not None else 70 - ) - assert ( - len(data["boundary"]["input"]) == min(batch_size, 7) - if batch_size is not None - else 7 - ) + # Check that the data module has been initialized correctly + assert dm.problem == problem + assert dm.trainer is None - dataloader = dm.val_dataloader() - assert isinstance(dataloader, _Aggregator) - data = next(iter(dataloader)) - assert isinstance(data, dict) - assert isinstance(data["D"]["input"], torch.Tensor) - assert isinstance(data["D"]["input"], torch.Tensor) - assert isinstance(data["boundary"]["input"], torch.Tensor) - assert isinstance(data["boundary"]["input"], torch.Tensor) - assert ( - len(data["D"]["input"]) == min(batch_size, 10) - if batch_size is not None - else 10 - ) - assert ( - len(data["boundary"]["input"]) == min(batch_size, 1) - if batch_size is not None - else 1 + # Expected keys in the split_idxs dictionary + expected_keys = ( + {"data"} if problem_type in ["tensor", "graph"] else {"D", "boundary"} ) - -@pytest.mark.parametrize("batch_size", [None, 5, 20]) -def test_dataloader_poisson_proportional(batch_size): - problem = Poisson2DSquareProblem() - problem.discretise_domain(n=10, mode="grid") - trainer = Trainer( - solver=PINN(problem=problem, model=torch.nn.Linear(10, 10)), + # Check that the split_idxs attribute has been created correctly + assert hasattr(dm, "split_idxs") + assert isinstance(dm.split_idxs, dict) + assert set(dm.split_idxs.keys()) == expected_keys + + # Iterate over keys in split_idxs + for k in dm.split_idxs.keys(): + + # Assert that the value corresponding to each key is a dictionary + assert isinstance(dm.split_idxs[k], dict) + assert set(dm.split_idxs[k].keys()) == {"train", "val", "test"} + + # Expected lengths of splits + expected_lengths = { + "train": int(train_size * n_samples), + "val": int(val_size * n_samples), + "test": int(test_size * n_samples), + } + + # Iterate over splits + for split in ["train", "val", "test"]: + assert isinstance(dm.split_idxs[k][split], list) + assert len(dm.split_idxs[k][split]) == expected_lengths[split] + + +@pytest.mark.parametrize("problem_type", ["tensor", "graph", "pinn"]) +@pytest.mark.parametrize("batch_size", [None, 5]) +@pytest.mark.parametrize( + "train_size, val_size, test_size", + [(0.7, 0.2, 0.1), (0.8, 0.2, 0.0), (0.0, 0.8, 0.2)], +) +def test_setup(problem_type, batch_size, train_size, val_size, test_size): + + # Build a tensor problem + if problem_type == "tensor": + input_tensor, output_tensor = _create_tensor_data() + problem = SupervisedProblem(input_=input_tensor, output_=output_tensor) + + # Build a graph problem + elif problem_type == "graph": + input_graph, output_graph = _create_graph_data() + problem = SupervisedProblem(input_=input_graph, output_=output_graph) + + # Build a pinn problem + elif problem_type == "pinn": + problem = Poisson2DSquareProblem() + problem.discretise_domain(n=n_samples, mode="random") + + # Initialize the data module + dm = DataModule( + problem=problem, + train_size=train_size, + val_size=val_size, + test_size=test_size, batch_size=batch_size, - val_size=0.1, - test_size=0.2, - train_size=0.7, batching_mode="proportional", + automatic_batching=True, + shuffle=True, + num_workers=0, + pin_memory=False, ) - dm = trainer.data_module + + # Call setup dm.setup() - dataloader = dm.train_dataloader() - assert isinstance(dataloader, _Aggregator) - data = next(iter(dataloader)) - assert isinstance(data, dict) - assert isinstance(data["D"]["input"], torch.Tensor) - assert isinstance(data["D"]["input"], torch.Tensor) - assert isinstance(data["boundary"]["input"], torch.Tensor) - assert isinstance(data["boundary"]["input"], torch.Tensor) - assert ( - len(data["D"]["input"]) == batch_size - 1 - if batch_size is not None - else 70 + # Expected keys in the split_idxs dictionary + expected_keys = ( + {"data"} if problem_type in ["tensor", "graph"] else {"D", "boundary"} ) - assert len(data["boundary"]["input"]) == 1 if batch_size is not None else 7 + + # Iterate over datsets + for dataset in ["train_datasets", "val_datasets", "test_datasets"]: + + # Assert that each dataset has been created correctly + assert hasattr(dm, dataset) + assert isinstance(getattr(dm, dataset), dict) + + # Assert that the keys in each dataset are correct, if not empty + if getattr(dm, dataset): + assert set(getattr(dm, dataset).keys()) == expected_keys + + # Iterate over keys in each dataset + for key in expected_keys: + + # Assert that the corresponding value is a _ConditionSubset + assert isinstance(getattr(dm, dataset)[key], _ConditionSubset) diff --git a/tests/test_data/test_single_batch_data_loader.py b/tests/test_data/test_single_batch_data_loader.py new file mode 100644 index 000000000..14a1aeed2 --- /dev/null +++ b/tests/test_data/test_single_batch_data_loader.py @@ -0,0 +1,111 @@ +import torch +import pytest +from pina.data import _SingleBatchDataLoader + + +# Initialize the test environment +size = 8 +distributed_rank = 1 +distributed_world_size = 3 +full_data_value = "all" + + +# Helper functions for testing +def _distributed_idx(size): + return list(range(distributed_rank, size, distributed_world_size)) + + +# Helper function to set up the distributed environment for testing +def _setup_distributed_environment(monkeypatch, distribute): + monkeypatch.setattr(torch.distributed, "is_available", lambda: True) + monkeypatch.setattr(torch.distributed, "is_initialized", lambda: distribute) + + if distribute: + monkeypatch.setattr( + torch.distributed, "get_rank", lambda: distributed_rank + ) + monkeypatch.setattr( + torch.distributed, + "get_world_size", + lambda: distributed_world_size, + ) + + +# Create a dummy data class for testing purposes +class DummyData: + def __init__(self, value): + self.value = value + + def to_batch(self): + return self + + +# Create a dummy dataset class for testing purposes +class DummyDataset: + def __init__(self, size=size): + self.size = size + self.fetched_indices = None + self.get_all_data_called = False + + def __len__(self): + return self.size + + def get_all_data(self): + self.get_all_data_called = True + return DummyData(full_data_value) + + def fetch_from_idx_list(self, idx): + self.fetched_indices = idx + return DummyData(idx) + + +@pytest.mark.parametrize("distribute", [True, False]) +def test_constructor(monkeypatch, distribute): + + # Set up distributed mock environment + _setup_distributed_environment(monkeypatch, distribute) + + # Create dataset and data loader + dataset = DummyDataset() + data_loader = _SingleBatchDataLoader(dataset) + + # Distributed case + if distribute: + expected_value = _distributed_idx(size) + assert data_loader.dataset.value == expected_value + assert dataset.fetched_indices == expected_value + + # Non-distributed case + else: + assert data_loader.dataset.value == full_data_value + assert dataset.get_all_data_called is True + + # Verify that the data loader yields exactly one batch per iteration + assert len(data_loader) == 1 + + # Should fail if dataset is smaller than world size in distributed case + if distribute: + small_dataset = DummyDataset(size=distributed_world_size - 1) + with pytest.raises(RuntimeError): + _SingleBatchDataLoader(small_dataset) + + +@pytest.mark.parametrize("distribute", [True, False]) +def test_iter(monkeypatch, distribute): + + # Set up distributed mock environment + _setup_distributed_environment(monkeypatch, distribute) + + # Create dataset and data loader + dataset = DummyDataset() + data_loader = _SingleBatchDataLoader(dataset) + + # Iterate through the data loader + batches = list(data_loader) + + # Expected value based on the distributed setting + expected_value = _distributed_idx(size) if distribute else full_data_value + + # Verify iteration behavior + assert len(batches) == 1 + assert batches[0].value == expected_value diff --git a/tests/test_trainer.py b/tests/test_trainer.py new file mode 100644 index 000000000..87353a6b7 --- /dev/null +++ b/tests/test_trainer.py @@ -0,0 +1,213 @@ +import pytest +from pina import Trainer +from pina.solver import PINN +from pina.model import FeedForward +from pina.problem.zoo import Poisson2DSquareProblem + + +# Define the problem, the model and the solver for testing purposes +problem = Poisson2DSquareProblem() +problem.discretise_domain(n=10, mode="random") +model = FeedForward(len(problem.input_variables), len(problem.output_variables)) +solver = PINN(model=model, problem=problem) + + +@pytest.mark.parametrize("batching_mode", Trainer._AVAIL_BATCHING_MODES) +@pytest.mark.parametrize("automatic_batching", [True, False]) +@pytest.mark.parametrize("pin_memory", [True, False]) +@pytest.mark.parametrize("shuffle", [True, False]) +@pytest.mark.parametrize("compile", [True, False]) +@pytest.mark.parametrize("batch_size", [None, 5]) +@pytest.mark.parametrize( + "train_size, test_size, val_size", [(0.8, 0.1, 0.1), (0.7, 0.2, 0.1)] +) +def test_constructor( + batch_size, + train_size, + test_size, + val_size, + compile, + batching_mode, + automatic_batching, + pin_memory, + shuffle, +): + + Trainer( + solver=solver, + batch_size=batch_size, + train_size=train_size, + test_size=test_size, + val_size=val_size, + compile=compile, + batching_mode=batching_mode if batch_size else "common_batch_size", + automatic_batching=automatic_batching, + num_workers=0, + pin_memory=pin_memory if batch_size else False, + shuffle=shuffle, + ) + + # Should raise ValueError if solver is not an instance of SolverInterface + with pytest.raises(ValueError): + Trainer( + solver="not_a_solver", + batch_size=batch_size, + train_size=train_size, + test_size=test_size, + val_size=val_size, + compile=compile, + batching_mode=batching_mode if batch_size else "common_batch_size", + automatic_batching=automatic_batching, + num_workers=0, + pin_memory=pin_memory if batch_size else False, + shuffle=shuffle, + ) + + # Should raise ValueError if train_size + test_size + val_size != 1.0 + with pytest.raises(ValueError): + Trainer( + solver=solver, + batch_size=batch_size, + train_size=0.5, + test_size=0.3, + val_size=0.3, + compile=compile, + batching_mode=batching_mode if batch_size else "common_batch_size", + automatic_batching=automatic_batching, + num_workers=0, + pin_memory=pin_memory if batch_size else False, + shuffle=shuffle, + ) + + # Should raise ValueError if compile is not a boolean + with pytest.raises(ValueError): + Trainer( + solver=solver, + batch_size=batch_size, + train_size=train_size, + test_size=test_size, + val_size=val_size, + compile="not_a_boolean", + batching_mode=batching_mode if batch_size else "common_batch_size", + automatic_batching=automatic_batching, + num_workers=0, + pin_memory=pin_memory if batch_size else False, + shuffle=shuffle, + ) + + # Should raise ValueError if automatic_batching is not a boolean + with pytest.raises(ValueError): + Trainer( + solver=solver, + batch_size=batch_size, + train_size=train_size, + test_size=test_size, + val_size=val_size, + compile=compile, + batching_mode=batching_mode if batch_size else "common_batch_size", + automatic_batching="not_a_boolean", + num_workers=0, + pin_memory=pin_memory if batch_size else False, + shuffle=shuffle, + ) + + # Should raise ValueError if shuffle is not a boolean + with pytest.raises(ValueError): + Trainer( + solver=solver, + batch_size=batch_size, + train_size=train_size, + test_size=test_size, + val_size=val_size, + compile=compile, + batching_mode=batching_mode if batch_size else "common_batch_size", + automatic_batching=automatic_batching, + num_workers=0, + pin_memory=pin_memory if batch_size else False, + shuffle="not_a_boolean", + ) + + # Should raise ValueError if pin_memory is not a boolean + with pytest.raises(ValueError): + Trainer( + solver=solver, + batch_size=batch_size, + train_size=train_size, + test_size=test_size, + val_size=val_size, + compile=compile, + batching_mode=batching_mode if batch_size else "common_batch_size", + automatic_batching=automatic_batching, + num_workers=0, + pin_memory="not_a_boolean", + shuffle=shuffle, + ) + + # Should raise ValueError if num_workers is negative + with pytest.raises(AssertionError): + Trainer( + solver=solver, + batch_size=batch_size, + train_size=train_size, + test_size=test_size, + val_size=val_size, + compile=compile, + batching_mode=batching_mode if batch_size else "common_batch_size", + automatic_batching=automatic_batching, + num_workers=-1, + pin_memory=pin_memory if batch_size else False, + shuffle=shuffle, + ) + + # Should raise ValueError if batch_size is not a positive integer + with pytest.raises(AssertionError): + Trainer( + solver=solver, + batch_size=-1, + train_size=train_size, + test_size=test_size, + val_size=val_size, + compile=compile, + batching_mode=batching_mode if batch_size else "common_batch_size", + automatic_batching=automatic_batching, + num_workers=0, + pin_memory=pin_memory if batch_size else False, + shuffle=shuffle, + ) + + # Should raise ValueError if an invalid batching mode is provided + with pytest.raises(ValueError): + Trainer( + solver=solver, + batch_size=batch_size, + train_size=train_size, + test_size=test_size, + val_size=val_size, + compile=compile, + batching_mode="invalid_mode", + automatic_batching=automatic_batching, + num_workers=0, + pin_memory=pin_memory if batch_size else False, + shuffle=shuffle, + ) + + # Should raise RuntimeError if any domain has not been discretised + with pytest.raises(RuntimeError): + + # Create a new problem without discretising the domain + new_problem = Poisson2DSquareProblem() + new_solver = PINN(model=model, problem=new_problem) + + Trainer( + solver=new_solver, + batch_size=batch_size, + train_size=train_size, + test_size=test_size, + val_size=val_size, + compile=compile, + batching_mode=batching_mode if batch_size else "common_batch_size", + automatic_batching=automatic_batching, + num_workers=0, + pin_memory=pin_memory if batch_size else False, + shuffle=shuffle, + )