Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,9 @@ expert_balance: False
use_random_routing: false # whether to use random routing for debug/test purpose
use_custom_sort_vjp: true # whether to use a custom VJP sort for efficient backward pass processing in sparse matmul
use_ring_of_experts: false # whether to use ring of experts for sparse matmul expert parallelism
# If true, shard the MoE dispatch/MLP batch dim without 'expert' (activation_batch_no_exp) so the
# expert GEMM stays expert-parallel (AllToAll); false keeps 'expert' on it (activation_batch_moe).
moe_dispatch_no_expert_sharding: false
use_ragged_sort: false # whether to use the Pallas ragged-sort kernels in the MoE permute path; valid both with and
# without `use_ring_of_experts` (with EP > 1). When `use_ring_of_experts=True` the kernels run
# inside `permute`/`unpermute`; otherwise they run inside `local_permute`/local-unpermute.
Expand Down Expand Up @@ -527,6 +530,8 @@ logical_axis_rules: [
# ==========================================
# MoE Activations
['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
# Batch axis WITHOUT 'expert' so MoE dispatch/mlp GEMM stays expert-parallel (AllToAll), not FSDP.
['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
['activation_length_moe', ['context']],
['activation_norm_length_moe', ['tensor_sequence', 'context']],
['activation_embed_moe', ['tensor', 'tensor_transpose']],
Expand Down
3 changes: 3 additions & 0 deletions src/maxtext/configs/models/mixtral-8x22b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,6 @@ num_experts: 8
num_experts_per_tok: 2
rope_max_timescale: 1_000_000
decoder_block: "mixtral"
# Few large experts: keep the MoE dispatch/MLP GEMM expert-parallel (AllToAll) rather than
# letting the batch dim double-map onto 'expert' and fall back to FSDP-style collectives.
moe_dispatch_no_expert_sharding: true
3 changes: 3 additions & 0 deletions src/maxtext/configs/models/mixtral-8x7b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,6 @@ num_experts: 8
num_experts_per_tok: 2
rope_max_timescale: 1_000_000
decoder_block: "mixtral"
# Few large experts: keep the MoE dispatch/MLP GEMM expert-parallel (AllToAll) rather than
# letting the batch dim double-map onto 'expert' and fall back to FSDP-style collectives.
moe_dispatch_no_expert_sharding: true
7 changes: 7 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,13 @@ class MoEGeneral(BaseModel):
False,
description="Whether to use Ring of Experts for sparse matmul expert parallelism.",
)
moe_dispatch_no_expert_sharding: bool = Field(
False,
description=(
"If true, shard the MoE dispatch/MLP batch dim without 'expert' so the expert GEMM "
"stays expert-parallel (AllToAll); false keeps 'expert' on it (activation_batch_moe)."
),
)
use_ragged_sort: bool = Field(
False, description="Whether to use ragged kernel for sorting, improve performance when EP is enabled."
)
Expand Down
9 changes: 7 additions & 2 deletions src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2197,15 +2197,20 @@ def dense_matmul(
top_k_indices, weights # pylint: disable=undefined-variable,possibly-used-before-assignment
)
mask_axes = ("activation_batch_moe", "activation_norm_length_moe", None, None)
# Dispatch/MLP are already expert-sharded via "activation_exp"; with no_expert_sharding the
# batch dim drops 'expert' so the GEMM stays expert-parallel (AllToAll) instead of FSDP.
moe_dispatch_batch_axis = (
"activation_batch_no_exp" if self.config.moe_dispatch_no_expert_sharding else "activation_batch_moe"
)
dispatch_axis = (
"activation_exp",
"activation_batch_moe",
moe_dispatch_batch_axis,
None,
"activation_embed_moe",
)
mlp_axis = (
"activation_exp",
"activation_batch_moe",
moe_dispatch_batch_axis,
None,
"activation_mlp",
)
Expand Down
33 changes: 33 additions & 0 deletions tests/unit/moe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1534,5 +1534,38 @@ def test_prefused_vs_sparse_softmax(self):
self.assertIsNone(bias_updates)


@pytest.mark.parametrize("model_name", ["mixtral-8x7b", "mixtral-8x22b", "deepseek3-671b"])
def test_moe_dispatch_keeps_expert_on_expert_dim(model_name):
"""Regression guard for the MoE dispatch/MLP expert-parallel sharding.

The expert (E) dim must always be sharded by the 'expert' mesh axis. When
moe_dispatch_no_expert_sharding is set, the batch (B) dim must NOT also take
'expert' (which would double-map two tensor dims onto one mesh axis and force an
FSDP-style fallback instead of expert-parallel AllToAll). Mirrors dense_matmul's
axis selection.
"""
cfg = pyconfig.initialize(
[None, get_test_config_path()],
run_name=f"moe_shard_{model_name}",
enable_checkpointing=False,
model_name=model_name,
)
rules = cfg.logical_axis_rules
batch_axis = "activation_batch_no_exp" if cfg.moe_dispatch_no_expert_sharding else "activation_batch_moe"
spec = nn_partitioning.logical_to_mesh_axes(("activation_exp", batch_axis, None, "activation_embed_moe"), rules=rules)

def _as_set(entry):
if entry is None:
return set()
return {entry} if isinstance(entry, str) else set(entry)

e_axes, b_axes = _as_set(spec[0]), _as_set(spec[1])
if cfg.moe_dispatch_no_expert_sharding:
# EP enabled: the expert dim must be sharded by 'expert' and the batch dim must not steal it.
assert "expert" in e_axes, "expert dim must be sharded by the 'expert' mesh axis"
assert "expert" not in b_axes, "with moe_dispatch_no_expert_sharding the batch dim must not take 'expert'"
# else: model keeps the default activation_batch_moe; don't assert the (intentional) collision.


if __name__ == "__main__":
unittest.main()
Loading