diff --git a/.github/workflows/build_image.yml b/.github/workflows/build_image.yml new file mode 100644 index 000000000..86329ca9f --- /dev/null +++ b/.github/workflows/build_image.yml @@ -0,0 +1,47 @@ +name: Build and Push Orbax-checkpoint IT runner Docker Images +on: + schedule: + # Run the job daily at 12AM UTC + - cron: '0 0 * * *' + push: + branches: + - 'test_*' +permissions: + contents: read +jobs: + build_and_push: + runs-on: linux-x86-n2-16-buildkit + container: google/cloud-sdk:524.0.0 + steps: + - name: Checkout Orbax-checkpoint + uses: actions/checkout@v5 + - name: Mark git repositories as safe + run: git config --global --add safe.directory '*' + - name: Configure Docker + run: gcloud auth configure-docker us-docker.pkg.dev,gcr.io -q + - name: Set up Docker BuildX + uses: docker/setup-buildx-action@v3.11.1 + with: + driver: remote + endpoint: tcp://localhost:1234 + - name: Build and push Docker image + uses: docker/build-push-action@v6 + with: + push: true + context: . + file: ./checkpoint/orbax/checkpoint/_src/testing/oss/Dockerfile + tags: gcr.io/orbax-checkpoint/orbax-benchmarks-integration_tests-runner:latest + cache-from: type=gha + outputs: type=image,compression=zstd,force-compression=true + build-args: | + DEVICE=tpu + JAX_VERSION=newest + BRANCH=main + GITHUB_RUNNER=true + - name: Add tags to Docker images + shell: bash + run: | + SOURCE_IMAGE="gcr.io/orbax-checkpoint/orbax-benchmarks-integration_tests-runner" + # Add Orbax-checkpoint tag + orbax_hash=$(git rev-parse --short HEAD) + gcloud container images add-tag "$SOURCE_IMAGE:latest" "$SOURCE_IMAGE:orbax_${orbax_hash}" --quiet diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/README.md b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/README.md index 8fabcb861..5743c52c0 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/README.md +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/README.md @@ -264,6 +264,7 @@ the complexities of cluster management and config propagation. | `--benchmark_binary_path` | `/app/...` | Path to the benchmark runner script within the Docker image. | | `--use_vertex_tensorboard` | `False` | Use Vertex AI Tensorboard for the workload. | | `--experiment_name` | `None` | Name of the Vertex AI experiment. | +| `--skip_validation` | `False` | Skip dependency validation checks. | #### 🔌 Networking & Security | Flag | Default | Description | diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk.py index 9987ad7b3..5c5405496 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk.py @@ -88,6 +88,7 @@ _NUM_SLICES = flags.DEFINE_integer('num_slices', 1, 'Number of slices.') _PROJECT = flags.DEFINE_string('project', 'orbax-checkpoint', 'GCP Project ID.') _ZONE = flags.DEFINE_string('zone', 'europe-west4-a', 'GCP Zone.') +_REGION = flags.DEFINE_string('region', 'europe-west4', 'GCP Region.') _WORKLOAD_NAME = flags.DEFINE_string( 'workload_name', None, 'Name of the workload. Defaults to generated name.' ) @@ -274,6 +275,11 @@ False, 'If True, run workload creation and execution twice to test restart.', ) +_SKIP_VALIDATION = flags.DEFINE_boolean( + 'skip_validation', + False, + 'If True, skip validation of the benchmark results.', +) # --- Pathways Flags --- # Pathways uses a "Sidecar" architecture on XPK: @@ -459,7 +465,7 @@ def check_preconditions() -> bool: Console.print_warning('Could not list clusters to verify existence.') return False else: - if _CLUSTER_NAME.value in clusters: + if clusters and _CLUSTER_NAME.value in clusters.split(): Console.print_success(f'Cluster found: {_CLUSTER_NAME.value}') return True @@ -479,6 +485,24 @@ def check_preconditions() -> bool: ) +def get_credentials() -> None: + """Gets credentials for the project.""" + try: + cmd = [ + 'gcloud', + 'container', + 'clusters', + 'get-credentials', + _CLUSTER_NAME.value, + '--region', + _REGION.value + ] + run_command(cmd, suppress_output=not _VERBOSE.value) + except subprocess.CalledProcessError as e: + print(f'Failed to get cluster credentials: {e}') + return False + + def create_cluster() -> None: """Creates the XPK cluster.""" Console.print_info(f'Creating cluster {_CLUSTER_NAME.value}...') @@ -646,6 +670,8 @@ def construct_xpk_command( base_cmd.append('--enable-ops-agent') if _RAMDISK_DIRECTORY.value is not None: base_cmd.append('--mtc-enabled') + if _SKIP_VALIDATION.value: + base_cmd.append('--skip-validation') if _ENABLE_PATHWAYS.value: if not _PATHWAYS_SERVER_IMAGE.value: @@ -804,6 +830,7 @@ def main(argv: Sequence[str]) -> None: if _RAMDISK_DIRECTORY.value is not None: # Delete CSI driver before running any workloads, to delete any previous # checkpoint files. + get_credentials() update_bucket_csi_driver(mount_csi_driver=False) # Mount CSI driver for the workload. update_bucket_csi_driver(mount_csi_driver=True) diff --git a/checkpoint/orbax/checkpoint/_src/testing/oss/Dockerfile b/checkpoint/orbax/checkpoint/_src/testing/oss/Dockerfile new file mode 100644 index 000000000..3c3997012 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/oss/Dockerfile @@ -0,0 +1,58 @@ +# Base image argument (defaulting to slim python image) +ARG BASE_IMAGE=python:3.11-slim +FROM $BASE_IMAGE + +WORKDIR /app + +# 1. Install System Dependencies +# common utils + git (needed for checkout) +# --no-install-recommends limits bloat +# python3-pip is standard in python images, no need to install +RUN apt-get update && apt-get install -y --no-install-recommends \ + git \ + dnsutils \ + && rm -rf /var/lib/apt/lists/* + +RUN apt-get update && apt-get install -y --no-install-recommends \ + git \ + dnsutils \ + curl \ + ca-certificates \ + gnupg \ + apt-transport-https \ + gettext-base \ + gawk \ + && curl -fsSL https://packages.cloud.google.com/apt/doc/apt-key.gpg | gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg \ + && echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee /etc/apt/sources.list.d/google-cloud-sdk.list \ + && apt-get update \ + && apt-get install -y --no-install-recommends \ + google-cloud-cli \ + google-cloud-cli-gke-gcloud-auth-plugin \ + kubectl \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + +COPY ./checkpoint/orbax/checkpoint/_src/testin[g] /app/orbax_repo/checkpoint/orbax/checkpoint/_src/testing + +# 3. Setup Python Environment & Dependencies +# Uninstall pre-installed orbax if present in base image to avoid conflicts +RUN pip uninstall -y orbax-checkpoint orbax || true + +# # Create a fake docker binary that always returns success (exit 0) +# RUN echo '#!/bin/bash\nexit 0' > /usr/local/bin/docker && chmod +x /usr/local/bin/docker + +# Install requirements from repo root if it exists +RUN if [ -f "requirements.txt" ]; then pip install --no-cache-dir -r requirements.txt; fi + +RUN pip install --no-cache-dir gcsfs portpicker clu tensorflow pyyaml + +# 4. Install Orbax from Source +WORKDIR /app/orbax_repo/checkpoint +RUN pip install xpk + +# 5. Environment Setup +# Set PYTHONPATH so 'import orbax' works from the correctly mapped directory +ENV PYTHONPATH=/app/orbax_repo/checkpoint + + +CMD ["python3", "orbax/checkpoint/_src/testing/oss/cloud_run_integration_tests.py"] diff --git a/checkpoint/orbax/checkpoint/_src/testing/oss/cloud_run_integration_tests.py b/checkpoint/orbax/checkpoint/_src/testing/oss/cloud_run_integration_tests.py new file mode 100644 index 000000000..67fb8f3d9 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/oss/cloud_run_integration_tests.py @@ -0,0 +1,101 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Runs Orbax benchmarks on GCP.""" + +import datetime +import os +import subprocess +import sys +import yaml + + +def run_benchmark(test_config): + """Runs a single benchmark test based on the given config. + + Args: + test_config: A dictionary containing the test configuration. + + Returns: + True if benchmark ran successfully, False otherwise. + """ + print(f"Running benchmark: {test_config['name']}") + + # Build command + output_dir = os.path.join( + test_config['output_directory'], + datetime.datetime.now().strftime('%Y%m%d'), + ) + + cmd = [ + 'python3', + 'orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk.py', + '--cluster_name', + test_config['cluster_name'], + '--tpu_type', + test_config['tpu_type'], + '--zone', + test_config['zone'], + '--config_file', + test_config['config_file'], + '--docker_image', + test_config['docker_image'], + '--output_directory', + output_dir, + '--num_slices', + str(test_config['num_slices']), + ] + if test_config.get('nodelete_cluster_on_completion'): + cmd.append('--nodelete_cluster_on_completion') + if test_config.get('ramdisk_directory'): + cmd.extend(['--ramdisk_directory', test_config['ramdisk_directory']]) + if test_config.get('test_restart_workflow'): + cmd.append('--test_restart_workflow') + if test_config.get('verbose'): + cmd.append('--verbose') + if test_config.get('skip_validation'): + cmd.append('--skip_validation') + if test_config.get('enable_pathways'): + cmd.append('--enable_pathways') + if test_config.get('gcp_region'): + cmd.extend(['--region', test_config['gcp_region']]) + + print(f"Executing command: {' '.join(cmd)}") + try: + subprocess.run(cmd, check=True) + except subprocess.CalledProcessError as e: + print(f'Benchmark script failed: {e}') + return False + + return True + + +def main(): + """Loads test configurations and runs benchmarks.""" + config_path = 'orbax/checkpoint/_src/testing/oss/cloud_run_integration_tests.yaml' + with open(config_path, 'r') as f: + config = yaml.safe_load(f) + + failures = 0 + for test in config.get('tests', []): + if not run_benchmark(test): + failures += 1 + + if failures: + print(f'{failures} benchmarks failed.') + sys.exit(1) + + +if __name__ == '__main__': + main() diff --git a/checkpoint/orbax/checkpoint/_src/testing/oss/cloud_run_integration_tests.yaml b/checkpoint/orbax/checkpoint/_src/testing/oss/cloud_run_integration_tests.yaml new file mode 100644 index 000000000..592234523 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/oss/cloud_run_integration_tests.yaml @@ -0,0 +1,17 @@ +tests: + - name: emergency_checkpoint_manager_benchmark + cluster_name: orbax-cluster-test-mtc + tpu_type: v5litepod-16 + zone: us-west1-c + gcp_region: us-west1 + gcp_project: orbax-checkpoint + config_file: orbax/checkpoint/_src/testing/benchmarks/configs/emergency_checkpoint_manager_benchmark.yaml + docker_image: gcr.io/orbax-checkpoint/orbax-benchmarks-runner:latest + output_directory: gs://orbax-benchmarks/cloud_runs/ + nodelete_cluster_on_completion: true + ramdisk_directory: /local/test + num_slices: 2 + test_restart_workflow: false + verbose: true + skip_validation: true + enable_pathways: false