diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 7e69d0b938..7294835ead 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -1247,10 +1247,7 @@ def jax_ragged_dot_gmm(inputs, kernel, tiling, group_sizes, expert_assignments, return output def get_tokamax_group_sizes(group_sizes, inputs, kernel): - # TODO (b/491979205) pipeline fsdp ag per repeat fails tokamax gmm - if self.config.use_qwix_quantization or ( - self.config.using_pipeline_parallelism and self.config.pipeline_fsdp_ag_per_repeat - ): + if self.config.use_qwix_quantization: return group_sizes elif self.config.attention == "vllm_rpa": return group_sizes