Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 56 additions & 23 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
preset_name_to_scheme,
)
from auto_round.sign_sgd import SignSGD
from auto_round.special_model_handler import get_predefined_ignore_layers, update_module
from auto_round.special_model_handler import get_predefined_fixed_attr, get_predefined_ignore_layers, update_module
from auto_round.utils import (
INNER_SUPPORTED_LAYER_TYPES,
SUPPORTED_DTYPES,
Expand Down Expand Up @@ -569,6 +569,14 @@ def __init__(
)

self.hadamard_config = normalize_hadamard_config(hadamard_config)
self.has_variable_block_shape = False
all_blocks = self.quant_block_list if self.quant_block_list else get_block_names(self.model)
if not all_blocks:
raise ValueError("Could not find any blocks. Check the model or quant_block_list.")
self.blocks_requiring_input_ids = [data if isinstance(data, str) else data[0] for data in all_blocks]
fixed_attr = get_predefined_fixed_attr(self.model) or {}
for key, value in fixed_attr.items():
setattr(self, key, value)

def _gen_auto_scheme(self) -> dict[str, dict]:
if self.mllm:
Expand Down Expand Up @@ -1517,17 +1525,20 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str])
if not all_blocks:
raise ValueError("Could not find any blocks. Check the model or quant_block_list.")

all_first_block_names = [block[0] for block in all_blocks]
if not self.has_variable_block_shape:
to_cache_block_names = [block[0] for block in all_blocks]
else:
to_cache_block_names = flatten_list(all_blocks)
layer_names = self._get_quantized_layer_names_outside_blocks()
if self.act_bits < 16 and (not self.act_dynamic or len(layer_names) > 0):
if self.act_bits < 16 and (not self.act_dynamic or len(layer_names) > 0) or self.has_variable_block_shape:
if len(layer_names) > 0:
logger.warning(
"quantize layers outside blocks for static activation quantizaiton"
" will significantly increase calibration time"
)
all_inputs = self.try_cache_inter_data_gpucpu(all_first_block_names, self.nsamples, layer_names)
all_inputs = self.try_cache_inter_data_gpucpu(to_cache_block_names, self.nsamples, layer_names)
else:
all_inputs = self.cache_inter_data(all_first_block_names, self.nsamples)
all_inputs = self.cache_inter_data(to_cache_block_names, self.nsamples)

# Clear hooks for multi-GPU setups
if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1:
Expand All @@ -1551,20 +1562,28 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str])
if total_samples < self.batch_size:
self.batch_size = total_samples
logger.warning(f"Forcing batch size to {total_samples}")
tmp_dtype = self.amp_dtype if self.amp else torch.float32

input_ids = to_device(inputs.pop("input_ids"), self.cache_device)
input_others = to_device(inputs, self.cache_device)

tmp_dtype = self.amp_dtype if self.amp else torch.float32
input_ids = [id_.to(tmp_dtype) for id_ in input_ids]

for key, val in input_others.items():
if isinstance(val, torch.Tensor) and val.dtype in (torch.float16, torch.bfloat16):
input_others[key] = val.to(tmp_dtype)
elif isinstance(val, list):
input_others[key] = [to_dtype(v, tmp_dtype) for v in val]
def process_input_others(input_others):

input_others = to_device(input_others, self.cache_device)
for key, val in input_others.items():
if isinstance(val, torch.Tensor) and val.dtype in (torch.float16, torch.bfloat16):
input_others[key] = val.to(tmp_dtype)
elif isinstance(val, list):
input_others[key] = [to_dtype(v, tmp_dtype) for v in val]
return input_others

input_others = inputs
input_others = process_input_others(input_others)
for block_name in block_names:
if block_name in all_inputs.keys():
input_others = all_inputs[block_name]
input_others = process_input_others(input_others)
all_inputs.pop(block_name)
pbar.set_description(f"Quantizing {block_name}")
block = get_module(self.model, block_name)
materialize_model_(block)
Expand Down Expand Up @@ -1809,22 +1828,25 @@ def _should_disable_inplace_due_to_layers_outside_block() -> bool:
self.model = self.model.to(self.amp_dtype)

layer_names = self._get_quantized_layer_names_outside_blocks()
all_first_block_names = [block[0] for block in all_blocks]
if not self.has_variable_block_shape:
to_cache_block_names = [block[0] for block in all_blocks]
else:
to_cache_block_names = flatten_list(all_blocks)
if len(layer_names) > 0:
logger.info(
"Starting to cache block inputs. This may be slow due to external block layers: %s", layer_names
)
else:
logger.info("start to cache block inputs")
all_inputs = self.try_cache_inter_data_gpucpu(all_first_block_names, self.nsamples, layer_names=layer_names)
all_inputs = self.try_cache_inter_data_gpucpu(to_cache_block_names, self.nsamples, layer_names=layer_names)
is_quantized_embedding = self._quantize_embedding_layer()
clear_memory(device_list=self.device_list)
all_q_inputs = None
if is_quantized_embedding:
all_inputs = copy.deepcopy(self.inputs)
clear_memory(self.inputs, device_list=self.device_list)
all_q_inputs = self.try_cache_inter_data_gpucpu(
all_first_block_names, self.nsamples, layer_names=layer_names
to_cache_block_names, self.nsamples, layer_names=layer_names
)
# Remove accelerate dispatch hooks before moving parameters.
# hf_device_map is kept for reference but hooks are no longer needed.
Expand Down Expand Up @@ -1872,6 +1894,7 @@ def _should_disable_inplace_due_to_layers_outside_block() -> bool:
nblocks=self.nblocks,
device=self.device,
pbar=pbar,
input_others_extra_blocks=all_inputs,
)
if self.is_immediate_packing and len(self.formats) != 1:
raise ValueError(
Expand Down Expand Up @@ -2257,6 +2280,7 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l
Raises:
Exception: If caching on GPU fails, switches to CPU and caches there.
"""
block_names = flatten_list(block_names)
if is_quantized_input_module(self.model):
layer_names = []
if layer_names is None:
Expand Down Expand Up @@ -2397,17 +2421,18 @@ def cache_inter_data(self, block_names, nsamples, layer_names=None, last_cache_n
if layer_names is None:
layer_names = []
self.inputs = {}
block_names = flatten_list(block_names)
self.to_cached_layers = block_names + layer_names

tmp_dtype = None # TODO delete this as most model is not fp32 now
## have bug if block name is not the first block
# There is a bug if the block name is not the first block
if (len(block_names) > 1 or len(layer_names) > 0) and self.low_gpu_mem_usage:
tmp_dtype = self.model.dtype
if self.amp:
if self.model.dtype != self.model.dtype:
self.model = self.model.to(torch.bfloat16)
else:
self.model = self.model.to(torch.float32) ##model on cpu
self.model = self.model.to(torch.float32) # model on cpu

self.last_cache_name = self._infer_last_cache_name(block_names, layer_names, last_cache_name)
self._cache_target_set = set(self.to_cached_layers)
Expand Down Expand Up @@ -2539,6 +2564,8 @@ def forward(m, hidden_states=None, *positional_inputs, **kwargs):
or isinstance(kwargs[key], list)
or isinstance(kwargs[key], tuple)
):
if name not in self.blocks_requiring_input_ids and key == "hidden_states":
continue
if key not in self.inputs[name].keys(): # initialization
data = to_device(kwargs[key], device=torch.device("cpu"))
if data is None or (self.batch_size > 1 and key in self.shared_cache_keys):
Expand Down Expand Up @@ -3224,21 +3251,21 @@ def _quantize_block(
return None, output

def _split_inputs(self, inputs: dict, first_input_name: str) -> tuple[torch.Tensor, dict]:
input_ids = inputs[first_input_name]
input_ids = inputs.get(first_input_name, None)
inputs.pop(first_input_name, None)
input_others = inputs
return input_ids, input_others

def _preprocess_block_inputs(self, inputs, first_input_name="input_ids"):
input_ids, input_others = self._split_inputs(inputs, first_input_name)
clear_memory(device_list=self.device_list)
input_ids = to_device(input_ids, self.cache_device)
tmp_dtype = self.amp_dtype if self.amp else torch.float32
if input_ids is not None:
input_ids = to_device(input_ids, self.cache_device)
input_ids = to_dtype(input_ids, tmp_dtype)
input_others = to_device(input_others, self.cache_device)
# As in calibration phase, we may use bf16 for calibration due to low_gpu_memory usage

tmp_dtype = self.amp_dtype if self.amp else torch.float32
input_ids = to_dtype(input_ids, tmp_dtype)

for key in input_others.keys():
if isinstance(input_others[key], torch.Tensor) and (
input_others[key].dtype == torch.float16 or input_others[key].dtype == torch.bfloat16
Expand All @@ -3258,6 +3285,7 @@ def _quantize_blocks(
nblocks: int = 1,
device: str = "cpu",
pbar: tqdm = None,
Comment thread
wenhuach21 marked this conversation as resolved.
input_others_extra_blocks: dict = None,
Comment thread
wenhuach21 marked this conversation as resolved.
):
"""Quantize and dequantize the weights of the specified blocks in the model.

Expand All @@ -3281,12 +3309,17 @@ def _quantize_blocks(
pbar = tqdm(range(0, len(block_names), nblocks))

for i in range(0, len(block_names), nblocks):
if block_names[i] in input_others_extra_blocks:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to change to "if input_others_extra_blocks and block_names[i] in input_others_extra_blocks:"

input_others = input_others_extra_blocks[block_names[i]]
_, input_others = self._preprocess_block_inputs(input_others)
input_others_extra_blocks.pop(block_names[i])
Comment thread
wenhuach21 marked this conversation as resolved.
if i != 0:
pbar.update(1)
if nblocks == 1:
n = block_names[i]
pbar.set_description(f"Quantizing {n}")
m = get_module(model, n)

else:
names = block_names[i : min(i + nblocks, len(block_names))]
pbar.set_description(f"Quantizing [{i + 1}-{min(i + nblocks, len(block_names))}]/{len(block_names)}")
Expand Down
24 changes: 22 additions & 2 deletions auto_round/compressors/shard_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from collections import OrderedDict

import torch
from torch.nn import Parameter

from auto_round.logger import logger
from auto_round.utils import get_lm_head_name, get_module
Expand Down Expand Up @@ -173,7 +174,26 @@ def _offload_to_meta(self, saved_params):
and isinstance(module, torch.nn.Module)
and all(f"{module_path}.{k}" in self._all_saved for k in module.state_dict().keys())
):
module.to("meta")
self._move_module_to_meta(module)

def _move_module_to_meta(self, module: torch.nn.Module):
for child in module.children():
self._move_module_to_meta(child)

for name, param in list(module._parameters.items()):
if param is None:
continue
if isinstance(param, Parameter):
meta_param = Parameter(param.detach().to(device="meta"), requires_grad=param.requires_grad)
elif isinstance(param, torch.Tensor):
meta_param = Parameter(param.detach().to(device="meta"), requires_grad=param.requires_grad)
else:
continue
module._parameters[name] = meta_param

for name, buffer in list(module._buffers.items()):
if isinstance(buffer, torch.Tensor):
module._buffers[name] = buffer.detach().to(device="meta")

def finalize(self):
"""Saves remaining weights, renames files, and writes the index JSON."""
Expand Down Expand Up @@ -202,7 +222,7 @@ def finalize(self):
layer_name = ".".join(pname.split(".")[:-1])
if self.lm_head_name is not None and layer_name == self.lm_head_name and tie_word_embeddings:
lm_head_module = get_module(self.model, self.lm_head_name)
lm_head_module.to("meta") # Must to meta, otherwise model's saver will dump it again
self._move_module_to_meta(lm_head_module) # Must to meta, otherwise model's saver will dump it again
continue
self._add_tensor(pname, tensor.detach().to("cpu"))

Expand Down
11 changes: 11 additions & 0 deletions auto_round/special_model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,3 +596,14 @@ def get_predefined_ignore_layers(model: torch.nn.Module) -> list[str]:
layers.append(name)

return list(dict.fromkeys(layers))


_PRE_DEFINED_FIXED_ATTR = {"gemma4": {"has_variable_block_shape": True}}


def get_predefined_fixed_attr(model: torch.nn.Module) -> dict | None:
config = getattr(model, "config", None)
if config is not None and hasattr(config, "model_type"):
key = config.model_type
return _PRE_DEFINED_FIXED_ATTR.get(key, None)
return None
Loading
Loading