Add fused_moe_mlp: fuse wi_0 and wi_1 into one grouped GEMM for MoE FFN1#3736
Add fused_moe_mlp: fuse wi_0 and wi_1 into one grouped GEMM for MoE FFN1#3736abhinavgoel95 wants to merge 1 commit intoAI-Hypercomputer:mainfrom
Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
a346298 to
84b4290
Compare
| layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1") | ||
| if self.config.fused_mlp: | ||
| # Fuse wi_0 and wi_1: [G,K,N] + [G,K,N] -> [G,K,2N], one GEMM, split result. | ||
| w_fused = jnp.concatenate([w0, w1], axis=-1) |
There was a problem hiding this comment.
we may want an implementation that initializes the weight in concatted shape (G, K,2N), which guarantees there are no expensive weight HBM movements to concat. This is the format assumed by vllm-tpu for concatted weights which gets higher performance by avoiding this HBM concat - may be less of an important optimization for training. However this would be more consistent with the dense implementation as well
maxtext/src/maxtext/layers/linears.py
Line 413 in b117f50
There was a problem hiding this comment.
That's a good point. The checkpoint conversion in the future will need to update a little bit accordingly if fine tuning workloads. I will help create a bug to track this feature and put in the backlog.
RissyRan
left a comment
There was a problem hiding this comment.
Thanks for the feature! Could you help add a test to compare output are the same with/without this flag?
maxtext/tests/unit/moe_test.py
Line 4 in b88ea63
Also, could you help add this config in
When fused_mlp=True, the two FFN1 grouped GEMMs (wi_0 gate and wi_1 up projection) are fused into a single call. Expert weights are stored in a concatenated (num_experts, embed_dim, 2*mlp_dim) shape so input activations are loaded from HBM once instead of twice. The concat in sparse_matmul operates on adjacent slices of the stored buffer, which XLA can elide. This is backend-agnostic and analogous to fused_mlp for dense models. Also adds a correctness test (FusedMlpMoETest) and documents the flag in docs/reference/core_concepts/moe_configuration.md.
d8b3334 to
f6c0c96
Compare
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
RissyRan
left a comment
There was a problem hiding this comment.
LGTM! Thanks! I see a few unit tests are failing. Could you try to rebase and see if it works? You may also need to work on format, which could be addressed by pre-commit.
Analogous to the existing fused_mlp flag for dense models. When enabled, concatenates the wi_0 and wi_1 expert weight matrices along the output dimension ([G,K,N] + [G,K,N] -> [G,K,2N]), issues a single grouped GEMM call, then splits the result. This halves FFN1 kernel launches and reads the input activations from HBM once instead of twice.
Works with all grouped GEMM backends (megablox, tokamax, ragged_dot). Off by default; requires sparse_matmul=True.
Description
Adds a
fused_moe_mlpboolean config flag (defaultfalse) for MoE models, analogous to the existingfused_mlpflag for dense models. When enabled, the two FFN1 input projections (wi_0gate andwi_1up) are fused into a single grouped GEMM call instead of two separateones.
Why this change:
In a gated MoE FFN,
wi_0andwi_1share the same inputxbut are currently dispatched as two independent grouped GEMMs. This meansxis loaded from HBM twice and two kernel launches are issued back-to-back for every MoE layer.How it works:
The two expert weight matrices (
[G, K, N]each) are concatenated along the output dimension to form a single fused weight[G, K, 2N]. One grouped GEMM call produces[M, 2N], which is then split intolayer_w0andlayer_w1. This is the same approachfused_mlpuses fordense models (stacking
wi_0/wi_1into a singledot_general), extended to the grouped GEMM case.Benefits:
xfrom HBM once instead of twiceN = 2 × mlp_dimGEMM, which typically achieves better hardware utilization than two narrower back-to-back GEMMsjax.lax.ragged_dotImplementation:
The change is entirely in
sparse_matmulinmoe.py— a conditional branch onconfig.fused_moe_mlpwraps the existinggmm_fncalls. The unfused path is unchanged. A config-load-time validation ensuressparse_matmul=True(the denseeinsumpath does not usegmm_fnandis unaffected). Off by default (
fused_moe_mlp: false).Tests
Forward-pass correctness: run with
fused_moe_mlp=falseandfused_moe_mlp=trueon identical synthetic inputs and verifyintermediate_layeroutputs match numerically. Gradient correctness: verify gradients w.r.t.wi_0andwi_1match the unfused path (JAX AD tracescleanly through
jnp.concatenate+ slicing with no custom VJP needed). Validation: confirmfused_moe_mlp=truewithsparse_matmul=falseraisesValueErrorat config load.Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.