Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,9 @@ load_balance_loss_weight: 0.0 # weight for the load balance loss
use_random_routing: false # whether to use random routing for debug/test purpose
use_custom_sort_vjp: true # whether to use a custom VJP sort for efficient backward pass processing in sparse matmul
use_ring_of_experts: false # whether to use ring of experts for sparse matmul expert parallelism
# If true, peel the 'expert' mesh axis off the MoE dispatch/MLP batch dim so the expert GEMM
# stays expert-parallel (AllToAll); false keeps 'expert' on the batch dim (activation_batch_moe).
moe_dispatch_no_expert_sharding: false
use_ragged_sort: false # whether to use the Pallas ragged-sort kernels in the MoE permute path; valid both with and
# without `use_ring_of_experts` (with EP > 1). When `use_ring_of_experts=True` the kernels run
# inside `permute`/`unpermute`; otherwise they run inside `local_permute`/local-unpermute.
Expand Down
3 changes: 3 additions & 0 deletions src/maxtext/configs/models/mixtral-8x22b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,6 @@ num_experts: 8
num_experts_per_tok: 2
rope_max_timescale: 1_000_000
decoder_block: "mixtral"
# Few large experts: keep the MoE dispatch/MLP GEMM expert-parallel (AllToAll) rather than
# letting the batch dim double-map onto 'expert' and fall back to FSDP-style collectives.
moe_dispatch_no_expert_sharding: true
3 changes: 3 additions & 0 deletions src/maxtext/configs/models/mixtral-8x7b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,6 @@ num_experts: 8
num_experts_per_tok: 2
rope_max_timescale: 1_000_000
decoder_block: "mixtral"
# Few large experts: keep the MoE dispatch/MLP GEMM expert-parallel (AllToAll) rather than
# letting the batch dim double-map onto 'expert' and fall back to FSDP-style collectives.
moe_dispatch_no_expert_sharding: true
Comment thread
gulsumgudukbay marked this conversation as resolved.
7 changes: 7 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,13 @@ class MoEGeneral(BaseModel):
False,
description="Whether to use Ring of Experts for sparse matmul expert parallelism.",
)
moe_dispatch_no_expert_sharding: bool = Field(
False,
description=(
"If true, shard the MoE dispatch/MLP batch dim without 'expert' so the expert GEMM "
"stays expert-parallel (AllToAll); false keeps 'expert' on it (activation_batch_moe)."
),
)
use_ragged_sort: bool = Field(
False, description="Whether to use ragged kernel for sorting, improve performance when EP is enabled."
)
Expand Down
35 changes: 22 additions & 13 deletions src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from maxtext.utils import max_utils
from maxtext.utils import maxtext_utils
from maxtext.utils.sharding import create_sharding, maybe_shard_with_logical, maybe_shard_with_pspec
from maxtext.utils.sharding import logical_to_mesh_axes
from maxtext.utils.sharding import logical_to_mesh_axes, remove_expert_from_partition_spec
import numpy as np
import qwix
from qwix.contrib.sparsity import sparsity_module
Expand Down Expand Up @@ -624,6 +624,18 @@ def _maybe_shard_with_pspec(self, inputs, pspec: jax.sharding.PartitionSpec | No
extra_stack_level=1,
)

def _maybe_shard_moe_dispatch(self, inputs, logical_axis, peel_expert):
"""Shard a MoE dispatch/MLP activation. When `peel_expert` is set, drop the 'expert'
mesh axis from the batch dim (index 1) so the GEMM stays expert-parallel (AllToAll)
instead of double-mapping E and B onto 'expert'. Each logical dim is resolved
independently so the shared 'expert' axis is not deduped off the expert dim before
the peel."""
if not peel_expert:
return self._maybe_shard_with_logical(inputs, logical_axis)
spec = [None if name is None else self._logical_to_mesh_axes((name,))[0] for name in logical_axis]
pspec = remove_expert_from_partition_spec(jax.sharding.PartitionSpec(*spec), dims_to_peel=(1,))
return self._maybe_shard_with_pspec(inputs, pspec)

def get_expert_parallelism_size(self):
# When expert parallelism has more than one physical axes, take product of their shapes
if isinstance(self._expert_parallelism_name, tuple):
Expand Down Expand Up @@ -2170,12 +2182,18 @@ def dense_matmul(

if self.config.capacity_factor > 0:
# token dropping if needed
moe_peel_expert = False # only the training dispatch/MLP path peels 'expert' from the batch dim
if self.config.model_call_mode != "inference":
# TODO(b/425930949): remove this pylint by refactoring the logic here.
dispatch_mask, combine_mask = self.generate_masks(
top_k_indices, weights # pylint: disable=undefined-variable,possibly-used-before-assignment
)
mask_axes = ("activation_batch_moe", "activation_norm_length_moe", None, None)
# Dispatch/MLP are already expert-sharded via "activation_exp". With
# moe_dispatch_no_expert_sharding we peel 'expert' off the batch dim of these specs
# (see _maybe_shard_moe_dispatch) so the GEMM stays expert-parallel (AllToAll) instead
# of double-mapping E and B onto 'expert' (FSDP-style fallback).
moe_peel_expert = self.config.moe_dispatch_no_expert_sharding
dispatch_axis = (
"activation_exp",
"activation_batch_moe",
Expand Down Expand Up @@ -2280,10 +2298,7 @@ def dense_matmul(
"activation_embed_moe",
),
)
dispatch = self._maybe_shard_with_logical(
dispatch,
dispatch_axis,
)
dispatch = self._maybe_shard_moe_dispatch(dispatch, dispatch_axis, moe_peel_expert)
with jax.named_scope("wi_0"):
w0_kernel_axes = ("exp", None, "mlp")
w0_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w0_kernel, w0_kernel_axes)
Expand All @@ -2296,10 +2311,7 @@ def dense_matmul(

if self.config.activations_in_float32:
layer_w0 = layer_w0.astype(jnp.float32)
layer_w0 = self._maybe_shard_with_logical(
layer_w0,
mlp_axis,
)
layer_w0 = self._maybe_shard_moe_dispatch(layer_w0, mlp_axis, moe_peel_expert)
layer_w0 = adc.checkpoint_name(adc.checkpoint_name(layer_w0, "mlpwi_0"), "moe_mlpwi_0")
with jax.named_scope("wi_1"):
w1_kernel_axes = ("exp", None, "mlp")
Expand All @@ -2312,10 +2324,7 @@ def dense_matmul(
layer_w1 = layer_w1 + w1_bias
if self.config.activations_in_float32:
layer_w1 = layer_w1.astype(jnp.float32)
layer_w1 = self._maybe_shard_with_logical(
layer_w1,
mlp_axis,
)
layer_w1 = self._maybe_shard_moe_dispatch(layer_w1, mlp_axis, moe_peel_expert)
layer_w1 = adc.checkpoint_name(adc.checkpoint_name(layer_w1, "mlpwi_1"), "moe_mlpwi_1")
layer_multiply = self.apply_ffn_activation(layer_w0, layer_w1)
with jax.named_scope("wo"):
Expand Down
26 changes: 26 additions & 0 deletions src/maxtext/utils/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,32 @@ def _remove_fsdp_from_partition_spec(named_sharding):
return jax.tree.map(_remove_fsdp_from_partition_spec, sharding_tree)


def remove_expert_from_partition_spec(pspec, dims_to_peel):
"""Return `pspec` with the 'expert' mesh axis removed from the given dim indices.

Used by the MoE dispatch/MLP sharding: the expert dim is already sharded over the
'expert' mesh axis via the `activation_exp` rule, so the batch dim must not also map
to 'expert' (that double-maps two tensor dims onto one mesh axis and makes GSPMD fall
back to FSDP-style AllGather+ReduceScatter instead of expert-parallel AllToAll). Only
the dims listed in `dims_to_peel` (the batch dim) are modified; the expert dim is left
untouched. Avoids needing a separate `activation_batch_no_exp` logical rule that every
`custom_mesh_and_rule` set would have to redefine.
"""
new_spec = list(pspec)
for i in dims_to_peel:
axis = new_spec[i]
if axis is None:
continue
if isinstance(axis, str):
new_spec[i] = None if axis == "expert" else axis
elif isinstance(axis, (list, tuple)):
filtered = tuple(a for a in axis if a != "expert")
new_spec[i] = filtered or None
else:
raise ValueError(f"Unsupported axis type: {type(axis)}")
return jax.sharding.PartitionSpec(*new_spec)


def get_physical_spec_no_fsdp(full_logical, mesh, logical_axis_rules):
"""
Generates a physical sharding spec for fully replicated weights.
Expand Down
49 changes: 49 additions & 0 deletions tests/unit/moe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from maxtext.layers.initializers import NdInitializer, nd_dense_init, variable_to_logically_partitioned
from maxtext.layers.quantizations import Fp8Quantization
from maxtext.utils import maxtext_utils
from maxtext.utils.sharding import remove_expert_from_partition_spec
from tests.utils.test_helpers import get_test_config_path
import pytest

Expand Down Expand Up @@ -1534,5 +1535,53 @@ def test_prefused_vs_sparse_softmax(self):
self.assertIsNone(bias_updates)


@pytest.mark.parametrize(
"model_name,override_flag",
[
("mixtral-8x7b", None), # opts in via model config (flag on)
("mixtral-8x22b", None), # opts in via model config (flag on)
("mixtral-8x7b", False), # force the flag off: default must keep 'expert' on the batch dim
],
)
def test_moe_dispatch_keeps_expert_on_expert_dim(model_name, override_flag):
"""Regression guard for the MoE dispatch/MLP expert-parallel sharding.

The expert (E) dim is always sharded by the 'expert' mesh axis (via activation_exp).
With moe_dispatch_no_expert_sharding the batch (B) dim must NOT also take 'expert'
(which would double-map two tensor dims onto one mesh axis and force an FSDP-style
fallback instead of expert-parallel AllToAll); with the flag off, the default keeps
'expert' on the batch dim. Mirrors dense_matmul's axis selection.
"""
init_kwargs = {
"run_name": f"moe_shard_{model_name}_{override_flag}",
"enable_checkpointing": False,
"model_name": model_name,
}
if override_flag is not None:
init_kwargs["moe_dispatch_no_expert_sharding"] = override_flag
init_kwargs["override_model_config"] = True
cfg = pyconfig.initialize([None, get_test_config_path()], **init_kwargs)
rules = cfg.logical_axis_rules

def _as_set(entry):
if entry is None:
return set()
return {entry} if isinstance(entry, str) else set(entry)

# Mirror _maybe_shard_moe_dispatch: resolve E and batch dims independently (so the shared
# 'expert' axis isn't deduped off E), then peel 'expert' from the batch dim when the flag is set.
e_spec = nn_partitioning.logical_to_mesh_axes(("activation_exp",), rules=rules)
b_spec = nn_partitioning.logical_to_mesh_axes(("activation_batch_moe",), rules=rules)
if cfg.moe_dispatch_no_expert_sharding:
b_spec = remove_expert_from_partition_spec(b_spec, dims_to_peel=(0,))

e_axes, b_axes = _as_set(e_spec[0]), _as_set(b_spec[0])
assert "expert" in e_axes, "expert dim must be sharded by the 'expert' mesh axis"
if cfg.moe_dispatch_no_expert_sharding:
assert "expert" not in b_axes, "flag on: the batch dim must not take 'expert'"
else:
assert "expert" in b_axes, "flag off (default): the batch dim keeps 'expert' (activation_batch_moe)"


if __name__ == "__main__":
unittest.main()
Loading