Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 100 additions & 19 deletions deepmd/pt_expt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,23 +166,30 @@ def get_additional_data_requirement(_model: Any) -> list[DataRequirementItem]:


def _remove_detach_nodes(gm: torch.fx.GraphModule) -> None:
"""Remove ``aten.detach.default`` nodes from an FX graph in-place.

``make_fx`` inserts these nodes when recording saved tensors from the
autograd backward pass (``autograd.grad`` with ``create_graph=True``).
The detach breaks the gradient connection between saved activations and
model parameters, causing incorrect second-order derivatives — e.g.
bias gradients become zero for force-loss training.

Removing these nodes restores the gradient path so that higher-order
derivatives flow correctly through the decomposed backward ops.
"""Replace ``aten.detach.default`` nodes with ``aten.clone.default``.

``make_fx`` inserts detach nodes for saved tensors in the decomposed
autograd backward. The detach breaks the gradient path from saved
activations back to model parameters, causing incorrect second-order
derivatives (e.g. bias gradients become zero for force-loss training).

We replace detach with clone rather than erasing the node entirely.
Erasing makes the output alias the input — AOT autograd detects the
alias and stores SymInt shape values as raw pointers in a C++
``view_meta_sequence``. When Python GC later collects those SymInt
objects the pointers dangle, producing a crash of the form
``shape '[139...008, ...]' is invalid for input of size N``.
Clone breaks the alias so no ``view_meta_sequence`` is generated.
"""
graph = gm.graph
for node in list(graph.nodes):
if node.op == "call_function" and node.target == torch.ops.aten.detach.default:
input_node = node.args[0]
node.replace_all_uses_with(input_node)
graph.erase_node(node)
# Replace detach with clone to break the input-output alias.
# Alias-free outputs mean AOT autograd never writes SymInt raw
# pointers into C++ view_meta_sequence, so GC of SymInt objects
# cannot produce dangling pointers and apply_view_meta_sequence
# crashes (shape '[139...008, ...]' is invalid for input ...).
node.target = torch.ops.aten.clone.default
graph.lint()
gm.recompile()

Expand Down Expand Up @@ -311,6 +318,15 @@ def _expand(t: torch.Tensor | None) -> torch.Tensor | None:
decomposition_table=decomp_table,
)(ext_coord, ext_atype, nlist, mapping, fparam, aparam, charge_spin)

# make_fx has captured the graph; input tensors are no longer needed.
del ext_coord, ext_atype, nlist, mapping
if fparam is not None:
del fparam
if aparam is not None:
del aparam
if charge_spin is not None:
del charge_spin

# make_fx inserts aten.detach.default for saved tensors used in the
# decomposed autograd.grad backward ops. These detach nodes break
# second-order gradient flow (d(force)/d(params) for force training).
Expand Down Expand Up @@ -1019,13 +1035,52 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None:
compile_opts=compile_opts,
)

# torch.compile is lazy: inductor only compiles on the first
# call. In DDP multi-task training, different ranks may first
# hit a task at different training steps, so one rank can block
# inside inductor for minutes while others spin in AllReduce —
# causing an NCCL timeout. Warmup here, while sample inputs
# still exist, forces eager compilation before training starts.
#
# Match _CompiledModel.forward which sets requires_grad_(True) on
# ext_coord: Dynamo's guard includes requires_grad, so a mismatch
# causes every task's first training call to miss the warmup cache.
ext_coord = ext_coord.detach().requires_grad_(True)
_warmup_out = compiled_lower(
ext_coord, ext_atype, nlist_t, mapping, fparam, aparam, charge_spin
)
del _warmup_out
if DEVICE.type == "cuda":
torch.cuda.synchronize()

wrapper_mod.model[task_key] = _CompiledModel(model, compiled_lower)

# Release all intermediate tensors built for this task so they don't
# accumulate across tasks in multi-task scenarios.
del ext_coord, ext_atype, mapping, nlist_t
del coord, atype, coord_3d, coord_norm
if box is not None:
del box, box_flat
if fparam is not None:
del fparam
if aparam is not None:
del aparam
if charge_spin is not None:
del charge_spin
del inp, _

log.info(
"Model compiled (task=%s, tracing_mode=symbolic, "
"dynamic=True, backend=inductor).",
task_key,
)

# All tasks compiled on this rank — wait for all ranks before
# training starts so no rank enters the training loop while another
# is still blocked in inductor compilation.
if self.is_distributed:
dist.barrier()

# ------------------------------------------------------------------
# Data helpers
# ------------------------------------------------------------------
Expand Down Expand Up @@ -1212,7 +1267,9 @@ def run(self) -> None:
if self.rank == 0:
if not self.multi_task:
train_results = {
k: v for k, v in more_loss.items() if "l2_" not in k
k: (v.item() if isinstance(v, torch.Tensor) else v)
for k, v in more_loss.items()
if "l2_" not in k
}

# validation
Expand All @@ -1233,7 +1290,13 @@ def run(self) -> None:
for k, v in _vmore.items():
if "l2_" not in k:
valid_results[k] = (
valid_results.get(k, 0.0) + v * natoms
valid_results.get(k, 0.0)
+ (
v.item()
if isinstance(v, torch.Tensor)
else v
)
* natoms
)
if sum_natoms > 0:
valid_results = {
Expand All @@ -1246,23 +1309,33 @@ def run(self) -> None:

# current task already has loss
train_results[task_key] = {
k: v for k, v in more_loss.items() if "l2_" not in k
k: (v.item() if isinstance(v, torch.Tensor) else v)
for k, v in more_loss.items()
if "l2_" not in k
}

# compute loss for other tasks
for _key in self.model_keys:
if _key != task_key:
self.optimizer.zero_grad()
self.optimizer.zero_grad(set_to_none=True)
_inp, _lab = self.get_data(is_train=True, task_key=_key)
_, _loss, _more = self._unwrapped(
**_inp,
cur_lr=cur_lr_sched,
label=_lab,
task_key=_key,
)
# Use .item() so the backward graph (and its
# saved activations) can be freed immediately.
# Display passes never call loss.backward(), so
# without this the computation graphs for all
# tasks accumulate simultaneously in GPU memory.
train_results[_key] = {
k: v for k, v in _more.items() if "l2_" not in k
k: (v.item() if isinstance(v, torch.Tensor) else v)
for k, v in _more.items()
if "l2_" not in k
}
del _loss, _more, _inp, _lab

# validation for each task
_vdata = self.validation_data[_key]
Expand All @@ -1285,7 +1358,15 @@ def run(self) -> None:
_sum_natoms += natoms
for k, v in _vmore.items():
if "l2_" not in k:
_vres[k] = _vres.get(k, 0.0) + v * natoms
_vres[k] = (
_vres.get(k, 0.0)
+ (
v.item()
if isinstance(v, torch.Tensor)
else v
)
* natoms
)
if _sum_natoms > 0:
_vres = {
k: v / _sum_natoms for k, v in _vres.items()
Expand Down
Loading