[feat] Add initial Gemma 4 multimodal bridge supportFeat/gemma4 support#15
[feat] Add initial Gemma 4 multimodal bridge supportFeat/gemma4 support#15JimmyMa99 wants to merge 3 commits intomodelscope:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for the Gemma4 model architecture, including necessary configuration updates, parser logic for converting HuggingFace configurations, and the implementation of Gemma4-specific model components. Feedback identifies a missing forward method override in Gemma4GPTModel required for per-layer input processing and suggests making the activation function in Gemma4TransformerLayer configurable via the model configuration.
| class Gemma4GPTModel(GPTModel): | ||
|
|
||
| def __init__(self, config, *args, **kwargs): | ||
| super().__init__(config, *args, **kwargs) | ||
| self.hidden_size_per_layer_input = getattr(config, 'hidden_size_per_layer_input', 0) or 0 | ||
| if self.hidden_size_per_layer_input: | ||
| self.embed_tokens_per_layer = VocabParallelEmbedding( | ||
| config.vocab_size_per_layer_input, | ||
| config.num_layers * self.hidden_size_per_layer_input, | ||
| init_method=config.init_method, | ||
| reduce_scatter_embeddings=False, | ||
| config=config, | ||
| tp_group=self.pg_collection.tp, | ||
| ) | ||
| self.per_layer_model_projection = nn.Linear( | ||
| config.hidden_size, | ||
| config.num_layers * self.hidden_size_per_layer_input, | ||
| bias=False, | ||
| ) | ||
| self.per_layer_projection_norm = Gemma4RMSNorm(self.hidden_size_per_layer_input, eps=config.layernorm_epsilon) | ||
| self.per_layer_input_scale = 2.0**-0.5 | ||
| self.per_layer_model_projection_scale = config.hidden_size**-0.5 |
There was a problem hiding this comment.
The Gemma4GPTModel class defines methods for handling per-layer inputs (get_per_layer_inputs and project_per_layer_inputs) and initializes the necessary parameters, but it does not override the forward method to actually invoke this logic. In Gemma 4 models (specifically the smaller variants like E2B), the per-layer input is a critical architectural component. Without overriding forward to compute these inputs and pass them to the transformer layers via kwargs, the model will not produce correct outputs. You should override forward to compute per_layer_input and ensure it is passed down to the decoder.
| if per_layer_input is not None and self.hidden_size_per_layer_input: | ||
| residual = hidden_states | ||
| hidden_states = self.per_layer_input_gate(hidden_states) | ||
| hidden_states = torch.nn.functional.gelu(hidden_states, approximate='tanh') |
There was a problem hiding this comment.
Summary
This PR adds initial Gemma 4 support to mcore-bridge, including:
The implementation is validated primarily on google/gemma-4-E2B-it, with additional structure-
level verification on google/gemma-4-E4B-it.
What Changed
Config / parser
Added Gemma 4-related config fields and parser mappings:
Also preserve mixed sliding_attention / full_attention layout from HF config and convert it into
MCore-usable config values.
Text backbone
Added Gemma 4-specific model components:
This addresses Gemma 4-specific structure differences, especially:
Multimodal bridge
Added initial Gemma 4 multimodal support for:
Also fixed HF module mapping and placeholder handling so multimodal soft-token replacement works
correctly for Gemma 4.
Validation
Passed
google/gemma-4-E2B-it
google/gemma-4-E4B-it
Config-level validation
HF config parsing was checked for:
Not Yet Fully Validated
The following are not yet validated at weight-loading / execution level in this PR:
In particular, 26B-A4B MoE execution is not yet runtime-validated here. This PR should be
treated as initial Gemma 4 support, with E2B-it as the main validated target.
Notes
This PR is intentionally scoped as initial support rather than claiming full runtime validation
across the entire Gemma 4 family.