fix(pt/pd): fix incompatibility between AutoBatchSize and eval hooks#5181
fix(pt/pd): fix incompatibility between AutoBatchSize and eval hooks#5181njzjz wants to merge 17 commits into
Conversation
for more information, see https://pre-commit.ci
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds RetrySignal and an oom_retry_mode to AutoBatchSize, raises RetrySignal on OOM when retry mode is enabled, wraps descriptor and fitting-last-layer evaluation in Paddle and PyTorch backends to toggle oom-retry mode and retry on RetrySignal, and adds tests verifying hook/state cleanup. ChangesOOM Retry Integration
Sequence Diagram(s)sequenceDiagram
participant DeepEval
participant AutoBatchSize
participant GPU
participant Model
DeepEval->>AutoBatchSize: set_oom_retry_mode(True)
DeepEval->>Model: set_eval_*_hook(True)
DeepEval->>AutoBatchSize: execute(batch)
AutoBatchSize->>GPU: run_batch()
GPU-->>AutoBatchSize: OOM error
AutoBatchSize->>DeepEval: raise RetrySignal
DeepEval->>Model: set_eval_*_hook(False)
DeepEval->>AutoBatchSize: set_oom_retry_mode(False)
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Pull request overview
This pull request attempts to fix an incompatibility between AutoBatchSize and evaluation hooks (eval_descriptor and eval_fitting_last_layer) that caused mismatched descriptor output when OOM errors occurred during batch processing. The issue manifested as descriptors having more frames than the input system (e.g., 241 frames vs 175 in the reported issue).
Changes:
- Introduces a retry mechanism via
RetrySignalexception and@retrydecorator to restart processing from the beginning when OOM occurs during hook-based evaluation - Adds
oom_retry_modeflag toAutoBatchSizeto control whether OOM errors trigger a full retry - Enables retry mode in
eval_descriptorandeval_fitting_last_layermethods for both PyTorch (pt) and Paddle (pd) backends
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 7 comments.
| File | Description |
|---|---|
| deepmd/utils/batch_size.py | Adds RetrySignal exception, retry decorator, oom_retry_mode flag, and logic to raise RetrySignal on OOM when retry mode is enabled |
| deepmd/pt/infer/deep_eval.py | Enables/disables oom_retry_mode around eval calls in eval_descriptor and eval_fitting_last_layer methods |
| deepmd/pd/infer/deep_eval.py | Enables/disables oom_retry_mode around eval calls in eval_descriptor and eval_fitting_last_layer methods |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
deepmd/pt/infer/deep_eval.py (1)
796-812: Ensure hooks & OOM retry mode are reset on exceptions.
Lines 796-812 and 855-871 enable hooks/retry mode but only disable them on the success path. Ifself.eval(...)ormodel.eval_*throws, the flags remain enabled and can corrupt subsequent calls.✅ Safer pattern (apply to both methods)
- if self.auto_batch_size is not None: - self.auto_batch_size.set_oom_retry_mode(True) - model.set_eval_descriptor_hook(True) - self.eval( - coords, - cells, - atom_types, - atomic=False, - fparam=fparam, - aparam=aparam, - **kwargs, - ) - descriptor = model.eval_descriptor() - model.set_eval_descriptor_hook(False) - if self.auto_batch_size is not None: - self.auto_batch_size.set_oom_retry_mode(False) + if self.auto_batch_size is not None: + self.auto_batch_size.set_oom_retry_mode(True) + model.set_eval_descriptor_hook(True) + try: + self.eval( + coords, + cells, + atom_types, + atomic=False, + fparam=fparam, + aparam=aparam, + **kwargs, + ) + descriptor = model.eval_descriptor() + finally: + model.set_eval_descriptor_hook(False) + if self.auto_batch_size is not None: + self.auto_batch_size.set_oom_retry_mode(False)Also applies to: 855-871
deepmd/pd/infer/deep_eval.py (1)
823-842: Ensure hooks & OOM retry mode are reset on exceptions.
Lines 823-842 and 884-901 toggle hooks/retry mode without afinally. Any exception duringself.eval(...)ormodel.eval_*can leave the backend in a bad state.✅ Safer pattern (apply to both methods)
- if self.auto_batch_size is not None: - self.auto_batch_size.set_oom_retry_mode(True) - model.set_eval_descriptor_hook(True) - self.eval( - coords, - cells, - atom_types, - atomic=False, - fparam=fparam, - aparam=aparam, - **kwargs, - ) - descriptor = model.eval_descriptor() - model.set_eval_descriptor_hook(False) - if self.auto_batch_size is not None: - self.auto_batch_size.set_oom_retry_mode(False) + if self.auto_batch_size is not None: + self.auto_batch_size.set_oom_retry_mode(True) + model.set_eval_descriptor_hook(True) + try: + self.eval( + coords, + cells, + atom_types, + atomic=False, + fparam=fparam, + aparam=aparam, + **kwargs, + ) + descriptor = model.eval_descriptor() + finally: + model.set_eval_descriptor_hook(False) + if self.auto_batch_size is not None: + self.auto_batch_size.set_oom_retry_mode(False)Also applies to: 884-901
🤖 Fix all issues with AI agents
In `@deepmd/utils/batch_size.py`:
- Around line 161-162: The OOM handler incorrectly checks the method object
self.set_oom_retry_mode rather than its boolean result, so every OOM triggers a
retry; change the condition to call the method (if self.set_oom_retry_mode():
...) so it evaluates the returned bool before raising RetrySignal, and ensure
the method returns a proper bool.
🧹 Nitpick comments (1)
deepmd/utils/batch_size.py (1)
32-54: Clarify retry semantics (docstring vs behavior).
Line 32 says retries happen for “certain times,” but the wrapper loops forever. Either document it as unbounded or add a cap/backoff.✏️ Minimal doc fix
- """Decorator to retry the function until it succeeds or fails for certain times. + """Decorator to retry the function until it succeeds (no max retry cap).
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #5181 +/- ##
==========================================
+ Coverage 81.95% 82.48% +0.53%
==========================================
Files 714 829 +115
Lines 73441 88810 +15369
Branches 3616 4225 +609
==========================================
+ Hits 60187 73258 +13071
- Misses 12091 14260 +2169
- Partials 1163 1292 +129 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
…deepmd-kit into eval_desc_auto_batch_size
for more information, see https://pre-commit.ci
|
@QuantumMisaka could you test whether this PR fixes your issue? |
njzjz-bot
left a comment
There was a problem hiding this comment.
I found two OOM-retry correctness issues in the PT/PD eval hook paths and left apply-ready suggestions below.
— OpenClaw 2026.5.12 (model: custom-chat-jinzhezeng-group/gpt-5.5)
Co-authored-by: A bot of @njzjz <48687836+njzjz-bot@users.noreply.github.com> Signed-off-by: Jinzhe Zeng <njzjz@qq.com>
wanghan-iapcm
left a comment
There was a problem hiding this comment.
The fix introduces three new code paths with zero test coverage. Please add unit tests for each — they exist solely because of this PR and will silently regress otherwise.
1. RetrySignal raise path in AutoBatchSize.execute()
Target: deepmd/utils/batch_size.py:130-131
Subclass AutoBatchSize with a fake executor that raises an OOM-sentinel on first call. Then:
- With
set_oom_retry_mode(True): assertRetrySignalis raised, its__cause__is the original OOM, andcurrent_batch_sizewas halved before the raise. - With
set_oom_retry_mode(False): assert it returns(0, None)instead — pins the flag actually gates behavior.
2. Recursive retry in eval_descriptor / eval_fitting_last_layer clears the hook between attempts
Target: deepmd/{pt,pd}/infer/deep_eval.py
This is the actual fix for #5180 (descriptor frame count doubling). Per backend:
- Inject an
AutoBatchSizethat raisesRetrySignalon the firstexecute, succeeds on the second. - Call
eval_descriptorwith N frames. - Assert the returned array has exactly N frames (not 2N — the #5180 regression check).
- Spy on
model.set_eval_descriptor_hook; assert the call sequence is[True, False, True, False]. - Assert
auto_batch_size.oom_retry_mode == Falseafter return.
Repeat for eval_fitting_last_layer with set_eval_fitting_last_layer_hook.
3. finally clears state on non-RetrySignal exceptions
Target: same files as Test 2.
The PR moves hook cleanup into finally — a real improvement, since arbitrary errors no longer leave the hook stuck True. Per backend, per method:
- Monkey-patch
self.evalto raise a genericRuntimeError. - Call inside
pytest.raises(RuntimeError). - Assert
set_eval_*_hookwas called withFalse(sequence[True, False]). - Assert
auto_batch_size.oom_retry_mode == False.
Out of scope (do NOT add)
oom_retry_mode == True+ no OOM — exercises no new code.oom_retry_mode == False+ OOM returning(0, None)— pre-existing behavior.
Acceptance
- All new tests pass.
- Mutation check: removing
raise RetrySignalfails Test 1; removing thefinallyfails Tests 2 and 3.
Add regression coverage for the AutoBatchSize RetrySignal path and eval hook cleanup around retry and non-retry failures. Move recursive retries until after finally so hooks are cleared between attempts. Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.5)
|
Addressed review feedback by adding regression coverage and ensuring retry attempts restart only after hook state is cleared. Opened follow-up PR against this PR branch: njzjz#229 Local focused test passed: uv run --no-project --with pytest --with array-api-strict --with array-api-compat --with numpy --with packaging --with typing-extensions --with pyyaml --with wcmatch python -m pytest source/tests/common/test_oom_retry.py -q
# 6 passed— OpenClaw 2026.5.12 (model: custom-chat-jinzhezeng-group/gpt-5.5) |
Avoid recursive RetrySignal handling in eval_descriptor and eval_fitting_last_layer so repeated OOM retries do not consume Python stack frames. The loop still clears hook and retry state between attempts before retrying. Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.5)
|
Addressed Copilot's latest recursion comments by opening another follow-up PR against this PR branch: njzjz#230 It converts the RetrySignal handling in PT/PD Local focused test passed: uv run --no-project --with pytest --with array-api-strict --with array-api-compat --with numpy --with packaging --with typing-extensions --with pyyaml --with wcmatch python -m pytest source/tests/common/test_oom_retry.py -q
# 6 passed— OpenClaw 2026.5.12 (model: custom-chat-jinzhezeng-group/gpt-5.5) |
wanghan-iapcm
left a comment
There was a problem hiding this comment.
The new tests address the three asks from my prior review, with one structural concern worth raising before merge.
DummyDeepEval.eval_descriptor / eval_fitting_last_layer (lines 68-104 of source/tests/common/test_oom_retry.py) are hand-written copies of the production try/except RetrySignal/finally orchestration in deepmd/pt/infer/deep_eval.py and
deepmd/pd/infer/deep_eval.py. The tests call the dummy's methods, not production's. This is testing the test code, not the code that ships.
Concrete evidence the gap matters: commit b5f789ae (the day after the tests were added) refactored the production code from recursion to iteration (while True + retry flag). The dummy still uses recursion. The tests passed throughout — they never
noticed the production-side change because they don't exercise it. If a future refactor drops the finally block from production, the same tests will keep passing.
The mutation-check guarantee from the prior review ("removing the finally fails Tests 2 and 3") does not currently hold for Tests 2 and 3. It holds for Test 1, which calls the real AutoBatchSize.execute.
Mocking is right for a UT — but mock the dependencies of the production method, not the production method itself. Roughly:
class TestEvalDescriptorRetry(unittest.TestCase):
def setUp(self):
# Construct DeepEval without loading a real model.
self.dp_eval = DeepPotPT.__new__(DeepPotPT)
self.dp_eval.dp = MagicMock()
self.dp_eval.dp.model = {"Default": MagicMock()}
self.dp_eval.auto_batch_size = MagicMock()
def test_retry_clears_hook_between_attempts(self):
model = self.dp_eval.dp.model["Default"]
model.eval_descriptor.return_value = np.array([1, 2, 3])
with patch.object(self.dp_eval, "eval",
side_effect=[RetrySignal, None]):
result = self.dp_eval.eval_descriptor(
coords=..., cells=..., atom_types=...,
)
self.assertEqual(
model.set_eval_descriptor_hook.call_args_list,
[call(True), call(False), call(True), call(False)],
)
np.testing.assert_array_equal(result, [1, 2, 3])
def test_finally_clears_hook_on_runtime_error(self):
with patch.object(self.dp_eval, "eval",
side_effect=RuntimeError("non-retry failure")):
with self.assertRaisesRegex(RuntimeError, "non-retry failure"):
self.dp_eval.eval_descriptor(
coords=..., cells=..., atom_types=...,
)
model = self.dp_eval.dp.model["Default"]
self.assertEqual(
model.set_eval_descriptor_hook.call_args_list,
[call(True), call(False)],
) The reusable DummyModel / DummyAutoBatchSize stubs (lines 24-49) are fine as collaborator fakes — keep them. Only DummyDeepEval should go; replace its usages with the real DeepPotPT / DeepPotPD constructed via __new__ (or a tiny init helper)
with the dependencies above patched in. Same shape, similar line count, but the assertions then pin the production code paths.
Test 1 (the AutoBatchSize.execute / RetrySignal raise path) is already correct — no changes needed there.
Replace the hand-written DummyDeepEval orchestration with real PT/PD DeepEval instances constructed through __new__, while mocking their dependencies. This keeps the AutoBatchSize test intact and makes retry/finally assertions pin the production eval_descriptor and eval_fitting_last_layer methods. Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.5)
|
Addressed Wang Han's latest structural test feedback by opening another follow-up PR against this PR branch: njzjz#231 It removes the hand-written Local focused checks passed: uv run --no-project --with pytest --with array-api-strict --with array-api-compat --with numpy --with packaging --with typing-extensions --with pyyaml --with wcmatch python -m pytest source/tests/common/test_oom_retry.py -q
# 2 passed, 8 skipped locally because torch/paddle are unavailable in this lightweight environment
uv run --no-project --with ruff ruff check source/tests/common/test_oom_retry.py
# All checks passed— OpenClaw 2026.5.12 (model: custom-chat-jinzhezeng-group/gpt-5.5) |
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> Signed-off-by: Jinzhe Zeng <njzjz@qq.com>
Production eval helpers convert model outputs with backend precision handling, which rejects integer arrays. Use floating arrays in the mocked descriptor and fitting outputs so the production PT/PD retry tests exercise the intended cleanup path. Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.5)
|
Follow-up to fix my test mistake: njzjz#232 The production PT retry tests in #231 mocked descriptor/fitting outputs with integer NumPy arrays. In CI, the real PT backend then passed those through Local focused checks passed: uv run --no-project --with pytest --with array-api-strict --with array-api-compat --with numpy --with packaging --with typing-extensions --with pyyaml --with wcmatch python -m pytest source/tests/common/test_oom_retry.py -q
# 2 passed, 8 skipped locally because torch/paddle are unavailable here
uv run --no-project --with ruff ruff check source/tests/common/test_oom_retry.py
# All checks passedThe separate TensorFlow — OpenClaw 2026.5.12 (model: custom-chat-jinzhezeng-group/gpt-5.5) |
test(oom): return floating mock outputs
Fix #5180.
Summary by CodeRabbit