diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 9031cf4298..259b095a09 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -218,6 +218,9 @@ load_balance_loss_weight: 0.0 # weight for the load balance loss use_random_routing: false # whether to use random routing for debug/test purpose use_custom_sort_vjp: true # whether to use a custom VJP sort for efficient backward pass processing in sparse matmul use_ring_of_experts: false # whether to use ring of experts for sparse matmul expert parallelism +# If true, peel the 'expert' mesh axis off the MoE dispatch/MLP batch dim so the expert GEMM +# stays expert-parallel (AllToAll); false keeps 'expert' on the batch dim (activation_batch_moe). +moe_dispatch_no_expert_sharding: false use_ragged_sort: false # whether to use the Pallas ragged-sort kernels in the MoE permute path; valid both with and # without `use_ring_of_experts` (with EP > 1). When `use_ring_of_experts=True` the kernels run # inside `permute`/`unpermute`; otherwise they run inside `local_permute`/local-unpermute. diff --git a/src/maxtext/configs/models/mixtral-8x22b.yml b/src/maxtext/configs/models/mixtral-8x22b.yml index b6d1fc71c0..778882a5c0 100644 --- a/src/maxtext/configs/models/mixtral-8x22b.yml +++ b/src/maxtext/configs/models/mixtral-8x22b.yml @@ -31,3 +31,6 @@ num_experts: 8 num_experts_per_tok: 2 rope_max_timescale: 1_000_000 decoder_block: "mixtral" +# Few large experts: keep the MoE dispatch/MLP GEMM expert-parallel (AllToAll) rather than +# letting the batch dim double-map onto 'expert' and fall back to FSDP-style collectives. +moe_dispatch_no_expert_sharding: true diff --git a/src/maxtext/configs/models/mixtral-8x7b.yml b/src/maxtext/configs/models/mixtral-8x7b.yml index 9528a667c5..211030cec8 100644 --- a/src/maxtext/configs/models/mixtral-8x7b.yml +++ b/src/maxtext/configs/models/mixtral-8x7b.yml @@ -31,3 +31,6 @@ num_experts: 8 num_experts_per_tok: 2 rope_max_timescale: 1_000_000 decoder_block: "mixtral" +# Few large experts: keep the MoE dispatch/MLP GEMM expert-parallel (AllToAll) rather than +# letting the batch dim double-map onto 'expert' and fall back to FSDP-style collectives. +moe_dispatch_no_expert_sharding: true diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 1eca954e53..d5173f3d3e 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -736,6 +736,13 @@ class MoEGeneral(BaseModel): False, description="Whether to use Ring of Experts for sparse matmul expert parallelism.", ) + moe_dispatch_no_expert_sharding: bool = Field( + False, + description=( + "If true, shard the MoE dispatch/MLP batch dim without 'expert' so the expert GEMM " + "stays expert-parallel (AllToAll); false keeps 'expert' on it (activation_batch_moe)." + ), + ) use_ragged_sort: bool = Field( False, description="Whether to use ragged kernel for sorting, improve performance when EP is enabled." ) diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 020956098c..539e6d4e05 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -43,7 +43,7 @@ from maxtext.utils import max_utils from maxtext.utils import maxtext_utils from maxtext.utils.sharding import create_sharding, maybe_shard_with_logical, maybe_shard_with_pspec -from maxtext.utils.sharding import logical_to_mesh_axes +from maxtext.utils.sharding import logical_to_mesh_axes, remove_expert_from_partition_spec import numpy as np import qwix from qwix.contrib.sparsity import sparsity_module @@ -624,6 +624,18 @@ def _maybe_shard_with_pspec(self, inputs, pspec: jax.sharding.PartitionSpec | No extra_stack_level=1, ) + def _maybe_shard_moe_dispatch(self, inputs, logical_axis, peel_expert): + """Shard a MoE dispatch/MLP activation. When `peel_expert` is set, drop the 'expert' + mesh axis from the batch dim (index 1) so the GEMM stays expert-parallel (AllToAll) + instead of double-mapping E and B onto 'expert'. Each logical dim is resolved + independently so the shared 'expert' axis is not deduped off the expert dim before + the peel.""" + if not peel_expert: + return self._maybe_shard_with_logical(inputs, logical_axis) + spec = [None if name is None else self._logical_to_mesh_axes((name,))[0] for name in logical_axis] + pspec = remove_expert_from_partition_spec(jax.sharding.PartitionSpec(*spec), dims_to_peel=(1,)) + return self._maybe_shard_with_pspec(inputs, pspec) + def get_expert_parallelism_size(self): # When expert parallelism has more than one physical axes, take product of their shapes if isinstance(self._expert_parallelism_name, tuple): @@ -2170,12 +2182,18 @@ def dense_matmul( if self.config.capacity_factor > 0: # token dropping if needed + moe_peel_expert = False # only the training dispatch/MLP path peels 'expert' from the batch dim if self.config.model_call_mode != "inference": # TODO(b/425930949): remove this pylint by refactoring the logic here. dispatch_mask, combine_mask = self.generate_masks( top_k_indices, weights # pylint: disable=undefined-variable,possibly-used-before-assignment ) mask_axes = ("activation_batch_moe", "activation_norm_length_moe", None, None) + # Dispatch/MLP are already expert-sharded via "activation_exp". With + # moe_dispatch_no_expert_sharding we peel 'expert' off the batch dim of these specs + # (see _maybe_shard_moe_dispatch) so the GEMM stays expert-parallel (AllToAll) instead + # of double-mapping E and B onto 'expert' (FSDP-style fallback). + moe_peel_expert = self.config.moe_dispatch_no_expert_sharding dispatch_axis = ( "activation_exp", "activation_batch_moe", @@ -2280,10 +2298,7 @@ def dense_matmul( "activation_embed_moe", ), ) - dispatch = self._maybe_shard_with_logical( - dispatch, - dispatch_axis, - ) + dispatch = self._maybe_shard_moe_dispatch(dispatch, dispatch_axis, moe_peel_expert) with jax.named_scope("wi_0"): w0_kernel_axes = ("exp", None, "mlp") w0_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w0_kernel, w0_kernel_axes) @@ -2296,10 +2311,7 @@ def dense_matmul( if self.config.activations_in_float32: layer_w0 = layer_w0.astype(jnp.float32) - layer_w0 = self._maybe_shard_with_logical( - layer_w0, - mlp_axis, - ) + layer_w0 = self._maybe_shard_moe_dispatch(layer_w0, mlp_axis, moe_peel_expert) layer_w0 = adc.checkpoint_name(adc.checkpoint_name(layer_w0, "mlpwi_0"), "moe_mlpwi_0") with jax.named_scope("wi_1"): w1_kernel_axes = ("exp", None, "mlp") @@ -2312,10 +2324,7 @@ def dense_matmul( layer_w1 = layer_w1 + w1_bias if self.config.activations_in_float32: layer_w1 = layer_w1.astype(jnp.float32) - layer_w1 = self._maybe_shard_with_logical( - layer_w1, - mlp_axis, - ) + layer_w1 = self._maybe_shard_moe_dispatch(layer_w1, mlp_axis, moe_peel_expert) layer_w1 = adc.checkpoint_name(adc.checkpoint_name(layer_w1, "mlpwi_1"), "moe_mlpwi_1") layer_multiply = self.apply_ffn_activation(layer_w0, layer_w1) with jax.named_scope("wo"): diff --git a/src/maxtext/utils/sharding.py b/src/maxtext/utils/sharding.py index 4a500e2fe1..50115cae72 100644 --- a/src/maxtext/utils/sharding.py +++ b/src/maxtext/utils/sharding.py @@ -732,6 +732,32 @@ def _remove_fsdp_from_partition_spec(named_sharding): return jax.tree.map(_remove_fsdp_from_partition_spec, sharding_tree) +def remove_expert_from_partition_spec(pspec, dims_to_peel): + """Return `pspec` with the 'expert' mesh axis removed from the given dim indices. + + Used by the MoE dispatch/MLP sharding: the expert dim is already sharded over the + 'expert' mesh axis via the `activation_exp` rule, so the batch dim must not also map + to 'expert' (that double-maps two tensor dims onto one mesh axis and makes GSPMD fall + back to FSDP-style AllGather+ReduceScatter instead of expert-parallel AllToAll). Only + the dims listed in `dims_to_peel` (the batch dim) are modified; the expert dim is left + untouched. Avoids needing a separate `activation_batch_no_exp` logical rule that every + `custom_mesh_and_rule` set would have to redefine. + """ + new_spec = list(pspec) + for i in dims_to_peel: + axis = new_spec[i] + if axis is None: + continue + if isinstance(axis, str): + new_spec[i] = None if axis == "expert" else axis + elif isinstance(axis, (list, tuple)): + filtered = tuple(a for a in axis if a != "expert") + new_spec[i] = filtered or None + else: + raise ValueError(f"Unsupported axis type: {type(axis)}") + return jax.sharding.PartitionSpec(*new_spec) + + def get_physical_spec_no_fsdp(full_logical, mesh, logical_axis_rules): """ Generates a physical sharding spec for fully replicated weights. diff --git a/tests/unit/moe_test.py b/tests/unit/moe_test.py index 18c675948d..0c1b783492 100644 --- a/tests/unit/moe_test.py +++ b/tests/unit/moe_test.py @@ -30,6 +30,7 @@ from maxtext.layers.initializers import NdInitializer, nd_dense_init, variable_to_logically_partitioned from maxtext.layers.quantizations import Fp8Quantization from maxtext.utils import maxtext_utils +from maxtext.utils.sharding import remove_expert_from_partition_spec from tests.utils.test_helpers import get_test_config_path import pytest @@ -1534,5 +1535,53 @@ def test_prefused_vs_sparse_softmax(self): self.assertIsNone(bias_updates) +@pytest.mark.parametrize( + "model_name,override_flag", + [ + ("mixtral-8x7b", None), # opts in via model config (flag on) + ("mixtral-8x22b", None), # opts in via model config (flag on) + ("mixtral-8x7b", False), # force the flag off: default must keep 'expert' on the batch dim + ], +) +def test_moe_dispatch_keeps_expert_on_expert_dim(model_name, override_flag): + """Regression guard for the MoE dispatch/MLP expert-parallel sharding. + + The expert (E) dim is always sharded by the 'expert' mesh axis (via activation_exp). + With moe_dispatch_no_expert_sharding the batch (B) dim must NOT also take 'expert' + (which would double-map two tensor dims onto one mesh axis and force an FSDP-style + fallback instead of expert-parallel AllToAll); with the flag off, the default keeps + 'expert' on the batch dim. Mirrors dense_matmul's axis selection. + """ + init_kwargs = { + "run_name": f"moe_shard_{model_name}_{override_flag}", + "enable_checkpointing": False, + "model_name": model_name, + } + if override_flag is not None: + init_kwargs["moe_dispatch_no_expert_sharding"] = override_flag + init_kwargs["override_model_config"] = True + cfg = pyconfig.initialize([None, get_test_config_path()], **init_kwargs) + rules = cfg.logical_axis_rules + + def _as_set(entry): + if entry is None: + return set() + return {entry} if isinstance(entry, str) else set(entry) + + # Mirror _maybe_shard_moe_dispatch: resolve E and batch dims independently (so the shared + # 'expert' axis isn't deduped off E), then peel 'expert' from the batch dim when the flag is set. + e_spec = nn_partitioning.logical_to_mesh_axes(("activation_exp",), rules=rules) + b_spec = nn_partitioning.logical_to_mesh_axes(("activation_batch_moe",), rules=rules) + if cfg.moe_dispatch_no_expert_sharding: + b_spec = remove_expert_from_partition_spec(b_spec, dims_to_peel=(0,)) + + e_axes, b_axes = _as_set(e_spec[0]), _as_set(b_spec[0]) + assert "expert" in e_axes, "expert dim must be sharded by the 'expert' mesh axis" + if cfg.moe_dispatch_no_expert_sharding: + assert "expert" not in b_axes, "flag on: the batch dim must not take 'expert'" + else: + assert "expert" in b_axes, "flag off (default): the batch dim keeps 'expert' (activation_batch_moe)" + + if __name__ == "__main__": unittest.main()