Skip to content

Add Gemma4 MaxText→vLLM Weight Converter (torchax)#3794

Open
hengtaoguo wants to merge 1 commit intomainfrom
hengtaoguo-gemma4
Open

Add Gemma4 MaxText→vLLM Weight Converter (torchax)#3794
hengtaoguo wants to merge 1 commit intomainfrom
hengtaoguo-gemma4

Conversation

@hengtaoguo
Copy link
Copy Markdown
Collaborator

@hengtaoguo hengtaoguo commented May 1, 2026

Description

Original author: @khatwanimohit @aireenmei #3677

Extracts the Gemma4-specific weight conversion logic from bench_weight_sync.py into a proper converter class, following the same base/model-specific split established for Qwen3.

New: gemma4_moe.py

  • Adds Gemma4MaxTextToVLLMConverter, inheriting from BaseMaxTextToVLLMConverter
  • Supports gemma4-26b (MoE: 128 routed + 1 shared expert)
  • Handles Gemma4's scanned-block layout (6 slots × N reps, local + global attention)
  • Overrides convert() to add the _convert_norms step and dispatch MoE vs. dense MLP

Updated: validate_converter.py

  • Imports Gemma4MaxTextToVLLMConverter and dispatches on gemma4-* model names
  • Adds gemma4-26b entry to vllm_model_name_mapping

Notes:

  1. Set env var MODEL_IMPL_TYPE=vllm to force the torchax-backed vLLM model for Gemma4 (default "auto" resolves to "flax_nnx" in newer tpu-inference, which uses a nested Flax state incompatible with the flat-key converter output)
  2. Gemma4's prompt need to start with <bos>, example: prompt="<box>Paris is"

Tests

Tested with validate_converter, full logs:

export MODEL_IMPL_TYPE="vllm"
python -m maxtext.integration.vllm.torchax_converter.validate_converter src/maxtext/configs/base.yml model_name=gemma4-26b tokenizer_type=huggingface tokenizer_path=google/gemma-4-26b-it load_parameters_path=gs://maxtext-gemma/gemma4/26b/converted/2026-04-07-23-04/0/items run_name=gemma4_converter_validation per_device_batch_size=1 max_prefill_predict_length=8 max_target_length=16 steps=1 scan_layers=true skip_jax_distributed_system=true weight_dtype=bfloat16 attention=dot_product remat_policy=custom decoder_layer_input=offload query_proj=offload key_proj=offload value_proj=offload ici_expert_parallelism=4 rollout_tensor_parallelism=4 hbm_utilization_vllm=0.8 async_scheduling=false prompt=\<bos\>Paris\ is hf_access_token=xxx 
[RequestOutput(request_id=0, prompt='<bos>Paris is', prompt_token_ids=[2, 50429, 563], encoder_prompt=None, encoder_prompt_token_ids=None, prompt_logprobs=None, outputs=[CompletionOutput(index=0, text=' the capital of France. It is the 1', token_ids=[506, 5279, 529, 7001, 236761, 1030, 563, 506, 236743, 236770], routed_experts=None, cumulative_logprob=None, logprobs=None, finish_reason=length, stop_reason=None)], finished=True, metrics=None, lora_request=None, num_cached_tokens=0)]

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@hengtaoguo hengtaoguo force-pushed the hengtaoguo-gemma4 branch 2 times, most recently from e1c342e to 36c846d Compare May 1, 2026 21:45
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 1, 2026

🤖 Hi @aireenmei, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request introduces a weight converter for the Gemma4 model to facilitate MaxText to vLLM conversion, specifically for the gemma4-26b MoE variant. The implementation is well-structured and follows the established patterns for weight conversion in the project, including proper JIT usage and memory management.

🔍 General Feedback

  • Naming Consistency: There is a critical naming discrepancy in the MoE weight keys (moe.per_expert_scale vs router.per_expert_scale) that should be resolved to ensure compatibility with vLLM's expectation.
  • Code Cleanup: A few minor items like unused arguments in JIT functions and commented-out debugging code in the validator should be addressed to maintain code quality.
  • Performance: The use of jax.jit for batch weight processing is a good practice for minimizing conversion overhead.

tensor_parallel_size=config.rollout_tensor_parallelism,
gpu_memory_utilization=getattr(config, "hbm_utilization_vllm", 0.5),
async_scheduling=getattr(config, "async_scheduling", False),
# load_format="dummy", # Load actual weights instead of dummy for debugging
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Using `load_format="dummy"` is usually preferred for converter validation to ensure the test is fast, deterministic, and doesn't depend on downloading large model weights.
Suggested change
# load_format="dummy", # Load actual weights instead of dummy for debugging
# load_format="dummy",

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load_format="dummy" was not working for me with torchax

)
self.vllm_state[f"{p}.moe.experts.w13_weight"] = processed.w13_weight
self.vllm_state[f"{p}.moe.experts.w2_weight"] = processed.w2_weight
# Shared expert: gate+up fused, TP-interleaved (MergedColumnParallelLinear,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Typo in comment: `swigluoai` -> `swiglu`.
Suggested change
# Shared expert: gate+up fused, TP-interleaved (MergedColumnParallelLinear,
w13_interleave=False, # Gemma4 uses gelu, not swiglu

FusedMoEWeights(
w13_weight=gate_up_proj[rep],
w13_weight_scale=None,
w13_bias=None,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 There is a naming discrepancy between the docstring (line 293) and the code here. The docstring specifies `router.per_expert_scale`, which is more consistent with the other `router.*` keys. Using `moe.per_expert_scale` might lead to a `KeyError` or incorrect weight mapping in vLLM.
Suggested change
w13_bias=None,
self.vllm_state[f"{p}.router.per_expert_scale"] = per_expert_scale[rep]

async_scheduling=getattr(config, "async_scheduling", False),
# load_format="dummy", # Load actual weights instead of dummy for debugging
)
print("\n" + "=" * 80)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Commented-out debugging code and unused helper functions should generally be removed before merging to keep the codebase clean.
Suggested change
print("\n" + "=" * 80)
llm_state = llm.llm_engine.model_executor.driver_worker.model_runner.state

# wi_0/wi_1: (d_model, L, d_sh) -> L × (d_sh, d_model)
# wo: (d_sh, L, d_model) -> L × (d_model, d_sh)
sh_gate = jnp.unstack(jnp.transpose(shared["wi_0"]["kernel"], (1, 2, 0)), axis=0)
sh_up = jnp.unstack(jnp.transpose(shared["wi_1"]["kernel"], (1, 2, 0)), axis=0)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 The `vllm_tp` argument and its inclusion in `static_argnames` are unnecessary as the variable is not used within the `_pack_moe` function.
Suggested change
sh_up = jnp.unstack(jnp.transpose(shared["wi_1"]["kernel"], (1, 2, 0)), axis=0)
@jax.jit
def _pack_moe(routed, shared, extra):

"qwen3-30b-a3b": "Qwen/Qwen3-30B-A3B",
"qwen3-30b-a3b-base": "Qwen/Qwen3-30B-A3B",
"qwen3-235b-a22b": "Qwen/Qwen3-235B-A22B",
"gemma4-26b": "google/gemma-4-26B-A4B-it",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we want to support the non-tuned one without "-it"?

tensor_parallel_size=config.rollout_tensor_parallelism,
gpu_memory_utilization=getattr(config, "hbm_utilization_vllm", 0.5),
async_scheduling=getattr(config, "async_scheduling", False),
# load_format="dummy", # Load actual weights instead of dummy for debugging
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load_format="dummy" was not working for me with torchax

logging.info("\tSharding: %s", leaf.sharding)

converter = Qwen3MaxTextToVLLMConverter(config, mesh)
if config.model_name.startswith("gemma4"):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When "gemma4", how about we add to the beginning of prompt if it doesn't have

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants