Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/reference/core_concepts/moe_configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ Dropping:

`mlp_bias`: If enabled, add learnable bias terms for MLP matmul. Originally implemented to support the GPT-OSS model architecture.

`fused_mlp`: If enabled alongside `sparse_matmul=True`, fuses the two FFN1 grouped GEMMs (wi\_0 and wi\_1) into a single grouped GEMM call. Expert weights are stored in a concatenated `(num_experts, embed_dim, 2 * mlp_dim)` shape, so input activations are loaded from HBM once per forward pass instead of twice. This is analogous to `fused_mlp` for dense models and is backend-agnostic (works with Megablox, JAX Ragged Dot, and Tokamax).

`use_batch_split_schedule` (experimental): If enabled, split batch into micro-batches to hide communications that yields performance benefits.

## 2. Sharding
Expand Down
69 changes: 44 additions & 25 deletions src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,6 @@ def quant_dot_general(self) -> nnx_wrappers.ToNNX | None:
return getattr(self, self._quant_dot_general_name)

def __call__(self, inputs: jax.Array, _initializing: bool = False) -> Tuple[jax.Array, Optional[jax.Array]]:

inputs = jnp.asarray(inputs, self.dtype)
norm_axis = linears.normalize_axes(self.axis, inputs.ndim)

Expand Down Expand Up @@ -410,7 +409,7 @@ def __init__(
self.wi_0 = jnp.zeros((num_experts, self.moe_expert_input_dim, intermediate_dim))
self.wi_1 = jnp.zeros((num_experts, self.moe_expert_input_dim, intermediate_dim))
self.wo = jnp.zeros((num_experts, intermediate_dim, self.moe_expert_input_dim))
elif self.config.prefuse_moe_weights and self.config.attention == "vllm_rpa":
elif (self.config.prefuse_moe_weights and self.config.attention == "vllm_rpa") or self.config.fused_mlp:
self.wi = nnx.Param(
self.kernel_init(
self.rngs.params(),
Expand Down Expand Up @@ -1318,29 +1317,44 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
self.config.wo_tile_drhs_embed_dim, # Called n in megablox, and indeed is the RHS batch dim
)

layer_w0 = gmm_fn(
x,
w0,
tiling=wi_tile_size,
weight_gather_axes=wi_gather_axes,
)
if self.get_tensor_transpose_parallelism_size() > 1:
layer_w0 = jax.lax.psum(layer_w0, "tensor_transpose")
if self.config.mlp_bias:
layer_w0 = layer_w0 + w0_bias
layer_w0 = adc.checkpoint_name(layer_w0, "moe_mlpwi_0")

layer_w1 = gmm_fn(
x,
w1,
tiling=wi_tile_size,
weight_gather_axes=wi_gather_axes,
)
if self.get_tensor_transpose_parallelism_size() > 1:
layer_w1 = jax.lax.psum(layer_w1, "tensor_transpose")
if self.config.mlp_bias:
layer_w1 = layer_w1 + w1_bias
layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1")
if self.config.fused_mlp:
# Weights are stored as (G,K,2N); w0/w1 are adjacent slices so XLA elides this concat.
w_fused = jnp.concatenate([w0, w1], axis=-1)
Copy link
Copy Markdown
Collaborator

@gobbleturk gobbleturk Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we may want an implementation that initializes the weight in concatted shape (G, K,2N), which guarantees there are no expensive weight HBM movements to concat. This is the format assumed by vllm-tpu for concatted weights which gets higher performance by avoiding this HBM concat - may be less of an important optimization for training. However this would be more consistent with the dense implementation as well

out_features_shape=(len(self.activations), self.intermediate_dim),
- which actually does a (K,2,N) format...

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point. The checkpoint conversion in the future will need to update a little bit accordingly if fine tuning workloads. I will help create a bug to track this feature and put in the backlog.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

out = gmm_fn(x, w_fused, tiling=wi_tile_size, weight_gather_axes=wi_gather_axes)
n = w0.shape[-1]
layer_w0, layer_w1 = out[:, :n], out[:, n:]
if self.get_tensor_transpose_parallelism_size() > 1:
layer_w0 = jax.lax.psum(layer_w0, "tensor_transpose")
layer_w1 = jax.lax.psum(layer_w1, "tensor_transpose")
if self.config.mlp_bias:
layer_w0 = layer_w0 + w0_bias
layer_w1 = layer_w1 + w1_bias
layer_w0 = adc.checkpoint_name(layer_w0, "moe_mlpwi_0")
layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1")
else:
layer_w0 = gmm_fn(
x,
w0,
tiling=wi_tile_size,
weight_gather_axes=wi_gather_axes,
)
if self.get_tensor_transpose_parallelism_size() > 1:
layer_w0 = jax.lax.psum(layer_w0, "tensor_transpose")
if self.config.mlp_bias:
layer_w0 = layer_w0 + w0_bias
layer_w0 = adc.checkpoint_name(layer_w0, "moe_mlpwi_0")

layer_w1 = gmm_fn(
x,
w1,
tiling=wi_tile_size,
weight_gather_axes=wi_gather_axes,
)
if self.get_tensor_transpose_parallelism_size() > 1:
layer_w1 = jax.lax.psum(layer_w1, "tensor_transpose")
if self.config.mlp_bias:
layer_w1 = layer_w1 + w1_bias
layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1")
intermediate_layer = self.apply_ffn_activation(layer_w0, layer_w1)

intermediate_output = gmm_fn(
Expand Down Expand Up @@ -2144,6 +2158,11 @@ def __call__(
w1_kernel = None
if cfg.prefuse_moe_weights and cfg.attention == "vllm_rpa":
fused_kernel = jnp.asarray(self.wi[...], self.dtype)
elif cfg.fused_mlp:
wi = jnp.asarray(self.wi[...], self.dtype)
n = wi.shape[-1] // 2
w0_kernel = wi[..., :n]
w1_kernel = wi[..., n:]
else:
w0_kernel = jnp.asarray(self.wi_0[...], self.dtype)
w1_kernel = jnp.asarray(self.wi_1[...], self.dtype)
Expand Down
144 changes: 102 additions & 42 deletions tests/unit/moe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,42 +68,38 @@ def setUp(self):
def test_generate_masks(self):
# expert_capacity = (tokens_per_batch / num_experts) * capacity_factor
# expert_capacity_in_batch = (4 * 2 / 8) * 2 = 2
top_k_indices = jnp.array(
top_k_indices = jnp.array([
[[0, 5], [0, 4], [1, 0], [3, 5]],
[[1, 2], [4, 1], [5, 0], [7, 1]],
[[6, 2], [2, 3], [4, 2], [1, 2]],
[[4, 1], [0, 7], [5, 0], [4, 7]],
])
softmax_probs = jnp.array([
[
[[0, 5], [0, 4], [1, 0], [3, 5]],
[[1, 2], [4, 1], [5, 0], [7, 1]],
[[6, 2], [2, 3], [4, 2], [1, 2]],
[[4, 1], [0, 7], [5, 0], [4, 7]],
]
)
softmax_probs = jnp.array(
[0.20, 0, 0, 0, 0, 0.80, 0, 0],
[0.68, 0, 0, 0, 0.32, 0, 0, 0],
[0.22, 0.78, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0.32, 0, 0.68, 0, 0],
],
[
[
[0.20, 0, 0, 0, 0, 0.80, 0, 0],
[0.68, 0, 0, 0, 0.32, 0, 0, 0],
[0.22, 0.78, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0.32, 0, 0.68, 0, 0],
],
[
[0, 0.26, 0.74, 0, 0, 0, 0, 0],
[0, 0.79, 0, 0, 0.21, 0, 0, 0],
[0.89, 0, 0, 0, 0, 0.11, 0, 0],
[0, 0.11, 0, 0, 0, 0, 0, 0.89],
],
[
[0, 0, 0.26, 0, 0, 0, 0.74, 0],
[0, 0, 0.88, 0.12, 0, 0, 0, 0],
[0, 0, 0.17, 0, 0.83, 0, 0, 0],
[0, 0.35, 0.65, 0, 0, 0, 0, 0],
],
[
[0, 0.47, 0, 0, 0.53, 0, 0, 0],
[0.36, 0, 0, 0, 0, 0, 0, 0.64],
[0.15, 0, 0, 0, 0, 0.85, 0, 0],
[0, 0, 0, 0, 0.18, 0, 0, 0.82],
],
]
)
[0, 0.26, 0.74, 0, 0, 0, 0, 0],
[0, 0.79, 0, 0, 0.21, 0, 0, 0],
[0.89, 0, 0, 0, 0, 0.11, 0, 0],
[0, 0.11, 0, 0, 0, 0, 0, 0.89],
],
[
[0, 0, 0.26, 0, 0, 0, 0.74, 0],
[0, 0, 0.88, 0.12, 0, 0, 0, 0],
[0, 0, 0.17, 0, 0.83, 0, 0, 0],
[0, 0.35, 0.65, 0, 0, 0, 0, 0],
],
[
[0, 0.47, 0, 0, 0.53, 0, 0, 0],
[0.36, 0, 0, 0, 0, 0, 0, 0.64],
[0.15, 0, 0, 0, 0, 0.85, 0, 0],
[0, 0, 0, 0, 0.18, 0, 0, 0.82],
],
])

# As expert_capacity_in_batch=2, so updated softmax_probs become (4 tokens were dropped):
# softmax_probs = jnp.array([[[0.20, 0, 0, 0, 0, 0.80, 0, 0],
Expand Down Expand Up @@ -238,14 +234,10 @@ def setUp(self):

def test_deepseek_routing(self):
# shape as [batch, sequence, num_experts] = [1,2,16]
gate_logits = jnp.array(
[
[
[0.20, 0.10, 0.05, 0.10, 0.10, 0.60, 0.30, 0.10, 0.80, 0.01, 0.01, 0.01, 0.05, 0.80, 0.20, 0.10],
[0.68, 0.20, 0.06, 0.03, 0.32, 0.10, 0.05, 0.02, 0.65, 0.20, 0.04, 0.01, 0.32, 0.10, 0.05, 0.02],
]
]
)
gate_logits = jnp.array([[
[0.20, 0.10, 0.05, 0.10, 0.10, 0.60, 0.30, 0.10, 0.80, 0.01, 0.01, 0.01, 0.05, 0.80, 0.20, 0.10],
[0.68, 0.20, 0.06, 0.03, 0.32, 0.10, 0.05, 0.02, 0.65, 0.20, 0.04, 0.01, 0.32, 0.10, 0.05, 0.02],
]])
pre_bias_logits = gate_logits - 0.5

# 4 groups of 1st token:
Expand Down Expand Up @@ -1402,5 +1394,73 @@ def test_prefused_vs_sparse_softmax(self):
self.assertIsNone(bias_updates)


@pytest.mark.tpu_only
class FusedMlpMoETest(unittest.TestCase):
"""Tests that fused_mlp=True and fused_mlp=False produce identical outputs for MoE."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._B = 1
self._S = 16

def setUp(self):
super().setUp()
self.rng = jax.random.PRNGKey(0)
extra_args = get_decoupled_parallelism_overrides()
self.ref_cfg = pyconfig.initialize(
[None, get_test_config_path()],
run_name="fused_mlp_moe_ref",
enable_checkpointing=False,
model_name="mixtral-8x7b",
dtype="bfloat16",
sparse_matmul=True,
megablox=True,
fused_mlp=False,
ici_expert_parallelism=jax.device_count(),
max_target_length=self._S,
per_device_batch_size=self._B,
**extra_args,
)
ref_devices = maxtext_utils.create_device_mesh(self.ref_cfg)
self.ref_mesh = Mesh(ref_devices, self.ref_cfg.mesh_axes)
self.ref_model = make_moe(self.ref_cfg, self.ref_mesh)

def _inputs(self):
return jax.random.normal(self.rng, (self._B, self._S, self.ref_cfg.base_emb_dim), dtype=jnp.bfloat16)

def test_fused_mlp_matches_unfused(self):
"""fused_mlp=True output matches fused_mlp=False with sparse_matmul (Megablox)."""
extra_args = get_decoupled_parallelism_overrides()
fused_cfg = pyconfig.initialize(
[None, get_test_config_path()],
run_name="fused_mlp_moe_fused",
enable_checkpointing=False,
model_name="mixtral-8x7b",
dtype="bfloat16",
sparse_matmul=True,
megablox=True,
fused_mlp=True,
ici_expert_parallelism=jax.device_count(),
max_target_length=self._S,
per_device_batch_size=self._B,
**extra_args,
)
fused_devices = maxtext_utils.create_device_mesh(fused_cfg)
fused_mesh = Mesh(fused_devices, fused_cfg.mesh_axes)
fused_model = make_moe(fused_cfg, fused_mesh)
copy_weights_prefused(self.ref_model, fused_model)

inputs = self._inputs()
ref_out, _, _ = self.ref_model(inputs)
fused_out, _, _ = fused_model(inputs)

np.testing.assert_allclose(
np.array(ref_out, dtype=np.float32),
np.array(fused_out, dtype=np.float32),
rtol=1e-2,
atol=1e-2,
)


if __name__ == "__main__":
unittest.main()
Loading