Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/deepseek/ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from modelopt.torch.export.model_config import KV_CACHE_FP8
from modelopt.torch.export.quant_utils import get_quant_config
from modelopt.torch.quantization.nn import TensorQuantizer
from modelopt.torch.quantization.triton import weight_dequant
from modelopt.torch.quantization.utils import (
is_quantized_column_parallel_linear,
is_quantized_parallel_linear,
Expand All @@ -77,7 +78,6 @@
)

import model as deekseep_model # noqa: E402
from ds_kernel import weight_dequant # noqa: E402
from kernel import act_quant, fp8_gemm # noqa: E402


Expand Down
1 change: 1 addition & 0 deletions examples/llm_ptq/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ Please reference our [framework scripts](#framework-scripts) and our [docs](http
| DeepSeek V3, R1, V3.1, V3.2<sup>7</sup> | - | - | - | - | ✅ |
| GLM-4.7<sup>8</sup> | ✅ | - | - | - | ✅ |
| Kimi K2 | - | - | - | - | ✅ |
| MiniMax M2.1 | - | - | - | - | ✅ |
| T5 | ✅ | ✅ | ✅ | ✅ | - |
| Whisper | ✅ | ❌ | ❌ | ❌ | - |

Expand Down
2 changes: 1 addition & 1 deletion examples/llm_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def build_quant_cfg(
quant_cfg["quant_cfg"]["*image*"] = {"enable": False}
quant_cfg["quant_cfg"]["*vision*"] = {"enable": False}

if model_type in ["qwen3moe", "qwen3next"] and qformat == "nvfp4":
if model_type in ["qwen3moe", "qwen3next", "minimax"] and qformat == "nvfp4":
# Disable the attention projection layers to retain accuracy
quant_cfg["quant_cfg"]["model*.*attn*in_proj*"] = {"enable": False}
quant_cfg["quant_cfg"]["model*.*attn*q_proj*"] = {"enable": False}
Expand Down
9 changes: 8 additions & 1 deletion modelopt/torch/export/layer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,14 @@ def is_moe(module: nn.Module) -> bool:
def is_quantlinear(module: nn.Module) -> bool:
"""Returns whether the module is a quantized linear layer."""
name = type(module).__name__
return ("QuantLinear" in name or "QuantCompressedLinear" in name) and "lora" not in name.lower()
return (
any(
keyword in name
for keyword in ["QuantLinear", "QuantCompressedLinear", "QuantFP8Linear"]
)
and "lora" not in name.lower()
and "ds_kernel" not in name.lower()
)


def dup_kv_weight(v: torch.Tensor, head_size: int, num_head: int, tp_size: int) -> torch.Tensor:
Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/export/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
"Deepseek": "deepseek",
"Whisper": "whisper",
"gptoss": "gptoss",
"MiniMax": "minimax",
}

__doc__ = f"""Utility functions for model type detection and classification.
Expand Down
4 changes: 3 additions & 1 deletion modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,9 @@ def _process_quantized_modules(
if is_modelopt_qlora and (hasattr(sub_module, "base_layer")):
continue

if hasattr(sub_module, "weight_packed"):
if hasattr(sub_module, "weight_packed") or (
"QuantFP8Linear" in type(sub_module).__name__ and sub_module.weight.element_size() <= 1
):
sub_module.unpack_weight()
if get_quantization_format(sub_module) != QUANTIZATION_NONE:
if is_quantlinear(sub_module):
Expand Down
101 changes: 96 additions & 5 deletions modelopt/torch/quantization/plugins/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from typing import TYPE_CHECKING

import torch
import transformers
from packaging import version
from torch import Tensor
from torch.nn.functional import linear

Expand All @@ -38,7 +40,6 @@
kitchen = None

import torch.nn as nn
import transformers
from transformers.models.t5.modeling_t5 import T5Attention

from modelopt.torch.opt.dynamic import DynamicModule
Expand All @@ -48,6 +49,13 @@
from ..conversion import register
from ..nn import QuantInputBase, QuantModule, QuantModuleRegistry, TensorQuantizer
from ..nn.modules.quant_linear import _QuantLinear
from ..triton import IS_AVAILABLE as IS_TRITON_AVAILABLE

if IS_TRITON_AVAILABLE:
from ..triton import weight_dequant
else:
weight_dequant = None

from ..utils import replace_function
from .attention import register_attention_for_kv_quant
from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear, _QuantFunctionalMixin
Expand All @@ -57,6 +65,8 @@

__all__ = ["register_hf_attentions_on_the_fly"]

TRANSFORMERS_VERSION_GE_5_0 = version.parse(transformers.__version__) >= version.parse("5.0.0")


class _QuantAttention(QuantModule):
"""Attention class for KV Cache quantization compatible with new_attention_interface in transformers >= 4.48.0."""
Expand Down Expand Up @@ -447,10 +457,24 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# If any of the experts are in calibration mode, we will forward all tokens to all experts
# This is used only for calibration, we need to re-calculate the actual outputs again using
# the original top_k
original_top_k = self.top_k
self.top_k = self.num_experts
super().forward(hidden_states)
self.top_k = original_top_k
if TRANSFORMERS_VERSION_GE_5_0:
assert hasattr(self, "gate")
# Path for transformers >= 5.0
original_top_k = self.gate.topk
self.gate.topk = self.gate.num_experts
super().forward(hidden_states)
self.gate.topk = original_top_k
else:
# Path for transformers < 5.0
original_top_k = self.top_k
if hasattr(self, "num_experts"):
self.top_k = self.num_experts
elif hasattr(self, "experts"):
self.top_k = self.experts.num_experts
else:
raise ValueError(f"Could not find num_experts in module {self}")
super().forward(hidden_states)
self.top_k = original_top_k
return super().forward(hidden_states)


Expand Down Expand Up @@ -693,6 +717,53 @@ def unpack_weight(self):
del self.weight_scale


class _QuantFP8Linear(QuantModule):
def _setup(self):
self.input_quantizer = TensorQuantizer()
self.weight_quantizer = TensorQuantizer()
assert self.weight_scale_inv.ndim == 2, "Weight scale inverse must be 2D"
assert self.weight.ndim == 2, "Weight must be 2D"
self.block_size = max(
self.weight.shape[0] // self.weight_scale_inv.shape[0],
self.weight.shape[1] // self.weight_scale_inv.shape[1],
)
assert self.block_size == 128, "Block size must be 128"

def _get_weight_and_scale_inv(self):
if isinstance(self.weight, torch.distributed.tensor.DTensor):
weight = self.weight._local_tensor.contiguous()
scale_inv = self.weight_scale_inv._local_tensor.contiguous()
else:
weight = self.weight.contiguous()
scale_inv = self.weight_scale_inv.contiguous()
return weight, scale_inv

def forward(self, input: Tensor) -> Tensor:
assert weight_dequant is not None, "Triton is not available"
if self.weight.element_size() == 1:
with torch.cuda.device(self.weight.device):
weight, scale_inv = self._get_weight_and_scale_inv()
weight = weight_dequant(weight, scale_inv, self.block_size, dtype=input.dtype)
else:
weight = self.weight
return linear(
self.input_quantizer(input),
self.weight_quantizer(weight),
self.bias,
)

def unpack_weight(self):
assert weight_dequant is not None, "Triton is not available"
with torch.cuda.device(self.weight.device):
weight, scale_inv = self._get_weight_and_scale_inv()
self.weight = nn.Parameter(
weight_dequant(weight, scale_inv, self.block_size, dtype=torch.get_default_dtype()),
requires_grad=False,
)
if hasattr(self, "weight_scale_inv"):
del self.weight_scale_inv
Comment on lines 755 to 764
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Consider specifying dtype in unpack_weight for consistency.

In forward(), weight_dequant is called with dtype=input.dtype, preserving the input's precision. However, unpack_weight() omits dtype, defaulting to torch.get_default_dtype() (typically float32).

If this is intentional for export, a comment would clarify the design. Otherwise, consider accepting/storing a target dtype to ensure consistency.

💡 Suggested improvement
     def unpack_weight(self):
         with torch.cuda.device(self.weight.device):
             weight, scale_inv = self._get_weight_and_scale_inv()
             self.weight = nn.Parameter(
-                weight_dequant(weight, scale_inv, self.block_size),
+                weight_dequant(weight, scale_inv, self.block_size, dtype=torch.bfloat16),
                 requires_grad=False,
             )

Or store the original dtype during _setup and use it here.

🤖 Prompt for AI Agents
In `@modelopt/torch/quantization/plugins/huggingface.py` around lines 746 - 754,
unpack_weight currently calls weight_dequant without specifying dtype, which can
mismatch forward (which uses dtype=input.dtype); fix by storing the target dtype
during _setup (e.g., save self._orig_dtype or self.target_dtype) or by passing
an explicit dtype parameter into unpack_weight, then call weight_dequant(weight,
scale_inv, self.block_size, dtype=self._orig_dtype) so unpacked weights match
forward; update any callers and remove weight_scale_inv as before. Reference:
unpack_weight, forward, weight_dequant, and _setup.



try:
from transformers.models.llama4.modeling_llama4 import Llama4TextExperts, Llama4TextMoe

Expand Down Expand Up @@ -796,6 +867,14 @@ def unpack_weight(self):
except ImportError:
pass

try:
from transformers.integrations.finegrained_fp8 import FP8Linear

if FP8Linear not in QuantModuleRegistry:
QuantModuleRegistry.register({FP8Linear: "hf.FP8Linear"})(_QuantFP8Linear)
except ImportError:
pass


class _QuantGptOssExperts(_QuantFunctionalMixin):
"""Quantized wrapper for `transformers.GptOssExperts`.
Expand Down Expand Up @@ -910,6 +989,17 @@ def register_falcon_linears_on_the_fly(model):
QuantModuleRegistry.register({linear_type: linear_type.__name__})(_QuantLinear)


def register_minimax_m2_moe_on_the_fly(model):
"""Register MiniMax M2 MoE modules as a QUANT_MODULE.

MiniMax M2 MoE modules are defined in the model card, so we need to register them on the fly.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does latest HF transformers not support MiniMax M2 MoE ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it requires 5.0x

"""
if type(model).__name__ in ["MiniMaxM2ForCausalLM"]:
moe_type = type(model.model.layers[0].block_sparse_moe)
if QuantModuleRegistry.get(moe_type) is None:
QuantModuleRegistry.register({moe_type: moe_type.__name__})(_QuantSparseMoe)


def _is_supported_hf_model(model):
"""Check if the model a valid model for transformers quantization specific support."""
supported_models = [transformers.PreTrainedModel]
Expand Down Expand Up @@ -975,6 +1065,7 @@ def _is_param_grad_enabled_for_auto_quantize(pname, model):
[
register_falcon_linears_on_the_fly,
register_dbrx_moe_on_the_fly,
register_minimax_m2_moe_on_the_fly,
register_hf_attentions_on_the_fly,
convert_hf_parallel_linears_on_the_fly,
]
Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/quantization/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
):
# fp4_kernel works on any CUDA GPU with triton
from .fp4_kernel import *
from .fp8_kernel import *

# fp4_kernel_hopper requires compute >= 8.9 (uses tl.float8e4nv)
if torch.cuda.get_device_capability() >= (8, 9):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -35,32 +35,18 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""FP8 Triton Kernel Implementations."""

import torch
import triton
import triton.language as tl

"""Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py"""


@triton.jit
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
"""
Dequantizes weights using the provided scaling factors and stores the result.
"""Dequantizes weights using the provided scaling factors and stores the result.

Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py

Args:
x_ptr (tl.pointer): Pointer to the quantized weights.
Expand All @@ -86,14 +72,21 @@ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
tl.store(y_ptr + offs, y, mask=mask)


def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
"""
Dequantizes the given weight tensor using the provided scale tensor.
def weight_dequant(
x: torch.Tensor,
s: torch.Tensor,
block_size: int = 128,
dtype: torch.dtype = torch.get_default_dtype(),
) -> torch.Tensor:
"""Dequantizes the given weight tensor using the provided scale tensor.

Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py

Args:
x (torch.Tensor): The quantized weight tensor of shape (M, N).
s (torch.Tensor): The scale tensor of shape (M//block_size, N//block_size).
block_size (int, optional): The block size to use for dequantization. Defaults to 128.
dtype (torch.dtype, optional): The dtype of the output tensor. Defaults to torch.get_default_dtype().

Returns:
torch.Tensor: The dequantized weight tensor of the same shape as `x`.
Expand All @@ -104,7 +97,7 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> t
assert x.is_contiguous() and s.is_contiguous(), "Input tensors must be contiguous"
assert x.dim() == 2 and s.dim() == 2, "Input tensors must have 2 dimensions"
M, N = x.size()
y = torch.empty_like(x, dtype=torch.get_default_dtype())
y = torch.empty_like(x, dtype=dtype)
grid = lambda meta: (triton.cdiv(M, meta["BLOCK_SIZE"]), triton.cdiv(N, meta["BLOCK_SIZE"]))
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
return y