From ec66806cd5a1fbb37acb456ec2ad7128d41eac8b Mon Sep 17 00:00:00 2001 From: Qiang Han Date: Tue, 17 Mar 2026 23:40:28 +0000 Subject: [PATCH 1/2] Add Primus Turbo grouped GEMM support for MoE sparse matmul Integrate primus_turbo's grouped_gemm kernel as an alternative backend for MoE sparse matrix multiplication. Controlled via the new use_turbo_grouped_gemm config flag (requires sparse_matmul=True, megablox=False, and the primus_turbo package). Includes a temporary workaround to enable JAX x64 mode locally during the grouped_gemm call, since the CK kernel requires int64 group_sizes but global x64 mode breaks argsort in the MoE routing path. Made-with: Cursor --- src/MaxText/configs/base.yml | 1 + src/MaxText/configs/types.py | 13 +++++++++++++ src/MaxText/layers/moe.py | 24 ++++++++++++++++++++++++ 3 files changed, 38 insertions(+) diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index 89e0d823ca..ecc308122f 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -174,6 +174,7 @@ num_experts: 1 num_experts_per_tok: 1 megablox: True sparse_matmul: True +use_turbo_grouped_gemm: false # Use Primus Turbo grouped GEMM for MoE sparse matmul. Requires sparse_matmul=True, megablox=False, and primus_turbo installed. capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default load_balance_loss_weight: 0.01 # weight for the load balance loss expert_balance: False # whether or not to do expert balancing diff --git a/src/MaxText/configs/types.py b/src/MaxText/configs/types.py index 9e00e13e84..d3ec1bd85f 100644 --- a/src/MaxText/configs/types.py +++ b/src/MaxText/configs/types.py @@ -556,6 +556,12 @@ class MoEKernels(BaseModel): megablox: bool = Field(True, description="Whether to use Megablox kernels for MoE.") sparse_matmul: bool = Field(True, description="Whether to use sparse matmul kernels for MoE.") + use_turbo_grouped_gemm: bool = Field( + False, + description="Use Primus Turbo grouped GEMM for MoE sparse matmul. " + "Requires sparse_matmul=True and megablox=False. " + "Requires the primus_turbo package to be installed.", + ) wi_tile_fwd_batch_seq: int = Field(512, description="forward pass tiling dimension for batch/sequence in GMM for wi.") wi_tile_fwd_embed_dim: int = Field(1024, description="forward pass tiling dimension for embedding in GMM for wi.") wi_tile_fwd_mlp_dim: int = Field(1024, description="forward pass tiling dimension for MLP in GMM for wi.") @@ -1899,6 +1905,13 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de ) if self.decoder_block == DecoderBlockType.GPT_OSS and not self.sparse_matmul and self.capacity_factor != -1: raise ValueError("GPT-OSS MoE only supports dropless (capacity_factor=-1) with dense matmul.") + if self.use_turbo_grouped_gemm: + if self.quantization: + raise ValueError("use_turbo_grouped_gemm is not compatible with quantization.") + if not self.sparse_matmul: + raise ValueError("use_turbo_grouped_gemm requires sparse_matmul=True.") + if self.megablox: + raise ValueError("use_turbo_grouped_gemm requires megablox=False.") if self.use_multimodal: valid_mm_models = ("gemma3-4b", "gemma3-12b", "gemma3-27b", "llama4-17b-16e", "llama4-17b-128e") if self.model_name not in valid_mm_models and self.model_name != "default": diff --git a/src/MaxText/layers/moe.py b/src/MaxText/layers/moe.py index 11c4a71f27..b3858f6a31 100644 --- a/src/MaxText/layers/moe.py +++ b/src/MaxText/layers/moe.py @@ -855,6 +855,30 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments): use_tokamax_backend=self.config.use_tokamax_gmm, is_fsdp_shard_on_exp=self.config.fsdp_shard_on_exp, ) + elif self.config.use_turbo_grouped_gemm: + try: + from primus_turbo.jax.lax.grouped_gemm import grouped_gemm as turbo_grouped_gemm + except ImportError: + raise ImportError( + "use_turbo_grouped_gemm=True requires the primus_turbo package to be installed." + ) + if not getattr(turbo_grouped_gemm, "_logged", False): + max_logging.log("Using primus_turbo grouped_gemm in MoE sparse matmul") + turbo_grouped_gemm._logged = True + # Thread-local x64: CK kernel requires int64 group_sizes, but + # global x64 breaks argsort (XLA-ROCm s32/s64 scatter mismatch). + # Use jax.experimental.enable_x64() for thread-local scope, + # safe for concurrent shard_map threads. + # Remove this once primus_turbo accepts int32 group_lens natively. + with jax.experimental.enable_x64(): + output = turbo_grouped_gemm( + inputs, + kernel, + group_sizes.astype(jnp.int64), + transA=False, + transB=False, + num_cu=-1, + ) else: rhs_inputs = kernel if isinstance(kernel, aqt.QTensor): From fe25475ce971052f3558a57085e1f774f1a7bb6b Mon Sep 17 00:00:00 2001 From: Qiang Han Date: Wed, 1 Apr 2026 17:24:40 +0000 Subject: [PATCH 2/2] Fix pyink formatting --- src/MaxText/configs/types.py | 3 ++- src/MaxText/layers/moe.py | 9 +++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/MaxText/configs/types.py b/src/MaxText/configs/types.py index d3ec1bd85f..4021c87f88 100644 --- a/src/MaxText/configs/types.py +++ b/src/MaxText/configs/types.py @@ -1100,7 +1100,8 @@ class DevelopmentAndDebugging(BaseModel): ) jax_distributed_initialization_timeout: int = Field(300, description="Timeout for jax.distributed.initialize.") jax_distributed_heartbeat_timeout_seconds: int = Field( - 100, description="How long before a missing heartbeat marks a task as dead. Increase for slow NFS checkpoint restores." + 100, + description="How long before a missing heartbeat marks a task as dead. Increase for slow NFS checkpoint restores.", ) jax_debug_log_modules: str = Field("", description="Set to 'jax' for verbose JAX logging.") skip_jax_distributed_system: bool = Field(False, description="If True, do not initialize the jax distributed system.") diff --git a/src/MaxText/layers/moe.py b/src/MaxText/layers/moe.py index b3858f6a31..f5d55585e3 100644 --- a/src/MaxText/layers/moe.py +++ b/src/MaxText/layers/moe.py @@ -859,9 +859,7 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments): try: from primus_turbo.jax.lax.grouped_gemm import grouped_gemm as turbo_grouped_gemm except ImportError: - raise ImportError( - "use_turbo_grouped_gemm=True requires the primus_turbo package to be installed." - ) + raise ImportError("use_turbo_grouped_gemm=True requires the primus_turbo package to be installed.") if not getattr(turbo_grouped_gemm, "_logged", False): max_logging.log("Using primus_turbo grouped_gemm in MoE sparse matmul") turbo_grouped_gemm._logged = True @@ -1469,7 +1467,10 @@ def get_einsum( def aqt_einsum(*args, **kwargs): # pylint: disable=unused-argument # simply skip kwargs, since aqt einsum doesn't support any kwargs # like precision - is_aqt = not ( isinstance(self.quant, quantizations.Fp8Quantization) or isinstance(self.quant, quantizations.NANOOFp8Quantization) ) + is_aqt = not ( + isinstance(self.quant, quantizations.Fp8Quantization) + or isinstance(self.quant, quantizations.NANOOFp8Quantization) + ) kw = {"mesh_axes": rhs_mesh_axes} if is_aqt else {"dtype": self.dtype} return self.quant.einsum(**kw)(*args) # pytype: disable=attribute-error