Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/dependencies/extra_deps/post_train_base_deps.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
google-tunix @ https://github.com/google/tunix/archive/9d25e2647520b205ef43e5182ded030d20f3f52b.zip
google-tunix @ https://github.com/google/tunix/archive/387072374f99a100cb11f99dec951940b1475a04.zip
4 changes: 2 additions & 2 deletions src/dependencies/extra_deps/post_train_github_deps.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
-r post_train_base_deps.txt
tpu-inference @ https://github.com/vllm-project/tpu-inference/archive/40876e81f04226f9b7b1e4bbdc9051d6b1364b9d.zip
vllm @ git+https://github.com/vllm-project/vllm@595562651a5a4539ffa910d8570c08fb5169bdc9
tpu-inference @ https://github.com/vllm-project/tpu-inference/archive/39d9a9d38d3c96a7e1e57f9e693cf1c96a44e87d.zip
vllm @ git+https://github.com/vllm-project/vllm@529c671e8075d265a48b72e0eaaeb5e30d2f1630
36 changes: 24 additions & 12 deletions src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,31 +103,42 @@ def generate_maxtext_config(vllm_config: VllmConfig, mesh: Mesh) -> pyconfig.Hyp

# Gather information on the hidden size of MoE models to determine if padding is needed
# to meet MLP MoE requirements for tpu-inference GMM_v2 kernel.
hidden_size = getattr(vllm_config.model_config.hf_config, "moe_intermediate_size", None)
padded_hidden_size = hidden_size
hf_config = (
vllm_config.model_config.hf_config.text_config
if hasattr(vllm_config.model_config.hf_config, "text_config")
else vllm_config.model_config.hf_config
)
hidden_size = getattr(hf_config, "moe_intermediate_size", None)
num_lanes = pltpu.get_tpu_info().num_lanes
use_global_kv_heads = hasattr(hf_config, "num_global_key_value_heads")

# Global KV heads used by Gemma4
if use_global_kv_heads:
num_kv_heads = hf_config.num_global_key_value_heads
else:
num_kv_heads = hf_config.num_key_value_heads

max_logging.log(
f"vLLM sharding config: hidden_size={hidden_size}, tp={tp}, "
f"vLLM sharding config: hidden_size={hidden_size}, num_kv_heads={num_kv_heads}, tp={tp}, "
f"attn_dp={attn_dp}, ep={ep}, moe_mlp_tp_size={moe_mlp_tp_size}"
)

# Replicate the number of KV heads if its less than the total degree of model parallelism
if (
kv_tp_size % vllm_config.model_config.get_total_num_kv_heads() == 0
and vllm_config.model_config.get_total_num_kv_heads() < kv_tp_size
):
if kv_tp_size % num_kv_heads == 0 and num_kv_heads < kv_tp_size:
max_logging.log(
f"Padding num_kv_heads from {vllm_config.model_config.get_total_num_kv_heads()} "
f"to {kv_tp_size} to match the degree of tensor parallelism."
f"Padding num_kv_heads from {num_kv_heads} " f"to {kv_tp_size} to match the degree of tensor parallelism."
)
overrides["base_num_kv_heads"] = kv_tp_size
if use_global_kv_heads:
overrides["global_num_kv_heads"] = kv_tp_size
else:
overrides["base_num_kv_heads"] = kv_tp_size

# Pad the hidden size of MoE models if the MLP dimension is less than expected by the GMM_v2 kernel in tpu-inference.
# The GMM_v2 kernel requires the MLP dimension per expert to be at least 2x the number of TPU lanes
# to ensure efficient execution. See the validate_inputs() method in the following file for more details:
# https://github.com/vllm-project/tpu-inference/blob/main/tpu_inference/kernels/megablox/gmm_v2.py
if hidden_size is not None and tp > 1 and (hidden_size // moe_mlp_tp_size) % (2 * num_lanes) != 0:
if hidden_size is not None and (hidden_size // moe_mlp_tp_size) % (2 * num_lanes) != 0:
padded_hidden_size = next_power_of_two(hidden_size)
while (padded_hidden_size // moe_mlp_tp_size) < (2 * num_lanes):
padded_hidden_size = next_power_of_two(padded_hidden_size)

Expand Down Expand Up @@ -219,6 +230,7 @@ def __call__(

with self.mesh, nn.logical_axis_rules(self.maxtext_config.logical_axis_rules):
aux_hidden_states = []
expert_indices = None
hidden, kv_caches = self.model(
decoder_input_tokens=input_ids,
decoder_positions=input_positions,
Expand All @@ -231,7 +243,7 @@ def __call__(
# To be compatible with vLLM, we reshape to (batch * seq, dim).
hidden = hidden.reshape((-1, hidden.shape[-1]))

return kv_caches, hidden, aux_hidden_states
return kv_caches, hidden, aux_hidden_states, expert_indices

def forward(self, *args, **kwargs):
"""Alias for __call__ for compatibility.
Expand Down
Loading