Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions .github/workflows/build_image.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
name: Build and Push Orbax-checkpoint IT runner Docker Images
on:
schedule:
# Run the job daily at 12AM UTC
- cron: '0 0 * * *'
push:
branches:
- 'test_*'
permissions:
contents: read
jobs:
build_and_push:
runs-on: linux-x86-n2-16-buildkit
container: google/cloud-sdk:524.0.0
steps:
- name: Checkout Orbax-checkpoint
uses: actions/checkout@v5
- name: Mark git repositories as safe
run: git config --global --add safe.directory '*'
- name: Configure Docker
run: gcloud auth configure-docker us-docker.pkg.dev,gcr.io -q
- name: Set up Docker BuildX
uses: docker/setup-buildx-action@v3.11.1
with:
driver: remote
endpoint: tcp://localhost:1234
- name: Build and push Docker image
uses: docker/build-push-action@v6
with:
push: true
context: .
file: ./checkpoint/orbax/checkpoint/_src/testing/oss/Dockerfile
tags: gcr.io/orbax-checkpoint/orbax-benchmarks-integration_tests-runner:latest
cache-from: type=gha
outputs: type=image,compression=zstd,force-compression=true
build-args: |
DEVICE=tpu
JAX_VERSION=newest
BRANCH=main
GITHUB_RUNNER=true
- name: Add tags to Docker images
shell: bash
run: |
SOURCE_IMAGE="gcr.io/orbax-checkpoint/orbax-benchmarks-integration_tests-runner"
# Add Orbax-checkpoint tag
orbax_hash=$(git rev-parse --short HEAD)
gcloud container images add-tag "$SOURCE_IMAGE:latest" "$SOURCE_IMAGE:orbax_${orbax_hash}" --quiet
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ the complexities of cluster management and config propagation.
| `--benchmark_binary_path` | `/app/...` | Path to the benchmark runner script within the Docker image. |
| `--use_vertex_tensorboard` | `False` | Use Vertex AI Tensorboard for the workload. |
| `--experiment_name` | `None` | Name of the Vertex AI experiment. |
| `--skip_validation` | `False` | Skip dependency validation checks. |

#### 🔌 Networking & Security
| Flag | Default | Description |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
_NUM_SLICES = flags.DEFINE_integer('num_slices', 1, 'Number of slices.')
_PROJECT = flags.DEFINE_string('project', 'orbax-checkpoint', 'GCP Project ID.')
_ZONE = flags.DEFINE_string('zone', 'europe-west4-a', 'GCP Zone.')
_REGION = flags.DEFINE_string('region', 'europe-west4', 'GCP Region.')
_WORKLOAD_NAME = flags.DEFINE_string(
'workload_name', None, 'Name of the workload. Defaults to generated name.'
)
Expand Down Expand Up @@ -274,6 +275,11 @@
False,
'If True, run workload creation and execution twice to test restart.',
)
_SKIP_VALIDATION = flags.DEFINE_boolean(
'skip_validation',
False,
'If True, skip validation of the benchmark results.',
)

# --- Pathways Flags ---
# Pathways uses a "Sidecar" architecture on XPK:
Expand Down Expand Up @@ -459,7 +465,7 @@ def check_preconditions() -> bool:
Console.print_warning('Could not list clusters to verify existence.')
return False
else:
if _CLUSTER_NAME.value in clusters:
if clusters and _CLUSTER_NAME.value in clusters.split():
Console.print_success(f'Cluster found: {_CLUSTER_NAME.value}')
return True

Expand All @@ -479,6 +485,24 @@ def check_preconditions() -> bool:
)


def get_credentials() -> None:
"""Gets credentials for the project."""
try:
cmd = [
'gcloud',
'container',
'clusters',
'get-credentials',
_CLUSTER_NAME.value,
'--region',
_REGION.value
]
run_command(cmd, suppress_output=not _VERBOSE.value)
except subprocess.CalledProcessError as e:
print(f'Failed to get cluster credentials: {e}')
return False


def create_cluster() -> None:
"""Creates the XPK cluster."""
Console.print_info(f'Creating cluster {_CLUSTER_NAME.value}...')
Expand Down Expand Up @@ -646,6 +670,8 @@ def construct_xpk_command(
base_cmd.append('--enable-ops-agent')
if _RAMDISK_DIRECTORY.value is not None:
base_cmd.append('--mtc-enabled')
if _SKIP_VALIDATION.value:
base_cmd.append('--skip-validation')

if _ENABLE_PATHWAYS.value:
if not _PATHWAYS_SERVER_IMAGE.value:
Expand Down Expand Up @@ -804,6 +830,7 @@ def main(argv: Sequence[str]) -> None:
if _RAMDISK_DIRECTORY.value is not None:
# Delete CSI driver before running any workloads, to delete any previous
# checkpoint files.
get_credentials()
update_bucket_csi_driver(mount_csi_driver=False)
# Mount CSI driver for the workload.
update_bucket_csi_driver(mount_csi_driver=True)
Expand Down
58 changes: 58 additions & 0 deletions checkpoint/orbax/checkpoint/_src/testing/oss/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Base image argument (defaulting to slim python image)
ARG BASE_IMAGE=python:3.11-slim
FROM $BASE_IMAGE

WORKDIR /app

# 1. Install System Dependencies
# common utils + git (needed for checkout)
# --no-install-recommends limits bloat
# python3-pip is standard in python images, no need to install
RUN apt-get update && apt-get install -y --no-install-recommends \
git \
dnsutils \
&& rm -rf /var/lib/apt/lists/*

RUN apt-get update && apt-get install -y --no-install-recommends \
git \
dnsutils \
curl \
ca-certificates \
gnupg \
apt-transport-https \
gettext-base \
gawk \
&& curl -fsSL https://packages.cloud.google.com/apt/doc/apt-key.gpg | gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg \
&& echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee /etc/apt/sources.list.d/google-cloud-sdk.list \
&& apt-get update \
&& apt-get install -y --no-install-recommends \
google-cloud-cli \
google-cloud-cli-gke-gcloud-auth-plugin \
kubectl \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*

COPY ./checkpoint/orbax/checkpoint/_src/testin[g] /app/orbax_repo/checkpoint/orbax/checkpoint/_src/testing

# 3. Setup Python Environment & Dependencies
# Uninstall pre-installed orbax if present in base image to avoid conflicts
RUN pip uninstall -y orbax-checkpoint orbax || true

# # Create a fake docker binary that always returns success (exit 0)
# RUN echo '#!/bin/bash\nexit 0' > /usr/local/bin/docker && chmod +x /usr/local/bin/docker

# Install requirements from repo root if it exists
RUN if [ -f "requirements.txt" ]; then pip install --no-cache-dir -r requirements.txt; fi

RUN pip install --no-cache-dir gcsfs portpicker clu tensorflow pyyaml

# 4. Install Orbax from Source
WORKDIR /app/orbax_repo/checkpoint
RUN pip install xpk

# 5. Environment Setup
# Set PYTHONPATH so 'import orbax' works from the correctly mapped directory
ENV PYTHONPATH=/app/orbax_repo/checkpoint


CMD ["python3", "orbax/checkpoint/_src/testing/oss/cloud_run_integration_tests.py"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Runs Orbax benchmarks on GCP."""

import datetime
import os
import subprocess
import sys
import yaml


def run_benchmark(test_config):
"""Runs a single benchmark test based on the given config.

Args:
test_config: A dictionary containing the test configuration.

Returns:
True if benchmark ran successfully, False otherwise.
"""
print(f"Running benchmark: {test_config['name']}")

# Build command
output_dir = os.path.join(
test_config['output_directory'],
datetime.datetime.now().strftime('%Y%m%d'),
)

cmd = [
'python3',
'orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk.py',
'--cluster_name',
test_config['cluster_name'],
'--tpu_type',
test_config['tpu_type'],
'--zone',
test_config['zone'],
'--config_file',
test_config['config_file'],
'--docker_image',
test_config['docker_image'],
'--output_directory',
output_dir,
'--num_slices',
str(test_config['num_slices']),
]
if test_config.get('nodelete_cluster_on_completion'):
cmd.append('--nodelete_cluster_on_completion')
if test_config.get('ramdisk_directory'):
cmd.extend(['--ramdisk_directory', test_config['ramdisk_directory']])
if test_config.get('test_restart_workflow'):
cmd.append('--test_restart_workflow')
if test_config.get('verbose'):
cmd.append('--verbose')
if test_config.get('skip_validation'):
cmd.append('--skip_validation')
if test_config.get('enable_pathways'):
cmd.append('--enable_pathways')
if test_config.get('gcp_region'):
cmd.extend(['--region', test_config['gcp_region']])

print(f"Executing command: {' '.join(cmd)}")
try:
subprocess.run(cmd, check=True)
except subprocess.CalledProcessError as e:
print(f'Benchmark script failed: {e}')
return False

return True


def main():
"""Loads test configurations and runs benchmarks."""
config_path = 'orbax/checkpoint/_src/testing/oss/cloud_run_integration_tests.yaml'
with open(config_path, 'r') as f:
config = yaml.safe_load(f)

failures = 0
for test in config.get('tests', []):
if not run_benchmark(test):
failures += 1

if failures:
print(f'{failures} benchmarks failed.')
sys.exit(1)


if __name__ == '__main__':
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
tests:
- name: emergency_checkpoint_manager_benchmark
cluster_name: orbax-cluster-test-mtc
tpu_type: v5litepod-16
zone: us-west1-c
gcp_region: us-west1
gcp_project: orbax-checkpoint
config_file: orbax/checkpoint/_src/testing/benchmarks/configs/emergency_checkpoint_manager_benchmark.yaml
docker_image: gcr.io/orbax-checkpoint/orbax-benchmarks-runner:latest
output_directory: gs://orbax-benchmarks/cloud_runs/
nodelete_cluster_on_completion: true
ramdisk_directory: /local/test
num_slices: 2
test_restart_workflow: false
verbose: true
skip_validation: true
enable_pathways: false
Loading