Conversation
Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
📝 WalkthroughWalkthroughThe code refactors MTP weight loading logic in a utility file. Changes include filtering index entries for MTP-related weights, introducing a helper function to extract layer prefixes, and simplifying the loading process to only load missing keys from designated MTP files with cleaner status reporting. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~22 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@examples/llm_ptq/example_utils.py`:
- Around line 376-391: The logic that builds missing_keys is inverted:
mtp_weight_map[filename] is filtered using "if k not in model_state" which
yields keys that don't exist in the model, but model.load_state_dict(...,
strict=False) only accepts keys that DO exist; change the filter to collect keys
present in model_state (i.e., use "if k in model_state") so weights_to_load
contains loadable keys, and add a guard after computing missing_keys to continue
if it's empty to avoid unnecessary load_file() I/O; update references around
mtp_weight_map, model_state, missing_keys, weights, weights_to_load and the call
to model.load_state_dict to implement this fix.
🧹 Nitpick comments (2)
examples/llm_ptq/example_utils.py (2)
345-345: Unclosed file handle — use awithstatement.
open(index_file)is never explicitly closed. While CPython's refcount GC will collect it, this leaks under alternative runtimes and emitsResourceWarning.♻️ Proposed fix
- index = json.load(open(index_file)) + with open(index_file) as f: + index = json.load(f)
356-366: Inner variablemtp_layer_prefixesshadows the outer variable of the same name.The local
seton line 357 shadows the outermtp_layer_prefixes(line 339 / assigned on line 370). Not a bug, but it makes the code harder to follow. Consider renaming the inner variable (e.g.,prefixes).♻️ Proposed fix
def _extract_layer_prefixes(keys): - mtp_layer_prefixes = set() + prefixes = set() for key in keys: parts = key.split(".") for i, part in enumerate(parts): if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit(): prefix = ".".join(parts[: i + 2]) - mtp_layer_prefixes.add(prefix) + prefixes.add(prefix) break - return list(mtp_layer_prefixes) + return list(prefixes)
| for filename in mtp_weight_map: | ||
| filepath = model_path / filename | ||
| if not filepath.exists(): | ||
| continue | ||
|
|
||
| # Find keys that should be in this file | ||
| expected_keys = [k for k, v in index["weight_map"].items() if v == filename] | ||
|
|
||
| # Check which are missing from the model | ||
| missing_keys = [k for k in expected_keys if k not in model_state] | ||
|
|
||
| if not missing_keys: | ||
| # Even if weights are loaded, record the layer prefixes for exclusion | ||
| # Extract unique layer prefixes (e.g., "model.layers.92" from "model.layers.92.mlp.weight") | ||
| for key in expected_keys: | ||
| # Extract layer prefix like "model.layers.92" or "layers.92" | ||
| parts = key.split(".") | ||
| for i, part in enumerate(parts): | ||
| if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit(): | ||
| prefix = ".".join(parts[: i + 2]) # e.g., "model.layers.92" | ||
| if prefix not in mtp_layer_prefixes: | ||
| mtp_layer_prefixes.append(prefix) | ||
| break | ||
| continue | ||
| missing_keys = [k for k in mtp_weight_map[filename] if k not in model_state] | ||
|
|
||
| print(f"Loading {len(missing_keys)} missing weights from {filename}...") | ||
|
|
||
| # Extract unique layer prefixes for exclusion from quantization | ||
| for key in missing_keys: | ||
| parts = key.split(".") | ||
| for i, part in enumerate(parts): | ||
| if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit(): | ||
| prefix = ".".join(parts[: i + 2]) # e.g., "model.layers.92" | ||
| if prefix not in mtp_layer_prefixes: | ||
| mtp_layer_prefixes.append(prefix) | ||
| break | ||
|
|
||
| # Load the weights to CPU first, load_state_dict will handle device placement | ||
| weights = load_file(str(filepath), device="cpu") | ||
| weights_to_load = {k: v for k, v in weights.items() if k in missing_keys} | ||
|
|
||
| # Load into model | ||
| missing, unexpected = model.load_state_dict(weights_to_load, strict=False) | ||
| model.load_state_dict(weights_to_load, strict=False) | ||
| total_loaded += len(weights_to_load) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Get full context of the function
echo "=== Function definition and context ==="
cat -n examples/llm_ptq/example_utils.py | sed -n '350,410p'
echo ""
echo "=== Check git history for this file ==="
git log --oneline -20 -- examples/llm_ptq/example_utils.pyRepository: NVIDIA/Model-Optimizer
Length of output: 2765
🏁 Script executed:
#!/bin/bash
# Search for callers and tests
echo "=== Callers of load_mtp_weights_if_needed ==="
rg -n 'load_mtp_weights_if_needed' --type=py -B 2 -A 5
echo ""
echo "=== Test files for example_utils ==="
fd -n '*test*.py' | grep -i exampleRepository: NVIDIA/Model-Optimizer
Length of output: 2887
🏁 Script executed:
#!/bin/bash
# Check for MTP-related documentation or comments
echo "=== MTP-related comments in the file ==="
rg -n 'mtp|MTP' examples/llm_ptq/example_utils.py -i -B 1 -A 1
echo ""
echo "=== Check for from_pretrained calls and how they handle weights ==="
rg -n 'from_pretrained' examples/llm_ptq/example_utils.py -B 2 -A 2Repository: NVIDIA/Model-Optimizer
Length of output: 4662
🏁 Script executed:
#!/bin/bash
# Get the recent changes to understand the PR context
echo "=== Full commit info for latest MTP change ==="
git show --stat aa380d9
echo ""
echo "=== Previous version of load_mtp_weights_if_needed ==="
git show aa380d9^:examples/llm_ptq/example_utils.py | sed -n '319,400p'Repository: NVIDIA/Model-Optimizer
Length of output: 50434
🏁 Script executed:
#!/bin/bash
# Check if there are model definitions that use MTP
echo "=== Search for MTP module definitions ==="
rg -n 'class.*MTP|mtp.*Module|MTPLayer' --type=py -B 2 -A 5
echo ""
echo "=== Check what from_pretrained actually loads ==="
rg -n 'from_pretrained' examples/llm_ptq/example_utils.py -A 15 | head -60Repository: NVIDIA/Model-Optimizer
Length of output: 4684
🏁 Script executed:
#!/bin/bash
# Check if there are any tests or usage patterns for Qwen3 Next vs GLM-4.7
echo "=== Search for Qwen3 Next and GLM-4.7 in examples ==="
rg -n 'qwen3|glm.?4' --type=py --type=sh -i examples/llm_ptq/ | head -40
echo ""
echo "=== Check comments around the weight loading logic ==="
sed -n '315,400p' examples/llm_ptq/example_utils.pyRepository: NVIDIA/Model-Optimizer
Length of output: 3873
🏁 Script executed:
#!/bin/bash
# Verify state_dict behavior with load_state_dict strict=False
echo "=== Check how load_state_dict strict=False actually behaves ==="
python3 << 'EOF'
import torch
# Demonstrate load_state_dict behavior
model = torch.nn.Linear(10, 5)
print("Original state_dict keys:", model.state_dict().keys())
# Try to load a key that doesn't exist in the model
extra_weights = {'nonexistent_param': torch.randn(5)}
model.load_state_dict(extra_weights, strict=False)
print("After load_state_dict(strict=False) with non-existent key:")
print(" Model state_dict keys:", model.state_dict().keys())
print(" Did 'nonexistent_param' get loaded? No - extra keys are silently ignored")
# Try to load a key that exists
existing_weights = {'weight': torch.randn(5, 10)}
model.load_state_dict(existing_weights, strict=False)
print("\nAfter load_state_dict(strict=False) with existing key:")
print(" Model weight shape:", model.weight.shape)
print(" Weight was loaded successfully")
EOFRepository: NVIDIA/Model-Optimizer
Length of output: 264
🏁 Script executed:
#!/bin/bash
# Search for documentation or examples of how PyTorch load_state_dict works
echo "=== Understanding the weight loading logic ==="
cat << 'EOF'
PyTorch load_state_dict behavior:
- load_state_dict(state_dict, strict=False) loads values for keys that exist in both:
1. The state_dict being loaded FROM (weights argument)
2. The model's current state_dict
- Keys in weights but NOT in model.state_dict() → silently ignored (extra keys)
- Keys in model.state_dict() but NOT in weights → unchanged (remaining keys)
Current code flow:
1. missing_keys = [k for k in mtp_weight_map[filename] if k NOT in model_state]
2. weights_to_load = {k: v for k, v in weights.items() if k in missing_keys}
→ weights_to_load contains only keys that are NOT in model_state
3. model.load_state_dict(weights_to_load, strict=False)
→ This call will IGNORE all keys in weights_to_load because they're NOT in model.state_dict()
Result: Nothing gets loaded even when missing_keys is non-empty!
EOF
echo ""
echo "=== Check if there are actual tests that verify this loads correctly ==="
fd --type f -n '*.py' examples/llm_ptq tests/examples | xargs rg -l 'load_mtp_weights'Repository: NVIDIA/Model-Optimizer
Length of output: 1329
🏁 Script executed:
#!/bin/bash
# Check if there's any logic that actually depends on weights being loaded
echo "=== How are the loaded prefixes used ==="
rg -n '_mtp_layer_prefixes' examples/llm_ptq/ -A 3 -B 1Repository: NVIDIA/Model-Optimizer
Length of output: 1083
Correct: missing_keys logic contradicts load_state_dict() behavior, preventing actual weight loading.
Line 382 filters for keys not in model.state_dict(), then line 390 calls load_state_dict(strict=False) which can only load keys that exist in the model. This means weights_to_load will always be rejected as "unexpected keys" and silently ignored.
The primary function (extracting MTP layer prefixes for quantization exclusion) works correctly, but the weight-loading portion is broken. For models like GLM-4.7 that rely on loading weights from separate files, nothing is actually loaded despite the print statement claiming otherwise.
Invert the condition to if k in model_state and add a continue guard when missing_keys is empty to avoid unnecessary file I/O:
Proposed fix
for filename in mtp_weight_map:
filepath = model_path / filename
if not filepath.exists():
continue
# Check which are missing from the model
- missing_keys = [k for k in mtp_weight_map[filename] if k not in model_state]
+ missing_keys = [k for k in mtp_weight_map[filename] if k in model_state]
+
+ if not missing_keys:
+ continue
print(f"Loading {len(missing_keys)} missing weights from {filename}...")
# Load the weights to CPU first, load_state_dict will handle device placement
weights = load_file(str(filepath), device="cpu")
weights_to_load = {k: v for k, v in weights.items() if k in missing_keys}Note: This loads all keys in model_state, which may include already-correctly-loaded weights from Qwen3 Next. Consider adding an additional check (e.g., tensor zero/meta device detection) to avoid redundant loads for models where from_pretrained() already handled MTP weights.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| for filename in mtp_weight_map: | |
| filepath = model_path / filename | |
| if not filepath.exists(): | |
| continue | |
| # Find keys that should be in this file | |
| expected_keys = [k for k, v in index["weight_map"].items() if v == filename] | |
| # Check which are missing from the model | |
| missing_keys = [k for k in expected_keys if k not in model_state] | |
| if not missing_keys: | |
| # Even if weights are loaded, record the layer prefixes for exclusion | |
| # Extract unique layer prefixes (e.g., "model.layers.92" from "model.layers.92.mlp.weight") | |
| for key in expected_keys: | |
| # Extract layer prefix like "model.layers.92" or "layers.92" | |
| parts = key.split(".") | |
| for i, part in enumerate(parts): | |
| if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit(): | |
| prefix = ".".join(parts[: i + 2]) # e.g., "model.layers.92" | |
| if prefix not in mtp_layer_prefixes: | |
| mtp_layer_prefixes.append(prefix) | |
| break | |
| continue | |
| missing_keys = [k for k in mtp_weight_map[filename] if k not in model_state] | |
| print(f"Loading {len(missing_keys)} missing weights from {filename}...") | |
| # Extract unique layer prefixes for exclusion from quantization | |
| for key in missing_keys: | |
| parts = key.split(".") | |
| for i, part in enumerate(parts): | |
| if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit(): | |
| prefix = ".".join(parts[: i + 2]) # e.g., "model.layers.92" | |
| if prefix not in mtp_layer_prefixes: | |
| mtp_layer_prefixes.append(prefix) | |
| break | |
| # Load the weights to CPU first, load_state_dict will handle device placement | |
| weights = load_file(str(filepath), device="cpu") | |
| weights_to_load = {k: v for k, v in weights.items() if k in missing_keys} | |
| # Load into model | |
| missing, unexpected = model.load_state_dict(weights_to_load, strict=False) | |
| model.load_state_dict(weights_to_load, strict=False) | |
| total_loaded += len(weights_to_load) | |
| for filename in mtp_weight_map: | |
| filepath = model_path / filename | |
| if not filepath.exists(): | |
| continue | |
| # Check which are missing from the model | |
| missing_keys = [k for k in mtp_weight_map[filename] if k in model_state] | |
| if not missing_keys: | |
| continue | |
| print(f"Loading {len(missing_keys)} missing weights from {filename}...") | |
| # Load the weights to CPU first, load_state_dict will handle device placement | |
| weights = load_file(str(filepath), device="cpu") | |
| weights_to_load = {k: v for k, v in weights.items() if k in missing_keys} | |
| # Load into model | |
| model.load_state_dict(weights_to_load, strict=False) | |
| total_loaded += len(weights_to_load) |
🤖 Prompt for AI Agents
In `@examples/llm_ptq/example_utils.py` around lines 376 - 391, The logic that
builds missing_keys is inverted: mtp_weight_map[filename] is filtered using "if
k not in model_state" which yields keys that don't exist in the model, but
model.load_state_dict(..., strict=False) only accepts keys that DO exist; change
the filter to collect keys present in model_state (i.e., use "if k in
model_state") so weights_to_load contains loadable keys, and add a guard after
computing missing_keys to continue if it's empty to avoid unnecessary
load_file() I/O; update references around mtp_weight_map, model_state,
missing_keys, weights, weights_to_load and the call to model.load_state_dict to
implement this fix.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #860 +/- ##
=======================================
Coverage 73.72% 73.72%
=======================================
Files 196 196
Lines 20457 20457
=======================================
Hits 15082 15082
Misses 5375 5375 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Edwardf0t1
left a comment
There was a problem hiding this comment.
@cjluo-nv Could you add the PTQ command in the description? Also, I'm wondering why we didn't have the MTP issue for Qwen3 Next PTQ previously?
| # Extract layer prefix like "model.layers.92" or "layers.92" | ||
| parts = key.split(".") | ||
| for i, part in enumerate(parts): | ||
| if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit(): |
There was a problem hiding this comment.
Did you also tested GLM-4.7 with the changes?
What does this PR do?
Fix MTP export for Qwen3 Next
Overview: ?
For Qwen3 next, the MTP weights are not stored separately in safetensors. So we use "mtp" weights key to decide if the weights are for MTP or not.
Testing
Qwen3 Next PTQ and check if MTP is in the exported checkpoint.
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
Refactor
Chores