Add Gemma4 MaxText→vLLM Weight Converter (torchax)#3794
Add Gemma4 MaxText→vLLM Weight Converter (torchax)#3794hengtaoguo wants to merge 1 commit intomainfrom
Conversation
e1c342e to
36c846d
Compare
c1807e9 to
1859f35
Compare
|
🤖 Hi @aireenmei, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This Pull Request introduces a weight converter for the Gemma4 model to facilitate MaxText to vLLM conversion, specifically for the gemma4-26b MoE variant. The implementation is well-structured and follows the established patterns for weight conversion in the project, including proper JIT usage and memory management.
🔍 General Feedback
- Naming Consistency: There is a critical naming discrepancy in the MoE weight keys (
moe.per_expert_scalevsrouter.per_expert_scale) that should be resolved to ensure compatibility with vLLM's expectation. - Code Cleanup: A few minor items like unused arguments in JIT functions and commented-out debugging code in the validator should be addressed to maintain code quality.
- Performance: The use of
jax.jitfor batch weight processing is a good practice for minimizing conversion overhead.
| tensor_parallel_size=config.rollout_tensor_parallelism, | ||
| gpu_memory_utilization=getattr(config, "hbm_utilization_vllm", 0.5), | ||
| async_scheduling=getattr(config, "async_scheduling", False), | ||
| # load_format="dummy", # Load actual weights instead of dummy for debugging |
There was a problem hiding this comment.
| # load_format="dummy", # Load actual weights instead of dummy for debugging | |
| # load_format="dummy", |
There was a problem hiding this comment.
load_format="dummy" was not working for me with torchax
| ) | ||
| self.vllm_state[f"{p}.moe.experts.w13_weight"] = processed.w13_weight | ||
| self.vllm_state[f"{p}.moe.experts.w2_weight"] = processed.w2_weight | ||
| # Shared expert: gate+up fused, TP-interleaved (MergedColumnParallelLinear, |
There was a problem hiding this comment.
| # Shared expert: gate+up fused, TP-interleaved (MergedColumnParallelLinear, | |
| w13_interleave=False, # Gemma4 uses gelu, not swiglu |
| FusedMoEWeights( | ||
| w13_weight=gate_up_proj[rep], | ||
| w13_weight_scale=None, | ||
| w13_bias=None, |
There was a problem hiding this comment.
| w13_bias=None, | |
| self.vllm_state[f"{p}.router.per_expert_scale"] = per_expert_scale[rep] |
| async_scheduling=getattr(config, "async_scheduling", False), | ||
| # load_format="dummy", # Load actual weights instead of dummy for debugging | ||
| ) | ||
| print("\n" + "=" * 80) |
There was a problem hiding this comment.
| print("\n" + "=" * 80) | |
| llm_state = llm.llm_engine.model_executor.driver_worker.model_runner.state |
| # wi_0/wi_1: (d_model, L, d_sh) -> L × (d_sh, d_model) | ||
| # wo: (d_sh, L, d_model) -> L × (d_model, d_sh) | ||
| sh_gate = jnp.unstack(jnp.transpose(shared["wi_0"]["kernel"], (1, 2, 0)), axis=0) | ||
| sh_up = jnp.unstack(jnp.transpose(shared["wi_1"]["kernel"], (1, 2, 0)), axis=0) |
There was a problem hiding this comment.
| sh_up = jnp.unstack(jnp.transpose(shared["wi_1"]["kernel"], (1, 2, 0)), axis=0) | |
| @jax.jit | |
| def _pack_moe(routed, shared, extra): |
| "qwen3-30b-a3b": "Qwen/Qwen3-30B-A3B", | ||
| "qwen3-30b-a3b-base": "Qwen/Qwen3-30B-A3B", | ||
| "qwen3-235b-a22b": "Qwen/Qwen3-235B-A22B", | ||
| "gemma4-26b": "google/gemma-4-26B-A4B-it", |
There was a problem hiding this comment.
I thought we want to support the non-tuned one without "-it"?
| tensor_parallel_size=config.rollout_tensor_parallelism, | ||
| gpu_memory_utilization=getattr(config, "hbm_utilization_vllm", 0.5), | ||
| async_scheduling=getattr(config, "async_scheduling", False), | ||
| # load_format="dummy", # Load actual weights instead of dummy for debugging |
There was a problem hiding this comment.
load_format="dummy" was not working for me with torchax
| logging.info("\tSharding: %s", leaf.sharding) | ||
|
|
||
| converter = Qwen3MaxTextToVLLMConverter(config, mesh) | ||
| if config.model_name.startswith("gemma4"): |
There was a problem hiding this comment.
When "gemma4", how about we add to the beginning of prompt if it doesn't have
Description
Original author: @khatwanimohit @aireenmei #3677
Extracts the Gemma4-specific weight conversion logic from
bench_weight_sync.pyinto a proper converter class, following the same base/model-specific split established for Qwen3.New:
gemma4_moe.pyGemma4MaxTextToVLLMConverter, inheriting fromBaseMaxTextToVLLMConvertergemma4-26b(MoE: 128 routed + 1 shared expert)convert()to add the_convert_normsstep and dispatch MoE vs. dense MLPUpdated:
validate_converter.pyGemma4MaxTextToVLLMConverterand dispatches ongemma4-*model namesgemma4-26bentry tovllm_model_name_mappingNotes:
MODEL_IMPL_TYPE=vllmto force the torchax-backed vLLM model for Gemma4 (default"auto"resolves to"flax_nnx"in newer tpu-inference, which uses a nested Flax state incompatible with the flat-key converter output)<bos>, example:prompt="<box>Paris is"Tests
Tested with
validate_converter, full logs:Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.