-
-
Notifications
You must be signed in to change notification settings - Fork 816
Description
Description
At transformer v4, experts were implemented using nn.ModuleList of nn.Linear layers. However, starting from transformer v5, newer efficient implementations (likely optimizing for Grouped GEMM or CUDA kernels) are moving towards storing all expert weights in a single fused nn.Parameter tensor (e.g., with shape [num_experts, in_features, out_features]) instead of separate nn.Linear modules.
One typical model is Qwen3MoeForCausalLM, with expert implementations shown as below:
@use_experts_implementation
class Qwen3MoeExperts(nn.Module):
"""Collection of expert weights stored as 3D tensors."""
def __init__(self, config):
super().__init__()
self.num_experts = config.num_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.moe_intermediate_size
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]It's worth noticing use_experts_implementation allow switching MOE implementation to the more efficent group_gemm version. Bitsandbytes library may also consider implementing a grouped forward version to improve speed on Moe models.
Reproduction
transformers 5.0.0rc3 should be used.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True
)
model_path = "Qwen/Qwen3-30B-A3B"
model = AutoModelForCausalLM.from_pretrained(
model_path,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
)
print(f"Memory: {model.get_memory_footprint() / 1024**3:.2f} GB")It prints Memory: 55.60 GB, showing quantizing is not working correctly.
Expected behavior
As in transformer v4, it prints Memory: 15.09 GB, which is properly quantized.
Fix proposal
I've opened a feature request in transformers to address the root cause of this issue: huggingface/transformers#43472.
Instead of handling raw parameters, I proposed introducing a standardized BatchLinear module. If adopted, downstream libraries would only need to support replacing this specific module type, ensuring compatibility for Qwen3-MoE, DeepSeek, and future MoE architectures without model-specific hacks.
Upvoting or commenting on that RFC would help prioritize a unified solution and significantly reduce the maintenance burden for bitsandbytes.