Skip to content

Support Qwen3 Next MTP load and export#860

Open
cjluo-nv wants to merge 1 commit intomainfrom
chenjiel/support_mtp_qwen_next
Open

Support Qwen3 Next MTP load and export#860
cjluo-nv wants to merge 1 commit intomainfrom
chenjiel/support_mtp_qwen_next

Conversation

@cjluo-nv
Copy link
Collaborator

@cjluo-nv cjluo-nv commented Feb 6, 2026

What does this PR do?

Fix MTP export for Qwen3 Next

Overview: ?

For Qwen3 next, the MTP weights are not stored separately in safetensors. So we use "mtp" weights key to decide if the weights are for MTP or not.

Testing

Qwen3 Next PTQ and check if MTP is in the exported checkpoint.

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

  • Refactor

    • Optimized Multi-Token Prediction weight loading with improved layer detection and handling.
  • Chores

    • Simplified status reporting to display total loaded weights and detected layers.
    • Removed verbose per-file warnings for cleaner console output.

Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
@cjluo-nv cjluo-nv requested a review from a team as a code owner February 6, 2026 00:22
@cjluo-nv cjluo-nv requested a review from sugunav14 February 6, 2026 00:22
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 6, 2026

📝 Walkthrough

Walkthrough

The code refactors MTP weight loading logic in a utility file. Changes include filtering index entries for MTP-related weights, introducing a helper function to extract layer prefixes, and simplifying the loading process to only load missing keys from designated MTP files with cleaner status reporting.

Changes

Cohort / File(s) Summary
MTP Weight Loading Refactor
examples/llm_ptq/example_utils.py
Restructured weight loading to filter for MTP entries, added _extract_layer_prefixes helper, changed weight iteration to use filtered map instead of all files, updated missing key detection logic, and simplified status messaging to report total loaded weights and detected MTP layers.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~22 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main change: adding support for Qwen3 Next MTP load and export functionality, which aligns with the code changes focused on reworking MTP weight loading logic.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch chenjiel/support_mtp_qwen_next

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@cjluo-nv cjluo-nv requested a review from Edwardf0t1 February 6, 2026 00:25
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@examples/llm_ptq/example_utils.py`:
- Around line 376-391: The logic that builds missing_keys is inverted:
mtp_weight_map[filename] is filtered using "if k not in model_state" which
yields keys that don't exist in the model, but model.load_state_dict(...,
strict=False) only accepts keys that DO exist; change the filter to collect keys
present in model_state (i.e., use "if k in model_state") so weights_to_load
contains loadable keys, and add a guard after computing missing_keys to continue
if it's empty to avoid unnecessary load_file() I/O; update references around
mtp_weight_map, model_state, missing_keys, weights, weights_to_load and the call
to model.load_state_dict to implement this fix.
🧹 Nitpick comments (2)
examples/llm_ptq/example_utils.py (2)

345-345: Unclosed file handle — use a with statement.

open(index_file) is never explicitly closed. While CPython's refcount GC will collect it, this leaks under alternative runtimes and emits ResourceWarning.

♻️ Proposed fix
-    index = json.load(open(index_file))
+    with open(index_file) as f:
+        index = json.load(f)

356-366: Inner variable mtp_layer_prefixes shadows the outer variable of the same name.

The local set on line 357 shadows the outer mtp_layer_prefixes (line 339 / assigned on line 370). Not a bug, but it makes the code harder to follow. Consider renaming the inner variable (e.g., prefixes).

♻️ Proposed fix
     def _extract_layer_prefixes(keys):
-        mtp_layer_prefixes = set()
+        prefixes = set()
         for key in keys:
             parts = key.split(".")
             for i, part in enumerate(parts):
                 if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit():
                     prefix = ".".join(parts[: i + 2])
-                    mtp_layer_prefixes.add(prefix)
+                    prefixes.add(prefix)
                     break
 
-        return list(mtp_layer_prefixes)
+        return list(prefixes)

Comment on lines +376 to 391
for filename in mtp_weight_map:
filepath = model_path / filename
if not filepath.exists():
continue

# Find keys that should be in this file
expected_keys = [k for k, v in index["weight_map"].items() if v == filename]

# Check which are missing from the model
missing_keys = [k for k in expected_keys if k not in model_state]

if not missing_keys:
# Even if weights are loaded, record the layer prefixes for exclusion
# Extract unique layer prefixes (e.g., "model.layers.92" from "model.layers.92.mlp.weight")
for key in expected_keys:
# Extract layer prefix like "model.layers.92" or "layers.92"
parts = key.split(".")
for i, part in enumerate(parts):
if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit():
prefix = ".".join(parts[: i + 2]) # e.g., "model.layers.92"
if prefix not in mtp_layer_prefixes:
mtp_layer_prefixes.append(prefix)
break
continue
missing_keys = [k for k in mtp_weight_map[filename] if k not in model_state]

print(f"Loading {len(missing_keys)} missing weights from {filename}...")

# Extract unique layer prefixes for exclusion from quantization
for key in missing_keys:
parts = key.split(".")
for i, part in enumerate(parts):
if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit():
prefix = ".".join(parts[: i + 2]) # e.g., "model.layers.92"
if prefix not in mtp_layer_prefixes:
mtp_layer_prefixes.append(prefix)
break

# Load the weights to CPU first, load_state_dict will handle device placement
weights = load_file(str(filepath), device="cpu")
weights_to_load = {k: v for k, v in weights.items() if k in missing_keys}

# Load into model
missing, unexpected = model.load_state_dict(weights_to_load, strict=False)
model.load_state_dict(weights_to_load, strict=False)
total_loaded += len(weights_to_load)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Get full context of the function
echo "=== Function definition and context ==="
cat -n examples/llm_ptq/example_utils.py | sed -n '350,410p'

echo ""
echo "=== Check git history for this file ==="
git log --oneline -20 -- examples/llm_ptq/example_utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 2765


🏁 Script executed:

#!/bin/bash
# Search for callers and tests
echo "=== Callers of load_mtp_weights_if_needed ==="
rg -n 'load_mtp_weights_if_needed' --type=py -B 2 -A 5

echo ""
echo "=== Test files for example_utils ==="
fd -n '*test*.py' | grep -i example

Repository: NVIDIA/Model-Optimizer

Length of output: 2887


🏁 Script executed:

#!/bin/bash
# Check for MTP-related documentation or comments
echo "=== MTP-related comments in the file ==="
rg -n 'mtp|MTP' examples/llm_ptq/example_utils.py -i -B 1 -A 1

echo ""
echo "=== Check for from_pretrained calls and how they handle weights ==="
rg -n 'from_pretrained' examples/llm_ptq/example_utils.py -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 4662


🏁 Script executed:

#!/bin/bash
# Get the recent changes to understand the PR context
echo "=== Full commit info for latest MTP change ==="
git show --stat aa380d9

echo ""
echo "=== Previous version of load_mtp_weights_if_needed ==="
git show aa380d9^:examples/llm_ptq/example_utils.py | sed -n '319,400p'

Repository: NVIDIA/Model-Optimizer

Length of output: 50434


🏁 Script executed:

#!/bin/bash
# Check if there are model definitions that use MTP
echo "=== Search for MTP module definitions ==="
rg -n 'class.*MTP|mtp.*Module|MTPLayer' --type=py -B 2 -A 5

echo ""
echo "=== Check what from_pretrained actually loads ==="
rg -n 'from_pretrained' examples/llm_ptq/example_utils.py -A 15 | head -60

Repository: NVIDIA/Model-Optimizer

Length of output: 4684


🏁 Script executed:

#!/bin/bash
# Check if there are any tests or usage patterns for Qwen3 Next vs GLM-4.7
echo "=== Search for Qwen3 Next and GLM-4.7 in examples ==="
rg -n 'qwen3|glm.?4' --type=py --type=sh -i examples/llm_ptq/ | head -40

echo ""
echo "=== Check comments around the weight loading logic ==="
sed -n '315,400p' examples/llm_ptq/example_utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 3873


🏁 Script executed:

#!/bin/bash
# Verify state_dict behavior with load_state_dict strict=False
echo "=== Check how load_state_dict strict=False actually behaves ==="
python3 << 'EOF'
import torch

# Demonstrate load_state_dict behavior
model = torch.nn.Linear(10, 5)
print("Original state_dict keys:", model.state_dict().keys())

# Try to load a key that doesn't exist in the model
extra_weights = {'nonexistent_param': torch.randn(5)}
model.load_state_dict(extra_weights, strict=False)
print("After load_state_dict(strict=False) with non-existent key:")
print("  Model state_dict keys:", model.state_dict().keys())
print("  Did 'nonexistent_param' get loaded? No - extra keys are silently ignored")

# Try to load a key that exists
existing_weights = {'weight': torch.randn(5, 10)}
model.load_state_dict(existing_weights, strict=False)
print("\nAfter load_state_dict(strict=False) with existing key:")
print("  Model weight shape:", model.weight.shape)
print("  Weight was loaded successfully")
EOF

Repository: NVIDIA/Model-Optimizer

Length of output: 264


🏁 Script executed:

#!/bin/bash
# Search for documentation or examples of how PyTorch load_state_dict works
echo "=== Understanding the weight loading logic ==="
cat << 'EOF'
PyTorch load_state_dict behavior:
- load_state_dict(state_dict, strict=False) loads values for keys that exist in both:
  1. The state_dict being loaded FROM (weights argument)
  2. The model's current state_dict
- Keys in weights but NOT in model.state_dict() → silently ignored (extra keys)
- Keys in model.state_dict() but NOT in weights → unchanged (remaining keys)

Current code flow:
1. missing_keys = [k for k in mtp_weight_map[filename] if k NOT in model_state]
2. weights_to_load = {k: v for k, v in weights.items() if k in missing_keys}
   → weights_to_load contains only keys that are NOT in model_state
3. model.load_state_dict(weights_to_load, strict=False)
   → This call will IGNORE all keys in weights_to_load because they're NOT in model.state_dict()

Result: Nothing gets loaded even when missing_keys is non-empty!
EOF

echo ""
echo "=== Check if there are actual tests that verify this loads correctly ==="
fd --type f -n '*.py' examples/llm_ptq tests/examples | xargs rg -l 'load_mtp_weights'

Repository: NVIDIA/Model-Optimizer

Length of output: 1329


🏁 Script executed:

#!/bin/bash
# Check if there's any logic that actually depends on weights being loaded
echo "=== How are the loaded prefixes used ==="
rg -n '_mtp_layer_prefixes' examples/llm_ptq/ -A 3 -B 1

Repository: NVIDIA/Model-Optimizer

Length of output: 1083


Correct: missing_keys logic contradicts load_state_dict() behavior, preventing actual weight loading.

Line 382 filters for keys not in model.state_dict(), then line 390 calls load_state_dict(strict=False) which can only load keys that exist in the model. This means weights_to_load will always be rejected as "unexpected keys" and silently ignored.

The primary function (extracting MTP layer prefixes for quantization exclusion) works correctly, but the weight-loading portion is broken. For models like GLM-4.7 that rely on loading weights from separate files, nothing is actually loaded despite the print statement claiming otherwise.

Invert the condition to if k in model_state and add a continue guard when missing_keys is empty to avoid unnecessary file I/O:

Proposed fix
     for filename in mtp_weight_map:
         filepath = model_path / filename
         if not filepath.exists():
             continue
 
         # Check which are missing from the model
-        missing_keys = [k for k in mtp_weight_map[filename] if k not in model_state]
+        missing_keys = [k for k in mtp_weight_map[filename] if k in model_state]
+
+        if not missing_keys:
+            continue
 
         print(f"Loading {len(missing_keys)} missing weights from {filename}...")
         # Load the weights to CPU first, load_state_dict will handle device placement
         weights = load_file(str(filepath), device="cpu")
         weights_to_load = {k: v for k, v in weights.items() if k in missing_keys}

Note: This loads all keys in model_state, which may include already-correctly-loaded weights from Qwen3 Next. Consider adding an additional check (e.g., tensor zero/meta device detection) to avoid redundant loads for models where from_pretrained() already handled MTP weights.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for filename in mtp_weight_map:
filepath = model_path / filename
if not filepath.exists():
continue
# Find keys that should be in this file
expected_keys = [k for k, v in index["weight_map"].items() if v == filename]
# Check which are missing from the model
missing_keys = [k for k in expected_keys if k not in model_state]
if not missing_keys:
# Even if weights are loaded, record the layer prefixes for exclusion
# Extract unique layer prefixes (e.g., "model.layers.92" from "model.layers.92.mlp.weight")
for key in expected_keys:
# Extract layer prefix like "model.layers.92" or "layers.92"
parts = key.split(".")
for i, part in enumerate(parts):
if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit():
prefix = ".".join(parts[: i + 2]) # e.g., "model.layers.92"
if prefix not in mtp_layer_prefixes:
mtp_layer_prefixes.append(prefix)
break
continue
missing_keys = [k for k in mtp_weight_map[filename] if k not in model_state]
print(f"Loading {len(missing_keys)} missing weights from {filename}...")
# Extract unique layer prefixes for exclusion from quantization
for key in missing_keys:
parts = key.split(".")
for i, part in enumerate(parts):
if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit():
prefix = ".".join(parts[: i + 2]) # e.g., "model.layers.92"
if prefix not in mtp_layer_prefixes:
mtp_layer_prefixes.append(prefix)
break
# Load the weights to CPU first, load_state_dict will handle device placement
weights = load_file(str(filepath), device="cpu")
weights_to_load = {k: v for k, v in weights.items() if k in missing_keys}
# Load into model
missing, unexpected = model.load_state_dict(weights_to_load, strict=False)
model.load_state_dict(weights_to_load, strict=False)
total_loaded += len(weights_to_load)
for filename in mtp_weight_map:
filepath = model_path / filename
if not filepath.exists():
continue
# Check which are missing from the model
missing_keys = [k for k in mtp_weight_map[filename] if k in model_state]
if not missing_keys:
continue
print(f"Loading {len(missing_keys)} missing weights from {filename}...")
# Load the weights to CPU first, load_state_dict will handle device placement
weights = load_file(str(filepath), device="cpu")
weights_to_load = {k: v for k, v in weights.items() if k in missing_keys}
# Load into model
model.load_state_dict(weights_to_load, strict=False)
total_loaded += len(weights_to_load)
🤖 Prompt for AI Agents
In `@examples/llm_ptq/example_utils.py` around lines 376 - 391, The logic that
builds missing_keys is inverted: mtp_weight_map[filename] is filtered using "if
k not in model_state" which yields keys that don't exist in the model, but
model.load_state_dict(..., strict=False) only accepts keys that DO exist; change
the filter to collect keys present in model_state (i.e., use "if k in
model_state") so weights_to_load contains loadable keys, and add a guard after
computing missing_keys to continue if it's empty to avoid unnecessary
load_file() I/O; update references around mtp_weight_map, model_state,
missing_keys, weights, weights_to_load and the call to model.load_state_dict to
implement this fix.

@codecov
Copy link

codecov bot commented Feb 6, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 73.72%. Comparing base (452c5a0) to head (aa380d9).

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #860   +/-   ##
=======================================
  Coverage   73.72%   73.72%           
=======================================
  Files         196      196           
  Lines       20457    20457           
=======================================
  Hits        15082    15082           
  Misses       5375     5375           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Contributor

@Edwardf0t1 Edwardf0t1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cjluo-nv Could you add the PTQ command in the description? Also, I'm wondering why we didn't have the MTP issue for Qwen3 Next PTQ previously?

# Extract layer prefix like "model.layers.92" or "layers.92"
parts = key.split(".")
for i, part in enumerate(parts):
if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you also tested GLM-4.7 with the changes?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants