From bb75291c09e7eaa29adf782f1c60add1368640d8 Mon Sep 17 00:00:00 2001 From: Param Bole Date: Wed, 17 Jun 2026 18:10:16 +0000 Subject: [PATCH 1/2] feat(models): Integrate DeepSeek V4 architecture and routing --- .../utils/hf_model_configs.py | 1 + src/maxtext/common/common_types.py | 1 + src/maxtext/configs/models/deepseek4-284b.yml | 64 ++++ src/maxtext/configs/types.py | 6 +- src/maxtext/layers/attention_compressed.py | 37 +-- src/maxtext/layers/attentions.py | 1 + src/maxtext/layers/decoders.py | 110 ++++++- src/maxtext/layers/embeddings.py | 17 +- src/maxtext/layers/moe.py | 19 +- src/maxtext/models/deepseek.py | 68 ++--- src/maxtext/models/deepseek4.py | 274 ++++++++++++++++++ src/maxtext/utils/globals.py | 1 + tests/unit/deepseek_v4_vs_reference_test.py | 39 +-- tests/unit/train_compile_test.py | 20 ++ 14 files changed, 579 insertions(+), 79 deletions(-) create mode 100644 src/maxtext/configs/models/deepseek4-284b.yml create mode 100644 src/maxtext/models/deepseek4.py diff --git a/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py b/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py index 371bb24ce4..ecf9fd4b36 100644 --- a/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py +++ b/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py @@ -1016,6 +1016,7 @@ class DeepseekV32Config(PTConfig): def __init__(self, **kwargs): self.max_position_embeddings = kwargs.get("max_position_embeddings", 163840) + self.rope_scaling = kwargs.pop("rope_scaling", None) super().__init__(**kwargs) diff --git a/src/maxtext/common/common_types.py b/src/maxtext/common/common_types.py index d4b52207fc..71dbc105d4 100644 --- a/src/maxtext/common/common_types.py +++ b/src/maxtext/common/common_types.py @@ -113,6 +113,7 @@ class DecoderBlockType(enum.Enum): SIMPLE_MLP = "simple_mlp" LLAMA4 = "llama4" OLMO3 = "olmo3" + DEEPSEEK4 = "deepseek4" class AttentionType(enum.Enum): diff --git a/src/maxtext/configs/models/deepseek4-284b.yml b/src/maxtext/configs/models/deepseek4-284b.yml new file mode 100644 index 0000000000..5ba2dd062f --- /dev/null +++ b/src/maxtext/configs/models/deepseek4-284b.yml @@ -0,0 +1,64 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Model config for DeepSeek-V4-Flash 284B (https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash) + +base_emb_dim: 4096 +base_num_query_heads: 64 +base_num_kv_heads: 1 +base_num_decoder_layers: 43 +base_mlp_dim: 2048 +base_moe_mlp_dim: 2048 +vocab_size: 129280 +head_dim: 512 + +# --- Standard Defaults --- +enable_dropout: false +logits_via_embedding: false +normalization_layer_epsilon: 1.0e-6 + +# --- V4 Specific Architectural Keys --- +decoder_block: "deepseek4" +mhc_expansion_rate: 4 +first_num_hash_layers: 3 +indexer_head_dim: 128 +indexer_n_heads: 64 +indexer_topk: 512 + +# Note: Layers (0, 1, 2) are prefix layers. +# The 44th layer (MTP module with compress_ratio=0) has been explicitly dropped for now. +# This leaves exactly 43 layers: 3 prefix [0,0,4] + 40 scanned. +compress_ratios: [0, 0, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4] + +# --- MoE configuration --- +mlp_activations: ["silu", "linear"] +num_experts: 256 +num_experts_per_tok: 6 +mlp_activations_limit: 10 +shared_experts: 1 +routed_score_func: "sqrtsoftplus" + +# --- Attention configuration --- +attention_type: 'compressed' +q_lora_rank: 1024 +o_groups: 8 +o_lora_rank: 1024 +sliding_window_size: 128 + +# --- RoPE --- + +rope_type: "default" +rope_max_timescale: 10000 # Main RoPE theta +compressed_rope_max_timescale: 160000 # Compressed RoPE theta +max_position_embeddings: 1048576 diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index e43f34f247..d1f293aae8 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -227,7 +227,7 @@ class ProfilerType(str, Enum): "deepseek3-test", "deepseek3-tiny", "deepseek3.2-671b", - "deepseek4", + "deepseek4-284b", "deepseek-custom", "kimi-k2-1t", "gemma-7b", @@ -553,7 +553,7 @@ class Attention(BaseModel): "autoselected", description="The attention algorithm to use (dot_product, flash, etc).", ) - attention_type: Literal["global", "local_sliding", "chunk", "mla", "full"] = Field( + attention_type: Literal["global", "local_sliding", "chunk", "mla", "full", "compressed"] = Field( "global", description="The variant of attention to use." ) share_kv_projections: bool = Field( @@ -2925,6 +2925,8 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de raise ValueError("`local_checkpoint_period` must be > 0 for emergency checkpointing.") if self.moba and self.attention not in ("dot_product"): raise ValueError("MoBA is only supported with dot_product attention.") + if self.decoder_block == DecoderBlockType.DEEPSEEK4 and self.attention != "dot_product": + raise ValueError("DeepSeek4 decoder block currently only supports dot_product attention.") if self.use_indexer: if self.q_lora_rank == 0: raise NotImplementedError("Sparse indexer has not implemented for q_lora_rank = 0.") diff --git a/src/maxtext/layers/attention_compressed.py b/src/maxtext/layers/attention_compressed.py index e9a25f46b5..391ec6cedd 100644 --- a/src/maxtext/layers/attention_compressed.py +++ b/src/maxtext/layers/attention_compressed.py @@ -680,24 +680,23 @@ def __init__( rngs: Optional[nnx.Rngs] = None, **kwargs, ): - """Initializes the CompressedAttention layer. + """Inherits all standard Attention hyperparameters and selectively instantiates + an underlying HCA or CSA compressor based on the provided `compress_ratio`. - Inherits all standard Attention hyperparameters and selectively instantiates - an underlying HCA or CSA compressor based on the provided `layer_type`. + Highlights of DeepSeek-V4 attention integration: + - Shared-KV: The layer supports decoupling Q and KV heads for heavy compression. + - MQA: Multi-Query Attention used alongside heavy KV compression. + - 3 Different Attention Modes: Sliding Window (prefix), HCA (128x), and CSA (4x). + - Dual RoPE Theta: Uses 10000 for standard uncompressed tokens and 160000 for compressed. Args: (See maxtext.layers.attentions.Attention for standard attention arguments) q_lora_rank: The rank for the LoRA projection in the compressed query. - compress_ratio: The compression ratio for the compressor. + compress_ratio: The compression ratio (0, 4, or 128) for the compressor. """ - """Initializes the Compressed Attention module.""" self.q_lora_rank = q_lora_rank self.compress_ratio = compress_ratio - # Determine the correct underlying attention type based on the compress_ratio - if self.compress_ratio == 0: - attention_type = AttentionType.LOCAL_SLIDING - super().__init__( config=config, num_query_heads=num_query_heads, @@ -809,20 +808,22 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No rngs=self.rngs, ) - # DeepSeek-V4 uses a separate RoPE theta (160000) for compressed tokens. - # We must instantiate a dedicated rotary embedding for the compressors - self.compress_rotary_embedding = DeepSeekV4RotaryEmbedding( + # Override the base rotary embedding with the correct theta for this layer. + # CSA / HCA layers use compressed_rope_max_timescale (160000). + # Sliding window prefix layers use rope_max_timescale (10000). + rope_theta = self.config.compressed_rope_max_timescale if self.compress_ratio > 0 else self.config.rope_max_timescale + self.rotary_embedding = DeepSeekV4RotaryEmbedding( head_dim=self.config.head_dim, - partial_rotary_factor=1.0, - rope_theta=self.config.compressed_rope_max_timescale, - dtype=self.dtype, + partial_rotary_factor=self.config.qk_rope_head_dim / self.config.head_dim, + rope_theta=rope_theta, + fprop_dtype=self.dtype, ) if self.compress_ratio > 4: self.hca_compressor = DeepseekV4HCACompressor( config=self.config, compress_ratio=self.compress_ratio, - rotary_embedding=self.compress_rotary_embedding, + rotary_embedding=self.rotary_embedding, kernel_init=self.kernel_init, quant=self.quant, model_mode=self.model_mode, @@ -832,7 +833,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No self.csa_compressor = DeepseekV4CSACompressor( config=self.config, compress_ratio=self.compress_ratio, - rotary_embedding=self.compress_rotary_embedding, + rotary_embedding=self.rotary_embedding, kernel_init=self.kernel_init, quant=self.quant, model_mode=self.model_mode, @@ -1047,7 +1048,7 @@ def __call__( # -> [batch, q_length, emb_dim] final_out = self.o_b_proj(grouped_flat) - return final_out + return final_out, None def compressed_attention( diff --git a/src/maxtext/layers/attentions.py b/src/maxtext/layers/attentions.py index 679c891360..ab7673d1d4 100644 --- a/src/maxtext/layers/attentions.py +++ b/src/maxtext/layers/attentions.py @@ -850,6 +850,7 @@ def init_rotary_embedding(self): shard_mode=self.config.shard_mode, rngs=self.rngs, ) + elif self.is_qwen3_hybrid: rotary_embedding = PartialRotaryEmbedding( min_timescale=self.config.rope_min_timescale, diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index b28b6dcb7a..0150c7b401 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -41,6 +41,7 @@ from maxtext.layers.quantizations import AqtQuantization as Quant from maxtext.models import ( deepseek, + deepseek4, deepseek_batchsplit, deepseek_batchsplit_fp8, gemma, @@ -467,6 +468,10 @@ def get_decoder_layers(self): deepseek.DeepSeekDenseLayerToLinen, deepseek.DeepSeekMoELayerToLinen, ] + case DecoderBlockType.DEEPSEEK4: + return ( + [deepseek4.DeepSeek4ScannableBlockToLinen] if self.config.scan_layers else [deepseek4.DeepSeek4LayerToLinen] + ) case DecoderBlockType.GEMMA: return [gemma.GemmaDecoderLayerToLinen] case DecoderBlockType.GEMMA2: @@ -632,6 +637,7 @@ def get_norm_layer(self, num_features: int): DecoderBlockType.MISTRAL, DecoderBlockType.MIXTRAL, DecoderBlockType.DEEPSEEK, + DecoderBlockType.DEEPSEEK4, DecoderBlockType.GEMMA, DecoderBlockType.GEMMA2, DecoderBlockType.GEMMA3, @@ -1061,6 +1067,17 @@ def __call__( previous_chunk, slot, ) + elif cfg.decoder_block == DecoderBlockType.DEEPSEEK4: + y = self._apply_deepseek4_scanned_blocks( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk, + slot, + decoder_input_tokens, + ) else: RemattedBlockLayer = RemattedBlockLayers[0] scan_length = int(cfg.num_decoder_layers / cfg.inhomogeneous_layer_cycle_interval) @@ -1195,7 +1212,7 @@ def __call__( "is_nope_layer": llama4.determine_is_nope_layer(lyr, self.config.nope_layer_interval), "is_moe_layer": llama4.determine_is_moe_layer(lyr, self.config.interleave_moe_layer_step), } - if cfg.decoder_block in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5): + if cfg.decoder_block in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5, DecoderBlockType.DEEPSEEK4): layer_kwargs = {"layer_idx": lyr} kv_cache = None if kv_caches is not None: @@ -1423,6 +1440,97 @@ def _apply_gemma4_scanned_blocks( return y + def _apply_deepseek4_scanned_blocks( + self, + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk, + slot, + decoder_input_tokens, + ): + """Applies DeepSeek V4 scanned decoder blocks. + + DeepSeek V4 has some number of prefix layers (defined by `first_num_hash_layers`) + that use static Hash Routing. The remaining layers alternate `compress_ratio=128` (HCA) + and `compress_ratio=4` (CSA) and are evaluated in a single `nn.scan` block. + + For DeepSeek4-Flash (43 hidden layers total): + - 3 Prefix layers (Indices 0, 1, 2) + - 40 Scanned layers: 20 perfectly repeating chunks of [128, 4] + """ + + cfg = self.config + mesh = self.mesh + + broadcast_args = ( + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + slot, + previous_chunk, + ) + + layer_call_kwargs = { + "previous_chunk": previous_chunk, + "slot": slot, + "decoder_input_tokens": decoder_input_tokens, + } + + # 1. Prefix Unrolling + # These layers use Hash Routing. + num_hash_layers = cfg.first_num_hash_layers + for layer_idx in range(num_hash_layers): + prefix_layer = deepseek4.DeepSeek4LayerToLinen( + config=cfg, + mesh=mesh, + name=f"layers_{layer_idx}", + quant=self.quant, + model_mode=self.model_mode, + layer_idx=layer_idx, + ) + y, _ = prefix_layer( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + **layer_call_kwargs, + ) + + # 2. Chunked Scanning + # The remaining layers perfectly alternate HCA (128) and CSA (4). + num_remaining_layers = cfg.num_decoder_layers - num_hash_layers + num_full_blocks = num_remaining_layers // 2 + + if num_full_blocks > 0: + ScannableBlockToLinen = deepseek4.DeepSeek4ScannableBlockToLinen + policy = self.get_remat_policy() + RemattedDeepSeek4Block = self.set_remat_policy([ScannableBlockToLinen], policy)[0] + + y, _ = nn.scan( + RemattedDeepSeek4Block, + variable_axes={ + "params": cfg.param_scan_axis, + "cache": 0, + "intermediates": 0, + "aqt": 0, + "_overwrite_with_gradient": 0, + }, + split_rngs={"params": True, "dropout": cfg.enable_dropout}, + in_axes=(nn.broadcast,) * len(broadcast_args), + length=num_full_blocks, + metadata_params={ + nn.PARTITION_NAME: "layers", + "abstract_init": False, + }, + )(config=cfg, mesh=mesh, quant=self.quant, model_mode=model_mode, name="scanned_blocks",)(y, *broadcast_args) + + return y + def _apply_gemma4_small_layers( self, y, diff --git a/src/maxtext/layers/embeddings.py b/src/maxtext/layers/embeddings.py index 86b6723bd5..ad6b171f2f 100644 --- a/src/maxtext/layers/embeddings.py +++ b/src/maxtext/layers/embeddings.py @@ -1803,7 +1803,7 @@ def qwen3_omni_mrope_embedding_as_linen( ) -class DeepSeekV4RotaryEmbedding(nnx.Module): +class DeepSeekV4RotaryEmbedding(RotaryEmbedding): """DeepSeek-V4 partial rotary embedding with interleaved frequencies. DeepSeek-V4 uses an interleaved positional encoding where consecutive channels @@ -1822,12 +1822,23 @@ def __init__( head_dim: int, partial_rotary_factor: float = 64.0 / 512.0, rope_theta: float = 10000.0, - dtype: Any = jnp.float32, + fprop_dtype: Any = jnp.float32, + min_timescale: int = 10000, + max_timescale: int = 10000, + mesh: Any = None, + **kwargs, ): + super().__init__( + min_timescale=min_timescale, + max_timescale=max_timescale, + mesh=mesh, + fprop_dtype=fprop_dtype, + **kwargs, + ) self.head_dim = head_dim self.partial_rotary_factor = partial_rotary_factor self.rope_theta = rope_theta - self.dtype = dtype + self.fprop_dtype = fprop_dtype # Compute the partial rotary dimension (rope_head_dim) self.dim = int(head_dim * partial_rotary_factor) diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 020956098c..4bb7cc7c08 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -208,6 +208,10 @@ def calculate_load_balance_updates(top_k_indices, num_experts, rate): return output +class Tid2EidVar(nnx.Variable): + """Custom variable to hold tid2eid without trainable param overhead.""" + + class GateLogit(nnx.Module): """A layer used to compute gate logits, allowing to return the pre bias values for DeepSeek routing.""" @@ -399,8 +403,11 @@ def __init__( # DeepSeek V4 Hash Routing if self.is_hash_routing: # Token-ID to Expert-ID lookup table for static routing - self.tid2eid = nnx.Variable( - jnp.zeros((self.config.vocab_size, self.num_experts_per_tok), dtype=jnp.int32), + # Must be stored as float32 because MaxText passes the entire variable tree + # through jax.value_and_grad, which strictly requires all leaves to be inexact types + # (even if they receive no gradients). We cast to int32 dynamically during routing. + self.tid2eid = Tid2EidVar( + jnp.zeros((self.config.vocab_size, self.num_experts_per_tok), dtype=jnp.float32), out_sharding=None, # Replicated across shards for local lookup ) else: @@ -665,7 +672,13 @@ def get_topk(self, gate_logits, pre_bias_logits, rngs=None, input_ids=None): return top_k_weights, top_k_indices if self.is_hash_routing: - top_k_indices = self.tid2eid[input_ids] + if input_ids is None: + raise ValueError("input_ids cannot be None when is_hash_routing is True") + # Access the static routing table + tid2eid_int = self.tid2eid.value + # Cast the float32 array to int32 (JAX automatically assigns 0.0 gradients to integer casts) + tid2eid_int = tid2eid_int.astype(jnp.int32) + top_k_indices = tid2eid_int[input_ids] top_k_weights = jnp.take_along_axis(pre_bias_logits, top_k_indices, axis=-1) # NOTE: deepseek2 has a different pattern elif self.config.model_name.startswith(("deepseek3", "deepseek4")): diff --git a/src/maxtext/models/deepseek.py b/src/maxtext/models/deepseek.py index 27e1a6f7ad..d3a72b31bf 100644 --- a/src/maxtext/models/deepseek.py +++ b/src/maxtext/models/deepseek.py @@ -25,7 +25,7 @@ import jax.numpy as jnp from jax.sharding import Mesh from maxtext.common.common_types import Config -from maxtext.common.common_types import HyperConnectionType, MODEL_MODE_PREFILL +from maxtext.common.common_types import HyperConnectionType, MODEL_MODE_PREFILL, DecoderBlockType from maxtext.layers import attention_mla from maxtext.layers import initializers from maxtext.layers import linears @@ -138,37 +138,39 @@ def __init__( self.engram_layer_norm = None self.engram = None - self.self_attention = attention_mla.MLA( - config=self.config, - num_query_heads=self.config.num_query_heads, - num_kv_heads=self.config.num_kv_heads, - head_dim=self.config.head_dim, - max_target_length=self.config.max_target_length, - max_prefill_predict_length=self.config.max_prefill_predict_length, - attention_kernel=self.config.attention, - attention_type=self.config.attention_type, - inputs_q_shape=self.dummy_inputs_shape, - inputs_kv_shape=self.dummy_inputs_shape, - mesh=mesh, - dtype=self.config.dtype, - weight_dtype=self.config.weight_dtype, - dropout_rate=self.config.dropout_rate, - name="self_attention", - quant=quant, - kv_quant=quantizations.configure_kv_quant(config), - q_lora_rank=self.config.q_lora_rank, - kv_lora_rank=self.config.kv_lora_rank, - qk_nope_head_dim=self.config.qk_nope_head_dim, - qk_rope_head_dim=self.config.qk_rope_head_dim, - v_head_dim=self.config.v_head_dim, - max_position_embeddings=self.config.max_position_embeddings, - original_max_position_embeddings=self.config.original_max_position_embeddings, - mscale=self.config.mscale, - rope_factor=self.config.rope_factor, - model_mode=model_mode, - rngs=rngs, - attn_logits_soft_cap=self.config.attn_logits_soft_cap, - ) + # DeepSeek V4 natively overrides this block with CompressedAttention. + if self.config.decoder_block != DecoderBlockType.DEEPSEEK4: + self.self_attention = attention_mla.MLA( + config=self.config, + num_query_heads=self.config.num_query_heads, + num_kv_heads=self.config.num_kv_heads, + head_dim=self.config.head_dim, + max_target_length=self.config.max_target_length, + max_prefill_predict_length=self.config.max_prefill_predict_length, + attention_kernel=self.config.attention, + attention_type=self.config.attention_type, + inputs_q_shape=self.dummy_inputs_shape, + inputs_kv_shape=self.dummy_inputs_shape, + mesh=mesh, + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + dropout_rate=self.config.dropout_rate, + name="self_attention", + quant=quant, + kv_quant=quantizations.configure_kv_quant(self.config), + q_lora_rank=self.config.q_lora_rank, + kv_lora_rank=self.config.kv_lora_rank, + qk_nope_head_dim=self.config.qk_nope_head_dim, + qk_rope_head_dim=self.config.qk_rope_head_dim, + v_head_dim=self.config.v_head_dim, + max_position_embeddings=self.config.max_position_embeddings, + original_max_position_embeddings=self.config.original_max_position_embeddings, + mscale=self.config.mscale, + rope_factor=self.config.rope_factor, + model_mode=model_mode, + rngs=rngs, + attn_logits_soft_cap=self.config.attn_logits_soft_cap, + ) self.dropout = Dropout(rate=self.config.dropout_rate, broadcast_dims=(-2,), rngs=self.rngs) if self.is_mhc_enabled: @@ -333,7 +335,7 @@ def __init__( rngs=self.rngs, ) - def mlp_op(self, x, deterministic): + def mlp_op(self, x, deterministic, *args, **kwargs): mlp = self.mlp(x, deterministic, intermediate_sharding=self.mlp_intermediate_sharding, out_sharding=self.out_sharding) return self.with_logical_constraint(mlp) diff --git a/src/maxtext/models/deepseek4.py b/src/maxtext/models/deepseek4.py new file mode 100644 index 0000000000..12b0b83823 --- /dev/null +++ b/src/maxtext/models/deepseek4.py @@ -0,0 +1,274 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DeepSeek-V4 model definition.""" + +from typing import Optional + +from flax import nnx +import flax.linen as nn +from jax.sharding import Mesh + +from maxtext.common.common_types import Config, AttentionType +from maxtext.common.common_types import HyperConnectionType +from maxtext.layers import attention_compressed +from maxtext.layers import initializers +from maxtext.layers import moe +from maxtext.layers import nnx_wrappers +from maxtext.layers import quantizations +from maxtext.models import deepseek +from jax.ad_checkpoint import checkpoint_name + + +class DeepSeek4DecoderLayer(deepseek.DeepSeekGenericLayer): + """DeepSeek-V4 specific decoder layer. + + Note: V4 does not utilize purely dense layers in the initial transformer blocks. + Every layer is a Sparse MoE layer (which internally contains shared dense experts). + + Args: + config: Configuration for the model. + model_mode: The mode of the model (e.g. 'train', 'inference'). + mesh: JAX sharding mesh. + rngs: NNX Rngs. + quant: Optional AQT quantization config. + layer_idx: The index of the layer. + compress_ratio: DeepSeek V4 specific parameter defining the KV cache compression + ratio. Expected values are 0 (no compression, sliding window), 4 (CSA), or 128 (HCA). + is_hash_routing: DeepSeek V4 specific parameter defining if this layer uses + static deterministic hash routing (used in prefix layers). + """ + + def __init__( + self, + config: Config, + model_mode: str, + mesh: Mesh, + rngs: nnx.Rngs, + quant: Optional[quantizations.AqtQuantization] = None, + layer_idx: int = -1, + compress_ratio: Optional[int] = None, + is_hash_routing: Optional[bool] = None, + ) -> None: + super().__init__( + config=config, + model_mode=model_mode, + mesh=mesh, + rngs=rngs, + quant=quant, + layer_idx=layer_idx, + ) + + # DeepSeek V4 applies Hash Routing to the first `config.first_num_hash_layers` layers. + # For the unscannable prefix layers, we can safely determine this using `layer_idx`. + # However, for layers inside `nn.scan` blocks, `layer_idx` is a dynamic JAX tracer + # and cannot be evaluated as a boolean condition. Since all scannable layers occur + # after the hash-routed prefix, the scannable block explicitly passes + # `is_hash_routing=False` to safely bypass this check. + if is_hash_routing is None: + is_hash_routing = layer_idx < config.first_num_hash_layers + self.mlp = moe.RoutedAndSharedMoE( + config=self.config, + mesh=self.mesh, + kernel_init=initializers.nd_dense_init(self.config.dense_init_scale, "fan_in", "truncated_normal"), + kernel_axes=("embed", None), + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + quant=quant, + is_hash_routing=is_hash_routing, + rngs=rngs, + ) + + if compress_ratio is None: + compress_ratio = config.compress_ratios[layer_idx] + + # Route to LOCAL_SLIDING if compression is disabled for this layer, + # otherwise default to the globally configured attention type (e.g., COMPRESSED). + layer_attention_type = ( + AttentionType.LOCAL_SLIDING if compress_ratio == 0 else AttentionType(self.config.attention_type) + ) + + self.self_attention = attention_compressed.CompressedAttention( + config=self.config, + compress_ratio=compress_ratio, + num_query_heads=self.config.num_query_heads, + num_kv_heads=self.config.num_kv_heads, + head_dim=self.config.head_dim, + max_target_length=self.config.max_target_length, + max_prefill_predict_length=self.config.max_prefill_predict_length, + attention_kernel=self.config.attention, + attention_type=layer_attention_type, + inputs_q_shape=self.dummy_inputs_shape, + inputs_kv_shape=self.dummy_inputs_shape, + mesh=self.mesh, + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + dropout_rate=self.config.dropout_rate, + sliding_window_size=self.config.sliding_window_size, + q_lora_rank=self.config.q_lora_rank, + name=f"compressed_attention_layer_{layer_idx}", + quant=quant, + kv_quant=quantizations.configure_kv_quant(config), + model_mode=model_mode, + rngs=rngs, + ) + + # pylint: disable=arguments-differ + def mlp_op(self, inputs, deterministic, *args, **kwargs): + input_ids = kwargs.get("input_ids") + mlp_lnx, load_balance_loss, moe_bias_updates = self.mlp( + inputs=inputs, + input_ids=input_ids, + ) + return self.with_logical_constraint(mlp_lnx), load_balance_loss, moe_bias_updates + + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=None, + slot: None | int = None, + kv_cache=None, + attention_metadata=None, + decoder_input_tokens=None, + ): + if isinstance(inputs, tuple): + inputs = inputs[0] + + x = self.with_logical_constraint(inputs) + x = checkpoint_name(x, "decoder_layer_input") + + _, intermediate_inputs = self.self_attention_with_norm_op( + x, + decoder_segment_ids, + decoder_positions, + deterministic, + previous_chunk, + slot, + ) + + layer_output, metadata = self.mhc_mlp( + self.post_attention_norm_op, + self.mlp_op, + x=intermediate_inputs, + mhc_type=HyperConnectionType.MLP_MOE, + deterministic=deterministic, + input_ids=decoder_input_tokens, + ) + load_balance_loss = metadata.get("load_balance_loss", None) + moe_bias_updates = metadata.get("moe_bias_updates", None) + + layer_output = self.dropout_op(layer_output, deterministic=deterministic) + return self.post_process(layer_output, load_balance_loss, moe_bias_updates, kv_cache) + + +class DeepSeek4ScannableBlock(nnx.Module): + """A scannable block containing exactly two DeepSeek V4 layers (HCA and CSA). + + DeepSeek V4 layers alternate `compress_ratio=128` (HCA) and `compress_ratio=4` (CSA) + throughout the middle of the network. This block encapsulates one full `[128, 4]` + cycle so it can be perfectly scanned using JAX `nn.scan`. + """ + + def __init__( + self, + config: Config, + mesh: Mesh, + model_mode: str, + rngs: nnx.Rngs, + quant: None | quantizations.AqtQuantization = None, + ): + self.config = config + self.mesh = mesh + self.model_mode = model_mode + self.quant = quant + self.rngs = rngs + + # Layer 0 in the block: HCA (compress_ratio=128) with Standard MoE (is_hash_routing=False) + self.layers_0 = DeepSeek4DecoderLayer( + config=self.config, + mesh=self.mesh, + model_mode=self.model_mode, + rngs=self.rngs, + quant=self.quant, + compress_ratio=128, + is_hash_routing=False, + ) + + # Layer 1 in the block: CSA (compress_ratio=4) with Standard MoE (is_hash_routing=False) + self.layers_1 = DeepSeek4DecoderLayer( + config=self.config, + mesh=self.mesh, + model_mode=self.model_mode, + rngs=self.rngs, + quant=self.quant, + compress_ratio=4, + is_hash_routing=False, + ) + + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + slot=None, + previous_chunk=None, + attention_metadata=None, + kv_cache=None, + ): + inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed")) + inputs = checkpoint_name(inputs, "decoder_layer_input") + y = inputs + + y, _ = self.layers_0( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=previous_chunk, + slot=slot, + kv_cache=kv_cache, + attention_metadata=attention_metadata, + ) + + y, _ = self.layers_1( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=previous_chunk, + slot=slot, + kv_cache=kv_cache, + attention_metadata=attention_metadata, + ) + + return y, None + + +DeepSeek4LayerToLinen = nnx_wrappers.to_linen_class( + DeepSeek4DecoderLayer, + base_metadata_fn=initializers.variable_to_logically_partitioned, +) + +DeepSeek4ScannableBlockToLinen = nnx_wrappers.to_linen_class( + DeepSeek4ScannableBlock, + base_metadata_fn=initializers.variable_to_logically_partitioned, +) diff --git a/src/maxtext/utils/globals.py b/src/maxtext/utils/globals.py index e3b3aadf2d..48caa91ef1 100644 --- a/src/maxtext/utils/globals.py +++ b/src/maxtext/utils/globals.py @@ -75,6 +75,7 @@ "deepseek2-16b": "deepseek-ai/DeepSeek-V2-Lite", "deepseek3-671b": "deepseek-ai/DeepSeek-V3", "deepseek3.2-671b": "deepseek-ai/DeepSeek-V3.2", + "deepseek4": "deepseek-ai/DeepSeek-V4-Flash", "gpt-oss-20b": "openai/gpt-oss-20b", "gpt-oss-120b": "openai/gpt-oss-120b", "qwen3-omni-30b-a3b": "Qwen/Qwen3-Omni-30B-A3B-Instruct", diff --git a/tests/unit/deepseek_v4_vs_reference_test.py b/tests/unit/deepseek_v4_vs_reference_test.py index 1da95a184e..0b75aa9ff4 100644 --- a/tests/unit/deepseek_v4_vs_reference_test.py +++ b/tests/unit/deepseek_v4_vs_reference_test.py @@ -57,13 +57,13 @@ # Tests # ============================================================================== -# HuggingFace reference: https://huggingface.co/deepseek-ai/DeepSeek-V4/blob/main/modeling_deepseek_v4.py +# HuggingFace reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/deepseek_v4/modeling_deepseek_v4.py # pylint: disable=line-too-long from jax.experimental import mesh_utils from jax.sharding import Mesh from maxtext.common.common_types import MODEL_MODE_TRAIN from maxtext.configs import pyconfig from maxtext.layers.attention_compressed import CompressedAttention -from maxtext.layers.embeddings import DeepSeekV4RotaryEmbedding as MTRope + from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.models.deepseek_v4.modeling_deepseek_v4 import DeepseekV4Attention from transformers.models.deepseek_v4.modeling_deepseek_v4 import DeepseekV4RotaryEmbedding as PTRope @@ -75,7 +75,7 @@ class DeepSeekV4RotaryEmbeddingTest(unittest.TestCase): def setUp(self): self.batch_size = 2 - self.seq_len = 16 + self.seq_len = 4096 self.head_dim = 128 self.num_heads = 4 self.main_rope_theta = 10000.0 @@ -408,6 +408,8 @@ def setUp(self): self.q_lora_rank = 32 self.o_groups = 2 self.o_lora_rank = 64 + self.qk_rope_head_dim = 64 + self.partial_rotary_factor = self.qk_rope_head_dim / self.head_dim self.rngs = nnx.Rngs(0) @@ -431,8 +433,12 @@ def setUp(self): layer_types=["sliding_attention"], num_hidden_layers=1, rope_parameters={ - "main": {"rope_type": "default", "rope_theta": 10000.0, "partial_rotary_factor": 1.0}, - "compress": {"rope_type": "default", "rope_theta": 160000.0, "partial_rotary_factor": 1.0}, + "main": {"rope_type": "default", "rope_theta": 10000.0, "partial_rotary_factor": self.partial_rotary_factor}, + "compress": { + "rope_type": "default", + "rope_theta": 160000.0, + "partial_rotary_factor": self.partial_rotary_factor, + }, }, sliding_window=2048, attention_dropout=0.0, @@ -524,9 +530,13 @@ def _run_e2e_test(self, layer_type, is_packed=False): "compressed_sparse_attention": self.pt_config.compress_rates["compressed_sparse_attention"], "heavily_compressed_attention": self.pt_config.compress_rates["heavily_compressed_attention"], } + compress_ratio = compress_ratio_map[layer_type] + layer_attention_type = AttentionType.LOCAL_SLIDING if compress_ratio == 0 else AttentionType.COMPRESSED + mt_attn = CompressedAttention( config=mt_config, - compress_ratio=compress_ratio_map[layer_type], + compress_ratio=compress_ratio, + attention_type=layer_attention_type, num_query_heads=self.num_heads, num_kv_heads=1, head_dim=self.head_dim, @@ -540,14 +550,6 @@ def _run_e2e_test(self, layer_type, is_packed=False): rngs=self.rngs, ) self.mt_attn = mt_attn - if layer_type == "sliding_attention": - rope_factor = self.pt_config.rope_parameters["main"]["partial_rotary_factor"] - mt_rope = MTRope(head_dim=self.head_dim, partial_rotary_factor=rope_factor, rope_theta=10000.0) - else: - rope_factor = self.pt_config.rope_parameters["compress"]["partial_rotary_factor"] - mt_rope = MTRope(head_dim=self.head_dim, partial_rotary_factor=rope_factor, rope_theta=160000.0) - - mt_attn.rotary_embedding = mt_rope # 3. Copy Weights self._copy_linear(mt_attn.wq_a, ref_attn.q_a_proj) @@ -652,8 +654,7 @@ def _run_e2e_test(self, layer_type, is_packed=False): print(f"top_k_indices mismatches: {num_mismatches}") # 6. Execute MaxText - - mt_out = mt_attn(x_mt, x_mt, segs_mt, pos_mt, deterministic=True, model_mode=MODEL_MODE_TRAIN) + mt_out, _ = mt_attn(x_mt, x_mt, segs_mt, pos_mt, deterministic=True, model_mode=MODEL_MODE_TRAIN) # 7. Asserts if not is_packed: @@ -771,7 +772,7 @@ def setUp(self): "vocab_size": self.vocab_size, "first_num_hash_layers": 3, "decoder_block": "deepseek", - "model_name": "deepseek4", + "model_name": "deepseek4-284b", "attention": "dot_product", "base_mlp_dim": 256, "base_moe_mlp_dim": 256, @@ -809,7 +810,7 @@ def test_hash_router(self): ) # Sync weights - mx_moe.tid2eid.value = jnp.array(pt_router.tid2eid.numpy()) + mx_moe.tid2eid.value = jnp.array(pt_router.tid2eid.numpy(), dtype=jnp.float32) mx_moe.gate.kernel.value = jnp.array(pt_router.weight.detach().numpy()).T hidden_states = torch.randn(self.batch_size, self.seq_len, self.hidden_dim) @@ -910,7 +911,7 @@ def test_swiglu_clamp(self): "topk_routing_group": 1, "mlp_activations_limit": limit, "decoder_block": "deepseek", - "model_name": "deepseek4", + "model_name": "deepseek4-284b", "attention": "dot_product", "base_mlp_dim": 256, "base_moe_mlp_dim": 256, diff --git a/tests/unit/train_compile_test.py b/tests/unit/train_compile_test.py index 1975ad1abf..41557c8c3c 100644 --- a/tests/unit/train_compile_test.py +++ b/tests/unit/train_compile_test.py @@ -804,6 +804,26 @@ def test_deepseek32(self): ) ) + def test_deepseek4(self): + # test deepseek4 compile + compiled_trainstep_file = "/tmp/test_deepseek4.pickle" + train_compile_main( + ( + "", + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5p-256", + "use_iota_embed=true", + "compile_topology_num_slices=1", + "model_name=deepseek4-284b", + "per_device_batch_size=1", + "max_target_length=1024", + "attention=dot_product", + "dtype=bfloat16", + "weight_dtype=bfloat16", + ) + ) + @pytest.mark.cpu_only def test_indexer_dense_warmup(self): # test deepseek3.2 with sparse attention From 96b2fc1962d7c463a6c62bb6f74edfcd3d7deaca Mon Sep 17 00:00:00 2001 From: Dipak Gaikwad Date: Wed, 17 Jun 2026 20:00:50 +0000 Subject: [PATCH 2/2] Enabled auxilillary loss free load balancing and sequence wise load balancing for Deepseek. Tested by running training loop with new tiny Deeepseek V4 model added as part of the commit, here are the logs for testing Without load balancing active logs : https://paste.googleplex.com/6421399878107136 with load balancing logs : https://paste.googleplex.com/6551357300539392 Here are the results actived for reducing the varience : 1 === DeepSeek V4 Load Balancing Variance Analysis (Step 0 vs Step 20) === 2 3 | Layer Index | Routing Type | Step 0 Var (Baseline) | Step 20 Var (Run A) | Step 20 Var (Run B) | Improvement (A vs B) | 4 |-------------|--------------|-----------------------|---------------------|---------------------|----------------------| 5 | 0 | Hash Routed | 3932160.00 | 3932160.00 | 3932160.00 | 0.00% | 6 | 1 | Hash Routed | 3932160.00 | 3932160.00 | 3932160.00 | 0.00% | 7 | 2 | Hash Routed | 3932160.00 | 3932160.00 | 3932160.00 | 0.00% | 8 | 3 | Top-K Routed | 7409.38 | 7509.25 | 3672.12 | 51.10% | 9 | 4 | Top-K Routed | 3158.38 | 3230.12 | 1216.00 | 62.35% | 10 | 5 | Top-K Routed | 5713.38 | 5772.75 | 2359.38 | 59.13% | 11 | 6 | Top-K Routed | 8295.25 | 8082.50 | 3674.12 | 54.54% | 12 | 7 | Top-K Routed | 4765.62 | 4614.62 | 1212.75 | 73.72% | 13 | 8 | Top-K Routed | 4960.75 | 4923.12 | 1663.50 | 66.21% | 14 | 9 | Top-K Routed | 3905.50 | 3816.25 | 1316.88 | 65.49% | 15 | 10 | Top-K Routed | 5057.00 | 4981.12 | 2257.75 | 54.67% | 16 | 11 | Top-K Routed | 10446.62 | 10381.62 | 5565.75 | 46.39% | 17 | 12 | Top-K Routed | 9538.50 | 9529.25 | 5319.12 | 44.18% | 18 | 13 | Top-K Routed | 7031.38 | 7131.25 | 3270.25 | 54.14% | 19 | 14 | Top-K Routed | 4852.00 | 4900.12 | 1906.88 | 61.09% | 20 | 15 | Top-K Routed | 9306.12 | 9342.88 | 4733.75 | 49.33% | 21 | 16 | Top-K Routed | 5811.25 | 5749.50 | 2110.88 | 63.29% | 22 | 17 | Top-K Routed | 6715.62 | 6874.25 | 2664.12 | 61.24% | 23 | 18 | Top-K Routed | 8145.50 | 7869.25 | 3383.75 | 57.00% | 24 | 19 | Top-K Routed | 6042.12 | 5908.62 | 2353.00 | 60.18% | 25 | 20 | Top-K Routed | 8559.88 | 8158.25 | 4333.38 | 46.88% | 26 | 21 | Top-K Routed | 11742.25 | 11943.62 | 7563.50 | 36.67% | 27 | 22 | Top-K Routed | 4959.62 | 5014.88 | 1998.62 | 60.15% | 28 | 23 | Top-K Routed | 7717.12 | 7751.88 | 3879.88 | 49.95% | 29 | 24 | Top-K Routed | 9017.75 | 9307.88 | 4702.75 | 49.48% | 30 | 25 | Top-K Routed | 14127.12 | 14111.25 | 8079.25 | 42.75% | 31 | 26 | Top-K Routed | 5074.25 | 5194.12 | 1675.50 | 67.74% | 32 | 27 | Top-K Routed | 11919.50 | 11204.38 | 6470.75 | 42.25% | 33 | 28 | Top-K Routed | 12241.75 | 12998.62 | 7624.12 | 41.35% | 34 | 29 | Top-K Routed | 9384.50 | 9005.00 | 5052.00 | 43.90% | 35 | 30 | Top-K Routed | 9698.62 | 9678.25 | 5231.75 | 45.94% | 36 | 31 | Top-K Routed | 12244.25 | 12392.75 | 7249.25 | 41.50% | 37 | 32 | Top-K Routed | 10030.00 | 9972.62 | 4755.50 | 52.31% | 38 | 33 | Top-K Routed | 7265.00 | 6973.62 | 3271.75 | 53.08% | 39 | 34 | Top-K Routed | 11945.50 | 11940.62 | 6076.88 | 49.11% | 40 | 35 | Top-K Routed | 12917.50 | 13740.00 | 7210.62 | 47.52% | 41 | 36 | Top-K Routed | 15011.62 | 15083.00 | 8870.62 | 41.19% | 42 | 37 | Top-K Routed | 10294.12 | 10176.25 | 5907.50 | 41.95% | 43 | 38 | Top-K Routed | 8928.62 | 9236.00 | 5136.62 | 44.38% | 44 | 39 | Top-K Routed | 15633.62 | 15171.00 | 9684.75 | 36.16% | 45 | 40 | Top-K Routed | 7687.75 | 7658.12 | 4521.25 | 40.96% | 46 | 41 | Top-K Routed | 12485.12 | 12270.38 | 6933.25 | 43.50% | 47 | 42 | Top-K Routed | 17641.25 | 17163.50 | 10974.12 | 36.06% | 48 |-------------|--------------|-----------------------|---------------------|---------------------|----------------------| 49 | TOTAL/AVG | Top-K Only | 357681.12 | 356762.50 | 185883.62 | 47.90% | Raw data collected for this analysis: https://paste.googleplex.com/5060754624610304 https://paste.googleplex.com/5473518849490944 --- src/maxtext/configs/models/deepseek4-tiny.yml | 69 +++++++++++++++++++ src/maxtext/configs/types.py | 1 + src/maxtext/layers/quantizations.py | 2 +- src/maxtext/trainers/pre_train/train.py | 64 ++++++++++++----- tests/unit/deepseek_routed_bias_test.py | 65 +++++++++++++++++ 5 files changed, 183 insertions(+), 18 deletions(-) create mode 100644 src/maxtext/configs/models/deepseek4-tiny.yml create mode 100644 tests/unit/deepseek_routed_bias_test.py diff --git a/src/maxtext/configs/models/deepseek4-tiny.yml b/src/maxtext/configs/models/deepseek4-tiny.yml new file mode 100644 index 0000000000..881043777b --- /dev/null +++ b/src/maxtext/configs/models/deepseek4-tiny.yml @@ -0,0 +1,69 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Tiny model config for DeepSeek V4 for CPU execution and testing + +base_emb_dim: 64 +base_num_query_heads: 4 +base_num_kv_heads: 1 +base_num_decoder_layers: 43 +base_mlp_dim: 64 +base_moe_mlp_dim: 64 +vocab_size: 129280 +head_dim: 32 +qk_rope_head_dim: 32 + +# --- Standard Defaults --- +enable_dropout: false +logits_via_embedding: false +normalization_layer_epsilon: 1.0e-6 + +# --- V4 Specific Architectural Keys --- +decoder_block: "deepseek4" +mhc_expansion_rate: 4 +first_num_hash_layers: 3 +indexer_head_dim: 32 +indexer_n_heads: 4 +indexer_topk: 16 + +# Note: Layers (0,1) are not compressed. +# The 44th layer (MTP module with compress_ratio=0) has been explicitly dropped for now. +# This leaves exactly 43 layers: 2 prefix [0,0] + 40 scanned + 1 suffix [4]. +compress_ratios: [0, 0, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4] + +# --- MoE configuration --- +mlp_activations: ["silu", "linear"] +num_experts: 16 +num_experts_per_tok: 4 +shared_experts: 1 +routed_score_func: "sqrtsoftplus" +routed_bias: true +routed_bias_update_rate: 0.001 +load_balance_loss_weight: 0.0001 +adamw_mask: [".*gate.*bias.*"] + +# --- Attention configuration --- +attention: 'dot_product' +attention_type: 'compressed' +q_lora_rank: 16 +o_groups: 4 +o_lora_rank: 16 +sliding_window_size: 32 + +# --- RoPE --- + +rope_type: "default" +rope_max_timescale: 10000 # Main RoPE theta +compressed_rope_max_timescale: 160000 # Compressed RoPE theta +max_position_embeddings: 4096 diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index d1f293aae8..0d4e1cc174 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -228,6 +228,7 @@ class ProfilerType(str, Enum): "deepseek3-tiny", "deepseek3.2-671b", "deepseek4-284b", + "deepseek4-tiny", "deepseek-custom", "kimi-k2-1t", "gemma-7b", diff --git a/src/maxtext/layers/quantizations.py b/src/maxtext/layers/quantizations.py index 95bd79eb9f..86b61c7480 100644 --- a/src/maxtext/layers/quantizations.py +++ b/src/maxtext/layers/quantizations.py @@ -38,7 +38,7 @@ import qwix from qwix._src.core import dot_general_qt from qwix._src.core import sparsity -from qwix._src.utils import flax_util +from qwix._src import flax_util import qwix.pallas as qpl # Params used to define mixed precision quantization configs diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index fd2cc7b56c..0d6fe59b45 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -278,12 +278,6 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr else: max_logging.debug("\nNo MoE load balance loss found. Defaulting to 0.0.") - # get MoE routed bias term updates - moe_bias_updates = None - if config.routed_bias and config.routed_bias_update_rate > 0.0: - nested_key = ("intermediates", "decoder", "moe_layers", "moe_bias_updates") - moe_bias_updates = maxtext_utils.get_nested_value(intermediate_outputs, nested_key, None) - # Add the model's primary output to the intermediates dict so it can be used # by the acceptance rate calculation in eval_step. intermediate_outputs["logits"] = logits @@ -295,7 +289,6 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr "total_weights": total_weights, "moe_lb_loss": moe_lb_loss, "indexer_loss": indexer_loss, - "moe_bias_updates": moe_bias_updates, "mtp_loss": mtp_loss, "batch_stats": (intermediate_outputs.get("batch_stats", None) if hasattr(intermediate_outputs, "get") else None), } @@ -421,9 +414,9 @@ def diff_wrapper(curr_params, custom_params, rest, config, data): moe_lb_loss = aux["moe_lb_loss"] indexer_loss = aux.get("indexer_loss", 0.0) z_loss = aux.get("z_loss", 0.0) - moe_bias_updates = aux.get("moe_bias_updates") mtp_loss = aux.get("mtp_loss", 0.0) new_opt_state = None + bias_metrics = {} if isinstance(model, nn.Module): if config.gradient_clipping_threshold > 0: @@ -480,12 +473,30 @@ def move(path, value): else: new_state = state.apply_gradients(grads=full_grads) - # Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family - if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None: - target_path = ("params", "decoder", "moe_layers", "DeepSeekMoeBlock_0", "MoeBlock_0", "gate", "bias") - # Updates the shape to be aligned with state. - moe_bias_updates = jnp.array(moe_bias_updates[0]).transpose() - new_state = maxtext_utils.update_state_param(new_state, target_path, moe_bias_updates) + # Apply updates for Auxiliary-Loss-Free load balancing for the DeepSeek family. + # We dynamically traverse the PyTree to apply updates because the topology varies drastically: + # 1. DeepSeek V3 mixes dense layers (no bias updates) with MoE layers. + # 2. DeepSeek V4 introduces Hash Routing in early layers (which lack a learnable bias entirely). + # 3. DeepSeek V4 groups alternating attention topologies into nested `ScannableBlocks`. + # Dynamic traversal ensures we only target the correct `gate.bias` parameters without hardcoded, brittle paths. + if config.routed_bias and config.routed_bias_update_rate > 0.0: + from flax import traverse_util + flat_intermediates = traverse_util.flatten_dict(aux.get("intermediate_outputs", {})) + flat_params = traverse_util.flatten_dict(new_state.params) + new_flat_params = dict(flat_params) + + for path, update in flat_intermediates.items(): + if path[-1] == "moe_bias_updates": + prefix = path[1:-1] if path[0] == "intermediates" else path[:-1] + for param_path in flat_params.keys(): + param_prefix = param_path[1:] if param_path[0] == "params" else param_path + if len(param_prefix) >= len(prefix) and param_prefix[:len(prefix)] == prefix and param_path[-2:] == ("gate", "bias"): + update_val = update[0] if isinstance(update, (tuple, list)) else update + bias_metrics[f"learning/moe_bias_before_norm_{'-'.join(map(str, param_path))}"] = jnp.linalg.norm(new_flat_params[param_path]) + new_flat_params[param_path] = new_flat_params[param_path] + jnp.array(update_val).transpose() + bias_metrics[f"learning/moe_bias_update_norm_{'-'.join(map(str, param_path))}"] = jnp.linalg.norm(jnp.array(update_val)) + + new_state = new_state.replace(params=traverse_util.unflatten_dict(new_flat_params)) else: if config.gradient_clipping_threshold > 0: grads = maxtext_utils.apply_gradient_clipping(raw_grads, None, config.gradient_clipping_threshold) @@ -506,9 +517,27 @@ def move(path, value): new_state = state # Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family - if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None: - target_bias = new_state.model.decoder.moe_layers.DeepSeekMoeBlock_0.MoeBlock_0.gate.bias - target_bias.value = target_bias.value + jnp.array(moe_bias_updates[0]).transpose() + if config.routed_bias and config.routed_bias_update_rate > 0.0: + from flax import traverse_util + flat_intermediates = traverse_util.flatten_dict(aux.get("intermediate_outputs", {})) + jax.debug.print("FLAT_INTERMEDIATE_KEYS_NNX: {}", flat_intermediates.keys()) + for path, update in flat_intermediates.items(): + if path[-1] == "moe_bias_updates": + target = new_state.model + for key in path[:-1]: + if hasattr(target, key): + target = getattr(target, key) + elif isinstance(target, dict) and key in target: + target = target[key] + else: + break + else: + for _, node in nnx.iter_graph(target): + if type(node).__name__ == "GateLogit" and hasattr(node, "bias") and node.bias is not None: + update_val = update[0] if isinstance(update, (tuple, list)) else update + bias_metrics[f"learning/moe_bias_before_norm_{'-'.join(map(str, path[:-1]))}"] = jnp.linalg.norm(node.bias.value) + node.bias.value = node.bias.value + jnp.array(update_val).transpose() + bias_metrics[f"learning/moe_bias_update_norm_{'-'.join(map(str, path[:-1]))}"] = jnp.linalg.norm(jnp.array(update_val)) lm_loss = xent_sum / (total_weights + EPS) scalar_metrics = { @@ -521,6 +550,7 @@ def move(path, value): "learning/mtp_loss": mtp_loss, "learning/total_weights": total_weights, } + scalar_metrics.update(bias_metrics) if config.use_qk_clip: if isinstance(model, nn.Module): new_state = qk_clip_utils.apply_qk_clip(new_state, intermediate_outputs, config) diff --git a/tests/unit/deepseek_routed_bias_test.py b/tests/unit/deepseek_routed_bias_test.py new file mode 100644 index 0000000000..9e12a7da9a --- /dev/null +++ b/tests/unit/deepseek_routed_bias_test.py @@ -0,0 +1,65 @@ +import unittest +import jax +import jax.numpy as jnp +import optax +from flax.training import train_state +from maxtext.configs import pyconfig +from maxtext.models import models +from maxtext.trainers.pre_train import train as pre_train +class DeepSeekRoutedBiasTest(unittest.TestCase): + def setUp(self): + self.mesh = jax.sharding.Mesh(jax.devices(), ('data',)) + def _make_dummy_data(self, batch=1, seq=16): + return { + "inputs": jnp.zeros((batch, seq), dtype=jnp.int32), + "inputs_position": jnp.broadcast_to(jnp.arange(seq), (batch, seq)), + "inputs_segmentation": jnp.ones((batch, seq), dtype=jnp.int32), + "targets": jnp.zeros((batch, seq), dtype=jnp.int32), + "targets_segmentation": jnp.ones((batch, seq), dtype=jnp.int32), + } + def _create_and_run_train_step(self, config_args): + config = pyconfig.initialize(config_args) + rngs = jax.nnx.Rngs(0) if hasattr(jax, 'nnx') else __import__('flax.nnx', fromlist=['Rngs']).Rngs(0) + import flax.nnx as nnx + from maxtext.common import train_state_nnx + rngs = nnx.Rngs(0) + model = models.Transformer(config, self.mesh, quant=None, rngs=rngs) + data = self._make_dummy_data(batch=config.micro_batch_size_to_train_on, seq=config.max_target_length) + optimizer = nnx.Optimizer(model, optax.sgd(0.01), wrt=nnx.Param) + ts = train_state_nnx.TrainStateNNX(model, optimizer) + state_graphdef, state_pure = nnx.split(ts) + new_state, metrics = pre_train.train_step( + state_graphdef, config, state_mesh_shardings=None, params_shardings=None, state=state_pure, data=data + ) + return new_state, metrics + def test_deepseek_v3_dense_routed_bias_success(self): + """Proves that a DeepSeek V3 model with dense layers (no moe_layers attribute) + successfully traverses the state tree and updates routed bias without crashing. + """ + config_args = [ + "", + "src/maxtext/configs/base.yml", + "model_name=deepseek3-tiny", + "decoder_block=deepseek", + "num_decoder_layers=2", + "per_device_batch_size=1", + "max_target_length=16", + "routed_bias=True", + "routed_bias_update_rate=0.001", + "skip_jax_distributed_system=True", + "base_emb_dim=64", + "base_mlp_dim=64", + "base_moe_mlp_dim=64", + "base_num_query_heads=1", + "base_num_kv_heads=1", + "num_experts=2", + "num_experts_per_tok=2", + "first_num_dense_layers=1", + "sparse_matmul=False", + "override_model_config=True", + ] + new_state, metrics = self._create_and_run_train_step(config_args) + self.assertIsNotNone(new_state) + self.assertIn("learning/loss", metrics["scalar"]) +if __name__ == '__main__': + unittest.main()