Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ num_experts: 1
num_experts_per_tok: 1
megablox: True
sparse_matmul: True
use_turbo_grouped_gemm: false # Use Primus Turbo grouped GEMM for MoE sparse matmul. Requires sparse_matmul=True, megablox=False, and primus_turbo installed.
capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default
load_balance_loss_weight: 0.01 # weight for the load balance loss
expert_balance: False # whether or not to do expert balancing
Expand Down
16 changes: 15 additions & 1 deletion src/MaxText/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,12 @@ class MoEKernels(BaseModel):

megablox: bool = Field(True, description="Whether to use Megablox kernels for MoE.")
sparse_matmul: bool = Field(True, description="Whether to use sparse matmul kernels for MoE.")
use_turbo_grouped_gemm: bool = Field(
False,
description="Use Primus Turbo grouped GEMM for MoE sparse matmul. "
"Requires sparse_matmul=True and megablox=False. "
"Requires the primus_turbo package to be installed.",
)
wi_tile_fwd_batch_seq: int = Field(512, description="forward pass tiling dimension for batch/sequence in GMM for wi.")
wi_tile_fwd_embed_dim: int = Field(1024, description="forward pass tiling dimension for embedding in GMM for wi.")
wi_tile_fwd_mlp_dim: int = Field(1024, description="forward pass tiling dimension for MLP in GMM for wi.")
Expand Down Expand Up @@ -1094,7 +1100,8 @@ class DevelopmentAndDebugging(BaseModel):
)
jax_distributed_initialization_timeout: int = Field(300, description="Timeout for jax.distributed.initialize.")
jax_distributed_heartbeat_timeout_seconds: int = Field(
100, description="How long before a missing heartbeat marks a task as dead. Increase for slow NFS checkpoint restores."
100,
description="How long before a missing heartbeat marks a task as dead. Increase for slow NFS checkpoint restores.",
)
jax_debug_log_modules: str = Field("", description="Set to 'jax' for verbose JAX logging.")
skip_jax_distributed_system: bool = Field(False, description="If True, do not initialize the jax distributed system.")
Expand Down Expand Up @@ -1899,6 +1906,13 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
)
if self.decoder_block == DecoderBlockType.GPT_OSS and not self.sparse_matmul and self.capacity_factor != -1:
raise ValueError("GPT-OSS MoE only supports dropless (capacity_factor=-1) with dense matmul.")
if self.use_turbo_grouped_gemm:
if self.quantization:
raise ValueError("use_turbo_grouped_gemm is not compatible with quantization.")
if not self.sparse_matmul:
raise ValueError("use_turbo_grouped_gemm requires sparse_matmul=True.")
if self.megablox:
raise ValueError("use_turbo_grouped_gemm requires megablox=False.")
if self.use_multimodal:
valid_mm_models = ("gemma3-4b", "gemma3-12b", "gemma3-27b", "llama4-17b-16e", "llama4-17b-128e")
if self.model_name not in valid_mm_models and self.model_name != "default":
Expand Down
27 changes: 26 additions & 1 deletion src/MaxText/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,28 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments):
use_tokamax_backend=self.config.use_tokamax_gmm,
is_fsdp_shard_on_exp=self.config.fsdp_shard_on_exp,
)
elif self.config.use_turbo_grouped_gemm:
try:
from primus_turbo.jax.lax.grouped_gemm import grouped_gemm as turbo_grouped_gemm
except ImportError:
raise ImportError("use_turbo_grouped_gemm=True requires the primus_turbo package to be installed.")
if not getattr(turbo_grouped_gemm, "_logged", False):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If getattr(turbo_grouped_gemm, "_logged", False) returns True, don't we still want to print max_logging.log("Using primus_turbo grouped_gemm in MoE sparse matmul")?

max_logging.log("Using primus_turbo grouped_gemm in MoE sparse matmul")
turbo_grouped_gemm._logged = True
# Thread-local x64: CK kernel requires int64 group_sizes, but
# global x64 breaks argsort (XLA-ROCm s32/s64 scatter mismatch).
# Use jax.experimental.enable_x64() for thread-local scope,
# safe for concurrent shard_map threads.
# Remove this once primus_turbo accepts int32 group_lens natively.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a GitHub issue for this in the Primus-Turbo repo? if not, should we create one?

with jax.experimental.enable_x64():
output = turbo_grouped_gemm(
inputs,
kernel,
group_sizes.astype(jnp.int64),
transA=False,
transB=False,
num_cu=-1,

@yeandy yeandy Apr 1, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qianghan-amd Do we want to be able to allow users to tune this? maybe as an additional flag in the config?

I think the optimal CU count used depends on # of experts and GEMM sizes, right?

)
else:
rhs_inputs = kernel
if isinstance(kernel, aqt.QTensor):
Expand Down Expand Up @@ -1445,7 +1467,10 @@ def get_einsum(
def aqt_einsum(*args, **kwargs): # pylint: disable=unused-argument
# simply skip kwargs, since aqt einsum doesn't support any kwargs
# like precision
is_aqt = not ( isinstance(self.quant, quantizations.Fp8Quantization) or isinstance(self.quant, quantizations.NANOOFp8Quantization) )
is_aqt = not (
isinstance(self.quant, quantizations.Fp8Quantization)
or isinstance(self.quant, quantizations.NANOOFp8Quantization)
)
kw = {"mesh_axes": rhs_mesh_axes} if is_aqt else {"dtype": self.dtype}
return self.quant.einsum(**kw)(*args) # pytype: disable=attribute-error

Expand Down
Loading