From b257963f68b5fcf1921671fd55b44eeda2b0be47 Mon Sep 17 00:00:00 2001 From: continuousml Date: Sun, 21 Jun 2026 20:04:46 -0700 Subject: [PATCH] Allow Tokamax GMM with pipeline FSDP AG per repeat --- src/maxtext/layers/moe.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 7e69d0b938..7294835ead 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -1247,10 +1247,7 @@ def jax_ragged_dot_gmm(inputs, kernel, tiling, group_sizes, expert_assignments, return output def get_tokamax_group_sizes(group_sizes, inputs, kernel): - # TODO (b/491979205) pipeline fsdp ag per repeat fails tokamax gmm - if self.config.use_qwix_quantization or ( - self.config.using_pipeline_parallelism and self.config.pipeline_fsdp_ag_per_repeat - ): + if self.config.use_qwix_quantization: return group_sizes elif self.config.attention == "vllm_rpa": return group_sizes