From 1859f35848d8a28212426370c0bef55ecd7ba5ee Mon Sep 17 00:00:00 2001 From: hengtaoguo Date: Fri, 1 May 2026 21:22:32 +0000 Subject: [PATCH] Add Gemma4 maxtext-vllm converter (torchax) --- .../vllm/torchax_converter/gemma4_moe.py | 414 ++++++++++++++++++ .../torchax_converter/validate_converter.py | 25 +- 2 files changed, 435 insertions(+), 4 deletions(-) create mode 100644 src/maxtext/integration/vllm/torchax_converter/gemma4_moe.py diff --git a/src/maxtext/integration/vllm/torchax_converter/gemma4_moe.py b/src/maxtext/integration/vllm/torchax_converter/gemma4_moe.py new file mode 100644 index 0000000000..47feea9f8b --- /dev/null +++ b/src/maxtext/integration/vllm/torchax_converter/gemma4_moe.py @@ -0,0 +1,414 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Gemma4 MaxText to vLLM weight converter. + +Supports gemma4-26b (MoE: 128 routed + 1 shared expert). + +MaxText Gemma4 stores layers in a scanned-block structure: + state['base']['decoder']['scanned_blocks']['layers_{slot}'] +where slot ∈ [0..5]. Slots 0–4 are local-sliding-window attention layers +and slot 5 is a global attention layer. The 'L' dimension (axis 1 of each +weight tensor) holds 'num_reps = num_layers // 6' repetitions of each slot. +Final vLLM layer index = rep * 6 + slot. + +Global attention (slot 5) uses a shared KV projection — 'key' serves as +both K and V; there is no separate 'value' tensor. + +Key names and tensor transformations are derived from the MaxText HF param mapping +at src/maxtext/checkpoint_conversion/utils/param_mapping.py. + +Attention: Gemma4 uses SEPARATE q/k/v proj weights (not fused QKV). +MoE (26B): gate+up proj are fused into experts.gate_up_proj (E, 2*d_inner, d_model). +Embedding: MaxText stores embedding * sqrt(d_model); divide out before writing to vLLM. +""" + +import functools +import gc +import logging + +import jax +import jax.numpy as jnp +from tpu_inference.layers.common.moe import MoEBackend +from tpu_inference.layers.common.process_weights.moe_weights import FusedMoEWeights +from tpu_inference.layers.common.process_weights.moe_weights import process_moe_weights + +from maxtext.integration.vllm.torchax_converter.base import BaseMaxTextToVLLMConverter +from maxtext.integration.vllm.torchax_converter.base import timer +from maxtext.integration.vllm.torchax_converter.base import GREEN +from maxtext.integration.vllm.torchax_converter.base import RESET + + +class Gemma4MaxTextToVLLMConverter(BaseMaxTextToVLLMConverter): + """Converts MaxText Gemma4 weights to the layout expected by a vLLM Gemma4 model.""" + + NUM_SLOTS = 6 # 5 local + 1 global + + def __init__(self, config, mesh): + super().__init__(config, mesh) + assert self.num_layers % self.NUM_SLOTS == 0, f"num_layers {self.num_layers} must be divisible by {self.NUM_SLOTS}" + self.num_reps = self.num_layers // self.NUM_SLOTS + self.is_moe = config.model_name == "gemma4-26b" + self.d_model = config.base_emb_dim + + # --- 1. Top-Level Entry Point --- + + def convert(self, model_state: dict): + """Convert a MaxText Gemma4 model state into vLLM weight tensors.""" + logging.info( + "\n%sStarting Gemma4 Conversion (is_moe=%s, num_layers=%d, num_reps=%d)...%s", + GREEN, + self.is_moe, + self.num_layers, + self.num_reps, + RESET, + ) + self.vllm_state = {} + blocks = model_state["base"]["decoder"]["scanned_blocks"] + prefix = "vllm_model.language_model.model.layers" + + with timer("Convert Global Weights"): + self._convert_global(model_state) + with timer("Convert Layer Norms"): + self._convert_norms(blocks, prefix) + with timer("Convert Attention Weights"): + self._convert_attn_weights(blocks, prefix) + if self.is_moe: + with timer("Convert MoE Weights"): + self._convert_moe_weights(blocks, prefix) + else: + with timer("Convert Dense MLP Weights"): + self._convert_dense_mlp_weights(blocks, prefix) + + return self.vllm_state + + # --- Abstract method implementations (delegate to Gemma4-specific methods) --- + + def _convert_global(self, params): + """Convert non-layered weights (embed_tokens, lm_head, final norm).""" + # Gemma4 uses tied embeddings: no logits_dense; lm_head.weight = embed_tokens.weight. + # MaxText stores embedding pre-multiplied by sqrt(hidden_size) (applied during HF->MaxText + # conversion in param_mapping.py). vLLM/tpu-inference apply sqrt(hidden_size) at runtime, + # so divide out the pre-multiplied factor to give vLLM the raw embedding. + logging.info("_convert_global: embed_tokens (de-normalize) + lm_head (tied) + final_norm...") + normalizer = self.d_model**0.5 + + @jax.jit + def _denorm_embed(x): + return (x / normalizer).astype(x.dtype) + + raw_embedding = _denorm_embed(params["base"]["token_embedder"]["embedding"]) + self.vllm_state["vllm_model.language_model.model.embed_tokens.weight"] = raw_embedding + self.vllm_state["vllm_model.language_model.lm_head.weight"] = raw_embedding # tied + self.vllm_state["vllm_model.language_model.model.norm.weight"] = params["base"]["decoder"]["decoder_norm"]["scale"] + logging.info("_convert_global: done") + + def _convert_attn(self, params): + """Satisfy abstract interface; Gemma4 uses _convert_attn_weights instead.""" + blocks = params["base"]["decoder"]["scanned_blocks"] + prefix = "vllm_model.language_model.model.layers" + self._convert_attn_weights(blocks, prefix) + + def _convert_moe(self, params): + """Satisfy abstract interface; Gemma4 uses _convert_moe_weights/_convert_dense_mlp_weights.""" + blocks = params["base"]["decoder"]["scanned_blocks"] + prefix = "vllm_model.language_model.model.layers" + if self.is_moe: + self._convert_moe_weights(blocks, prefix) + else: + self._convert_dense_mlp_weights(blocks, prefix) + + # --- 2. Static JIT helper --- + + @staticmethod + @jax.jit + def _pack_attn(q, k, v, o, qnorm, knorm): + """Prepares separate q/k/v, o, and norms for all layers in a slot. + + Input shapes (MaxText scanned, scan axis at index 1): + q/k/v: (d_model, L, nH, D) + o: (nH, L, D, d_model) # scan axis is 1 + norms: (d_model, L) + Returns: L × (nH*D, d_model) for q/k/v, L × (d_model, nH*D) for o. + """ + # q/k/v: (d_model, L, nH, D) -> (L, nH, D, d_model) -> (L, nH*D, d_model) + q = jnp.transpose(q, (1, 2, 3, 0)).reshape(q.shape[1], -1, q.shape[0]) + k = jnp.transpose(k, (1, 2, 3, 0)).reshape(k.shape[1], -1, k.shape[0]) + v = jnp.transpose(v, (1, 2, 3, 0)).reshape(v.shape[1], -1, v.shape[0]) + # o: (nH, L, D, d_model) -> (L, d_model, nH, D) -> (L, d_model, nH*D) + o = jnp.transpose(o, (1, 3, 0, 2)).reshape(o.shape[1], o.shape[3], -1) + # norms: (D, L) -> (L, D) + qnorm = jnp.transpose(qnorm, (1, 0)) + knorm = jnp.transpose(knorm, (1, 0)) + return ( + jnp.unstack(q), + jnp.unstack(k), + jnp.unstack(v), + jnp.unstack(o), + jnp.unstack(qnorm), + jnp.unstack(knorm), + ) + + # --- 3. Per-layer norms --- + + def _convert_norms(self, blocks, prefix): + """Converts all 4 per-layer norm vectors across all layers.""" + + @jax.jit + def _unstack_norm(x): + # x: (d_model, L) -> L tensors of (d_model,) + return jnp.unstack(x, axis=1) + + for slot in range(self.NUM_SLOTS): + slot_data = blocks[f"layers_{slot}"] + pre_attn = _unstack_norm(slot_data["pre_self_attention_norm"]["scale"]) + post_attn = _unstack_norm(slot_data["post_self_attention_norm"]["scale"]) + pre_ffw = _unstack_norm(slot_data["pre_ffw_norm"]["scale"]) + post_ffw = _unstack_norm(slot_data["post_ffw_norm"]["scale"]) + for rep in range(self.num_reps): + i = rep * self.NUM_SLOTS + slot + self.vllm_state[f"{prefix}.{i}.input_layernorm.weight"] = pre_attn[rep] + self.vllm_state[f"{prefix}.{i}.post_attention_layernorm.weight"] = post_attn[rep] + self.vllm_state[f"{prefix}.{i}.pre_feedforward_layernorm.weight"] = pre_ffw[rep] + self.vllm_state[f"{prefix}.{i}.post_feedforward_layernorm.weight"] = post_ffw[rep] + del pre_attn, post_attn, pre_ffw, post_ffw + gc.collect() + + # --- 4. Per-layer attention weights --- + + def _convert_attn_weights(self, blocks, prefix): + """Converts separate q/k/v proj, o proj, q-norm, k-norm for all layers. + + HF/vLLM Gemma4 uses separate projections (not fused QKV). Global attention + layers (slot 5) have no 'value' tensor; vLLM sets v_proj = k_proj. + + Tensor transformations (MaxText → HF): + q/k/v kernel: (d_model, nH, D) → (nH*D, d_model) [reshape then transpose] + out kernel: (nH, D, d_model) → (d_model, nH*D) [reshape then transpose] + norms: (D,) → (D,) [identity] + """ + + @jax.jit + def _pack_local(attn): + q = attn["query"]["kernel"] + k = attn["key"]["kernel"] + v = attn["value"]["kernel"] + return Gemma4MaxTextToVLLMConverter._pack_attn( + q, + k, + v, + attn["out"]["kernel"], + attn["query_norm"]["scale"], + attn["key_norm"]["scale"], + ) + + @jax.jit + def _pack_global(attn): + # Global: no 'value'; key used as both K and V (shared KV projection). + q = attn["query"]["kernel"] + k = attn["key"]["kernel"] + return Gemma4MaxTextToVLLMConverter._pack_attn( + q, + k, + k, + attn["out"]["kernel"], + attn["query_norm"]["scale"], + attn["key_norm"]["scale"], + ) + + for slot in range(self.NUM_SLOTS): + is_global = slot == self.NUM_SLOTS - 1 + attn = blocks[f"layers_{slot}"]["self_attention"] + pack_fn = _pack_global if is_global else _pack_local + q_layers, k_layers, v_layers, o_layers, qnorm_layers, knorm_layers = pack_fn(attn) + num_kv_heads = self.config.global_num_kv_heads if is_global else self.config.base_num_kv_heads + tp = min(self.vllm_tp, num_kv_heads) + for rep in range(self.num_reps): + i = rep * self.NUM_SLOTS + slot + q, k, v = q_layers[rep], k_layers[rep], v_layers[rep] + # QKVParallelLinear (vLLM) expects TP-interleaved layout: + # [q_tp0, k_tp0, v_tp0, q_tp1, k_tp1, v_tp1, ...] + q_per_tp = q.shape[0] // tp + kv_per_tp = k.shape[0] // tp + qkv = jnp.concatenate( + [ + q.reshape(tp, q_per_tp, q.shape[1]), + k.reshape(tp, kv_per_tp, k.shape[1]), + v.reshape(tp, kv_per_tp, v.shape[1]), + ], + axis=1, + ).reshape(-1, q.shape[1]) + self.vllm_state[f"{prefix}.{i}.self_attn.qkv_proj.weight"] = qkv + self.vllm_state[f"{prefix}.{i}.self_attn.o_proj.weight"] = o_layers[rep] + self.vllm_state[f"{prefix}.{i}.self_attn.q_norm.weight"] = qnorm_layers[rep] + self.vllm_state[f"{prefix}.{i}.self_attn.k_norm.weight"] = knorm_layers[rep] + del q_layers, k_layers, v_layers, o_layers, qnorm_layers, knorm_layers + gc.collect() + + # --- 5a. MoE weights (gemma4-26b only) --- + + def _convert_moe_weights(self, blocks, prefix): + """Converts router, routed experts (fused gate_up_proj), shared expert, MoE norms (26B). + + Tensor transformations: + router.proj.weight: gate.kernel (d_model, L, E) → (E, d_model) + router.scale: pre_forward_scale_2 (d_model, L) → (d_model,) + router.per_expert_scale: per_expert_scale (E, L) → (E,) + experts.gate_up_proj: fuse wi_0+wi_1 (E, L, d_model, d_inner) → (E, 2*d_inner, d_model) + experts.down_proj: wo (E, L, d_inner, d_model) → (E, d_model, d_inner) + shared mlp.*: (d_model, L, d_sh) or (d_sh, L, d_model) → HF convention + extra norms: (d_model, L) → (d_model,) + """ + + @functools.partial(jax.jit, static_argnames=["vllm_tp"]) + def _pack_moe(routed, shared, extra, vllm_tp): + # Router proj: (d_model, L, E) -> L × (E, d_model) + router_proj = jnp.unstack(jnp.transpose(routed["gate"]["kernel"], (1, 2, 0)), axis=0) + # Router scale: (d_model, L) -> L × (d_model,) + router_scale = jnp.unstack(extra["pre_forward_scale_2"], axis=1) + # Per-expert scale: (E, L) -> L × (E,) + per_expert_scale = jnp.unstack(routed["per_expert_scale"], axis=1) + + # Fused gate+up proj for routed experts (HF format): + # wi_0 (gate): (E, L, d_model, d_inner) -> (L, E, d_inner, d_model) + # wi_1 (up): (E, L, d_model, d_inner) -> (L, E, d_inner, d_model) + # concat along axis 2: (L, E, 2*d_inner, d_model) = gate_up_proj + w0 = jnp.transpose(routed["wi_0"], (1, 0, 3, 2)) # (L, E, d_inner, d_model) + w1 = jnp.transpose(routed["wi_1"], (1, 0, 3, 2)) # (L, E, d_inner, d_model) + gate_up = jnp.concatenate([w0, w1], axis=2) # (L, E, 2*d_inner, d_model) + gate_up_proj = jnp.unstack(gate_up, axis=0) + + # Down proj: (E, L, d_inner, d_model) -> L × (E, d_model, d_inner) + down_proj = jnp.unstack(jnp.transpose(routed["wo"], (1, 0, 3, 2)), axis=0) + + # Shared expert: + # wi_0/wi_1: (d_model, L, d_sh) -> L × (d_sh, d_model) + # wo: (d_sh, L, d_model) -> L × (d_model, d_sh) + sh_gate = jnp.unstack(jnp.transpose(shared["wi_0"]["kernel"], (1, 2, 0)), axis=0) + sh_up = jnp.unstack(jnp.transpose(shared["wi_1"]["kernel"], (1, 2, 0)), axis=0) + sh_down = jnp.unstack(jnp.transpose(shared["wo"]["kernel"], (1, 2, 0)), axis=0) + + # Extra MoE norms: (d_model, L) -> L × (d_model,) + pre_ln_2 = jnp.unstack(extra["pre_feedforward_layernorm_2"]["scale"], axis=1) + post_ln_1 = jnp.unstack(extra["post_feedforward_layernorm_1"]["scale"], axis=1) + post_ln_2 = jnp.unstack(extra["post_feedforward_layernorm_2"]["scale"], axis=1) + + return ( + router_proj, + router_scale, + per_expert_scale, + gate_up_proj, + down_proj, + sh_gate, + sh_up, + sh_down, + pre_ln_2, + post_ln_1, + post_ln_2, + ) + + for slot in range(self.NUM_SLOTS): + moe_block = blocks[f"layers_{slot}"]["mlp"]["moe_block"] + routed = moe_block["MoeBlock_0"] + shared = moe_block["shared_experts"] + extra = blocks[f"layers_{slot}"]["mlp"] + ( + router_proj, + router_scale, + per_expert_scale, + gate_up_proj, + down_proj, + sh_gate, + sh_up, + sh_down, + pre_ln_2, + post_ln_1, + post_ln_2, + ) = _pack_moe(routed, shared, extra, self.vllm_tp) + + for rep in range(self.num_reps): + i = rep * self.NUM_SLOTS + slot + p = f"{prefix}.{i}" + # Router + self.vllm_state[f"{p}.router.proj.weight"] = router_proj[rep] + self.vllm_state[f"{p}.router.scale"] = router_scale[rep] + self.vllm_state[f"{p}.moe.per_expert_scale"] = per_expert_scale[rep] + # Routed experts: apply process_moe_weights (GMM_TP: swapaxes + pad + TP reorder) + # to produce the post-processed format that llm_state holds after model init. + processed = process_moe_weights( + FusedMoEWeights( + w13_weight=gate_up_proj[rep], + w13_weight_scale=None, + w13_bias=None, + w2_weight=down_proj[rep], + w2_weight_scale=None, + w2_bias=None, + ), + moe_backend=MoEBackend.GMM_TP, + w13_reorder_size=self.vllm_tp, + w13_interleave=False, # Gemma4 uses gelu, not swigluoai + ) + self.vllm_state[f"{p}.moe.experts.w13_weight"] = processed.w13_weight + self.vllm_state[f"{p}.moe.experts.w2_weight"] = processed.w2_weight + # Shared expert: gate+up fused, TP-interleaved (MergedColumnParallelLinear, + # spec=P('model', None)): [gate_tp0, up_tp0, gate_tp1, up_tp1, ...] + sh_g, sh_u = sh_gate[rep], sh_up[rep] # each (d_sh, d_model) + sh_per_tp = sh_g.shape[0] // self.vllm_tp + shared_gate_up = jnp.concatenate( + [ + sh_g.reshape(self.vllm_tp, sh_per_tp, sh_g.shape[1]), + sh_u.reshape(self.vllm_tp, sh_per_tp, sh_u.shape[1]), + ], + axis=1, + ).reshape(-1, sh_g.shape[1]) + self.vllm_state[f"{p}.mlp.gate_up_proj.weight"] = shared_gate_up + self.vllm_state[f"{p}.mlp.down_proj.weight"] = sh_down[rep] + # Extra MoE norms + self.vllm_state[f"{p}.pre_feedforward_layernorm_2.weight"] = pre_ln_2[rep] + self.vllm_state[f"{p}.post_feedforward_layernorm_1.weight"] = post_ln_1[rep] + self.vllm_state[f"{p}.post_feedforward_layernorm_2.weight"] = post_ln_2[rep] + + del router_proj, router_scale, per_expert_scale, gate_up_proj, down_proj + del sh_gate, sh_up, sh_down, pre_ln_2, post_ln_1, post_ln_2 + gc.collect() + + # --- 5b. Dense MLP weights (gemma4-31b only) --- + + def _convert_dense_mlp_weights(self, blocks, prefix): + """Converts gate/up/down projections for all layers (31B only). + + Tensor transformations: + wi_0 (gate): (d_model, L, d_mlp) → L × (d_mlp, d_model) + wi_1 (up): (d_model, L, d_mlp) → L × (d_mlp, d_model) + wo (down): (d_mlp, L, d_model) → L × (d_model, d_mlp) + """ + + @jax.jit + def _pack_mlp(mlp): + gate = jnp.unstack(jnp.transpose(mlp["wi_0"]["kernel"], (1, 2, 0)), axis=0) + up = jnp.unstack(jnp.transpose(mlp["wi_1"]["kernel"], (1, 2, 0)), axis=0) + down = jnp.unstack(jnp.transpose(mlp["wo"]["kernel"], (1, 2, 0)), axis=0) + return gate, up, down + + for slot in range(self.NUM_SLOTS): + mlp = blocks[f"layers_{slot}"]["mlp"] + gate_layers, up_layers, down_layers = _pack_mlp(mlp) + for rep in range(self.num_reps): + i = rep * self.NUM_SLOTS + slot + p = f"{prefix}.{i}" + self.vllm_state[f"{p}.mlp.gate_proj.weight"] = gate_layers[rep] + self.vllm_state[f"{p}.mlp.up_proj.weight"] = up_layers[rep] + self.vllm_state[f"{p}.mlp.down_proj.weight"] = down_layers[rep] + del gate_layers, up_layers, down_layers + gc.collect() diff --git a/src/maxtext/integration/vllm/torchax_converter/validate_converter.py b/src/maxtext/integration/vllm/torchax_converter/validate_converter.py index cf00a5aa64..effc1b48de 100644 --- a/src/maxtext/integration/vllm/torchax_converter/validate_converter.py +++ b/src/maxtext/integration/vllm/torchax_converter/validate_converter.py @@ -31,7 +31,7 @@ rollout_tensor_parallelism=4 hbm_utilization_vllm=0.6 async_scheduling=false \ prompt="Paris is" hf_access_token= -Currently this validator supports qwen3 converter flows. +Currently this validator supports: qwen3-30b-a3b, qwen3-30b-a3b-base, qwen3-235b-a22b, gemma4-26b. """ import functools @@ -51,6 +51,7 @@ from maxtext.integration.vllm.torchax_converter.base import GREEN from maxtext.integration.vllm.torchax_converter.base import RESET from maxtext.integration.vllm.torchax_converter.base import timer +from maxtext.integration.vllm.torchax_converter.gemma4_moe import Gemma4MaxTextToVLLMConverter from maxtext.integration.vllm.torchax_converter.qwen3_moe import Qwen3MaxTextToVLLMConverter from maxtext.utils import model_creation_utils @@ -62,6 +63,7 @@ "qwen3-30b-a3b": "Qwen/Qwen3-30B-A3B", "qwen3-30b-a3b-base": "Qwen/Qwen3-30B-A3B", "qwen3-235b-a22b": "Qwen/Qwen3-235B-A22B", + "gemma4-26b": "google/gemma-4-26B-A4B-it", # Add more mappings as needed } @@ -96,10 +98,19 @@ def _get_maxtext_model(config): return model, mesh +def save_dict_to_file(state_dict, filename): + with open(filename, "w", encoding="utf-8") as f: + for key in sorted(state_dict.keys()): + f.write(f"{key}: {state_dict[key].shape}\n") + + def validate_converter(config) -> None: """Run end-to-end validation for MaxText to vLLM weight conversion.""" - if not config.model_name.startswith("qwen3"): - raise ValueError("validate_converter.py currently supports qwen3 models only. " f"Got {config.model_name}.") + if config.model_name not in vllm_model_name_mapping: + raise ValueError( + f"validate_converter.py does not support model '{config.model_name}'. " + f"Supported models: {sorted(vllm_model_name_mapping.keys())}" + ) model, mesh = _get_maxtext_model(config) print(f"{GREEN}MaxText model loaded successfully{RESET}") @@ -116,7 +127,10 @@ def validate_converter(config) -> None: logging.info("Name: %s, shape: %s", path_str, leaf.shape) logging.info("\tSharding: %s", leaf.sharding) - converter = Qwen3MaxTextToVLLMConverter(config, mesh) + if config.model_name.startswith("gemma4"): + converter = Gemma4MaxTextToVLLMConverter(config, mesh) + else: + converter = Qwen3MaxTextToVLLMConverter(config, mesh) with timer("Overall Conversion"): vllm_state = converter.convert(model_state) del model_state @@ -131,9 +145,12 @@ def validate_converter(config) -> None: tensor_parallel_size=config.rollout_tensor_parallelism, gpu_memory_utilization=getattr(config, "hbm_utilization_vllm", 0.5), async_scheduling=getattr(config, "async_scheduling", False), + # load_format="dummy", # Load actual weights instead of dummy for debugging ) print("\n" + "=" * 80) llm_state = llm.llm_engine.model_executor.driver_worker.model_runner.state + # save_dict_to_file(llm_state, "vllm_model_state.txt") + # save_dict_to_file(vllm_state, "converted_vllm_state.txt") any_src = next(iter(vllm_state.values())) any_src_arr = any_src.value if hasattr(any_src, "value") else any_src