diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index 1059af0be6..2d5d47ea78 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -166,23 +166,30 @@ def get_additional_data_requirement(_model: Any) -> list[DataRequirementItem]: def _remove_detach_nodes(gm: torch.fx.GraphModule) -> None: - """Remove ``aten.detach.default`` nodes from an FX graph in-place. - - ``make_fx`` inserts these nodes when recording saved tensors from the - autograd backward pass (``autograd.grad`` with ``create_graph=True``). - The detach breaks the gradient connection between saved activations and - model parameters, causing incorrect second-order derivatives — e.g. - bias gradients become zero for force-loss training. - - Removing these nodes restores the gradient path so that higher-order - derivatives flow correctly through the decomposed backward ops. + """Replace ``aten.detach.default`` nodes with ``aten.clone.default``. + + ``make_fx`` inserts detach nodes for saved tensors in the decomposed + autograd backward. The detach breaks the gradient path from saved + activations back to model parameters, causing incorrect second-order + derivatives (e.g. bias gradients become zero for force-loss training). + + We replace detach with clone rather than erasing the node entirely. + Erasing makes the output alias the input — AOT autograd detects the + alias and stores SymInt shape values as raw pointers in a C++ + ``view_meta_sequence``. When Python GC later collects those SymInt + objects the pointers dangle, producing a crash of the form + ``shape '[139...008, ...]' is invalid for input of size N``. + Clone breaks the alias so no ``view_meta_sequence`` is generated. """ graph = gm.graph for node in list(graph.nodes): if node.op == "call_function" and node.target == torch.ops.aten.detach.default: - input_node = node.args[0] - node.replace_all_uses_with(input_node) - graph.erase_node(node) + # Replace detach with clone to break the input-output alias. + # Alias-free outputs mean AOT autograd never writes SymInt raw + # pointers into C++ view_meta_sequence, so GC of SymInt objects + # cannot produce dangling pointers and apply_view_meta_sequence + # crashes (shape '[139...008, ...]' is invalid for input ...). + node.target = torch.ops.aten.clone.default graph.lint() gm.recompile() @@ -311,6 +318,15 @@ def _expand(t: torch.Tensor | None) -> torch.Tensor | None: decomposition_table=decomp_table, )(ext_coord, ext_atype, nlist, mapping, fparam, aparam, charge_spin) + # make_fx has captured the graph; input tensors are no longer needed. + del ext_coord, ext_atype, nlist, mapping + if fparam is not None: + del fparam + if aparam is not None: + del aparam + if charge_spin is not None: + del charge_spin + # make_fx inserts aten.detach.default for saved tensors used in the # decomposed autograd.grad backward ops. These detach nodes break # second-order gradient flow (d(force)/d(params) for force training). @@ -1019,13 +1035,52 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None: compile_opts=compile_opts, ) + # torch.compile is lazy: inductor only compiles on the first + # call. In DDP multi-task training, different ranks may first + # hit a task at different training steps, so one rank can block + # inside inductor for minutes while others spin in AllReduce — + # causing an NCCL timeout. Warmup here, while sample inputs + # still exist, forces eager compilation before training starts. + # + # Match _CompiledModel.forward which sets requires_grad_(True) on + # ext_coord: Dynamo's guard includes requires_grad, so a mismatch + # causes every task's first training call to miss the warmup cache. + ext_coord = ext_coord.detach().requires_grad_(True) + _warmup_out = compiled_lower( + ext_coord, ext_atype, nlist_t, mapping, fparam, aparam, charge_spin + ) + del _warmup_out + if DEVICE.type == "cuda": + torch.cuda.synchronize() + wrapper_mod.model[task_key] = _CompiledModel(model, compiled_lower) + + # Release all intermediate tensors built for this task so they don't + # accumulate across tasks in multi-task scenarios. + del ext_coord, ext_atype, mapping, nlist_t + del coord, atype, coord_3d, coord_norm + if box is not None: + del box, box_flat + if fparam is not None: + del fparam + if aparam is not None: + del aparam + if charge_spin is not None: + del charge_spin + del inp, _ + log.info( "Model compiled (task=%s, tracing_mode=symbolic, " "dynamic=True, backend=inductor).", task_key, ) + # All tasks compiled on this rank — wait for all ranks before + # training starts so no rank enters the training loop while another + # is still blocked in inductor compilation. + if self.is_distributed: + dist.barrier() + # ------------------------------------------------------------------ # Data helpers # ------------------------------------------------------------------ @@ -1212,7 +1267,9 @@ def run(self) -> None: if self.rank == 0: if not self.multi_task: train_results = { - k: v for k, v in more_loss.items() if "l2_" not in k + k: (v.item() if isinstance(v, torch.Tensor) else v) + for k, v in more_loss.items() + if "l2_" not in k } # validation @@ -1233,7 +1290,13 @@ def run(self) -> None: for k, v in _vmore.items(): if "l2_" not in k: valid_results[k] = ( - valid_results.get(k, 0.0) + v * natoms + valid_results.get(k, 0.0) + + ( + v.item() + if isinstance(v, torch.Tensor) + else v + ) + * natoms ) if sum_natoms > 0: valid_results = { @@ -1246,13 +1309,15 @@ def run(self) -> None: # current task already has loss train_results[task_key] = { - k: v for k, v in more_loss.items() if "l2_" not in k + k: (v.item() if isinstance(v, torch.Tensor) else v) + for k, v in more_loss.items() + if "l2_" not in k } # compute loss for other tasks for _key in self.model_keys: if _key != task_key: - self.optimizer.zero_grad() + self.optimizer.zero_grad(set_to_none=True) _inp, _lab = self.get_data(is_train=True, task_key=_key) _, _loss, _more = self._unwrapped( **_inp, @@ -1260,9 +1325,17 @@ def run(self) -> None: label=_lab, task_key=_key, ) + # Use .item() so the backward graph (and its + # saved activations) can be freed immediately. + # Display passes never call loss.backward(), so + # without this the computation graphs for all + # tasks accumulate simultaneously in GPU memory. train_results[_key] = { - k: v for k, v in _more.items() if "l2_" not in k + k: (v.item() if isinstance(v, torch.Tensor) else v) + for k, v in _more.items() + if "l2_" not in k } + del _loss, _more, _inp, _lab # validation for each task _vdata = self.validation_data[_key] @@ -1285,7 +1358,15 @@ def run(self) -> None: _sum_natoms += natoms for k, v in _vmore.items(): if "l2_" not in k: - _vres[k] = _vres.get(k, 0.0) + v * natoms + _vres[k] = ( + _vres.get(k, 0.0) + + ( + v.item() + if isinstance(v, torch.Tensor) + else v + ) + * natoms + ) if _sum_natoms > 0: _vres = { k: v / _sum_natoms for k, v in _vres.items()