diff --git a/examples/deepseek/ptq.py b/examples/deepseek/ptq.py index 6b1086a30..d451758c8 100644 --- a/examples/deepseek/ptq.py +++ b/examples/deepseek/ptq.py @@ -56,6 +56,7 @@ from modelopt.torch.export.model_config import KV_CACHE_FP8 from modelopt.torch.export.quant_utils import get_quant_config from modelopt.torch.quantization.nn import TensorQuantizer +from modelopt.torch.quantization.triton import weight_dequant from modelopt.torch.quantization.utils import ( is_quantized_column_parallel_linear, is_quantized_parallel_linear, @@ -77,7 +78,6 @@ ) import model as deekseep_model # noqa: E402 -from ds_kernel import weight_dequant # noqa: E402 from kernel import act_quant, fp8_gemm # noqa: E402 diff --git a/examples/llm_ptq/README.md b/examples/llm_ptq/README.md index 9d85c0998..187eed7f1 100755 --- a/examples/llm_ptq/README.md +++ b/examples/llm_ptq/README.md @@ -111,6 +111,7 @@ Please reference our [framework scripts](#framework-scripts) and our [docs](http | DeepSeek V3, R1, V3.1, V3.27 | - | - | - | - | ✅ | | GLM-4.78 | ✅ | - | - | - | ✅ | | Kimi K2 | - | - | - | - | ✅ | +| MiniMax M2.1 | - | - | - | - | ✅ | | T5 | ✅ | ✅ | ✅ | ✅ | - | | Whisper | ✅ | ❌ | ❌ | ❌ | - | diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 93687a8d0..d613e027b 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -230,7 +230,7 @@ def build_quant_cfg( quant_cfg["quant_cfg"]["*image*"] = {"enable": False} quant_cfg["quant_cfg"]["*vision*"] = {"enable": False} - if model_type in ["qwen3moe", "qwen3next"] and qformat == "nvfp4": + if model_type in ["qwen3moe", "qwen3next", "minimax"] and qformat == "nvfp4": # Disable the attention projection layers to retain accuracy quant_cfg["quant_cfg"]["model*.*attn*in_proj*"] = {"enable": False} quant_cfg["quant_cfg"]["model*.*attn*q_proj*"] = {"enable": False} diff --git a/modelopt/torch/export/layer_utils.py b/modelopt/torch/export/layer_utils.py index 9346e074b..9c68899d9 100755 --- a/modelopt/torch/export/layer_utils.py +++ b/modelopt/torch/export/layer_utils.py @@ -346,7 +346,14 @@ def is_moe(module: nn.Module) -> bool: def is_quantlinear(module: nn.Module) -> bool: """Returns whether the module is a quantized linear layer.""" name = type(module).__name__ - return ("QuantLinear" in name or "QuantCompressedLinear" in name) and "lora" not in name.lower() + return ( + any( + keyword in name + for keyword in ["QuantLinear", "QuantCompressedLinear", "QuantFP8Linear"] + ) + and "lora" not in name.lower() + and "ds_kernel" not in name.lower() + ) def dup_kv_weight(v: torch.Tensor, head_size: int, num_head: int, tp_size: int) -> torch.Tensor: diff --git a/modelopt/torch/export/model_utils.py b/modelopt/torch/export/model_utils.py index 5a24429ad..2b6e3c52b 100755 --- a/modelopt/torch/export/model_utils.py +++ b/modelopt/torch/export/model_utils.py @@ -55,6 +55,7 @@ "Deepseek": "deepseek", "Whisper": "whisper", "gptoss": "gptoss", + "MiniMax": "minimax", } __doc__ = f"""Utility functions for model type detection and classification. diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 5703f4515..89154eb21 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -589,7 +589,9 @@ def _process_quantized_modules( if is_modelopt_qlora and (hasattr(sub_module, "base_layer")): continue - if hasattr(sub_module, "weight_packed"): + if hasattr(sub_module, "weight_packed") or ( + "QuantFP8Linear" in type(sub_module).__name__ and sub_module.weight.element_size() <= 1 + ): sub_module.unpack_weight() if get_quantization_format(sub_module) != QUANTIZATION_NONE: if is_quantlinear(sub_module): diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index a29d7c754..807c92c2c 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -22,6 +22,8 @@ from typing import TYPE_CHECKING import torch +import transformers +from packaging import version from torch import Tensor from torch.nn.functional import linear @@ -38,7 +40,6 @@ kitchen = None import torch.nn as nn -import transformers from transformers.models.t5.modeling_t5 import T5Attention from modelopt.torch.opt.dynamic import DynamicModule @@ -48,6 +49,13 @@ from ..conversion import register from ..nn import QuantInputBase, QuantModule, QuantModuleRegistry, TensorQuantizer from ..nn.modules.quant_linear import _QuantLinear +from ..triton import IS_AVAILABLE as IS_TRITON_AVAILABLE + +if IS_TRITON_AVAILABLE: + from ..triton import weight_dequant +else: + weight_dequant = None + from ..utils import replace_function from .attention import register_attention_for_kv_quant from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear, _QuantFunctionalMixin @@ -57,6 +65,8 @@ __all__ = ["register_hf_attentions_on_the_fly"] +TRANSFORMERS_VERSION_GE_5_0 = version.parse(transformers.__version__) >= version.parse("5.0.0") + class _QuantAttention(QuantModule): """Attention class for KV Cache quantization compatible with new_attention_interface in transformers >= 4.48.0.""" @@ -447,10 +457,24 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # If any of the experts are in calibration mode, we will forward all tokens to all experts # This is used only for calibration, we need to re-calculate the actual outputs again using # the original top_k - original_top_k = self.top_k - self.top_k = self.num_experts - super().forward(hidden_states) - self.top_k = original_top_k + if TRANSFORMERS_VERSION_GE_5_0: + assert hasattr(self, "gate") + # Path for transformers >= 5.0 + original_top_k = self.gate.topk + self.gate.topk = self.gate.num_experts + super().forward(hidden_states) + self.gate.topk = original_top_k + else: + # Path for transformers < 5.0 + original_top_k = self.top_k + if hasattr(self, "num_experts"): + self.top_k = self.num_experts + elif hasattr(self, "experts"): + self.top_k = self.experts.num_experts + else: + raise ValueError(f"Could not find num_experts in module {self}") + super().forward(hidden_states) + self.top_k = original_top_k return super().forward(hidden_states) @@ -693,6 +717,53 @@ def unpack_weight(self): del self.weight_scale +class _QuantFP8Linear(QuantModule): + def _setup(self): + self.input_quantizer = TensorQuantizer() + self.weight_quantizer = TensorQuantizer() + assert self.weight_scale_inv.ndim == 2, "Weight scale inverse must be 2D" + assert self.weight.ndim == 2, "Weight must be 2D" + self.block_size = max( + self.weight.shape[0] // self.weight_scale_inv.shape[0], + self.weight.shape[1] // self.weight_scale_inv.shape[1], + ) + assert self.block_size == 128, "Block size must be 128" + + def _get_weight_and_scale_inv(self): + if isinstance(self.weight, torch.distributed.tensor.DTensor): + weight = self.weight._local_tensor.contiguous() + scale_inv = self.weight_scale_inv._local_tensor.contiguous() + else: + weight = self.weight.contiguous() + scale_inv = self.weight_scale_inv.contiguous() + return weight, scale_inv + + def forward(self, input: Tensor) -> Tensor: + assert weight_dequant is not None, "Triton is not available" + if self.weight.element_size() == 1: + with torch.cuda.device(self.weight.device): + weight, scale_inv = self._get_weight_and_scale_inv() + weight = weight_dequant(weight, scale_inv, self.block_size, dtype=input.dtype) + else: + weight = self.weight + return linear( + self.input_quantizer(input), + self.weight_quantizer(weight), + self.bias, + ) + + def unpack_weight(self): + assert weight_dequant is not None, "Triton is not available" + with torch.cuda.device(self.weight.device): + weight, scale_inv = self._get_weight_and_scale_inv() + self.weight = nn.Parameter( + weight_dequant(weight, scale_inv, self.block_size, dtype=torch.get_default_dtype()), + requires_grad=False, + ) + if hasattr(self, "weight_scale_inv"): + del self.weight_scale_inv + + try: from transformers.models.llama4.modeling_llama4 import Llama4TextExperts, Llama4TextMoe @@ -796,6 +867,14 @@ def unpack_weight(self): except ImportError: pass +try: + from transformers.integrations.finegrained_fp8 import FP8Linear + + if FP8Linear not in QuantModuleRegistry: + QuantModuleRegistry.register({FP8Linear: "hf.FP8Linear"})(_QuantFP8Linear) +except ImportError: + pass + class _QuantGptOssExperts(_QuantFunctionalMixin): """Quantized wrapper for `transformers.GptOssExperts`. @@ -910,6 +989,17 @@ def register_falcon_linears_on_the_fly(model): QuantModuleRegistry.register({linear_type: linear_type.__name__})(_QuantLinear) +def register_minimax_m2_moe_on_the_fly(model): + """Register MiniMax M2 MoE modules as a QUANT_MODULE. + + MiniMax M2 MoE modules are defined in the model card, so we need to register them on the fly. + """ + if type(model).__name__ in ["MiniMaxM2ForCausalLM"]: + moe_type = type(model.model.layers[0].block_sparse_moe) + if QuantModuleRegistry.get(moe_type) is None: + QuantModuleRegistry.register({moe_type: moe_type.__name__})(_QuantSparseMoe) + + def _is_supported_hf_model(model): """Check if the model a valid model for transformers quantization specific support.""" supported_models = [transformers.PreTrainedModel] @@ -975,6 +1065,7 @@ def _is_param_grad_enabled_for_auto_quantize(pname, model): [ register_falcon_linears_on_the_fly, register_dbrx_moe_on_the_fly, + register_minimax_m2_moe_on_the_fly, register_hf_attentions_on_the_fly, convert_hf_parallel_linears_on_the_fly, ] diff --git a/modelopt/torch/quantization/triton/__init__.py b/modelopt/torch/quantization/triton/__init__.py index 0af34b21f..def70e591 100644 --- a/modelopt/torch/quantization/triton/__init__.py +++ b/modelopt/torch/quantization/triton/__init__.py @@ -32,6 +32,7 @@ ): # fp4_kernel works on any CUDA GPU with triton from .fp4_kernel import * + from .fp8_kernel import * # fp4_kernel_hopper requires compute >= 8.9 (uses tl.float8e4nv) if torch.cuda.get_device_capability() >= (8, 9): diff --git a/examples/deepseek/ds_kernel.py b/modelopt/torch/quantization/triton/fp8_kernel.py similarity index 75% rename from examples/deepseek/ds_kernel.py rename to modelopt/torch/quantization/triton/fp8_kernel.py index 00586acc2..0b3c93e3e 100644 --- a/examples/deepseek/ds_kernel.py +++ b/modelopt/torch/quantization/triton/fp8_kernel.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -35,32 +35,18 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +"""FP8 Triton Kernel Implementations.""" import torch import triton import triton.language as tl -"""Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py""" - @triton.jit def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): - """ - Dequantizes weights using the provided scaling factors and stores the result. + """Dequantizes weights using the provided scaling factors and stores the result. + + Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py Args: x_ptr (tl.pointer): Pointer to the quantized weights. @@ -86,14 +72,21 @@ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): tl.store(y_ptr + offs, y, mask=mask) -def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor: - """ - Dequantizes the given weight tensor using the provided scale tensor. +def weight_dequant( + x: torch.Tensor, + s: torch.Tensor, + block_size: int = 128, + dtype: torch.dtype = torch.get_default_dtype(), +) -> torch.Tensor: + """Dequantizes the given weight tensor using the provided scale tensor. + + Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py Args: x (torch.Tensor): The quantized weight tensor of shape (M, N). s (torch.Tensor): The scale tensor of shape (M//block_size, N//block_size). block_size (int, optional): The block size to use for dequantization. Defaults to 128. + dtype (torch.dtype, optional): The dtype of the output tensor. Defaults to torch.get_default_dtype(). Returns: torch.Tensor: The dequantized weight tensor of the same shape as `x`. @@ -104,7 +97,7 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> t assert x.is_contiguous() and s.is_contiguous(), "Input tensors must be contiguous" assert x.dim() == 2 and s.dim() == 2, "Input tensors must have 2 dimensions" M, N = x.size() - y = torch.empty_like(x, dtype=torch.get_default_dtype()) + y = torch.empty_like(x, dtype=dtype) grid = lambda meta: (triton.cdiv(M, meta["BLOCK_SIZE"]), triton.cdiv(N, meta["BLOCK_SIZE"])) weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) return y