From e41a72c0524e12d33b5d1158d0d36c6b1223e768 Mon Sep 17 00:00:00 2001 From: continuousml Date: Sun, 21 Jun 2026 17:06:01 -0700 Subject: [PATCH] Add ragged all-to-all pipeline correctness coverage --- .../integration/pipeline_parallelism_test.py | 95 +++++++++++++------ 1 file changed, 66 insertions(+), 29 deletions(-) diff --git a/tests/integration/pipeline_parallelism_test.py b/tests/integration/pipeline_parallelism_test.py index 3de6c0ce96..5de813e0cf 100644 --- a/tests/integration/pipeline_parallelism_test.py +++ b/tests/integration/pipeline_parallelism_test.py @@ -24,6 +24,7 @@ from flax import linen as nn from flax import nnx from flax.core import meta +from flax.linen import partitioning as nn_partitioning import jax import jax.numpy as jnp from jax.sharding import Mesh @@ -154,23 +155,24 @@ def stage_factory(stage_rngs): return raw_stage_class(config=config, mesh=mesh, model_mode=model_mode, rngs=stage_rngs) my_pipeline = pipeline.create_pipeline(config=config, layers=stage_factory, mesh=mesh) - init_pipeline_params = my_pipeline.init( - jax.random.PRNGKey(0), inputs, inputs_position, inputs_segmentation, deterministic, model_mode - ) - # `get_weight_sharding` is a compact method on `PipelineLinen` (callable as - # `my_pipeline.get_weight_sharding(...)` directly) but on the ToLinen-wrapped NNX - # pipeline it must be invoked inside a bound module context. Use `bind` so the - # same call shape works on both paths. - logical_partition_spec = my_pipeline.bind(init_pipeline_params).get_weight_sharding( - inputs, inputs_position, inputs_segmentation, deterministic, model_mode - ) + with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): + init_pipeline_params = my_pipeline.init( + jax.random.PRNGKey(0), inputs, inputs_segmentation, inputs_position, deterministic, model_mode + ) + # `get_weight_sharding` is a compact method on `PipelineLinen` (callable as + # `my_pipeline.get_weight_sharding(...)` directly) but on the ToLinen-wrapped NNX + # pipeline it must be invoked inside a bound module context. Use `bind` so the + # same call shape works on both paths. + logical_partition_spec = my_pipeline.bind(init_pipeline_params).get_weight_sharding( + inputs, inputs_segmentation, inputs_position, deterministic, model_mode + ) # Create a dummy scalar loss function so we may take the gradient wrt weights def pipeline_parallelism_dummy_loss_extra( params, inputs, - inputs_position, inputs_segmentation, + inputs_position, deterministic, model_mode, dummy_targets, @@ -179,8 +181,8 @@ def pipeline_parallelism_dummy_loss_extra( outputs = my_pipeline.apply( params, inputs, - inputs_position, inputs_segmentation, + inputs_position, deterministic, model_mode, logical_partition_spec=logical_partition_spec, @@ -192,7 +194,7 @@ def pipeline_parallelism_dummy_loss_extra( pipeline_parallelism_dummy_loss_extra, logical_partition_spec=logical_partition_spec ) - def regular_sequential_layers(params, inputs, inputs_position, inputs_segmentation, deterministic, model_mode): + def regular_sequential_layers(params, inputs, inputs_segmentation, inputs_position, deterministic, model_mode): def get_cur_layer_params(params, layer_idx): def get_cur_layer_params_arr(leaf): # Reshape layers into a linear list of layers, e.g. [repeat, stage] into [layers] @@ -213,34 +215,35 @@ def get_cur_layer_params_arr(leaf): for layer in range(config.num_decoder_layers): cur_layer_params = get_cur_layer_params(params, layer) cur_layer_params["params"] = cur_layer_params["params"]["layers"] - if config.num_pipeline_repeats > 1 and config.num_layers_per_pipeline_stage > 1: + if config.num_pipeline_repeats > 1: cur_layer_params["params"] = meta.remove_axis( cur_layer_params["params"], 0, {nn.PARTITION_NAME: "circular_repeats"} ) - cur_layer_params["params"] = meta.remove_axis(cur_layer_params["params"], 0, {nn.PARTITION_NAME: "layers"}) + cur_layer_params["params"] = meta.remove_axis(cur_layer_params["params"], 0, {nn.PARTITION_NAME: "layers"}) reg_layer_activations, _ = single_pipeline_stage.apply( - cur_layer_params, reg_layer_activations, inputs_position, inputs_segmentation, deterministic, model_mode + cur_layer_params, reg_layer_activations, inputs_segmentation, inputs_position, deterministic, model_mode ) return reg_layer_activations def regular_sequential_layers_dummy_loss( - params, inputs, inputs_position, inputs_segmentation, deterministic, model_mode, dummy_targets + params, inputs, inputs_segmentation, inputs_position, deterministic, model_mode, dummy_targets ): - outputs = regular_sequential_layers(params, inputs, inputs_position, inputs_segmentation, deterministic, model_mode) + outputs = regular_sequential_layers(params, inputs, inputs_segmentation, inputs_position, deterministic, model_mode) loss = jnp.linalg.norm(outputs - dummy_targets) return loss - assert_same_output_and_grad( - regular_sequential_layers_dummy_loss, - pipeline_parallelism_dummy_loss, - init_pipeline_params, - inputs, - inputs_segmentation, - inputs_position, - deterministic, - model_mode, - dummy_targets, - ) + with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): + assert_same_output_and_grad( + regular_sequential_layers_dummy_loss, + pipeline_parallelism_dummy_loss, + init_pipeline_params, + inputs, + inputs_segmentation, + inputs_position, + deterministic, + model_mode, + dummy_targets, + ) def assert_pipeline_same_output_and_grad(self, config, single_pipeline_stage_class=None): """ @@ -342,6 +345,40 @@ def test_circular_deepseek_megablox_same_output_and_grad(self): config, single_pipeline_stage_class=deepseek.DeepSeekMoELayerToLinen ) + @pytest.mark.tpu_only + def test_deepseek_ragged_a2a_ep_same_output_and_grad(self): + config = pyconfig.initialize( + [sys.argv[0], get_test_config_path()], + enable_checkpointing=False, + enable_goodput_recording=False, + run_name="deepseek_ragged_a2a_ep", + dtype="bfloat16", + weight_dtype="bfloat16", + matmul_precision="high", + max_target_length=128, + base_emb_dim=256, + ici_pipeline_parallelism=2, + ici_expert_parallelism=2, + allow_split_physical_axes=True, + base_num_decoder_layers=2, + num_pipeline_microbatches=4, + per_device_batch_size=4, + num_experts=4, + num_experts_per_tok=2, + shared_experts=1, + megablox=True, + sparse_matmul=True, + capacity_factor=-1, + decoder_block="deepseek", + attention_type="mla", + base_moe_mlp_dim=256, + base_mlp_dim=256, + ) + self.assert_pipeline_same_output_and_grad( + config, + single_pipeline_stage_class=deepseek.DeepSeekMoELayerToLinen, + ) + @pytest.mark.tpu_only def test_circular_ag_once(self): # 2 stages, 8 microbatches, all gather once