diff --git a/deepmd/pd/infer/deep_eval.py b/deepmd/pd/infer/deep_eval.py index 6c0ffed7ec..72dcc2bff0 100644 --- a/deepmd/pd/infer/deep_eval.py +++ b/deepmd/pd/infer/deep_eval.py @@ -64,6 +64,9 @@ to_numpy_array, to_paddle_tensor, ) +from deepmd.utils.batch_size import ( + RetrySignal, +) from deepmd.utils.econf_embd import ( sort_element_type, ) @@ -823,19 +826,30 @@ def eval_descriptor( model = ( self.dp.model["Default"] if isinstance(self.dp, ModelWrapper) else self.dp ) - model.set_eval_descriptor_hook(True) - self.eval( - coords, - cells, - atom_types, - atomic=False, - fparam=fparam, - aparam=aparam, - **kwargs, - ) - descriptor = model.eval_descriptor() - model.set_eval_descriptor_hook(False) - return to_numpy_array(descriptor) + while True: + if self.auto_batch_size is not None: + self.auto_batch_size.set_oom_retry_mode(True) + model.set_eval_descriptor_hook(True) + retry = False + try: + self.eval( + coords, + cells, + atom_types, + atomic=False, + fparam=fparam, + aparam=aparam, + **kwargs, + ) + descriptor = model.eval_descriptor() + except RetrySignal: + retry = True + finally: + model.set_eval_descriptor_hook(False) + if self.auto_batch_size is not None: + self.auto_batch_size.set_oom_retry_mode(False) + if not retry: + return to_numpy_array(descriptor) def eval_fitting_last_layer( self, @@ -878,16 +892,27 @@ def eval_fitting_last_layer( Fitting output before last layer. """ model = self.dp.model["Default"] - model.set_eval_fitting_last_layer_hook(True) - self.eval( - coords, - cells, - atom_types, - atomic=False, - fparam=fparam, - aparam=aparam, - **kwargs, - ) - fitting_net = model.eval_fitting_last_layer() - model.set_eval_fitting_last_layer_hook(False) - return to_numpy_array(fitting_net) + while True: + if self.auto_batch_size is not None: + self.auto_batch_size.set_oom_retry_mode(True) + model.set_eval_fitting_last_layer_hook(True) + retry = False + try: + self.eval( + coords, + cells, + atom_types, + atomic=False, + fparam=fparam, + aparam=aparam, + **kwargs, + ) + fitting_net = model.eval_fitting_last_layer() + except RetrySignal: + retry = True + finally: + model.set_eval_fitting_last_layer_hook(False) + if self.auto_batch_size is not None: + self.auto_batch_size.set_oom_retry_mode(False) + if not retry: + return to_numpy_array(fitting_net) diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index 6e63ecb2fc..e48b26057e 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -67,6 +67,9 @@ to_numpy_array, to_torch_tensor, ) +from deepmd.utils.batch_size import ( + RetrySignal, +) from deepmd.utils.econf_embd import ( sort_element_type, ) @@ -793,19 +796,30 @@ def eval_descriptor( Descriptors. """ model = self.dp.model["Default"] - model.set_eval_descriptor_hook(True) - self.eval( - coords, - cells, - atom_types, - atomic=False, - fparam=fparam, - aparam=aparam, - **kwargs, - ) - descriptor = model.eval_descriptor() - model.set_eval_descriptor_hook(False) - return to_numpy_array(descriptor) + while True: + if self.auto_batch_size is not None: + self.auto_batch_size.set_oom_retry_mode(True) + model.set_eval_descriptor_hook(True) + retry = False + try: + self.eval( + coords, + cells, + atom_types, + atomic=False, + fparam=fparam, + aparam=aparam, + **kwargs, + ) + descriptor = model.eval_descriptor() + except RetrySignal: + retry = True + finally: + model.set_eval_descriptor_hook(False) + if self.auto_batch_size is not None: + self.auto_batch_size.set_oom_retry_mode(False) + if not retry: + return to_numpy_array(descriptor) def eval_fitting_last_layer( self, @@ -848,16 +862,27 @@ def eval_fitting_last_layer( Fitting output before last layer. """ model = self.dp.model["Default"] - model.set_eval_fitting_last_layer_hook(True) - self.eval( - coords, - cells, - atom_types, - atomic=False, - fparam=fparam, - aparam=aparam, - **kwargs, - ) - fitting_net = model.eval_fitting_last_layer() - model.set_eval_fitting_last_layer_hook(False) - return to_numpy_array(fitting_net) + while True: + if self.auto_batch_size is not None: + self.auto_batch_size.set_oom_retry_mode(True) + model.set_eval_fitting_last_layer_hook(True) + retry = False + try: + self.eval( + coords, + cells, + atom_types, + atomic=False, + fparam=fparam, + aparam=aparam, + **kwargs, + ) + fitting_net = model.eval_fitting_last_layer() + except RetrySignal: + retry = True + finally: + model.set_eval_fitting_last_layer_hook(False) + if self.auto_batch_size is not None: + self.auto_batch_size.set_oom_retry_mode(False) + if not retry: + return to_numpy_array(fitting_net) diff --git a/deepmd/utils/batch_size.py b/deepmd/utils/batch_size.py index e701e82ec6..a38eed0224 100644 --- a/deepmd/utils/batch_size.py +++ b/deepmd/utils/batch_size.py @@ -22,6 +22,10 @@ log = logging.getLogger(__name__) +class RetrySignal(Exception): + """Signal to retry execution after OOM error.""" + + class AutoBatchSize(ABC): """This class allows DeePMD-kit to automatically decide the maximum batch size that will not cause an OOM error. @@ -75,6 +79,7 @@ def __init__(self, initial_batch_size: int = 1024, factor: float = 2.0) -> None: ) self.factor = factor + self.oom_retry_mode = False def execute( self, callable: Callable, start_index: int, natoms: int @@ -125,6 +130,8 @@ def execute( ) from e # adjust the next batch size self._adjust_batch_size(1.0 / self.factor) + if self.oom_retry_mode: + raise RetrySignal from e return 0, None else: n_tot = n_batch * natoms @@ -281,3 +288,20 @@ def is_oom_error(self, e: Exception) -> bool: bool True if the exception is an OOM error """ + + def set_oom_retry_mode(self, enable: bool) -> None: + """Set OOM retry mode. + + In OOM retry mode, an OOM during execution may reduce the current + batch size and raise :class:`RetrySignal` to indicate that execution + should be retried. + + Callers that want all data to be re-executed must catch + :class:`RetrySignal` and restart the full evaluation themselves. + + Parameters + ---------- + enable : bool + True to enable OOM retry mode + """ + self.oom_retry_mode = enable diff --git a/source/tests/common/test_oom_retry.py b/source/tests/common/test_oom_retry.py new file mode 100644 index 0000000000..132d129f0b --- /dev/null +++ b/source/tests/common/test_oom_retry.py @@ -0,0 +1,218 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from types import ( + SimpleNamespace, +) +from typing import ( + Any, +) +from unittest.mock import ( + MagicMock, + call, + patch, +) + +import numpy as np + +from deepmd.utils.batch_size import ( + AutoBatchSize, + RetrySignal, +) +from deepmd.utils.errors import ( + OutOfMemoryError, +) + + +class CustomizedAutoBatchSizeGPU(AutoBatchSize): + def is_gpu_available(self) -> bool: + return True + + def is_oom_error(self, e): + return isinstance(e, OutOfMemoryError) + + +class DummyAutoBatchSize: + def __init__(self) -> None: + self.oom_retry_mode = False + self.modes: list[bool] = [] + + def set_oom_retry_mode(self, enable: bool) -> None: + self.oom_retry_mode = enable + self.modes.append(enable) + + +class TestOOMRetry(unittest.TestCase): + def test_execute_oom_retry_mode_raises_retry_signal(self) -> None: + auto_batch_size = CustomizedAutoBatchSizeGPU(256, 2.0) + + oom = OutOfMemoryError("oom") + + def executor(batch_size: int, start_index: int) -> tuple[int, None]: + raise oom + + auto_batch_size.set_oom_retry_mode(True) + with self.assertRaises(RetrySignal) as context: + auto_batch_size.execute(executor, 0, 1) + self.assertIs(context.exception.__cause__, oom) + self.assertEqual(auto_batch_size.current_batch_size, 128) + + def test_execute_oom_retry_mode_false_returns_zero(self) -> None: + auto_batch_size = CustomizedAutoBatchSizeGPU(256, 2.0) + + def executor(batch_size: int, start_index: int) -> tuple[int, None]: + raise OutOfMemoryError("oom") + + auto_batch_size.set_oom_retry_mode(False) + n_batch, result = auto_batch_size.execute(executor, 0, 1) + self.assertEqual(n_batch, 0) + self.assertIsNone(result) + self.assertEqual(auto_batch_size.current_batch_size, 128) + + def _make_backend(self, backend: str, method_name: str) -> tuple[Any, MagicMock]: + try: + if backend == "pt": + from deepmd.pt.infer.deep_eval import ( + DeepEval, + ) + else: + from deepmd.pd.infer.deep_eval import ( + DeepEval, + ) + except ModuleNotFoundError as exc: + self.skipTest(f"{backend} backend dependencies are unavailable: {exc}") + + abstract_methods = getattr(DeepEval, "__abstractmethods__", frozenset()) + try: + DeepEval.__abstractmethods__ = frozenset() + deep_eval = object.__new__(DeepEval) + finally: + DeepEval.__abstractmethods__ = abstract_methods + + model = MagicMock() + model.eval_descriptor.return_value = np.array([1.0, 2.0, 3.0]) + model.eval_fitting_last_layer.return_value = np.array([4.0, 5.0, 6.0]) + + if backend == "pd" and method_name == "eval_descriptor": + # Paddle eval_descriptor accepts either a ModelWrapper or a direct model. + deep_eval.dp = model + else: + deep_eval.dp = SimpleNamespace(model={"Default": model}) + deep_eval.auto_batch_size = DummyAutoBatchSize() + return deep_eval, model + + def _assert_retry_clears_hook_between_attempts( + self, + backend: str, + method_name: str, + hook_name: str, + expected: np.ndarray, + ) -> None: + deep_eval, model = self._make_backend(backend, method_name) + with patch.object( + deep_eval, "eval", side_effect=[RetrySignal, None] + ) as eval_mock: + result = getattr(deep_eval, method_name)( + coords=np.zeros((3, 1, 3)), + cells=None, + atom_types=np.array([0]), + ) + self.assertEqual(eval_mock.call_count, 2) + np.testing.assert_array_equal(result, expected) + self.assertEqual( + getattr(model, hook_name).call_args_list, + [call(True), call(False), call(True), call(False)], + ) + self.assertFalse(deep_eval.auto_batch_size.oom_retry_mode) + self.assertEqual(deep_eval.auto_batch_size.modes, [True, False, True, False]) + + def _assert_runtime_error_clears_state( + self, + backend: str, + method_name: str, + hook_name: str, + ) -> None: + deep_eval, model = self._make_backend(backend, method_name) + with patch.object( + deep_eval, + "eval", + side_effect=RuntimeError("non-retry failure"), + ): + with self.assertRaisesRegex(RuntimeError, "non-retry failure"): + getattr(deep_eval, method_name)( + coords=np.zeros((3, 1, 3)), + cells=None, + atom_types=np.array([0]), + ) + self.assertEqual( + getattr(model, hook_name).call_args_list, [call(True), call(False)] + ) + self.assertFalse(deep_eval.auto_batch_size.oom_retry_mode) + self.assertEqual(deep_eval.auto_batch_size.modes, [True, False]) + + def test_pt_eval_descriptor_retry_clears_hook_between_attempts(self) -> None: + self._assert_retry_clears_hook_between_attempts( + "pt", + "eval_descriptor", + "set_eval_descriptor_hook", + np.array([1.0, 2.0, 3.0]), + ) + + def test_pt_eval_fitting_last_layer_retry_clears_hook_between_attempts( + self, + ) -> None: + self._assert_retry_clears_hook_between_attempts( + "pt", + "eval_fitting_last_layer", + "set_eval_fitting_last_layer_hook", + np.array([4.0, 5.0, 6.0]), + ) + + def test_pd_eval_descriptor_retry_clears_hook_between_attempts(self) -> None: + self._assert_retry_clears_hook_between_attempts( + "pd", + "eval_descriptor", + "set_eval_descriptor_hook", + np.array([1.0, 2.0, 3.0]), + ) + + def test_pd_eval_fitting_last_layer_retry_clears_hook_between_attempts( + self, + ) -> None: + self._assert_retry_clears_hook_between_attempts( + "pd", + "eval_fitting_last_layer", + "set_eval_fitting_last_layer_hook", + np.array([4.0, 5.0, 6.0]), + ) + + def test_pt_eval_descriptor_runtime_error_clears_state(self) -> None: + self._assert_runtime_error_clears_state( + "pt", + "eval_descriptor", + "set_eval_descriptor_hook", + ) + + def test_pt_eval_fitting_last_layer_runtime_error_clears_state(self) -> None: + self._assert_runtime_error_clears_state( + "pt", + "eval_fitting_last_layer", + "set_eval_fitting_last_layer_hook", + ) + + def test_pd_eval_descriptor_runtime_error_clears_state(self) -> None: + self._assert_runtime_error_clears_state( + "pd", + "eval_descriptor", + "set_eval_descriptor_hook", + ) + + def test_pd_eval_fitting_last_layer_runtime_error_clears_state(self) -> None: + self._assert_runtime_error_clears_state( + "pd", + "eval_fitting_last_layer", + "set_eval_fitting_last_layer_hook", + ) + + +if __name__ == "__main__": + unittest.main()