test: add a diagnostic script for prefix caching naning#1987
Conversation
📝 WalkthroughWalkthroughDocumentation updated to include a new diagnostic section for prefix caching NaN logprobs validation. A new Python script added to tools/model_diagnostics/ that reproduces and validates prefix caching behavior in vLLM, including multi-iteration generation with prefix cache reuse and NaN logprob detection. Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~12 minutes Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 inconclusive)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Tip Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (2)
tools/model_diagnostics/5.prefix_caching_nan.py (2)
75-84: Unconditionalbreaksilently under-counts NaNs iflogprobs > 1.The
breakat line 84 exits after inspecting only the first token-id entry per step, regardless of whether a NaN was found. Withlogprobs=1andtemperature=0.0(greedy), each step's dict has exactly one entry so this is functionally correct today. However, vLLM returns up tologprobs+1elements per step, meaning iflogprobsis ever bumped above 1, the counter would silently undercount NaN occurrences (at most 1 per step). Removing thebreakmakes the intent clear and future-proof:♻️ Proposed fix
for _tid, lp_obj in step.items(): lp = lp_obj.logprob if hasattr(lp_obj, "logprob") else lp_obj if isinstance(lp, float) and math.isnan(lp): nan_count += 1 - break🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tools/model_diagnostics/5.prefix_caching_nan.py` around lines 75 - 84, The loop over out2.logprobs currently contains an unconditional break after inspecting the first token-id entry, which causes under-counting NaNs when a step contains multiple entries; remove the break so the inner loop over step.items() examines every lp_obj (keep existing hasattr(lp_obj, "logprob") check and NaN detection for lp) so nan_count increments for every NaN in all token-id entries rather than just the first one per step.
35-41: Consider consolidating underif __name__ == "__main__":and moving imports to the top.All module-level logic (argparse, LLM instantiation, generation) runs unconditionally on import. A
__main__guard is the standard protection against accidental execution when scripts are discovered by tooling. Additionally, thevllmimport (lines 40–41) appears mid-file afterparse_args()— while this speeds up--help, it diverges from PEP 8 and can surprise readers.♻️ Suggested structure
import argparse import math +from vllm import LLM, SamplingParams +import vllm MODEL = "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16" ... -parser = argparse.ArgumentParser() -parser.add_argument("--model", type=str, default=MODEL) -parser.add_argument("--tp", type=int, default=TP) -args = parser.parse_args() - -import vllm -from vllm import LLM, SamplingParams - -print(f"vLLM version: {vllm.__version__}") -... +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, default=MODEL) + parser.add_argument("--tp", type=int, default=TP) + args = parser.parse_args() + + print(f"vLLM version: {vllm.__version__}") + ...🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tools/model_diagnostics/5.prefix_caching_nan.py` around lines 35 - 41, Move all runtime logic (argparse setup and calls using parser/args, LLM instantiation, and generation) under a guarded block: wrap code that calls parser.parse_args(), creates the vllm LLM and SamplingParams, and runs generation inside if __name__ == "__main__":. Also relocate imports (import vllm and from vllm import LLM, SamplingParams) to the top of the file with other imports to follow PEP8; keep only lightweight module-level constants like MODEL and TP outside the guard and ensure no heavy side-effect code runs at import time.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@docs/adding-new-models.md`:
- Around line 330-344: The two fenced code blocks under "Expected pass output
(vLLM 0.13.0)" and "Expected fail output (vLLM 0.14.0)" are missing language
specifiers causing markdownlint MD040; update both fences to include a language
(e.g., use ```text) so the pass-output and fail-output blocks explicitly start
with ```text and end with ``` to satisfy MD040 and preserve formatting.
In `@tools/model_diagnostics/5.prefix_caching_nan.py`:
- Line 87: The print statement currently uses an unnecessary f-string: locate
the call print(f"\n Sample logprobs from iteration 2:") and remove the leading
"f" so it becomes a plain string literal; this eliminates the Ruff F541 spurious
f-string warning without changing behavior.
- Around line 29-75: The module defines several module-level mutable bindings
(parser, args, numbers, prompt, llm, sampling_params, out1, out2, nan_count)
which violate the global naming guideline; wrap all runtime code that creates or
mutates these symbols inside an if __name__ == "__main__": block so they become
local to main (keep MODEL, TP, MAX_TOKENS, MAX_MODEL_LEN, COUNT_UP_TO and
imports at module scope), e.g., move creation of
argparse.ArgumentParser()/parser, args = parser.parse_args(), numbers, prompt
construction, LLM() instantiation, SamplingParams(), the two generate calls that
produce out1/out2, and nan_count into that guard; alternatively if you must keep
any of them global, rename using upper snake_case with the G_ prefix (e.g.,
G_PARSER, G_LLM) to satisfy the guideline.
---
Nitpick comments:
In `@tools/model_diagnostics/5.prefix_caching_nan.py`:
- Around line 75-84: The loop over out2.logprobs currently contains an
unconditional break after inspecting the first token-id entry, which causes
under-counting NaNs when a step contains multiple entries; remove the break so
the inner loop over step.items() examines every lp_obj (keep existing
hasattr(lp_obj, "logprob") check and NaN detection for lp) so nan_count
increments for every NaN in all token-id entries rather than just the first one
per step.
- Around line 35-41: Move all runtime logic (argparse setup and calls using
parser/args, LLM instantiation, and generation) under a guarded block: wrap code
that calls parser.parse_args(), creates the vllm LLM and SamplingParams, and
runs generation inside if __name__ == "__main__":. Also relocate imports (import
vllm and from vllm import LLM, SamplingParams) to the top of the file with other
imports to follow PEP8; keep only lightweight module-level constants like MODEL
and TP outside the guard and ensure no heavy side-effect code runs at import
time.
Signed-off-by: Terry Kong <terryk@nvidia.com>
c5c0bee to
5f9f9be
Compare
Signed-off-by: Terry Kong <terryk@nvidia.com>
Signed-off-by: Terry Kong <terryk@nvidia.com>
Signed-off-by: Terry Kong <terryk@nvidia.com>
What does this PR do ?
Add a one line overview of what this PR aims to accomplish.
Issues
List issues that this PR closes (syntax):
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit
Documentation
New Features