From 0a683a8e53e517735b0407d3f3e03560b4d6e458 Mon Sep 17 00:00:00 2001 From: Colin Gaffney Date: Fri, 20 Mar 2026 11:05:46 -0700 Subject: [PATCH] No public description PiperOrigin-RevId: 886881258 --- .../evaluations/model_sizes/llama-405b.yaml | 1 + .../evaluations/model_sizes/llama-70b.yaml | 2 +- .../evaluations/model_sizes/llama-8b.yaml | 1 + .../evaluations/pathways/llama-405b.yaml | 26 ++ .../evaluations/pathways/llama-70b.yaml | 26 ++ .../evaluations/pathways/llama-8b.yaml | 25 ++ .../llama-405b_generate_1-64-1.yaml | 27 ++ .../resharding/llama-70b_generate_1-16-1.yaml | 27 ++ .../resharding/llama-8b_generate_1-4-1.yaml | 27 ++ .../create_large_cluster.sh | 415 ++++++++++++++++++ .../llama-70b_replicas_16.yaml | 28 ++ .../llama-70b_replicas_16_no_broadcast.yaml | 28 ++ .../llama-70b_replicas_2.yaml | 29 ++ .../llama-70b_replicas_2_no_broadcast.yaml | 28 ++ .../llama-70b_replicas_32.yaml | 28 ++ .../llama-70b_replicas_32_no_broadcast.yaml | 28 ++ .../llama-70b_replicas_4.yaml | 29 ++ .../llama-70b_replicas_4_no_broadcast.yaml | 28 ++ .../testing/benchmarks/core/device_mesh.py | 16 +- .../benchmarks/core/device_mesh_test.py | 8 +- .../_src/testing/benchmarks/run_benchmarks.py | 28 +- .../_src/testing/benchmarks/v1/benchmark.py | 3 + .../benchmarks/v1/resharding_benchmark.py | 2 +- .../v1/restore_and_broadcast_benchmark.py | 196 +++++++++ .../restore_and_broadcast_benchmark_test.py | 216 +++++++++ .../_src/testing/benchmarks/xpk/Dockerfile | 18 +- .../testing/benchmarks/xpk/build_image.sh | 11 +- .../_src/testing/benchmarks/xpk/launch_xpk.py | 93 +++- .../testing/benchmarks/xpk/launch_xpk_test.py | 182 ++++++++ .../experimental/v1/_src/context/options.py | 4 +- .../v1/_src/serialization/registration.py | 7 +- 31 files changed, 1569 insertions(+), 18 deletions(-) create mode 100644 checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/pathways/llama-405b.yaml create mode 100644 checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/pathways/llama-70b.yaml create mode 100644 checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/pathways/llama-8b.yaml create mode 100644 checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/resharding/llama-405b_generate_1-64-1.yaml create mode 100644 checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/resharding/llama-70b_generate_1-16-1.yaml create mode 100644 checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/resharding/llama-8b_generate_1-4-1.yaml create mode 100644 checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/restore_and_broadcast/create_large_cluster.sh create mode 100644 checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/restore_and_broadcast/llama-70b_replicas_16.yaml create mode 100644 checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/restore_and_broadcast/llama-70b_replicas_16_no_broadcast.yaml create mode 100644 checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/restore_and_broadcast/llama-70b_replicas_2.yaml create mode 100644 checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/restore_and_broadcast/llama-70b_replicas_2_no_broadcast.yaml create mode 100644 checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/restore_and_broadcast/llama-70b_replicas_32.yaml create mode 100644 checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/restore_and_broadcast/llama-70b_replicas_32_no_broadcast.yaml create mode 100644 checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/restore_and_broadcast/llama-70b_replicas_4.yaml create mode 100644 checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/restore_and_broadcast/llama-70b_replicas_4_no_broadcast.yaml create mode 100644 checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/restore_and_broadcast_benchmark.py create mode 100644 checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/restore_and_broadcast_benchmark_test.py create mode 100644 checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk_test.py diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/model_sizes/llama-405b.yaml b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/model_sizes/llama-405b.yaml index d6f770ed2..ac1fefffb 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/model_sizes/llama-405b.yaml +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/model_sizes/llama-405b.yaml @@ -11,6 +11,7 @@ mesh_config: # The checkpoint configuration, shared across all generated benchmarks. checkpoint_config: path: "gs://orbax-benchmarks/checkpoints/llama-3.1-405B-checkpoints/0/items" + sharding_config_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-405b-v5p-128-data-1-fsdp-64-tensor-1/abstract_state.json" benchmarks: - generator: "orbax.checkpoint._src.testing.benchmarks.v1.benchmark.Benchmark" diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/model_sizes/llama-70b.yaml b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/model_sizes/llama-70b.yaml index c8ff2acc7..932336758 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/model_sizes/llama-70b.yaml +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/model_sizes/llama-70b.yaml @@ -11,6 +11,7 @@ mesh_config: # The checkpoint configuration, shared across all generated benchmarks. checkpoint_config: path: "gs://orbax-benchmarks/checkpoints/llama-3.1-70B-checkpoints/0/items" + sharding_config_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-32-data-1-fsdp-16-tensor-1/abstract_state.json" benchmarks: - generator: "orbax.checkpoint._src.testing.benchmarks.v1.benchmark.Benchmark" @@ -23,4 +24,3 @@ benchmarks: use_zarr3: true use_replica_parallel: false use_compression: true - enable_trace: true diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/model_sizes/llama-8b.yaml b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/model_sizes/llama-8b.yaml index 96483f150..f2fde1535 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/model_sizes/llama-8b.yaml +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/model_sizes/llama-8b.yaml @@ -10,6 +10,7 @@ mesh_config: # The checkpoint configuration, shared across all generated benchmarks. checkpoint_config: path: "gs://orbax-benchmarks/checkpoints/llama-3.1-8B-checkpoints/0/items" + sharding_config_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-8b-v5p-8-data-1-fsdp-4-tensor-1/abstract_state.json" benchmarks: - generator: "orbax.checkpoint._src.testing.benchmarks.v1.benchmark.Benchmark" diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/pathways/llama-405b.yaml b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/pathways/llama-405b.yaml new file mode 100644 index 000000000..760e569d8 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/pathways/llama-405b.yaml @@ -0,0 +1,26 @@ +# The name for the entire test suite run. +# Assumes v5p-128 (64 chips) +suite_name: "Llama 3.1 405B" +num_repeats: 20 + + +mesh_config: + mesh_axes: ["data", "fsdp", "tensor"] + ici_parallelism: {"data": 1, "fsdp": 64, "tensor": 1} + +# The checkpoint configuration, shared across all generated benchmarks. +checkpoint_config: + path: "gs://orbax-benchmarks/checkpoints/llama-405b_generate_1-64-1" + sharding_config_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-405b-v5p-128-data-1-fsdp-64-tensor-1/abstract_state.json" + +benchmarks: + - generator: "orbax.checkpoint._src.testing.benchmarks.v1.benchmark.Benchmark" + options: + # --- Generator Options --- + # These keys must match the attributes of the `BenchmarkOptions` class + # associated with the `Benchmark` generator. + async_enabled: true + use_ocdbt: true + use_zarr3: true + use_replica_parallel: false + use_compression: true diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/pathways/llama-70b.yaml b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/pathways/llama-70b.yaml new file mode 100644 index 000000000..5fa44c864 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/pathways/llama-70b.yaml @@ -0,0 +1,26 @@ +# The name for the entire test suite run. +# Assumes v5p-32 (16 chips) +suite_name: "Llama 3.1 70B" +num_repeats: 20 + + +mesh_config: + mesh_axes: ["data", "fsdp", "tensor"] + ici_parallelism: {"data": 1, "fsdp": 16, "tensor": 1} + +# The checkpoint configuration, shared across all generated benchmarks. +checkpoint_config: + path: "gs://orbax-benchmarks/checkpoints/llama-70b_generate_1-16-1" + sharding_config_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-32-data-1-fsdp-16-tensor-1/abstract_state.json" + +benchmarks: + - generator: "orbax.checkpoint._src.testing.benchmarks.v1.benchmark.Benchmark" + options: + # --- Generator Options --- + # These keys must match the attributes of the `BenchmarkOptions` class + # associated with the `Benchmark` generator. + async_enabled: true + use_ocdbt: true + use_zarr3: true + use_replica_parallel: false + use_compression: true diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/pathways/llama-8b.yaml b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/pathways/llama-8b.yaml new file mode 100644 index 000000000..37fd4bfc2 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/pathways/llama-8b.yaml @@ -0,0 +1,25 @@ +# The name for the entire test suite run. +# Assumes v5p-8 (4 chips) +suite_name: "Llama 3.1 8B" +num_repeats: 20 + +mesh_config: + mesh_axes: ["data", "fsdp", "tensor"] + ici_parallelism: {"data": 1, "fsdp": 4, "tensor": 1} + +# The checkpoint configuration, shared across all generated benchmarks. +checkpoint_config: + path: "gs://orbax-benchmarks/checkpoints/llama-8b_generate_1-4-1" + sharding_config_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-8b-v5p-8-data-1-fsdp-4-tensor-1/abstract_state.json" + +benchmarks: + - generator: "orbax.checkpoint._src.testing.benchmarks.v1.benchmark.Benchmark" + options: + # --- Generator Options --- + # These keys must match the attributes of the `BenchmarkOptions` class + # associated with the `Benchmark` generator. + async_enabled: true + use_ocdbt: true + use_zarr3: true + use_replica_parallel: false + use_compression: true diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/resharding/llama-405b_generate_1-64-1.yaml b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/resharding/llama-405b_generate_1-64-1.yaml new file mode 100644 index 000000000..8492da058 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/resharding/llama-405b_generate_1-64-1.yaml @@ -0,0 +1,27 @@ +# The name for the entire test suite run. +# Assumes v5p-128 (64 chips) +suite_name: "Llama 3.1 405B" +num_repeats: 1 + + +mesh_config: + mesh_axes: ["data", "fsdp", "tensor"] + # Should match reference_sharding_path. + ici_parallelism: {"data": 1, "fsdp": 64, "tensor": 1} + +# Note: checkpoint_config field not specified. +checkpoint_config: + path: "gs://orbax-benchmarks/checkpoints/llama-3.1-405B-checkpoints/0/items" + sharding_config_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-405b-v5p-128-data-1-fsdp-64-tensor-1/abstract_state.json" + +benchmarks: + - generator: "orbax.checkpoint._src.testing.benchmarks.v1.benchmark.Benchmark" + options: + # --- Generator Options --- + # These keys must match the attributes of the `BenchmarkOptions` class + # associated with the `Benchmark` generator. + async_enabled: true + use_ocdbt: true + use_zarr3: true + use_replica_parallel: false + use_compression: true diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/resharding/llama-70b_generate_1-16-1.yaml b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/resharding/llama-70b_generate_1-16-1.yaml new file mode 100644 index 000000000..daa3f5171 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/resharding/llama-70b_generate_1-16-1.yaml @@ -0,0 +1,27 @@ +# The name for the entire test suite run. +# Assumes v5p-32 (16 chips) +suite_name: "Llama 3.1 70B" +num_repeats: 1 + + +mesh_config: + mesh_axes: ["data", "fsdp", "tensor"] + # Should match reference_sharding_path. + ici_parallelism: {"data": 1, "fsdp": 16, "tensor": 1} + +# Note: checkpoint_config field not specified. +checkpoint_config: + path: "gs://orbax-benchmarks/checkpoints/llama-3.1-70B-checkpoints/0/items" + sharding_config_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-32-data-1-fsdp-16-tensor-1/abstract_state.json" + +benchmarks: + - generator: "orbax.checkpoint._src.testing.benchmarks.v1.benchmark.Benchmark" + options: + # --- Generator Options --- + # These keys must match the attributes of the `BenchmarkOptions` class + # associated with the `Benchmark` generator. + async_enabled: true + use_ocdbt: true + use_zarr3: true + use_replica_parallel: false + use_compression: true diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/resharding/llama-8b_generate_1-4-1.yaml b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/resharding/llama-8b_generate_1-4-1.yaml new file mode 100644 index 000000000..a3e3e4228 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/resharding/llama-8b_generate_1-4-1.yaml @@ -0,0 +1,27 @@ +# The name for the entire test suite run. +# Assumes v5p-8 (4 chips) +suite_name: "Llama 3.1 8B" +num_repeats: 1 + + +mesh_config: + mesh_axes: ["data", "fsdp", "tensor"] + # Should match reference_sharding_path. + ici_parallelism: {"data": 1, "fsdp": 4, "tensor": 1} + +# Note: checkpoint_config field not specified. +checkpoint_config: + path: "gs://orbax-benchmarks/checkpoints/llama-3.1-8B-checkpoints/0/items" + sharding_config_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-8b-v5p-8-data-1-fsdp-4-tensor-1/abstract_state.json" + +benchmarks: + - generator: "orbax.checkpoint._src.testing.benchmarks.v1.benchmark.Benchmark" + options: + # --- Generator Options --- + # These keys must match the attributes of the `BenchmarkOptions` class + # associated with the `Benchmark` generator. + async_enabled: true + use_ocdbt: true + use_zarr3: true + use_replica_parallel: false + use_compression: true diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/restore_and_broadcast/create_large_cluster.sh b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/restore_and_broadcast/create_large_cluster.sh new file mode 100644 index 000000000..c37d42f6a --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/restore_and_broadcast/create_large_cluster.sh @@ -0,0 +1,415 @@ +#!/bin/bash + +# Note: this is not intended to be run E2E. Run the commands manually. +# Borrows from XPK docs: https://github.com/AI-Hypercomputer/xpk/blob/main/xpk-large-scale-guide.sh + +######### OVERVIEW OF SCRIPT ######### +## This script is intended to guide one on the steps needed to create large scale +## (>2k VMs) with xpk and GKE. +## This script was run by manually copying commands per step and verifying the +## output of each step. +## We recommend you manually copy commands per step and verify the outputs of +## each step. + +## Step Summary is: +## Step 1: Cluster Networking setup. +## Step 2: Create your cluster with xpk. +## Step 3: Move from KubeDNS to CoreDNS. This is necessary past 1000 VMs. +## (SKIPPED) Step 4: Pass Cluster name and Project ID to Google POCs to setup your cluster +## for large scale and high throughput. This is necessary past 5000 VMs. +## Step 5: Scale up your cluster. +######### OVERVIEW OF SCRIPT ######### + +### USER VARIABLES: +# TODO(USER): ADJUST PROJECT_NAME, CLUSTER NAME at least. +export PREFIX=${USER} +export PROJECT="orbax-checkpoint" +export REGION=us-west1 +export ZONE=us-west1-a +export CLUSTER=cpgaffney-n2-standard-4-64-32 +export DEVICE_TYPE=n2-standard-4-64 + +gcloud config set project $PROJECT +gcloud config set compute/zone $ZONE +gcloud config set compute/region $REGION + +### INTERNAL VARIABLES: +export NETWORK_NAME=${PREFIX}-privatenetwork +export SUBNET_NAME=${PREFIX}-privatesubnet +export FIREWALL_RULE_NAME=${PREFIX}-privatefirewall +export ROUTER_NAME=${PREFIX}-network +export NAT_CONFIG=${PREFIX}-natconfig + +##### STEP 1 ################# +##### Cluster Networking ##### +############################## + +##### 1A ##################### +# Create network for cluster. +##### 1A ##################### + +gcloud compute networks create "${NETWORK_NAME}" --mtu=8896 --bgp-routing-mode=regional --subnet-mode=custom --project="${PROJECT}" + +# Created [https://www.googleapis.com/compute/v1/projects/PROJECT/global/networks/PREFIX-privatenetwork]. +# NAME SUBNET_MODE BGP_ROUTING_MODE IPV4_RANGE GATEWAY_IPV4 +# PREFIX-privatenetwork CUSTOM REGIONAL + +# Instances on this network will not be reachable until firewall rules +# are created. As an example, you can allow all internal traffic between +# instances as well as SSH, RDP, and ICMP by running: + +##### 1B ##################### +# Create subnetwork for cluster. +##### 1B ##################### + +gcloud compute networks subnets create "${SUBNET_NAME}" --network="${NETWORK_NAME}" --range=10.10.0.0/18 --region="${REGION}" --project="${PROJECT}" + +# Created [https://www.googleapis.com/compute/v1/projects/PROJECT/regions/us-central2/subnetworks/PREFIX-privatesubnet]. +# NAME REGION NETWORK RANGE STACK_TYPE IPV6_ACCESS_TYPE INTERNAL_IPV6_PREFIX EXTERNAL_IPV6_PREFIX +# PREFIX-privatesubnet us-central2 PREFIX-privatenetwork 10.10.0.0/18 IPV4_ONLY + +##### 1C ##################### +# Create firewall rules for private network. +##### 1C ##################### + +gcloud compute firewall-rules create "${FIREWALL_RULE_NAME}" --network "${NETWORK_NAME}" --allow tcp,icmp,udp --project="${PROJECT}" + +# Creating firewall...â ıCreated [https://www.googleapis.com/compute/v1/projects/PROJECT/global/firewalls/PREFIX-privatefirewall]. +# Creating firewall...done. +# NAME NETWORK DIRECTION PRIORITY ALLOW DENY DISABLED +# PREFIX-privatefirewall PREFIX-privatenetwork INGRESS 1000 tcp,icmp,udp False + +##### 1D ##################### +# Routers for network and region. +##### 1D ##################### + +gcloud compute routers create "${ROUTER_NAME}" \ + --project="${PROJECT}" \ + --network="${NETWORK_NAME}" \ + --region="${REGION}" + +# Creating router [PREFIX-network]...done. +# NAME REGION NETWORK +# PREFIX-network us-central2 PREFIX-privatenetwork + +##### 1E ##################### +# Router nats for the region +##### 1E ##################### + +gcloud compute routers nats create "${NAT_CONFIG}" \ + --router="${ROUTER_NAME}" \ + --region="${REGION}" \ + --project="${PROJECT}" \ + --auto-allocate-nat-external-ips \ + --nat-all-subnet-ip-ranges \ + --enable-logging + +# Creating NAT [PREFIX-natconfig] in router [PREFIX-network]...done. + +##### STEP 2 ############################ +##### Create your cluster with xpk. ##### +######################################### + +##### 2A ##################### +# Export cluster and node pool arguments +##### 2A ##################### + +export CLUSTER_ARGUMENTS="\ + --network=${NETWORK_NAME} \ + --subnetwork=${SUBNET_NAME} \ + --scopes=storage-full,gke-default \ + --enable-ip-alias \ + --master-ipv4-cidr 172.16.0.32/28 \ + --cluster-ipv4-cidr=10.224.0.0/12 \ + --enable-private-nodes \ + --no-enable-master-authorized-networks \ +" + +export TPU_NODEPOOL_ARGUMENTS="\ + --scopes=storage-full,gke-default \ + --enable-gvnic \ + --max-pods-per-node 15 \ + --disk-type=pd-standard \ + --disk-size=50 \ +" +export NUMSLICES=4 # Use fewer slices than ultimately desired at first. + +##### 2B ##################### +# Activate a local XPK build +##### 2B ##################### + +source ~/xpk/bin/activate + +##### 2C ##################### +# Confirm that variables are correctly set: +##### 2C ##################### +echo xpk cluster create \ + --cluster="${CLUSTER}" \ + --device-type="${DEVICE_TYPE}" \ + --num-slices="${NUMSLICES}" \ + --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \ + --custom-nodepool-arguments="${TPU_NODEPOOL_ARGUMENTS}" \ + --spot + +# example output ... +# xpk cluster create --cluster NAME \ +# --tpu-type=v5litepod-256 --num-slices=4 \ +# --host-maintenance-interval=PERIODIC \ # Use this argument if using on-demand rather than spot. +# --custom-cluster-arguments=" --network=NETWORK --subnetwork=SUBNET --scopes=storage-full,gke-default --enable-ip-alias --master-ipv4-cidr 172.16.0.32/28 --cluster-ipv4-cidr=10.224.0.0/12" +# --custom-tpu-nodepool-arguments=" --scopes=storage-full,gke-default --enable-gvnic --max-pods-per-node 15 --disk-size=50" + + +##### 2D ##################### +# Run Cluster Create. +##### 2D ##################### + +# Rerun create command to update the cluster (with a new slice size) or if the create command fails. +xpk cluster create \ + --cluster="${CLUSTER}" \ + --device-type="${DEVICE_TYPE}" \ + --num-slices="${NUMSLICES}" \ + --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \ + --custom-nodepool-arguments="${TPU_NODEPOOL_ARGUMENTS}" \ + --spot + +# This process takes around 4 minutes with 4 slices of v5e-256. + +############################### +##### 2D - TIPS ############### +############################### + +# 1) View other examples of xpk here: https://github.com/google/maxtext/blob/main/xpk/README.md +# 2) xpk create command will update the cluster. If you adjust the num-slices and call create again, +# xpk will intelligently adjust the # number of node pools and execute the number of create / delete commands. +# 3) If xpk create command fails, the first step is to try running create again. + +##### STEP 3 ############################## +##### MOVE From KubeDNS to CoreDNS ######## +########################################### + +##### 3A ##################### +# Install jq command-line JSON processor +##### 3A ##################### + +sudo apt install jq -y + +##### 3B ##################### +# git clone coredns deployment repo. +##### 3B ##################### + +git clone https://github.com/coredns/deployment.git + +##### 3C ##################### +# Go to repo and deploy coredns. +##### 3C ##################### + +cd deployment/kubernetes +./deploy.sh | kubectl apply -f - + +# serviceaccount/coredns created +# clusterrole.rbac.authorization.k8s.io/system:coredns created +# clusterrolebinding.rbac.authorization.k8s.io/system:coredns created +# configmap/coredns created +# deployment.apps/coredns created +# Warning: resource services/kube-dns is missing the kubectl.kubernetes.io/last-applied-configuration annotation which is required by kubectl apply. kubectl apply should only be used on resources created declaratively by either kubectl create --save-config or kubectl apply. The missing annotation will be patched automatically. +# service/kube-dns configured + +##### 3D ##################### +# Scale down kube-dns-autoscaler +##### 3D ##################### + +kubectl scale deployment --replicas=0 kube-dns-autoscaler --namespace=kube-system + +# deployment.apps/kube-dns-autoscaler scaled + +##### 3E ##################### +# Scale down kube-dns +##### 3E ##################### + +kubectl scale deployment --replicas=0 kube-dns --namespace=kube-system + +# Warning: spec.template.metadata.annotations[scheduler.alpha.kubernetes.io/critical-pod]: non-functional in v1.16+; use the "priorityClassName" field instead +# Warning: spec.template.metadata.annotations[seccomp.security.alpha.kubernetes.io/pod]: non-functional in v1.27+; use the "seccompProfile" field instead +# deployment.apps/kube-dns scaled + +##### 3F ##################### +# Scale up core-dns +##### 3F ##################### + +# We recommend 15+ replicas +kubectl scale deployment coredns --replicas=15 -n kube-system + +# deployment.apps/coredns scaled + +##### 3G ##################### +# Verify that kubedns pods have stopped. +# Verify that coredns pods have started. +##### 3G ##################### + +watch 'kubectl get pods -n kube-system -o=wide | grep dns | grep -i kube' +# These should be terminated / disappear soon. +watch 'kubectl get pods -n kube-system -o=wide | grep dns | grep -i core' +# These should create at least one coredns pod. + +##### 3H ##################### +# Rerun xpk cluster create to plumb coredns changes to the cluster. +##### 3H ##################### + +# Go to the correct directory. +cd ../.. + +# Cluster create is the same command as run previously in step 2D. It will +# not recreate the cluster but just update it. +xpk cluster create \ + --cluster="${CLUSTER}" \ + --device-type="${DEVICE_TYPE}" \ + --num-slices="${NUMSLICES}" \ + --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \ + --custom-nodepool-arguments="${TPU_NODEPOOL_ARGUMENTS}" \ + --spot + +##### STEP 5 ############################### +##### Begin Scale Up of GKE Cluster ######## +############################################ + +##### 5A ##################### +# TODO(USER): Set NUMSLICES to what you wish to scale to +##### 5A ##################### + +# Remember it is ok to incrementally scale if you wish. You can run cluster create +# repeatedly and adjust `--num-slices`. +export NUMSLICES=32 + +##### 5B ##################### +# Confirm that variables are correctly set: +##### 5B ##################### + +echo xpk cluster create \ + --cluster="${CLUSTER}" \ + --device-type="${DEVICE_TYPE}" \ + --num-slices="${NUMSLICES}" \ + --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \ + --custom-nodepool-arguments="${TPU_NODEPOOL_ARGUMENTS}" \ + --spot + +# example output ... +# xpk cluster create --cluster NAME \ +# --tpu-type=v5litepod-256 --num-slices=64 \ +# --host-maintenance-interval=PERIODIC \ +# --custom-cluster-arguments=" --network=NETWORK --subnetwork=SUBNET --scopes=storage-full,gke-default --enable-ip-alias --enable-private-nodes --master-ipv4-cidr 172.16.0.32/28 --cluster-ipv4-cidr=10.224.0.0/12 --no-enable-master-authorized-networks" +# --custom-nodepool-arguments=" --scopes=storage-full,gke-default --enable-gvnic --max-pods-per-node 15 --disk-size=50" + +##### 5C ##################### +# Scale up to NUMSLICES (64 in the provided case) V5e-256s. +##### 5C ##################### + +xpk cluster create \ + --cluster="${CLUSTER}" \ + --device-type="${DEVICE_TYPE}" \ + --num-slices="${NUMSLICES}" \ + --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \ + --custom-nodepool-arguments="${TPU_NODEPOOL_ARGUMENTS}" \ + --spot + +############################### +##### 5C - POTENTIAL ERRORS ### +############################### +# If you see failures, the first step is to retry the cluster create command. +# [XPK] Terminating all Create and Delete Nodepools processes since at least one failed. +# [XPK] Failure is NodepoolCreate-PREFIX-CLUSTER-np-53 and logfile /tmp/NodepoolCreate-PREFIX-CLUSTER-np-53-pqfbm5nl +# [XPK] Create and Delete Nodepools returned ERROR 1 +# [XPK] XPK failed, error code 1 + +# Auto-repairing nodepools on creation. +# Node pools can go into auto repair if there is an issue with their creation. For example, +# in the DEADLINE error, the node pool will automatically repair itself. This took ~20 minutes for the +# node pool to repair in the above example. You can continue with rerun cluster create +# commands while it is auto-repairing so that the rest of the cluster continues to +# be created while the repair occurs. + +# It took 20 minutes for the above internal example to go from 4 to 64 NPs with +# a series of increment cluster create steps. + +##### (OPTIONAL) STEP 6 (OPTIONAL) ######################################### +##### Run a simple multislice sample job ################################### +############################################################################ + +##### 6A ##################### +# Verify in Cloud Console that you have NUMSLICES node pools. +##### 6A ##################### + +# Verify in Cloud Console that you have NUMSLICES node pools in the NODES tab of Cloud Console. +echo "https://console.cloud.google.com/kubernetes/clusters/details/${REGION}/${CLUSTER}/details?project=${PROJECT}" + +##### 6B ##################### +# Run a multislice workload on all slices. +##### 6B ##################### + +# Set --scheduler=gke.io/high-throughput-scheduler to use the high throughput scheduler. + +xpk workload create \ + --scheduler=gke.io/high-throughput-scheduler \ + --workload xpk-test-workload --command "echo hello world" --cluster ${CLUSTER} \ + --device-type="${DEVICE_TYPE}" --num-slices=${NUMSLICES} + +# [XPK] Starting xpk +# [XPK] Working on args.project='PROJECT' and us-central2-b +# [XPK] Task: `Set Cluster` is implemented by `gcloud container clusters get-credentials CLUSTER --region=us-central2 --project=PROJECT && kubectl config view`, hiding output unless there is an error. +# [XPK] Task: `Set Cluster` succeeded. +# [XPK] Task: `Check if Workload Already Exists` is implemented by `kubectl get workloads -o=custom-columns='Jobset:.metadata.ownerReferences[0].name'`, hiding output unless there is an error. +# [XPK] Starting workload create +# [XPK] Task: `Creating Workload` is implemented by `kubectl apply -f /tmp/tmpk7599zd9`, streaming output live. +# [XPK] Waiting for `Creating Workload`, for 0 seconds +# [XPK] Waiting for `Creating Workload`, for 1 seconds +# jobset.jobset.x-k8s.io/xpk-test-workload created +# [XPK] Task: `Creating Workload` terminated with code `0` +# [XPK] Follow your workload here: WORKLOAD_LOGS_LINK +# [XPK] Exiting XPK cleanly + +# ########################################### +# #### Logs expected from the above link #### +# ########################################### +# 2023-10-03 11:34:54.621 PDT +# XPK Start: Tue Oct 3 18:34:54 UTC 2023 +# 2023-10-03 11:34:54.622 PDT +# hello world +# 2023-10-03 11:34:54.622 PDT +# XPK End: Tue Oct 3 18:34:54 UTC 2023 +# 2023-10-03 11:34:54.622 PDT +# EXIT_CODE=0 +# 2023-10-03 11:34:54.779 PDT +# XPK Start: Tue Oct 3 18:34:54 UTC 2023 +# 2023-10-03 11:34:54.779 PDT +# hello world +# ... + +##### 6C ##################### +# Verify workload. +# Use the link in the above "WORKLOAD_LOGS_LINK" view logs +# Run xpk workload list to view all workloads on the cluster. +##### 6C ##################### + +# Use the link in the above "WORKLOAD_LOGS_LINK" view logs. You should see +# the echo command in cloud logs. +xpk workload list \ + --cluster ${CLUSTER} + +############################### +##### 6C - TIPS ############### +############################### +# If you see `Not all pods are ready or succeeded` then the workload is still running. +# If you see `JobSet finished successfully` then the workload is finished successfully. + +# [XPK] Starting xpk +# Namespace(xpk_subcommands='workload', func=, xpk_workload_subcommands='list', cluster='CLUSTER', project=None, zone=None, dry_run=False) +# [XPK] Starting workload list +# [XPK] Working on args.project='PROJECT' and us-central2-b +# [XPK] Task: `Set Cluster` is implemented by `gcloud container clusters get-credentials CLUSTER --region=us-central2 --project=PROJECT && kubectl config view`, hiding output unless there is an error. +# [XPK] Task: `Set Cluster` succeeded. +# [XPK] Task: `List Jobs` is implemented by `kubectl get workloads -o=custom-columns='Jobset:.metadata.ownerReferences[0].name,Created Time:.metadata.creationTimestamp,Priority:.spec.priorityClassName,TPU VMs Needed:.spec.podSets[0].count,Last Status Verbose:.status.conditions[-1].message,Last Status:.status.conditions[-1].status,Last Transition:.status.conditions[-1].lastTransitionTime,Current Queue:.status.admission.clusterQueue,All Done:.status.reclaimablePods[0].count'`, streaming output live. +# [XPK] Waiting for `List Jobs`, for 0 seconds +# [XPK] Waiting for `List Jobs`, for 1 seconds +# Jobset Created Time Priority TPU VMs Needed Last Status Verbose Last Status Last Transition Current Queue All Done +# xpk-test-workload 2023-09-27T18:30:47Z medium 5120 Not all pods are ready or succeeded False 2023-09-27T18:30:47Z cluster-queue 192 +# [XPK] Task: `List Jobs` terminated with code `0` +# [XPK] Exiting XPK cleanly \ No newline at end of file diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/restore_and_broadcast/llama-70b_replicas_16.yaml b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/restore_and_broadcast/llama-70b_replicas_16.yaml new file mode 100644 index 000000000..b25fa4c62 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/restore_and_broadcast/llama-70b_replicas_16.yaml @@ -0,0 +1,28 @@ +# The name for the entire test suite run. +# Assumes n2-standard-4-64 x 16 replicas +suite_name: "llama-70b_replicas_16" +num_repeats: 20 + + +mesh_config: + mesh_axes: ["replica", "model"] + # Should match reference_sharding_path. + ici_parallelism: {"replica": 1, "model": 64} + dcn_parallelism: {"replica": 16} + +# Note: checkpoint_config field not specified. + +benchmarks: + - generator: "orbax.checkpoint._src.testing.benchmarks.v1.restore_and_broadcast_benchmark.RestoreAndBroadcastBenchmark" + options: + # --- Generator Options --- + # These keys must match the attributes of the `RestoreAndBroadcastBenchmarkOptions` class + # associated with the `RestoreAndBroadcastBenchmark` generator. + async_enabled: true + use_ocdbt: true + use_zarr3: true + use_replica_parallel: false + use_compression: true + reference_checkpoint_path: "gs://orbax-benchmarks/checkpoints/llama-70b_generate_4-8-4_subchunked/ckpt" + reference_sharding_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-128-data-1-fsdp-64-tensor-1/abstract_state.json" + use_load_and_broadcast: true \ No newline at end of file diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/restore_and_broadcast/llama-70b_replicas_16_no_broadcast.yaml b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/restore_and_broadcast/llama-70b_replicas_16_no_broadcast.yaml new file mode 100644 index 000000000..4aa2ec1f6 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/restore_and_broadcast/llama-70b_replicas_16_no_broadcast.yaml @@ -0,0 +1,28 @@ +# The name for the entire test suite run. +# Assumes n2-standard-4-64 x 16 replicas +suite_name: "llama-70b_replicas_16_no_broadcast" +num_repeats: 20 + + +mesh_config: + mesh_axes: ["replica", "model"] + # Should match reference_sharding_path. + ici_parallelism: {"replica": 1, "model": 64} + dcn_parallelism: {"replica": 16} + +# Note: checkpoint_config field not specified. + +benchmarks: + - generator: "orbax.checkpoint._src.testing.benchmarks.v1.restore_and_broadcast_benchmark.RestoreAndBroadcastBenchmark" + options: + # --- Generator Options --- + # These keys must match the attributes of the `RestoreAndBroadcastBenchmarkOptions` class + # associated with the `RestoreAndBroadcastBenchmark` generator. + async_enabled: true + use_ocdbt: true + use_zarr3: true + use_replica_parallel: false + use_compression: true + reference_checkpoint_path: "gs://orbax-benchmarks/checkpoints/llama-70b_generate_4-8-4_subchunked/ckpt" + reference_sharding_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-128-data-1-fsdp-64-tensor-1/abstract_state.json" + use_load_and_broadcast: false diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/restore_and_broadcast/llama-70b_replicas_2.yaml b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/restore_and_broadcast/llama-70b_replicas_2.yaml new file mode 100644 index 000000000..e82ef0b8e --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/restore_and_broadcast/llama-70b_replicas_2.yaml @@ -0,0 +1,29 @@ +# The name for the entire test suite run. +# Assumes n2-standard-4-64 x 2 replicas +suite_name: "llama-70b_replicas_2" +num_repeats: 20 + + +mesh_config: + mesh_axes: ["replica", "model"] + # Should match reference_sharding_path. + ici_parallelism: {"replica": 1, "model": 64} + dcn_parallelism: {"replica": 2} + +# Note: checkpoint_config field not specified. + +benchmarks: + - generator: "orbax.checkpoint._src.testing.benchmarks.v1.restore_and_broadcast_benchmark.RestoreAndBroadcastBenchmark" + options: + # --- Generator Options --- + # These keys must match the attributes of the `RestoreAndBroadcastBenchmarkOptions` class + # associated with the `RestoreAndBroadcastBenchmark` generator. + async_enabled: true + use_ocdbt: true + use_zarr3: true + use_replica_parallel: false + use_compression: true + reference_checkpoint_path: "gs://orbax-benchmarks/checkpoints/llama-70b_generate_4-8-4_subchunked/ckpt" + reference_sharding_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-128-data-1-fsdp-64-tensor-1/abstract_state.json" + use_load_and_broadcast: true + # enable_trace: true \ No newline at end of file diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/restore_and_broadcast/llama-70b_replicas_2_no_broadcast.yaml b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/restore_and_broadcast/llama-70b_replicas_2_no_broadcast.yaml new file mode 100644 index 000000000..e224abdd6 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/restore_and_broadcast/llama-70b_replicas_2_no_broadcast.yaml @@ -0,0 +1,28 @@ +# The name for the entire test suite run. +# Assumes n2-standard-4-64 x 2 replicas +suite_name: "llama-70b_replicas_2_no_broadcast" +num_repeats: 20 + + +mesh_config: + mesh_axes: ["replica", "model"] + # Should match reference_sharding_path. + ici_parallelism: {"replica": 1, "model": 64} + dcn_parallelism: {"replica": 2} + +# Note: checkpoint_config field not specified. + +benchmarks: + - generator: "orbax.checkpoint._src.testing.benchmarks.v1.restore_and_broadcast_benchmark.RestoreAndBroadcastBenchmark" + options: + # --- Generator Options --- + # These keys must match the attributes of the `RestoreAndBroadcastBenchmarkOptions` class + # associated with the `RestoreAndBroadcastBenchmark` generator. + async_enabled: true + use_ocdbt: true + use_zarr3: true + use_replica_parallel: false + use_compression: true + reference_checkpoint_path: "gs://orbax-benchmarks/checkpoints/llama-70b_generate_4-8-4_subchunked/ckpt" + reference_sharding_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-128-data-1-fsdp-64-tensor-1/abstract_state.json" + use_load_and_broadcast: false diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/restore_and_broadcast/llama-70b_replicas_32.yaml b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/restore_and_broadcast/llama-70b_replicas_32.yaml new file mode 100644 index 000000000..04994d343 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/restore_and_broadcast/llama-70b_replicas_32.yaml @@ -0,0 +1,28 @@ +# The name for the entire test suite run. +# Assumes n2-standard-4-64 x 32 replicas +suite_name: "llama-70b_replicas_32" +num_repeats: 20 + + +mesh_config: + mesh_axes: ["replica", "model"] + # Should match reference_sharding_path. + ici_parallelism: {"replica": 1, "model": 64} + dcn_parallelism: {"replica": 32} + +# Note: checkpoint_config field not specified. + +benchmarks: + - generator: "orbax.checkpoint._src.testing.benchmarks.v1.restore_and_broadcast_benchmark.RestoreAndBroadcastBenchmark" + options: + # --- Generator Options --- + # These keys must match the attributes of the `RestoreAndBroadcastBenchmarkOptions` class + # associated with the `RestoreAndBroadcastBenchmark` generator. + async_enabled: true + use_ocdbt: true + use_zarr3: true + use_replica_parallel: false + use_compression: true + reference_checkpoint_path: "gs://orbax-benchmarks/checkpoints/llama-70b_generate_4-8-4_subchunked/ckpt" + reference_sharding_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-128-data-1-fsdp-64-tensor-1/abstract_state.json" + use_load_and_broadcast: true diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/restore_and_broadcast/llama-70b_replicas_32_no_broadcast.yaml b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/restore_and_broadcast/llama-70b_replicas_32_no_broadcast.yaml new file mode 100644 index 000000000..ec8a63eef --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/restore_and_broadcast/llama-70b_replicas_32_no_broadcast.yaml @@ -0,0 +1,28 @@ +# The name for the entire test suite run. +# Assumes n2-standard-4-64 x 32 replicas +suite_name: "llama-70b_replicas_32_no_broadcast" +num_repeats: 20 + + +mesh_config: + mesh_axes: ["replica", "model"] + # Should match reference_sharding_path. + ici_parallelism: {"replica": 1, "model": 64} + dcn_parallelism: {"replica": 32} + +# Note: checkpoint_config field not specified. + +benchmarks: + - generator: "orbax.checkpoint._src.testing.benchmarks.v1.restore_and_broadcast_benchmark.RestoreAndBroadcastBenchmark" + options: + # --- Generator Options --- + # These keys must match the attributes of the `RestoreAndBroadcastBenchmarkOptions` class + # associated with the `RestoreAndBroadcastBenchmark` generator. + async_enabled: true + use_ocdbt: true + use_zarr3: true + use_replica_parallel: false + use_compression: true + reference_checkpoint_path: "gs://orbax-benchmarks/checkpoints/llama-70b_generate_4-8-4_subchunked/ckpt" + reference_sharding_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-128-data-1-fsdp-64-tensor-1/abstract_state.json" + use_load_and_broadcast: false diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/restore_and_broadcast/llama-70b_replicas_4.yaml b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/restore_and_broadcast/llama-70b_replicas_4.yaml new file mode 100644 index 000000000..425ba3d0a --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/restore_and_broadcast/llama-70b_replicas_4.yaml @@ -0,0 +1,29 @@ +# The name for the entire test suite run. +# Assumes n2-standard-4-64 x 4 replicas +suite_name: "llama-70b_replicas_4" +num_repeats: 20 + + +mesh_config: + mesh_axes: ["replica", "model"] + # Should match reference_sharding_path. + ici_parallelism: {"replica": 1, "model": 64} + dcn_parallelism: {"replica": 4} + +# Note: checkpoint_config field not specified. + +benchmarks: + - generator: "orbax.checkpoint._src.testing.benchmarks.v1.restore_and_broadcast_benchmark.RestoreAndBroadcastBenchmark" + options: + # --- Generator Options --- + # These keys must match the attributes of the `RestoreAndBroadcastBenchmarkOptions` class + # associated with the `RestoreAndBroadcastBenchmark` generator. + async_enabled: true + use_ocdbt: true + use_zarr3: true + use_replica_parallel: false + use_compression: true + reference_checkpoint_path: "gs://orbax-benchmarks/checkpoints/llama-70b_generate_4-8-4_subchunked/ckpt" + reference_sharding_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-128-data-1-fsdp-64-tensor-1/abstract_state.json" + use_load_and_broadcast: true + enable_trace: true \ No newline at end of file diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/restore_and_broadcast/llama-70b_replicas_4_no_broadcast.yaml b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/restore_and_broadcast/llama-70b_replicas_4_no_broadcast.yaml new file mode 100644 index 000000000..809371b55 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/restore_and_broadcast/llama-70b_replicas_4_no_broadcast.yaml @@ -0,0 +1,28 @@ +# The name for the entire test suite run. +# Assumes n2-standard-4-64 x 4 replicas +suite_name: "llama-70b_replicas_4_no_broadcast" +num_repeats: 20 + + +mesh_config: + mesh_axes: ["replica", "model"] + # Should match reference_sharding_path. + ici_parallelism: {"replica": 1, "model": 64} + dcn_parallelism: {"replica": 4} + +# Note: checkpoint_config field not specified. + +benchmarks: + - generator: "orbax.checkpoint._src.testing.benchmarks.v1.restore_and_broadcast_benchmark.RestoreAndBroadcastBenchmark" + options: + # --- Generator Options --- + # These keys must match the attributes of the `RestoreAndBroadcastBenchmarkOptions` class + # associated with the `RestoreAndBroadcastBenchmark` generator. + async_enabled: true + use_ocdbt: true + use_zarr3: true + use_replica_parallel: false + use_compression: true + reference_checkpoint_path: "gs://orbax-benchmarks/checkpoints/llama-70b_generate_4-8-4_subchunked/ckpt" + reference_sharding_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-128-data-1-fsdp-64-tensor-1/abstract_state.json" + use_load_and_broadcast: false \ No newline at end of file diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/device_mesh.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/device_mesh.py index a2d4fde66..2ed1ddceb 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/device_mesh.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/device_mesh.py @@ -49,13 +49,25 @@ def create_mesh(config: configs.MeshConfig) -> jax.sharding.Mesh: devices_array = mesh_utils.create_device_mesh(ici_shape, devices) logging.info( 'Creating mesh with axes: %s', - {axis: dim for axis, dim in zip(config.mesh_axes, devices_array.shape)}, + dict(zip(config.mesh_axes, devices_array.shape)), ) return jax.sharding.Mesh(devices_array, config.mesh_axes) else: logging.info('Creating hybrid mesh.') dcn_shape = [dcn_parallelism.get(axis, 1) for axis in config.mesh_axes] + if jax.default_backend() == 'cpu': + devices = jax.devices() + # Sort devices by process index to ensure a predictable global grid + devices = sorted(devices, key=lambda d: d.process_index) + global_shape = tuple(d * i for d, i in zip(dcn_shape, ici_shape)) + devices_array = np.array(devices).reshape(global_shape) + logging.info( + 'Creating CPU-only hybrid mesh with axes: %s', + dict(zip(config.mesh_axes, devices_array.shape)), + ) + return jax.sharding.Mesh(devices_array, config.mesh_axes) + # --- Validation --- if config.process_is_granule: process_count = jax.process_count() @@ -99,6 +111,6 @@ def create_mesh(config: configs.MeshConfig) -> jax.sharding.Mesh: ) logging.info( 'Creating mesh with axes: %s', - {axis: dim for axis, dim in zip(config.mesh_axes, devices_array.shape)}, + dict(zip(config.mesh_axes, devices_array.shape)), ) return jax.sharding.Mesh(devices_array, config.mesh_axes) diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/device_mesh_test.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/device_mesh_test.py index 533840b74..ce3f69dff 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/device_mesh_test.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/device_mesh_test.py @@ -23,8 +23,9 @@ class FakeDevice: - def __init__(self, device_id): + def __init__(self, device_id, process_index=0): self.id = device_id + self.process_index = process_index def __str__(self): return f"FakeDevice(id={self.id})" @@ -48,6 +49,11 @@ def setUp(self): self.mock_process_count = self.enter_context( mock.patch.object(jax, "process_count", autospec=True) ) + # Force a non-CPU backend so tests exercise the hybrid-mesh path rather + # than the CPU shortcut that bypasses all validation logic. + self.mock_default_backend = self.enter_context( + mock.patch.object(jax, "default_backend", return_value="tpu") + ) def test_create_mesh_success(self): self.mock_devices.return_value = [FakeDevice(device_id=i) for i in range(8)] diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/run_benchmarks.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/run_benchmarks.py index b2bc0a651..c470b0f85 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/run_benchmarks.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/run_benchmarks.py @@ -63,7 +63,29 @@ def _init_jax_distributed(): """Initializes JAX distributed system if not managed by XManager.""" try: - jax.distributed.initialize() + jax_platforms = os.environ.get('JAX_PLATFORMS') + jax_coordinator_address = os.environ.get('JAX_COORDINATOR_ADDRESS') + jax_process_id = os.environ.get('JAX_PROCESS_ID') + jax_num_processes = os.environ.get('JAX_NUM_PROCESSES') + jax_coordinator_port = os.environ.get('JAX_COORDINATOR_PORT') + logging.info('JAX_PLATFORMS: %s', jax_platforms) + logging.info( + 'JAX_COORDINATOR_ADDRESS: %s', + jax_coordinator_address, + ) + logging.info('JAX_PROCESS_ID: %s', jax_process_id) + logging.info('JAX_NUM_PROCESSES: %s', jax_num_processes) + logging.info('JAX_COORDINATOR_PORT: %s', jax_coordinator_port) + if jax_num_processes is not None: + jax_num_processes = int(jax_num_processes) + if jax_process_id is not None: + jax_process_id = int(jax_process_id) + jax.distributed.initialize( + coordinator_address=jax_coordinator_address, + num_processes=jax_num_processes, + process_id=jax_process_id, + initialization_timeout=600, + ) logging.info('JAX distributed system initialized.') except Exception as e: # pylint: disable=broad-exception-caught logging.warning( @@ -71,9 +93,11 @@ def _init_jax_distributed(): 'This is expected if running in a single-process environment. ' 'Continuing as single-process.', e, - exc_info=False, + exc_info=True, ) + logging.info('Default JAX backend: %s', jax.default_backend()) + logging.info('Available devices: %s', jax.devices()) logging.info('JAX process index: %d', jax.process_index()) logging.info('JAX process count: %d', jax.process_count()) logging.info('JAX device count: %d', jax.device_count()) diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/benchmark.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/benchmark.py index 16cfc272e..93dfeaf69 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/benchmark.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/benchmark.py @@ -57,6 +57,7 @@ class BenchmarkOptions(benchmarks_core.BenchmarkOptions): restore_concurrent_gb: The number of concurrent GB to use for restoring. metric_tracemalloc_enabled: Whether to enable tracemalloc metric. metric_tensorstore_enabled: Whether to enable tensorstore metric. + use_load_and_broadcast: Whether to use load and broadcast. use_replica_parallel: Whether to use replica parallel. enable_replica_parallel_separate_folder: Whether to enable replica parallel separate folder. @@ -71,6 +72,7 @@ class BenchmarkOptions(benchmarks_core.BenchmarkOptions): restore_concurrent_gb: int | None | Sequence[int | None] = None metric_tracemalloc_enabled: bool = False metric_tensorstore_enabled: bool = False + use_load_and_broadcast: bool | Sequence[bool] = False use_replica_parallel: bool | Sequence[bool] = False enable_replica_parallel_separate_folder: bool | Sequence[bool] = False chunk_byte_size: int | None | Sequence[int | None] = None @@ -103,6 +105,7 @@ def context(self) -> ocp.Context: else None, ), loading=ocp.options.ArrayOptions.Loading( + use_load_and_broadcast=self.use_load_and_broadcast, concurrent_bytes=self.restore_concurrent_gb * 1024**3 if self.restore_concurrent_gb is not None else None, diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/resharding_benchmark.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/resharding_benchmark.py index 1baf226f4..936df3e76 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/resharding_benchmark.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/resharding_benchmark.py @@ -101,7 +101,7 @@ def test_fn( checkpoint_generation.get_abstract_state_from_sharding_config( reference_sharding_path, metadata.metadata, - devices=context.mesh.devices, + devices=jax.devices(), ) ) diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/restore_and_broadcast_benchmark.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/restore_and_broadcast_benchmark.py new file mode 100644 index 000000000..296569850 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/restore_and_broadcast_benchmark.py @@ -0,0 +1,196 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Benchmarks for V1 free functions.""" + +from __future__ import annotations + +import dataclasses +import pprint +from typing import Any + +from absl import logging +from etils import epath +import jax +from orbax.checkpoint import v1 as ocp +from orbax.checkpoint._src.multihost import multislice +from orbax.checkpoint._src.testing.benchmarks.core import checkpoint_generation +from orbax.checkpoint._src.testing.benchmarks.core import core as benchmarks_core +from orbax.checkpoint._src.testing.benchmarks.core import metric as metric_lib +from orbax.checkpoint._src.testing.benchmarks.v1 import benchmark + + +# ============================================================================== +# 1. Define the Options Dataclass for this specific benchmark +# ============================================================================== +@dataclasses.dataclass(frozen=True) +class RestoreAndBroadcastBenchmarkOptions(benchmark.BenchmarkOptions): + """Configuration options for benchmarks targeting restore-and-broadcast. + + See parent class. + + Attributes: + reference_checkpoint_path: The path to the reference checkpoint. This + dictates the structure of the checkpoint to be restored. + reference_sharding_path: The path to the reference sharding config. This + dictates the shardings used for restoration. Note that this sharding + config is for a *single replica*. The benchmark should be configured with + DCN parallelism, and the test harness will replicate the sharding config + to the multiple replicas dictated by the mesh. + use_load_and_broadcast: Whether to use the load_and_broadcast API. + """ + + reference_checkpoint_path: str | None = None + reference_sharding_path: str | None = None + use_load_and_broadcast: bool = True + + def is_valid(self) -> bool: + if self.reference_checkpoint_path is None: + return False + if self.reference_sharding_path is None: + return False + return super().is_valid() + + +def _get_single_replica_abstract_state( + context: ocp.Context, + global_mesh: jax.sharding.Mesh, + *, + reference_checkpoint_path: epath.Path, + reference_sharding_path: epath.Path, +): + """Returns the abstract state for a single replica.""" + with ocp.Context(context=context): + metadata = ocp.pytree_metadata(reference_checkpoint_path) + # Abstract tree has shardings on a single replica. + return checkpoint_generation.get_abstract_state_from_sharding_config( + reference_sharding_path, + metadata.metadata, + devices=multislice.replica_devices( + global_mesh, replica_id=0, replica_axis_index=0 + ).tolist(), + ) + + +def _get_abstract_state( + context: ocp.Context, + global_mesh: jax.sharding.Mesh, + single_replica_abstract_state: Any, +): + """Returns the abstract state for all replicas.""" + with ocp.Context(context=context): + # Blow shardings up to all replicas. + def _multi_replica_sharding(abstract_arr: jax.ShapeDtypeStruct): + logging.info( + "Original (single-replica) sharding: %s", abstract_arr.sharding + ) + assert isinstance(abstract_arr.sharding, jax.sharding.NamedSharding) + single_replica_mesh = abstract_arr.sharding.mesh + single_replica_partition_spec = abstract_arr.sharding.spec + multi_replica_sharding = jax.sharding.NamedSharding( + jax.sharding.Mesh( + devices=global_mesh.devices.reshape( + -1, *single_replica_mesh.devices.shape + ), + axis_names=["replica", *single_replica_mesh.axis_names], + ), + spec=jax.sharding.PartitionSpec(*single_replica_partition_spec), + ) + logging.info("Multi-replica sharding: %s", multi_replica_sharding) + return jax.ShapeDtypeStruct( + shape=abstract_arr.shape, + dtype=abstract_arr.dtype, + sharding=multi_replica_sharding, + ) + + return jax.tree.map( + _multi_replica_sharding, + single_replica_abstract_state, + ) + + +# ============================================================================== +# 2. Implement the Benchmark Generator +# ============================================================================== +@benchmarks_core.benchmark_options(RestoreAndBroadcastBenchmarkOptions) +class RestoreAndBroadcastBenchmark(benchmarks_core.BenchmarksGenerator): + """A concrete generator for restore and broadcast benchmarks.""" + + def test_fn( + self, context: benchmarks_core.TestContext + ) -> benchmarks_core.TestResult: + """The core test logic for a single save/restore cycle. + + This function is called for each combination of options generated by the + framework. It uses the `context.options` to configure the handler + dynamically for each run. + + Args: + context: The test context containing the pytree, path, and options. + + Returns: + The test result containing the metrics. + """ + metrics = metric_lib.Metrics() + assert context.pytree is None + options = context.options + assert isinstance(options, RestoreAndBroadcastBenchmarkOptions) + assert options.reference_checkpoint_path is not None + assert options.reference_sharding_path is not None + assert context.mesh is not None + + logging.info("Benchmark options: %s", pprint.pformat(options)) + metrics_to_measure = benchmark.get_metrics_to_measure(options) + + reference_checkpoint_path = epath.Path( + options.reference_checkpoint_path + ) + reference_sharding_path = epath.Path( + options.reference_sharding_path + ) + + if context.mesh.devices.ndim != 2: + raise ValueError( + "Found mesh with unexpected number of dimensions:" + f" {context.mesh.ndim}" + ) + if [str(axis) for axis in context.mesh.axis_names] != ["replica", "model"]: + raise ValueError( + f"Found mesh with unexpected axis names: {context.mesh.axis_names}" + ) + + single_replica_abstract_pytree = _get_single_replica_abstract_state( + context=options.context, + global_mesh=context.mesh, + reference_checkpoint_path=reference_checkpoint_path, + reference_sharding_path=reference_sharding_path, + ) + abstract_pytree = _get_abstract_state( + context=options.context, + global_mesh=context.mesh, + single_replica_abstract_state=single_replica_abstract_pytree, + ) + + with ocp.Context(context=options.context): + if options.enable_trace: + jax.profiler.start_trace(context.path / "trace_load") + with metrics.measure("load", metrics_to_measure): + restored_pytree = ocp.load_pytree( + reference_checkpoint_path, abstract_pytree + ) + benchmark.clear_pytree(restored_pytree) + if options.enable_trace: + jax.profiler.stop_trace() + + return benchmarks_core.TestResult(metrics=metrics) diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/restore_and_broadcast_benchmark_test.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/restore_and_broadcast_benchmark_test.py new file mode 100644 index 000000000..40d3fe49a --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/restore_and_broadcast_benchmark_test.py @@ -0,0 +1,216 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os + +from absl.testing import absltest +from absl.testing import parameterized +from etils import epath +import jax +import jax.numpy as jnp +import numpy as np +from orbax.checkpoint import v1 as ocp +from orbax.checkpoint._src.testing.benchmarks.core import configs as benchmarks_configs +from orbax.checkpoint._src.testing.benchmarks.core import core as benchmarks_core +from orbax.checkpoint._src.testing.benchmarks.v1 import restore_and_broadcast_benchmark + + +RestoreAndBroadcastBenchmarkOptions = ( + restore_and_broadcast_benchmark.RestoreAndBroadcastBenchmarkOptions +) +RestoreAndBroadcastBenchmark = ( + restore_and_broadcast_benchmark.RestoreAndBroadcastBenchmark +) + +_REQUIRED_DEVICE_COUNT = 16 + + +class RestoreAndBroadcastBenchmarkTest(parameterized.TestCase): + + def setUp(self): + self._prev_xla_flags = os.environ.get('XLA_FLAGS') + os.environ['XLA_FLAGS'] = ( + self._prev_xla_flags or '' + ) + ' --xla_force_host_platform_device_count=16' + super().setUp() + if jax.local_device_count() != _REQUIRED_DEVICE_COUNT: + self.skipTest( + f'Test requires {_REQUIRED_DEVICE_COUNT} local devices, but only' + f' {jax.local_device_count()} are available. Set XLA_FLAGS=' + f'"--xla_force_host_platform_device_count={_REQUIRED_DEVICE_COUNT}"' + ' before JAX initializes.' + ) + self.directory = epath.Path(self.create_tempdir().full_path) + + def tearDown(self): + if self._prev_xla_flags is None: + os.environ.pop('XLA_FLAGS', None) + else: + os.environ['XLA_FLAGS'] = self._prev_xla_flags + super().tearDown() + + @parameterized.parameters( + dict( + options=RestoreAndBroadcastBenchmarkOptions( + reference_checkpoint_path='ckpt_path', + reference_sharding_path='sharding_path', + ), + expected_len=1, + ), + ) + def test_generate_benchmarks(self, options, expected_len): + generator = RestoreAndBroadcastBenchmark( + checkpoint_configs=[benchmarks_configs.CheckpointConfig(spec={})], + options=options, + ) + benchmarks = generator.generate() + self.assertLen(benchmarks, expected_len) + for benchmark in benchmarks: + self.assertIsInstance( + benchmark.options, RestoreAndBroadcastBenchmarkOptions + ) + + def test_get_abstract_state(self): + # Setup real checkpoint and sharding config + pytree = {'a': jnp.arange(32), 'b': {'c': jnp.ones((8, 8))}} + ref_ckpt_path = self.directory / 'ref_ckpt' + ocp.save_pytree(ref_ckpt_path, pytree) + + sharding_config = { + 'a': { + 'shape': [32], + 'dtype': 'int32', + 'sharding': { + 'mesh': {'shape': [4], 'axes': ['model']}, + 'spec': ['model'], + }, + }, + 'b.c': { + 'shape': [8, 8], + 'dtype': 'float32', + 'sharding': { + 'mesh': {'shape': [4], 'axes': ['model']}, + 'spec': [None, 'model'], + }, + }, + } + sharding_config_path = self.directory / 'sharding_config.json' + sharding_config_path.write_text(json.dumps(sharding_config)) + global_mesh = jax.sharding.Mesh( + np.array(jax.devices()).reshape((4, 4)), ('replica', 'model') + ) + + single_replica_abstract_pytree = ( + restore_and_broadcast_benchmark._get_single_replica_abstract_state( + context=ocp.Context(), + global_mesh=global_mesh, + reference_checkpoint_path=ref_ckpt_path, + reference_sharding_path=sharding_config_path, + ) + ) + self.assertEqual( + {'model': 4}, single_replica_abstract_pytree['a'].sharding.mesh.shape + ) + self.assertEqual( + jax.sharding.PartitionSpec('model'), + single_replica_abstract_pytree['a'].sharding.spec, + ) + self.assertEqual( + {'model': 4}, + single_replica_abstract_pytree['b']['c'].sharding.mesh.shape, + ) + self.assertEqual( + jax.sharding.PartitionSpec(None, 'model'), + single_replica_abstract_pytree['b']['c'].sharding.spec, + ) + + abstract_pytree = restore_and_broadcast_benchmark._get_abstract_state( + context=ocp.Context(), + global_mesh=global_mesh, + single_replica_abstract_state=single_replica_abstract_pytree, + ) + self.assertEqual( + {'replica': 4, 'model': 4}, abstract_pytree['a'].sharding.mesh.shape + ) + self.assertEqual( + jax.sharding.PartitionSpec('model'), abstract_pytree['a'].sharding.spec + ) + self.assertEqual( + {'replica': 4, 'model': 4}, + abstract_pytree['b']['c'].sharding.mesh.shape, + ) + self.assertEqual( + jax.sharding.PartitionSpec(None, 'model'), + abstract_pytree['b']['c'].sharding.spec, + ) + + def test_benchmark_test_fn(self): + # Setup real checkpoint and sharding config + pytree = {'a': jnp.arange(32), 'b': {'c': jnp.ones((8, 8))}} + ref_ckpt_path = self.directory / 'ref_ckpt' + ocp.save_pytree(ref_ckpt_path, pytree) + + sharding_config = { + 'a': { + 'shape': [32], + 'dtype': 'int32', + 'sharding': { + 'mesh': {'shape': [4], 'axes': ['model']}, + 'spec': ['model'], + }, + }, + 'b.c': { + 'shape': [8, 8], + 'dtype': 'float32', + 'sharding': { + 'mesh': {'shape': [4], 'axes': ['model']}, + 'spec': [None, 'model'], + }, + }, + } + sharding_config_path = self.directory / 'sharding_config.json' + sharding_config_path.write_text(json.dumps(sharding_config)) + + options = RestoreAndBroadcastBenchmarkOptions( + reference_checkpoint_path=str(ref_ckpt_path), + reference_sharding_path=str(sharding_config_path), + ) + global_mesh = jax.sharding.Mesh( + np.array(jax.devices()).reshape((4, 4)), ('replica', 'model') + ) + + context = benchmarks_core.TestContext( + pytree=None, + path=self.directory / 'test_run', + options=options, + mesh=global_mesh, + ) + self.assertTrue(options.use_load_and_broadcast) + self.assertTrue( + options.context.array_options.loading.use_load_and_broadcast + ) + + generator = RestoreAndBroadcastBenchmark( + checkpoint_configs=[benchmarks_configs.CheckpointConfig(spec={})], + options=options, + ) + result = generator.test_fn(context) + self.assertIsInstance(result, benchmarks_core.TestResult) + metrics = result.metrics.results + self.assertIn('load_time_duration', metrics) + + +if __name__ == '__main__': + absltest.main() diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/Dockerfile b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/Dockerfile index c88b5c52a..d5d0d795a 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/Dockerfile +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/Dockerfile @@ -1,5 +1,5 @@ # Base image argument (defaulting to slim python image) -ARG BASE_IMAGE=python:3.11-slim +ARG BASE_IMAGE=python:3.13-slim FROM $BASE_IMAGE WORKDIR /app @@ -56,7 +56,15 @@ ARG JAX_VERSION=newest ARG DEVICE=tpu # Install GCSFS and Portpicker -RUN pip install --no-cache-dir gcsfs portpicker clu tensorflow +RUN pip install --no-cache-dir gcsfs portpicker clu + +RUN if [ "$DEVICE" = "gpu" ]; then \ + pip install --no-cache-dir tensorflow; \ + elif [ "$DEVICE" = "tpu" ]; then \ + pip install --no-cache-dir tensorflow -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \ + else \ + pip install --no-cache-dir tensorflow-cpu; \ + fi # Install requirements from repo root if it exists RUN if [ -f "requirements.txt" ]; then pip install --no-cache-dir -r requirements.txt; fi @@ -68,13 +76,15 @@ RUN if [ "$JAX_VERSION" = "newest" ]; then \ elif [ "$DEVICE" = "tpu" ]; then \ pip install --no-cache-dir -U "jax[k8s,tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \ else \ - pip install --no-cache-dir -U "jax[k8s]" jaxlib; \ + pip install --no-cache-dir -U "jax[k8s,cpu]" jaxlib; \ fi \ elif [ "$JAX_VERSION" = "nightly" ]; then \ if [ "$DEVICE" = "gpu" ]; then \ pip install --no-cache-dir -U --pre "jax[k8s,cuda12]" jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/; \ elif [ "$DEVICE" = "tpu" ]; then \ pip install --no-cache-dir -U --pre "jax[k8s,tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/; \ + else \ + pip install --no-cache-dir -U --pre "jax[k8s,cpu]" jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/; \ fi \ else \ # Specific version @@ -83,7 +93,7 @@ RUN if [ "$JAX_VERSION" = "newest" ]; then \ elif [ "$DEVICE" = "tpu" ]; then \ pip install --no-cache-dir "jax[k8s,tpu]==${JAX_VERSION}" "jaxlib==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html; \ else \ - pip install --no-cache-dir "jax[k8s]==${JAX_VERSION}" "jaxlib==${JAX_VERSION}"; \ + pip install --no-cache-dir "jax[k8s,cpu]==${JAX_VERSION}" "jaxlib==${JAX_VERSION}"; \ fi \ fi diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/build_image.sh b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/build_image.sh index 486b95983..02e01aa8c 100755 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/build_image.sh +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/build_image.sh @@ -58,7 +58,7 @@ fi # Set default base image if not provided if [[ -z "$BASE_IMAGE" ]]; then - BASE_IMAGE="python:3.11-slim" + BASE_IMAGE="python:3.13-slim" fi SCRIPT_DIR="$(dirname "$(realpath "$0")")" @@ -118,6 +118,13 @@ for t in "${tags[@]}"; do build_tag_args+=(-t "${IMAGE_REPO}:${t}") done +# Create a temporary directory to act as the clean build context +BUILD_CONTEXT=$(mktemp -d) +# Ensure the temporary directory is cleaned up when the script exits (success or fail) +trap 'rm -rf "$BUILD_CONTEXT"' EXIT +cp "${DOCKERFILE_PATH}" "$BUILD_CONTEXT/Dockerfile" +cd "$BUILD_CONTEXT" + # Build with local Docker echo "Building with previously installed Docker..." declare -a build_args=() @@ -133,7 +140,7 @@ build_args+=( ) build_args+=("${build_tag_args[@]}") build_args+=( - "-f" "${DOCKERFILE_PATH}" + "-f" "Dockerfile" "." ) docker build "${build_args[@]}" diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk.py index 9987ad7b3..0ea62316d 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk.py @@ -31,8 +31,10 @@ from collections.abc import Sequence import datetime +import enum import itertools import os +import re import subprocess import sys import uuid @@ -552,25 +554,98 @@ def create_cluster() -> None: Console.print_success(f'Cluster {_CLUSTER_NAME.value} created.') +class HardwareType(enum.Enum): + TPU = 'tpu' + GPU = 'gpu' + CPU = 'cpu' + UNKNOWN = 'unknown' + + +def get_hardware_type( + tpu_type: str | None, device_type: str | None +) -> HardwareType: + """Categorizes a compute instance string into a HardwareType enum.""" + tpu_type = tpu_type.lower().strip() if tpu_type else '' + device_type = device_type.lower().strip() if device_type else '' + device_type = tpu_type or device_type + + # 1. Check for TPU + # Matches GCP TPU names like v2, v3, v4, v5e, v5p, or explicit 'tpu' + if re.search(r'\bv[2-5][a-z]?\b', device_type) or 'tpu' in device_type: + return HardwareType.TPU + + # 2. Check for GPU + # Matches common accelerator names (h100, a100, l4) or GPU (a2, a3, g2, p4) + gpu_chips = ['h100', 'a100', 'v100', 'p100', 't4', 'l4', 'k80'] + gpu_instances = [r'^a[2-3]-', r'^g[2]-', r'^p[3-5]\.', r'^g[4-5]\.'] + + if ( + any(chip in device_type for chip in gpu_chips) + or any(re.match(pattern, device_type) for pattern in gpu_instances) + or 'gpu' in device_type + ): + return HardwareType.GPU + + # 3. Check for CPU + # Matches GCP (n1, n2, e2, c2, c3, m1) and AWS (t2, m5, c5) + cpu_instances = [r'^[necm][1-4]-', r'^[tcmri][2-8]\.'] + if ( + any(re.match(pattern, device_type) for pattern in cpu_instances) + or 'cpu' in device_type + ): + return HardwareType.CPU + + return HardwareType.UNKNOWN + + def construct_workload_command( *, + workload_name: str, config_file: str, output_directory: str, run_id: str, enable_pathways: bool, benchmark_binary_path: str, + hardware_type: HardwareType, v_level: int | None, ) -> str: """Constructs the command to run inside the workload.""" - # Environment variables + # Environment variables. if enable_pathways: + # Pathways (create-pathways) runs user Python on the head pod, which has no + # direct TPU hardware. We must use JAX_PLATFORMS=proxy so JAX routes through + # the Pathways server that manages the TPU workers. + # NOTE: We deliberately do NOT set ENABLE_PATHWAYS_PERSISTENCE=1 because + # that flag causes pathwaysutils to intercept all Orbax checkpoint reads + # via orbax_handler.py, which expects raw zarr3 files but our checkpoints + # use OCDBT wrapping. Orbax handles checkpoint I/O perfectly well on its + # own. env_vars = [ 'export JAX_PLATFORMS=proxy', - 'export ENABLE_PATHWAYS_PERSISTENCE=1', 'export ENABLE_PJRT_COMPATIBILITY=true', ] else: - env_vars = ['export JAX_PLATFORMS=tpu,cpu'] + if hardware_type == HardwareType.TPU: + env_vars = ['export JAX_PLATFORMS=tpu,cpu'] + elif hardware_type == HardwareType.GPU: + env_vars = ['export JAX_PLATFORMS=gpu,cpu'] + elif hardware_type == HardwareType.CPU: + fqdn_address = f'{workload_name}-slice-job-0-0.{workload_name}.default.svc.cluster.local' + env_vars = [ + 'export JAX_PLATFORMS=cpu', + 'export JAX_NUM_PROCESSES=$JAX_PROCESS_COUNT', + ( + 'export JAX_PROCESS_ID=$(($JOB_INDEX * $PROCESSES_IN_JOB +' + ' $JOB_COMPLETION_INDEX))' + ), + ( + 'export JAX_COORDINATOR_ADDRESS=$(if [ "$JAX_PROCESS_ID" = "0" ];' + f' then echo "localhost"; else echo "{fqdn_address}"; fi):1234' + ), + 'export XLA_FLAGS="--xla_cpu_collective_timeout_seconds=600"', + ] + else: + raise ValueError(f'Unsupported hardware type: {hardware_type}') env_cmd = ' && '.join(env_vars) + ' && ' if env_vars else '' @@ -586,6 +661,8 @@ def construct_workload_command( python_args.append(f'--v={v_level}') python_cmd = ' '.join(python_args) + if hardware_type == HardwareType.CPU: + python_cmd += ' --jax_cpu_collectives_implementation=gloo' if enable_pathways: python_cmd = ( 'python3 -c "import pathwaysutils;' @@ -600,6 +677,13 @@ def construct_xpk_command( workload_name: str, workload_command: str ) -> Sequence[str]: """Constructs the XPK CLI command.""" + # In colocated Python mode (enable_pathways=True), we still use + # 'create-pathways' because the cluster was provisioned with Pathways and + # XPK needs this to schedule through the Pathways resource manager. + # However we don't pass --server-image/--proxy-server-image (those are for + # the full proxy sidecar architecture). Instead we run with JAX_PLATFORMS=tpu + # and just a plain --docker-image, which colocates Python directly on the + # TPU pods without going through the IFRT proxy server. base_cmd = [ _XPK_PATH.value, 'workload', @@ -847,12 +931,15 @@ def main(argv: Sequence[str]) -> None: # 5. Construct Commands Console.print_step(4, 6, 'Constructing Commands') + hardware_type = get_hardware_type(_TPU_TYPE.value, _DEVICE_TYPE.value) workload_cmd = construct_workload_command( + workload_name=workload_name_base, config_file=remote_config_path, output_directory=_OUTPUT_DIRECTORY.value, run_id=run_id, enable_pathways=_ENABLE_PATHWAYS.value, benchmark_binary_path=_BENCHMARK_BINARY_PATH.value, + hardware_type=hardware_type, v_level=_V_LEVEL.value, ) diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk_test.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk_test.py new file mode 100644 index 000000000..602ed9486 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk_test.py @@ -0,0 +1,182 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl import flags +from absl.testing import absltest +from absl.testing import parameterized +from orbax.checkpoint._src.testing.benchmarks.xpk import launch_xpk + +HardwareType = launch_xpk.HardwareType + +# satisfy required flags and validators +flags.FLAGS.set_default('cluster_name', 'dummy-cluster') +flags.FLAGS.set_default('config_file', __file__) +flags.FLAGS.set_default('output_directory', 'gs://dummy-bucket') + + +class LaunchXpkTest(parameterized.TestCase): + + @parameterized.named_parameters( + # TPU Cases + dict( + testcase_name='tpu_v2', + tpu='v2', + device=None, + expected=HardwareType.TPU, + ), + dict( + testcase_name='tpu_v3', tpu='v3', device='', expected=HardwareType.TPU + ), + dict( + testcase_name='tpu_v4', + tpu='v4-8', + device=None, + expected=HardwareType.TPU, + ), + dict( + testcase_name='tpu_v5e', + tpu='v5e-4', + device=None, + expected=HardwareType.TPU, + ), + dict( + testcase_name='tpu_v5p', + tpu='v5p-8', + device=None, + expected=HardwareType.TPU, + ), + dict( + testcase_name='tpu_explicit', + tpu=None, + device='tpu-v5-litepod-8', + expected=HardwareType.TPU, + ), + dict( + testcase_name='tpu_mixed', + tpu='v5p', + device='something-else', + expected=HardwareType.TPU, + ), + # GPU Cases + dict( + testcase_name='gpu_h100', + tpu=None, + device='h100', + expected=HardwareType.GPU, + ), + dict( + testcase_name='gpu_a100', + tpu='', + device='a100', + expected=HardwareType.GPU, + ), + dict( + testcase_name='gpu_v100', + tpu=None, + device='v100', + expected=HardwareType.GPU, + ), + dict( + testcase_name='gpu_l4', + tpu=None, + device='l4', + expected=HardwareType.GPU, + ), + dict( + testcase_name='gpu_a2_instance', + tpu=None, + device='a2-highgpu-1g', + expected=HardwareType.GPU, + ), + dict( + testcase_name='gpu_g2_instance', + tpu=None, + device='g2-standard-4', + expected=HardwareType.GPU, + ), + dict( + testcase_name='gpu_p4_instance', + tpu=None, + device='p4.2xlarge', + expected=HardwareType.GPU, + ), + dict( + testcase_name='gpu_explicit', + tpu=None, + device='nvidia-gpu', + expected=HardwareType.GPU, + ), + # CPU Cases + dict( + testcase_name='cpu_n1', + tpu=None, + device='n1-standard-4', + expected=HardwareType.CPU, + ), + dict( + testcase_name='cpu_n2', + tpu=None, + device='n2-standard-32', + expected=HardwareType.CPU, + ), + dict( + testcase_name='cpu_c3', + tpu=None, + device='c3-standard-4', + expected=HardwareType.CPU, + ), + dict( + testcase_name='cpu_m1', + tpu=None, + device='m1-ultramem-40', + expected=HardwareType.CPU, + ), + dict( + testcase_name='cpu_t2_aws', + tpu=None, + device='t2.micro', + expected=HardwareType.CPU, + ), + dict( + testcase_name='cpu_explicit', + tpu=None, + device='google-cpu', + expected=HardwareType.CPU, + ), + # Unknown Cases + dict( + testcase_name='unknown_junk', + tpu=None, + device='junk-string', + expected=HardwareType.UNKNOWN, + ), + dict( + testcase_name='unknown_empty', + tpu='', + device='', + expected=HardwareType.UNKNOWN, + ), + dict( + testcase_name='unknown_none', + tpu=None, + device=None, + expected=HardwareType.UNKNOWN, + ), + ) + def test_get_hardware_type(self, tpu, device, expected): + self.assertEqual(launch_xpk.get_hardware_type(tpu, device), expected) + + +if __name__ == '__main__': + absltest.main() diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py index 545f30ca0..2d0c1c43b 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py @@ -469,7 +469,9 @@ class PathwaysOptions: checkpointing_impl: The implementation to use for Pathways checkpointing. """ - checkpointing_impl: pathways_types.CheckpointingImpl | None = None + checkpointing_impl: pathways_types.CheckpointingImpl | None = ( + pathways_types.CheckpointingImpl.COLOCATED_PYTHON + ) class CheckpointLayout(enum.Enum): diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/registration.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/registration.py index 022a22e12..8b573c407 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/registration.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/registration.py @@ -26,6 +26,7 @@ Pathways dependencies should not be added to this file. """ +from absl import logging from orbax.checkpoint._src.serialization import jax_array_handlers from orbax.checkpoint._src.serialization import pathways_handler_registry from orbax.checkpoint._src.serialization import pathways_types @@ -57,11 +58,15 @@ def resolve_pathways_checkpointing_impl( except ImportError as e: raise ImportError(_PATHWAYS_IMPORT_ERROR_MSG) from e checkpointing_impl = context.pathways_options.checkpointing_impl - return checkpointing_impl or pathways_types.CheckpointingImpl.from_options( + resolved_checkpointing_impl = checkpointing_impl or pathways_types.CheckpointingImpl.from_options( use_colocated_python=False, # Not enabled unless explicitly requested. use_remote_python=rp.available(), use_persistence_array_handler=True, # Only used as a fallback. ) + logging.info( + 'Resolved Pathways implementation: %s', resolved_checkpointing_impl + ) + return resolved_checkpointing_impl def get_array_handler(