diff --git a/docs/reference/core_concepts/moe_configuration.md b/docs/reference/core_concepts/moe_configuration.md index 96b3bbe65e..73114f868a 100644 --- a/docs/reference/core_concepts/moe_configuration.md +++ b/docs/reference/core_concepts/moe_configuration.md @@ -97,6 +97,8 @@ Dropping: `mlp_bias`: If enabled, add learnable bias terms for MLP matmul. Originally implemented to support the GPT-OSS model architecture. +`fused_mlp`: If enabled alongside `sparse_matmul=True`, fuses the two FFN1 grouped GEMMs (wi\_0 and wi\_1) into a single grouped GEMM call. Expert weights are stored in a concatenated `(num_experts, embed_dim, 2 * mlp_dim)` shape, so input activations are loaded from HBM once per forward pass instead of twice. This is analogous to `fused_mlp` for dense models and is backend-agnostic (works with Megablox, JAX Ragged Dot, and Tokamax). + `use_batch_split_schedule` (experimental): If enabled, split batch into micro-batches to hide communications that yields performance benefits. ## 2. Sharding diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 7abe0782f3..b3203549e5 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -265,7 +265,6 @@ def quant_dot_general(self) -> nnx_wrappers.ToNNX | None: return getattr(self, self._quant_dot_general_name) def __call__(self, inputs: jax.Array, _initializing: bool = False) -> Tuple[jax.Array, Optional[jax.Array]]: - inputs = jnp.asarray(inputs, self.dtype) norm_axis = linears.normalize_axes(self.axis, inputs.ndim) @@ -410,7 +409,7 @@ def __init__( self.wi_0 = jnp.zeros((num_experts, self.moe_expert_input_dim, intermediate_dim)) self.wi_1 = jnp.zeros((num_experts, self.moe_expert_input_dim, intermediate_dim)) self.wo = jnp.zeros((num_experts, intermediate_dim, self.moe_expert_input_dim)) - elif self.config.prefuse_moe_weights and self.config.attention == "vllm_rpa": + elif (self.config.prefuse_moe_weights and self.config.attention == "vllm_rpa") or self.config.fused_mlp: self.wi = nnx.Param( self.kernel_init( self.rngs.params(), @@ -1318,29 +1317,44 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): self.config.wo_tile_drhs_embed_dim, # Called n in megablox, and indeed is the RHS batch dim ) - layer_w0 = gmm_fn( - x, - w0, - tiling=wi_tile_size, - weight_gather_axes=wi_gather_axes, - ) - if self.get_tensor_transpose_parallelism_size() > 1: - layer_w0 = jax.lax.psum(layer_w0, "tensor_transpose") - if self.config.mlp_bias: - layer_w0 = layer_w0 + w0_bias - layer_w0 = adc.checkpoint_name(layer_w0, "moe_mlpwi_0") - - layer_w1 = gmm_fn( - x, - w1, - tiling=wi_tile_size, - weight_gather_axes=wi_gather_axes, - ) - if self.get_tensor_transpose_parallelism_size() > 1: - layer_w1 = jax.lax.psum(layer_w1, "tensor_transpose") - if self.config.mlp_bias: - layer_w1 = layer_w1 + w1_bias - layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1") + if self.config.fused_mlp: + # Weights are stored as (G,K,2N); w0/w1 are adjacent slices so XLA elides this concat. + w_fused = jnp.concatenate([w0, w1], axis=-1) + out = gmm_fn(x, w_fused, tiling=wi_tile_size, weight_gather_axes=wi_gather_axes) + n = w0.shape[-1] + layer_w0, layer_w1 = out[:, :n], out[:, n:] + if self.get_tensor_transpose_parallelism_size() > 1: + layer_w0 = jax.lax.psum(layer_w0, "tensor_transpose") + layer_w1 = jax.lax.psum(layer_w1, "tensor_transpose") + if self.config.mlp_bias: + layer_w0 = layer_w0 + w0_bias + layer_w1 = layer_w1 + w1_bias + layer_w0 = adc.checkpoint_name(layer_w0, "moe_mlpwi_0") + layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1") + else: + layer_w0 = gmm_fn( + x, + w0, + tiling=wi_tile_size, + weight_gather_axes=wi_gather_axes, + ) + if self.get_tensor_transpose_parallelism_size() > 1: + layer_w0 = jax.lax.psum(layer_w0, "tensor_transpose") + if self.config.mlp_bias: + layer_w0 = layer_w0 + w0_bias + layer_w0 = adc.checkpoint_name(layer_w0, "moe_mlpwi_0") + + layer_w1 = gmm_fn( + x, + w1, + tiling=wi_tile_size, + weight_gather_axes=wi_gather_axes, + ) + if self.get_tensor_transpose_parallelism_size() > 1: + layer_w1 = jax.lax.psum(layer_w1, "tensor_transpose") + if self.config.mlp_bias: + layer_w1 = layer_w1 + w1_bias + layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1") intermediate_layer = self.apply_ffn_activation(layer_w0, layer_w1) intermediate_output = gmm_fn( @@ -2144,6 +2158,11 @@ def __call__( w1_kernel = None if cfg.prefuse_moe_weights and cfg.attention == "vllm_rpa": fused_kernel = jnp.asarray(self.wi[...], self.dtype) + elif cfg.fused_mlp: + wi = jnp.asarray(self.wi[...], self.dtype) + n = wi.shape[-1] // 2 + w0_kernel = wi[..., :n] + w1_kernel = wi[..., n:] else: w0_kernel = jnp.asarray(self.wi_0[...], self.dtype) w1_kernel = jnp.asarray(self.wi_1[...], self.dtype) diff --git a/tests/unit/moe_test.py b/tests/unit/moe_test.py index 10d13830c2..d498d39fd1 100644 --- a/tests/unit/moe_test.py +++ b/tests/unit/moe_test.py @@ -68,42 +68,38 @@ def setUp(self): def test_generate_masks(self): # expert_capacity = (tokens_per_batch / num_experts) * capacity_factor # expert_capacity_in_batch = (4 * 2 / 8) * 2 = 2 - top_k_indices = jnp.array( + top_k_indices = jnp.array([ + [[0, 5], [0, 4], [1, 0], [3, 5]], + [[1, 2], [4, 1], [5, 0], [7, 1]], + [[6, 2], [2, 3], [4, 2], [1, 2]], + [[4, 1], [0, 7], [5, 0], [4, 7]], + ]) + softmax_probs = jnp.array([ [ - [[0, 5], [0, 4], [1, 0], [3, 5]], - [[1, 2], [4, 1], [5, 0], [7, 1]], - [[6, 2], [2, 3], [4, 2], [1, 2]], - [[4, 1], [0, 7], [5, 0], [4, 7]], - ] - ) - softmax_probs = jnp.array( + [0.20, 0, 0, 0, 0, 0.80, 0, 0], + [0.68, 0, 0, 0, 0.32, 0, 0, 0], + [0.22, 0.78, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0.32, 0, 0.68, 0, 0], + ], [ - [ - [0.20, 0, 0, 0, 0, 0.80, 0, 0], - [0.68, 0, 0, 0, 0.32, 0, 0, 0], - [0.22, 0.78, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0.32, 0, 0.68, 0, 0], - ], - [ - [0, 0.26, 0.74, 0, 0, 0, 0, 0], - [0, 0.79, 0, 0, 0.21, 0, 0, 0], - [0.89, 0, 0, 0, 0, 0.11, 0, 0], - [0, 0.11, 0, 0, 0, 0, 0, 0.89], - ], - [ - [0, 0, 0.26, 0, 0, 0, 0.74, 0], - [0, 0, 0.88, 0.12, 0, 0, 0, 0], - [0, 0, 0.17, 0, 0.83, 0, 0, 0], - [0, 0.35, 0.65, 0, 0, 0, 0, 0], - ], - [ - [0, 0.47, 0, 0, 0.53, 0, 0, 0], - [0.36, 0, 0, 0, 0, 0, 0, 0.64], - [0.15, 0, 0, 0, 0, 0.85, 0, 0], - [0, 0, 0, 0, 0.18, 0, 0, 0.82], - ], - ] - ) + [0, 0.26, 0.74, 0, 0, 0, 0, 0], + [0, 0.79, 0, 0, 0.21, 0, 0, 0], + [0.89, 0, 0, 0, 0, 0.11, 0, 0], + [0, 0.11, 0, 0, 0, 0, 0, 0.89], + ], + [ + [0, 0, 0.26, 0, 0, 0, 0.74, 0], + [0, 0, 0.88, 0.12, 0, 0, 0, 0], + [0, 0, 0.17, 0, 0.83, 0, 0, 0], + [0, 0.35, 0.65, 0, 0, 0, 0, 0], + ], + [ + [0, 0.47, 0, 0, 0.53, 0, 0, 0], + [0.36, 0, 0, 0, 0, 0, 0, 0.64], + [0.15, 0, 0, 0, 0, 0.85, 0, 0], + [0, 0, 0, 0, 0.18, 0, 0, 0.82], + ], + ]) # As expert_capacity_in_batch=2, so updated softmax_probs become (4 tokens were dropped): # softmax_probs = jnp.array([[[0.20, 0, 0, 0, 0, 0.80, 0, 0], @@ -238,14 +234,10 @@ def setUp(self): def test_deepseek_routing(self): # shape as [batch, sequence, num_experts] = [1,2,16] - gate_logits = jnp.array( - [ - [ - [0.20, 0.10, 0.05, 0.10, 0.10, 0.60, 0.30, 0.10, 0.80, 0.01, 0.01, 0.01, 0.05, 0.80, 0.20, 0.10], - [0.68, 0.20, 0.06, 0.03, 0.32, 0.10, 0.05, 0.02, 0.65, 0.20, 0.04, 0.01, 0.32, 0.10, 0.05, 0.02], - ] - ] - ) + gate_logits = jnp.array([[ + [0.20, 0.10, 0.05, 0.10, 0.10, 0.60, 0.30, 0.10, 0.80, 0.01, 0.01, 0.01, 0.05, 0.80, 0.20, 0.10], + [0.68, 0.20, 0.06, 0.03, 0.32, 0.10, 0.05, 0.02, 0.65, 0.20, 0.04, 0.01, 0.32, 0.10, 0.05, 0.02], + ]]) pre_bias_logits = gate_logits - 0.5 # 4 groups of 1st token: @@ -1402,5 +1394,73 @@ def test_prefused_vs_sparse_softmax(self): self.assertIsNone(bias_updates) +@pytest.mark.tpu_only +class FusedMlpMoETest(unittest.TestCase): + """Tests that fused_mlp=True and fused_mlp=False produce identical outputs for MoE.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._B = 1 + self._S = 16 + + def setUp(self): + super().setUp() + self.rng = jax.random.PRNGKey(0) + extra_args = get_decoupled_parallelism_overrides() + self.ref_cfg = pyconfig.initialize( + [None, get_test_config_path()], + run_name="fused_mlp_moe_ref", + enable_checkpointing=False, + model_name="mixtral-8x7b", + dtype="bfloat16", + sparse_matmul=True, + megablox=True, + fused_mlp=False, + ici_expert_parallelism=jax.device_count(), + max_target_length=self._S, + per_device_batch_size=self._B, + **extra_args, + ) + ref_devices = maxtext_utils.create_device_mesh(self.ref_cfg) + self.ref_mesh = Mesh(ref_devices, self.ref_cfg.mesh_axes) + self.ref_model = make_moe(self.ref_cfg, self.ref_mesh) + + def _inputs(self): + return jax.random.normal(self.rng, (self._B, self._S, self.ref_cfg.base_emb_dim), dtype=jnp.bfloat16) + + def test_fused_mlp_matches_unfused(self): + """fused_mlp=True output matches fused_mlp=False with sparse_matmul (Megablox).""" + extra_args = get_decoupled_parallelism_overrides() + fused_cfg = pyconfig.initialize( + [None, get_test_config_path()], + run_name="fused_mlp_moe_fused", + enable_checkpointing=False, + model_name="mixtral-8x7b", + dtype="bfloat16", + sparse_matmul=True, + megablox=True, + fused_mlp=True, + ici_expert_parallelism=jax.device_count(), + max_target_length=self._S, + per_device_batch_size=self._B, + **extra_args, + ) + fused_devices = maxtext_utils.create_device_mesh(fused_cfg) + fused_mesh = Mesh(fused_devices, fused_cfg.mesh_axes) + fused_model = make_moe(fused_cfg, fused_mesh) + copy_weights_prefused(self.ref_model, fused_model) + + inputs = self._inputs() + ref_out, _, _ = self.ref_model(inputs) + fused_out, _, _ = fused_model(inputs) + + np.testing.assert_allclose( + np.array(ref_out, dtype=np.float32), + np.array(fused_out, dtype=np.float32), + rtol=1e-2, + atol=1e-2, + ) + + if __name__ == "__main__": unittest.main()