From 8ef48c9f9cc5c0e80aa018e52580f812cbdc36cc Mon Sep 17 00:00:00 2001 From: continuousml Date: Sat, 20 Jun 2026 22:35:18 -0700 Subject: [PATCH] Use resolved mesh size for context parallel sharding --- src/maxtext/layers/attention_op.py | 2 +- src/maxtext/utils/train_utils.py | 7 ++++--- tests/utils/attention_test_util.py | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/maxtext/layers/attention_op.py b/src/maxtext/layers/attention_op.py index b3c3f296f4..4caddf5b97 100644 --- a/src/maxtext/layers/attention_op.py +++ b/src/maxtext/layers/attention_op.py @@ -1168,7 +1168,7 @@ def tpu_flash_attention( ) -> tuple[Array, Array]: """TPU Flash Attention.""" - cp_size = self.config.context_parallel_size + cp_size = int(self.mesh.shape.get(self.config.context_sharding, 1)) load_balanced_context_parallel = self.config.context_parallel_load_balance # Transpose to ('batch', 'heads', 'length', 'kv') diff --git a/src/maxtext/utils/train_utils.py b/src/maxtext/utils/train_utils.py index eb429f5446..9b13c8c309 100644 --- a/src/maxtext/utils/train_utils.py +++ b/src/maxtext/utils/train_utils.py @@ -214,6 +214,7 @@ def setup_train_loop(config, recorder, devices=None): is_training = True init_rng = jax.random.PRNGKey(config.init_weights_seed) mesh = maxtext_utils.get_mesh_from_config(config, devices) + context_parallel_size = int(mesh.shape.get(config.context_sharding, 1)) if config.pure_nnx: # Create abstract NNX model. _create_model_partial, model = model_creation_utils.create_nnx_abstract_model(config, mesh, devices) @@ -241,7 +242,7 @@ def create_train_state_fn(): data_iterator, eval_data_iterator = create_data_iterator(config, mesh) rampup_manager = create_rampup_manager(config, checkpoint_manager) # Validate context parallelism with packing configuration - if config.context_parallel_size > 1 and config.packing: + if context_parallel_size > 1 and config.packing: if config.dataset_type == "synthetic": raise ValueError( "Context parallelism with sequence packing is not supported with synthetic data. " @@ -255,7 +256,7 @@ def create_train_state_fn(): # Apply reordering wrapper to data iterators if context parallelism is enabled with jax.set_mesh(mesh): - if config.context_parallel_size > 1 and config.context_parallel_load_balance: + if context_parallel_size > 1 and config.context_parallel_load_balance: # Determine load balancing reorder strategy based on whether packing is enabled if config.context_parallel_reorder_strategy == ReorderStrategy.AUTO: @@ -264,7 +265,7 @@ def create_train_state_fn(): reorder_strategy = config.context_parallel_reorder_strategy reorder_fn = maxtext_utils.get_reorder_callable( - config.context_parallel_size, config.shard_mode, reorder_strategy, config.hardware + context_parallel_size, config.shard_mode, reorder_strategy, config.hardware ) data_iterator = map(reorder_fn, data_iterator) if eval_data_iterator: diff --git a/tests/utils/attention_test_util.py b/tests/utils/attention_test_util.py index 23188caf0c..ba22a9e839 100644 --- a/tests/utils/attention_test_util.py +++ b/tests/utils/attention_test_util.py @@ -187,7 +187,7 @@ def forward_with_context_expert_parallelism( """Get logits from attention under context/expert parallelism.""" # If load balanced cp, shuffle along seq dim for input # This corresponds to the pre-shuffle step in training - context_parallel_size = cfg_cp.context_parallel_size + context_parallel_size = int(mesh_cp.shape.get(cfg_cp.context_sharding, 1)) # This helper is TPU-oriented and uses the TPU-compatible DUAL_CHUNK_SWAP reorder path. # It does not model GPU-specific packed/striped reorder behavior. if context_parallel_size > 1 and cfg_cp.context_parallel_load_balance: