From 684de8f77f909f709f05372825029f1e08899c7b Mon Sep 17 00:00:00 2001 From: Alex Shraer Date: Tue, 12 May 2026 05:39:02 +0000 Subject: [PATCH] Optimize MaxText unit and integration test suite runtime --- src/maxtext/utils/maxtext_utils.py | 15 +++- tests/integration/aot_identical_test.py | 32 ++++++-- .../integration/checkpoint_resharding_test.py | 12 +-- tests/integration/checkpointing_test.py | 12 +-- .../deepseek_scan_engram_test.py | 74 +++++++++++------ tests/integration/determinism_test.py | 12 ++- tests/{unit => integration}/diloco_test.py | 15 ++++ .../generate_param_only_checkpoint_test.py | 14 ++-- .../integration/gradient_accumulation_test.py | 10 ++- tests/{unit => integration}/maxengine_test.py | 3 +- .../pipeline_parallelism_test.py | 5 +- tests/{ => integration}/sparsity_test.py | 19 ++--- tests/integration/standalone_dl_ckpt_test.py | 7 +- tests/integration/train_tests.py | 34 +++++++- tests/unit/attention_test.py | 36 +++++---- tests/unit/context_parallelism_test.py | 18 ++--- tests/unit/grain_data_processing_test.py | 31 ++++--- tests/unit/hf_data_processing_test.py | 12 ++- tests/unit/maxtext_utils_test.py | 35 ++++++++ tests/unit/model_creation_utils_test.py | 2 +- tests/unit/model_test.py | 4 +- tests/unit/multi_token_prediction_test.py | 13 +++ tests/unit/quantizations_test.py | 81 ++++++++++--------- tests/unit/tfds_data_processing_test.py | 29 ++++++- tests/unit/tiling_test.py | 7 +- tests/unit/train_compile_test.py | 10 ++- 26 files changed, 385 insertions(+), 157 deletions(-) rename tests/{unit => integration}/deepseek_scan_engram_test.py (76%) rename tests/{unit => integration}/diloco_test.py (95%) rename tests/{unit => integration}/maxengine_test.py (99%) rename tests/{unit => integration}/pipeline_parallelism_test.py (99%) rename tests/{ => integration}/sparsity_test.py (90%) diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index f4aa2cf18a..d8648c7cd3 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -1936,8 +1936,19 @@ def maybe_dump_jaxpr(config, p_train_step, train_step_inputs): return max_logging.log("Tracing train_step to jaxpr...") - # We use the p_train_step (the JIT-decorated function) - p_train_jaxpr = jax.make_jaxpr(p_train_step)(*train_step_inputs) + # Trace the underlying un-jitted function via __wrapped__ to avoid heavy remote + # compilation/gRPC round-trips to the Pathways controller. + unwrapped_step = getattr(p_train_step, "__wrapped__", p_train_step) + + def to_abstract(x): + if hasattr(x, "shape") and hasattr(x, "dtype"): + return jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype) + return x + + # Convert all input arguments recursively to purely local abstract ShapeDtypeStruct objects + # to completely bypass remote Array objects and proxy tracing overhead. + abstract_inputs = jax.tree.map(to_abstract, train_step_inputs) + p_train_jaxpr = jax.make_jaxpr(unwrapped_step)(*abstract_inputs) local_filename = "train_step.jaxpr" local_path = os.path.join(config.dump_jaxpr_local_dir, local_filename) diff --git a/tests/integration/aot_identical_test.py b/tests/integration/aot_identical_test.py index ca95593cf3..a8e6e7b633 100644 --- a/tests/integration/aot_identical_test.py +++ b/tests/integration/aot_identical_test.py @@ -110,9 +110,13 @@ def assert_compile_and_real_match_hlo(self, test_name, *extra_args): "steps=1", "enable_checkpointing=False", "base_num_decoder_layers=1", - "max_target_length=512", - "base_emb_dim=256", - "base_mlp_dim=256", + "max_target_length=32", + "base_emb_dim=64", + "base_mlp_dim=64", + "base_num_query_heads=2", + "base_num_kv_heads=2", + "head_dim=16", + "vocab_size=128", ] + hlo_dump_args if extra_args: shared_args.extend(extra_args) @@ -179,6 +183,14 @@ def assert_compile_and_real_match_jaxpr(self, test_name, *extra_args): "enable_checkpointing=False", "dump_jaxpr=True", "dump_jaxpr_delete_local_after=False", + "base_num_decoder_layers=1", + "max_target_length=32", + "base_emb_dim=64", + "base_mlp_dim=64", + "base_num_query_heads=2", + "base_num_kv_heads=2", + "head_dim=16", + "vocab_size=128", ] if extra_args: shared_args.extend(extra_args) @@ -218,5 +230,15 @@ def assert_compile_and_real_match_jaxpr(self, test_name, *extra_args): ) @pytest.mark.tpu_only - def test_default_jaxpr_match(self): - self.assert_compile_and_real_match_jaxpr("default_run") + def test_default_jaxpr_match_mcjax(self): + if os.getenv("JAX_PLATFORMS") == "proxy": + pytest.skip("This is a McJAX test, skipping in Pathways environment.") + self.assert_compile_and_real_match_jaxpr("default_run_mcjax") + + @pytest.mark.tpu_only + @pytest.mark.scheduled_only + def test_default_jaxpr_match_pathways(self): + # Currently this test is extremely slow (b/512065615). + if os.getenv("JAX_PLATFORMS") != "proxy": + pytest.skip("This is a Pathways test, skipping in McJAX environment.") + self.assert_compile_and_real_match_jaxpr("default_run_pathways", "enable_single_controller=True") diff --git a/tests/integration/checkpoint_resharding_test.py b/tests/integration/checkpoint_resharding_test.py index b01533e46c..0f0566ba92 100644 --- a/tests/integration/checkpoint_resharding_test.py +++ b/tests/integration/checkpoint_resharding_test.py @@ -35,12 +35,12 @@ def get_resharding_command(run_date, steps, metrics_file, base_output_directory, dataset_path, parallelism_args): """Generates a command list for the resharding test run.""" model_params = [ - "base_emb_dim=384", - "base_num_query_heads=8", - "base_num_kv_heads=8", - "base_mlp_dim=192", - "base_num_decoder_layers=8", - "head_dim=128", + "base_emb_dim=128", + "base_num_query_heads=2", + "base_num_kv_heads=2", + "base_mlp_dim=128", + "base_num_decoder_layers=2", + "head_dim=64", ] return ( diff --git a/tests/integration/checkpointing_test.py b/tests/integration/checkpointing_test.py index d81d98c2b3..d9cb23ba56 100644 --- a/tests/integration/checkpointing_test.py +++ b/tests/integration/checkpointing_test.py @@ -59,12 +59,12 @@ def get_checkpointing_command(run_date, hardware, steps, metrics_file, attention """ base_output_directory = get_test_base_output_directory() model_params = [ - "base_emb_dim=384", - "base_num_query_heads=8", - "base_num_kv_heads=8", - "base_mlp_dim=192", - "base_num_decoder_layers=8", - "head_dim=128", + "base_emb_dim=128", + "base_num_query_heads=2", + "base_num_kv_heads=2", + "base_mlp_dim=128", + "base_num_decoder_layers=1", + "head_dim=64", ] pathways_command = [] if os.getenv("JAX_PLATFORMS") == "proxy": diff --git a/tests/unit/deepseek_scan_engram_test.py b/tests/integration/deepseek_scan_engram_test.py similarity index 76% rename from tests/unit/deepseek_scan_engram_test.py rename to tests/integration/deepseek_scan_engram_test.py index ff51c5c62f..b4116c23e4 100644 --- a/tests/unit/deepseek_scan_engram_test.py +++ b/tests/integration/deepseek_scan_engram_test.py @@ -41,6 +41,7 @@ def __call__(self, x, model_mode): return jnp.ones((x.shape[0], x.shape[1], self.emb_dim)) +@pytest.mark.integration_test class TestDeepSeekScanEngram(unittest.TestCase): """Test DeepSeek decoder block with scan_layers=True and engram_layers.""" @@ -53,16 +54,16 @@ class TestDeepSeekScanEngram(unittest.TestCase): "first_num_dense_layers=5", "base_num_decoder_layers=10", "num_decoder_layers=10", - "base_emb_dim=64", - "base_mlp_dim=64", - "base_moe_mlp_dim=64", - "base_num_query_heads=2", - "base_num_kv_heads=2", - "head_dim=32", - "indexer_head_dim=32", - "qk_nope_head_dim=32", - "qk_rope_head_dim=16", - "v_head_dim=32", + "base_emb_dim=8", + "base_mlp_dim=8", + "base_moe_mlp_dim=8", + "base_num_query_heads=1", + "base_num_kv_heads=1", + "head_dim=4", + "indexer_head_dim=4", + "qk_nope_head_dim=4", + "qk_rope_head_dim=4", + "v_head_dim=4", "vocab_size=128", "mhc_expansion_rate=4", "attention=dot_product", @@ -71,15 +72,24 @@ class TestDeepSeekScanEngram(unittest.TestCase): "max_prefill_predict_length=8", "enable_checkpointing=False", "engram_num_heads=1", - "engram_head_dim=8", + "engram_head_dim=4", "engram_vocab_bases=[128,128]", "engram_max_ngram_size=3", "engram_kernel_size=4", + "num_experts=2", + "num_experts_per_tok=1", "hf_access_token=dummy", "tokenizer_path=dummy", ] - def _test_engram_pattern(self, mock_from_pretrained, engram_layers_str, expected_keys): + def _test_engram_pattern( + self, + mock_from_pretrained, + engram_layers_str, + expected_keys, + first_num_dense_layers=5, + base_num_decoder_layers=10, + ): """Helper method to test different engram layer patterns.""" # Setup mock tokenizer @@ -106,7 +116,16 @@ def batch_decode(self, token_ids, *args, **kwargs): mock_from_pretrained.return_value = MockTokenizer() config_path = os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml") - config = pyconfig.initialize([None, config_path] + self._COMMON_CONFIG + [f"engram_layers=[{engram_layers_str}]"]) + config = pyconfig.initialize( + [None, config_path] + + self._COMMON_CONFIG + + [ + f"engram_layers=[{engram_layers_str}]", + f"first_num_dense_layers={first_num_dense_layers}", + f"base_num_decoder_layers={base_num_decoder_layers}", + f"num_decoder_layers={base_num_decoder_layers}", + ] + ) devices_array = maxtext_utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) @@ -126,7 +145,7 @@ def batch_decode(self, token_ids, *args, **kwargs): shared_embedding = DummyEmbedding(emb_dim=config.emb_dim) - with mesh: + with mesh, jax.disable_jit(): variables = decoder.init( {"params": jax.random.PRNGKey(0), "dropout": jax.random.PRNGKey(1), "aqt": jax.random.PRNGKey(2)}, shared_embedding=shared_embedding, @@ -154,15 +173,16 @@ def test_decoder_init_engram_2_8(self, mock_from_pretrained): """Test engram layers at indices 2 and 8.""" self._test_engram_pattern( mock_from_pretrained, - "2,8", + "1,4", [ - "dense_layers_0_1", - "dense_layers_engram_2", - "dense_layers_3_4", - "moe_layers_5_7", - "moe_layers_engram_8", - "moe_layers_9_9", + "dense_layers_0_0", + "dense_layers_engram_1", + "dense_layers_2_2", + "moe_layers_3_3", + "moe_layers_engram_4", ], + first_num_dense_layers=3, + base_num_decoder_layers=5, ) @pytest.mark.tpu_only @@ -171,8 +191,10 @@ def test_decoder_init_engram_0_5(self, mock_from_pretrained): """Test engram layers at indices 0 and 5 - first engram layer of block.""" self._test_engram_pattern( mock_from_pretrained, - "0,5", - ["dense_layers_engram_0", "dense_layers_1_4", "moe_layers_engram_5", "moe_layers_6_9"], + "0,1", + ["dense_layers_engram_0", "moe_layers_engram_1"], + first_num_dense_layers=1, + base_num_decoder_layers=2, ) @pytest.mark.tpu_only @@ -181,6 +203,8 @@ def test_decoder_init_engram_4_9(self, mock_from_pretrained): """Test engram layers at indices 4 and 9 - last engram layer of block.""" self._test_engram_pattern( mock_from_pretrained, - "4,9", - ["dense_layers_0_3", "dense_layers_engram_4", "moe_layers_5_8", "moe_layers_engram_9"], + "1,3", + ["dense_layers_0_0", "dense_layers_engram_1", "moe_layers_2_2", "moe_layers_engram_3"], + first_num_dense_layers=2, + base_num_decoder_layers=4, ) diff --git a/tests/integration/determinism_test.py b/tests/integration/determinism_test.py index 1dae6b1238..d08129be76 100644 --- a/tests/integration/determinism_test.py +++ b/tests/integration/determinism_test.py @@ -26,7 +26,7 @@ import pytest from maxtext.trainers.pre_train.train import main as train_main -from tests.utils.test_helpers import get_test_config_path +from tests.utils.test_helpers import get_test_config_path, get_test_base_output_directory, get_test_dataset_path pytestmark = pytest.mark.integration_test @@ -52,13 +52,17 @@ def test_determinism(self): common_config = [ None, get_test_config_path(), - "steps=5", + "steps=2", "enable_checkpointing=False", "enable_data_shuffling=True", "enable_dropout=False", - "base_output_directory=gs://runner-maxtext-logs", - "dataset_path=gs://maxtext-dataset", + f"base_output_directory={get_test_base_output_directory()}", + f"dataset_path={get_test_dataset_path()}", "skip_jax_distributed_system=True", + "base_emb_dim=128", + "base_mlp_dim=128", + "base_num_decoder_layers=1", + "head_dim=64", ] train_1_config = common_config + [ f"run_name={run_name}_1", diff --git a/tests/unit/diloco_test.py b/tests/integration/diloco_test.py similarity index 95% rename from tests/unit/diloco_test.py rename to tests/integration/diloco_test.py index 7a60b9acbd..a6c58a69d9 100644 --- a/tests/unit/diloco_test.py +++ b/tests/integration/diloco_test.py @@ -51,6 +51,7 @@ def __call__(self, x): return self.dense(x) +@pytest.mark.integration_test class DiLoCoTest(unittest.TestCase): @pytest.mark.tpu_only @@ -284,6 +285,13 @@ def test_diloco_qwen3_moe_two_slices(self): "dcn_diloco_parallelism=2", "enable_diloco=true", "model_name=qwen3-30b-a3b", + "override_model_config=True", + "base_emb_dim=32", + "base_num_decoder_layers=1", + "base_mlp_dim=64", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "head_dim=8", ) ) @@ -302,5 +310,12 @@ def test_diloco_two_slices(self): "dcn_diloco_parallelism=2", "enable_diloco=true", "model_name=gemma2-2b", + "override_model_config=True", + "base_emb_dim=32", + "base_num_decoder_layers=1", + "base_mlp_dim=64", + "base_num_query_heads=1", + "base_num_kv_heads=1", + "head_dim=4", ) ) diff --git a/tests/integration/generate_param_only_checkpoint_test.py b/tests/integration/generate_param_only_checkpoint_test.py index c44831f5d5..cfce481ba1 100644 --- a/tests/integration/generate_param_only_checkpoint_test.py +++ b/tests/integration/generate_param_only_checkpoint_test.py @@ -31,12 +31,12 @@ def get_model_params(quantization): return [ f"quantization={quantization}", - "base_emb_dim=384", - "base_num_query_heads=8", - "base_num_kv_heads=8", - "base_mlp_dim=192", - "base_num_decoder_layers=8", - "head_dim=128", + "base_emb_dim=128", + "base_num_query_heads=2", + "base_num_kv_heads=2", + "base_mlp_dim=128", + "base_num_decoder_layers=1", + "head_dim=64", ] @@ -69,7 +69,7 @@ def run_e2e_test_flow(hardware, model_config, attention_type="autoselected", sta steps=1, metrics_file="run_metrics.txt", attention_type=attention_type, - dataset_type="tfds", + dataset_type="synthetic", dataset_path=dataset_path, ) ) diff --git a/tests/integration/gradient_accumulation_test.py b/tests/integration/gradient_accumulation_test.py index 28523d9dc1..9e935d9c6e 100644 --- a/tests/integration/gradient_accumulation_test.py +++ b/tests/integration/gradient_accumulation_test.py @@ -154,13 +154,15 @@ def test_sft_grad_accumulate_same_loss(self): [ None, get_test_config_path(), - "base_output_directory=gs://runner-maxtext-logs", - "dataset_path=gs://maxtext-dataset", + f"base_output_directory={self.base_output_directory}", + f"dataset_path={self.dataset_path}", + "dataset_type=synthetic", + "max_target_length=128", "gradient_clipping_threshold=0", # Ensures we are testing raw scales of gradients (clipping off). "enable_checkpointing=False", "enable_goodput_recording=False", - "base_emb_dim=256", - "base_num_decoder_layers=4", + "base_emb_dim=128", + "base_num_decoder_layers=1", rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}", "steps=3", "gradient_accumulation_steps=2", diff --git a/tests/unit/maxengine_test.py b/tests/integration/maxengine_test.py similarity index 99% rename from tests/unit/maxengine_test.py rename to tests/integration/maxengine_test.py index 944d34bfef..eb4a7729d6 100644 --- a/tests/unit/maxengine_test.py +++ b/tests/integration/maxengine_test.py @@ -36,6 +36,7 @@ pytestmark = [pytest.mark.external_serving] +@pytest.mark.integration_test class MaxEngineTest(unittest.TestCase): """Tests for MaxEngine.""" @@ -55,7 +56,7 @@ def init_pyconfig(self, **kwargs): "base_num_decoder_layers": 2, "attention": "dot_product", "max_target_length": 16, - "base_emb_dim": 256, + "base_emb_dim": 32, "base_num_query_heads": 2, "base_num_kv_heads": 2, "max_prefill_predict_length": 4, diff --git a/tests/unit/pipeline_parallelism_test.py b/tests/integration/pipeline_parallelism_test.py similarity index 99% rename from tests/unit/pipeline_parallelism_test.py rename to tests/integration/pipeline_parallelism_test.py index a3041e9735..1c2a22822a 100644 --- a/tests/unit/pipeline_parallelism_test.py +++ b/tests/integration/pipeline_parallelism_test.py @@ -69,6 +69,7 @@ def pytree_ravel(pytree): assert jax.numpy.allclose(f1_grad, f2_grad, rtol=1e-1, equal_nan=False) +@pytest.mark.integration_test class PipelineParallelismTest(unittest.TestCase): decoupled = is_decoupled() base_output_directory = get_test_base_output_directory() @@ -481,7 +482,7 @@ def test_full_train_fp8(self): "base_num_decoder_layers=4", "head_dim=128", "per_device_batch_size=2", - "max_target_length=1024", + "max_target_length=128", "vocab_size=32", "dataset_type=synthetic", "steps=3", @@ -514,7 +515,7 @@ def test_full_train_nanoo_fp8(self): "base_num_decoder_layers=4", "head_dim=128", "per_device_batch_size=2", - "max_target_length=1024", + "max_target_length=128", "vocab_size=32", "dataset_type=synthetic", "steps=3", diff --git a/tests/sparsity_test.py b/tests/integration/sparsity_test.py similarity index 90% rename from tests/sparsity_test.py rename to tests/integration/sparsity_test.py index 5bf8e85d4a..a7f8bbfd3a 100644 --- a/tests/sparsity_test.py +++ b/tests/integration/sparsity_test.py @@ -26,6 +26,7 @@ gettempdir = tempfile.gettempdir +@pytest.mark.integration_test class Train(parameterized.TestCase): """Smoke test for sparsity in G3 only.""" @@ -55,13 +56,13 @@ def test_different_quant_sparsity_configs(self, quantization: str, use_sparsity: get_test_config_path(), f"base_output_directory={test_tmpdir}", "run_name=different_quant_sparsity_configs_test", - "base_emb_dim=128", - "base_num_query_heads=4", - "base_num_kv_heads=4", - "base_mlp_dim=128", - "base_moe_mlp_dim=128", - "base_num_decoder_layers=8", - "head_dim=128", + "base_emb_dim=16", + "base_num_query_heads=1", + "base_num_kv_heads=1", + "base_mlp_dim=16", + "base_moe_mlp_dim=16", + "base_num_decoder_layers=2", + "head_dim=64", "decoder_block=deepseek", "attention_type=mla", "num_experts=2", @@ -71,9 +72,9 @@ def test_different_quant_sparsity_configs(self, quantization: str, use_sparsity: f'quantization="{quantization}"', "use_qwix_quantization=True", "per_device_batch_size=2", - "max_target_length=1024", + "max_target_length=128", "dataset_type=synthetic", - "steps=10", + "steps=1", "enable_checkpointing=False", "enable_goodput_recording=False", "enable_checkpoint_cloud_logger=False", diff --git a/tests/integration/standalone_dl_ckpt_test.py b/tests/integration/standalone_dl_ckpt_test.py index f904112a12..47ecdb919c 100644 --- a/tests/integration/standalone_dl_ckpt_test.py +++ b/tests/integration/standalone_dl_ckpt_test.py @@ -60,7 +60,12 @@ def test_standalone_dataloader(self): f"run_name={random_run_name}", f"base_output_directory={self.base_output_directory}", f"dataset_path={self.dataset_path}", - "steps=100", + "base_emb_dim=128", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=128", + "base_num_decoder_layers=2", + "steps=10", "enable_checkpointing=false", "enable_goodput_recording=False", rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}", diff --git a/tests/integration/train_tests.py b/tests/integration/train_tests.py index fc27753abe..da440f7822 100644 --- a/tests/integration/train_tests.py +++ b/tests/integration/train_tests.py @@ -47,6 +47,19 @@ class TrainTests(unittest.TestCase): elif dev_count < 4: _fsdp_tp4_override = get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True) + _small_model_overrides = [ + "base_emb_dim=28", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=32", + "base_num_decoder_layers=2", + "head_dim=128", + "max_target_length=128", + "vocab_size=32", + # Allow higher unsharded percentage because downscaled models make fixed-size FP8 history tensors relatively larger. + "sharding_tolerance=0.1", + ] + CONFIGS = { "base": [ # short test for train.py with TFDS c4 None, @@ -54,11 +67,13 @@ class TrainTests(unittest.TestCase): f"base_output_directory={_base_output_directory}", "run_name=runner_test", f"dataset_path={dataset_path}", + "max_target_length=128", "steps=2", "enable_checkpointing=False", "enable_goodput_recording=False", rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}", ] + + _small_model_overrides + get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True), "synthetic": [ # tests base config with synthetic dataset None, @@ -71,6 +86,7 @@ class TrainTests(unittest.TestCase): "dataset_type=synthetic", rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}", ] + + _small_model_overrides + get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True), "pdb_lt_1": [ # tests base config with per_device_batch_size < 1 None, @@ -85,6 +101,7 @@ class TrainTests(unittest.TestCase): "ici_tensor_parallelism=4", rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}", ] + + _small_model_overrides + get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True), "tp_transpose": [ # tests base config with ici_tensor_transpose_parallelism=4 None, @@ -97,6 +114,7 @@ class TrainTests(unittest.TestCase): "enable_goodput_recording=False", rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}", ] + + _small_model_overrides + get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True), "int8": [ # tests base config with int8 None, @@ -110,6 +128,7 @@ class TrainTests(unittest.TestCase): "enable_goodput_recording=False", rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}", ] + + _small_model_overrides + get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True), "fp8": [ # tests base config with fp8 None, @@ -123,6 +142,7 @@ class TrainTests(unittest.TestCase): "enable_goodput_recording=False", rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}", ] + + _small_model_overrides + get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True), "nanoo_fp8": [ # tests base config with nanoo_fp8 None, @@ -136,6 +156,7 @@ class TrainTests(unittest.TestCase): "enable_goodput_recording=False", rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}", ] + + _small_model_overrides + get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True), "te_fp8_delayedscaling": [ # tests base config with te_fp8_delayedscaling None, @@ -149,6 +170,7 @@ class TrainTests(unittest.TestCase): "enable_goodput_recording=False", rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}", ] + + _small_model_overrides + get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True), "te_fp8_currentscaling": [ # tests base config with te_fp8_currentscaling None, @@ -162,6 +184,7 @@ class TrainTests(unittest.TestCase): "enable_goodput_recording=False", rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}", ] + + _small_model_overrides + get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True), "te_mxfp8": [ # tests base config with te_mxfp8 None, @@ -175,6 +198,7 @@ class TrainTests(unittest.TestCase): "enable_goodput_recording=False", rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}", ] + + _small_model_overrides + get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True), "dropout": [ # tests base config with dropout None, @@ -190,6 +214,7 @@ class TrainTests(unittest.TestCase): "dropout_rate=0.02", rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}", ] + + _small_model_overrides + get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True), "hf_input_pipeline": [ # test for train.py with TFDS c4, using HF input pipeline None, @@ -204,6 +229,7 @@ class TrainTests(unittest.TestCase): f"hf_train_files={dataset_path}/hf/c4/c4-train-00000-of-01637.parquet", "tokenizer_path=google-t5/t5-large", ] + + _small_model_overrides + get_decoupled_parallelism_overrides(fsdp_parallelism=dev_count, as_argv=True), } @@ -486,15 +512,15 @@ def test_tpu_zero1_gradient_accumulation(self): zero1_ga = [ # tests Zero-1 optimizer sharding with gradient accumulation None, get_test_config_path(), - "base_output_directory=gs://runner-maxtext-logs", + f"base_output_directory={self._base_output_directory}", "run_name=runner_test", - "dataset_path=gs://maxtext-dataset", - "steps=10", + f"dataset_path={self.dataset_path}", + "steps=3", "enable_checkpointing=False", "enable_goodput_recording=False", "dataset_type=synthetic", "remat_policy=minimal", - "max_target_length=8192", + "max_target_length=512", "per_device_batch_size=2", "ici_data_parallelism=-1", "dcn_data_parallelism=1", diff --git a/tests/unit/attention_test.py b/tests/unit/attention_test.py index 09d281c70a..57dfb85bdd 100644 --- a/tests/unit/attention_test.py +++ b/tests/unit/attention_test.py @@ -276,7 +276,6 @@ class AttentionTest(parameterized.TestCase): "per_device_batch_size": 1.0, "run_name": "test", "enable_checkpointing": False, - "max_prefill_predict_length": 16, "max_target_length": 512, "sa_block_q": 128, "sa_block_kv": 128, @@ -402,7 +401,7 @@ def test_autoregression(self): jax.numpy.allclose(mha_prefill, mha_full[:, :prefill_length, :], rtol=1e-02, atol=1e-02, equal_nan=False) ) - for idx in range(prefill_length, decode_total_length): + for idx in range(prefill_length, min(prefill_length + 3, decode_total_length)): lnx_idx = lnx[:, idx : idx + 1, :] decoder_positions_idx = decoder_positions[:, idx : idx + 1] mha_idx, _ = self._attention_as_mha_generic( @@ -765,7 +764,7 @@ def test_tpu_flash_attention_context_parallel( @pytest.mark.tpu_only def test_dot_product_cache_axis_order(self): all_axis_orders = tuple(itertools.permutations(range(4))) - for axis_order in random.choices(all_axis_orders, k=4): + for axis_order in random.choices(all_axis_orders, k=2): self.dot_product_attention_helper(prefill_cache_axis_order=axis_order, ar_cache_axis_order=axis_order) print(f"passed test for {axis_order=}") @@ -790,12 +789,7 @@ def _dot_product_attention( config = pyconfig.initialize( [sys.argv[0], get_test_config_path()], - per_device_batch_size=1.0, - run_name="test", - enable_checkpointing=False, - max_target_length=128, - max_prefill_predict_length=16, - attention="dot_product", + **{**self.config_arguments, "attention": "dot_product"}, ) prefill_length = config.max_prefill_predict_length @@ -881,12 +875,7 @@ def _dot_product_attention_reshape_q(self, compute_axis_order): config = pyconfig.initialize( [sys.argv[0], get_test_config_path()], - per_device_batch_size=1.0, - run_name="test", - enable_checkpointing=False, - max_target_length=128, - max_prefill_predict_length=16, - attention="dot_product", + **{**self.config_arguments, "attention": "dot_product"}, ) prefill_length = config.max_prefill_predict_length @@ -1236,6 +1225,23 @@ def test_forward_serve_vllm(self, mock_sharded_ragged_paged_attention): class MLATest(attention_test_util.MLATestBase): """Test for the Multi-Headed Latent Attention""" + config_arguments = { + "per_device_batch_size": 1.0, + "run_name": "test", + "enable_checkpointing": False, + "max_target_length": 32, + "max_prefill_predict_length": 16, + "attention_type": AttentionType.MLA.value, + "head_dim": 32, + "q_lora_rank": 4, + "kv_lora_rank": 8, + "qk_nope_head_dim": 16, + "qk_rope_head_dim": 8, + "v_head_dim": 32, + "dtype": "float32", + "mla_naive_kvcache": False, + } + @parameterized.named_parameters( {"testcase_name": "RoPE_Yarn_Autoregression", "rope_type": "yarn"}, {"testcase_name": "Default_Autoregression", "rope_type": "default"}, diff --git a/tests/unit/context_parallelism_test.py b/tests/unit/context_parallelism_test.py index 3dece67014..6045fc1f24 100644 --- a/tests/unit/context_parallelism_test.py +++ b/tests/unit/context_parallelism_test.py @@ -47,15 +47,15 @@ class ContextParallelismTest(unittest.TestCase): "run_name": "test", "enable_checkpointing": False, "max_prefill_predict_length": 16, - "max_target_length": 512, - "sa_block_q": 128, - "sa_block_kv": 128, - "sa_block_kv_compute": 128, - "sa_block_q_dkv": 128, - "sa_block_kv_dkv": 128, - "sa_block_kv_dkv_compute": 128, - "sa_block_q_dq": 128, - "sa_block_kv_dq": 128, + "max_target_length": 128, + "sa_block_q": 16, + "sa_block_kv": 16, + "sa_block_kv_compute": 16, + "sa_block_q_dkv": 16, + "sa_block_kv_dkv": 16, + "sa_block_kv_dkv_compute": 16, + "sa_block_q_dq": 16, + "sa_block_kv_dq": 16, } def setUp(self): diff --git a/tests/unit/grain_data_processing_test.py b/tests/unit/grain_data_processing_test.py index 93304195c4..49c41f8639 100644 --- a/tests/unit/grain_data_processing_test.py +++ b/tests/unit/grain_data_processing_test.py @@ -42,6 +42,17 @@ class GrainBaseProcessingTest: from unittest.TestCase (or a subclass thereof). """ + @property + def train_iter(self): + cache_key = f"_cached_train_iter_{self.__class__.__name__}" + if not hasattr(self.__class__, cache_key): + setattr( + self.__class__, + cache_key, + grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices), + ) + return getattr(self.__class__, cache_key) + def test_train_ds(self): expected_shape = [jax.device_count(), self.config.max_target_length] # For training we pack multiple short examples in one example. @@ -133,6 +144,7 @@ def setUp(self): grain_train_files=grain_train_files, tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.default"), enable_checkpointing=False, + max_target_length=128, ) self.mesh_shape_1d = (len(jax.devices()),) self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes) @@ -143,7 +155,6 @@ def setUp(self): self.config.max_target_length, self.mesh, ) - self.train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) def _make_config(self, **overrides): """Re-initialize config with base params, applying any overrides.""" @@ -158,6 +169,7 @@ def _make_config(self, **overrides): "grain_train_files": self.config.grain_train_files, "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.default"), "enable_checkpointing": False, + "max_target_length": 128, **overrides, } return pyconfig.initialize([sys.argv[0], get_test_config_path()], **kwargs) @@ -173,7 +185,6 @@ def setUp(self): super().setUp() train_files_weighted = ";".join([f"{self.config.grain_train_files},0.3", f"{self.config.grain_train_files},0.7"]) self.config = self._make_config(grain_train_files=train_files_weighted) - self.train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) class GrainArrayRecordProcessingWithMixtureConfigTest(GrainArrayRecordProcessingTest): @@ -223,7 +234,6 @@ def setUp(self): json.dump(mixture_config, f) self.config = self._make_config(grain_train_mixture_config_path=self.mixture_config_path) - self.train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) # TODO(aireenmei): Migrate this test to XLML @@ -234,7 +244,6 @@ class GrainArrayRecordAutoTuneTest(GrainArrayRecordProcessingTest): def setUp(self): super().setUp() self.config = self._make_config(grain_ram_budget_mb=512, grain_worker_count=-1) # Enable auto-tuning - self.train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) @pytest.mark.skip( reason=( @@ -259,7 +268,6 @@ def setUp(self): tokenizer_type="tiktoken", tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer_llama3.tiktoken"), ) - self.train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) # Only runs test_train_ds from parent class, skip other tests @pytest.mark.skip(reason="skip for tokenizer testing") @@ -276,8 +284,10 @@ class GrainArrayRecordHFTokenizerTest(GrainArrayRecordProcessingTest): def setUp(self): super().setUp() - self.config = self._make_config(tokenizer_type="huggingface", tokenizer_path="deepseek-ai/DeepSeek-V3") - self.train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) + self.config = self._make_config( + tokenizer_type="huggingface", + tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "qwen3-tokenizer"), + ) # Only runs test_train_ds from parent class, skip other tests @pytest.mark.skip(reason="skip for tokenizer testing") @@ -295,7 +305,6 @@ class GrainArrayRecordBestFitPackingTest(GrainArrayRecordProcessingTest): def setUp(self): super().setUp() self.config = self._make_config(grain_packing_type="best_fit") - self.train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) class GrainParquetProcessingTest(GrainBaseProcessingTest, unittest.TestCase): @@ -348,6 +357,7 @@ def setUp(self): grain_per_worker_buffer_size=1, tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.default"), enable_checkpointing=False, + max_target_length=128, ) self.mesh_shape_1d = (len(jax.devices()),) self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes) @@ -358,7 +368,6 @@ def setUp(self): self.config.max_target_length, self.mesh, ) - self.train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) class GrainTFRecordProcessingTest(GrainBaseProcessingTest, unittest.TestCase): @@ -413,6 +422,7 @@ def setUp(self): grain_per_worker_buffer_size=1, tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.default"), enable_checkpointing=False, + max_target_length=128, ) self.mesh_shape_1d = (len(jax.devices()),) self.mesh = Mesh(mesh_utils.create_device_mesh(self.mesh_shape_1d), self.config.mesh_axes) @@ -423,7 +433,6 @@ def setUp(self): self.config.max_target_length, self.mesh, ) - self.train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) @pytest.mark.external_training @@ -452,7 +461,7 @@ def setUp(self): sft_train_on_completion_only=True, train_data_columns=["messages"], tokenizer_type="huggingface", - tokenizer_path="HuggingFaceH4/zephyr-7b-beta", # The ungated tokenizer + tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "qwen3-tokenizer"), max_target_length=128, packing=True, grain_worker_count=1, diff --git a/tests/unit/hf_data_processing_test.py b/tests/unit/hf_data_processing_test.py index a0e72f1431..262c56ff9b 100644 --- a/tests/unit/hf_data_processing_test.py +++ b/tests/unit/hf_data_processing_test.py @@ -26,6 +26,7 @@ from maxtext.input_pipeline import hf_data_processing from maxtext.input_pipeline import input_pipeline_interface from maxtext.common.gcloud_stub import is_decoupled +from maxtext.utils.globals import MAXTEXT_ASSETS_ROOT from tests.utils.test_helpers import get_test_config_path, get_test_base_output_directory @@ -60,7 +61,7 @@ def setUp(self): if decoupled else "gs://maxtext-dataset/hf/c4/c4-train-00000-of-01637.parquet" ), - tokenizer_path="google-t5/t5-large", + tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "qwen3-tokenizer"), enable_checkpointing=False, ) self.mesh_shape_1d = (len(jax.devices()),) @@ -73,7 +74,14 @@ def setUp(self): self.mesh, ) - self.train_iter = hf_data_processing.make_hf_train_iterator(self.config, self.mesh, self.process_indices) + @property + def train_iter(self): + # pylint: disable=protected-access + if not hasattr(self.__class__, "_cached_train_iter"): + self.__class__._cached_train_iter = hf_data_processing.make_hf_train_iterator( + self.config, self.mesh, self.process_indices + ) + return self.__class__._cached_train_iter def test_train_ds(self): expected_shape = [jax.device_count(), self.config.max_target_length] diff --git a/tests/unit/maxtext_utils_test.py b/tests/unit/maxtext_utils_test.py index ca5dbe7a87..b1b0296f32 100644 --- a/tests/unit/maxtext_utils_test.py +++ b/tests/unit/maxtext_utils_test.py @@ -18,6 +18,7 @@ from collections.abc import Callable from typing import Any, Sequence import unittest +import pytest from unittest.mock import MagicMock, Mock, patch from dataclasses import dataclass, field import numpy as np @@ -1430,6 +1431,40 @@ def test_early_return_when_disabled(self): # Should return immediately without calling any JAX tracing (no exception raised) maxtext_utils.maybe_dump_jaxpr(cfg, p_train_step=None, train_step_inputs=None) + @pytest.mark.tpu_only + def test_traces_with_abstract_inputs_and_unwrapped_function(self): + cfg = MagicMock() + cfg.dump_jaxpr = True + cfg.dump_jaxpr_local_dir = "/tmp/nonexistent_jaxpr_test_dir" + cfg.dump_jaxpr_gcs_dir = "" + + @jax.jit + def dummy_func(x): + return x * 2 + + with ( + unittest.mock.patch("builtins.open", unittest.mock.mock_open()), + unittest.mock.patch("os.makedirs"), + unittest.mock.patch("jax.make_jaxpr") as mock_make_jaxpr, + ): + + mock_tracer = MagicMock() + mock_make_jaxpr.return_value = mock_tracer + mock_tracer.return_value = MagicMock() + + input_array = jnp.ones((4, 5), dtype=jnp.float32) + maxtext_utils.maybe_dump_jaxpr(cfg, dummy_func, (input_array,)) + + # 1. Verify make_jaxpr was called with the underlying unwrapped function + mock_make_jaxpr.assert_called_once_with(dummy_func.__wrapped__) + + # 2. Verify the input arguments passed to the tracer were mapped to abstract ShapeDtypeStruct + called_args, _ = mock_tracer.call_args + self.assertEqual(len(called_args), 1) + self.assertIsInstance(called_args[0], jax.ShapeDtypeStruct) + self.assertEqual(called_args[0].shape, (4, 5)) + self.assertEqual(called_args[0].dtype, jnp.float32) + class TestPrintShardingsParams(unittest.TestCase): """Tests for print_shardings_params — normalization branches.""" diff --git a/tests/unit/model_creation_utils_test.py b/tests/unit/model_creation_utils_test.py index 8e9926dcde..9451e2f86a 100644 --- a/tests/unit/model_creation_utils_test.py +++ b/tests/unit/model_creation_utils_test.py @@ -85,7 +85,7 @@ def _make_config(**kwargs): "base_num_decoder_layers": 2, "attention": "dot_product", "max_target_length": 16, - "base_emb_dim": 256, + "base_emb_dim": 32, "base_num_query_heads": 2, "base_num_kv_heads": 2, "max_prefill_predict_length": 4, diff --git a/tests/unit/model_test.py b/tests/unit/model_test.py index 07c0b1c30f..192bb63322 100644 --- a/tests/unit/model_test.py +++ b/tests/unit/model_test.py @@ -52,7 +52,7 @@ def init_pyconfig(self, **kwargs): base_num_decoder_layers=2, attention="dot_product", max_target_length=16, - base_emb_dim=256, + base_emb_dim=32, base_num_query_heads=2, base_num_kv_heads=2, max_prefill_predict_length=4, @@ -175,7 +175,7 @@ def test_train_vs_prefill_and_autoregress(self): equal_nan=False, ) - for idx in range(PREFILL_RANGE, self.cfg.max_target_length): + for idx in range(PREFILL_RANGE, min(PREFILL_RANGE + 3, self.cfg.max_target_length)): ids_idx = ids[:, idx : idx + 1] decoder_positions_idx = decoder_positions[:, idx : idx + 1] prefill_transformer_vars.update(partial_cache) diff --git a/tests/unit/multi_token_prediction_test.py b/tests/unit/multi_token_prediction_test.py index 99300b97df..01703546cf 100644 --- a/tests/unit/multi_token_prediction_test.py +++ b/tests/unit/multi_token_prediction_test.py @@ -47,6 +47,13 @@ def setUp(self): run_name="multi_token_prediction_layer_test", skip_jax_distributed_system=True, per_device_batch_size=8, + base_emb_dim=16, + base_mlp_dim=32, + base_num_query_heads=4, + base_num_kv_heads=4, + head_dim=8, + max_target_length=128, + vocab_size=128, **extra_args, ) self.rng = jax.random.PRNGKey(42) # Base RNG for setup @@ -205,6 +212,12 @@ def setUp(self): skip_jax_distributed_system=True, mtp_num_layers=2, base_emb_dim=16, + base_mlp_dim=32, + base_num_query_heads=4, + base_num_kv_heads=4, + head_dim=8, + max_target_length=128, + vocab_size=128, **extra_args, ) self.nnx_rngs = nnx.Rngs(params=0) diff --git a/tests/unit/quantizations_test.py b/tests/unit/quantizations_test.py index 19e37cea97..6b4a3c6295 100644 --- a/tests/unit/quantizations_test.py +++ b/tests/unit/quantizations_test.py @@ -348,11 +348,12 @@ def init_pyconfig(self, **kwargs): "per_device_batch_size": 1, "use_qwix_quantization": True, "skip_jax_distributed_system": True, - "base_emb_dim": 1024, - "base_num_query_heads": 8, - "base_num_kv_heads": 8, - "base_mlp_dim": 4096, - "base_num_decoder_layers": 12, + "base_emb_dim": 16, + "base_num_query_heads": 1, + "base_num_kv_heads": 1, + "base_mlp_dim": 16, + "base_num_decoder_layers": 1, + "max_target_length": 16, } | kwargs | extra_args @@ -393,19 +394,47 @@ def compare_fn(path, x, y): def quantization_config(self, quant, logits_tolerance=2e-1, grad_tolerance=5e-1): """Run forward pass and backward pass for quantized model and compare with base model.""" + # pylint: disable=protected-access cfg = self.init_pyconfig(quantization=quant) - model = model_creation_utils.create_model(self.cfg, self.mesh) qt_model = model_creation_utils.create_model(cfg, self.mesh) ids, decoder_segment_ids, decoder_positions = self.get_data() - var = model.init( - {"params": self.rng, "aqt": self.rng, "dropout": self.rng}, - ids, - decoder_positions, - decoder_segment_ids, - enable_dropout=False, - mutable=True, - ) + + if not hasattr(self.__class__, "_cached_base_results"): + model = model_creation_utils.create_model(self.cfg, self.mesh) + var = model.init( + {"params": self.rng, "aqt": self.rng, "dropout": self.rng}, + ids, + decoder_positions, + decoder_segment_ids, + enable_dropout=False, + mutable=True, + ) + + def loss_base(all_vars, inputs): + logits, _ = model.apply( + all_vars, + *inputs, + enable_dropout=False, + rngs={"params": self.rng}, + mutable=True, + ) + return jnp.mean((logits) ** 2) + + grads_base = jax.grad(loss_base)(var, (ids, decoder_positions, decoder_segment_ids)) + logits, _ = model.apply( + var, + ids, + decoder_positions, + decoder_segment_ids, + enable_dropout=False, + rngs={"params": self.rng}, + mutable=True, + ) + self.__class__._cached_base_results = (grads_base, logits) + + grads_base, logits = self.__class__._cached_base_results + quantized_vars = qt_model.init( {"params": self.rng, "aqt": self.rng, "dropout": self.rng}, ids, @@ -415,16 +444,6 @@ def quantization_config(self, quant, logits_tolerance=2e-1, grad_tolerance=5e-1) mutable=True, ) - def loss_base(all_vars, inputs): - logits, _ = model.apply( - all_vars, - *inputs, - enable_dropout=False, - rngs={"params": self.rng}, - mutable=True, - ) - return jnp.mean((logits) ** 2) - def loss_quant(all_vars, inputs): logits, _ = qt_model.apply( all_vars, @@ -436,18 +455,8 @@ def loss_quant(all_vars, inputs): return jnp.mean((logits) ** 2) # Compute gradients w.r.t. both models - grads_base = jax.grad(loss_base)(var, (ids, decoder_positions, decoder_segment_ids)) grads_quant = jax.grad(loss_quant)(quantized_vars, (ids, decoder_positions, decoder_segment_ids)) - logits, _ = model.apply( - var, - ids, - decoder_positions, - decoder_segment_ids, - enable_dropout=False, - rngs={"params": self.rng}, - mutable=True, - ) quant_logits, _ = qt_model.apply( quantized_vars, ids, @@ -483,12 +492,12 @@ def test_fp8_full_quantization(self): @pytest.mark.gpu_only @pytest.mark.external_serving def test_fp8_gpu_quantization(self): - self.quantization_config("fp8_gpu", grad_tolerance=1.0) + self.quantization_config("fp8_gpu", grad_tolerance=1.5) @pytest.mark.gpu_only @pytest.mark.external_serving def test_fp8_nanoo_quantization(self): - self.quantization_config("fp8_nanoo", grad_tolerance=1.0) + self.quantization_config("fp8_nanoo", grad_tolerance=1.5) @pytest.mark.skip(reason="No runner with GPU arch >= 89 is available") @pytest.mark.gpu_only diff --git a/tests/unit/tfds_data_processing_test.py b/tests/unit/tfds_data_processing_test.py index ab76c818ae..7700524157 100644 --- a/tests/unit/tfds_data_processing_test.py +++ b/tests/unit/tfds_data_processing_test.py @@ -55,6 +55,7 @@ def setUp(self): "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.default"), "enable_checkpointing": False, "eval_interval": 10, + "max_target_length": 128, } if decoupled and local_dataset_name: @@ -75,9 +76,31 @@ def setUp(self): shuffle_seed=self.config.data_shuffle_seed, ) self.read_config.add_tfds_id = True - self.train_ds = self._get_datasets() - self.train_iter = tfds_data_processing.make_tfds_train_iterator(self.config, self.mesh, self.process_indices) - self.eval_iter = tfds_data_processing.make_tfds_eval_iterator(self.config, self.mesh, self.process_indices) + + @property + def train_ds(self): + # pylint: disable=protected-access + if not hasattr(self.__class__, "_cached_train_ds"): + self.__class__._cached_train_ds = self._get_datasets() + return self.__class__._cached_train_ds + + @property + def train_iter(self): + # pylint: disable=protected-access + if not hasattr(self.__class__, "_cached_train_iter"): + self.__class__._cached_train_iter = tfds_data_processing.make_tfds_train_iterator( + self.config, self.mesh, self.process_indices + ) + return self.__class__._cached_train_iter + + @property + def eval_iter(self): + # pylint: disable=protected-access + if not hasattr(self.__class__, "_cached_eval_iter"): + self.__class__._cached_eval_iter = tfds_data_processing.make_tfds_eval_iterator( + self.config, self.mesh, self.process_indices + ) + return self.__class__._cached_eval_iter def _get_datasets(self): ds_builder = tfds.builder(self.config.dataset_name) diff --git a/tests/unit/tiling_test.py b/tests/unit/tiling_test.py index 6ed33c3c67..816a5b54a9 100644 --- a/tests/unit/tiling_test.py +++ b/tests/unit/tiling_test.py @@ -67,7 +67,12 @@ def setUp(self): """ Set up common configurations and dummy data for the tests. """ - self.base_config = [None, get_test_config_path()] + self.base_config = [ + None, + get_test_config_path(), + "base_emb_dim=32", + "vocab_size=128", + ] self.rng = jax.random.PRNGKey(1234) self.batch_size = 1 self.seq_len = 64 diff --git a/tests/unit/train_compile_test.py b/tests/unit/train_compile_test.py index 6be542d1ff..fb7993aa3e 100644 --- a/tests/unit/train_compile_test.py +++ b/tests/unit/train_compile_test.py @@ -23,6 +23,7 @@ import os.path from tempfile import gettempdir +from jax.experimental.compilation_cache import compilation_cache import pytest import transformers @@ -32,6 +33,13 @@ from maxtext.trainers.pre_train.train_compile import main as train_compile_main from tests.utils.test_helpers import get_test_config_path +# Enable JAX compilation cache for testing to speed up AOT compilation +try: + if os.getenv("JAX_PLATFORMS") != "proxy": + compilation_cache.set_cache_dir(os.path.join(gettempdir(), "jax_compile_test_cache")) +except Exception: # pylint: disable=broad-exception-caught + pass + @pytest.mark.tpu_backend class TrainCompile(parameterized.TestCase): @@ -975,11 +983,11 @@ def test_qk_clip(self, attention): ) ) - @pytest.mark.cpu_only @parameterized.named_parameters( {"testcase_name": "consistent_rms_scaling", "muon_consistent_rms": 0.2}, {"testcase_name": "width_scaling", "muon_consistent_rms": None}, ) + @pytest.mark.cpu_only def test_muon(self, muon_consistent_rms): """AOT test for Muon optimizer for DeepSeek3 Tiny model""" compiled_trainstep_file = "/tmp/test_muon.pickle"