diff --git a/examples/deepseek/ptq.py b/examples/deepseek/ptq.py
index 6b1086a30..d451758c8 100644
--- a/examples/deepseek/ptq.py
+++ b/examples/deepseek/ptq.py
@@ -56,6 +56,7 @@
from modelopt.torch.export.model_config import KV_CACHE_FP8
from modelopt.torch.export.quant_utils import get_quant_config
from modelopt.torch.quantization.nn import TensorQuantizer
+from modelopt.torch.quantization.triton import weight_dequant
from modelopt.torch.quantization.utils import (
is_quantized_column_parallel_linear,
is_quantized_parallel_linear,
@@ -77,7 +78,6 @@
)
import model as deekseep_model # noqa: E402
-from ds_kernel import weight_dequant # noqa: E402
from kernel import act_quant, fp8_gemm # noqa: E402
diff --git a/examples/llm_ptq/README.md b/examples/llm_ptq/README.md
index 9d85c0998..187eed7f1 100755
--- a/examples/llm_ptq/README.md
+++ b/examples/llm_ptq/README.md
@@ -111,6 +111,7 @@ Please reference our [framework scripts](#framework-scripts) and our [docs](http
| DeepSeek V3, R1, V3.1, V3.27 | - | - | - | - | ✅ |
| GLM-4.78 | ✅ | - | - | - | ✅ |
| Kimi K2 | - | - | - | - | ✅ |
+| MiniMax M2.1 | - | - | - | - | ✅ |
| T5 | ✅ | ✅ | ✅ | ✅ | - |
| Whisper | ✅ | ❌ | ❌ | ❌ | - |
diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py
index 93687a8d0..d613e027b 100755
--- a/examples/llm_ptq/example_utils.py
+++ b/examples/llm_ptq/example_utils.py
@@ -230,7 +230,7 @@ def build_quant_cfg(
quant_cfg["quant_cfg"]["*image*"] = {"enable": False}
quant_cfg["quant_cfg"]["*vision*"] = {"enable": False}
- if model_type in ["qwen3moe", "qwen3next"] and qformat == "nvfp4":
+ if model_type in ["qwen3moe", "qwen3next", "minimax"] and qformat == "nvfp4":
# Disable the attention projection layers to retain accuracy
quant_cfg["quant_cfg"]["model*.*attn*in_proj*"] = {"enable": False}
quant_cfg["quant_cfg"]["model*.*attn*q_proj*"] = {"enable": False}
diff --git a/modelopt/torch/export/layer_utils.py b/modelopt/torch/export/layer_utils.py
index 9346e074b..9c68899d9 100755
--- a/modelopt/torch/export/layer_utils.py
+++ b/modelopt/torch/export/layer_utils.py
@@ -346,7 +346,14 @@ def is_moe(module: nn.Module) -> bool:
def is_quantlinear(module: nn.Module) -> bool:
"""Returns whether the module is a quantized linear layer."""
name = type(module).__name__
- return ("QuantLinear" in name or "QuantCompressedLinear" in name) and "lora" not in name.lower()
+ return (
+ any(
+ keyword in name
+ for keyword in ["QuantLinear", "QuantCompressedLinear", "QuantFP8Linear"]
+ )
+ and "lora" not in name.lower()
+ and "ds_kernel" not in name.lower()
+ )
def dup_kv_weight(v: torch.Tensor, head_size: int, num_head: int, tp_size: int) -> torch.Tensor:
diff --git a/modelopt/torch/export/model_utils.py b/modelopt/torch/export/model_utils.py
index 5a24429ad..2b6e3c52b 100755
--- a/modelopt/torch/export/model_utils.py
+++ b/modelopt/torch/export/model_utils.py
@@ -55,6 +55,7 @@
"Deepseek": "deepseek",
"Whisper": "whisper",
"gptoss": "gptoss",
+ "MiniMax": "minimax",
}
__doc__ = f"""Utility functions for model type detection and classification.
diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py
index 5703f4515..89154eb21 100644
--- a/modelopt/torch/export/unified_export_hf.py
+++ b/modelopt/torch/export/unified_export_hf.py
@@ -589,7 +589,9 @@ def _process_quantized_modules(
if is_modelopt_qlora and (hasattr(sub_module, "base_layer")):
continue
- if hasattr(sub_module, "weight_packed"):
+ if hasattr(sub_module, "weight_packed") or (
+ "QuantFP8Linear" in type(sub_module).__name__ and sub_module.weight.element_size() <= 1
+ ):
sub_module.unpack_weight()
if get_quantization_format(sub_module) != QUANTIZATION_NONE:
if is_quantlinear(sub_module):
diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py
index a29d7c754..807c92c2c 100644
--- a/modelopt/torch/quantization/plugins/huggingface.py
+++ b/modelopt/torch/quantization/plugins/huggingface.py
@@ -22,6 +22,8 @@
from typing import TYPE_CHECKING
import torch
+import transformers
+from packaging import version
from torch import Tensor
from torch.nn.functional import linear
@@ -38,7 +40,6 @@
kitchen = None
import torch.nn as nn
-import transformers
from transformers.models.t5.modeling_t5 import T5Attention
from modelopt.torch.opt.dynamic import DynamicModule
@@ -48,6 +49,13 @@
from ..conversion import register
from ..nn import QuantInputBase, QuantModule, QuantModuleRegistry, TensorQuantizer
from ..nn.modules.quant_linear import _QuantLinear
+from ..triton import IS_AVAILABLE as IS_TRITON_AVAILABLE
+
+if IS_TRITON_AVAILABLE:
+ from ..triton import weight_dequant
+else:
+ weight_dequant = None
+
from ..utils import replace_function
from .attention import register_attention_for_kv_quant
from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear, _QuantFunctionalMixin
@@ -57,6 +65,8 @@
__all__ = ["register_hf_attentions_on_the_fly"]
+TRANSFORMERS_VERSION_GE_5_0 = version.parse(transformers.__version__) >= version.parse("5.0.0")
+
class _QuantAttention(QuantModule):
"""Attention class for KV Cache quantization compatible with new_attention_interface in transformers >= 4.48.0."""
@@ -447,10 +457,24 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# If any of the experts are in calibration mode, we will forward all tokens to all experts
# This is used only for calibration, we need to re-calculate the actual outputs again using
# the original top_k
- original_top_k = self.top_k
- self.top_k = self.num_experts
- super().forward(hidden_states)
- self.top_k = original_top_k
+ if TRANSFORMERS_VERSION_GE_5_0:
+ assert hasattr(self, "gate")
+ # Path for transformers >= 5.0
+ original_top_k = self.gate.topk
+ self.gate.topk = self.gate.num_experts
+ super().forward(hidden_states)
+ self.gate.topk = original_top_k
+ else:
+ # Path for transformers < 5.0
+ original_top_k = self.top_k
+ if hasattr(self, "num_experts"):
+ self.top_k = self.num_experts
+ elif hasattr(self, "experts"):
+ self.top_k = self.experts.num_experts
+ else:
+ raise ValueError(f"Could not find num_experts in module {self}")
+ super().forward(hidden_states)
+ self.top_k = original_top_k
return super().forward(hidden_states)
@@ -693,6 +717,53 @@ def unpack_weight(self):
del self.weight_scale
+class _QuantFP8Linear(QuantModule):
+ def _setup(self):
+ self.input_quantizer = TensorQuantizer()
+ self.weight_quantizer = TensorQuantizer()
+ assert self.weight_scale_inv.ndim == 2, "Weight scale inverse must be 2D"
+ assert self.weight.ndim == 2, "Weight must be 2D"
+ self.block_size = max(
+ self.weight.shape[0] // self.weight_scale_inv.shape[0],
+ self.weight.shape[1] // self.weight_scale_inv.shape[1],
+ )
+ assert self.block_size == 128, "Block size must be 128"
+
+ def _get_weight_and_scale_inv(self):
+ if isinstance(self.weight, torch.distributed.tensor.DTensor):
+ weight = self.weight._local_tensor.contiguous()
+ scale_inv = self.weight_scale_inv._local_tensor.contiguous()
+ else:
+ weight = self.weight.contiguous()
+ scale_inv = self.weight_scale_inv.contiguous()
+ return weight, scale_inv
+
+ def forward(self, input: Tensor) -> Tensor:
+ assert weight_dequant is not None, "Triton is not available"
+ if self.weight.element_size() == 1:
+ with torch.cuda.device(self.weight.device):
+ weight, scale_inv = self._get_weight_and_scale_inv()
+ weight = weight_dequant(weight, scale_inv, self.block_size, dtype=input.dtype)
+ else:
+ weight = self.weight
+ return linear(
+ self.input_quantizer(input),
+ self.weight_quantizer(weight),
+ self.bias,
+ )
+
+ def unpack_weight(self):
+ assert weight_dequant is not None, "Triton is not available"
+ with torch.cuda.device(self.weight.device):
+ weight, scale_inv = self._get_weight_and_scale_inv()
+ self.weight = nn.Parameter(
+ weight_dequant(weight, scale_inv, self.block_size, dtype=torch.get_default_dtype()),
+ requires_grad=False,
+ )
+ if hasattr(self, "weight_scale_inv"):
+ del self.weight_scale_inv
+
+
try:
from transformers.models.llama4.modeling_llama4 import Llama4TextExperts, Llama4TextMoe
@@ -796,6 +867,14 @@ def unpack_weight(self):
except ImportError:
pass
+try:
+ from transformers.integrations.finegrained_fp8 import FP8Linear
+
+ if FP8Linear not in QuantModuleRegistry:
+ QuantModuleRegistry.register({FP8Linear: "hf.FP8Linear"})(_QuantFP8Linear)
+except ImportError:
+ pass
+
class _QuantGptOssExperts(_QuantFunctionalMixin):
"""Quantized wrapper for `transformers.GptOssExperts`.
@@ -910,6 +989,17 @@ def register_falcon_linears_on_the_fly(model):
QuantModuleRegistry.register({linear_type: linear_type.__name__})(_QuantLinear)
+def register_minimax_m2_moe_on_the_fly(model):
+ """Register MiniMax M2 MoE modules as a QUANT_MODULE.
+
+ MiniMax M2 MoE modules are defined in the model card, so we need to register them on the fly.
+ """
+ if type(model).__name__ in ["MiniMaxM2ForCausalLM"]:
+ moe_type = type(model.model.layers[0].block_sparse_moe)
+ if QuantModuleRegistry.get(moe_type) is None:
+ QuantModuleRegistry.register({moe_type: moe_type.__name__})(_QuantSparseMoe)
+
+
def _is_supported_hf_model(model):
"""Check if the model a valid model for transformers quantization specific support."""
supported_models = [transformers.PreTrainedModel]
@@ -975,6 +1065,7 @@ def _is_param_grad_enabled_for_auto_quantize(pname, model):
[
register_falcon_linears_on_the_fly,
register_dbrx_moe_on_the_fly,
+ register_minimax_m2_moe_on_the_fly,
register_hf_attentions_on_the_fly,
convert_hf_parallel_linears_on_the_fly,
]
diff --git a/modelopt/torch/quantization/triton/__init__.py b/modelopt/torch/quantization/triton/__init__.py
index 0af34b21f..def70e591 100644
--- a/modelopt/torch/quantization/triton/__init__.py
+++ b/modelopt/torch/quantization/triton/__init__.py
@@ -32,6 +32,7 @@
):
# fp4_kernel works on any CUDA GPU with triton
from .fp4_kernel import *
+ from .fp8_kernel import *
# fp4_kernel_hopper requires compute >= 8.9 (uses tl.float8e4nv)
if torch.cuda.get_device_capability() >= (8, 9):
diff --git a/examples/deepseek/ds_kernel.py b/modelopt/torch/quantization/triton/fp8_kernel.py
similarity index 75%
rename from examples/deepseek/ds_kernel.py
rename to modelopt/torch/quantization/triton/fp8_kernel.py
index 00586acc2..0b3c93e3e 100644
--- a/examples/deepseek/ds_kernel.py
+++ b/modelopt/torch/quantization/triton/fp8_kernel.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -35,32 +35,18 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
-# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: Apache-2.0
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
+"""FP8 Triton Kernel Implementations."""
import torch
import triton
import triton.language as tl
-"""Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py"""
-
@triton.jit
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
- """
- Dequantizes weights using the provided scaling factors and stores the result.
+ """Dequantizes weights using the provided scaling factors and stores the result.
+
+ Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
Args:
x_ptr (tl.pointer): Pointer to the quantized weights.
@@ -86,14 +72,21 @@ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
tl.store(y_ptr + offs, y, mask=mask)
-def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
- """
- Dequantizes the given weight tensor using the provided scale tensor.
+def weight_dequant(
+ x: torch.Tensor,
+ s: torch.Tensor,
+ block_size: int = 128,
+ dtype: torch.dtype = torch.get_default_dtype(),
+) -> torch.Tensor:
+ """Dequantizes the given weight tensor using the provided scale tensor.
+
+ Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
Args:
x (torch.Tensor): The quantized weight tensor of shape (M, N).
s (torch.Tensor): The scale tensor of shape (M//block_size, N//block_size).
block_size (int, optional): The block size to use for dequantization. Defaults to 128.
+ dtype (torch.dtype, optional): The dtype of the output tensor. Defaults to torch.get_default_dtype().
Returns:
torch.Tensor: The dequantized weight tensor of the same shape as `x`.
@@ -104,7 +97,7 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> t
assert x.is_contiguous() and s.is_contiguous(), "Input tensors must be contiguous"
assert x.dim() == 2 and s.dim() == 2, "Input tensors must have 2 dimensions"
M, N = x.size()
- y = torch.empty_like(x, dtype=torch.get_default_dtype())
+ y = torch.empty_like(x, dtype=dtype)
grid = lambda meta: (triton.cdiv(M, meta["BLOCK_SIZE"]), triton.cdiv(N, meta["BLOCK_SIZE"]))
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
return y