From 3cfe54b6ecd7702b0bc1296d220ff5d2168c4a8a Mon Sep 17 00:00:00 2001 From: Nicolas Grande Date: Sat, 2 May 2026 00:02:11 +0000 Subject: [PATCH] adding adapter changes needed for gemma4 inference. --- .../extra_deps/post_train_base_deps.txt | 2 +- .../extra_deps/post_train_github_deps.txt | 4 +-- .../vllm/maxtext_vllm_adapter/adapter.py | 36 ++++++++++++------- 3 files changed, 27 insertions(+), 15 deletions(-) diff --git a/src/dependencies/extra_deps/post_train_base_deps.txt b/src/dependencies/extra_deps/post_train_base_deps.txt index 97dd437435..f57ad482c5 100644 --- a/src/dependencies/extra_deps/post_train_base_deps.txt +++ b/src/dependencies/extra_deps/post_train_base_deps.txt @@ -1 +1 @@ -google-tunix @ https://github.com/google/tunix/archive/9d25e2647520b205ef43e5182ded030d20f3f52b.zip +google-tunix @ https://github.com/google/tunix/archive/387072374f99a100cb11f99dec951940b1475a04.zip diff --git a/src/dependencies/extra_deps/post_train_github_deps.txt b/src/dependencies/extra_deps/post_train_github_deps.txt index 80483bd68e..474d664122 100644 --- a/src/dependencies/extra_deps/post_train_github_deps.txt +++ b/src/dependencies/extra_deps/post_train_github_deps.txt @@ -1,3 +1,3 @@ -r post_train_base_deps.txt -tpu-inference @ https://github.com/vllm-project/tpu-inference/archive/40876e81f04226f9b7b1e4bbdc9051d6b1364b9d.zip -vllm @ git+https://github.com/vllm-project/vllm@595562651a5a4539ffa910d8570c08fb5169bdc9 +tpu-inference @ https://github.com/vllm-project/tpu-inference/archive/39d9a9d38d3c96a7e1e57f9e693cf1c96a44e87d.zip +vllm @ git+https://github.com/vllm-project/vllm@529c671e8075d265a48b72e0eaaeb5e30d2f1630 diff --git a/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py b/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py index 773ee0822a..a521f5c2ae 100644 --- a/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py +++ b/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py @@ -103,31 +103,42 @@ def generate_maxtext_config(vllm_config: VllmConfig, mesh: Mesh) -> pyconfig.Hyp # Gather information on the hidden size of MoE models to determine if padding is needed # to meet MLP MoE requirements for tpu-inference GMM_v2 kernel. - hidden_size = getattr(vllm_config.model_config.hf_config, "moe_intermediate_size", None) - padded_hidden_size = hidden_size + hf_config = ( + vllm_config.model_config.hf_config.text_config + if hasattr(vllm_config.model_config.hf_config, "text_config") + else vllm_config.model_config.hf_config + ) + hidden_size = getattr(hf_config, "moe_intermediate_size", None) num_lanes = pltpu.get_tpu_info().num_lanes + use_global_kv_heads = hasattr(hf_config, "num_global_key_value_heads") + + # Global KV heads used by Gemma4 + if use_global_kv_heads: + num_kv_heads = hf_config.num_global_key_value_heads + else: + num_kv_heads = hf_config.num_key_value_heads max_logging.log( - f"vLLM sharding config: hidden_size={hidden_size}, tp={tp}, " + f"vLLM sharding config: hidden_size={hidden_size}, num_kv_heads={num_kv_heads}, tp={tp}, " f"attn_dp={attn_dp}, ep={ep}, moe_mlp_tp_size={moe_mlp_tp_size}" ) # Replicate the number of KV heads if its less than the total degree of model parallelism - if ( - kv_tp_size % vllm_config.model_config.get_total_num_kv_heads() == 0 - and vllm_config.model_config.get_total_num_kv_heads() < kv_tp_size - ): + if kv_tp_size % num_kv_heads == 0 and num_kv_heads < kv_tp_size: max_logging.log( - f"Padding num_kv_heads from {vllm_config.model_config.get_total_num_kv_heads()} " - f"to {kv_tp_size} to match the degree of tensor parallelism." + f"Padding num_kv_heads from {num_kv_heads} " f"to {kv_tp_size} to match the degree of tensor parallelism." ) - overrides["base_num_kv_heads"] = kv_tp_size + if use_global_kv_heads: + overrides["global_num_kv_heads"] = kv_tp_size + else: + overrides["base_num_kv_heads"] = kv_tp_size # Pad the hidden size of MoE models if the MLP dimension is less than expected by the GMM_v2 kernel in tpu-inference. # The GMM_v2 kernel requires the MLP dimension per expert to be at least 2x the number of TPU lanes # to ensure efficient execution. See the validate_inputs() method in the following file for more details: # https://github.com/vllm-project/tpu-inference/blob/main/tpu_inference/kernels/megablox/gmm_v2.py - if hidden_size is not None and tp > 1 and (hidden_size // moe_mlp_tp_size) % (2 * num_lanes) != 0: + if hidden_size is not None and (hidden_size // moe_mlp_tp_size) % (2 * num_lanes) != 0: + padded_hidden_size = next_power_of_two(hidden_size) while (padded_hidden_size // moe_mlp_tp_size) < (2 * num_lanes): padded_hidden_size = next_power_of_two(padded_hidden_size) @@ -219,6 +230,7 @@ def __call__( with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules): aux_hidden_states = [] + expert_indices = None hidden, kv_caches = self.model( decoder_input_tokens=input_ids, decoder_positions=input_positions, @@ -231,7 +243,7 @@ def __call__( # To be compatible with vLLM, we reshape to (batch * seq, dim). hidden = hidden.reshape((-1, hidden.shape[-1])) - return kv_caches, hidden, aux_hidden_states + return kv_caches, hidden, aux_hidden_states, expert_indices def forward(self, *args, **kwargs): """Alias for __call__ for compatibility.