From 4f86655d05eec1e956d73476b004331ebb652630 Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Tue, 16 Jun 2026 19:42:43 +0000 Subject: [PATCH 1/5] Make MoE dispatch/MLP expert-axis batch sharding configurable The dispatch/MLP MoE activations are already expert-sharded via activation_exp. Since #4007, their batch dim also maps to activation_batch_moe, which includes 'expert'. Under single-node expert parallelism (ici_expert_parallelism=-1) this double-maps two tensor dims onto the 'expert' mesh axis, so GSPMD falls back from expert-parallel AllToAll to FSDP-style AllGather+ReduceScatter, regressing throughput for few-large-expert models (e.g. Mixtral-8x7b: ~7.4k -> ~10.9k tok/s/device at bs=11 on 8x MI355X). Add a config flag moe_dispatch_no_expert_sharding (default false) that selects a new activation_batch_no_exp rule ([data, fsdp, fsdp_transpose], no 'expert') for the training dispatch/MLP batch axis. Enable it for mixtral-8x7b. Default-false keeps every other model and all TPU/non-EP paths byte-identical; the flag only changes sharding when the 'expert' mesh axis size > 1. --- src/maxtext/configs/base.yml | 5 +++++ src/maxtext/configs/models/mixtral-8x7b.yml | 3 +++ src/maxtext/configs/types.py | 7 +++++++ src/maxtext/layers/moe.py | 9 +++++++-- 4 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 9031cf4298..09ae8e6d48 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -218,6 +218,9 @@ load_balance_loss_weight: 0.0 # weight for the load balance loss use_random_routing: false # whether to use random routing for debug/test purpose use_custom_sort_vjp: true # whether to use a custom VJP sort for efficient backward pass processing in sparse matmul use_ring_of_experts: false # whether to use ring of experts for sparse matmul expert parallelism +# If true, shard the MoE dispatch/MLP batch dim without 'expert' (activation_batch_no_exp) so the +# expert GEMM stays expert-parallel (AllToAll); false keeps 'expert' on it (activation_batch_moe). +moe_dispatch_no_expert_sharding: false use_ragged_sort: false # whether to use the Pallas ragged-sort kernels in the MoE permute path; valid both with and # without `use_ring_of_experts` (with EP > 1). When `use_ring_of_experts=True` the kernels run # inside `permute`/`unpermute`; otherwise they run inside `local_permute`/local-unpermute. @@ -526,6 +529,8 @@ logical_axis_rules: [ # ========================================== # MoE Activations ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose', 'expert']], + # Batch axis WITHOUT 'expert' so MoE dispatch/mlp GEMM stays expert-parallel (AllToAll), not FSDP. + ['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']], ['activation_length_moe', ['context']], ['activation_norm_length_moe', ['tensor_sequence', 'context']], ['activation_embed_moe', ['tensor', 'tensor_transpose']], diff --git a/src/maxtext/configs/models/mixtral-8x7b.yml b/src/maxtext/configs/models/mixtral-8x7b.yml index 9528a667c5..211030cec8 100644 --- a/src/maxtext/configs/models/mixtral-8x7b.yml +++ b/src/maxtext/configs/models/mixtral-8x7b.yml @@ -31,3 +31,6 @@ num_experts: 8 num_experts_per_tok: 2 rope_max_timescale: 1_000_000 decoder_block: "mixtral" +# Few large experts: keep the MoE dispatch/MLP GEMM expert-parallel (AllToAll) rather than +# letting the batch dim double-map onto 'expert' and fall back to FSDP-style collectives. +moe_dispatch_no_expert_sharding: true diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 1eca954e53..d5173f3d3e 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -736,6 +736,13 @@ class MoEGeneral(BaseModel): False, description="Whether to use Ring of Experts for sparse matmul expert parallelism.", ) + moe_dispatch_no_expert_sharding: bool = Field( + False, + description=( + "If true, shard the MoE dispatch/MLP batch dim without 'expert' so the expert GEMM " + "stays expert-parallel (AllToAll); false keeps 'expert' on it (activation_batch_moe)." + ), + ) use_ragged_sort: bool = Field( False, description="Whether to use ragged kernel for sorting, improve performance when EP is enabled." ) diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 020956098c..ce37a52e54 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -2176,15 +2176,20 @@ def dense_matmul( top_k_indices, weights # pylint: disable=undefined-variable,possibly-used-before-assignment ) mask_axes = ("activation_batch_moe", "activation_norm_length_moe", None, None) + # Dispatch/MLP are already expert-sharded via "activation_exp"; with no_expert_sharding the + # batch dim drops 'expert' so the GEMM stays expert-parallel (AllToAll) instead of FSDP. + moe_dispatch_batch_axis = ( + "activation_batch_no_exp" if self.config.moe_dispatch_no_expert_sharding else "activation_batch_moe" + ) dispatch_axis = ( "activation_exp", - "activation_batch_moe", + moe_dispatch_batch_axis, None, "activation_embed_moe", ) mlp_axis = ( "activation_exp", - "activation_batch_moe", + moe_dispatch_batch_axis, None, "activation_mlp", ) From 859527540de5c43a9468dd2085bdc23fe6dbe837 Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Wed, 17 Jun 2026 20:11:20 +0000 Subject: [PATCH 2/5] Enable moe_dispatch_no_expert_sharding for mixtral-8x22b. Same 8-large-expert geometry as 8x7b, so it benefits from the same expert-parallel MoE dispatch/MLP sharding. --- src/maxtext/configs/models/mixtral-8x22b.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/maxtext/configs/models/mixtral-8x22b.yml b/src/maxtext/configs/models/mixtral-8x22b.yml index b6d1fc71c0..778882a5c0 100644 --- a/src/maxtext/configs/models/mixtral-8x22b.yml +++ b/src/maxtext/configs/models/mixtral-8x22b.yml @@ -31,3 +31,6 @@ num_experts: 8 num_experts_per_tok: 2 rope_max_timescale: 1_000_000 decoder_block: "mixtral" +# Few large experts: keep the MoE dispatch/MLP GEMM expert-parallel (AllToAll) rather than +# letting the batch dim double-map onto 'expert' and fall back to FSDP-style collectives. +moe_dispatch_no_expert_sharding: true From b6b511c6dfaa7d5db85f4543c32029ac1855bba6 Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Wed, 17 Jun 2026 20:11:52 +0000 Subject: [PATCH 3/5] Add MoE dispatch expert-sharding regression test. Asserts that with moe_dispatch_no_expert_sharding the expert dim is sharded by 'expert' and the batch dim is not, guarding the expert-parallel dispatch/MLP sharding. --- tests/unit/moe_test.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/unit/moe_test.py b/tests/unit/moe_test.py index 18c675948d..15e031fbff 100644 --- a/tests/unit/moe_test.py +++ b/tests/unit/moe_test.py @@ -1534,5 +1534,38 @@ def test_prefused_vs_sparse_softmax(self): self.assertIsNone(bias_updates) +@pytest.mark.parametrize("model_name", ["mixtral-8x7b", "mixtral-8x22b", "deepseek3-671b"]) +def test_moe_dispatch_keeps_expert_on_expert_dim(model_name): + """Regression guard for the MoE dispatch/MLP expert-parallel sharding. + + The expert (E) dim must always be sharded by the 'expert' mesh axis. When + moe_dispatch_no_expert_sharding is set, the batch (B) dim must NOT also take + 'expert' (which would double-map two tensor dims onto one mesh axis and force an + FSDP-style fallback instead of expert-parallel AllToAll). Mirrors dense_matmul's + axis selection. + """ + cfg = pyconfig.initialize( + [None, get_test_config_path()], + run_name=f"moe_shard_{model_name}", + enable_checkpointing=False, + model_name=model_name, + ) + rules = cfg.logical_axis_rules + batch_axis = "activation_batch_no_exp" if cfg.moe_dispatch_no_expert_sharding else "activation_batch_moe" + spec = nn_partitioning.logical_to_mesh_axes(("activation_exp", batch_axis, None, "activation_embed_moe"), rules=rules) + + def _as_set(entry): + if entry is None: + return set() + return {entry} if isinstance(entry, str) else set(entry) + + e_axes, b_axes = _as_set(spec[0]), _as_set(spec[1]) + if cfg.moe_dispatch_no_expert_sharding: + # EP enabled: the expert dim must be sharded by 'expert' and the batch dim must not steal it. + assert "expert" in e_axes, "expert dim must be sharded by the 'expert' mesh axis" + assert "expert" not in b_axes, "with moe_dispatch_no_expert_sharding the batch dim must not take 'expert'" + # else: model keeps the default activation_batch_moe; don't assert the (intentional) collision. + + if __name__ == "__main__": unittest.main() From d16d561b54b1d2077c8c652e7a7b81da07e68992 Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Thu, 18 Jun 2026 17:31:57 +0000 Subject: [PATCH 4/5] Address review: peel 'expert' from MoE dispatch/MLP batch dim instead of a new logical rule Replace the activation_batch_no_exp logical rule with a remove_expert_from_partition_spec util (mirrors remove_fsdp_sharding), applied at the training dispatch/MLP sites when moe_dispatch_no_expert_sharding is set. Avoids a logical name that every custom_mesh_and_rule set would have to redefine. Same result (verified on 8xMI355X): Mixtral stays expert-parallel (a2a=5, ~11k tok/s/device), DeepSeek unchanged (a2a=0, ~17.8k). --- src/maxtext/configs/base.yml | 6 ++--- src/maxtext/layers/moe.py | 44 +++++++++++++++++++---------------- src/maxtext/utils/sharding.py | 26 +++++++++++++++++++++ tests/unit/moe_test.py | 12 +++++++--- 4 files changed, 61 insertions(+), 27 deletions(-) diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 09ae8e6d48..259b095a09 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -218,8 +218,8 @@ load_balance_loss_weight: 0.0 # weight for the load balance loss use_random_routing: false # whether to use random routing for debug/test purpose use_custom_sort_vjp: true # whether to use a custom VJP sort for efficient backward pass processing in sparse matmul use_ring_of_experts: false # whether to use ring of experts for sparse matmul expert parallelism -# If true, shard the MoE dispatch/MLP batch dim without 'expert' (activation_batch_no_exp) so the -# expert GEMM stays expert-parallel (AllToAll); false keeps 'expert' on it (activation_batch_moe). +# If true, peel the 'expert' mesh axis off the MoE dispatch/MLP batch dim so the expert GEMM +# stays expert-parallel (AllToAll); false keeps 'expert' on the batch dim (activation_batch_moe). moe_dispatch_no_expert_sharding: false use_ragged_sort: false # whether to use the Pallas ragged-sort kernels in the MoE permute path; valid both with and # without `use_ring_of_experts` (with EP > 1). When `use_ring_of_experts=True` the kernels run @@ -529,8 +529,6 @@ logical_axis_rules: [ # ========================================== # MoE Activations ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose', 'expert']], - # Batch axis WITHOUT 'expert' so MoE dispatch/mlp GEMM stays expert-parallel (AllToAll), not FSDP. - ['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']], ['activation_length_moe', ['context']], ['activation_norm_length_moe', ['tensor_sequence', 'context']], ['activation_embed_moe', ['tensor', 'tensor_transpose']], diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index ce37a52e54..539e6d4e05 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -43,7 +43,7 @@ from maxtext.utils import max_utils from maxtext.utils import maxtext_utils from maxtext.utils.sharding import create_sharding, maybe_shard_with_logical, maybe_shard_with_pspec -from maxtext.utils.sharding import logical_to_mesh_axes +from maxtext.utils.sharding import logical_to_mesh_axes, remove_expert_from_partition_spec import numpy as np import qwix from qwix.contrib.sparsity import sparsity_module @@ -624,6 +624,18 @@ def _maybe_shard_with_pspec(self, inputs, pspec: jax.sharding.PartitionSpec | No extra_stack_level=1, ) + def _maybe_shard_moe_dispatch(self, inputs, logical_axis, peel_expert): + """Shard a MoE dispatch/MLP activation. When `peel_expert` is set, drop the 'expert' + mesh axis from the batch dim (index 1) so the GEMM stays expert-parallel (AllToAll) + instead of double-mapping E and B onto 'expert'. Each logical dim is resolved + independently so the shared 'expert' axis is not deduped off the expert dim before + the peel.""" + if not peel_expert: + return self._maybe_shard_with_logical(inputs, logical_axis) + spec = [None if name is None else self._logical_to_mesh_axes((name,))[0] for name in logical_axis] + pspec = remove_expert_from_partition_spec(jax.sharding.PartitionSpec(*spec), dims_to_peel=(1,)) + return self._maybe_shard_with_pspec(inputs, pspec) + def get_expert_parallelism_size(self): # When expert parallelism has more than one physical axes, take product of their shapes if isinstance(self._expert_parallelism_name, tuple): @@ -2170,26 +2182,27 @@ def dense_matmul( if self.config.capacity_factor > 0: # token dropping if needed + moe_peel_expert = False # only the training dispatch/MLP path peels 'expert' from the batch dim if self.config.model_call_mode != "inference": # TODO(b/425930949): remove this pylint by refactoring the logic here. dispatch_mask, combine_mask = self.generate_masks( top_k_indices, weights # pylint: disable=undefined-variable,possibly-used-before-assignment ) mask_axes = ("activation_batch_moe", "activation_norm_length_moe", None, None) - # Dispatch/MLP are already expert-sharded via "activation_exp"; with no_expert_sharding the - # batch dim drops 'expert' so the GEMM stays expert-parallel (AllToAll) instead of FSDP. - moe_dispatch_batch_axis = ( - "activation_batch_no_exp" if self.config.moe_dispatch_no_expert_sharding else "activation_batch_moe" - ) + # Dispatch/MLP are already expert-sharded via "activation_exp". With + # moe_dispatch_no_expert_sharding we peel 'expert' off the batch dim of these specs + # (see _maybe_shard_moe_dispatch) so the GEMM stays expert-parallel (AllToAll) instead + # of double-mapping E and B onto 'expert' (FSDP-style fallback). + moe_peel_expert = self.config.moe_dispatch_no_expert_sharding dispatch_axis = ( "activation_exp", - moe_dispatch_batch_axis, + "activation_batch_moe", None, "activation_embed_moe", ) mlp_axis = ( "activation_exp", - moe_dispatch_batch_axis, + "activation_batch_moe", None, "activation_mlp", ) @@ -2285,10 +2298,7 @@ def dense_matmul( "activation_embed_moe", ), ) - dispatch = self._maybe_shard_with_logical( - dispatch, - dispatch_axis, - ) + dispatch = self._maybe_shard_moe_dispatch(dispatch, dispatch_axis, moe_peel_expert) with jax.named_scope("wi_0"): w0_kernel_axes = ("exp", None, "mlp") w0_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w0_kernel, w0_kernel_axes) @@ -2301,10 +2311,7 @@ def dense_matmul( if self.config.activations_in_float32: layer_w0 = layer_w0.astype(jnp.float32) - layer_w0 = self._maybe_shard_with_logical( - layer_w0, - mlp_axis, - ) + layer_w0 = self._maybe_shard_moe_dispatch(layer_w0, mlp_axis, moe_peel_expert) layer_w0 = adc.checkpoint_name(adc.checkpoint_name(layer_w0, "mlpwi_0"), "moe_mlpwi_0") with jax.named_scope("wi_1"): w1_kernel_axes = ("exp", None, "mlp") @@ -2317,10 +2324,7 @@ def dense_matmul( layer_w1 = layer_w1 + w1_bias if self.config.activations_in_float32: layer_w1 = layer_w1.astype(jnp.float32) - layer_w1 = self._maybe_shard_with_logical( - layer_w1, - mlp_axis, - ) + layer_w1 = self._maybe_shard_moe_dispatch(layer_w1, mlp_axis, moe_peel_expert) layer_w1 = adc.checkpoint_name(adc.checkpoint_name(layer_w1, "mlpwi_1"), "moe_mlpwi_1") layer_multiply = self.apply_ffn_activation(layer_w0, layer_w1) with jax.named_scope("wo"): diff --git a/src/maxtext/utils/sharding.py b/src/maxtext/utils/sharding.py index 4a500e2fe1..50115cae72 100644 --- a/src/maxtext/utils/sharding.py +++ b/src/maxtext/utils/sharding.py @@ -732,6 +732,32 @@ def _remove_fsdp_from_partition_spec(named_sharding): return jax.tree.map(_remove_fsdp_from_partition_spec, sharding_tree) +def remove_expert_from_partition_spec(pspec, dims_to_peel): + """Return `pspec` with the 'expert' mesh axis removed from the given dim indices. + + Used by the MoE dispatch/MLP sharding: the expert dim is already sharded over the + 'expert' mesh axis via the `activation_exp` rule, so the batch dim must not also map + to 'expert' (that double-maps two tensor dims onto one mesh axis and makes GSPMD fall + back to FSDP-style AllGather+ReduceScatter instead of expert-parallel AllToAll). Only + the dims listed in `dims_to_peel` (the batch dim) are modified; the expert dim is left + untouched. Avoids needing a separate `activation_batch_no_exp` logical rule that every + `custom_mesh_and_rule` set would have to redefine. + """ + new_spec = list(pspec) + for i in dims_to_peel: + axis = new_spec[i] + if axis is None: + continue + if isinstance(axis, str): + new_spec[i] = None if axis == "expert" else axis + elif isinstance(axis, (list, tuple)): + filtered = tuple(a for a in axis if a != "expert") + new_spec[i] = filtered or None + else: + raise ValueError(f"Unsupported axis type: {type(axis)}") + return jax.sharding.PartitionSpec(*new_spec) + + def get_physical_spec_no_fsdp(full_logical, mesh, logical_axis_rules): """ Generates a physical sharding spec for fully replicated weights. diff --git a/tests/unit/moe_test.py b/tests/unit/moe_test.py index 15e031fbff..ab921741a8 100644 --- a/tests/unit/moe_test.py +++ b/tests/unit/moe_test.py @@ -30,6 +30,7 @@ from maxtext.layers.initializers import NdInitializer, nd_dense_init, variable_to_logically_partitioned from maxtext.layers.quantizations import Fp8Quantization from maxtext.utils import maxtext_utils +from maxtext.utils.sharding import remove_expert_from_partition_spec from tests.utils.test_helpers import get_test_config_path import pytest @@ -1551,15 +1552,20 @@ def test_moe_dispatch_keeps_expert_on_expert_dim(model_name): model_name=model_name, ) rules = cfg.logical_axis_rules - batch_axis = "activation_batch_no_exp" if cfg.moe_dispatch_no_expert_sharding else "activation_batch_moe" - spec = nn_partitioning.logical_to_mesh_axes(("activation_exp", batch_axis, None, "activation_embed_moe"), rules=rules) def _as_set(entry): if entry is None: return set() return {entry} if isinstance(entry, str) else set(entry) - e_axes, b_axes = _as_set(spec[0]), _as_set(spec[1]) + # Mirror _maybe_shard_moe_dispatch: resolve E and batch dims independently (so the shared + # 'expert' axis isn't deduped off E), then peel 'expert' from the batch dim. + e_spec = nn_partitioning.logical_to_mesh_axes(("activation_exp",), rules=rules) + b_spec = nn_partitioning.logical_to_mesh_axes(("activation_batch_moe",), rules=rules) + if cfg.moe_dispatch_no_expert_sharding: + b_spec = remove_expert_from_partition_spec(b_spec, dims_to_peel=(0,)) + + e_axes, b_axes = _as_set(e_spec[0]), _as_set(b_spec[0]) if cfg.moe_dispatch_no_expert_sharding: # EP enabled: the expert dim must be sharded by 'expert' and the batch dim must not steal it. assert "expert" in e_axes, "expert dim must be sharded by the 'expert' mesh axis" From edb3286c2a64aff59ecf08b6f28c1cd04c6b5b19 Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Thu, 18 Jun 2026 19:21:15 +0000 Subject: [PATCH 5/5] Test both flag states for MoE dispatch sharding; drop deepseek3-671b --- tests/unit/moe_test.py | 46 +++++++++++++++++++++++++----------------- 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/tests/unit/moe_test.py b/tests/unit/moe_test.py index ab921741a8..0c1b783492 100644 --- a/tests/unit/moe_test.py +++ b/tests/unit/moe_test.py @@ -1535,22 +1535,32 @@ def test_prefused_vs_sparse_softmax(self): self.assertIsNone(bias_updates) -@pytest.mark.parametrize("model_name", ["mixtral-8x7b", "mixtral-8x22b", "deepseek3-671b"]) -def test_moe_dispatch_keeps_expert_on_expert_dim(model_name): +@pytest.mark.parametrize( + "model_name,override_flag", + [ + ("mixtral-8x7b", None), # opts in via model config (flag on) + ("mixtral-8x22b", None), # opts in via model config (flag on) + ("mixtral-8x7b", False), # force the flag off: default must keep 'expert' on the batch dim + ], +) +def test_moe_dispatch_keeps_expert_on_expert_dim(model_name, override_flag): """Regression guard for the MoE dispatch/MLP expert-parallel sharding. - The expert (E) dim must always be sharded by the 'expert' mesh axis. When - moe_dispatch_no_expert_sharding is set, the batch (B) dim must NOT also take - 'expert' (which would double-map two tensor dims onto one mesh axis and force an - FSDP-style fallback instead of expert-parallel AllToAll). Mirrors dense_matmul's - axis selection. + The expert (E) dim is always sharded by the 'expert' mesh axis (via activation_exp). + With moe_dispatch_no_expert_sharding the batch (B) dim must NOT also take 'expert' + (which would double-map two tensor dims onto one mesh axis and force an FSDP-style + fallback instead of expert-parallel AllToAll); with the flag off, the default keeps + 'expert' on the batch dim. Mirrors dense_matmul's axis selection. """ - cfg = pyconfig.initialize( - [None, get_test_config_path()], - run_name=f"moe_shard_{model_name}", - enable_checkpointing=False, - model_name=model_name, - ) + init_kwargs = { + "run_name": f"moe_shard_{model_name}_{override_flag}", + "enable_checkpointing": False, + "model_name": model_name, + } + if override_flag is not None: + init_kwargs["moe_dispatch_no_expert_sharding"] = override_flag + init_kwargs["override_model_config"] = True + cfg = pyconfig.initialize([None, get_test_config_path()], **init_kwargs) rules = cfg.logical_axis_rules def _as_set(entry): @@ -1559,18 +1569,18 @@ def _as_set(entry): return {entry} if isinstance(entry, str) else set(entry) # Mirror _maybe_shard_moe_dispatch: resolve E and batch dims independently (so the shared - # 'expert' axis isn't deduped off E), then peel 'expert' from the batch dim. + # 'expert' axis isn't deduped off E), then peel 'expert' from the batch dim when the flag is set. e_spec = nn_partitioning.logical_to_mesh_axes(("activation_exp",), rules=rules) b_spec = nn_partitioning.logical_to_mesh_axes(("activation_batch_moe",), rules=rules) if cfg.moe_dispatch_no_expert_sharding: b_spec = remove_expert_from_partition_spec(b_spec, dims_to_peel=(0,)) e_axes, b_axes = _as_set(e_spec[0]), _as_set(b_spec[0]) + assert "expert" in e_axes, "expert dim must be sharded by the 'expert' mesh axis" if cfg.moe_dispatch_no_expert_sharding: - # EP enabled: the expert dim must be sharded by 'expert' and the batch dim must not steal it. - assert "expert" in e_axes, "expert dim must be sharded by the 'expert' mesh axis" - assert "expert" not in b_axes, "with moe_dispatch_no_expert_sharding the batch dim must not take 'expert'" - # else: model keeps the default activation_batch_moe; don't assert the (intentional) collision. + assert "expert" not in b_axes, "flag on: the batch dim must not take 'expert'" + else: + assert "expert" in b_axes, "flag off (default): the batch dim keeps 'expert' (activation_batch_moe)" if __name__ == "__main__":