Skip to content

add support for gemma4 model#1655

Open
n1ck-guo wants to merge 26 commits intomainfrom
hengguo/support_for_gemma4
Open

add support for gemma4 model#1655
n1ck-guo wants to merge 26 commits intomainfrom
hengguo/support_for_gemma4

Conversation

@n1ck-guo
Copy link
Copy Markdown
Contributor

@n1ck-guo n1ck-guo commented Apr 3, 2026

Description

Please briefly describe your main changes, the motivation.

Type of Change

  • Bug fix
  • New feature
  • Documentation update
  • Performance improvement
  • Code refactoring
  • Other (please specify):

Related Issues

Fixes or relates to #

Checklist Before Submitting

  • My code has been tested locally.
  • Documentation has been updated as needed.
  • New or updated tests are included where applicable.

Signed-off-by: n1ck-guo <heng.guo@intel.com>
Copilot AI review requested due to automatic review settings April 3, 2026 07:25
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds runtime handling for the gemma4 model type by patching Gemma4 decoder layers to avoid shape mismatches during auto-round block-wise quantization.

Changes:

  • Added gemma4 to the special model list and introduced a Gemma4-specific patch routine.
  • Hooked the patch into _handle_special_model when model.config.model_type == "gemma4".
  • Removed a couple of stray whitespace-only lines near ignore-layer registrations.

Comment thread auto_round/special_model_handler.py
Comment thread auto_round/special_model_handler.py
Comment thread auto_round/special_model_handler.py
Comment thread auto_round/special_model_handler.py Outdated
wenhuach21 and others added 10 commits April 3, 2026 15:46
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
wenhuach21 and others added 14 commits April 3, 2026 17:26
Signed-off-by: lvliang-intel <liang1.lv@intel.com>
Signed-off-by: n1ck-guo <heng.guo@intel.com>
Signed-off-by: n1ck-guo <heng.guo@intel.com>
…port_for_gemma4

# Conflicts:
#	auto_round/compressors/base.py
#	auto_round/special_model_handler.py
#	auto_round/utils/common.py
Signed-off-by: n1ck-guo <heng.guo@intel.com>
…emma4

# Conflicts:
#	auto_round/compressors/base.py
#	auto_round/utils/model.py
Signed-off-by: n1ck-guo <heng.guo@intel.com>
@n1ck-guo
Copy link
Copy Markdown
Contributor Author

/azp run Unit-Test-CUDA-AutoRound

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines successfully started running 1 pipeline(s).

@JGSphaela
Copy link
Copy Markdown

Hi @n1ck-guo , I encountered a crash while testing this PR with google/gemma-4-E2B-it using iters 0 (RTN).

The Issue:
The quantization crashes with TypeError: 'NoneType' object does not support item assignment at transformers/models/gemma4/modeling_gemma4.py:1226. This happens because shared_kv_states is passed as None to the attention module.

Root Cause:
The patched_layer_forward signature in auto_round/special_model_handler.py is missing the shared_kv_states argument, which is the 3rd positional argument in the latest Gemma4DecoderLayer.forward implementation.

For Gemma-4 E2B/E4B, this dictionary is required to manage the shared KV cache between anchor and sharer layers. When auto-round patches the layer, it accidentally drops this argument.

Suggested Fix:
Update the signature in _patch_gemma4_model to include shared_kv_states and ensure it is propagated to orig_fwd. Additionally, since auto-round processes blocks individually, we should initialize a shared dictionary in the patching closure to maintain the state across layers:

# In _patch_gemma4_model
shared_kv_states_global = {}

# In patched_layer_forward signature
def patched_layer_forward(self, hidden_states, per_layer_input=None, shared_kv_states=None, ...):
    if shared_kv_states is None:
        shared_kv_states = shared_kv_states_global
    ...
    return orig_fwd(..., shared_kv_states=shared_kv_states, ...)

This fix resolved the issue in my local tests and allowed quantization to complete.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants