Conversation
ca7cba3 to
07ddfbe
Compare
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughUpdates speculative decoding training pipeline to support offline pre-computed data loading from Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant LaunchScript as launch_train.sh
participant TrainingScript as main.py
participant DataModule as make_eagle_supervised_data_module()
participant OfflineDataset as OfflineSupervisedDataset
participant DataCollator as EagleOfflineDataCollator
participant Trainer
User->>LaunchScript: Provide --draft_vocab_cache path
LaunchScript->>TrainingScript: Pass draft_vocab_cache arg
TrainingScript->>DataModule: Call with offline_data_path & train_len
DataModule->>OfflineDataset: Create with dumped_files from cache
OfflineDataset->>OfflineDataset: Load .pt files on initialization
DataModule->>DataCollator: Create with train_len parameter
TrainingScript->>Trainer: Initialize with dataset & collator
Trainer->>OfflineDataset: Request batch samples
OfflineDataset-->>Trainer: Return tensor dict items
Trainer->>DataCollator: Call with feature list
DataCollator->>DataCollator: Pad/truncate tensors to train_len
DataCollator-->>Trainer: Return batched dict
Trainer->>Trainer: Train on batched data
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #668 +/- ##
==========================================
- Coverage 73.73% 73.72% -0.01%
==========================================
Files 199 199
Lines 21165 21176 +11
==========================================
+ Hits 15606 15612 +6
- Misses 5559 5564 +5 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
d61f05e to
6197b58
Compare
There was a problem hiding this comment.
Actionable comments posted: 12
🤖 Fix all issues with AI agents
In `@examples/speculative_decoding/eagle_utils.py`:
- Around line 135-150: The collators are created with return_labels=True but due
to upstream bugs labels never reach the collated batch: LanguageDataCollator
computes labels in _process_chat_sample but does not add them to the returned
dict, and VisionLanguageDataCollator.__init__ does not forward return_labels to
its parent; update the collator implementations so that
LanguageDataCollator._process_chat_sample attaches the computed labels to the
output dict (e.g., output["labels"] = labels) and modify
VisionLanguageDataCollator.__init__ to accept return_labels and pass it to
super().__init__(..., return_labels=return_labels) so that constructing these
classes in eagle_utils.py actually yields batches with labels.
- Around line 53-86: In OfflineSupervisedDataset.__getitem__, the torch.load
call relies on the default behavior changed in PyTorch 2.6; update the
torch.load invocation in __getitem__ (where offline_data is loaded) to
explicitly pass weights_only=True (e.g., torch.load(self.dumped_files[i],
weights_only=True)) so the .pt tensor dictionary is loaded safely and the intent
is clear.
In `@examples/speculative_decoding/main.py`:
- Around line 200-211: The config passed to mtsp.convert when training_args.mode
== "eagle" omits the user-provided draft vocab cache, so add the DataArguments
value into the config before calling mtsp.convert: include a key that maps
EagleConfig.draft_vocab_cache (e.g., "eagle_draft_vocab_cache") to
data_args.draft_vocab_cache alongside the existing "eagle_decoder_type",
"eagle_offline", and "eagle_architecture_config" entries so mtsp.convert(model,
[("eagle", config)]) receives the draft vocab cache.
In `@examples/speculative_decoding/README.md`:
- Line 247: Fix the typo in the README sentence: change "To user 2-layer eagle
with 8192 intermediate size for MLP, set `eagle_config.json` to:" to "To use
2-layer eagle with 8192 intermediate size for MLP, set `eagle_config.json` to:"
in examples/speculative_decoding/README.md so the sentence reads correctly;
search for the sentence fragment "To user 2-layer eagle" to locate the exact
place to edit.
In `@modelopt/torch/speculative/eagle/conversion.py`:
- Around line 42-48: The current dict lookup using config.eagle_decoder_type can
raise a raw KeyError; update conversion logic around the lookup that builds
default_arch_config to validate config.eagle_decoder_type first (accepting
"llama" or "kimik2"), and if it is unsupported raise a clear ValueError with a
descriptive message including the invalid value and allowed options; reference
the existing symbols eagle3_default_config, kimik2_eagle_default_config and
config.eagle_architecture_config so the check sits before merging
({**default_arch_config, **custom_config}) and uses those defaults when valid.
In `@modelopt/torch/speculative/plugins/transformers.py`:
- Around line 847-853: The offline/pre-computed branch that checks for
"base_model_outputs" in kwargs never assigns base_input_embeds, causing a
NameError when later referencing inputs_embeds = base_input_embeds.roll(-1, 1);
fix this by computing base_input_embeds from
self._base_model_embeddings(input_ids) inside that branch as a fallback (ensure
you respect input_ids device/dtype and any attention_mask or position handling
consistent with the online path), so base_input_embeds is always defined before
the roll operation.
In `@modelopt/torch/utils/plugins/transformers_dataset.py`:
- Around line 107-112: When num_streaming_samples is set, the current logic
appends every accessed streamed item into self._raw_samples and builds
self._stream_iterator = itertools.cycle(self._stream_samples), which makes
memory grow unboundedly and causes duplicate returns on subsequent cycles;
change the behavior so streaming mode does not retain all seen items: when
self.num_streaming_samples is not None, do not append streamed items into
self._raw_samples (remove or stop the append in __getitem__/iteration), keep
self._stream_samples as the original shard or an iterator, and build
self._stream_iterator from that non-caching iterator (e.g., use
itertools.islice/shard iterator or a bounded collections.deque(maxlen=...) if a
small cache is needed); ensure the code paths that reference _raw_samples and
_stream_iterator (look for __getitem__, _raw_samples, _stream_samples,
_stream_iterator) are updated so cycles do not iterate over a growing list and
streaming respects memory limits.
- Around line 185-189: The branch guarded by return_labels computes a labels
tensor but never attaches it to tokenized_examples, so callers never receive
labels; modify the block in transformer's dataset preprocessing (the code around
return_labels, tokenized_examples, labels, and IGNORE_TOKEN_ID) to assign the
constructed labels into tokenized_examples (e.g., tokenized_examples["labels"] =
labels) before returning so the caller gets the labels when return_labels is
True.
- Around line 229-250: The __init__ accepts return_labels but doesn't forward it
to the parent, so update VisionLanguageDataCollator/constructor to both store it
(self.return_labels = return_labels) and pass it into the super call (add
return_labels=return_labels in the super().__init__(...) argument list) so the
parent sees the intended value.
- Around line 167-172: _post_process_chat_template currently calls
self.tokenizer.chat_template.replace(...) which raises AttributeError if
chat_template is None; guard the operation by checking
self.tokenizer.chat_template is not None/falsey before calling replace (or
default it to an empty string) and only perform the replace when chat_template
is a str. Update the method so it safely handles a missing chat_template
(reference: _post_process_chat_template, self.tokenizer.chat_template,
REMOVE_THINK_CHAT_TEMPLATE) and preserves existing behavior when chat_template
is present.
- Around line 201-223: The __call__ method currently builds batch items as
either plain strings or message lists but always calls _process_chat_sample;
change __call__ (in the LanguageDataCollator class) to detect when all items in
batch are plain text (strings) and route those to _process_text_sample instead
of _process_chat_sample; keep existing logic that converts ShareGPT
conversations to OpenAI messages via _sharegpt_to_openai_messages and only call
_process_chat_sample when batch contains any message lists (to avoid passing raw
strings into tokenizer.apply_chat_template).
- Around line 86-93: The __getitem__ method currently divides the incoming index
by self.num_shards (index = index // self.num_shards), which double-shards the
already shard-local dataset; remove that division so __getitem__ uses the raw
shard-local index directly, i.e., make __getitem__ use the incoming index to
access self._raw_samples (and keep the streaming fill loop that advances
self._stream_iterator when index >= len(self._raw_samples)) so items aren’t
duplicated or skipped.
🧹 Nitpick comments (7)
examples/speculative_decoding/main.py (2)
201-203: Unclosed file handle.
open(eagle_args.eagle_config)is never explicitly closed. Use awithstatement orpathlibto avoid resource leaks.Proposed fix
- custom_config = ( - json.load(open(eagle_args.eagle_config)) if eagle_args.eagle_config else {} - ) + if eagle_args.eagle_config: + with open(eagle_args.eagle_config) as f: + custom_config = json.load(f) + else: + custom_config = {}
163-169: Model type detection heuristic is fragile.The substring check
"vl" in model_config.model_type.lower()could match unintended model types (e.g., a hypothetical model type containing "vl" as part of another word). Consider matching against known VLM model types explicitly.modelopt/torch/utils/plugins/transformers_dataset.py (2)
36-53: Unhandled role values will raise an unhelpfulKeyError.If a conversation entry has a
rolenot present inrole_mapping(e.g.,"tool","function", or a typo), line 50 raises a rawKeyErrorwith no context. Consider using.get()with a fallback or raising a descriptiveValueError.Suggested improvement
for msg in conversations: - role = role_mapping[msg["role"]] + raw_role = msg["role"] + role = role_mapping.get(raw_role) + if role is None: + raise ValueError( + f"Unknown role '{raw_role}' in conversation. " + f"Supported roles: {list(role_mapping.keys())}" + ) content = msg["content"]
267-301:VisionLanguageDataCollator.__call__does not handle plain-text samples, duplicating validation logic.The
__call__method duplicates the message-format validation logic from the parent class (lines 272–280 are nearly identical to lines 212–220 inLanguageDataCollator.__call__). Consider extracting a shared_normalize_messageshelper to DRY up the conversion logic.examples/speculative_decoding/eagle_utils.py (2)
89-127: Typo in comment and minor observation on collator.Line 104: "consturct" → "construct".
Fix
- # consturct copy slice + # construct copy slice
154-162: Use a conditional +raiseinstead ofassertfor runtime validation; sort dumped files for reproducibility.
assertcan be stripped withpython -O. Also,glob("*.pt")returns files in filesystem-dependent order, which may differ across runs/machines.Proposed fix
- assert not data_args.vlm_processor, "Offline data is not supported for VLM." + if data_args.vlm_processor: + raise ValueError("Offline data is not supported for VLM.") offline_data_path = Path(data_args.offline_data_path) - dumped_files = [str(p) for p in offline_data_path.glob("*.pt")] + dumped_files = sorted(str(p) for p in offline_data_path.glob("*.pt"))modelopt/torch/speculative/plugins/transformers.py (1)
244-262:torch.loadwithoutweights_only— security and deprecation concern.Line 250:
torch.load(draft_vocab_cache)should specifyweights_only=Truesince the expected content is a single tensor. This avoids arbitrary code execution from untrusted files and suppresses the deprecation warning in newer PyTorch versions.Proposed fix
- d2t = torch.load(draft_vocab_cache) + d2t = torch.load(draft_vocab_cache, weights_only=True)
| if self.num_streaming_samples is not None: | ||
| self._raw_samples = [] | ||
| self._stream_samples = shard | ||
| self._stream_iterator = itertools.cycle(self._stream_samples) | ||
| else: | ||
| self._raw_samples = shard |
There was a problem hiding this comment.
Streaming dataset grows unboundedly in memory, defeating the purpose of streaming.
When num_streaming_samples is not None, every accessed item is appended to self._raw_samples (line 91) and kept forever. Over the course of an epoch this accumulates all streamed data in memory, which undermines the memory benefit of using streaming mode. Also, itertools.cycle will restart the iterable from the beginning after exhaustion, but since _raw_samples retains all seen items, the second pass through the cycle will trigger next() calls that return already-cached items, creating duplicates in the list.
Consider using an approach that doesn't retain all streamed items, or document this as intentional behavior for datasets that fit in memory.
🤖 Prompt for AI Agents
In `@modelopt/torch/utils/plugins/transformers_dataset.py` around lines 107 - 112,
When num_streaming_samples is set, the current logic appends every accessed
streamed item into self._raw_samples and builds self._stream_iterator =
itertools.cycle(self._stream_samples), which makes memory grow unboundedly and
causes duplicate returns on subsequent cycles; change the behavior so streaming
mode does not retain all seen items: when self.num_streaming_samples is not
None, do not append streamed items into self._raw_samples (remove or stop the
append in __getitem__/iteration), keep self._stream_samples as the original
shard or an iterator, and build self._stream_iterator from that non-caching
iterator (e.g., use itertools.islice/shard iterator or a bounded
collections.deque(maxlen=...) if a small cache is needed); ensure the code paths
that reference _raw_samples and _stream_iterator (look for __getitem__,
_raw_samples, _stream_samples, _stream_iterator) are updated so cycles do not
iterate over a growing list and streaming respects memory limits.
| def _post_process_chat_template(self): | ||
| # [WAR]: For DeepSeek-V3/R1 tokenizer, we modify the chat_template such that the <think> | ||
| # tokens are preserved for supervised learning. | ||
| self.tokenizer.chat_template = self.tokenizer.chat_template.replace( | ||
| REMOVE_THINK_CHAT_TEMPLATE, "" | ||
| ) |
There was a problem hiding this comment.
_post_process_chat_template crashes with AttributeError when the tokenizer has no chat template.
If the user doesn't supply a chat_template and self.tokenizer.chat_template is None, calling .replace() on None at line 170 raises AttributeError: 'NoneType' object has no attribute 'replace'. This occurs before the friendlier check on line 153–154.
Proposed fix
def _post_process_chat_template(self):
+ if self.tokenizer.chat_template is None:
+ return
# [WAR]: For DeepSeek-V3/R1 tokenizer, we modify the chat_template such that the <think>
# tokens are preserved for supervised learning.
self.tokenizer.chat_template = self.tokenizer.chat_template.replace(
REMOVE_THINK_CHAT_TEMPLATE, ""
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def _post_process_chat_template(self): | |
| # [WAR]: For DeepSeek-V3/R1 tokenizer, we modify the chat_template such that the <think> | |
| # tokens are preserved for supervised learning. | |
| self.tokenizer.chat_template = self.tokenizer.chat_template.replace( | |
| REMOVE_THINK_CHAT_TEMPLATE, "" | |
| ) | |
| def _post_process_chat_template(self): | |
| if self.tokenizer.chat_template is None: | |
| return | |
| # [WAR]: For DeepSeek-V3/R1 tokenizer, we modify the chat_template such that the <think> | |
| # tokens are preserved for supervised learning. | |
| self.tokenizer.chat_template = self.tokenizer.chat_template.replace( | |
| REMOVE_THINK_CHAT_TEMPLATE, "" | |
| ) |
🤖 Prompt for AI Agents
In `@modelopt/torch/utils/plugins/transformers_dataset.py` around lines 167 - 172,
_post_process_chat_template currently calls
self.tokenizer.chat_template.replace(...) which raises AttributeError if
chat_template is None; guard the operation by checking
self.tokenizer.chat_template is not None/falsey before calling replace (or
default it to an empty string) and only perform the replace when chat_template
is a str. Update the method so it safely handles a missing chat_template
(reference: _post_process_chat_template, self.tokenizer.chat_template,
REMOVE_THINK_CHAT_TEMPLATE) and preserves existing behavior when chat_template
is present.
b64548b to
7612abe
Compare
| ) | ||
| return new_examples | ||
|
|
||
| class OfflineSupervisedDataset(Dataset): |
There was a problem hiding this comment.
why not move this to modelopt.torch.utils.plugins.transformers_dataset as well?
There was a problem hiding this comment.
I think these two classes are specific to eagle3 (E.g. it has "aux_hidden_states) and is not useful to other modules. Contents in torch.utils.plugins.transofrmers_dataset.py are algorithm-agnostic and intend to be reused in the future.
eeed7bf to
3698f17
Compare
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
cb67d63 to
027ee36
Compare
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
What does this PR do?
Type of change: Refactor
Overview:
Jira ticket: https://jirasw.nvidia.com/browse/OMNIML-2955
Main changes :
Consolidate Eagle data loading with @ChenhanYu 's implementation of
transformers_dataset.pyRefactor: baked the following logics from
example/main.pytomodelopt/torchfor cleaner example entrance:Implementation refactor: In HF workflow, reuse base modfel's input hidden states as input_embedding, instead of calculating from input_ids. This has two main benefits:
Deprecating eagle1 from the example. It is still available by setting custom config.
Other minor fixes and readme updates.
Usage
# Add a code snippet demonstrating how to use thisTesting
Tested that training curves after changes (both online&offline) is identical with original branch:

Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
Release Notes
New Features
--draft_vocab_cacheparameter--log_stepsconfiguration to training launcherDocumentation
Refactor
--input-datareplaces--input-file)