Skip to content

[Bug] DPA-4 / SeZM: use_compile=true crashes with InductorError: CantSplit on first batch (variable atom counts) #5454

@SchrodingersCattt

Description

@SchrodingersCattt

Summary

DPA-4 / SeZM training crashes on the first batch when model.use_compile = true is set, with a Torch Inductor CantSplit error. Reverting to model.use_compile = false (the default) makes training run normally. The dataset has variable atom counts per frame (the model sees a wide range of system sizes during training), which appears to be what the inductor split logic chokes on.

This is presumably a known sharp edge of use_compile (the argcheck doc string already calls it "Experimental"), but the failure mode is currently silent until the first forward — there is no upfront validation, no clear log line saying "use_compile requires uniform shapes / disable for ragged batches", and the resulting traceback is opaque to a user trying it for the first time.

Branch / commit: OutisLi/dpmd-public@dpa4 at commit 37cdf46 ("add atom_modify map yes"). This is also the code being upstreamed in PR #5448.

Environment

  • GPU: NVIDIA A800-SXM4-80GB (1×), driver 535.129.03
  • CUDA runtime: 12.6
  • PyTorch: 2.12.0+cu126
  • Python: 3.12
  • DeePMD-kit (this branch): 0.1.dev1+g37cdf4629
  • Other deps: einops==0.8.2, e3nn==0.6.0
  • OS: Ubuntu 22.04 (vemlp-cn-beijing.cr.volces.com/preset-images/python:3.12-ubuntu22.04)

All of the above match the requirements stated in the announcement (CUDA ≥ 12.6, torch ≥ 2.11), so this should be a supported configuration.

Reproduction

Single-task DPA-4 from-scratch training. Relevant model section of input.json:

"model": {
  "type": "dpa4",
  "type_map": ["H", "He", "...", "Og"],
  "descriptor": {
    "type": "dpa4",
    "sel": 200,
    "rcut": 6.0,
    "env_exp": [7, 5],
    "channels": 64,
    "n_radial": 16,
    "lmax": 3,
    "mmax": 1,
    "n_blocks": 3,
    "so2_layers": 4,
    "use_amp": true,
    "precision": "float32"
  },
  "fitting_net": { "type": "ener", "neuron": [240, 240, 240], "precision": "float32" },
  "use_compile": true,
  "enable_tf32": true
}

Training section uses batch_size: "auto:64", default auto_prob over multiple systems, MAE loss, HybridMuon optimizer, WSD LR schedule. Dataset: a heterogeneous collection of crystal/cluster structures, frame atom counts span roughly 1–100 atoms per system across ~1k systems. Nothing exotic about the data on the deepmd-kit side — it's plain deepmd/npy shards.

Launch:

dp --pt train input.json --skip-neighbor-stat

Initialization succeeds (model params print, dataset stats print, Start to train ... steps. is logged). The crash happens on the very first forward / backward, before any optimizer step:

torch._inductor.exc.InductorError: CantSplit: 16*s38 + 16 not divisible by s38 + 1

(raised through the standard torch.distributed.elastic.multiprocessing.errors.wrapper path; full traceback was unfortunately overwritten when the run was relaunched, but the proximate CantSplit line is reproducible and is the same failure mode on a separate run on different but similarly heterogeneous data).

Workaround

Setting "use_compile": false (and dropping enable_tf32) makes training run cleanly on the same input.json, same data, same cluster:

[DEEPMD INFO] Batch 1: trn: mae_e = ..., mae_f = ..., lr = 9.00e-05
[DEEPMD INFO] Batch 1000: trn: mae_e = ..., mae_f = ...
...

Throughput on A800-SXM4-80GB is ~210 s / 1k batches at this model size.

Suggestions

  1. Fail loudly upfront: when use_compile=true is set, validate at Trainer.__init__ time that the user's training data has uniform per-batch shapes (or at least warn that variable atom counts will trigger inductor CantSplit). The current behaviour — a successful initialization followed by an opaque inductor error inside make_fx — is hard to debug without prior knowledge.
  2. Document the constraint: extend the doc_use_compile string in deepmd/utils/argcheck.py (around line 2992) to spell out that the compact-sparse-edges path needs constant-shaped batches; right now it only says "Experimental feature" and lists hardware/version requirements.
  3. Optional: if it's feasible, automatically fall back to the eager path with a single warning when make_fx fails, instead of propagating InductorError. This would make use_compile=true a safe default to recommend in a future release.

cc @OutisLi (via PR #5448).

——Co-authored by Cursor Agent

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions