Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
cf55efb
add support for gemma4 model
n1ck-guo Apr 3, 2026
8dfdca8
Merge branch 'main' into hengguo/support_for_gemma4
n1ck-guo Apr 3, 2026
8ab6ebe
try to support gemma4
wenhuach21 Apr 3, 2026
0cc631d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 3, 2026
6a6afcf
Update auto_round/compressors/base.py
wenhuach21 Apr 3, 2026
7f1af02
Update auto_round/compressors/base.py
wenhuach21 Apr 3, 2026
416797c
Update auto_round/compressors/base.py
wenhuach21 Apr 3, 2026
4638d8a
refine
wenhuach21 Apr 3, 2026
a8dd583
Update auto_round/special_model_handler.py
wenhuach21 Apr 3, 2026
8256d92
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 3, 2026
d3d1ed8
fix
wenhuach21 Apr 3, 2026
f2e5332
Merge branch 'support_gemma4' of https://github.com/intel/auto-round …
wenhuach21 Apr 3, 2026
222838e
support opt_rtn
wenhuach21 Apr 3, 2026
bc93840
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 3, 2026
771b003
update
wenhuach21 Apr 3, 2026
f5849bf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 3, 2026
4c087b6
fix offload immediate_saving issue
lvliang-intel Apr 4, 2026
5b4ae05
Merge branch 'main' into support_gemma4
lvliang-intel Apr 4, 2026
a53ea86
merge pr1656
n1ck-guo Apr 8, 2026
52873d3
sync
n1ck-guo Apr 8, 2026
9218ad6
Merge remote-tracking branch 'origin/support_gemma4' into hengguo/sup…
n1ck-guo Apr 10, 2026
9987bba
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 10, 2026
4f0d3fe
update
n1ck-guo Apr 10, 2026
d508e02
Merge remote-tracking branch 'origin/main' into hengguo/support_for_g…
n1ck-guo Apr 14, 2026
6f6a2d3
fix merge
n1ck-guo Apr 14, 2026
edea4e1
Merge branch 'hengguo/support_for_gemma4' of https://github.com/intel…
n1ck-guo Apr 14, 2026
a4b1302
Merge remote-tracking branch 'origin/main' into hengguo/support_for_g…
n1ck-guo Apr 17, 2026
4979f2d
fix by comment
n1ck-guo Apr 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 62 additions & 23 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
scheme_to_preset_name,
)
from auto_round.sign_sgd import SignSGD
from auto_round.special_model_handler import get_predefined_ignore_layers, update_module
from auto_round.special_model_handler import get_predefined_fixed_attr, get_predefined_ignore_layers, update_module
from auto_round.utils import (
INNER_SUPPORTED_LAYER_TYPES,
SUPPORTED_DTYPES,
Expand Down Expand Up @@ -568,6 +568,12 @@ def __init__(
self.hadamard_config = normalize_hadamard_config(hadamard_config, self.data_type)
self.model = apply_hadamard_transform(self.model, self.hadamard_config, data_type=self.data_type)

self.blocks_requiring_input_ids = []
self.has_variable_block_shape = False
fixed_attr = get_predefined_fixed_attr(self.model) or {}
for key, value in fixed_attr.items():
setattr(self, key, value)

def _gen_auto_scheme(self) -> dict[str, dict]:
if self.mllm:
logger.info("AutoScheme is not yet supported for multimodal LLMs.")
Expand Down Expand Up @@ -1515,17 +1521,20 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str])
if not all_blocks:
raise ValueError("Could not find any blocks. Check the model or quant_block_list.")

all_first_block_names = [block[0] for block in all_blocks]
if not self.has_variable_block_shape:
to_cache_block_names = [block[0] for block in all_blocks]
else:
to_cache_block_names = flatten_list(all_blocks)
layer_names = self._get_quantized_layer_names_outside_blocks()
if self.act_bits < 16 and (not self.act_dynamic or len(layer_names) > 0):
if self.act_bits < 16 and (not self.act_dynamic or len(layer_names) > 0) or self.has_variable_block_shape:
if len(layer_names) > 0:
logger.warning(
"quantize layers outside blocks for static activation quantizaiton"
" will significantly increase calibration time"
)
all_inputs = self.try_cache_inter_data_gpucpu(all_first_block_names, self.nsamples, layer_names)
all_inputs = self.try_cache_inter_data_gpucpu(to_cache_block_names, self.nsamples, layer_names)
else:
all_inputs = self.cache_inter_data(all_first_block_names, self.nsamples)
all_inputs = self.cache_inter_data(to_cache_block_names, self.nsamples)

# Clear hooks for multi-GPU setups
if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1:
Expand All @@ -1549,20 +1558,28 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str])
if total_samples < self.batch_size:
self.batch_size = total_samples
logger.warning(f"Forcing batch size to {total_samples}")
tmp_dtype = self.amp_dtype if self.amp else torch.float32

input_ids = to_device(inputs.pop("input_ids"), self.cache_device)
input_others = to_device(inputs, self.cache_device)

tmp_dtype = self.amp_dtype if self.amp else torch.float32
input_ids = [id_.to(tmp_dtype) for id_ in input_ids]

for key, val in input_others.items():
if isinstance(val, torch.Tensor) and val.dtype in (torch.float16, torch.bfloat16):
input_others[key] = val.to(tmp_dtype)
elif isinstance(val, list):
input_others[key] = [to_dtype(v, tmp_dtype) for v in val]
def process_input_others(input_others):

input_others = to_device(input_others, self.cache_device)
for key, val in input_others.items():
if isinstance(val, torch.Tensor) and val.dtype in (torch.float16, torch.bfloat16):
input_others[key] = val.to(tmp_dtype)
elif isinstance(val, list):
input_others[key] = [to_dtype(v, tmp_dtype) for v in val]
return input_others

input_others = inputs
input_others = process_input_others(input_others)
for block_name in block_names:
if block_name in all_inputs.keys():
input_others = all_inputs[block_name]
input_others = process_input_others(input_others)
all_inputs.pop(block_name)
pbar.set_description(f"Quantizing {block_name}")
block = get_module(self.model, block_name)
materialize_model_(block)
Expand Down Expand Up @@ -1813,22 +1830,25 @@ def _should_disable_inplace_due_to_layers_outside_block() -> bool:
self.model = self.model.to(self.amp_dtype)

layer_names = self._get_quantized_layer_names_outside_blocks()
all_first_block_names = [block[0] for block in all_blocks]
if not self.has_variable_block_shape:
to_cache_block_names = [block[0] for block in all_blocks]
else:
to_cache_block_names = flatten_list(all_blocks)
if len(layer_names) > 0:
logger.info(
"Starting to cache block inputs. This may be slow due to external block layers: %s", layer_names
)
else:
logger.info("start to cache block inputs")
all_inputs = self.try_cache_inter_data_gpucpu(all_first_block_names, self.nsamples, layer_names=layer_names)
all_inputs = self.try_cache_inter_data_gpucpu(to_cache_block_names, self.nsamples, layer_names=layer_names)
is_quantized_embedding = self._quantize_embedding_layer()
clear_memory(device_list=self.device_list)
all_q_inputs = None
if is_quantized_embedding:
all_inputs = copy.deepcopy(self.inputs)
clear_memory(self.inputs, device_list=self.device_list)
all_q_inputs = self.try_cache_inter_data_gpucpu(
all_first_block_names, self.nsamples, layer_names=layer_names
to_cache_block_names, self.nsamples, layer_names=layer_names
)
# Remove accelerate dispatch hooks before moving parameters.
# hf_device_map is kept for reference but hooks are no longer needed.
Expand Down Expand Up @@ -1876,6 +1896,7 @@ def _should_disable_inplace_due_to_layers_outside_block() -> bool:
nblocks=self.nblocks,
device=self.device,
pbar=pbar,
input_others_extra_blocks=all_inputs,
)
if self.is_immediate_packing and len(self.formats) != 1:
raise ValueError(
Expand Down Expand Up @@ -2084,6 +2105,7 @@ def input_capture_hook(module, *args, **kwargs):
first_block_name = self.quant_block_list[0][0]

class _FakeDecodingLayer(torch.nn.Module):

def forward(self, *args, **kwargs):
return args, kwargs

Expand Down Expand Up @@ -2268,10 +2290,12 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l
Raises:
Exception: If caching on GPU fails, switches to CPU and caches there.
"""
block_names = flatten_list(block_names)
if is_quantized_input_module(self.model):
layer_names = []
if layer_names is None:
layer_names = []
self.blocks_requiring_input_ids = [data if isinstance(data, str) else data[0] for data in block_names]

calibrate_on_cpu = False
cannot_calibrate_on_cpu = False
Expand All @@ -2280,6 +2304,7 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l
and len(layer_names) == 0
and not self.has_qlayer_outside_block
and (last_cache_name is None or last_cache_name in block_names)
and not getattr(self, "mllm", False)
):
# low_gpu_mem_usage or calibrate only the embedding layer, which is also very fast on CPU
calibrate_on_cpu = True
Expand Down Expand Up @@ -2426,17 +2451,18 @@ def cache_inter_data(self, block_names, nsamples, layer_names=None, last_cache_n
if layer_names is None:
layer_names = []
self.inputs = {}
block_names = flatten_list(block_names)
self.to_cached_layers = block_names + layer_names

tmp_dtype = None # TODO delete this as most model is not fp32 now
## have bug if block name is not the first block
# There is a bug if the block name is not the first block
if (len(block_names) > 1 or len(layer_names) > 0) and self.low_gpu_mem_usage:
tmp_dtype = self.model.dtype
if self.amp:
if self.model.dtype != self.model.dtype:
self.model = self.model.to(torch.bfloat16)
else:
self.model = self.model.to(torch.float32) ##model on cpu
self.model = self.model.to(torch.float32) # model on cpu

self.last_cache_name = self._infer_last_cache_name(block_names, layer_names, last_cache_name)
self._cache_target_set = set(self.to_cached_layers)
Expand Down Expand Up @@ -2571,6 +2597,12 @@ def forward(m, hidden_states=None, *positional_inputs, **kwargs):
or isinstance(kwargs[key], list)
or isinstance(kwargs[key], tuple)
):
if (
self.has_variable_block_shape
and name not in self.blocks_requiring_input_ids
and key == "hidden_states"
):
continue
if key not in self.inputs[name].keys(): # initialization
data = to_device(kwargs[key], device=torch.device("cpu"))
if data is None or (self.batch_size > 1 and key in self.shared_cache_keys):
Expand Down Expand Up @@ -2655,6 +2687,7 @@ def _replace_forward(self):
self.hook_handles.append(hook_handle)

def _register_act_max_hook(self, model):

def get_act_max_hook(module, input, output):
if isinstance(input, (tuple, list)):
input = input[0]
Expand Down Expand Up @@ -3259,21 +3292,21 @@ def _quantize_block(
return None, output

def _split_inputs(self, inputs: dict, first_input_name: str) -> tuple[torch.Tensor, dict]:
input_ids = inputs[first_input_name]
input_ids = inputs.get(first_input_name, None)
inputs.pop(first_input_name, None)
input_others = inputs
return input_ids, input_others

def _preprocess_block_inputs(self, inputs, first_input_name="input_ids"):
input_ids, input_others = self._split_inputs(inputs, first_input_name)
clear_memory(device_list=self.device_list)
input_ids = to_device(input_ids, self.cache_device)
tmp_dtype = self.amp_dtype if self.amp else torch.float32
if input_ids is not None:
input_ids = to_device(input_ids, self.cache_device)
input_ids = to_dtype(input_ids, tmp_dtype)
input_others = to_device(input_others, self.cache_device)
# As in calibration phase, we may use bf16 for calibration due to low_gpu_memory usage

tmp_dtype = self.amp_dtype if self.amp else torch.float32
input_ids = to_dtype(input_ids, tmp_dtype)

for key in input_others.keys():
if isinstance(input_others[key], torch.Tensor) and (
input_others[key].dtype == torch.float16 or input_others[key].dtype == torch.bfloat16
Expand All @@ -3293,6 +3326,7 @@ def _quantize_blocks(
nblocks: int = 1,
device: str = "cpu",
pbar: tqdm = None,
input_others_extra_blocks: dict = None,
):
"""Quantize and dequantize the weights of the specified blocks in the model.

Expand All @@ -3316,12 +3350,17 @@ def _quantize_blocks(
pbar = tqdm(range(0, len(block_names), nblocks))

for i in range(0, len(block_names), nblocks):
if input_others_extra_blocks and block_names[i] in input_others_extra_blocks:
input_others = input_others_extra_blocks[block_names[i]]
_, input_others = self._preprocess_block_inputs(input_others)
input_others_extra_blocks.pop(block_names[i])
if i != 0:
pbar.update(1)
if nblocks == 1:
n = block_names[i]
pbar.set_description(f"Quantizing {n}")
m = get_module(model, n)

else:
names = block_names[i : min(i + nblocks, len(block_names))]
pbar.set_description(f"Quantizing [{i + 1}-{min(i + nblocks, len(block_names))}]/{len(block_names)}")
Expand Down
5 changes: 4 additions & 1 deletion auto_round/compressors/mllm/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,10 @@ def __init__(
)
dataset = "liuhaotian/llava_conv_58k"
elif self.template is not None and not _only_text_test(
model, tokenizer, self.device, self.template.model_type
model,
tokenizer,
self.device,
getattr(getattr(model, "config", None), "model_type", None) or self.template.model_type,
):
logger.warning(
f"{model.config.model_type} does not support for {dataset},"
Expand Down
24 changes: 22 additions & 2 deletions auto_round/compressors/shard_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from collections import OrderedDict

import torch
from torch.nn import Parameter

from auto_round.logger import logger
from auto_round.utils import get_lm_head_name, get_module
Expand Down Expand Up @@ -196,7 +197,26 @@ def _offload_to_meta(self, saved_params):
and isinstance(module, torch.nn.Module)
and all(f"{module_path}.{k}" in self._all_saved for k in module.state_dict().keys())
):
module.to("meta")
self._move_module_to_meta(module)

def _move_module_to_meta(self, module: torch.nn.Module):
for child in module.children():
self._move_module_to_meta(child)

for name, param in list(module._parameters.items()):
if param is None:
continue
if isinstance(param, Parameter):
meta_param = Parameter(param.detach().to(device="meta"), requires_grad=param.requires_grad)
elif isinstance(param, torch.Tensor):
meta_param = Parameter(param.detach().to(device="meta"), requires_grad=param.requires_grad)
else:
continue
module._parameters[name] = meta_param

for name, buffer in list(module._buffers.items()):
if isinstance(buffer, torch.Tensor):
module._buffers[name] = buffer.detach().to(device="meta")

def finalize(self):
"""Saves remaining weights, renames files, and writes the index JSON."""
Expand Down Expand Up @@ -225,7 +245,7 @@ def finalize(self):
layer_name = ".".join(pname.split(".")[:-1])
if self.lm_head_name is not None and layer_name == self.lm_head_name and tie_word_embeddings:
lm_head_module = get_module(self.model, self.lm_head_name)
lm_head_module.to("meta") # Must to meta, otherwise model's saver will dump it again
self._move_module_to_meta(lm_head_module) # Must to meta, otherwise model's saver will dump it again
continue
self._add_tensor(pname, tensor.detach().to("cpu"))

Expand Down
Loading
Loading