From 7d49b8dd80c80d7a75cd39b9c0eefe48f904fd1d Mon Sep 17 00:00:00 2001 From: Darisoy Date: Wed, 29 Apr 2026 23:16:40 +0000 Subject: [PATCH] Support specifying custom tile sizes for forward and backward passes of Tokamax GMM in MaxText --- src/maxtext/configs/types.py | 1 + src/maxtext/layers/moe.py | 39 +++++++++++++++---- src/maxtext/models/deepseek_batchsplit_fp8.py | 39 +++++++++++++++---- 3 files changed, 63 insertions(+), 16 deletions(-) diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 5db6067dde..3f44a41d40 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -557,6 +557,7 @@ class Attention(BaseModel): False, description="Whether to use the Tokamax library for GMM kernel implementation.", ) + tokamax_gmm_autotune: bool = Field(False, description="Whether to use tokamax auto-tuner for GMM.") ragged_block_size: int = Field(256, description="Block size for ragged attention.") enable_padding_causal_mask: bool = Field(True, description="Temporary flag for TE padding.") use_tokamax_splash: bool = Field(False, description="Whether to use tokamax splash attention.") diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index a08c1d10ff..6930097f83 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -43,6 +43,8 @@ from qwix.contrib.sparsity import sparsity_module import qwix.pallas as qpl import tokamax +from tokamax import config as tokamax_config +from tokamax._src.ops.ragged_dot.pallas_mosaic_tpu import PallasMosaicTpuRaggedDot, Config set_xla_metadata = xla_metadata.set_xla_metadata @@ -1104,14 +1106,35 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_a weight_gather_axes=weight_gather_axes, ) else: # tokamax (unquantized) - output = tokamax.ragged_dot( - lhs=inputs, - rhs=kernel, - group_sizes=tokamax_group_sizes, - precision=jax.lax.Precision.DEFAULT, - preferred_element_type=self.dtype, - implementation="mosaic", - ) + if self.config.tokamax_gmm_autotune: + with tokamax_config.autotuning_cache_miss_fallback("heuristics"): + output = tokamax.ragged_dot( + lhs=inputs, + rhs=kernel, + group_sizes=tokamax_group_sizes, + precision=jax.lax.Precision.DEFAULT, + preferred_element_type=self.dtype, + implementation="mosaic", + ) + else: + custom_impl = PallasMosaicTpuRaggedDot( + config=Config( + # Forward Pass + gmm_tiling=(tiling[0], tiling[1], tiling[2]), + # Backward DLHS Pass + gmm_rhs_transpose_tiling=(tiling[3], tiling[4], tiling[5]), + # Backward DRHS Pass + tgmm_tiling=(tiling[6], tiling[7], tiling[8]), + ) + ) + output = tokamax.ragged_dot( + lhs=inputs, + rhs=kernel, + group_sizes=tokamax_group_sizes, + precision=jax.lax.Precision.DEFAULT, + preferred_element_type=self.dtype, + implementation=custom_impl, + ) elif self.config.megablox: # Older forked megablox output = mblx.gmm( lhs=inputs, diff --git a/src/maxtext/models/deepseek_batchsplit_fp8.py b/src/maxtext/models/deepseek_batchsplit_fp8.py index 2d55536440..91c2e38a84 100644 --- a/src/maxtext/models/deepseek_batchsplit_fp8.py +++ b/src/maxtext/models/deepseek_batchsplit_fp8.py @@ -29,6 +29,8 @@ from maxtext.layers import quantizations import qwix.pallas as qpl import tokamax +from tokamax import config as tokamax_config +from tokamax._src.ops.ragged_dot.pallas_mosaic_tpu import PallasMosaicTpuRaggedDot, Config @functools.partial( @@ -962,14 +964,35 @@ def gmm( qwix_rule=quantizations.get_fp8_full_qwix_rule_w_sparsity(config), ) else: - output = tokamax.ragged_dot( - lhs=inputs, - rhs=kernel, - group_sizes=tokamax.RaggedDotGroupSizes(group_sizes, len(inputs)), - precision=jax.lax.Precision.DEFAULT, - preferred_element_type=preferred_element_type, - implementation="mosaic", - ) + if config.tokamax_gmm_autotune: + with tokamax_config.autotuning_cache_miss_fallback("heuristics"): + output = tokamax.ragged_dot( + lhs=inputs, + rhs=kernel, + group_sizes=tokamax.RaggedDotGroupSizes(group_sizes, len(inputs)), + precision=jax.lax.Precision.DEFAULT, + preferred_element_type=preferred_element_type, + implementation="mosaic", + ) + else: + custom_impl = PallasMosaicTpuRaggedDot( + config=Config( + # Forward Pass + gmm_tiling=(tiling[0], tiling[1], tiling[2]), + # Backward DLHS Pass + gmm_rhs_transpose_tiling=(tiling[3], tiling[4], tiling[5]), + # Backward DRHS Pass + tgmm_tiling=(tiling[6], tiling[7], tiling[8]), + ) + ) + output = tokamax.ragged_dot( + lhs=inputs, + rhs=kernel, + group_sizes=tokamax.RaggedDotGroupSizes(group_sizes, len(inputs)), + precision=jax.lax.Precision.DEFAULT, + preferred_element_type=preferred_element_type, + implementation=custom_impl, + ) return output gmm_fn = functools.partial(gmm, group_sizes=group_sizes, preferred_element_type=config.dtype)