Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
_SERVER_IMAGE = flags.DEFINE_string(
"server_image", None, "Full path to the server Docker image"
)
_SIDECAR_IMAGE = flags.DEFINE_string(
"sidecar_image",
"us-docker.pkg.dev/cloud-tpu-v2-images/pathways-colocated-python/sidecar:20260423-python_3.12-jax_0.10.0",
"Full path to the sidecar Docker image",
)
_TPU_TYPE = flags.DEFINE_enum(
"tpu_type", "v6e", ["v5e", "v5p", "v6e", "tpu7x"], "TPU type"
)
Expand All @@ -52,6 +57,7 @@
False,
"If true, only print the generated YAML without deploying.",
)
_SIDECAR_SHM_DIR = "/tmp/sidecar_dir"


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -191,6 +197,7 @@ def run_deployment(
jobset_name,
gcs_bucket,
server_image,
sidecar_image,
template_file,
dry_run,
deploy_func: Callable[[dict[str, Any]], None] = deploy_jobset,
Expand All @@ -202,6 +209,8 @@ def run_deployment(
context = {
"JOBSET_NAME": jobset_name,
"SERVER_IMAGE": server_image,
"SIDECAR_IMAGE": sidecar_image,
"SIDECAR_SHM_DIR": _SIDECAR_SHM_DIR,
"GCS_SCRATCH_LOCATION": gcs_bucket,
"NUM_SLICES": num_slices,
"INSTANCE_TYPE": f"{tpu_config.instance_prefix}:{topology}",
Expand Down Expand Up @@ -246,6 +255,7 @@ def main(argv: Sequence[str]) -> None:
jobset_name=_JOBSET_NAME.value,
gcs_bucket=_GCS_BUCKET.value,
server_image=server_image,
sidecar_image=_SIDECAR_IMAGE.value,
template_file=_TEMPLATE_FILE.value,
dry_run=_DRY_RUN.value,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Use the JAX image with the custom-built sidecar as the base.
FROM us-docker.pkg.dev/cloud-tpu-v2-images/pathways-colocated-python/sidecar:20260423-python_3.12-jax_0.10.0

# Set the working directory
WORKDIR /app

# 1. Upgrade pip and build tools
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --upgrade pip setuptools wheel

# 2. Clone MaxText
RUN git clone https://github.com/google/maxtext.git

# ADD THE CACHE MOUNT HERE
# Install the same version of JAX and JAXlib as the base image.
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install -r maxtext/src/dependencies/requirements/base_requirements/requirements.txt && \
uv pip install --upgrade jax==0.10.0 jaxlib==0.10.0

# 3. (optional) Copy your local edits to MaxText requirements and src, if any.
# Make sure you're running this docker build from the root of your local MaxText
# checkout.
# COPY maxtext/src/dependencies/requirements/base_requirements/requirements.txt ./requirements.txt
# COPY maxtext/src /app/maxtext/src

# Ensure MaxText src and Orbax are in PYTHONPATH
ENV PYTHONPATH=/app/maxtext/src:/app/orbax/checkpoint:$PYTHONPATH
56 changes: 56 additions & 0 deletions pathwaysutils/experimental/shared_pathways_service/gke_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""GKE utils for deploying and managing the Pathways proxy."""

import json
import logging
import re
import socket
Expand Down Expand Up @@ -475,3 +476,58 @@ def is_local_port_free(port: int) -> bool:
"""Checks if a local port is free."""
return portpicker.is_port_free(port)


def get_worker_sidecar_image(
pathways_service: str, namespace: str = "default"
) -> str | None:
"""Gets the image of the sidecar container used by the workers."""
pathways_head_hostname = pathways_service.split(":")[0]
_validate_k8s_name(namespace)

# Try to extract the jobset name from the Pathways service hostname.
jobset_name = None
if "-pathways-head" in pathways_head_hostname:
jobset_name = pathways_head_hostname.split("-pathways-head")[0]

command = ["kubectl", "get", "pods", "-n", namespace, "-o", "json"]
try:
result = subprocess.run(
command,
check=True,
capture_output=True,
text=True,
)
except subprocess.CalledProcessError as e:
_logger.exception("Failed to get pods. kubectl output:\n%r", e.stderr)
return None

try:
pods_data = json.loads(result.stdout)
except json.JSONDecodeError as e:
_logger.exception("Failed to parse kubectl get pods output: %r", e)
return None

items = pods_data.get("items", [])

# Look for pods belonging to the jobset and having the sidecar
# container/initContainer.
if jobset_name:
for pod in items:
metadata = pod.get("metadata", {})
labels = metadata.get("labels", {})
pod_jobset_name = labels.get("jobset.sigs.k8s.io/jobset-name")
pod_name = metadata.get("name", "")

if pod_jobset_name == jobset_name or pod_name.startswith(jobset_name):
spec = pod.get("spec", {})
for container in spec.get("initContainers", []) + spec.get(
"containers", []
):
if container.get("name") == "colocated-python-sidecar":
image = container.get("image")
if image:
return image

return None


31 changes: 25 additions & 6 deletions pathwaysutils/experimental/shared_pathways_service/isc_pathways.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,26 @@ class ProxyOptions:
use_insecure_credentials: Whether to use insecure gRPC credentials for the
proxy server.
xla_flags: A list of XLA flags to pass to the proxy server.
sidecar: Whether to use the worker sidecar or not.
"""
use_insecure_credentials: bool = False
xla_flags: list[str] = dataclasses.field(default_factory=list)
sidecar: bool = False

@classmethod
def from_list(cls, options: Iterable[str] | None) -> "ProxyOptions":
"""Creates a ProxyOptions object from a list of 'key:value' strings."""
use_insecure = False
use_sidecar = False
xla_flags = []
for option in options or []:
if ":" in option:
key, value = option.split(":", 1)
key_strip = key.strip().lower()
if key_strip == "use_insecure_credentials":
use_insecure = value.strip().lower() == "true"
elif key_strip == "sidecar":
use_sidecar = value.strip().lower() == "true"
elif key_strip == "xla_flags":
val_strip = value.strip()
if (
Expand All @@ -78,7 +83,11 @@ def from_list(cls, options: Iterable[str] | None) -> "ProxyOptions":
if xla_flags:
validators.validate_xla_flags(xla_flags)

return cls(use_insecure_credentials=use_insecure, xla_flags=xla_flags)
return cls(
use_insecure_credentials=use_insecure,
xla_flags=xla_flags,
sidecar=use_sidecar,
)


def _deploy_pathways_proxy_server(
Expand Down Expand Up @@ -134,6 +143,9 @@ def _deploy_pathways_proxy_server(
)
proxy_args_str = "\n" + proxy_args_str

if proxy_options.sidecar:
proxy_args_str += "\n - --sidecar_name=external"

template = string.Template(yaml_template)
substituted_yaml = template.substitute(
PROXY_JOB_NAME=proxy_job_name,
Expand Down Expand Up @@ -455,13 +467,20 @@ def connect(
gke_utils.fetch_cluster_credentials(
cluster_name=cluster, project_id=project, location=region
)
proxy_job_name = (
proxy_job_name or f"isc-proxy-{os.environ.get('USER', 'user')}-{''.join(
random.choices(string.ascii_lowercase + string.digits, k=5)
)}"
)

proxy_options_obj = ProxyOptions.from_list(proxy_options)
if proxy_options_obj.sidecar:
sidecar_image = gke_utils.get_worker_sidecar_image(
pathways_service=pathways_service
)
if sidecar_image:
validators.validate_sidecar_image_versions(sidecar_image)

proxy_job_name = (
proxy_job_name
or f"isc-proxy-{os.environ.get('USER', 'user')}-"
f"{''.join(random.choices(string.ascii_lowercase + string.digits, k=5))}"
)

_logger.info("Starting ISCPathways context.")
with _ISCPathways(
Expand Down
69 changes: 69 additions & 0 deletions pathwaysutils/experimental/shared_pathways_service/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from collections.abc import Iterable, Mapping
import logging
import re
import sys
from typing import Any
from absl import flags
import jax

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -133,3 +135,70 @@ def validate_xla_flags(xla_flags: Iterable[str] | None) -> None:
raise flags.ValidationError(
f"XLA flag '{flag}' must start with '--xla_'."
)


def validate_sidecar_image_versions(sidecar_image: str) -> None:
"""Checks compatibility of sidecar image versions with user environment.

Compares the Python and JAX versions in the sidecar image tag with the user
environment's Python and JAX versions.

Args:
sidecar_image: The sidecar image string, e.g.,
"us-docker.pkg.dev/.../sidecar:20260423-python_3.12-jax_0.10.0".

Raises:
ValueError: If the sidecar image Python or JAX versions do not match the
user environment.
"""
_logger.info(
"Checking sidecar image version compatibility: %s", sidecar_image
)

parts = sidecar_image.rsplit(":", 1)
if len(parts) < 2:
return
tag = parts[1]

sidecar_python_match = re.search(
r"python[-_]?(\d+\.\d+(?:\.\d+)*)", tag, re.IGNORECASE
)

def clean_version(version_str: str) -> str:
match = re.match(r"^(\d+(?:\.\d+)*)", version_str)
return match.group(1) if match else version_str

def versions_match(sidecar_ver: str, env_ver: str) -> bool:
sidecar_parts = sidecar_ver.split(".")
env_parts = env_ver.split(".")
compare_len = min(len(sidecar_parts), len(env_parts))
if compare_len == 0:
return False
return sidecar_parts[:compare_len] == env_parts[:compare_len]

sidecar_jax_match = re.search(
r"jax[-_]?(\d+\.\d+(?:\.\d+)*)", tag, re.IGNORECASE
)
if sidecar_python_match:
sidecar_python = clean_version(sidecar_python_match.group(1))
env_python = (
f"{sys.version_info.major}.{sys.version_info.minor}."
f"{sys.version_info.micro}"
)
if not versions_match(sidecar_python, env_python):
raise ValueError(
f"Python version mismatch: sidecar image matches Python version "
f"{sidecar_python}, but the user environment is running Python "
f"{env_python}."
)

if sidecar_jax_match:
sidecar_jax = clean_version(sidecar_jax_match.group(1))
env_jax = clean_version(jax.__version__)
if not versions_match(sidecar_jax, env_jax):
raise ValueError(
f"JAX version mismatch: sidecar image matches JAX version "
f"{sidecar_jax}, but the user environment is running JAX "
f"{env_jax}."
)

Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ spec:
- --server_port=29005
- --resource_manager_address=$$(PATHWAYS_HEAD):29001
- --gcs_scratch_location=${GCS_SCRATCH_LOCATION}
- --cloud_pathways_sidecar_shm_directory=${SIDECAR_SHM_DIR}
env:
- name: TPU_MIN_LOG_LEVEL
value: "0"
Expand Down Expand Up @@ -133,8 +134,47 @@ spec:
limits:
google.com/tpu: "${CHIPS_PER_VM}"
volumeMounts:
- mountPath: /tmp
name: shared-tmp
- name: shared-tmp
mountPath: /tmp
- name: sidecar-shared-memory
mountPath: ${SIDECAR_SHM_DIR}
initContainers:
- name: colocated-python-sidecar
image: ${SIDECAR_IMAGE}
imagePullPolicy: Always
env:
- name: GRPC_SERVER_ADDRESS
value: '''0.0.0.0:50051'''
- name: CLOUD_PATHWAYS_SIDECAR_SHM_DIRECTORY
value: ${SIDECAR_SHM_DIR}
- name: PYTHONUNBUFFERED
value: '1'
# --- High Verbosity Logging Variables ---
- name: LOGLEVEL
value: 'DEBUG'
- name: GLOG_minloglevel
value: '0' # 0 = INFO level base
- name: GLOG_v
value: '5' # Extreme verbosity for all C++ modules
- name: TF_CPP_MIN_LOG_LEVEL
value: '0'
- name: TF_CPP_MIN_VLOG_LEVEL
value: '5' # TF/XLA verbose logging
- name: TPU_MIN_LOG_LEVEL
value: '0'
- name: GLOG_vmodule
value: 'jax_array_handlers=5,type_handlers=5,tensorstore_utils=5'
# ----------------------------------------
ports:
- containerPort: 50051
protocol: TCP
resources: {}
restartPolicy: Always
volumeMounts:
- name: shared-tmp
mountPath: /tmp
- name: sidecar-shared-memory
mountPath: ${SIDECAR_SHM_DIR}
dnsPolicy: ClusterFirstWithHostNet
hostNetwork: true
nodeSelector:
Expand All @@ -146,6 +186,9 @@ spec:
hostPath:
path: /tmp
type: DirectoryOrCreate
- name: sidecar-shared-memory
emptyDir:
medium: Memory
startupPolicy:
startupPolicyOrder: InOrder
successPolicy:
Expand Down
Loading