Skip to content

Add fused_moe_mlp: fuse wi_0 and wi_1 into one grouped GEMM for MoE FFN1#3736

Open
abhinavgoel95 wants to merge 1 commit intoAI-Hypercomputer:mainfrom
abhinavgoel95:abgoel/fused-moe-mlp
Open

Add fused_moe_mlp: fuse wi_0 and wi_1 into one grouped GEMM for MoE FFN1#3736
abhinavgoel95 wants to merge 1 commit intoAI-Hypercomputer:mainfrom
abhinavgoel95:abgoel/fused-moe-mlp

Conversation

@abhinavgoel95
Copy link
Copy Markdown
Contributor

Analogous to the existing fused_mlp flag for dense models. When enabled, concatenates the wi_0 and wi_1 expert weight matrices along the output dimension ([G,K,N] + [G,K,N] -> [G,K,2N]), issues a single grouped GEMM call, then splits the result. This halves FFN1 kernel launches and reads the input activations from HBM once instead of twice.

Works with all grouped GEMM backends (megablox, tokamax, ragged_dot). Off by default; requires sparse_matmul=True.

Description

Adds a fused_moe_mlp boolean config flag (default false) for MoE models, analogous to the existing fused_mlp flag for dense models. When enabled, the two FFN1 input projections (wi_0 gate and wi_1 up) are fused into a single grouped GEMM call instead of two separate
ones.

Why this change:
In a gated MoE FFN, wi_0 and wi_1 share the same input x but are currently dispatched as two independent grouped GEMMs. This means x is loaded from HBM twice and two kernel launches are issued back-to-back for every MoE layer.

How it works:
The two expert weight matrices ([G, K, N] each) are concatenated along the output dimension to form a single fused weight [G, K, 2N]. One grouped GEMM call produces [M, 2N], which is then split into layer_w0 and layer_w1. This is the same approach fused_mlp uses for
dense models (stacking wi_0/wi_1 into a single dot_general), extended to the grouped GEMM case.

Benefits:

  • Halves FFN1 kernel launches (2 → 1) per MoE layer
  • Reads input activations x from HBM once instead of twice
  • Presents the accelerator with a wider N = 2 × mlp_dim GEMM, which typically achieves better hardware utilization than two narrower back-to-back GEMMs
  • Backend-agnostic: works with megablox, tokamax, and jax.lax.ragged_dot

Implementation:
The change is entirely in sparse_matmul in moe.py — a conditional branch on config.fused_moe_mlp wraps the existing gmm_fn calls. The unfused path is unchanged. A config-load-time validation ensures sparse_matmul=True (the dense einsum path does not use gmm_fn and
is unaffected). Off by default (fused_moe_mlp: false).

Tests

Forward-pass correctness: run with fused_moe_mlp=false and fused_moe_mlp=true on identical synthetic inputs and verify intermediate_layer outputs match numerically. Gradient correctness: verify gradients w.r.t. wi_0 and wi_1 match the unfused path (JAX AD traces
cleanly through jnp.concatenate + slicing with no custom VJP needed). Validation: confirm fused_moe_mlp=true with sparse_matmul=false raises ValueError at config load.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@google-cla
Copy link
Copy Markdown

google-cla Bot commented Apr 23, 2026

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@abhinavgoel95 abhinavgoel95 force-pushed the abgoel/fused-moe-mlp branch 3 times, most recently from a346298 to 84b4290 Compare April 23, 2026 23:09
Comment thread src/maxtext/layers/moe.py
layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1")
if self.config.fused_mlp:
# Fuse wi_0 and wi_1: [G,K,N] + [G,K,N] -> [G,K,2N], one GEMM, split result.
w_fused = jnp.concatenate([w0, w1], axis=-1)
Copy link
Copy Markdown
Collaborator

@gobbleturk gobbleturk Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we may want an implementation that initializes the weight in concatted shape (G, K,2N), which guarantees there are no expensive weight HBM movements to concat. This is the format assumed by vllm-tpu for concatted weights which gets higher performance by avoiding this HBM concat - may be less of an important optimization for training. However this would be more consistent with the dense implementation as well

out_features_shape=(len(self.activations), self.intermediate_dim),
- which actually does a (K,2,N) format...

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point. The checkpoint conversion in the future will need to update a little bit accordingly if fine tuning workloads. I will help create a bug to track this feature and put in the backlog.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Copy Markdown
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feature! Could you help add a test to compare output are the same with/without this flag?

# you may not use this file except in compliance with the License.

Also, could you help add this config in

Licensed under the Apache License, Version 2.0 (the "License");

When fused_mlp=True, the two FFN1 grouped GEMMs (wi_0 gate and wi_1 up
projection) are fused into a single call. Expert weights are stored in a
concatenated (num_experts, embed_dim, 2*mlp_dim) shape so input
activations are loaded from HBM once instead of twice. The concat in
sparse_matmul operates on adjacent slices of the stored buffer, which
XLA can elide. This is backend-agnostic and analogous to fused_mlp for
dense models.

Also adds a correctness test (FusedMlpMoETest) and documents the flag
in docs/reference/core_concepts/moe_configuration.md.
@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 28, 2026

Codecov Report

❌ Patch coverage is 67.74194% with 10 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/layers/moe.py 67.74% 6 Missing and 4 partials ⚠️

📢 Thoughts on this report? Let us know!

@gobbleturk gobbleturk mentioned this pull request Apr 29, 2026
4 tasks
Copy link
Copy Markdown
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks! I see a few unit tests are failing. Could you try to rebase and see if it works? You may also need to work on format, which could be addressed by pre-commit.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants