From 46439743923a1c613c13a9553e58db4f7c4879de Mon Sep 17 00:00:00 2001 From: Alex Shraer Date: Sat, 2 May 2026 16:54:50 +0000 Subject: [PATCH 01/11] Add integration test for checkpoint resharding --- .../integration/checkpoint_resharding_test.py | 142 ++++++++++++++++++ 1 file changed, 142 insertions(+) create mode 100644 tests/integration/checkpoint_resharding_test.py diff --git a/tests/integration/checkpoint_resharding_test.py b/tests/integration/checkpoint_resharding_test.py new file mode 100644 index 0000000000..55d75869c4 --- /dev/null +++ b/tests/integration/checkpoint_resharding_test.py @@ -0,0 +1,142 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for checkpoint resharding functionality. + +These tests verify that a training run saves a checkpoint using one mesh topology, +and that a subsequent run can successfully restore and continue training using a +different mesh topology (resharding). +""" + +from datetime import datetime +import json +from math import isclose +import pytest + +from maxtext.trainers.pre_train.train import main as train_main +from tests.utils.test_helpers import ( + get_test_config_path, + get_test_base_output_directory, + get_test_dataset_path, +) + + +def get_resharding_command(run_date, steps, metrics_file, base_output_directory, dataset_path, parallelism_args): + """Generates a command list for the resharding test run.""" + model_params = [ + "base_emb_dim=384", + "base_num_query_heads=8", + "base_num_kv_heads=8", + "base_mlp_dim=192", + "base_num_decoder_layers=8", + "head_dim=128", + ] + + return ( + [ + None, + get_test_config_path(), + f"run_name=runner_{run_date}", + f"steps={steps}", + f"metrics_file={metrics_file}", + f"base_output_directory={base_output_directory}", + f"dataset_path={dataset_path}", + "dataset_type=grain", + "grain_worker_count=0", + "collect_stack_trace=False", + ] + + model_params + + parallelism_args + ) + + +def check_loss(metrics_file, target): + """Asserts that loss values match between saved and restored checkpoints. + + Verifies the resharding restoration is mathematically consistent by comparing + the final logged loss of the initial (saved) run against the initial logged + loss of the resumed (restored) run within a relative tolerance. + """ + metrics_file_saved = "saved_" + metrics_file + metrics_file_restored = "restored_" + metrics_file + + with ( + open(metrics_file_saved, "rt", encoding="utf8") as saved, + open(metrics_file_restored, "rt", encoding="utf8") as restored, + ): + # Read the last line of the saved metrics to get the final pre-checkpoint loss + saved_loss = json.loads(saved.readlines()[-1])[target] + # Read the first line of the restored metrics to get the initial post-restoration loss + restored_loss = json.loads(restored.readlines()[0])[target] + + print("Saved loss: ", saved_loss) + print("Restored loss: ", restored_loss) + # Checks that checkpoint restore was successful by comparing loss of last + # step in saved checkpoint to loss of first step in restored checkpoint + assert isclose(saved_loss, restored_loss, rel_tol=0.1) + + +@pytest.mark.integration_test +@pytest.mark.tpu_only +@pytest.mark.scheduled_only +def test_checkpoint_resharding(): + """Tests checkpoint resharding by saving and restoring with different mesh topologies.""" + run_date = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + base_output_directory = get_test_base_output_directory() + dataset_path = get_test_dataset_path() + + # Phase 1: Train and Save Checkpoint + # Topology: FSDP=4, Tensor=1 + save_parallelism = [ + "checkpoint_period=10", + "save_checkpoint_on_completion=True", # Saves Checkpoint 0 upon job completion (model state after step 0) + "dcn_data_parallelism=2", + "dcn_fsdp_parallelism=1", + "ici_fsdp_parallelism=4", + "ici_tensor_parallelism=1", + ] + train_main( + get_resharding_command( + run_date, + steps=1, # Executes Step 0 + metrics_file="saved_metrics.txt", + base_output_directory=base_output_directory, + dataset_path=dataset_path, + parallelism_args=save_parallelism, + ) + ) + + # Phase 2: Restore and Continue + # Topology: FSDP=2, Tensor=2 + restore_parallelism = [ + "dcn_data_parallelism=2", + "dcn_fsdp_parallelism=1", + "ici_fsdp_parallelism=2", + "ici_tensor_parallelism=2", + ] + train_main( + get_resharding_command( + run_date, + # 'steps' defines the target global step. + # Restores Checkpoint 0 (state after step 0), sets start_step=1, and executes Step 1 to reach global step 2. + steps=2, + metrics_file="restored_metrics.txt", + base_output_directory=base_output_directory, + dataset_path=dataset_path, + parallelism_args=restore_parallelism, + ) + ) + + # Phase 3: Verify Loss Consistency + check_loss("metrics.txt", "learning/loss") From f90d592f481da1c50b6f28a767545d4054bfb9c9 Mon Sep 17 00:00:00 2001 From: Alex Shraer Date: Sat, 2 May 2026 20:37:31 +0000 Subject: [PATCH 02/11] Enable checkpoint resharding integration test in normal runs --- tests/integration/checkpoint_resharding_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/integration/checkpoint_resharding_test.py b/tests/integration/checkpoint_resharding_test.py index 55d75869c4..73575a5733 100644 --- a/tests/integration/checkpoint_resharding_test.py +++ b/tests/integration/checkpoint_resharding_test.py @@ -89,7 +89,6 @@ def check_loss(metrics_file, target): @pytest.mark.integration_test @pytest.mark.tpu_only -@pytest.mark.scheduled_only def test_checkpoint_resharding(): """Tests checkpoint resharding by saving and restoring with different mesh topologies.""" run_date = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") From 6e40f9350bbefb5ce4f498cdc08235405f4c3f89 Mon Sep 17 00:00:00 2001 From: Alex Shraer Date: Sat, 2 May 2026 21:33:27 +0000 Subject: [PATCH 03/11] Adjust resharding test parallelism for single slice TPU VM --- tests/integration/checkpoint_resharding_test.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/integration/checkpoint_resharding_test.py b/tests/integration/checkpoint_resharding_test.py index 73575a5733..4bab07ea54 100644 --- a/tests/integration/checkpoint_resharding_test.py +++ b/tests/integration/checkpoint_resharding_test.py @@ -96,19 +96,19 @@ def test_checkpoint_resharding(): dataset_path = get_test_dataset_path() # Phase 1: Train and Save Checkpoint - # Topology: FSDP=4, Tensor=1 + # Topology: FSDP=8, Tensor=1 save_parallelism = [ "checkpoint_period=10", - "save_checkpoint_on_completion=True", # Saves Checkpoint 0 upon job completion (model state after step 0) - "dcn_data_parallelism=2", + "save_checkpoint_on_completion=True", + "dcn_data_parallelism=1", "dcn_fsdp_parallelism=1", - "ici_fsdp_parallelism=4", + "ici_fsdp_parallelism=8", "ici_tensor_parallelism=1", ] train_main( get_resharding_command( run_date, - steps=1, # Executes Step 0 + steps=1, metrics_file="saved_metrics.txt", base_output_directory=base_output_directory, dataset_path=dataset_path, @@ -117,11 +117,11 @@ def test_checkpoint_resharding(): ) # Phase 2: Restore and Continue - # Topology: FSDP=2, Tensor=2 + # Topology: FSDP=4, Tensor=2 restore_parallelism = [ - "dcn_data_parallelism=2", + "dcn_data_parallelism=1", "dcn_fsdp_parallelism=1", - "ici_fsdp_parallelism=2", + "ici_fsdp_parallelism=4", "ici_tensor_parallelism=2", ] train_main( From 464425f31fd1186cc468ded4a4ed1f157f9eef73 Mon Sep 17 00:00:00 2001 From: Alex Shraer Date: Sat, 2 May 2026 21:33:35 +0000 Subject: [PATCH 04/11] Adjust resharding test parallelism for 4 devices --- tests/integration/checkpoint_resharding_test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/integration/checkpoint_resharding_test.py b/tests/integration/checkpoint_resharding_test.py index 4bab07ea54..25a0453f72 100644 --- a/tests/integration/checkpoint_resharding_test.py +++ b/tests/integration/checkpoint_resharding_test.py @@ -96,13 +96,13 @@ def test_checkpoint_resharding(): dataset_path = get_test_dataset_path() # Phase 1: Train and Save Checkpoint - # Topology: FSDP=8, Tensor=1 + # Topology: FSDP=4, Tensor=1 save_parallelism = [ "checkpoint_period=10", "save_checkpoint_on_completion=True", "dcn_data_parallelism=1", "dcn_fsdp_parallelism=1", - "ici_fsdp_parallelism=8", + "ici_fsdp_parallelism=4", "ici_tensor_parallelism=1", ] train_main( @@ -117,11 +117,11 @@ def test_checkpoint_resharding(): ) # Phase 2: Restore and Continue - # Topology: FSDP=4, Tensor=2 + # Topology: FSDP=2, Tensor=2 restore_parallelism = [ "dcn_data_parallelism=1", "dcn_fsdp_parallelism=1", - "ici_fsdp_parallelism=4", + "ici_fsdp_parallelism=2", "ici_tensor_parallelism=2", ] train_main( From 10a6734169979cdea24e9fe76e18acf745aed668 Mon Sep 17 00:00:00 2001 From: Alex Shraer Date: Sat, 2 May 2026 21:34:22 +0000 Subject: [PATCH 05/11] Use synthetic dataset in resharding test --- tests/integration/checkpoint_resharding_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/checkpoint_resharding_test.py b/tests/integration/checkpoint_resharding_test.py index 25a0453f72..3dd87d0509 100644 --- a/tests/integration/checkpoint_resharding_test.py +++ b/tests/integration/checkpoint_resharding_test.py @@ -52,7 +52,7 @@ def get_resharding_command(run_date, steps, metrics_file, base_output_directory, f"metrics_file={metrics_file}", f"base_output_directory={base_output_directory}", f"dataset_path={dataset_path}", - "dataset_type=grain", + "dataset_type=synthetic", "grain_worker_count=0", "collect_stack_trace=False", ] From 8db40f224e0dccf03f0c9335fa61b9288c7c94d3 Mon Sep 17 00:00:00 2001 From: Alex Shraer Date: Sat, 2 May 2026 21:36:12 +0000 Subject: [PATCH 06/11] Restore step=1 comment in resharding test --- tests/integration/checkpoint_resharding_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/checkpoint_resharding_test.py b/tests/integration/checkpoint_resharding_test.py index 3dd87d0509..00377877db 100644 --- a/tests/integration/checkpoint_resharding_test.py +++ b/tests/integration/checkpoint_resharding_test.py @@ -108,7 +108,7 @@ def test_checkpoint_resharding(): train_main( get_resharding_command( run_date, - steps=1, + steps=1, # Executes Step 0 metrics_file="saved_metrics.txt", base_output_directory=base_output_directory, dataset_path=dataset_path, From 2781b58b9e0b7850eb1f644fd2c5b85df25c69e8 Mon Sep 17 00:00:00 2001 From: Alex Shraer Date: Sat, 2 May 2026 21:36:33 +0000 Subject: [PATCH 07/11] Restore save_checkpoint_on_completion comment --- tests/integration/checkpoint_resharding_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/checkpoint_resharding_test.py b/tests/integration/checkpoint_resharding_test.py index 00377877db..23dac0c8f8 100644 --- a/tests/integration/checkpoint_resharding_test.py +++ b/tests/integration/checkpoint_resharding_test.py @@ -99,7 +99,7 @@ def test_checkpoint_resharding(): # Topology: FSDP=4, Tensor=1 save_parallelism = [ "checkpoint_period=10", - "save_checkpoint_on_completion=True", + "save_checkpoint_on_completion=True", # Saves Checkpoint 0 upon job completion (model state after step 0) "dcn_data_parallelism=1", "dcn_fsdp_parallelism=1", "ici_fsdp_parallelism=4", From dcafa78b100a6b9a75a27ad8ddfd569b251aa923 Mon Sep 17 00:00:00 2001 From: Alex Shraer Date: Sat, 2 May 2026 21:41:50 +0000 Subject: [PATCH 08/11] Fix trailing whitespace in docstring --- tests/integration/checkpoint_resharding_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/checkpoint_resharding_test.py b/tests/integration/checkpoint_resharding_test.py index 23dac0c8f8..fcfbc2d35d 100644 --- a/tests/integration/checkpoint_resharding_test.py +++ b/tests/integration/checkpoint_resharding_test.py @@ -65,7 +65,7 @@ def check_loss(metrics_file, target): """Asserts that loss values match between saved and restored checkpoints. Verifies the resharding restoration is mathematically consistent by comparing - the final logged loss of the initial (saved) run against the initial logged + the final logged loss of the initial (saved) run against the initial logged loss of the resumed (restored) run within a relative tolerance. """ metrics_file_saved = "saved_" + metrics_file From 71575df5c1084ac3128f793cfd41c10a49a7067e Mon Sep 17 00:00:00 2001 From: Alex Shraer Date: Sat, 2 May 2026 22:07:33 +0000 Subject: [PATCH 09/11] Re-enable scheduled_only tag on resharding test --- tests/integration/checkpoint_resharding_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integration/checkpoint_resharding_test.py b/tests/integration/checkpoint_resharding_test.py index fcfbc2d35d..5db5ac444e 100644 --- a/tests/integration/checkpoint_resharding_test.py +++ b/tests/integration/checkpoint_resharding_test.py @@ -89,6 +89,7 @@ def check_loss(metrics_file, target): @pytest.mark.integration_test @pytest.mark.tpu_only +@pytest.mark.scheduled_only def test_checkpoint_resharding(): """Tests checkpoint resharding by saving and restoring with different mesh topologies.""" run_date = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") From b05830f301afddfb9d71f5186208b0539f0a95a0 Mon Sep 17 00:00:00 2001 From: Alex Shraer Date: Sat, 2 May 2026 22:11:04 +0000 Subject: [PATCH 10/11] Enable modified tests to ignore scheduled_only flag in CI --- .github/workflows/build_and_test_maxtext.yml | 21 +++++++------------ .../workflows/run_tests_against_package.yml | 7 ++++++- .github/workflows/run_tests_coordinator.yml | 6 ++++++ tests/conftest.py | 13 ++++++++++++ 4 files changed, 32 insertions(+), 15 deletions(-) diff --git a/.github/workflows/build_and_test_maxtext.yml b/.github/workflows/build_and_test_maxtext.yml index 6949bf8d7e..3adf141e33 100644 --- a/.github/workflows/build_and_test_maxtext.yml +++ b/.github/workflows/build_and_test_maxtext.yml @@ -43,6 +43,7 @@ jobs: outputs: run_tests: ${{ steps.check.outputs.run_tests }} run_notebooks: ${{ steps.check.outputs.run_notebooks }} + modified_tests: ${{ steps.check.outputs.modified_tests }} steps: - uses: actions/checkout@v4 with: @@ -63,24 +64,13 @@ jobs: echo "Changed files:" echo "$CHANGED_FILES" - if [ -z "$CHANGED_FILES" ]; then - echo "No files detected or diff failed. Running everything as a fail-safe." - echo "run_tests=true" >> $GITHUB_OUTPUT - echo "run_notebooks=true" >> $GITHUB_OUTPUT - exit 0 - fi - # default to running everything if something goes wrong with the checks echo "run_tests=true" >> $GITHUB_OUTPUT echo "run_notebooks=true" >> $GITHUB_OUTPUT + echo "modified_tests=" >> $GITHUB_OUTPUT - # 1. Check if only documentation files (.md) were changed - if ! echo "$CHANGED_FILES" | grep -v -E '\.md$' > /dev/null; then - echo "Documentation-only files changed, skipping all tests and notebooks." - echo "run_tests=false" >> $GITHUB_OUTPUT - echo "run_notebooks=false" >> $GITHUB_OUTPUT - exit 0 - fi + MODIFIED_TESTS=$(echo "$CHANGED_FILES" | grep -E '^tests/.*\.py$' | paste -sd ";" - || echo "") + echo "modified_tests=$MODIFIED_TESTS" >> $GITHUB_OUTPUT # 2. Check if dependencies or Github workflows were changed if echo "$CHANGED_FILES" | grep -E '(^|/)(src/dependencies/|\.github/workflows/)' > /dev/null; then @@ -156,6 +146,7 @@ jobs: base_image: maxtext-unit-test-tpu:py312 is_scheduled_run: ${{ github.event_name == 'schedule' }} maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }} + modified_tests: ${{ needs.analyze_code_changes.outputs.modified_tests }} gpu-tests: name: ${{ matrix.flavor }} tests @@ -171,6 +162,7 @@ jobs: base_image: maxtext-unit-test-cuda12:py312 is_scheduled_run: ${{ github.event_name == 'schedule' }} maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }} + modified_tests: ${{ needs.analyze_code_changes.outputs.modified_tests }} cpu-tests: name: ${{ matrix.flavor }} tests @@ -186,6 +178,7 @@ jobs: base_image: maxtext-unit-test-tpu:py312 is_scheduled_run: ${{ github.event_name == 'schedule' }} maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }} + modified_tests: ${{ needs.analyze_code_changes.outputs.modified_tests }} maxtext_tpu_pathways_unit_tests: needs: build_and_upload_maxtext_package diff --git a/.github/workflows/run_tests_against_package.yml b/.github/workflows/run_tests_against_package.yml index 2955e082f0..868240b9b0 100644 --- a/.github/workflows/run_tests_against_package.yml +++ b/.github/workflows/run_tests_against_package.yml @@ -74,6 +74,10 @@ on: required: false type: boolean default: false + modified_tests: + required: false + type: string + default: '' permissions: contents: read @@ -86,6 +90,7 @@ jobs: XLA_PYTHON_CLIENT_MEM_FRACTION: ${{ inputs.xla_python_client_mem_fraction }} TF_FORCE_GPU_ALLOW_GROWTH: ${{ inputs.tf_force_gpu_allow_growth }} TPU_SKIP_MDS_QUERY: ${{ inputs.device_type == 'cpu' && '1' || '' }} + MODIFIED_TESTS: ${{ inputs.modified_tests }} MAXTEXT_PACKAGE_EXTRA: >- ${{ !contains(inputs.pytest_marker, 'not post_training') && 'tpu-post-train' @@ -146,7 +151,7 @@ jobs: if [ "${INPUTS_IS_SCHEDULED_RUN}" == "true" ]; then FINAL_PYTEST_MARKER="${INPUTS_PYTEST_MARKER}" else - FINAL_PYTEST_MARKER="${INPUTS_PYTEST_MARKER} and not scheduled_only" + FINAL_PYTEST_MARKER="${INPUTS_PYTEST_MARKER} and (not scheduled_only or force_run)" fi # TODO: Use package data for testing and remove the env vars export MAXTEXT_REPO_ROOT=$(pwd) diff --git a/.github/workflows/run_tests_coordinator.yml b/.github/workflows/run_tests_coordinator.yml index 935b7fbda0..d9b35e7f11 100644 --- a/.github/workflows/run_tests_coordinator.yml +++ b/.github/workflows/run_tests_coordinator.yml @@ -56,6 +56,11 @@ on: required: false type: string default: '' + modified_tests: + description: 'List of modified tests separated by semicolon' + required: false + type: string + default: '' permissions: contents: read @@ -161,3 +166,4 @@ jobs: total_workers: ${{ contains(inputs.flavor, 'cpu-unit') && 2 || 1 }} maxtext_sha: ${{ inputs.maxtext_sha }} is_update_hlo: ${{ inputs.is_update_hlo }} + modified_tests: ${{ inputs.modified_tests }} diff --git a/tests/conftest.py b/tests/conftest.py index 1bd442c8ed..fc55fa1a41 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -127,6 +127,18 @@ def pytest_collection_modifyitems(config, items): for item in remaining: item.add_marker(pytest.mark.decoupled) + import os + modified_tests = os.environ.get("MODIFIED_TESTS", "").split(";") + if modified_tests and modified_tests[0]: + import os.path + for item in remaining: + try: + rel_path = os.path.relpath(str(item.fspath), os.getcwd()) + if rel_path in modified_tests: + item.add_marker(pytest.mark.force_run) + except Exception: # pragma: no cover + pass + def pytest_configure(config): for m in [ @@ -136,5 +148,6 @@ def pytest_configure(config): "external_serving: JetStream / serving / decode server components", "external_training: goodput integrations", "decoupled: marked on tests that are not skipped due to GCP deps, when DECOUPLE_GCLOUD=TRUE", + "force_run: force execution of scheduled tests if their file was modified", ]: config.addinivalue_line("markers", m) From a211a6bbd0d48ba185040cf802cb21bd96d777c6 Mon Sep 17 00:00:00 2001 From: Alex Shraer Date: Sat, 2 May 2026 22:12:31 +0000 Subject: [PATCH 11/11] Revert "Enable modified tests to ignore scheduled_only flag in CI" This reverts commit b05830f301afddfb9d71f5186208b0539f0a95a0. --- .github/workflows/build_and_test_maxtext.yml | 21 ++++++++++++------- .../workflows/run_tests_against_package.yml | 7 +------ .github/workflows/run_tests_coordinator.yml | 6 ------ tests/conftest.py | 13 ------------ 4 files changed, 15 insertions(+), 32 deletions(-) diff --git a/.github/workflows/build_and_test_maxtext.yml b/.github/workflows/build_and_test_maxtext.yml index 3adf141e33..6949bf8d7e 100644 --- a/.github/workflows/build_and_test_maxtext.yml +++ b/.github/workflows/build_and_test_maxtext.yml @@ -43,7 +43,6 @@ jobs: outputs: run_tests: ${{ steps.check.outputs.run_tests }} run_notebooks: ${{ steps.check.outputs.run_notebooks }} - modified_tests: ${{ steps.check.outputs.modified_tests }} steps: - uses: actions/checkout@v4 with: @@ -64,13 +63,24 @@ jobs: echo "Changed files:" echo "$CHANGED_FILES" + if [ -z "$CHANGED_FILES" ]; then + echo "No files detected or diff failed. Running everything as a fail-safe." + echo "run_tests=true" >> $GITHUB_OUTPUT + echo "run_notebooks=true" >> $GITHUB_OUTPUT + exit 0 + fi + # default to running everything if something goes wrong with the checks echo "run_tests=true" >> $GITHUB_OUTPUT echo "run_notebooks=true" >> $GITHUB_OUTPUT - echo "modified_tests=" >> $GITHUB_OUTPUT - MODIFIED_TESTS=$(echo "$CHANGED_FILES" | grep -E '^tests/.*\.py$' | paste -sd ";" - || echo "") - echo "modified_tests=$MODIFIED_TESTS" >> $GITHUB_OUTPUT + # 1. Check if only documentation files (.md) were changed + if ! echo "$CHANGED_FILES" | grep -v -E '\.md$' > /dev/null; then + echo "Documentation-only files changed, skipping all tests and notebooks." + echo "run_tests=false" >> $GITHUB_OUTPUT + echo "run_notebooks=false" >> $GITHUB_OUTPUT + exit 0 + fi # 2. Check if dependencies or Github workflows were changed if echo "$CHANGED_FILES" | grep -E '(^|/)(src/dependencies/|\.github/workflows/)' > /dev/null; then @@ -146,7 +156,6 @@ jobs: base_image: maxtext-unit-test-tpu:py312 is_scheduled_run: ${{ github.event_name == 'schedule' }} maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }} - modified_tests: ${{ needs.analyze_code_changes.outputs.modified_tests }} gpu-tests: name: ${{ matrix.flavor }} tests @@ -162,7 +171,6 @@ jobs: base_image: maxtext-unit-test-cuda12:py312 is_scheduled_run: ${{ github.event_name == 'schedule' }} maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }} - modified_tests: ${{ needs.analyze_code_changes.outputs.modified_tests }} cpu-tests: name: ${{ matrix.flavor }} tests @@ -178,7 +186,6 @@ jobs: base_image: maxtext-unit-test-tpu:py312 is_scheduled_run: ${{ github.event_name == 'schedule' }} maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }} - modified_tests: ${{ needs.analyze_code_changes.outputs.modified_tests }} maxtext_tpu_pathways_unit_tests: needs: build_and_upload_maxtext_package diff --git a/.github/workflows/run_tests_against_package.yml b/.github/workflows/run_tests_against_package.yml index 868240b9b0..2955e082f0 100644 --- a/.github/workflows/run_tests_against_package.yml +++ b/.github/workflows/run_tests_against_package.yml @@ -74,10 +74,6 @@ on: required: false type: boolean default: false - modified_tests: - required: false - type: string - default: '' permissions: contents: read @@ -90,7 +86,6 @@ jobs: XLA_PYTHON_CLIENT_MEM_FRACTION: ${{ inputs.xla_python_client_mem_fraction }} TF_FORCE_GPU_ALLOW_GROWTH: ${{ inputs.tf_force_gpu_allow_growth }} TPU_SKIP_MDS_QUERY: ${{ inputs.device_type == 'cpu' && '1' || '' }} - MODIFIED_TESTS: ${{ inputs.modified_tests }} MAXTEXT_PACKAGE_EXTRA: >- ${{ !contains(inputs.pytest_marker, 'not post_training') && 'tpu-post-train' @@ -151,7 +146,7 @@ jobs: if [ "${INPUTS_IS_SCHEDULED_RUN}" == "true" ]; then FINAL_PYTEST_MARKER="${INPUTS_PYTEST_MARKER}" else - FINAL_PYTEST_MARKER="${INPUTS_PYTEST_MARKER} and (not scheduled_only or force_run)" + FINAL_PYTEST_MARKER="${INPUTS_PYTEST_MARKER} and not scheduled_only" fi # TODO: Use package data for testing and remove the env vars export MAXTEXT_REPO_ROOT=$(pwd) diff --git a/.github/workflows/run_tests_coordinator.yml b/.github/workflows/run_tests_coordinator.yml index d9b35e7f11..935b7fbda0 100644 --- a/.github/workflows/run_tests_coordinator.yml +++ b/.github/workflows/run_tests_coordinator.yml @@ -56,11 +56,6 @@ on: required: false type: string default: '' - modified_tests: - description: 'List of modified tests separated by semicolon' - required: false - type: string - default: '' permissions: contents: read @@ -166,4 +161,3 @@ jobs: total_workers: ${{ contains(inputs.flavor, 'cpu-unit') && 2 || 1 }} maxtext_sha: ${{ inputs.maxtext_sha }} is_update_hlo: ${{ inputs.is_update_hlo }} - modified_tests: ${{ inputs.modified_tests }} diff --git a/tests/conftest.py b/tests/conftest.py index fc55fa1a41..1bd442c8ed 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -127,18 +127,6 @@ def pytest_collection_modifyitems(config, items): for item in remaining: item.add_marker(pytest.mark.decoupled) - import os - modified_tests = os.environ.get("MODIFIED_TESTS", "").split(";") - if modified_tests and modified_tests[0]: - import os.path - for item in remaining: - try: - rel_path = os.path.relpath(str(item.fspath), os.getcwd()) - if rel_path in modified_tests: - item.add_marker(pytest.mark.force_run) - except Exception: # pragma: no cover - pass - def pytest_configure(config): for m in [ @@ -148,6 +136,5 @@ def pytest_configure(config): "external_serving: JetStream / serving / decode server components", "external_training: goodput integrations", "decoupled: marked on tests that are not skipped due to GCP deps, when DECOUPLE_GCLOUD=TRUE", - "force_run: force execution of scheduled tests if their file was modified", ]: config.addinivalue_line("markers", m)