diff --git a/deepmd/pt_expt/descriptor/repflows.py b/deepmd/pt_expt/descriptor/repflows.py index dacab9f464..c850efc2b3 100644 --- a/deepmd/pt_expt/descriptor/repflows.py +++ b/deepmd/pt_expt/descriptor/repflows.py @@ -99,6 +99,11 @@ def _exchange_ghosts( value=0.0, ) + from deepmd.pt_expt.utils.comm import ( + ensure_comm_registered, + ) + + ensure_comm_registered() exchanged = torch.ops.deepmd_export.border_op( comm_dict["send_list"], comm_dict["send_proc"], diff --git a/deepmd/pt_expt/descriptor/repformers.py b/deepmd/pt_expt/descriptor/repformers.py index 9b8ddb4a85..f289d55d2d 100644 --- a/deepmd/pt_expt/descriptor/repformers.py +++ b/deepmd/pt_expt/descriptor/repformers.py @@ -74,6 +74,11 @@ def _exchange_ghosts( value=0.0, ) + from deepmd.pt_expt.utils.comm import ( + ensure_comm_registered, + ) + + ensure_comm_registered() exchanged = torch.ops.deepmd_export.border_op( comm_dict["send_list"], comm_dict["send_proc"], diff --git a/deepmd/pt_expt/utils/__init__.py b/deepmd/pt_expt/utils/__init__.py index 99da68fe4f..cc627d13ec 100644 --- a/deepmd/pt_expt/utils/__init__.py +++ b/deepmd/pt_expt/utils/__init__.py @@ -22,10 +22,12 @@ # as it's a stateless utility class register_dpmodel_mapping(EnvMat, lambda v: v) -# Register opaque deepmd_export::border_op wrapper (used by GNN MPI -# parallel inference; see comm.py module docstring). -# Register fake tensor implementations for custom tabulate ops -from deepmd.pt_expt.utils import comm # noqa: F401 +# Register fake tensor implementations for custom tabulate ops. +# comm.py (border_op fake/autograd) is NOT imported here — its +# ensure_comm_registered() is called lazily by comm_dict entry points +# (the with_comm_dict export path in serialization.py and pt_expt +# descriptor comm_dict exchange paths) to avoid eager libdeepmd_op_pt.so +# loading that breaks fake-op registration order in tests. from deepmd.pt_expt.utils import tabulate_ops # noqa: F401 __all__ = [ diff --git a/deepmd/pt_expt/utils/comm.py b/deepmd/pt_expt/utils/comm.py index 434d2a97b0..4eef330230 100644 --- a/deepmd/pt_expt/utils/comm.py +++ b/deepmd/pt_expt/utils/comm.py @@ -33,8 +33,13 @@ annotations, ) +import threading + import torch +_registered: bool = False +_register_lock = threading.Lock() + def _check_underlying_ops_loaded() -> None: """Surface a clearer error when libdeepmd_op_pt.so isn't loaded. @@ -76,15 +81,11 @@ def _check_underlying_ops_loaded() -> None: ) -_check_underlying_ops_loaded() - - # --------------------------------------------------------------------------- # Fake (meta) impls — let make_fx / torch.export trace through. # --------------------------------------------------------------------------- -@torch.library.register_fake("deepmd_export::border_op") def _border_op_fake( sendlist: torch.Tensor, sendproc: torch.Tensor, @@ -99,7 +100,6 @@ def _border_op_fake( return torch.empty_like(g1) -@torch.library.register_fake("deepmd_export::border_op_backward") def _border_op_backward_fake( sendlist: torch.Tensor, sendproc: torch.Tensor, @@ -180,8 +180,44 @@ def _border_op_backward( ) -torch.library.register_autograd( - "deepmd_export::border_op", - _border_op_backward, - setup_context=_border_op_setup_context, -) +def ensure_comm_registered() -> None: + """Load libdeepmd_op_pt.so and register fake/autograd metadata for border_op. + + Idempotent — safe to call multiple times. Must be called before any + ``make_fx`` / ``torch.export`` trace that passes through border_op (i.e. + before the ``with_comm_dict=True`` export path in serialization.py). + + Kept lazy (not called at import time) so that merely importing + ``deepmd.pt_expt.utils`` does not force-load libdeepmd_op_pt.so and + disrupt fake-op registration order in tests that don't exercise the comm + path at all. + """ + global _registered + if _registered: + return + with _register_lock: + if _registered: + return + _check_underlying_ops_loaded() + try: + torch.library.register_fake("deepmd_export::border_op")(_border_op_fake) + except RuntimeError as e: + if "already has" not in str(e) and "already registered" not in str(e): + raise + try: + torch.library.register_fake("deepmd_export::border_op_backward")( + _border_op_backward_fake + ) + except RuntimeError as e: + if "already has" not in str(e) and "already registered" not in str(e): + raise + try: + torch.library.register_autograd( + "deepmd_export::border_op", + _border_op_backward, + setup_context=_border_op_setup_context, + ) + except RuntimeError as e: + if "already has" not in str(e) and "already registered" not in str(e): + raise + _registered = True diff --git a/deepmd/pt_expt/utils/serialization.py b/deepmd/pt_expt/utils/serialization.py index d85a334493..fe2a4e29cd 100644 --- a/deepmd/pt_expt/utils/serialization.py +++ b/deepmd/pt_expt/utils/serialization.py @@ -677,6 +677,13 @@ def _trace_and_export( "(has_message_passing_across_ranks() is False) — " "there's nothing to compile." ) + # Load libdeepmd_op_pt.so and register border_op fake/autograd + # metadata only for models that actually need the comm path. + from deepmd.pt_expt.utils.comm import ( + ensure_comm_registered, + ) + + ensure_comm_registered() nloc_sample = nlist_t.shape[1] nall_sample = ext_atype.shape[1] nghost_sample = nall_sample - nloc_sample diff --git a/source/tests/pt_expt/conftest.py b/source/tests/pt_expt/conftest.py index d4d987fe95..18453d418a 100644 --- a/source/tests/pt_expt/conftest.py +++ b/source/tests/pt_expt/conftest.py @@ -17,9 +17,10 @@ _get_current_function_mode_stack, ) -# ``deepmd.pt_expt.utils.comm`` self-bootstraps libdeepmd_op_pt.so via -# ``_check_underlying_ops_loaded()``, so we no longer need to preload -# ``deepmd.pt`` here. +# ``deepmd.pt_expt.utils.comm`` is now lazy: libdeepmd_op_pt.so is only +# loaded when ``ensure_comm_registered()`` is explicitly called from a +# comm_dict code path. Tests that don't exercise comm_dict paths never +# load the op library, preserving fake-op registration order. def _pop_device_contexts() -> list: diff --git a/source/tests/pt_expt/descriptor/test_repflow_parallel.py b/source/tests/pt_expt/descriptor/test_repflow_parallel.py index f5b4d40bcd..566661e95b 100644 --- a/source/tests/pt_expt/descriptor/test_repflow_parallel.py +++ b/source/tests/pt_expt/descriptor/test_repflow_parallel.py @@ -31,8 +31,6 @@ import pytest import torch -# Trigger registration of the deepmd_export::border_op opaque wrapper. -import deepmd.pt_expt.utils.comm # noqa: F401 # lgtm[py/unused-import] from deepmd.dpmodel.descriptor.dpa3 import ( RepFlowArgs, ) @@ -42,6 +40,9 @@ from deepmd.pt_expt.utils import ( env, ) +from deepmd.pt_expt.utils.comm import ( + ensure_comm_registered, +) from deepmd.pt_expt.utils.env import ( PRECISION_DICT, ) @@ -54,6 +55,8 @@ GLOBAL_SEED, ) +ensure_comm_registered() + # --------------------------------------------------------------------------- # Helpers for building the comm_dict tensors diff --git a/source/tests/pt_expt/descriptor/test_repformer_parallel.py b/source/tests/pt_expt/descriptor/test_repformer_parallel.py index 1a6413d08f..ed21592747 100644 --- a/source/tests/pt_expt/descriptor/test_repformer_parallel.py +++ b/source/tests/pt_expt/descriptor/test_repformer_parallel.py @@ -17,8 +17,6 @@ import pytest import torch -# Trigger registration of the deepmd_export::border_op opaque wrapper. -import deepmd.pt_expt.utils.comm # noqa: F401 # lgtm[py/unused-import] from deepmd.dpmodel.descriptor.dpa2 import ( RepformerArgs, RepinitArgs, @@ -29,6 +27,9 @@ from deepmd.pt_expt.utils import ( env, ) +from deepmd.pt_expt.utils.comm import ( + ensure_comm_registered, +) from deepmd.pt_expt.utils.env import ( PRECISION_DICT, ) @@ -41,6 +42,8 @@ GLOBAL_SEED, ) +ensure_comm_registered() + def _addr_of(np_arr: np.ndarray) -> int: return np_arr.ctypes.data_as(ctypes.c_void_p).value diff --git a/source/tests/pt_expt/model/test_export_with_comm.py b/source/tests/pt_expt/model/test_export_with_comm.py index dcbc628e53..9ec337b842 100644 --- a/source/tests/pt_expt/model/test_export_with_comm.py +++ b/source/tests/pt_expt/model/test_export_with_comm.py @@ -33,17 +33,21 @@ import pytest import torch -# Trigger registration of the deepmd_export::border_op opaque wrapper -# (needed by the with-comm artifact at runtime). -import deepmd.pt_expt.utils.comm # noqa: F401 # lgtm[py/unused-import] from deepmd.pt_expt.model.get_model import ( get_model, ) + +# Register deepmd_export::border_op fake/autograd metadata explicitly. +from deepmd.pt_expt.utils.comm import ( + ensure_comm_registered, +) from deepmd.pt_expt.utils.serialization import ( _make_sample_inputs, deserialize_to_file, ) +ensure_comm_registered() + _DPA3_CONFIG = { "type_map": ["O", "H"], "descriptor": { diff --git a/source/tests/pt_expt/model/test_spin_export_with_comm.py b/source/tests/pt_expt/model/test_spin_export_with_comm.py index 0e403d2b42..74ea7a498f 100644 --- a/source/tests/pt_expt/model/test_spin_export_with_comm.py +++ b/source/tests/pt_expt/model/test_spin_export_with_comm.py @@ -30,11 +30,15 @@ import numpy as np import torch -import deepmd.pt_expt.utils.comm # noqa: F401 # lgtm[py/unused-import] - opaque op registration from deepmd.dpmodel.model.model import get_model as get_model_dp from deepmd.pt_expt.model.spin_ener_model import ( SpinEnergyModel, ) +from deepmd.pt_expt.utils.comm import ( + ensure_comm_registered, +) + +ensure_comm_registered() SPIN_GNN_DATA = { "type_map": ["O", "H", "B"], diff --git a/source/tests/pt_expt/utils/test_border_op_backward.py b/source/tests/pt_expt/utils/test_border_op_backward.py index b33e575f1a..80e1edc99f 100644 --- a/source/tests/pt_expt/utils/test_border_op_backward.py +++ b/source/tests/pt_expt/utils/test_border_op_backward.py @@ -30,10 +30,11 @@ import pytest import torch -# comm self-bootstraps the underlying libdeepmd_op_pt.so when needed, so -# this single side-effect import is enough to register both the C++ -# ops (deepmd::border_op_backward) and their fake/autograd metadata. -import deepmd.pt_expt.utils.comm # noqa: F401 # lgtm[py/unused-import] - registers deepmd_export::border_op +from deepmd.pt_expt.utils.comm import ( + ensure_comm_registered, +) + +ensure_comm_registered() def _addr_of(np_arr: np.ndarray) -> int: diff --git a/source/tests/pt_expt/utils/test_comm_registration.py b/source/tests/pt_expt/utils/test_comm_registration.py new file mode 100644 index 0000000000..2188a07c9e --- /dev/null +++ b/source/tests/pt_expt/utils/test_comm_registration.py @@ -0,0 +1,70 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from __future__ import ( + annotations, +) + +import torch + +from deepmd.pt_expt.utils import ( + comm, +) + + +def test_ensure_comm_registered_idempotent(monkeypatch) -> None: + monkeypatch.setattr(comm, "_registered", False) + calls: dict[str, object] = { + "check": 0, + "fake_names": [], + "autograd": 0, + } + + def _check() -> None: + calls["check"] = int(calls["check"]) + 1 + + def _register_fake(name: str): + fake_names = calls["fake_names"] + assert isinstance(fake_names, list) + fake_names.append(name) + + def _decorator(fn): + return fn + + return _decorator + + def _register_autograd(*args, **kwargs) -> None: + calls["autograd"] = int(calls["autograd"]) + 1 + + monkeypatch.setattr(comm, "_check_underlying_ops_loaded", _check) + monkeypatch.setattr(torch.library, "register_fake", _register_fake) + monkeypatch.setattr(torch.library, "register_autograd", _register_autograd) + + comm.ensure_comm_registered() + comm.ensure_comm_registered() + + assert calls["check"] == 1 + assert calls["fake_names"] == [ + "deepmd_export::border_op", + "deepmd_export::border_op_backward", + ] + assert calls["autograd"] == 1 + + +def test_ensure_comm_registered_tolerates_duplicate_autograd(monkeypatch) -> None: + monkeypatch.setattr(comm, "_registered", False) + monkeypatch.setattr(comm, "_check_underlying_ops_loaded", lambda: None) + + def _register_fake(name: str): + def _decorator(fn): + raise RuntimeError("already registered") + + return _decorator + + def _register_autograd(*args, **kwargs) -> None: + raise RuntimeError("already has an autograd implementation") + + monkeypatch.setattr(torch.library, "register_fake", _register_fake) + monkeypatch.setattr(torch.library, "register_autograd", _register_autograd) + + comm.ensure_comm_registered() + + assert comm._registered is True diff --git a/source/tests/pt_expt/utils/test_serialization_with_comm_guard.py b/source/tests/pt_expt/utils/test_serialization_with_comm_guard.py new file mode 100644 index 0000000000..9d58451fa6 --- /dev/null +++ b/source/tests/pt_expt/utils/test_serialization_with_comm_guard.py @@ -0,0 +1,51 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from __future__ import ( + annotations, +) + +import pytest +import torch + +from deepmd.pt_expt.utils import ( + serialization, +) + + +class _DummyModel: + def to(self, device): + return self + + def eval(self): + return self + + +def test_trace_with_comm_rejects_non_comm_model_before_registration( + monkeypatch, +) -> None: + import deepmd.pt_expt.model.model as model_module + + monkeypatch.setattr( + model_module.BaseModel, + "deserialize", + classmethod(lambda cls, data: _DummyModel()), + ) + monkeypatch.setattr(serialization, "_collect_metadata", lambda model, is_spin: {}) + monkeypatch.setattr(serialization, "_needs_with_comm_artifact", lambda model: False) + monkeypatch.setattr( + serialization, + "_make_sample_inputs", + lambda model, nframes, has_spin: ( + torch.zeros((1, 4, 3), dtype=torch.float64), + torch.zeros((1, 4), dtype=torch.int32), + torch.zeros((1, 2, 4), dtype=torch.int64), + torch.zeros((1, 4), dtype=torch.int64), + None, + None, + ), + ) + + with pytest.raises(ValueError, match="nothing to compile"): + serialization._trace_and_export( + {"model": {"type": "ener"}}, + with_comm_dict=True, + )