Skip to content

Add Primus Turbo grouped GEMM support for MoE sparse matmul#83

Merged
qianghan-amd merged 2 commits into
release/v26.4from
qianghan/add-turbo-grouped-gemm
Apr 1, 2026
Merged

Add Primus Turbo grouped GEMM support for MoE sparse matmul#83
qianghan-amd merged 2 commits into
release/v26.4from
qianghan/add-turbo-grouped-gemm

Conversation

@qianghan-amd

Copy link
Copy Markdown

Integrate primus_turbo's grouped_gemm kernel as an alternative backend for MoE sparse matrix multiplication. Controlled via the new use_turbo_grouped_gemm config flag.

Made-with: Cursor

Usage

sparse_matmul: true
use_turbo_grouped_gemm: true

Requires the primus_turbo package installed in the container image.

Tests

Hardware: 1x MI355X node (8 GPUs, 288 GB VRAM each), ROCm 7.0.1, JAX 0.8.2

Turbo vs ragged_dot — cross-model comparison

Model Experts Top-K Turbo GMM TGS (pdbs=1) ragged_dot TGS (pdbs=1) Speedup ragged_dot status
ds-proxy-e128 128 2 1,306 492 2.65x Works (pdbs=1 only, OOM at pdbs>=2)
ds-proxy-se0-e256 256 4 4,217 CRASH N/A HIP_ERROR_InvalidValue at all layer counts
DeepSeek V3 (scaled) 256 8 1,536 OOM/CRASH N/A OOM at layers>=3, HIP crash at layers<=2
Kimi K2 (scaled) 384 8 1,502 OOM N/A 398 GB temp at 1 layer (1.5x GPU capacity)

ragged_dot is only viable on the smallest expert config (128 experts, top-2). For 256+ experts, it either OOMs or hits an XLA-ROCm kernel generation bug (HIP_ERROR_InvalidValue).

Memory efficiency

Model Turbo GMM Mem ragged_dot Mem Turbo advantage
ds-proxy-e128 (128E, top-2) 102.3 GB 119.5 GB 1.2x
ds-proxy-se0-e256 (256E, top-4) 58.6 GB 148–172 GB 2.7x
DeepSeek V3 (256E, top-8) 91.9 GB 271.5 GB 3.0x
Kimi K2 (384E, top-8) 121.5 GB 397.9 GB 3.3x

Convergence (2000 steps, C4 real data)

Kernel Dropless? Loss @ step 1999 TGS/device
Sparse ragged_dot Yes 5.930 490
GMM (turbo) Yes 5.951 1,132

Turbo converges to the same loss cluster as ragged_dot, confirming correctness over long training runs.

image

1-sparse is ragged_dot, 5-gmm is turbo grouped gemm.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

Integrate primus_turbo's grouped_gemm kernel as an alternative backend
for MoE sparse matrix multiplication. Controlled via the new
use_turbo_grouped_gemm config flag (requires sparse_matmul=True,
megablox=False, and the primus_turbo package).

Includes a temporary workaround to enable JAX x64 mode locally during
the grouped_gemm call, since the CK kernel requires int64 group_sizes
but global x64 mode breaks argsort in the MoE routing path.

Made-with: Cursor
@qianghan-amd qianghan-amd merged commit c81153a into release/v26.4 Apr 1, 2026
2 of 4 checks passed
Comment thread src/MaxText/layers/moe.py
# global x64 breaks argsort (XLA-ROCm s32/s64 scatter mismatch).
# Use jax.experimental.enable_x64() for thread-local scope,
# safe for concurrent shard_map threads.
# Remove this once primus_turbo accepts int32 group_lens natively.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a GitHub issue for this in the Primus-Turbo repo? if not, should we create one?

Comment thread src/MaxText/layers/moe.py
group_sizes.astype(jnp.int64),
transA=False,
transB=False,
num_cu=-1,

@yeandy yeandy Apr 1, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qianghan-amd Do we want to be able to allow users to tune this? maybe as an additional flag in the config?

I think the optimal CU count used depends on # of experts and GEMM sizes, right?

Comment thread src/MaxText/layers/moe.py
from primus_turbo.jax.lax.grouped_gemm import grouped_gemm as turbo_grouped_gemm
except ImportError:
raise ImportError("use_turbo_grouped_gemm=True requires the primus_turbo package to be installed.")
if not getattr(turbo_grouped_gemm, "_logged", False):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If getattr(turbo_grouped_gemm, "_logged", False) returns True, don't we still want to print max_logging.log("Using primus_turbo grouped_gemm in MoE sparse matmul")?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants