Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions deepmd/pt_expt/descriptor/repflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ def _exchange_ghosts(
value=0.0,
)

from deepmd.pt_expt.utils.comm import (
ensure_comm_registered,
)

ensure_comm_registered()
exchanged = torch.ops.deepmd_export.border_op(
comm_dict["send_list"],
comm_dict["send_proc"],
Expand Down
5 changes: 5 additions & 0 deletions deepmd/pt_expt/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ def _exchange_ghosts(
value=0.0,
)

from deepmd.pt_expt.utils.comm import (
ensure_comm_registered,
)

ensure_comm_registered()
exchanged = torch.ops.deepmd_export.border_op(
comm_dict["send_list"],
comm_dict["send_proc"],
Expand Down
10 changes: 6 additions & 4 deletions deepmd/pt_expt/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
# as it's a stateless utility class
register_dpmodel_mapping(EnvMat, lambda v: v)

# Register opaque deepmd_export::border_op wrapper (used by GNN MPI
# parallel inference; see comm.py module docstring).
# Register fake tensor implementations for custom tabulate ops
from deepmd.pt_expt.utils import comm # noqa: F401
# Register fake tensor implementations for custom tabulate ops.
# comm.py (border_op fake/autograd) is NOT imported here — its
# ensure_comm_registered() is called lazily by comm_dict entry points
# (the with_comm_dict export path in serialization.py and pt_expt
# descriptor comm_dict exchange paths) to avoid eager libdeepmd_op_pt.so
# loading that breaks fake-op registration order in tests.
from deepmd.pt_expt.utils import tabulate_ops # noqa: F401

__all__ = [
Expand Down
56 changes: 46 additions & 10 deletions deepmd/pt_expt/utils/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,13 @@
annotations,
)

import threading

import torch

_registered: bool = False
_register_lock = threading.Lock()


def _check_underlying_ops_loaded() -> None:
"""Surface a clearer error when libdeepmd_op_pt.so isn't loaded.
Expand Down Expand Up @@ -76,15 +81,11 @@ def _check_underlying_ops_loaded() -> None:
)


_check_underlying_ops_loaded()


# ---------------------------------------------------------------------------
# Fake (meta) impls — let make_fx / torch.export trace through.
# ---------------------------------------------------------------------------


@torch.library.register_fake("deepmd_export::border_op")
def _border_op_fake(
sendlist: torch.Tensor,
sendproc: torch.Tensor,
Expand All @@ -99,7 +100,6 @@ def _border_op_fake(
return torch.empty_like(g1)


@torch.library.register_fake("deepmd_export::border_op_backward")
def _border_op_backward_fake(
sendlist: torch.Tensor,
sendproc: torch.Tensor,
Expand Down Expand Up @@ -180,8 +180,44 @@ def _border_op_backward(
)


torch.library.register_autograd(
"deepmd_export::border_op",
_border_op_backward,
setup_context=_border_op_setup_context,
)
def ensure_comm_registered() -> None:
"""Load libdeepmd_op_pt.so and register fake/autograd metadata for border_op.

Idempotent — safe to call multiple times. Must be called before any
``make_fx`` / ``torch.export`` trace that passes through border_op (i.e.
before the ``with_comm_dict=True`` export path in serialization.py).

Kept lazy (not called at import time) so that merely importing
``deepmd.pt_expt.utils`` does not force-load libdeepmd_op_pt.so and
disrupt fake-op registration order in tests that don't exercise the comm
path at all.
"""
global _registered
if _registered:
return
with _register_lock:
if _registered:
return
_check_underlying_ops_loaded()
try:
torch.library.register_fake("deepmd_export::border_op")(_border_op_fake)
except RuntimeError as e:
if "already has" not in str(e) and "already registered" not in str(e):
raise
try:
torch.library.register_fake("deepmd_export::border_op_backward")(
_border_op_backward_fake
)
except RuntimeError as e:
if "already has" not in str(e) and "already registered" not in str(e):
raise
try:
torch.library.register_autograd(
"deepmd_export::border_op",
_border_op_backward,
setup_context=_border_op_setup_context,
)
except RuntimeError as e:
if "already has" not in str(e) and "already registered" not in str(e):
raise
_registered = True
7 changes: 7 additions & 0 deletions deepmd/pt_expt/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,13 @@ def _trace_and_export(
"(has_message_passing_across_ranks() is False) — "
"there's nothing to compile."
)
# Load libdeepmd_op_pt.so and register border_op fake/autograd
# metadata only for models that actually need the comm path.
from deepmd.pt_expt.utils.comm import (
ensure_comm_registered,
)

ensure_comm_registered()
nloc_sample = nlist_t.shape[1]
nall_sample = ext_atype.shape[1]
nghost_sample = nall_sample - nloc_sample
Expand Down
7 changes: 4 additions & 3 deletions source/tests/pt_expt/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
_get_current_function_mode_stack,
)

# ``deepmd.pt_expt.utils.comm`` self-bootstraps libdeepmd_op_pt.so via
# ``_check_underlying_ops_loaded()``, so we no longer need to preload
# ``deepmd.pt`` here.
# ``deepmd.pt_expt.utils.comm`` is now lazy: libdeepmd_op_pt.so is only
# loaded when ``ensure_comm_registered()`` is explicitly called from a
# comm_dict code path. Tests that don't exercise comm_dict paths never
# load the op library, preserving fake-op registration order.


def _pop_device_contexts() -> list:
Expand Down
7 changes: 5 additions & 2 deletions source/tests/pt_expt/descriptor/test_repflow_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@
import pytest
import torch

# Trigger registration of the deepmd_export::border_op opaque wrapper.
import deepmd.pt_expt.utils.comm # noqa: F401 # lgtm[py/unused-import]
from deepmd.dpmodel.descriptor.dpa3 import (
RepFlowArgs,
)
Expand All @@ -42,6 +40,9 @@
from deepmd.pt_expt.utils import (
env,
)
from deepmd.pt_expt.utils.comm import (
ensure_comm_registered,
)
from deepmd.pt_expt.utils.env import (
PRECISION_DICT,
)
Expand All @@ -54,6 +55,8 @@
GLOBAL_SEED,
)

ensure_comm_registered()

# ---------------------------------------------------------------------------
# Helpers for building the comm_dict tensors

Expand Down
7 changes: 5 additions & 2 deletions source/tests/pt_expt/descriptor/test_repformer_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
import pytest
import torch

# Trigger registration of the deepmd_export::border_op opaque wrapper.
import deepmd.pt_expt.utils.comm # noqa: F401 # lgtm[py/unused-import]
from deepmd.dpmodel.descriptor.dpa2 import (
RepformerArgs,
RepinitArgs,
Expand All @@ -29,6 +27,9 @@
from deepmd.pt_expt.utils import (
env,
)
from deepmd.pt_expt.utils.comm import (
ensure_comm_registered,
)
from deepmd.pt_expt.utils.env import (
PRECISION_DICT,
)
Expand All @@ -41,6 +42,8 @@
GLOBAL_SEED,
)

ensure_comm_registered()


def _addr_of(np_arr: np.ndarray) -> int:
return np_arr.ctypes.data_as(ctypes.c_void_p).value
Expand Down
10 changes: 7 additions & 3 deletions source/tests/pt_expt/model/test_export_with_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,21 @@
import pytest
import torch

# Trigger registration of the deepmd_export::border_op opaque wrapper
# (needed by the with-comm artifact at runtime).
import deepmd.pt_expt.utils.comm # noqa: F401 # lgtm[py/unused-import]
from deepmd.pt_expt.model.get_model import (
get_model,
)

# Register deepmd_export::border_op fake/autograd metadata explicitly.
from deepmd.pt_expt.utils.comm import (
ensure_comm_registered,
)
from deepmd.pt_expt.utils.serialization import (
_make_sample_inputs,
deserialize_to_file,
)

ensure_comm_registered()

_DPA3_CONFIG = {
"type_map": ["O", "H"],
"descriptor": {
Expand Down
6 changes: 5 additions & 1 deletion source/tests/pt_expt/model/test_spin_export_with_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,15 @@
import numpy as np
import torch

import deepmd.pt_expt.utils.comm # noqa: F401 # lgtm[py/unused-import] - opaque op registration
from deepmd.dpmodel.model.model import get_model as get_model_dp
from deepmd.pt_expt.model.spin_ener_model import (
SpinEnergyModel,
)
from deepmd.pt_expt.utils.comm import (
ensure_comm_registered,
)

ensure_comm_registered()

SPIN_GNN_DATA = {
"type_map": ["O", "H", "B"],
Expand Down
9 changes: 5 additions & 4 deletions source/tests/pt_expt/utils/test_border_op_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@
import pytest
import torch

# comm self-bootstraps the underlying libdeepmd_op_pt.so when needed, so
# this single side-effect import is enough to register both the C++
# ops (deepmd::border_op_backward) and their fake/autograd metadata.
import deepmd.pt_expt.utils.comm # noqa: F401 # lgtm[py/unused-import] - registers deepmd_export::border_op
from deepmd.pt_expt.utils.comm import (
ensure_comm_registered,
)

ensure_comm_registered()


def _addr_of(np_arr: np.ndarray) -> int:
Expand Down
70 changes: 70 additions & 0 deletions source/tests/pt_expt/utils/test_comm_registration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from __future__ import (
annotations,
)

import torch

from deepmd.pt_expt.utils import (
comm,
)


def test_ensure_comm_registered_idempotent(monkeypatch) -> None:
monkeypatch.setattr(comm, "_registered", False)
calls: dict[str, object] = {
"check": 0,
"fake_names": [],
"autograd": 0,
}

def _check() -> None:
calls["check"] = int(calls["check"]) + 1

def _register_fake(name: str):
fake_names = calls["fake_names"]
assert isinstance(fake_names, list)
fake_names.append(name)

def _decorator(fn):
return fn

return _decorator

def _register_autograd(*args, **kwargs) -> None:
calls["autograd"] = int(calls["autograd"]) + 1

monkeypatch.setattr(comm, "_check_underlying_ops_loaded", _check)
monkeypatch.setattr(torch.library, "register_fake", _register_fake)
monkeypatch.setattr(torch.library, "register_autograd", _register_autograd)

comm.ensure_comm_registered()
comm.ensure_comm_registered()

assert calls["check"] == 1
assert calls["fake_names"] == [
"deepmd_export::border_op",
"deepmd_export::border_op_backward",
]
assert calls["autograd"] == 1


def test_ensure_comm_registered_tolerates_duplicate_autograd(monkeypatch) -> None:
monkeypatch.setattr(comm, "_registered", False)
monkeypatch.setattr(comm, "_check_underlying_ops_loaded", lambda: None)

def _register_fake(name: str):
def _decorator(fn):
raise RuntimeError("already registered")

return _decorator

def _register_autograd(*args, **kwargs) -> None:
raise RuntimeError("already has an autograd implementation")

monkeypatch.setattr(torch.library, "register_fake", _register_fake)
monkeypatch.setattr(torch.library, "register_autograd", _register_autograd)

comm.ensure_comm_registered()

assert comm._registered is True
51 changes: 51 additions & 0 deletions source/tests/pt_expt/utils/test_serialization_with_comm_guard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from __future__ import (
annotations,
)

import pytest
import torch

from deepmd.pt_expt.utils import (
serialization,
)


class _DummyModel:
def to(self, device):
return self

def eval(self):
return self


def test_trace_with_comm_rejects_non_comm_model_before_registration(
monkeypatch,
) -> None:
import deepmd.pt_expt.model.model as model_module

monkeypatch.setattr(
model_module.BaseModel,
"deserialize",
classmethod(lambda cls, data: _DummyModel()),
)
monkeypatch.setattr(serialization, "_collect_metadata", lambda model, is_spin: {})
monkeypatch.setattr(serialization, "_needs_with_comm_artifact", lambda model: False)
monkeypatch.setattr(
serialization,
"_make_sample_inputs",
lambda model, nframes, has_spin: (
torch.zeros((1, 4, 3), dtype=torch.float64),
torch.zeros((1, 4), dtype=torch.int32),
torch.zeros((1, 2, 4), dtype=torch.int64),
torch.zeros((1, 4), dtype=torch.int64),
None,
None,
),
)

with pytest.raises(ValueError, match="nothing to compile"):
serialization._trace_and_export(
{"model": {"type": "ener"}},
with_comm_dict=True,
)