-
Notifications
You must be signed in to change notification settings - Fork 107
[step 1]support variable block input shapes for gemma4 #1656
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
wenhuach21
wants to merge
16
commits into
main
Choose a base branch
from
support_gemma4
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
8ab6ebe
try to support gemma4
wenhuach21 0cc631d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 6a6afcf
Update auto_round/compressors/base.py
wenhuach21 7f1af02
Update auto_round/compressors/base.py
wenhuach21 416797c
Update auto_round/compressors/base.py
wenhuach21 4638d8a
refine
wenhuach21 a8dd583
Update auto_round/special_model_handler.py
wenhuach21 8256d92
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] d3d1ed8
fix
wenhuach21 f2e5332
Merge branch 'support_gemma4' of https://github.com/intel/auto-round …
wenhuach21 222838e
support opt_rtn
wenhuach21 bc93840
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 771b003
update
wenhuach21 f5849bf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 4c087b6
fix offload immediate_saving issue
lvliang-intel 5b4ae05
Merge branch 'main' into support_gemma4
lvliang-intel File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -69,7 +69,7 @@ | |
| preset_name_to_scheme, | ||
| ) | ||
| from auto_round.sign_sgd import SignSGD | ||
| from auto_round.special_model_handler import get_predefined_ignore_layers, update_module | ||
| from auto_round.special_model_handler import get_predefined_fixed_attr, get_predefined_ignore_layers, update_module | ||
| from auto_round.utils import ( | ||
| INNER_SUPPORTED_LAYER_TYPES, | ||
| SUPPORTED_DTYPES, | ||
|
|
@@ -569,6 +569,14 @@ def __init__( | |
| ) | ||
|
|
||
| self.hadamard_config = normalize_hadamard_config(hadamard_config) | ||
| self.has_variable_block_shape = False | ||
| all_blocks = self.quant_block_list if self.quant_block_list else get_block_names(self.model) | ||
| if not all_blocks: | ||
| raise ValueError("Could not find any blocks. Check the model or quant_block_list.") | ||
| self.blocks_requiring_input_ids = [data if isinstance(data, str) else data[0] for data in all_blocks] | ||
| fixed_attr = get_predefined_fixed_attr(self.model) or {} | ||
| for key, value in fixed_attr.items(): | ||
| setattr(self, key, value) | ||
|
|
||
| def _gen_auto_scheme(self) -> dict[str, dict]: | ||
| if self.mllm: | ||
|
|
@@ -1517,17 +1525,20 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str]) | |
| if not all_blocks: | ||
| raise ValueError("Could not find any blocks. Check the model or quant_block_list.") | ||
|
|
||
| all_first_block_names = [block[0] for block in all_blocks] | ||
| if not self.has_variable_block_shape: | ||
| to_cache_block_names = [block[0] for block in all_blocks] | ||
| else: | ||
| to_cache_block_names = flatten_list(all_blocks) | ||
| layer_names = self._get_quantized_layer_names_outside_blocks() | ||
| if self.act_bits < 16 and (not self.act_dynamic or len(layer_names) > 0): | ||
| if self.act_bits < 16 and (not self.act_dynamic or len(layer_names) > 0) or self.has_variable_block_shape: | ||
| if len(layer_names) > 0: | ||
| logger.warning( | ||
| "quantize layers outside blocks for static activation quantizaiton" | ||
| " will significantly increase calibration time" | ||
| ) | ||
| all_inputs = self.try_cache_inter_data_gpucpu(all_first_block_names, self.nsamples, layer_names) | ||
| all_inputs = self.try_cache_inter_data_gpucpu(to_cache_block_names, self.nsamples, layer_names) | ||
| else: | ||
| all_inputs = self.cache_inter_data(all_first_block_names, self.nsamples) | ||
| all_inputs = self.cache_inter_data(to_cache_block_names, self.nsamples) | ||
|
|
||
| # Clear hooks for multi-GPU setups | ||
| if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1: | ||
|
|
@@ -1551,20 +1562,28 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str]) | |
| if total_samples < self.batch_size: | ||
| self.batch_size = total_samples | ||
| logger.warning(f"Forcing batch size to {total_samples}") | ||
| tmp_dtype = self.amp_dtype if self.amp else torch.float32 | ||
|
|
||
| input_ids = to_device(inputs.pop("input_ids"), self.cache_device) | ||
| input_others = to_device(inputs, self.cache_device) | ||
|
|
||
| tmp_dtype = self.amp_dtype if self.amp else torch.float32 | ||
| input_ids = [id_.to(tmp_dtype) for id_ in input_ids] | ||
|
|
||
| for key, val in input_others.items(): | ||
| if isinstance(val, torch.Tensor) and val.dtype in (torch.float16, torch.bfloat16): | ||
| input_others[key] = val.to(tmp_dtype) | ||
| elif isinstance(val, list): | ||
| input_others[key] = [to_dtype(v, tmp_dtype) for v in val] | ||
| def process_input_others(input_others): | ||
|
|
||
| input_others = to_device(input_others, self.cache_device) | ||
| for key, val in input_others.items(): | ||
| if isinstance(val, torch.Tensor) and val.dtype in (torch.float16, torch.bfloat16): | ||
| input_others[key] = val.to(tmp_dtype) | ||
| elif isinstance(val, list): | ||
| input_others[key] = [to_dtype(v, tmp_dtype) for v in val] | ||
| return input_others | ||
|
|
||
| input_others = inputs | ||
| input_others = process_input_others(input_others) | ||
| for block_name in block_names: | ||
| if block_name in all_inputs.keys(): | ||
| input_others = all_inputs[block_name] | ||
| input_others = process_input_others(input_others) | ||
| all_inputs.pop(block_name) | ||
| pbar.set_description(f"Quantizing {block_name}") | ||
| block = get_module(self.model, block_name) | ||
| materialize_model_(block) | ||
|
|
@@ -1809,22 +1828,25 @@ def _should_disable_inplace_due_to_layers_outside_block() -> bool: | |
| self.model = self.model.to(self.amp_dtype) | ||
|
|
||
| layer_names = self._get_quantized_layer_names_outside_blocks() | ||
| all_first_block_names = [block[0] for block in all_blocks] | ||
| if not self.has_variable_block_shape: | ||
| to_cache_block_names = [block[0] for block in all_blocks] | ||
| else: | ||
| to_cache_block_names = flatten_list(all_blocks) | ||
| if len(layer_names) > 0: | ||
| logger.info( | ||
| "Starting to cache block inputs. This may be slow due to external block layers: %s", layer_names | ||
| ) | ||
| else: | ||
| logger.info("start to cache block inputs") | ||
| all_inputs = self.try_cache_inter_data_gpucpu(all_first_block_names, self.nsamples, layer_names=layer_names) | ||
| all_inputs = self.try_cache_inter_data_gpucpu(to_cache_block_names, self.nsamples, layer_names=layer_names) | ||
| is_quantized_embedding = self._quantize_embedding_layer() | ||
| clear_memory(device_list=self.device_list) | ||
| all_q_inputs = None | ||
| if is_quantized_embedding: | ||
| all_inputs = copy.deepcopy(self.inputs) | ||
| clear_memory(self.inputs, device_list=self.device_list) | ||
| all_q_inputs = self.try_cache_inter_data_gpucpu( | ||
| all_first_block_names, self.nsamples, layer_names=layer_names | ||
| to_cache_block_names, self.nsamples, layer_names=layer_names | ||
| ) | ||
| # Remove accelerate dispatch hooks before moving parameters. | ||
| # hf_device_map is kept for reference but hooks are no longer needed. | ||
|
|
@@ -1872,6 +1894,7 @@ def _should_disable_inplace_due_to_layers_outside_block() -> bool: | |
| nblocks=self.nblocks, | ||
| device=self.device, | ||
| pbar=pbar, | ||
| input_others_extra_blocks=all_inputs, | ||
| ) | ||
| if self.is_immediate_packing and len(self.formats) != 1: | ||
| raise ValueError( | ||
|
|
@@ -2257,6 +2280,7 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l | |
| Raises: | ||
| Exception: If caching on GPU fails, switches to CPU and caches there. | ||
| """ | ||
| block_names = flatten_list(block_names) | ||
| if is_quantized_input_module(self.model): | ||
| layer_names = [] | ||
| if layer_names is None: | ||
|
|
@@ -2397,17 +2421,18 @@ def cache_inter_data(self, block_names, nsamples, layer_names=None, last_cache_n | |
| if layer_names is None: | ||
| layer_names = [] | ||
| self.inputs = {} | ||
| block_names = flatten_list(block_names) | ||
| self.to_cached_layers = block_names + layer_names | ||
|
|
||
| tmp_dtype = None # TODO delete this as most model is not fp32 now | ||
| ## have bug if block name is not the first block | ||
| # There is a bug if the block name is not the first block | ||
| if (len(block_names) > 1 or len(layer_names) > 0) and self.low_gpu_mem_usage: | ||
| tmp_dtype = self.model.dtype | ||
| if self.amp: | ||
| if self.model.dtype != self.model.dtype: | ||
| self.model = self.model.to(torch.bfloat16) | ||
| else: | ||
| self.model = self.model.to(torch.float32) ##model on cpu | ||
| self.model = self.model.to(torch.float32) # model on cpu | ||
|
|
||
| self.last_cache_name = self._infer_last_cache_name(block_names, layer_names, last_cache_name) | ||
| self._cache_target_set = set(self.to_cached_layers) | ||
|
|
@@ -2539,6 +2564,8 @@ def forward(m, hidden_states=None, *positional_inputs, **kwargs): | |
| or isinstance(kwargs[key], list) | ||
| or isinstance(kwargs[key], tuple) | ||
| ): | ||
| if name not in self.blocks_requiring_input_ids and key == "hidden_states": | ||
| continue | ||
| if key not in self.inputs[name].keys(): # initialization | ||
| data = to_device(kwargs[key], device=torch.device("cpu")) | ||
| if data is None or (self.batch_size > 1 and key in self.shared_cache_keys): | ||
|
|
@@ -3224,21 +3251,21 @@ def _quantize_block( | |
| return None, output | ||
|
|
||
| def _split_inputs(self, inputs: dict, first_input_name: str) -> tuple[torch.Tensor, dict]: | ||
| input_ids = inputs[first_input_name] | ||
| input_ids = inputs.get(first_input_name, None) | ||
| inputs.pop(first_input_name, None) | ||
| input_others = inputs | ||
| return input_ids, input_others | ||
|
|
||
| def _preprocess_block_inputs(self, inputs, first_input_name="input_ids"): | ||
| input_ids, input_others = self._split_inputs(inputs, first_input_name) | ||
| clear_memory(device_list=self.device_list) | ||
| input_ids = to_device(input_ids, self.cache_device) | ||
| tmp_dtype = self.amp_dtype if self.amp else torch.float32 | ||
| if input_ids is not None: | ||
| input_ids = to_device(input_ids, self.cache_device) | ||
| input_ids = to_dtype(input_ids, tmp_dtype) | ||
| input_others = to_device(input_others, self.cache_device) | ||
| # As in calibration phase, we may use bf16 for calibration due to low_gpu_memory usage | ||
|
|
||
| tmp_dtype = self.amp_dtype if self.amp else torch.float32 | ||
| input_ids = to_dtype(input_ids, tmp_dtype) | ||
|
|
||
| for key in input_others.keys(): | ||
| if isinstance(input_others[key], torch.Tensor) and ( | ||
| input_others[key].dtype == torch.float16 or input_others[key].dtype == torch.bfloat16 | ||
|
|
@@ -3258,6 +3285,7 @@ def _quantize_blocks( | |
| nblocks: int = 1, | ||
| device: str = "cpu", | ||
| pbar: tqdm = None, | ||
| input_others_extra_blocks: dict = None, | ||
|
wenhuach21 marked this conversation as resolved.
|
||
| ): | ||
| """Quantize and dequantize the weights of the specified blocks in the model. | ||
|
|
||
|
|
@@ -3281,12 +3309,17 @@ def _quantize_blocks( | |
| pbar = tqdm(range(0, len(block_names), nblocks)) | ||
|
|
||
| for i in range(0, len(block_names), nblocks): | ||
| if block_names[i] in input_others_extra_blocks: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better to change to "if input_others_extra_blocks and block_names[i] in input_others_extra_blocks:" |
||
| input_others = input_others_extra_blocks[block_names[i]] | ||
| _, input_others = self._preprocess_block_inputs(input_others) | ||
| input_others_extra_blocks.pop(block_names[i]) | ||
|
wenhuach21 marked this conversation as resolved.
|
||
| if i != 0: | ||
| pbar.update(1) | ||
| if nblocks == 1: | ||
| n = block_names[i] | ||
| pbar.set_description(f"Quantizing {n}") | ||
| m = get_module(model, n) | ||
|
|
||
| else: | ||
| names = block_names[i : min(i + nblocks, len(block_names))] | ||
| pbar.set_description(f"Quantizing [{i + 1}-{min(i + nblocks, len(block_names))}]/{len(block_names)}") | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.