-
Notifications
You must be signed in to change notification settings - Fork 281
Support MiniMax M2.1 (FP8 checkpoint) #817
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0bb9291
3b95a19
d45fd45
494f5b4
1881b8a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,6 +22,8 @@ | |
| from typing import TYPE_CHECKING | ||
|
|
||
| import torch | ||
| import transformers | ||
| from packaging import version | ||
| from torch import Tensor | ||
| from torch.nn.functional import linear | ||
|
|
||
|
|
@@ -38,7 +40,6 @@ | |
| kitchen = None | ||
|
|
||
| import torch.nn as nn | ||
| import transformers | ||
| from transformers.models.t5.modeling_t5 import T5Attention | ||
|
|
||
| from modelopt.torch.opt.dynamic import DynamicModule | ||
|
|
@@ -48,6 +49,13 @@ | |
| from ..conversion import register | ||
| from ..nn import QuantInputBase, QuantModule, QuantModuleRegistry, TensorQuantizer | ||
| from ..nn.modules.quant_linear import _QuantLinear | ||
| from ..triton import IS_AVAILABLE as IS_TRITON_AVAILABLE | ||
|
|
||
| if IS_TRITON_AVAILABLE: | ||
| from ..triton import weight_dequant | ||
| else: | ||
| weight_dequant = None | ||
|
|
||
| from ..utils import replace_function | ||
| from .attention import register_attention_for_kv_quant | ||
| from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear, _QuantFunctionalMixin | ||
|
|
@@ -57,6 +65,8 @@ | |
|
|
||
| __all__ = ["register_hf_attentions_on_the_fly"] | ||
|
|
||
| TRANSFORMERS_VERSION_GE_5_0 = version.parse(transformers.__version__) >= version.parse("5.0.0") | ||
|
|
||
|
|
||
| class _QuantAttention(QuantModule): | ||
| """Attention class for KV Cache quantization compatible with new_attention_interface in transformers >= 4.48.0.""" | ||
|
|
@@ -447,10 +457,24 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| # If any of the experts are in calibration mode, we will forward all tokens to all experts | ||
| # This is used only for calibration, we need to re-calculate the actual outputs again using | ||
| # the original top_k | ||
| original_top_k = self.top_k | ||
| self.top_k = self.num_experts | ||
| super().forward(hidden_states) | ||
| self.top_k = original_top_k | ||
| if TRANSFORMERS_VERSION_GE_5_0: | ||
| assert hasattr(self, "gate") | ||
| # Path for transformers >= 5.0 | ||
| original_top_k = self.gate.topk | ||
| self.gate.topk = self.gate.num_experts | ||
| super().forward(hidden_states) | ||
| self.gate.topk = original_top_k | ||
| else: | ||
| # Path for transformers < 5.0 | ||
| original_top_k = self.top_k | ||
| if hasattr(self, "num_experts"): | ||
| self.top_k = self.num_experts | ||
| elif hasattr(self, "experts"): | ||
| self.top_k = self.experts.num_experts | ||
| else: | ||
| raise ValueError(f"Could not find num_experts in module {self}") | ||
| super().forward(hidden_states) | ||
| self.top_k = original_top_k | ||
| return super().forward(hidden_states) | ||
|
|
||
|
|
||
|
|
@@ -693,6 +717,53 @@ def unpack_weight(self): | |
| del self.weight_scale | ||
|
|
||
|
|
||
| class _QuantFP8Linear(QuantModule): | ||
| def _setup(self): | ||
| self.input_quantizer = TensorQuantizer() | ||
| self.weight_quantizer = TensorQuantizer() | ||
| assert self.weight_scale_inv.ndim == 2, "Weight scale inverse must be 2D" | ||
| assert self.weight.ndim == 2, "Weight must be 2D" | ||
| self.block_size = max( | ||
| self.weight.shape[0] // self.weight_scale_inv.shape[0], | ||
| self.weight.shape[1] // self.weight_scale_inv.shape[1], | ||
| ) | ||
| assert self.block_size == 128, "Block size must be 128" | ||
|
|
||
| def _get_weight_and_scale_inv(self): | ||
| if isinstance(self.weight, torch.distributed.tensor.DTensor): | ||
| weight = self.weight._local_tensor.contiguous() | ||
| scale_inv = self.weight_scale_inv._local_tensor.contiguous() | ||
| else: | ||
| weight = self.weight.contiguous() | ||
| scale_inv = self.weight_scale_inv.contiguous() | ||
| return weight, scale_inv | ||
|
|
||
| def forward(self, input: Tensor) -> Tensor: | ||
| assert weight_dequant is not None, "Triton is not available" | ||
| if self.weight.element_size() == 1: | ||
| with torch.cuda.device(self.weight.device): | ||
| weight, scale_inv = self._get_weight_and_scale_inv() | ||
| weight = weight_dequant(weight, scale_inv, self.block_size, dtype=input.dtype) | ||
| else: | ||
| weight = self.weight | ||
| return linear( | ||
| self.input_quantizer(input), | ||
| self.weight_quantizer(weight), | ||
| self.bias, | ||
| ) | ||
|
|
||
| def unpack_weight(self): | ||
| assert weight_dequant is not None, "Triton is not available" | ||
| with torch.cuda.device(self.weight.device): | ||
| weight, scale_inv = self._get_weight_and_scale_inv() | ||
| self.weight = nn.Parameter( | ||
| weight_dequant(weight, scale_inv, self.block_size, dtype=torch.get_default_dtype()), | ||
| requires_grad=False, | ||
| ) | ||
| if hasattr(self, "weight_scale_inv"): | ||
| del self.weight_scale_inv | ||
|
Comment on lines
755
to
764
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider specifying In If this is intentional for export, a comment would clarify the design. Otherwise, consider accepting/storing a target dtype to ensure consistency. 💡 Suggested improvement def unpack_weight(self):
with torch.cuda.device(self.weight.device):
weight, scale_inv = self._get_weight_and_scale_inv()
self.weight = nn.Parameter(
- weight_dequant(weight, scale_inv, self.block_size),
+ weight_dequant(weight, scale_inv, self.block_size, dtype=torch.bfloat16),
requires_grad=False,
)Or store the original dtype during 🤖 Prompt for AI Agents |
||
|
|
||
|
|
||
| try: | ||
| from transformers.models.llama4.modeling_llama4 import Llama4TextExperts, Llama4TextMoe | ||
|
|
||
|
|
@@ -796,6 +867,14 @@ def unpack_weight(self): | |
| except ImportError: | ||
| pass | ||
|
|
||
| try: | ||
| from transformers.integrations.finegrained_fp8 import FP8Linear | ||
|
|
||
| if FP8Linear not in QuantModuleRegistry: | ||
| QuantModuleRegistry.register({FP8Linear: "hf.FP8Linear"})(_QuantFP8Linear) | ||
| except ImportError: | ||
| pass | ||
|
|
||
|
|
||
| class _QuantGptOssExperts(_QuantFunctionalMixin): | ||
| """Quantized wrapper for `transformers.GptOssExperts`. | ||
|
|
@@ -910,6 +989,17 @@ def register_falcon_linears_on_the_fly(model): | |
| QuantModuleRegistry.register({linear_type: linear_type.__name__})(_QuantLinear) | ||
|
|
||
|
|
||
| def register_minimax_m2_moe_on_the_fly(model): | ||
| """Register MiniMax M2 MoE modules as a QUANT_MODULE. | ||
|
|
||
| MiniMax M2 MoE modules are defined in the model card, so we need to register them on the fly. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does latest HF transformers not support MiniMax M2 MoE ?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it requires 5.0x |
||
| """ | ||
| if type(model).__name__ in ["MiniMaxM2ForCausalLM"]: | ||
| moe_type = type(model.model.layers[0].block_sparse_moe) | ||
| if QuantModuleRegistry.get(moe_type) is None: | ||
| QuantModuleRegistry.register({moe_type: moe_type.__name__})(_QuantSparseMoe) | ||
|
|
||
|
|
||
| def _is_supported_hf_model(model): | ||
| """Check if the model a valid model for transformers quantization specific support.""" | ||
| supported_models = [transformers.PreTrainedModel] | ||
|
|
@@ -975,6 +1065,7 @@ def _is_param_grad_enabled_for_auto_quantize(pname, model): | |
| [ | ||
| register_falcon_linears_on_the_fly, | ||
| register_dbrx_moe_on_the_fly, | ||
| register_minimax_m2_moe_on_the_fly, | ||
| register_hf_attentions_on_the_fly, | ||
| convert_hf_parallel_linears_on_the_fly, | ||
| ] | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.