Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 66 additions & 29 deletions tests/integration/pipeline_parallelism_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from flax import linen as nn
from flax import nnx
from flax.core import meta
from flax.linen import partitioning as nn_partitioning
import jax
import jax.numpy as jnp
from jax.sharding import Mesh
Expand Down Expand Up @@ -154,23 +155,24 @@ def stage_factory(stage_rngs):
return raw_stage_class(config=config, mesh=mesh, model_mode=model_mode, rngs=stage_rngs)

my_pipeline = pipeline.create_pipeline(config=config, layers=stage_factory, mesh=mesh)
init_pipeline_params = my_pipeline.init(
jax.random.PRNGKey(0), inputs, inputs_position, inputs_segmentation, deterministic, model_mode
)
# `get_weight_sharding` is a compact method on `PipelineLinen` (callable as
# `my_pipeline.get_weight_sharding(...)` directly) but on the ToLinen-wrapped NNX
# pipeline it must be invoked inside a bound module context. Use `bind` so the
# same call shape works on both paths.
logical_partition_spec = my_pipeline.bind(init_pipeline_params).get_weight_sharding(
inputs, inputs_position, inputs_segmentation, deterministic, model_mode
)
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
init_pipeline_params = my_pipeline.init(
jax.random.PRNGKey(0), inputs, inputs_segmentation, inputs_position, deterministic, model_mode
)
# `get_weight_sharding` is a compact method on `PipelineLinen` (callable as
# `my_pipeline.get_weight_sharding(...)` directly) but on the ToLinen-wrapped NNX
# pipeline it must be invoked inside a bound module context. Use `bind` so the
# same call shape works on both paths.
logical_partition_spec = my_pipeline.bind(init_pipeline_params).get_weight_sharding(
inputs, inputs_segmentation, inputs_position, deterministic, model_mode
)

# Create a dummy scalar loss function so we may take the gradient wrt weights
def pipeline_parallelism_dummy_loss_extra(
params,
inputs,
inputs_position,
inputs_segmentation,
inputs_position,
deterministic,
model_mode,
dummy_targets,
Expand All @@ -179,8 +181,8 @@ def pipeline_parallelism_dummy_loss_extra(
outputs = my_pipeline.apply(
params,
inputs,
inputs_position,
inputs_segmentation,
inputs_position,
deterministic,
model_mode,
logical_partition_spec=logical_partition_spec,
Expand All @@ -192,7 +194,7 @@ def pipeline_parallelism_dummy_loss_extra(
pipeline_parallelism_dummy_loss_extra, logical_partition_spec=logical_partition_spec
)

def regular_sequential_layers(params, inputs, inputs_position, inputs_segmentation, deterministic, model_mode):
def regular_sequential_layers(params, inputs, inputs_segmentation, inputs_position, deterministic, model_mode):
def get_cur_layer_params(params, layer_idx):
def get_cur_layer_params_arr(leaf):
# Reshape layers into a linear list of layers, e.g. [repeat, stage] into [layers]
Expand All @@ -213,34 +215,35 @@ def get_cur_layer_params_arr(leaf):
for layer in range(config.num_decoder_layers):
cur_layer_params = get_cur_layer_params(params, layer)
cur_layer_params["params"] = cur_layer_params["params"]["layers"]
if config.num_pipeline_repeats > 1 and config.num_layers_per_pipeline_stage > 1:
if config.num_pipeline_repeats > 1:
cur_layer_params["params"] = meta.remove_axis(
cur_layer_params["params"], 0, {nn.PARTITION_NAME: "circular_repeats"}
)
cur_layer_params["params"] = meta.remove_axis(cur_layer_params["params"], 0, {nn.PARTITION_NAME: "layers"})
cur_layer_params["params"] = meta.remove_axis(cur_layer_params["params"], 0, {nn.PARTITION_NAME: "layers"})
reg_layer_activations, _ = single_pipeline_stage.apply(
cur_layer_params, reg_layer_activations, inputs_position, inputs_segmentation, deterministic, model_mode
cur_layer_params, reg_layer_activations, inputs_segmentation, inputs_position, deterministic, model_mode
)
return reg_layer_activations

def regular_sequential_layers_dummy_loss(
params, inputs, inputs_position, inputs_segmentation, deterministic, model_mode, dummy_targets
params, inputs, inputs_segmentation, inputs_position, deterministic, model_mode, dummy_targets
):
outputs = regular_sequential_layers(params, inputs, inputs_position, inputs_segmentation, deterministic, model_mode)
outputs = regular_sequential_layers(params, inputs, inputs_segmentation, inputs_position, deterministic, model_mode)
loss = jnp.linalg.norm(outputs - dummy_targets)
return loss

assert_same_output_and_grad(
regular_sequential_layers_dummy_loss,
pipeline_parallelism_dummy_loss,
init_pipeline_params,
inputs,
inputs_segmentation,
inputs_position,
deterministic,
model_mode,
dummy_targets,
)
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
assert_same_output_and_grad(
regular_sequential_layers_dummy_loss,
pipeline_parallelism_dummy_loss,
init_pipeline_params,
inputs,
inputs_segmentation,
inputs_position,
deterministic,
model_mode,
dummy_targets,
)

def assert_pipeline_same_output_and_grad(self, config, single_pipeline_stage_class=None):
"""
Expand Down Expand Up @@ -342,6 +345,40 @@ def test_circular_deepseek_megablox_same_output_and_grad(self):
config, single_pipeline_stage_class=deepseek.DeepSeekMoELayerToLinen
)

@pytest.mark.tpu_only
def test_deepseek_ragged_a2a_ep_same_output_and_grad(self):
config = pyconfig.initialize(
[sys.argv[0], get_test_config_path()],
enable_checkpointing=False,
enable_goodput_recording=False,
run_name="deepseek_ragged_a2a_ep",
dtype="bfloat16",
weight_dtype="bfloat16",
matmul_precision="high",
max_target_length=128,
base_emb_dim=256,
ici_pipeline_parallelism=2,
ici_expert_parallelism=2,
allow_split_physical_axes=True,
base_num_decoder_layers=2,
num_pipeline_microbatches=4,
per_device_batch_size=4,
num_experts=4,
num_experts_per_tok=2,
shared_experts=1,
megablox=True,
sparse_matmul=True,
capacity_factor=-1,
decoder_block="deepseek",
attention_type="mla",
base_moe_mlp_dim=256,
base_mlp_dim=256,
)
self.assert_pipeline_same_output_and_grad(
config,
single_pipeline_stage_class=deepseek.DeepSeekMoELayerToLinen,
)

@pytest.mark.tpu_only
def test_circular_ag_once(self):
# 2 stages, 8 microbatches, all gather once
Expand Down
Loading