Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions auto_round/special_model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,16 @@ def get_glm_flash_ignore_layers(model) -> list[str]:
],
)

register_ignore_layers(
matchers=[
ModelTypeMatcher(r"kimi_k25", mode="full"),
],
ignore_layers=[
"vision_tower",
"mm_projector",
],
)


def get_predefined_ignore_layers(model: torch.nn.Module) -> list[str]:
layers = []
Expand Down
2 changes: 1 addition & 1 deletion auto_round/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def __getitem__(self, key):
SUPPORTED_FORMATS = SupportedFormats()
SUPPORTED_LAYER_TYPES = (torch.nn.Linear, transformers.pytorch_utils.Conv1D)
# Changed to str as it relies on triton or others lib to load this
INNER_SUPPORTED_LAYER_TYPES = ("FP8Linear",)
INNER_SUPPORTED_LAYER_TYPES = ("FP8Linear", "CompressedLinear")
# transformers.integrations.finegrained_fp8.FP8Linear
if deepspeed_exists:
from deepspeed.module_inject import LinearAllreduce, LinearLayer
Expand Down
115 changes: 114 additions & 1 deletion auto_round/utils/weight_handler.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
- MXFP8Handler: CompressedLinear with MXFP8PackedCompressor
- MXFP4Handler: CompressedLinear with MXFP4PackedCompressor
- NVFP4Handler: CompressedLinear with NVFP4PackedCompressor
- WOQHandler: CompressedLinear with weight-only quantization

Quick Start Guide:
Usage - Detect and Convert:
Expand Down Expand Up @@ -158,6 +159,7 @@ class ModuleWeightType(Enum):
MXFP8 = auto() # MX FP8 (CompressedLinear with MXFP8PackedCompressor)
MXFP4 = auto() # MX FP4 (CompressedLinear with MXFP4PackedCompressor)
NVFP4 = auto() # NV FP4 (CompressedLinear with NVFP4PackedCompressor)
WOQ = auto() # Weight-Only Quantization (CompressedLinear with weight-only quantization)


class WeightTypeHandler(ABC):
Expand All @@ -180,6 +182,15 @@ def detect_layer(self, module: torch.nn.Module) -> bool:
"""
pass

def attach_weight_shape(self, module: torch.nn.Module):
Comment thread
yiliu30 marked this conversation as resolved.
"""Optional helper to attach weight shape information to the module for detection."""
if not hasattr(module, "weight") or module.weight is None:
module.weight = torch.empty(
module.out_features,
module.in_features,
device="meta",
)

@abstractmethod
def convert_layer(
self,
Expand Down Expand Up @@ -297,6 +308,7 @@ def check_and_mark_quantized_module(model: torch.nn.Module) -> Set[ModuleWeightT
for weight_type, handler in _WEIGHT_TYPE_HANDLERS.items():
# Check model itself first
if handler.detect_layer(model):
handler.attach_weight_shape(model)
Comment thread
xin3he marked this conversation as resolved.
model._is_quantized_input_module = True
model.quantized_weight_type = weight_type
detected_types.add(weight_type)
Expand All @@ -305,6 +317,7 @@ def check_and_mark_quantized_module(model: torch.nn.Module) -> Set[ModuleWeightT
for n, m in model.named_modules():
# Use handler to detect based on actual characteristics
if handler.detect_layer(m):
handler.attach_weight_shape(m)
# Mark the layer itself
m.quantized_weight_type = weight_type
# for gguf format, gguf format need to mark the quantized input module
Expand All @@ -314,6 +327,9 @@ def check_and_mark_quantized_module(model: torch.nn.Module) -> Set[ModuleWeightT
# Record detected types
detected_types.add(weight_type)

# remove decompress_hook for CT models
if hasattr(model, "ct_decompress_hook"):
model.ct_decompress_hook.remove()
return detected_types


Expand Down Expand Up @@ -341,6 +357,24 @@ def is_quantized_input_module(model: torch.nn.Module) -> Optional[ModuleWeightTy
return None


def remove_existed_quantization_config(model: torch.nn.Module):
"""Removes the existing quantization configuration from the model's config if it exists.

This is necessary to prevent conflicts during conversion, especially for models that have a
`quantization_config` attribute in their config or sub-configs. It checks the model and its
config for any `quantization_config` attributes and deletes them if found.
"""
if hasattr(model, "config") and model.config is not None:
if hasattr(model.config, "quantization_config"):
delattr(model.config, "quantization_config")
for attr in dir(model.config): # for text_config, vision_config, etc.
if "config" not in attr:
continue
config_attr = getattr(model.config, attr)
if hasattr(config_attr, "quantization_config"):
delattr(config_attr, "quantization_config")


# --- Main Conversion Function ---
def convert_module_to_hp_if_necessary(
model_or_layer: torch.nn.Module,
Expand All @@ -367,10 +401,22 @@ def convert_module_to_hp_if_necessary(
from auto_round.utils.device import clear_memory
from auto_round.utils.model import set_module

def _sync_serialization_attrs(src_module: torch.nn.Module, dst_module: torch.nn.Module) -> None:
"""Copy serialization-related attributes from source to destination module."""
from auto_round.compressors.base import SERIALIZATION_KEYS

orig_module_keys = list(SERIALIZATION_KEYS) + ["global_name"]
for key in orig_module_keys:
if hasattr(src_module, key):
setattr(dst_module, key, getattr(src_module, key))

remove_existed_quantization_config(model_or_layer)
# Check if it's a single quantized layer (has the attribute directly)
if hasattr(model_or_layer, "quantized_weight_type") and model_or_layer.quantized_weight_type is not None:
handler = get_handler(model_or_layer.quantized_weight_type)
return handler.convert_layer(model_or_layer, dtype, device, to_cpu)
new_module = handler.convert_layer(model_or_layer, dtype, device, to_cpu)
_sync_serialization_attrs(model_or_layer, new_module)
return new_module

# Otherwise, traverse model and convert all quantized layers
# Get handler for each layer to support mixed quantization types
Expand All @@ -379,6 +425,8 @@ def convert_module_to_hp_if_necessary(
if hasattr(m, "quantized_weight_type") and m.quantized_weight_type is not None:
handler = get_handler(m.quantized_weight_type)
new_module = handler.convert_layer(m, dtype, device, to_cpu)
_sync_serialization_attrs(m, new_module)
new_module.quantized_weight_type = None # Clear quantized type after conversion
set_module(model_or_layer, n, new_module)
cnt += 1
if cnt % 10 == 0:
Expand Down Expand Up @@ -819,3 +867,68 @@ def convert_layer(
new_layer = new_layer.to("cpu")

return new_layer


# ----------------------------------------------------------------------------
# WOQ Handler - CompressedLinear with weight-only quantization
# ----------------------------------------------------------------------------


@register_weight_type_handler(ModuleWeightType.WOQ)
class WOQHandler(WeightTypeHandler):
"""Handler for integer 4-bit weight-only quantized layers (Compressed Tensor only)."""

def detect_layer(self, module: torch.nn.Module) -> bool:
Comment thread
xin3he marked this conversation as resolved.
"""Check if a module is a CompressedLinear layer."""
if module.__class__.__name__ == "CompressedLinear":
if hasattr(module, "compressor") and module.compressor is not None:
q_scheme = module.quantization_scheme
if (
q_scheme.weights.num_bits in [2, 4, 8]
and q_scheme.weights.type == "int"
and q_scheme.input_activations is None
):
return True
if hasattr(module, "quantization_scheme"):
from compressed_tensors.quantization.utils import is_module_quantized # pylint: disable=E0401

if is_module_quantized(module) and module.quantization_status.value == "compressed":
q_scheme = module.quantization_scheme
if (
q_scheme.weights.num_bits in [2, 4, 8]
and q_scheme.weights.type == "int"
and q_scheme.input_activations is None
):
return True
return False

def convert_layer(
self,
layer: torch.nn.Module,
dtype: torch.dtype = torch.bfloat16,
device: str = "cpu",
to_cpu: bool = False,
) -> torch.nn.Module:
"""Convert an integer weight-only quantized layer to a standard Linear layer."""
if hasattr(layer, "quantization_scheme") and layer.__class__.__name__ == "Linear":
from compressed_tensors.compressors.base import decompress_module # pylint: disable=E0401

decompress_module(layer)
return layer
Comment thread
xin3he marked this conversation as resolved.

new_layer = torch.nn.Linear(layer.in_features, layer.out_features, bias=layer.bias is not None, dtype=dtype)
if layer.bias is not None:
new_layer.bias.data.copy_(layer.bias.data.to(dtype=dtype))

# Use compressor.decompress_module for dequantization
dq_weight = layer.compressor.decompress_module(layer)
new_layer.weight.data.copy_(dq_weight.to(dtype=dtype))

# Free intermediate CUDA tensors to avoid memory buildup
del dq_weight
layer.to("meta")

if to_cpu:
new_layer = new_layer.to("cpu")

return new_layer
33 changes: 33 additions & 0 deletions test/test_cpu/advanced/test_low_precision_input_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pytest
import torch
import transformers
from packaging import version

from auto_round import AutoRound
from auto_round.utils.weight_handler import (
ModuleWeightType,
check_and_mark_quantized_module,
Expand All @@ -15,6 +17,7 @@ class TestCompressedTensor:
nvfp4_model_path = "kaitchup/Qwen3-0.6B-NVFP4"
mxfp4_model_path = "QuixiAI/Llama-3.2-1B-MXFP4"
fp8_block_model_path = "RedHatAI/Qwen3-0.6B-FP8-BLOCK"
w4a16_model_path = "RedHatAI/Qwen3-0.6B-quantized.w4a16"

def test_fp8_block(self):
model = get_tiny_model(get_model_path(self.fp8_block_model_path))
Expand Down Expand Up @@ -67,3 +70,33 @@ def test_mxfp4(self):
assert (
model.model.layers[0].mlp.up_proj.weight.dtype == torch.bfloat16
), "CompressedLinear layer was not converted to Linear"

def test_w4a16(self):
model = get_tiny_model(get_model_path(self.w4a16_model_path))
assert (
model.model.layers[0].mlp.up_proj.weight_packed.dtype == torch.int32
), "Original weight is not in INT4 format"
assert hasattr(
model.model.layers[0].mlp.up_proj, "quantization_scheme"
), "Model does not contain CompressedLinear layers"
detected_types = check_and_mark_quantized_module(model)
Comment thread
xin3he marked this conversation as resolved.
assert ModuleWeightType.WOQ in detected_types
model = convert_module_to_hp_if_necessary(model)
assert (
model.model.layers[0].mlp.up_proj.weight.dtype == torch.bfloat16
), "CompressedLinear layer was not converted to Linear"

def test_w4a16_to_mxfp4(self, tmp_path):
model = get_tiny_model(get_model_path(self.w4a16_model_path))
model.config.name_or_path = None # Clear the name_or_path to avoid MTP copying issues
tokenizer = transformers.AutoTokenizer.from_pretrained(self.w4a16_model_path)
ar = AutoRound(
model,
tokenizer=tokenizer,
scheme="MXFP4",
iters=2,
nsamples=2,
)
ar.quantize_and_save(tmp_path, format="llm_compressor")
model = transformers.AutoModelForCausalLM.from_pretrained(tmp_path)
assert model, "Failed to load the quantized model"
Loading