Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,7 @@ class Attention(BaseModel):
False,
description="Whether to use the Tokamax library for GMM kernel implementation.",
)
tokamax_gmm_autotune: bool = Field(False, description="Whether to use tokamax auto-tuner for GMM.")
ragged_block_size: int = Field(256, description="Block size for ragged attention.")
enable_padding_causal_mask: bool = Field(True, description="Temporary flag for TE padding.")
use_tokamax_splash: bool = Field(False, description="Whether to use tokamax splash attention.")
Expand Down
39 changes: 31 additions & 8 deletions src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
from qwix.contrib.sparsity import sparsity_module
import qwix.pallas as qpl
import tokamax
from tokamax import config as tokamax_config
from tokamax._src.ops.ragged_dot.pallas_mosaic_tpu import PallasMosaicTpuRaggedDot, Config

set_xla_metadata = xla_metadata.set_xla_metadata

Expand Down Expand Up @@ -1104,14 +1106,35 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_a
weight_gather_axes=weight_gather_axes,
)
else: # tokamax (unquantized)
output = tokamax.ragged_dot(
lhs=inputs,
rhs=kernel,
group_sizes=tokamax_group_sizes,
precision=jax.lax.Precision.DEFAULT,
preferred_element_type=self.dtype,
implementation="mosaic",
)
if self.config.tokamax_gmm_autotune:
with tokamax_config.autotuning_cache_miss_fallback("heuristics"):
output = tokamax.ragged_dot(
lhs=inputs,
rhs=kernel,
group_sizes=tokamax_group_sizes,
precision=jax.lax.Precision.DEFAULT,
preferred_element_type=self.dtype,
implementation="mosaic",
)
else:
custom_impl = PallasMosaicTpuRaggedDot(
config=Config(
# Forward Pass
gmm_tiling=(tiling[0], tiling[1], tiling[2]),
# Backward DLHS Pass
gmm_rhs_transpose_tiling=(tiling[3], tiling[4], tiling[5]),
# Backward DRHS Pass
tgmm_tiling=(tiling[6], tiling[7], tiling[8]),
)
)
output = tokamax.ragged_dot(
lhs=inputs,
rhs=kernel,
group_sizes=tokamax_group_sizes,
precision=jax.lax.Precision.DEFAULT,
preferred_element_type=self.dtype,
implementation=custom_impl,
)
elif self.config.megablox: # Older forked megablox
output = mblx.gmm(
lhs=inputs,
Expand Down
39 changes: 31 additions & 8 deletions src/maxtext/models/deepseek_batchsplit_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from maxtext.layers import quantizations
import qwix.pallas as qpl
import tokamax
from tokamax import config as tokamax_config
from tokamax._src.ops.ragged_dot.pallas_mosaic_tpu import PallasMosaicTpuRaggedDot, Config


@functools.partial(
Expand Down Expand Up @@ -962,14 +964,35 @@ def gmm(
qwix_rule=quantizations.get_fp8_full_qwix_rule_w_sparsity(config),
)
else:
output = tokamax.ragged_dot(
lhs=inputs,
rhs=kernel,
group_sizes=tokamax.RaggedDotGroupSizes(group_sizes, len(inputs)),
precision=jax.lax.Precision.DEFAULT,
preferred_element_type=preferred_element_type,
implementation="mosaic",
)
if config.tokamax_gmm_autotune:
with tokamax_config.autotuning_cache_miss_fallback("heuristics"):
output = tokamax.ragged_dot(
lhs=inputs,
rhs=kernel,
group_sizes=tokamax.RaggedDotGroupSizes(group_sizes, len(inputs)),
precision=jax.lax.Precision.DEFAULT,
preferred_element_type=preferred_element_type,
implementation="mosaic",
)
else:
custom_impl = PallasMosaicTpuRaggedDot(
config=Config(
# Forward Pass
gmm_tiling=(tiling[0], tiling[1], tiling[2]),
# Backward DLHS Pass
gmm_rhs_transpose_tiling=(tiling[3], tiling[4], tiling[5]),
# Backward DRHS Pass
tgmm_tiling=(tiling[6], tiling[7], tiling[8]),
)
)
output = tokamax.ragged_dot(
lhs=inputs,
rhs=kernel,
group_sizes=tokamax.RaggedDotGroupSizes(group_sizes, len(inputs)),
precision=jax.lax.Precision.DEFAULT,
preferred_element_type=preferred_element_type,
implementation=custom_impl,
)
return output

gmm_fn = functools.partial(gmm, group_sizes=group_sizes, preferred_element_type=config.dtype)
Expand Down
Loading