[Feat]: Specdec Multinode Streaming#1611
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughConverts streaming Dataset to synchronous map-style getitem, replaces server_url with server_urls and adds safetensors retry, removes async resume/prefetch/seed plumbing, rewrites multi-node serve/trainer orchestration and launcher CLI/wiring, and updates tests and launcher YAMLs. ChangesStreaming Dataset and Multi-Node Training
Sequence Diagram(s)sequenceDiagram
participant DistributedSampler
participant StreamingDataset
participant EagleVllmStreamingDataset
participant vLLMEndpoint
DistributedSampler->>StreamingDataset: provide sharded index -> __getitem__(i)
StreamingDataset->>EagleVllmStreamingDataset: validate/tokenize sample, call _fetch(sample)
EagleVllmStreamingDataset->>vLLMEndpoint: synchronous HTTP POST prompt
vLLMEndpoint-->>EagleVllmStreamingDataset: return hidden-state safetensor path
EagleVllmStreamingDataset->>StreamingDataset: return validated batch payload
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 5 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
aae3d2d to
11571c0
Compare
|
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1611 +/- ##
==========================================
+ Coverage 74.58% 77.20% +2.61%
==========================================
Files 482 483 +1
Lines 52943 53033 +90
==========================================
+ Hits 39489 40945 +1456
+ Misses 13454 12088 -1366
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
9a73971 to
0265624
Compare
There was a problem hiding this comment.
Warning
CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.
Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.
Actionable comments posted: 8
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
modelopt/torch/speculative/plugins/hf_training_args.py (1)
30-35: 🛠️ Refactor suggestion | 🟠 Major | ⚡ Quick winAdd
__all__to keep the schema API explicit.This module exports public Pydantic models, so leaving the symbol list implicit works against the repo's Python API convention. Please declare
__all__near the top.As per coding guidelines, "
**/*.py: Define the public API with__all__at the top of each Python module ... to keep the public API explicit and make star-imports safe".🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/speculative/plugins/hf_training_args.py` around lines 30 - 35, Add an explicit __all__ declaration near the top of hf_training_args.py (immediately after the imports) that lists the module's public API: include the names of the Pydantic models and any helper functions or constants you intend to export (i.e., the public class names that subclass BaseModel and any functions that should be visible to consumers), so star-imports are safe and the module follows the repo convention.examples/speculative_decoding/launch_train.sh (1)
80-92:⚠️ Potential issue | 🔴 Critical | ⚡ Quick winStop using
sh -cto run the assembledaccelerate launchcommand.
sh -c "… ${EXTRA_ARGS[*]}"re-parses CLI overrides with shell syntax, so argument boundaries break (e.g.training.output_dir=/tmp/has spacebecomes two separate args) and command substitutions inside overrides execute (e.g.note=$(...)becomesnote=...). This same interpolation also embeds unquoted multi-node variables like$MACHINE_RANK/$HEAD_NODE_IPinto thesh -cstring.Build an argv array and invoke it directly (no
sh -c):Suggested fix
- set -x - start_time=$(date +%s) - sh -c "accelerate launch --mixed_precision bf16 $MULTI_NODE_ARGS ${SCRIPT_DIR}/main.py --config $CONFIG_FILE ${EXTRA_ARGS[*]}" + CMD=(accelerate launch --mixed_precision bf16) + if [[ "$NUM_NODES" != "1" ]]; then + CMD+=( + --multi_gpu + --num_processes "$TOTAL_GPU" + --num_machines "$NUM_NODES" + --machine_rank "${MACHINE_RANK:-$SLURM_PROCID}" + --main_process_ip "$HEAD_NODE_IP" + --main_process_port 29500 + ) + fi + CMD+=("${SCRIPT_DIR}/main.py" --config "$CONFIG_FILE" "${EXTRA_ARGS[@]}") + set -x + start_time=$(date +%s) + "${CMD[@]}" echo "Total time: $(( $(date +%s) - $start_time )) seconds"🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@examples/speculative_decoding/launch_train.sh` around lines 80 - 92, The script currently uses sh -c "accelerate launch ... $MULTI_NODE_ARGS ${SCRIPT_DIR}/main.py --config $CONFIG_FILE ${EXTRA_ARGS[*]}", which causes word-splitting and unintended shell interpolation (affecting MULTI_NODE_ARGS, MACHINE_RANK, HEAD_NODE_IP and EXTRA_ARGS). Fix by constructing an argv array for the command (e.g., base args: "accelerate" "launch" "--mixed_precision" "bf16" plus expanded MULTI_NODE_ARGS tokens, "${SCRIPT_DIR}/main.py", "--config" "$CONFIG_FILE" and each element of EXTRA_ARGS) and then invoke it directly (exec or run the array) instead of using sh -c so argument boundaries and literal values are preserved; ensure you expand EXTRA_ARGS as separate array elements rather than via ${EXTRA_ARGS[*]} or unquoted expansions.tools/launcher/common/eagle3/train_eagle_streaming.sh (1)
299-337:⚠️ Potential issue | 🟠 Major | ⚡ Quick winReject
SERVE_NODES >= SLURM_NNODESup front.In the new dispatch, that configuration turns every node into a serve node, so no trainer ever publishes the rendezvous address or creates
DONE_FILE. The job just waits forever.Suggested fix
NNODES="${SLURM_NNODES:-1}" NODEID="${SLURM_NODEID:-0}" + +if [ "$NNODES" -gt 1 ] && [ "$SERVE_NODES" -ge "$NNODES" ]; then + echo "ERROR: SERVE_NODES must be smaller than SLURM_NNODES." >&2 + exit 1 +fi🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tools/launcher/common/eagle3/train_eagle_streaming.sh` around lines 299 - 337, Add an explicit sanity check that fails fast when SERVE_NODES is >= NNODES (SLURM_NNODES) so we don't turn every node into a serve replica and deadlock waiting for a trainer; locate the topology dispatch logic that uses NNODES, NODEID and SERVE_NODES and before branching (or at start of that block) validate that if SERVE_NODES is set it is strictly less than NNODES, otherwise print an error and exit (include variables SERVE_NODES and NNODES in the message) so the job is rejected up front instead of hanging waiting for DONE_FILE or a trainer rendezvous.
🧹 Nitpick comments (5)
tests/unit/torch/speculative/plugins/test_hf_streaming_dataset.py (1)
128-143: ⚡ Quick winAdd one behavioral test for round-robin dispatch across multiple
server_urls.This refactor introduces multi-server fan-out as core behavior, but current coverage only checks URL normalization and a single-endpoint fetch path. A two-endpoint mock that asserts alternating destinations would catch regressions that silently pin all traffic to the first replica.
Also applies to: 176-214
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unit/torch/speculative/plugins/test_hf_streaming_dataset.py` around lines 128 - 143, Add a new unit test alongside test_server_urls_normalization that verifies round-robin dispatch across multiple server_urls: construct an EagleVllmStreamingConfig with server_urls set to two distinct endpoints, patch or mock the network/send/dispatch call used by the streaming plugin (replace the HTTP client or the function that actually sends requests) to record the target URL for each invocation, then invoke the dataset/request dispatching method twice (or more) and assert that recorded destinations alternate between the two endpoints (e.g., calls[0] == endpoint_a, calls[1] == endpoint_b, calls[2] == endpoint_a). Ensure the test uses pytest/monkeypatch to avoid real network I/O and name the test to reflect round_robin behavior so regressions that pin traffic to a single replica are caught.tools/launcher/examples/moonshotai/Kimi-K2.5/hf_dflash_dryrun.yaml (1)
50-56: ⚡ Quick winSet
data.mode=offlineexplicitly in the dry-run config.This smoke test currently depends on
offline_data_pathside effects to flip the recipe into offline mode. Since the DFlash path now keys offdata.mode, making that override explicit will keep this config aligned with the new contract and avoid brittle coupling to mode inference.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tools/launcher/examples/moonshotai/Kimi-K2.5/hf_dflash_dryrun.yaml` around lines 50 - 56, Add an explicit data.mode=offline entry to the dry-run config so the recipe uses offline mode deterministically; update the HF dry-run YAML that currently lists --dry_run, --config ..., model.* and data.offline_data_path to also include data.mode=offline (rather than relying on the presence of data.offline_data_path to infer mode).tools/launcher/slurm_config.py (1)
16-26: ⚡ Quick winAdd
__all__to make the factory/config API explicit.This module now exposes the launcher-facing
SlurmConfig/slurm_factorypair, but the public surface is still implicit.As per coding guidelines,
**/*.py: Define the public API with__all__at the top of each Python module and re-export submodules in__init__.pyfiles usingfrom .module import *to keep the public API explicit and make star-imports safe.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tools/launcher/slurm_config.py` around lines 16 - 26, This module should explicitly declare its public API: add an __all__ list near the top that exports the launcher-facing symbols, e.g. include "SlurmConfig" and "slurm_factory" (matching the actual class/function names in this file) so star-imports are safe; also update any package __init__.py to re-export with from .slurm_config import * if you want the same public surface at package level.tools/launcher/examples/moonshotai/Kimi-K2.5/specdec_bench.yaml (1)
60-61: ⚡ Quick winReplace the draft checkpoint placeholder with a real config input.
The checked-in example still requires hand-editing
--draft_model_dir, so it is easy to run the file exactly as documented and fail immediately. Promoting this toglobal_varsor defaulting to the standard export path would make the example self-contained.If you want, I can wire this through
global_varsso the example remains override-friendly without the inline TODO.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tools/launcher/examples/moonshotai/Kimi-K2.5/specdec_bench.yaml` around lines 60 - 61, The example hardcodes a placeholder for the DFLASH draft checkpoint (--draft_model_dir) which forces manual edits; update specdec_bench.yaml so the draft checkpoint is provided from a reusable config variable instead: add a global_vars entry (e.g., draft_model_dir) with a sensible default path (the standard HF export location) and replace the inline "- --draft_model_dir /hf-local/nvidia/Kimi-K2.5-DFlash" with a reference to that global var, or alternatively set the draft path to the project’s standard export path by default so the example runs out-of-the-box; ensure the variable name matches any existing templating scheme used in this file.tools/launcher/core.py (1)
16-30: ⚡ Quick winAdd
__all__for the launcher module's public surface.This module exports the launcher entry points and dataclasses, but it still leaves the public API implicit.
As per coding guidelines,
**/*.py: Define the public API with__all__at the top of each Python module and re-export submodules in__init__.pyfiles usingfrom .module import *to keep the public API explicit and make star-imports safe.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tools/launcher/core.py` around lines 16 - 30, Add an explicit __all__ list at the top of this module (immediately after the module docstring) that enumerates the public API symbols exported by tools.launcher.core — specifically the launcher entry points, dataclasses and any executor-builder and job-run-loop function/class names defined in this file; be careful not to accidentally export imported modules like the dataclasses module, nemo_run (imported as run), or yaml unless they are intentionally part of the public surface. Also update the package __init__.py to re-export the core module via from .core import * if you want core's public names to be available at package level.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@examples/speculative_decoding/main.py`:
- Around line 280-286: The code unconditionally sets
training_args.ignore_data_skip = True when recipe.data.mode == "streaming";
change this to only enable ignore_data_skip when the run is effectively
single-epoch or when the user explicitly opts into restart-on-resume: check
training_args.num_train_epochs <= 1 before setting
training_args.ignore_data_skip, and otherwise either leave it False and emit a
warning or require an explicit flag (e.g., resume_without_fast_forward) to opt
in; also consider detecting a resume scenario (checkpoint present in output_dir)
to apply the gating only on resume. Ensure references: recipe.data.mode,
training_args.ignore_data_skip, training_args.num_train_epochs, and
output_dir/checkpoint are used to implement the guard and warning.
In `@modelopt/recipe/config.py`:
- Around line 179-185: Add an explicit __all__ at the top of the module that
lists the module's public recipe types and loader mapping so star-imports are
safe; include the exported class names such as ModelOptDFlashRecipe and
ModelOptEagleRecipe plus the recipe loader mapping name (e.g., RECIPE_LOADERS)
and any public helper functions used to load or validate recipes, then remove
reliance on implicit globals—place the __all__ declaration near the module
imports/header so it clearly documents the public API.
In `@modelopt/torch/speculative/config.py`:
- Around line 68-74: Add an explicit __all__ declaration at the top of this
public config module listing the exported names (e.g. "dflash_offline" and any
other public symbols such as ModeloptField, ModelOptDFlashRecipe or other config
types/constants defined in this file) so the module no longer relies on implicit
exports; place __all__ = ["dflash_offline", ...] near the top of the file and
ensure it includes every identifier meant to be part of the public API so
star-imports remain safe and explicit.
In `@modelopt/torch/speculative/plugins/hf_streaming_dataset.py`:
- Around line 202-214: The current broad except in the loop around
self._fetch(sample) swallows deterministic bugs; narrow it so only
transient/IO/transport errors are treated as fetch misses: catch specific
transport-related exceptions (e.g., network/IO exceptions used by your stack)
and handle them by logging via warn_rank_0 and incrementing
self._consecutive_fail as before, but immediately re-raise critical exceptions
such as RuntimeError and ValueError from _fetch to avoid masking contract
violations; ensure the RuntimeError/ValueError paths bypass the
resample/circuit-breaker logic that references self._consecutive_fail and
config.fail_after_consecutive_skips so real misconfigurations surface.
- Around line 48-65: Add an explicit module public API by defining __all__ near
the top (after imports) that lists the public symbols exported from this file —
include the dataset/class names defined in this module (e.g., the main dataset
class such as HFStreamingDataset or any other public classes/functions you
defined below) and constants like IGNORE_TOKEN_ID; place a single __all__ =
["HFStreamingDataset", "OtherPublicClass", "some_public_function",
"IGNORE_TOKEN_ID"] (adjust names to match the actual symbols in this file) so
star-imports and re-exports are stable.
In `@tools/launcher/common/eagle3/train_eagle_streaming.sh`:
- Around line 145-151: The SERVE_ADDR_FILE and DONE_FILE are global and must be
namespaced per run; modify the assignment of SERVE_ADDR_FILE and DONE_FILE so
they include a stable per-job identifier (e.g., JOB_ID, SLURM_JOB_ID, or
fallback to $$ or a timestamp) and use that same identifier everywhere the
script references SERVE_ADDR_FILE and DONE_FILE (so functions that publish/read
addresses and the head trainer check the same namespaced paths); ensure the
chosen identifier is exported or passed to serve/trainer subprocesses so
concurrent launcher runs write/read distinct files.
In `@tools/launcher/core.py`:
- Around line 289-291: build_slurm_executor currently assumes slurm_config has a
.segment attribute which can raise AttributeError for older/custom config types
patched via set_slurm_config_type; change the access to be defensive (use
getattr(slurm_config, "segment", None)) and only pass the segment option when it
is not None (or otherwise omit it) so older Pydantic-based configs and
ModeloptBaseConfig remain compatible; update any use sites in
build_slurm_executor that reference slurm_config.segment to follow this pattern.
In `@tools/launcher/examples/moonshotai/Kimi-K2.5/specdec_bench.yaml`:
- Around line 76-81: The slurm_config in the specdec_bench.yaml currently sets
container to "vllm/vllm-openai:latest", which may not support DFLASH on
GB200/aarch64; update the container value to a GB200-compatible image tag (pin
to a specific image known to include DFLASH/GB200 support) instead of the
floating "latest" tag so the documented HSG path works out of the box — change
the slurm_config.container entry to that pinned GB200-compatible image.
---
Outside diff comments:
In `@examples/speculative_decoding/launch_train.sh`:
- Around line 80-92: The script currently uses sh -c "accelerate launch ...
$MULTI_NODE_ARGS ${SCRIPT_DIR}/main.py --config $CONFIG_FILE ${EXTRA_ARGS[*]}",
which causes word-splitting and unintended shell interpolation (affecting
MULTI_NODE_ARGS, MACHINE_RANK, HEAD_NODE_IP and EXTRA_ARGS). Fix by constructing
an argv array for the command (e.g., base args: "accelerate" "launch"
"--mixed_precision" "bf16" plus expanded MULTI_NODE_ARGS tokens,
"${SCRIPT_DIR}/main.py", "--config" "$CONFIG_FILE" and each element of
EXTRA_ARGS) and then invoke it directly (exec or run the array) instead of using
sh -c so argument boundaries and literal values are preserved; ensure you expand
EXTRA_ARGS as separate array elements rather than via ${EXTRA_ARGS[*]} or
unquoted expansions.
In `@modelopt/torch/speculative/plugins/hf_training_args.py`:
- Around line 30-35: Add an explicit __all__ declaration near the top of
hf_training_args.py (immediately after the imports) that lists the module's
public API: include the names of the Pydantic models and any helper functions or
constants you intend to export (i.e., the public class names that subclass
BaseModel and any functions that should be visible to consumers), so
star-imports are safe and the module follows the repo convention.
In `@tools/launcher/common/eagle3/train_eagle_streaming.sh`:
- Around line 299-337: Add an explicit sanity check that fails fast when
SERVE_NODES is >= NNODES (SLURM_NNODES) so we don't turn every node into a serve
replica and deadlock waiting for a trainer; locate the topology dispatch logic
that uses NNODES, NODEID and SERVE_NODES and before branching (or at start of
that block) validate that if SERVE_NODES is set it is strictly less than NNODES,
otherwise print an error and exit (include variables SERVE_NODES and NNODES in
the message) so the job is rejected up front instead of hanging waiting for
DONE_FILE or a trainer rendezvous.
---
Nitpick comments:
In `@tests/unit/torch/speculative/plugins/test_hf_streaming_dataset.py`:
- Around line 128-143: Add a new unit test alongside
test_server_urls_normalization that verifies round-robin dispatch across
multiple server_urls: construct an EagleVllmStreamingConfig with server_urls set
to two distinct endpoints, patch or mock the network/send/dispatch call used by
the streaming plugin (replace the HTTP client or the function that actually
sends requests) to record the target URL for each invocation, then invoke the
dataset/request dispatching method twice (or more) and assert that recorded
destinations alternate between the two endpoints (e.g., calls[0] == endpoint_a,
calls[1] == endpoint_b, calls[2] == endpoint_a). Ensure the test uses
pytest/monkeypatch to avoid real network I/O and name the test to reflect
round_robin behavior so regressions that pin traffic to a single replica are
caught.
In `@tools/launcher/core.py`:
- Around line 16-30: Add an explicit __all__ list at the top of this module
(immediately after the module docstring) that enumerates the public API symbols
exported by tools.launcher.core — specifically the launcher entry points,
dataclasses and any executor-builder and job-run-loop function/class names
defined in this file; be careful not to accidentally export imported modules
like the dataclasses module, nemo_run (imported as run), or yaml unless they are
intentionally part of the public surface. Also update the package __init__.py to
re-export the core module via from .core import * if you want core's public
names to be available at package level.
In `@tools/launcher/examples/moonshotai/Kimi-K2.5/hf_dflash_dryrun.yaml`:
- Around line 50-56: Add an explicit data.mode=offline entry to the dry-run
config so the recipe uses offline mode deterministically; update the HF dry-run
YAML that currently lists --dry_run, --config ..., model.* and
data.offline_data_path to also include data.mode=offline (rather than relying on
the presence of data.offline_data_path to infer mode).
In `@tools/launcher/examples/moonshotai/Kimi-K2.5/specdec_bench.yaml`:
- Around line 60-61: The example hardcodes a placeholder for the DFLASH draft
checkpoint (--draft_model_dir) which forces manual edits; update
specdec_bench.yaml so the draft checkpoint is provided from a reusable config
variable instead: add a global_vars entry (e.g., draft_model_dir) with a
sensible default path (the standard HF export location) and replace the inline
"- --draft_model_dir /hf-local/nvidia/Kimi-K2.5-DFlash" with a reference to that
global var, or alternatively set the draft path to the project’s standard export
path by default so the example runs out-of-the-box; ensure the variable name
matches any existing templating scheme used in this file.
In `@tools/launcher/slurm_config.py`:
- Around line 16-26: This module should explicitly declare its public API: add
an __all__ list near the top that exports the launcher-facing symbols, e.g.
include "SlurmConfig" and "slurm_factory" (matching the actual class/function
names in this file) so star-imports are safe; also update any package
__init__.py to re-export with from .slurm_config import * if you want the same
public surface at package level.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 824b234f-9634-4bc3-9fea-29506a691a90
📒 Files selected for processing (20)
examples/specdec_bench/specdec_bench/utils.pyexamples/speculative_decoding/eagle_utils.pyexamples/speculative_decoding/launch_train.shexamples/speculative_decoding/main.pymodelopt/recipe/config.pymodelopt/torch/speculative/config.pymodelopt/torch/speculative/plugins/hf_streaming_dataset.pymodelopt/torch/speculative/plugins/hf_training_args.pytests/examples/speculative_decoding/test_eagle_streaming.pytests/unit/torch/speculative/plugins/test_hf_streaming_dataset.pytools/launcher/common/eagle3/train_eagle_streaming.shtools/launcher/core.pytools/launcher/examples/Qwen/Qwen3-8B/hf_streaming_eagle3_multi_node.yamltools/launcher/examples/moonshotai/Kimi-K2.5/hf_dflash_dryrun.yamltools/launcher/examples/moonshotai/Kimi-K2.5/hf_streaming_dflash.yamltools/launcher/examples/moonshotai/Kimi-K2.5/hf_streaming_dflash_multi_node.yamltools/launcher/examples/moonshotai/Kimi-K2.5/hf_streaming_eagle3.yamltools/launcher/examples/moonshotai/Kimi-K2.5/hf_streaming_eagle3_multi_node.yamltools/launcher/examples/moonshotai/Kimi-K2.5/specdec_bench.yamltools/launcher/slurm_config.py
|
/claude review |
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/speculative/plugins/hf_training_args.py (1)
52-69:⚠️ Potential issue | 🟠 Major | ⚡ Quick winRestore
streaming_prefetchas a deprecated no-op for config compatibility.Line 52 uses
extra="forbid", andDataArgumentsno longer acceptsstreaming_prefetch. Any existing recipe YAML that still setsdata.streaming_prefetchwill now fail validation instead of being tolerated as a no-op.Proposed patch
-from pydantic import BaseModel, ConfigDict, field_validator, model_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator @@ class DataArguments(BaseModel): @@ streaming_model_name: str | None = None # Mirror of the vLLM connector's ``shared_storage_path``; trainer-side allowlist. streaming_shared_storage_path: str | None = None + streaming_prefetch: int = Field( + default=64, + ge=1, + description=( + "Deprecated no-op kept for backward YAML compatibility. " + "Use dataloader_num_workers for fetch concurrency." + ), + )As per coding guidelines: "Preserve config and checkpoint backward compatibility; explicitly handle ModeloptBaseConfig and other Pydantic-based configs to ensure older checkpoints remain loadable when configs change."
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/speculative/plugins/hf_training_args.py` around lines 52 - 69, Add back a declared `streaming_prefetch` field to the config so unknown-field validation (ConfigDict(extra="forbid")) does not reject legacy YAML: declare `streaming_prefetch: bool | None = None` alongside the other fields (e.g., near `mode`, `data_path`, etc.) and treat it as a deprecated no-op; additionally add a minimal deprecation warning when it is set (via the model's init/validator for the DataArguments / hf_training_args model) to inform users it is ignored while preserving backwards compatibility.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Outside diff comments:
In `@modelopt/torch/speculative/plugins/hf_training_args.py`:
- Around line 52-69: Add back a declared `streaming_prefetch` field to the
config so unknown-field validation (ConfigDict(extra="forbid")) does not reject
legacy YAML: declare `streaming_prefetch: bool | None = None` alongside the
other fields (e.g., near `mode`, `data_path`, etc.) and treat it as a deprecated
no-op; additionally add a minimal deprecation warning when it is set (via the
model's init/validator for the DataArguments / hf_training_args model) to inform
users it is ignored while preserving backwards compatibility.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 7f6214fe-beb9-4efe-a615-05835a00ffcc
📒 Files selected for processing (5)
examples/speculative_decoding/eagle_utils.pyexamples/speculative_decoding/main.pymodelopt/torch/speculative/plugins/hf_training_args.pytests/unit/torch/speculative/plugins/test_hf_streaming_dataset.pytools/launcher/examples/Qwen/Qwen3-8B/hf_streaming_eagle3.yaml
💤 Files with no reviewable changes (1)
- tools/launcher/examples/Qwen/Qwen3-8B/hf_streaming_eagle3.yaml
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/unit/torch/speculative/plugins/test_hf_streaming_dataset.py
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
|
/claude review |
There was a problem hiding this comment.
Claude review summary
Findings: CRITICAL: 0 · IMPORTANT: 1 · SUGGESTION: 2
Most impactful finding
- [IMPORTANT Compatibility]
streaming_prefetchwas removed fromDataArguments(which hasextra="forbid"), but the PR description claims the field is "kept so existing yamls still validate". Existing recipe YAMLs withdata.streaming_prefetch=…will now hard-fail Pydantic validation rather than be silently tolerated. Re-declare it as a deprecated no-op field. (CodeRabbit raised the same issue and it remains unaddressed.)
Suggestions
- DataLoader workers all start their round-robin cursor at
_rr=0, so cold-start sends the first request from every worker × every rank toserver_urls[0]— exactly the flood pattern the PR's docstring warns about. Stagger initial cursor byworker_id(and/or rank). train_eagle_streaming.shfalls back to/scratchspace/eagle3when derivingout_dir, but it now serves DFlash runs too — the failure mode is silent if a user forgets to forwardtraining.output_dir=.
Overall assessment
Solid refactor — the map-style + DistributedSampler design is the right shape, the tests are good (especially the resume test that proves no re-fetch on skip_first_batches), and the writer-race retry + transient-error narrowing are nice clean-ups. The one IMPORTANT finding is a real but easy backward-compat fix; the rest are nice-to-have. Risk level: low-to-moderate (a single-line Pydantic addition closes the only blocking concern).
Algorithm correctness, mode/state composition, and HF/TRT-LLM export paths are unaffected by this PR (it stays inside the data-loading layer).
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
What does this PR do?
Type of change: New feature
Scales speculative-decoding streaming training to multiple nodes. The streaming
dataset is rewritten from a rank-0-fetches-and-broadcasts
IterableDatasetto amap-style
Dataset: each rank fetches only its ownDistributedSamplershard, soaggregate read bandwidth scales with trainer ranks. Fetch concurrency now comes from
dataloader_num_workersinstead of an in-process producer thread. A fetcher canround-robin across multiple vLLM endpoints (
server_urls), and multi-node acceleratelaunch is fixed (
--multi_gpu, explicit--machine_rank, optional Slurm--segmentto keep nodes in one NVLink domain).
Usage
Testing
tests/unit/torch/speculative/plugins/test_hf_streaming_dataset.pyfor the map-style dataset.single- and multi-node; all converged and exported. Throughput below.
Training throughput (trainer vs serve scaling, Qwen3-8B EAGLE3)
Rough throughput comparison across multi-node topologies for the streaming pipeline (H100). Per-device batch size = 1, so throughput (samples/s) ∝ steps/s × world_size.
Training throughput (Kimi-K2.5-NVFP4, GB200)
Per-device batch = 1, so throughput (samples/s) ∝ steps/s × world_size. Single-node = 1 serve (TP=4) + 1 trainer node (4 GPU, world 4); multi-node = 2 serves (TP=4) + 2 trainer nodes (4 GPU each, world 8). All four runs converged (loss decreasing) and exported.
EAGLE3 scales better than DFlash because DFlash captures 6 hidden-state layers (~350 MB/sample) vs EAGLE3's 4 (~235 MB), so DFlash is more data-movement-bound across nodes.
Before your PR is "Ready for review"
EagleVllmStreamingConfig.server_urlis nowserver_urls;streaming_prefetchis now a no-op (kept so existing yamls still validate).CONTRIBUTING.md: N/ASummary by CodeRabbit
New Features
Bug Fixes
Refactor
Behavior