Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/maxtext/layers/attention_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,7 +1168,7 @@ def tpu_flash_attention(
) -> tuple[Array, Array]:
"""TPU Flash Attention."""

cp_size = self.config.context_parallel_size
cp_size = int(self.mesh.shape.get(self.config.context_sharding, 1))
load_balanced_context_parallel = self.config.context_parallel_load_balance

# Transpose to ('batch', 'heads', 'length', 'kv')
Expand Down
7 changes: 4 additions & 3 deletions src/maxtext/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def setup_train_loop(config, recorder, devices=None):
is_training = True
init_rng = jax.random.PRNGKey(config.init_weights_seed)
mesh = maxtext_utils.get_mesh_from_config(config, devices)
context_parallel_size = int(mesh.shape.get(config.context_sharding, 1))
if config.pure_nnx:
# Create abstract NNX model.
_create_model_partial, model = model_creation_utils.create_nnx_abstract_model(config, mesh, devices)
Expand Down Expand Up @@ -241,7 +242,7 @@ def create_train_state_fn():
data_iterator, eval_data_iterator = create_data_iterator(config, mesh)
rampup_manager = create_rampup_manager(config, checkpoint_manager)
# Validate context parallelism with packing configuration
if config.context_parallel_size > 1 and config.packing:
if context_parallel_size > 1 and config.packing:
if config.dataset_type == "synthetic":
raise ValueError(
"Context parallelism with sequence packing is not supported with synthetic data. "
Expand All @@ -255,7 +256,7 @@ def create_train_state_fn():

# Apply reordering wrapper to data iterators if context parallelism is enabled
with jax.set_mesh(mesh):
if config.context_parallel_size > 1 and config.context_parallel_load_balance:
if context_parallel_size > 1 and config.context_parallel_load_balance:

# Determine load balancing reorder strategy based on whether packing is enabled
if config.context_parallel_reorder_strategy == ReorderStrategy.AUTO:
Expand All @@ -264,7 +265,7 @@ def create_train_state_fn():
reorder_strategy = config.context_parallel_reorder_strategy

reorder_fn = maxtext_utils.get_reorder_callable(
config.context_parallel_size, config.shard_mode, reorder_strategy, config.hardware
context_parallel_size, config.shard_mode, reorder_strategy, config.hardware
)
data_iterator = map(reorder_fn, data_iterator)
if eval_data_iterator:
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/attention_test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def forward_with_context_expert_parallelism(
"""Get logits from attention under context/expert parallelism."""
# If load balanced cp, shuffle along seq dim for input
# This corresponds to the pre-shuffle step in training
context_parallel_size = cfg_cp.context_parallel_size
context_parallel_size = int(mesh_cp.shape.get(cfg_cp.context_sharding, 1))
# This helper is TPU-oriented and uses the TPU-compatible DUAL_CHUNK_SWAP reorder path.
# It does not model GPU-specific packed/striped reorder behavior.
if context_parallel_size > 1 and cfg_cp.context_parallel_load_balance:
Expand Down
Loading