From e96c03bc71ea966c5f462cef370f9c1b53b13978 Mon Sep 17 00:00:00 2001 From: Yi Huang <234278504+yihuang-amd@users.noreply.github.com> Date: Wed, 1 Apr 2026 10:41:48 -0700 Subject: [PATCH] Revert "Add Primus Turbo grouped GEMM support for MoE sparse matmul" --- src/MaxText/configs/base.yml | 1 - src/MaxText/configs/types.py | 16 +--------------- src/MaxText/layers/moe.py | 27 +-------------------------- 3 files changed, 2 insertions(+), 42 deletions(-) diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index ecc308122f..89e0d823ca 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -174,7 +174,6 @@ num_experts: 1 num_experts_per_tok: 1 megablox: True sparse_matmul: True -use_turbo_grouped_gemm: false # Use Primus Turbo grouped GEMM for MoE sparse matmul. Requires sparse_matmul=True, megablox=False, and primus_turbo installed. capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default load_balance_loss_weight: 0.01 # weight for the load balance loss expert_balance: False # whether or not to do expert balancing diff --git a/src/MaxText/configs/types.py b/src/MaxText/configs/types.py index 4021c87f88..9e00e13e84 100644 --- a/src/MaxText/configs/types.py +++ b/src/MaxText/configs/types.py @@ -556,12 +556,6 @@ class MoEKernels(BaseModel): megablox: bool = Field(True, description="Whether to use Megablox kernels for MoE.") sparse_matmul: bool = Field(True, description="Whether to use sparse matmul kernels for MoE.") - use_turbo_grouped_gemm: bool = Field( - False, - description="Use Primus Turbo grouped GEMM for MoE sparse matmul. " - "Requires sparse_matmul=True and megablox=False. " - "Requires the primus_turbo package to be installed.", - ) wi_tile_fwd_batch_seq: int = Field(512, description="forward pass tiling dimension for batch/sequence in GMM for wi.") wi_tile_fwd_embed_dim: int = Field(1024, description="forward pass tiling dimension for embedding in GMM for wi.") wi_tile_fwd_mlp_dim: int = Field(1024, description="forward pass tiling dimension for MLP in GMM for wi.") @@ -1100,8 +1094,7 @@ class DevelopmentAndDebugging(BaseModel): ) jax_distributed_initialization_timeout: int = Field(300, description="Timeout for jax.distributed.initialize.") jax_distributed_heartbeat_timeout_seconds: int = Field( - 100, - description="How long before a missing heartbeat marks a task as dead. Increase for slow NFS checkpoint restores.", + 100, description="How long before a missing heartbeat marks a task as dead. Increase for slow NFS checkpoint restores." ) jax_debug_log_modules: str = Field("", description="Set to 'jax' for verbose JAX logging.") skip_jax_distributed_system: bool = Field(False, description="If True, do not initialize the jax distributed system.") @@ -1906,13 +1899,6 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de ) if self.decoder_block == DecoderBlockType.GPT_OSS and not self.sparse_matmul and self.capacity_factor != -1: raise ValueError("GPT-OSS MoE only supports dropless (capacity_factor=-1) with dense matmul.") - if self.use_turbo_grouped_gemm: - if self.quantization: - raise ValueError("use_turbo_grouped_gemm is not compatible with quantization.") - if not self.sparse_matmul: - raise ValueError("use_turbo_grouped_gemm requires sparse_matmul=True.") - if self.megablox: - raise ValueError("use_turbo_grouped_gemm requires megablox=False.") if self.use_multimodal: valid_mm_models = ("gemma3-4b", "gemma3-12b", "gemma3-27b", "llama4-17b-16e", "llama4-17b-128e") if self.model_name not in valid_mm_models and self.model_name != "default": diff --git a/src/MaxText/layers/moe.py b/src/MaxText/layers/moe.py index f5d55585e3..11c4a71f27 100644 --- a/src/MaxText/layers/moe.py +++ b/src/MaxText/layers/moe.py @@ -855,28 +855,6 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments): use_tokamax_backend=self.config.use_tokamax_gmm, is_fsdp_shard_on_exp=self.config.fsdp_shard_on_exp, ) - elif self.config.use_turbo_grouped_gemm: - try: - from primus_turbo.jax.lax.grouped_gemm import grouped_gemm as turbo_grouped_gemm - except ImportError: - raise ImportError("use_turbo_grouped_gemm=True requires the primus_turbo package to be installed.") - if not getattr(turbo_grouped_gemm, "_logged", False): - max_logging.log("Using primus_turbo grouped_gemm in MoE sparse matmul") - turbo_grouped_gemm._logged = True - # Thread-local x64: CK kernel requires int64 group_sizes, but - # global x64 breaks argsort (XLA-ROCm s32/s64 scatter mismatch). - # Use jax.experimental.enable_x64() for thread-local scope, - # safe for concurrent shard_map threads. - # Remove this once primus_turbo accepts int32 group_lens natively. - with jax.experimental.enable_x64(): - output = turbo_grouped_gemm( - inputs, - kernel, - group_sizes.astype(jnp.int64), - transA=False, - transB=False, - num_cu=-1, - ) else: rhs_inputs = kernel if isinstance(kernel, aqt.QTensor): @@ -1467,10 +1445,7 @@ def get_einsum( def aqt_einsum(*args, **kwargs): # pylint: disable=unused-argument # simply skip kwargs, since aqt einsum doesn't support any kwargs # like precision - is_aqt = not ( - isinstance(self.quant, quantizations.Fp8Quantization) - or isinstance(self.quant, quantizations.NANOOFp8Quantization) - ) + is_aqt = not ( isinstance(self.quant, quantizations.Fp8Quantization) or isinstance(self.quant, quantizations.NANOOFp8Quantization) ) kw = {"mesh_axes": rhs_mesh_axes} if is_aqt else {"dtype": self.dtype} return self.quant.einsum(**kw)(*args) # pytype: disable=attribute-error