From 36a85de47a8cfc9fe0ad32692749d1f5996e3999 Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Tue, 16 Jun 2026 19:42:43 +0000 Subject: [PATCH 1/3] Make MoE dispatch/MLP expert-axis batch sharding configurable The dispatch/MLP MoE activations are already expert-sharded via activation_exp. Since #4007, their batch dim also maps to activation_batch_moe, which includes 'expert'. Under single-node expert parallelism (ici_expert_parallelism=-1) this double-maps two tensor dims onto the 'expert' mesh axis, so GSPMD falls back from expert-parallel AllToAll to FSDP-style AllGather+ReduceScatter, regressing throughput for few-large-expert models (e.g. Mixtral-8x7b: ~7.4k -> ~10.9k tok/s/device at bs=11 on 8x MI355X). Add a config flag moe_dispatch_no_expert_sharding (default false) that selects a new activation_batch_no_exp rule ([data, fsdp, fsdp_transpose], no 'expert') for the training dispatch/MLP batch axis. Enable it for mixtral-8x7b. Default-false keeps every other model and all TPU/non-EP paths byte-identical; the flag only changes sharding when the 'expert' mesh axis size > 1. --- src/maxtext/configs/base.yml | 5 +++++ src/maxtext/configs/models/mixtral-8x7b.yml | 3 +++ src/maxtext/configs/types.py | 7 +++++++ src/maxtext/layers/moe.py | 9 +++++++-- 4 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 7bfcd226c0..c1196963c2 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -219,6 +219,9 @@ expert_balance: False use_random_routing: false # whether to use random routing for debug/test purpose use_custom_sort_vjp: true # whether to use a custom VJP sort for efficient backward pass processing in sparse matmul use_ring_of_experts: false # whether to use ring of experts for sparse matmul expert parallelism +# If true, shard the MoE dispatch/MLP batch dim without 'expert' (activation_batch_no_exp) so the +# expert GEMM stays expert-parallel (AllToAll); false keeps 'expert' on it (activation_batch_moe). +moe_dispatch_no_expert_sharding: false use_ragged_sort: false # whether to use the Pallas ragged-sort kernels in the MoE permute path; valid both with and # without `use_ring_of_experts` (with EP > 1). When `use_ring_of_experts=True` the kernels run # inside `permute`/`unpermute`; otherwise they run inside `local_permute`/local-unpermute. @@ -527,6 +530,8 @@ logical_axis_rules: [ # ========================================== # MoE Activations ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose', 'expert']], + # Batch axis WITHOUT 'expert' so MoE dispatch/mlp GEMM stays expert-parallel (AllToAll), not FSDP. + ['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']], ['activation_length_moe', ['context']], ['activation_norm_length_moe', ['tensor_sequence', 'context']], ['activation_embed_moe', ['tensor', 'tensor_transpose']], diff --git a/src/maxtext/configs/models/mixtral-8x7b.yml b/src/maxtext/configs/models/mixtral-8x7b.yml index 9528a667c5..211030cec8 100644 --- a/src/maxtext/configs/models/mixtral-8x7b.yml +++ b/src/maxtext/configs/models/mixtral-8x7b.yml @@ -31,3 +31,6 @@ num_experts: 8 num_experts_per_tok: 2 rope_max_timescale: 1_000_000 decoder_block: "mixtral" +# Few large experts: keep the MoE dispatch/MLP GEMM expert-parallel (AllToAll) rather than +# letting the batch dim double-map onto 'expert' and fall back to FSDP-style collectives. +moe_dispatch_no_expert_sharding: true diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 00d77f071e..009f2c4f29 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -737,6 +737,13 @@ class MoEGeneral(BaseModel): False, description="Whether to use Ring of Experts for sparse matmul expert parallelism.", ) + moe_dispatch_no_expert_sharding: bool = Field( + False, + description=( + "If true, shard the MoE dispatch/MLP batch dim without 'expert' so the expert GEMM " + "stays expert-parallel (AllToAll); false keeps 'expert' on it (activation_batch_moe)." + ), + ) use_ragged_sort: bool = Field( False, description="Whether to use ragged kernel for sorting, improve performance when EP is enabled." ) diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 01f248244d..ae1dd9f70a 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -2197,15 +2197,20 @@ def dense_matmul( top_k_indices, weights # pylint: disable=undefined-variable,possibly-used-before-assignment ) mask_axes = ("activation_batch_moe", "activation_norm_length_moe", None, None) + # Dispatch/MLP are already expert-sharded via "activation_exp"; with no_expert_sharding the + # batch dim drops 'expert' so the GEMM stays expert-parallel (AllToAll) instead of FSDP. + moe_dispatch_batch_axis = ( + "activation_batch_no_exp" if self.config.moe_dispatch_no_expert_sharding else "activation_batch_moe" + ) dispatch_axis = ( "activation_exp", - "activation_batch_moe", + moe_dispatch_batch_axis, None, "activation_embed_moe", ) mlp_axis = ( "activation_exp", - "activation_batch_moe", + moe_dispatch_batch_axis, None, "activation_mlp", ) From 5bff15500e6e8d2da304ff161af91c5c1093f7ca Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Wed, 17 Jun 2026 20:11:20 +0000 Subject: [PATCH 2/3] Enable moe_dispatch_no_expert_sharding for mixtral-8x22b. Same 8-large-expert geometry as 8x7b, so it benefits from the same expert-parallel MoE dispatch/MLP sharding. --- src/maxtext/configs/models/mixtral-8x22b.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/maxtext/configs/models/mixtral-8x22b.yml b/src/maxtext/configs/models/mixtral-8x22b.yml index b6d1fc71c0..778882a5c0 100644 --- a/src/maxtext/configs/models/mixtral-8x22b.yml +++ b/src/maxtext/configs/models/mixtral-8x22b.yml @@ -31,3 +31,6 @@ num_experts: 8 num_experts_per_tok: 2 rope_max_timescale: 1_000_000 decoder_block: "mixtral" +# Few large experts: keep the MoE dispatch/MLP GEMM expert-parallel (AllToAll) rather than +# letting the batch dim double-map onto 'expert' and fall back to FSDP-style collectives. +moe_dispatch_no_expert_sharding: true From 9644edd1896d0c87e365d30e5f36250bc5c0e626 Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Wed, 17 Jun 2026 20:11:52 +0000 Subject: [PATCH 3/3] Add MoE dispatch expert-sharding regression test. Asserts that with moe_dispatch_no_expert_sharding the expert dim is sharded by 'expert' and the batch dim is not, guarding the expert-parallel dispatch/MLP sharding. --- tests/unit/moe_test.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/unit/moe_test.py b/tests/unit/moe_test.py index 51b7d1ba3e..9bd0000076 100644 --- a/tests/unit/moe_test.py +++ b/tests/unit/moe_test.py @@ -1534,5 +1534,38 @@ def test_prefused_vs_sparse_softmax(self): self.assertIsNone(bias_updates) +@pytest.mark.parametrize("model_name", ["mixtral-8x7b", "mixtral-8x22b", "deepseek3-671b"]) +def test_moe_dispatch_keeps_expert_on_expert_dim(model_name): + """Regression guard for the MoE dispatch/MLP expert-parallel sharding. + + The expert (E) dim must always be sharded by the 'expert' mesh axis. When + moe_dispatch_no_expert_sharding is set, the batch (B) dim must NOT also take + 'expert' (which would double-map two tensor dims onto one mesh axis and force an + FSDP-style fallback instead of expert-parallel AllToAll). Mirrors dense_matmul's + axis selection. + """ + cfg = pyconfig.initialize( + [None, get_test_config_path()], + run_name=f"moe_shard_{model_name}", + enable_checkpointing=False, + model_name=model_name, + ) + rules = cfg.logical_axis_rules + batch_axis = "activation_batch_no_exp" if cfg.moe_dispatch_no_expert_sharding else "activation_batch_moe" + spec = nn_partitioning.logical_to_mesh_axes(("activation_exp", batch_axis, None, "activation_embed_moe"), rules=rules) + + def _as_set(entry): + if entry is None: + return set() + return {entry} if isinstance(entry, str) else set(entry) + + e_axes, b_axes = _as_set(spec[0]), _as_set(spec[1]) + if cfg.moe_dispatch_no_expert_sharding: + # EP enabled: the expert dim must be sharded by 'expert' and the batch dim must not steal it. + assert "expert" in e_axes, "expert dim must be sharded by the 'expert' mesh axis" + assert "expert" not in b_axes, "with moe_dispatch_no_expert_sharding the batch dim must not take 'expert'" + # else: model keeps the default activation_batch_moe; don't assert the (intentional) collision. + + if __name__ == "__main__": unittest.main()