diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 7bfcd226c0..c1196963c2 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -219,6 +219,9 @@ expert_balance: False use_random_routing: false # whether to use random routing for debug/test purpose use_custom_sort_vjp: true # whether to use a custom VJP sort for efficient backward pass processing in sparse matmul use_ring_of_experts: false # whether to use ring of experts for sparse matmul expert parallelism +# If true, shard the MoE dispatch/MLP batch dim without 'expert' (activation_batch_no_exp) so the +# expert GEMM stays expert-parallel (AllToAll); false keeps 'expert' on it (activation_batch_moe). +moe_dispatch_no_expert_sharding: false use_ragged_sort: false # whether to use the Pallas ragged-sort kernels in the MoE permute path; valid both with and # without `use_ring_of_experts` (with EP > 1). When `use_ring_of_experts=True` the kernels run # inside `permute`/`unpermute`; otherwise they run inside `local_permute`/local-unpermute. @@ -527,6 +530,8 @@ logical_axis_rules: [ # ========================================== # MoE Activations ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose', 'expert']], + # Batch axis WITHOUT 'expert' so MoE dispatch/mlp GEMM stays expert-parallel (AllToAll), not FSDP. + ['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']], ['activation_length_moe', ['context']], ['activation_norm_length_moe', ['tensor_sequence', 'context']], ['activation_embed_moe', ['tensor', 'tensor_transpose']], diff --git a/src/maxtext/configs/models/mixtral-8x22b.yml b/src/maxtext/configs/models/mixtral-8x22b.yml index b6d1fc71c0..778882a5c0 100644 --- a/src/maxtext/configs/models/mixtral-8x22b.yml +++ b/src/maxtext/configs/models/mixtral-8x22b.yml @@ -31,3 +31,6 @@ num_experts: 8 num_experts_per_tok: 2 rope_max_timescale: 1_000_000 decoder_block: "mixtral" +# Few large experts: keep the MoE dispatch/MLP GEMM expert-parallel (AllToAll) rather than +# letting the batch dim double-map onto 'expert' and fall back to FSDP-style collectives. +moe_dispatch_no_expert_sharding: true diff --git a/src/maxtext/configs/models/mixtral-8x7b.yml b/src/maxtext/configs/models/mixtral-8x7b.yml index 9528a667c5..211030cec8 100644 --- a/src/maxtext/configs/models/mixtral-8x7b.yml +++ b/src/maxtext/configs/models/mixtral-8x7b.yml @@ -31,3 +31,6 @@ num_experts: 8 num_experts_per_tok: 2 rope_max_timescale: 1_000_000 decoder_block: "mixtral" +# Few large experts: keep the MoE dispatch/MLP GEMM expert-parallel (AllToAll) rather than +# letting the batch dim double-map onto 'expert' and fall back to FSDP-style collectives. +moe_dispatch_no_expert_sharding: true diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 00d77f071e..009f2c4f29 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -737,6 +737,13 @@ class MoEGeneral(BaseModel): False, description="Whether to use Ring of Experts for sparse matmul expert parallelism.", ) + moe_dispatch_no_expert_sharding: bool = Field( + False, + description=( + "If true, shard the MoE dispatch/MLP batch dim without 'expert' so the expert GEMM " + "stays expert-parallel (AllToAll); false keeps 'expert' on it (activation_batch_moe)." + ), + ) use_ragged_sort: bool = Field( False, description="Whether to use ragged kernel for sorting, improve performance when EP is enabled." ) diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 01f248244d..ae1dd9f70a 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -2197,15 +2197,20 @@ def dense_matmul( top_k_indices, weights # pylint: disable=undefined-variable,possibly-used-before-assignment ) mask_axes = ("activation_batch_moe", "activation_norm_length_moe", None, None) + # Dispatch/MLP are already expert-sharded via "activation_exp"; with no_expert_sharding the + # batch dim drops 'expert' so the GEMM stays expert-parallel (AllToAll) instead of FSDP. + moe_dispatch_batch_axis = ( + "activation_batch_no_exp" if self.config.moe_dispatch_no_expert_sharding else "activation_batch_moe" + ) dispatch_axis = ( "activation_exp", - "activation_batch_moe", + moe_dispatch_batch_axis, None, "activation_embed_moe", ) mlp_axis = ( "activation_exp", - "activation_batch_moe", + moe_dispatch_batch_axis, None, "activation_mlp", ) diff --git a/tests/unit/moe_test.py b/tests/unit/moe_test.py index 51b7d1ba3e..9bd0000076 100644 --- a/tests/unit/moe_test.py +++ b/tests/unit/moe_test.py @@ -1534,5 +1534,38 @@ def test_prefused_vs_sparse_softmax(self): self.assertIsNone(bias_updates) +@pytest.mark.parametrize("model_name", ["mixtral-8x7b", "mixtral-8x22b", "deepseek3-671b"]) +def test_moe_dispatch_keeps_expert_on_expert_dim(model_name): + """Regression guard for the MoE dispatch/MLP expert-parallel sharding. + + The expert (E) dim must always be sharded by the 'expert' mesh axis. When + moe_dispatch_no_expert_sharding is set, the batch (B) dim must NOT also take + 'expert' (which would double-map two tensor dims onto one mesh axis and force an + FSDP-style fallback instead of expert-parallel AllToAll). Mirrors dense_matmul's + axis selection. + """ + cfg = pyconfig.initialize( + [None, get_test_config_path()], + run_name=f"moe_shard_{model_name}", + enable_checkpointing=False, + model_name=model_name, + ) + rules = cfg.logical_axis_rules + batch_axis = "activation_batch_no_exp" if cfg.moe_dispatch_no_expert_sharding else "activation_batch_moe" + spec = nn_partitioning.logical_to_mesh_axes(("activation_exp", batch_axis, None, "activation_embed_moe"), rules=rules) + + def _as_set(entry): + if entry is None: + return set() + return {entry} if isinstance(entry, str) else set(entry) + + e_axes, b_axes = _as_set(spec[0]), _as_set(spec[1]) + if cfg.moe_dispatch_no_expert_sharding: + # EP enabled: the expert dim must be sharded by 'expert' and the batch dim must not steal it. + assert "expert" in e_axes, "expert dim must be sharded by the 'expert' mesh axis" + assert "expert" not in b_axes, "with moe_dispatch_no_expert_sharding the batch dim must not take 'expert'" + # else: model keeps the default activation_batch_moe; don't assert the (intentional) collision. + + if __name__ == "__main__": unittest.main()