Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/maxtext/configs/decoupled_base_test.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Decoupled base test config: used when DECOUPLE_GCLOUD=TRUE for tests that previously relied on base.yml.
# Inherit all model defaults (PyDantic already does this) but override any cloud-coupled paths and disable
# optional cloud features.
base_config: base.yml

# Output goes to a local relative directory so tests do not require GCS.
base_output_directory: ./maxtext_local_output/gcloud_decoupled_test_logs
Expand Down Expand Up @@ -44,7 +45,7 @@ ici_context_parallelism: 1
ici_tensor_transpose_parallelism: 1
ici_tensor_sequence_parallelism: 1
ici_autoregressive_parallelism: 1
ici_fsdp_parallelism: 1
ici_fsdp_parallelism: -1
ici_fsdp_transpose_parallelism: 1
# Allow higher unsharded parameter percentage for small device count
sharding_tolerance: 0.3
Expand Down
2 changes: 2 additions & 0 deletions src/maxtext/input_pipeline/tfds_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def make_tfds_eval_iterator(
if not config.colocated_python_data_input:
eval_ds = get_datasets(
dataset_name=config.eval_dataset_name,
dataset_path=config.dataset_path,
data_split=config.eval_split,
shuffle_files=False,
shuffle_seed=config.data_shuffle_seed,
Expand Down Expand Up @@ -285,6 +286,7 @@ def make_tfds_eval_iterator(
get_ds_fn = functools.partial(
get_datasets,
dataset_name=config.eval_dataset_name,
dataset_path=config.dataset_path,
data_split=config.eval_split,
shuffle_files=False,
shuffle_seed=config.data_shuffle_seed,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,4 +323,4 @@ def load_weights(self, rng_key: jax.Array) -> None:
model = model_creation_utils.from_pretrained(
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
)
self.model = nnx.data(model)
self.model = nnx.data(model)
21 changes: 15 additions & 6 deletions src/maxtext/layers/attention_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1555,14 +1555,23 @@ def cudnn_flash_attention(
qkv_layout = "THD_THD_THD" # Packed format: 'T3HD', 'THD_T2HD' or 'THD_THD_THD'
if decoder_segment_ids is None:
decoder_segment_ids = jnp.ones(shape=query.shape[:2], dtype=jnp.int32)
attn_mask = SequenceDescriptor.from_segment_ids_and_pos(
segment_ids=decoder_segment_ids, segment_pos=segment_positions
)

# TE 2.12+ requires THD metadata; older TE versions infer it.
def _sequence_descriptor(segment_ids):
try:
return SequenceDescriptor.from_segment_ids_and_pos(
segment_ids=segment_ids,
segment_pos=segment_positions,
is_thd=True,
is_segment_ids_reordered=False,
)
except TypeError:
return SequenceDescriptor.from_segment_ids_and_pos(segment_ids=segment_ids, segment_pos=segment_positions)

attn_mask = _sequence_descriptor(decoder_segment_ids)
# Create dummy SequenceDescriptor for lazy_init
dummy_segment_ids = jnp.ones(shape=query.shape[:2], dtype=jnp.int32)
dummy_attn_mask = SequenceDescriptor.from_segment_ids_and_pos(
segment_ids=dummy_segment_ids, segment_pos=segment_positions
)
dummy_attn_mask = _sequence_descriptor(dummy_segment_ids)
max_segments_per_seq = self.config.max_segments_per_seq
elif using_context_parallelism:
if self.attention_type == AttentionType.LOCAL_SLIDING:
Expand Down
6 changes: 0 additions & 6 deletions tests/integration/checkpointing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from tests.utils.test_helpers import (
get_test_config_path,
get_test_base_output_directory,
get_decoupled_parallelism_overrides,
)


Expand Down Expand Up @@ -73,10 +72,6 @@ def get_checkpointing_command(run_date, hardware, steps, metrics_file, attention
"checkpoint_storage_use_zarr3=False",
]

extra_parallelism = []
if is_decoupled(): # Match device topology in decoupled/local mode
extra_parallelism.extend(get_decoupled_parallelism_overrides(as_argv=True))

return (
[
None,
Expand All @@ -96,7 +91,6 @@ def get_checkpointing_command(run_date, hardware, steps, metrics_file, attention
]
+ model_params
+ pathways_command
+ extra_parallelism
)


Expand Down
99 changes: 56 additions & 43 deletions tests/integration/train_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
get_test_config_path,
get_test_dataset_path,
get_test_base_output_directory,
get_decoupled_parallelism_overrides,
is_rocm_backend,
)

Expand All @@ -39,14 +38,6 @@ class TrainTests(unittest.TestCase):
_base_output_directory = get_test_base_output_directory()
dataset_path = get_test_dataset_path()

# FSDP override logic for tensor-parallel=4 configs: provide an axis only when cleanly divisible.
_fsdp_tp4_override = []
if decoupled:
if dev_count >= 4 and dev_count % 4 == 0:
_fsdp_tp4_override = get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count // 4, as_argv=True)
elif dev_count < 4:
_fsdp_tp4_override = get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True)

CONFIGS = {
"base": [ # short test for train.py with TFDS c4
None,
Expand All @@ -58,8 +49,7 @@ class TrainTests(unittest.TestCase):
"enable_checkpointing=False",
"enable_goodput_recording=False",
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
]
+ get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True),
],
"synthetic": [ # tests base config with synthetic dataset
None,
get_test_config_path(),
Expand All @@ -70,8 +60,7 @@ class TrainTests(unittest.TestCase):
"enable_goodput_recording=False",
"dataset_type=synthetic",
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
]
+ get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True),
],
"pdb_lt_1": [ # tests base config with per_device_batch_size < 1
None,
get_test_config_path(),
Expand All @@ -84,8 +73,7 @@ class TrainTests(unittest.TestCase):
"per_device_batch_size=0.25",
"ici_tensor_parallelism=4",
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
]
+ get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True),
],
"tp_transpose": [ # tests base config with ici_tensor_transpose_parallelism=4
None,
get_test_config_path(),
Expand All @@ -96,8 +84,7 @@ class TrainTests(unittest.TestCase):
"ici_tensor_transpose_parallelism=4",
"enable_goodput_recording=False",
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
]
+ get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True),
],
"int8": [ # tests base config with int8
None,
get_test_config_path(),
Expand All @@ -109,8 +96,7 @@ class TrainTests(unittest.TestCase):
"enable_checkpointing=False",
"enable_goodput_recording=False",
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
]
+ get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True),
],
"fp8": [ # tests base config with fp8
None,
get_test_config_path(),
Expand All @@ -122,8 +108,7 @@ class TrainTests(unittest.TestCase):
"enable_checkpointing=False",
"enable_goodput_recording=False",
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
]
+ get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True),
],
"nanoo_fp8": [ # tests base config with nanoo_fp8
None,
get_test_config_path(),
Expand All @@ -135,8 +120,7 @@ class TrainTests(unittest.TestCase):
"enable_checkpointing=False",
"enable_goodput_recording=False",
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
]
+ get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True),
],
"te_fp8_delayedscaling": [ # tests base config with te_fp8_delayedscaling
None,
get_test_config_path(),
Expand All @@ -148,8 +132,7 @@ class TrainTests(unittest.TestCase):
"enable_checkpointing=False",
"enable_goodput_recording=False",
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
]
+ get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True),
],
"te_fp8_currentscaling": [ # tests base config with te_fp8_currentscaling
None,
get_test_config_path(),
Expand All @@ -161,8 +144,7 @@ class TrainTests(unittest.TestCase):
"enable_checkpointing=False",
"enable_goodput_recording=False",
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
]
+ get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True),
],
"te_mxfp8": [ # tests base config with te_mxfp8
None,
get_test_config_path(),
Expand All @@ -174,8 +156,7 @@ class TrainTests(unittest.TestCase):
"enable_checkpointing=False",
"enable_goodput_recording=False",
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
]
+ get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True),
],
"dropout": [ # tests base config with dropout
None,
get_test_config_path(),
Expand All @@ -189,8 +170,7 @@ class TrainTests(unittest.TestCase):
"per_device_batch_size=1",
"dropout_rate=0.02",
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
]
+ get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True),
],
"hf_input_pipeline": [ # test for train.py with TFDS c4, using HF input pipeline
None,
get_test_config_path(),
Expand All @@ -203,8 +183,7 @@ class TrainTests(unittest.TestCase):
"hf_path=parquet",
f"hf_train_files={dataset_path}/hf/c4/c4-train-00000-of-01637.parquet",
"tokenizer_path=google-t5/t5-large",
]
+ get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True),
],
}

@pytest.mark.integration_test
Expand Down Expand Up @@ -432,11 +411,16 @@ def test_gpu_optimizer_offload(self):
"enable_goodput_recording=False",
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
]
train_main(optimizer_offload + get_decoupled_parallelism_overrides(fsdp_parallelism=self.dev_count, as_argv=True))
train_main(optimizer_offload)

@pytest.mark.integration_test
@pytest.mark.gpu_only
def test_gpu_parameter_offload(self):
if is_rocm_backend():
# JAX 0.9.1 MSIT enforces memory_space typematch across VJP; MaxText's
# pinned_host params + device compute mismatch the cotangent at the jit
# boundary.
pytest.skip("Parameter memory host offload: JAX MSIT VJP typematch fails for pinned_host params.")
os.environ["NVTE_FUSED_ATTN"] = "1" # Enable fused attention
parameter_offload = [ # tests base config on GPU with parameter offload
None,
Expand All @@ -453,7 +437,7 @@ def test_gpu_parameter_offload(self):
"enable_goodput_recording=False",
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
]
train_main(parameter_offload + get_decoupled_parallelism_overrides(fsdp_parallelism=self.dev_count, as_argv=True))
train_main(parameter_offload)

@pytest.mark.gpu_only
def test_gpu_cudnn_flash_jax(self):
Expand Down Expand Up @@ -573,8 +557,7 @@ def test_gpu_packed_attention(self):
@pytest.mark.gpu_only
@pytest.mark.skip(reason="b/489133823. Previously transient in b/462548581.")
def test_gpu_ring_attention(self):
if is_rocm_backend():
pytest.skip("TE ring attention context parallelism not supported on ROCm.")
rocm_backend = is_rocm_backend()
os.environ["NVTE_FUSED_ATTN"] = "1" # Enable fused attention
os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "0" # Disable scan for ring attention
ring_attention = [ # tests base config on GPU with ring attention
Expand All @@ -583,7 +566,7 @@ def test_gpu_ring_attention(self):
f"base_output_directory={self._base_output_directory}",
"run_name=runner_test",
"dataset_type=synthetic", # use synthetic dataset_type to decrease training time
"steps=10",
"steps=1" if rocm_backend else "steps=10",
"enable_checkpointing=False",
"enable_goodput_recording=False",
"attention=cudnn_flash_te",
Expand All @@ -595,15 +578,32 @@ def test_gpu_ring_attention(self):
"hardware=gpu",
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
]
if rocm_backend:
# Keep the ROCm ring-attention smoke test small enough to avoid long TE/XLA compile times.
ring_attention.extend(
[
"max_target_length=512",
"base_emb_dim=1024",
"base_mlp_dim=4096",
"base_num_query_heads=8",
"base_num_kv_heads=8",
"base_num_decoder_layers=2",
]
)
train_main(ring_attention)

@pytest.mark.integration_test
@pytest.mark.gpu_only
def test_gpu_ring_attention_with_packing(self):
gpu_device = jax.devices("gpu")[0]
compute_capability = gpu_device.compute_capability
if float(compute_capability) < 9.0:
pytest.skip("Ring attention with packing is only supported on sm90+!")
rocm_backend = is_rocm_backend()
if not rocm_backend:
gpu_device = jax.devices("gpu")[0]
compute_capability = getattr(gpu_device, "compute_capability", None)
try:
if float(compute_capability) < 9.0:
pytest.skip("Ring attention with packing is only supported on sm90+!")
except Exception: # pylint: disable=broad-exception-caught
pytest.skip("Ring attention with packing is only supported on sm90+!")
os.environ["NVTE_FUSED_ATTN"] = "1" # Enable fused attention
os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "0" # Disable scan for ring attention
thd_ring_attention = [ # tests base config on GPU with ring attention + packing
Expand All @@ -612,7 +612,7 @@ def test_gpu_ring_attention_with_packing(self):
f"base_output_directory={self._base_output_directory}",
"run_name=runner_test",
f"dataset_path={self.dataset_path}",
"steps=10",
"steps=1" if rocm_backend else "steps=10",
"enable_checkpointing=False",
"enable_goodput_recording=False",
"attention=cudnn_flash_te",
Expand All @@ -624,6 +624,19 @@ def test_gpu_ring_attention_with_packing(self):
"hardware=gpu",
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
]
if rocm_backend:
# Keep the ROCm packed-ring smoke test small enough to avoid long TE/XLA compile times.
thd_ring_attention.extend(
[
"max_segments_per_seq=2",
"max_target_length=512",
"base_emb_dim=1024",
"base_mlp_dim=4096",
"base_num_query_heads=8",
"base_num_kv_heads=8",
"base_num_decoder_layers=2",
]
)
train_main(thd_ring_attention)


Expand Down
Loading