Add Primus Turbo grouped GEMM support for MoE sparse matmul#83
Merged
Conversation
Integrate primus_turbo's grouped_gemm kernel as an alternative backend for MoE sparse matrix multiplication. Controlled via the new use_turbo_grouped_gemm config flag (requires sparse_matmul=True, megablox=False, and the primus_turbo package). Includes a temporary workaround to enable JAX x64 mode locally during the grouped_gemm call, since the CK kernel requires int64 group_sizes but global x64 mode breaks argsort in the MoE routing path. Made-with: Cursor
yeandy
reviewed
Apr 1, 2026
| # global x64 breaks argsort (XLA-ROCm s32/s64 scatter mismatch). | ||
| # Use jax.experimental.enable_x64() for thread-local scope, | ||
| # safe for concurrent shard_map threads. | ||
| # Remove this once primus_turbo accepts int32 group_lens natively. |
Collaborator
There was a problem hiding this comment.
Is there a GitHub issue for this in the Primus-Turbo repo? if not, should we create one?
| group_sizes.astype(jnp.int64), | ||
| transA=False, | ||
| transB=False, | ||
| num_cu=-1, |
Collaborator
There was a problem hiding this comment.
@qianghan-amd Do we want to be able to allow users to tune this? maybe as an additional flag in the config?
I think the optimal CU count used depends on # of experts and GEMM sizes, right?
| from primus_turbo.jax.lax.grouped_gemm import grouped_gemm as turbo_grouped_gemm | ||
| except ImportError: | ||
| raise ImportError("use_turbo_grouped_gemm=True requires the primus_turbo package to be installed.") | ||
| if not getattr(turbo_grouped_gemm, "_logged", False): |
Collaborator
There was a problem hiding this comment.
If getattr(turbo_grouped_gemm, "_logged", False) returns True, don't we still want to print max_logging.log("Using primus_turbo grouped_gemm in MoE sparse matmul")?
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Integrate primus_turbo's grouped_gemm kernel as an alternative backend for MoE sparse matrix multiplication. Controlled via the new use_turbo_grouped_gemm config flag.
Made-with: Cursor
Usage
Requires the primus_turbo package installed in the container image.
Tests
Hardware: 1x MI355X node (8 GPUs, 288 GB VRAM each), ROCm 7.0.1, JAX 0.8.2
Turbo vs ragged_dot — cross-model comparison
ragged_dotis only viable on the smallest expert config (128 experts, top-2). For 256+ experts, it either OOMs or hits an XLA-ROCm kernel generation bug (HIP_ERROR_InvalidValue).Memory efficiency
Convergence (2000 steps, C4 real data)
Turbo converges to the same loss cluster as ragged_dot, confirming correctness over long training runs.
1-sparse is ragged_dot, 5-gmm is turbo grouped gemm.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.