Skip to content

[feat] Add initial Gemma 4 multimodal bridge supportFeat/gemma4 support#15

Open
JimmyMa99 wants to merge 3 commits intomodelscope:mainfrom
JimmyMa99:feat/gemma4-support
Open

[feat] Add initial Gemma 4 multimodal bridge supportFeat/gemma4 support#15
JimmyMa99 wants to merge 3 commits intomodelscope:mainfrom
JimmyMa99:feat/gemma4-support

Conversation

@JimmyMa99
Copy link
Copy Markdown

Summary

This PR adds initial Gemma 4 support to mcore-bridge, including:

  • mixed sliding/full attention config parsing
  • Gemma 4-specific text backbone handling
  • per_layer_input module support for small Gemma 4 models
  • multimodal bridge support for vision and audio paths
  • Gemma 4 model registration and loader wiring

The implementation is validated primarily on google/gemma-4-E2B-it, with additional structure-
level verification on google/gemma-4-E4B-it.

What Changed

Config / parser

Added Gemma 4-related config fields and parser mappings:

  • layer_types
  • global_kv_channels
  • num_global_query_groups
  • hidden_size_per_layer_input
  • vocab_size_per_layer_input
  • num_kv_shared_layers
  • use_double_wide_mlp
  • enable_moe_block
  • top_k_experts

Also preserve mixed sliding_attention / full_attention layout from HF config and convert it into
MCore-usable config values.

Text backbone

Added Gemma 4-specific model components:

  • Gemma4SelfAttention
  • Gemma4MLP
  • Gemma4TransformerLayer
  • Gemma4GPTModel
  • Gemma4MultimodalGPTModel
  • Gemma4Loader

This addresses Gemma 4-specific structure differences, especially:

  • mixed local/global attention head dimensions
  • per-layer input modules used by small Gemma 4 models
  • post-attention / post-MLP normalization wiring

Multimodal bridge

Added initial Gemma 4 multimodal support for:

  • vision_tower
  • embed_vision
  • audio_tower
  • embed_audio

Also fixed HF module mapping and placeholder handling so multimodal soft-token replacement works
correctly for Gemma 4.

Validation

Passed

google/gemma-4-E2B-it

  • 2-GPU load_weights smoke passed
  • vision embedding injection smoke passed
  • audio embedding injection smoke passed
  • text-only demo train step passed
  • vision demo train step passed

google/gemma-4-E4B-it

  • model build smoke passed
  • mixed-attention shape smoke passed
  • verified first full_attention layer uses larger QKV projection shape
  • verified multimodal/audio modules are present

Config-level validation

HF config parsing was checked for:

  • google/gemma-4-E2B-it
  • google/gemma-4-E4B-it
  • google/gemma-4-31B-it
  • google/gemma-4-26B-A4B-it

Not Yet Fully Validated

The following are not yet validated at weight-loading / execution level in this PR:

  • google/gemma-4-31B-it
  • google/gemma-4-26B-A4B-it

In particular, 26B-A4B MoE execution is not yet runtime-validated here. This PR should be
treated as initial Gemma 4 support, with E2B-it as the main validated target.

Notes

This PR is intentionally scoped as initial support rather than claiming full runtime validation
across the entire Gemma 4 family.

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the Gemma4 model architecture, including necessary configuration updates, parser logic for converting HuggingFace configurations, and the implementation of Gemma4-specific model components. Feedback identifies a missing forward method override in Gemma4GPTModel required for per-layer input processing and suggests making the activation function in Gemma4TransformerLayer configurable via the model configuration.

Comment on lines +114 to +135
class Gemma4GPTModel(GPTModel):

def __init__(self, config, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.hidden_size_per_layer_input = getattr(config, 'hidden_size_per_layer_input', 0) or 0
if self.hidden_size_per_layer_input:
self.embed_tokens_per_layer = VocabParallelEmbedding(
config.vocab_size_per_layer_input,
config.num_layers * self.hidden_size_per_layer_input,
init_method=config.init_method,
reduce_scatter_embeddings=False,
config=config,
tp_group=self.pg_collection.tp,
)
self.per_layer_model_projection = nn.Linear(
config.hidden_size,
config.num_layers * self.hidden_size_per_layer_input,
bias=False,
)
self.per_layer_projection_norm = Gemma4RMSNorm(self.hidden_size_per_layer_input, eps=config.layernorm_epsilon)
self.per_layer_input_scale = 2.0**-0.5
self.per_layer_model_projection_scale = config.hidden_size**-0.5
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The Gemma4GPTModel class defines methods for handling per-layer inputs (get_per_layer_inputs and project_per_layer_inputs) and initializes the necessary parameters, but it does not override the forward method to actually invoke this logic. In Gemma 4 models (specifically the smaller variants like E2B), the per-layer input is a critical architectural component. Without overriding forward to compute these inputs and pass them to the transformer layers via kwargs, the model will not produce correct outputs. You should override forward to compute per_layer_input and ensure it is passed down to the decoder.

if per_layer_input is not None and self.hidden_size_per_layer_input:
residual = hidden_states
hidden_states = self.per_layer_input_gate(hidden_states)
hidden_states = torch.nn.functional.gelu(hidden_states, approximate='tanh')
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The activation function for the per-layer input module is hardcoded to GELU with approximate='tanh'. While this matches Gemma 4's architecture, it would be more maintainable to derive this from the ModelConfig or use a named activation consistent with the rest of the bridge to avoid magic strings.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant