Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions src/maxtext/utils/maxtext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1766,8 +1766,19 @@ def maybe_dump_jaxpr(config, p_train_step, train_step_inputs):
return
max_logging.log("Tracing train_step to jaxpr...")

# We use the p_train_step (the JIT-decorated function)
p_train_jaxpr = jax.make_jaxpr(p_train_step)(*train_step_inputs)
# Trace the underlying un-jitted function via __wrapped__ to avoid heavy remote
# compilation/gRPC round-trips to the Pathways controller.
unwrapped_step = getattr(p_train_step, "__wrapped__", p_train_step)

def to_abstract(x):
if hasattr(x, "shape") and hasattr(x, "dtype"):
return jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype)
return x

# Convert all input arguments recursively to purely local abstract ShapeDtypeStruct objects
# to completely bypass remote Array objects and proxy tracing overhead.
abstract_inputs = jax.tree.map(to_abstract, train_step_inputs)
p_train_jaxpr = jax.make_jaxpr(unwrapped_step)(*abstract_inputs)

local_filename = "train_step.jaxpr"
local_path = os.path.join(config.dump_jaxpr_local_dir, local_filename)
Expand Down
32 changes: 27 additions & 5 deletions tests/integration/aot_identical_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,13 @@ def assert_compile_and_real_match_hlo(self, test_name, *extra_args):
"steps=1",
"enable_checkpointing=False",
"base_num_decoder_layers=1",
"max_target_length=512",
"base_emb_dim=256",
"base_mlp_dim=256",
"max_target_length=32",
"base_emb_dim=64",
"base_mlp_dim=64",
"base_num_query_heads=2",
"base_num_kv_heads=2",
"head_dim=16",
"vocab_size=128",
] + hlo_dump_args
if extra_args:
shared_args.extend(extra_args)
Expand Down Expand Up @@ -179,6 +183,14 @@ def assert_compile_and_real_match_jaxpr(self, test_name, *extra_args):
"enable_checkpointing=False",
"dump_jaxpr=True",
"dump_jaxpr_delete_local_after=False",
"base_num_decoder_layers=1",
"max_target_length=32",
"base_emb_dim=64",
"base_mlp_dim=64",
"base_num_query_heads=2",
"base_num_kv_heads=2",
"head_dim=16",
"vocab_size=128",
]
if extra_args:
shared_args.extend(extra_args)
Expand Down Expand Up @@ -218,5 +230,15 @@ def assert_compile_and_real_match_jaxpr(self, test_name, *extra_args):
)

@pytest.mark.tpu_only
def test_default_jaxpr_match(self):
self.assert_compile_and_real_match_jaxpr("default_run")
def test_default_jaxpr_match_mcjax(self):
if os.getenv("JAX_PLATFORMS") == "proxy":
pytest.skip("This is a McJAX test, skipping in Pathways environment.")
self.assert_compile_and_real_match_jaxpr("default_run_mcjax")

@pytest.mark.tpu_only
@pytest.mark.scheduled_only
def test_default_jaxpr_match_pathways(self):
# Currently this test is extremely slow (b/512065615).
if os.getenv("JAX_PLATFORMS") != "proxy":
pytest.skip("This is a Pathways test, skipping in McJAX environment.")
self.assert_compile_and_real_match_jaxpr("default_run_pathways", "enable_single_controller=True")
12 changes: 6 additions & 6 deletions tests/integration/checkpoint_resharding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@
def get_resharding_command(run_date, steps, metrics_file, base_output_directory, dataset_path, parallelism_args):
"""Generates a command list for the resharding test run."""
model_params = [
"base_emb_dim=384",
"base_num_query_heads=8",
"base_num_kv_heads=8",
"base_mlp_dim=192",
"base_num_decoder_layers=8",
"head_dim=128",
"base_emb_dim=128",
"base_num_query_heads=2",
"base_num_kv_heads=2",
"base_mlp_dim=128",
"base_num_decoder_layers=2",
"head_dim=64",
]

return (
Expand Down
12 changes: 6 additions & 6 deletions tests/integration/checkpointing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ def get_checkpointing_command(run_date, hardware, steps, metrics_file, attention
"""
base_output_directory = get_test_base_output_directory()
model_params = [
"base_emb_dim=384",
"base_num_query_heads=8",
"base_num_kv_heads=8",
"base_mlp_dim=192",
"base_num_decoder_layers=8",
"head_dim=128",
"base_emb_dim=128",
"base_num_query_heads=2",
"base_num_kv_heads=2",
"base_mlp_dim=128",
"base_num_decoder_layers=1",
"head_dim=64",
]
pathways_command = []
if os.getenv("JAX_PLATFORMS") == "proxy":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __call__(self, x, model_mode):
return jnp.ones((x.shape[0], x.shape[1], self.emb_dim))


@pytest.mark.integration_test
class TestDeepSeekScanEngram(unittest.TestCase):
"""Test DeepSeek decoder block with scan_layers=True and engram_layers."""

Expand All @@ -53,16 +54,16 @@ class TestDeepSeekScanEngram(unittest.TestCase):
"first_num_dense_layers=5",
"base_num_decoder_layers=10",
"num_decoder_layers=10",
"base_emb_dim=64",
"base_mlp_dim=64",
"base_moe_mlp_dim=64",
"base_num_query_heads=2",
"base_num_kv_heads=2",
"head_dim=32",
"indexer_head_dim=32",
"qk_nope_head_dim=32",
"qk_rope_head_dim=16",
"v_head_dim=32",
"base_emb_dim=8",
"base_mlp_dim=8",
"base_moe_mlp_dim=8",
"base_num_query_heads=1",
"base_num_kv_heads=1",
"head_dim=4",
"indexer_head_dim=4",
"qk_nope_head_dim=4",
"qk_rope_head_dim=4",
"v_head_dim=4",
"vocab_size=128",
"mhc_expansion_rate=4",
"attention=dot_product",
Expand All @@ -71,15 +72,24 @@ class TestDeepSeekScanEngram(unittest.TestCase):
"max_prefill_predict_length=8",
"enable_checkpointing=False",
"engram_num_heads=1",
"engram_head_dim=8",
"engram_head_dim=4",
"engram_vocab_bases=[128,128]",
"engram_max_ngram_size=3",
"engram_kernel_size=4",
"num_experts=2",
"num_experts_per_tok=1",
"hf_access_token=dummy",
"tokenizer_path=dummy",
]

def _test_engram_pattern(self, mock_from_pretrained, engram_layers_str, expected_keys):
def _test_engram_pattern(
self,
mock_from_pretrained,
engram_layers_str,
expected_keys,
first_num_dense_layers=5,
base_num_decoder_layers=10,
):
"""Helper method to test different engram layer patterns."""

# Setup mock tokenizer
Expand All @@ -106,7 +116,16 @@ def batch_decode(self, token_ids, *args, **kwargs):
mock_from_pretrained.return_value = MockTokenizer()

config_path = os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")
config = pyconfig.initialize([None, config_path] + self._COMMON_CONFIG + [f"engram_layers=[{engram_layers_str}]"])
config = pyconfig.initialize(
[None, config_path]
+ self._COMMON_CONFIG
+ [
f"engram_layers=[{engram_layers_str}]",
f"first_num_dense_layers={first_num_dense_layers}",
f"base_num_decoder_layers={base_num_decoder_layers}",
f"num_decoder_layers={base_num_decoder_layers}",
]
)

devices_array = maxtext_utils.create_device_mesh(config)
mesh = Mesh(devices_array, config.mesh_axes)
Expand All @@ -126,7 +145,7 @@ def batch_decode(self, token_ids, *args, **kwargs):

shared_embedding = DummyEmbedding(emb_dim=config.emb_dim)

with mesh:
with mesh, jax.disable_jit():
variables = decoder.init(
{"params": jax.random.PRNGKey(0), "dropout": jax.random.PRNGKey(1), "aqt": jax.random.PRNGKey(2)},
shared_embedding=shared_embedding,
Expand Down Expand Up @@ -154,15 +173,16 @@ def test_decoder_init_engram_2_8(self, mock_from_pretrained):
"""Test engram layers at indices 2 and 8."""
self._test_engram_pattern(
mock_from_pretrained,
"2,8",
"1,4",
[
"dense_layers_0_1",
"dense_layers_engram_2",
"dense_layers_3_4",
"moe_layers_5_7",
"moe_layers_engram_8",
"moe_layers_9_9",
"dense_layers_0_0",
"dense_layers_engram_1",
"dense_layers_2_2",
"moe_layers_3_3",
"moe_layers_engram_4",
],
first_num_dense_layers=3,
base_num_decoder_layers=5,
)

@pytest.mark.tpu_only
Expand All @@ -171,8 +191,10 @@ def test_decoder_init_engram_0_5(self, mock_from_pretrained):
"""Test engram layers at indices 0 and 5 - first engram layer of block."""
self._test_engram_pattern(
mock_from_pretrained,
"0,5",
["dense_layers_engram_0", "dense_layers_1_4", "moe_layers_engram_5", "moe_layers_6_9"],
"0,1",
["dense_layers_engram_0", "moe_layers_engram_1"],
first_num_dense_layers=1,
base_num_decoder_layers=2,
)

@pytest.mark.tpu_only
Expand All @@ -181,6 +203,8 @@ def test_decoder_init_engram_4_9(self, mock_from_pretrained):
"""Test engram layers at indices 4 and 9 - last engram layer of block."""
self._test_engram_pattern(
mock_from_pretrained,
"4,9",
["dense_layers_0_3", "dense_layers_engram_4", "moe_layers_5_8", "moe_layers_engram_9"],
"1,3",
["dense_layers_0_0", "dense_layers_engram_1", "moe_layers_2_2", "moe_layers_engram_3"],
first_num_dense_layers=2,
base_num_decoder_layers=4,
)
12 changes: 8 additions & 4 deletions tests/integration/determinism_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import pytest

from maxtext.trainers.pre_train.train import main as train_main
from tests.utils.test_helpers import get_test_config_path
from tests.utils.test_helpers import get_test_config_path, get_test_base_output_directory, get_test_dataset_path

pytestmark = pytest.mark.integration_test

Expand All @@ -52,13 +52,17 @@ def test_determinism(self):
common_config = [
None,
get_test_config_path(),
"steps=5",
"steps=2",
"enable_checkpointing=False",
"enable_data_shuffling=True",
"enable_dropout=False",
"base_output_directory=gs://runner-maxtext-logs",
"dataset_path=gs://maxtext-dataset",
f"base_output_directory={get_test_base_output_directory()}",
f"dataset_path={get_test_dataset_path()}",
"skip_jax_distributed_system=True",
"base_emb_dim=128",
"base_mlp_dim=128",
"base_num_decoder_layers=1",
"head_dim=64",
]
train_1_config = common_config + [
f"run_name={run_name}_1",
Expand Down
15 changes: 15 additions & 0 deletions tests/unit/diloco_test.py → tests/integration/diloco_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __call__(self, x):
return self.dense(x)


@pytest.mark.integration_test
class DiLoCoTest(unittest.TestCase):

@pytest.mark.tpu_only
Expand Down Expand Up @@ -284,6 +285,13 @@ def test_diloco_qwen3_moe_two_slices(self):
"dcn_diloco_parallelism=2",
"enable_diloco=true",
"model_name=qwen3-30b-a3b",
"override_model_config=True",
"base_emb_dim=32",
"base_num_decoder_layers=1",
"base_mlp_dim=64",
"base_num_query_heads=4",
"base_num_kv_heads=4",
"head_dim=8",
)
)

Expand All @@ -302,5 +310,12 @@ def test_diloco_two_slices(self):
"dcn_diloco_parallelism=2",
"enable_diloco=true",
"model_name=gemma2-2b",
"override_model_config=True",
"base_emb_dim=32",
"base_num_decoder_layers=1",
"base_mlp_dim=64",
"base_num_query_heads=1",
"base_num_kv_heads=1",
"head_dim=4",
)
)
14 changes: 7 additions & 7 deletions tests/integration/generate_param_only_checkpoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@
def get_model_params(quantization):
return [
f"quantization={quantization}",
"base_emb_dim=384",
"base_num_query_heads=8",
"base_num_kv_heads=8",
"base_mlp_dim=192",
"base_num_decoder_layers=8",
"head_dim=128",
"base_emb_dim=128",
"base_num_query_heads=2",
"base_num_kv_heads=2",
"base_mlp_dim=128",
"base_num_decoder_layers=1",
"head_dim=64",
]


Expand Down Expand Up @@ -69,7 +69,7 @@ def run_e2e_test_flow(hardware, model_config, attention_type="autoselected", sta
steps=1,
metrics_file="run_metrics.txt",
attention_type=attention_type,
dataset_type="tfds",
dataset_type="synthetic",
dataset_path=dataset_path,
)
)
Expand Down
10 changes: 6 additions & 4 deletions tests/integration/gradient_accumulation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,15 @@ def test_sft_grad_accumulate_same_loss(self):
[
None,
get_test_config_path(),
"base_output_directory=gs://runner-maxtext-logs",
"dataset_path=gs://maxtext-dataset",
f"base_output_directory={self.base_output_directory}",
f"dataset_path={self.dataset_path}",
"dataset_type=synthetic",
"max_target_length=128",
"gradient_clipping_threshold=0", # Ensures we are testing raw scales of gradients (clipping off).
"enable_checkpointing=False",
"enable_goodput_recording=False",
"base_emb_dim=256",
"base_num_decoder_layers=4",
"base_emb_dim=128",
"base_num_decoder_layers=1",
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
"steps=3",
"gradient_accumulation_steps=2",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
pytestmark = [pytest.mark.external_serving]


@pytest.mark.integration_test
class MaxEngineTest(unittest.TestCase):
"""Tests for MaxEngine."""

Expand All @@ -55,7 +56,7 @@ def init_pyconfig(self, **kwargs):
"base_num_decoder_layers": 2,
"attention": "dot_product",
"max_target_length": 16,
"base_emb_dim": 256,
"base_emb_dim": 32,
"base_num_query_heads": 2,
"base_num_kv_heads": 2,
"max_prefill_predict_length": 4,
Expand Down
Loading
Loading