From 444f1ff1f7a5e0b2fc993d2d9320fb8e8ccb6d3f Mon Sep 17 00:00:00 2001 From: Akanksha Gupta Date: Mon, 15 Jun 2026 14:56:20 -0700 Subject: [PATCH] Add sidecar image version validation to ISC Pathways connection PiperOrigin-RevId: 932682609 --- .../deploy_pathways_service.py | 10 +++ .../dockerfiles/sample_sidecar.dockerfile | 27 ++++++++ .../shared_pathways_service/gke_utils.py | 56 +++++++++++++++ .../shared_pathways_service/isc_pathways.py | 31 +++++++-- .../shared_pathways_service/validators.py | 69 +++++++++++++++++++ .../yamls/pw-service.yaml | 47 ++++++++++++- 6 files changed, 232 insertions(+), 8 deletions(-) create mode 100644 pathwaysutils/experimental/shared_pathways_service/dockerfiles/sample_sidecar.dockerfile diff --git a/pathwaysutils/experimental/shared_pathways_service/deploy_pathways_service.py b/pathwaysutils/experimental/shared_pathways_service/deploy_pathways_service.py index a4b78d3..bfd6979 100644 --- a/pathwaysutils/experimental/shared_pathways_service/deploy_pathways_service.py +++ b/pathwaysutils/experimental/shared_pathways_service/deploy_pathways_service.py @@ -26,6 +26,11 @@ _SERVER_IMAGE = flags.DEFINE_string( "server_image", None, "Full path to the server Docker image" ) +_SIDECAR_IMAGE = flags.DEFINE_string( + "sidecar_image", + "us-docker.pkg.dev/cloud-tpu-v2-images/pathways-colocated-python/sidecar:20260423-python_3.12-jax_0.10.0", + "Full path to the sidecar Docker image", +) _TPU_TYPE = flags.DEFINE_enum( "tpu_type", "v6e", ["v5e", "v5p", "v6e", "tpu7x"], "TPU type" ) @@ -52,6 +57,7 @@ False, "If true, only print the generated YAML without deploying.", ) +_SIDECAR_SHM_DIR = "/tmp/sidecar_dir" @dataclasses.dataclass(frozen=True) @@ -191,6 +197,7 @@ def run_deployment( jobset_name, gcs_bucket, server_image, + sidecar_image, template_file, dry_run, deploy_func: Callable[[dict[str, Any]], None] = deploy_jobset, @@ -202,6 +209,8 @@ def run_deployment( context = { "JOBSET_NAME": jobset_name, "SERVER_IMAGE": server_image, + "SIDECAR_IMAGE": sidecar_image, + "SIDECAR_SHM_DIR": _SIDECAR_SHM_DIR, "GCS_SCRATCH_LOCATION": gcs_bucket, "NUM_SLICES": num_slices, "INSTANCE_TYPE": f"{tpu_config.instance_prefix}:{topology}", @@ -246,6 +255,7 @@ def main(argv: Sequence[str]) -> None: jobset_name=_JOBSET_NAME.value, gcs_bucket=_GCS_BUCKET.value, server_image=server_image, + sidecar_image=_SIDECAR_IMAGE.value, template_file=_TEMPLATE_FILE.value, dry_run=_DRY_RUN.value, ) diff --git a/pathwaysutils/experimental/shared_pathways_service/dockerfiles/sample_sidecar.dockerfile b/pathwaysutils/experimental/shared_pathways_service/dockerfiles/sample_sidecar.dockerfile new file mode 100644 index 0000000..a50892c --- /dev/null +++ b/pathwaysutils/experimental/shared_pathways_service/dockerfiles/sample_sidecar.dockerfile @@ -0,0 +1,27 @@ +# Use the JAX image with the custom-built sidecar as the base. +FROM us-docker.pkg.dev/cloud-tpu-v2-images/pathways-colocated-python/sidecar:20260423-python_3.12-jax_0.10.0 + +# Set the working directory +WORKDIR /app + +# 1. Upgrade pip and build tools +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --upgrade pip setuptools wheel + +# 2. Clone MaxText +RUN git clone https://github.com/google/maxtext.git + +# ADD THE CACHE MOUNT HERE +# Install the same version of JAX and JAXlib as the base image. +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install -r maxtext/src/dependencies/requirements/base_requirements/requirements.txt && \ + uv pip install --upgrade jax==0.10.0 jaxlib==0.10.0 + +# 3. (optional) Copy your local edits to MaxText requirements and src, if any. +# Make sure you're running this docker build from the root of your local MaxText +# checkout. +# COPY maxtext/src/dependencies/requirements/base_requirements/requirements.txt ./requirements.txt +# COPY maxtext/src /app/maxtext/src + +# Ensure MaxText src and Orbax are in PYTHONPATH +ENV PYTHONPATH=/app/maxtext/src:/app/orbax/checkpoint:$PYTHONPATH diff --git a/pathwaysutils/experimental/shared_pathways_service/gke_utils.py b/pathwaysutils/experimental/shared_pathways_service/gke_utils.py index 2e08a49..ac17acf 100644 --- a/pathwaysutils/experimental/shared_pathways_service/gke_utils.py +++ b/pathwaysutils/experimental/shared_pathways_service/gke_utils.py @@ -1,5 +1,6 @@ """GKE utils for deploying and managing the Pathways proxy.""" +import json import logging import re import socket @@ -475,3 +476,58 @@ def is_local_port_free(port: int) -> bool: """Checks if a local port is free.""" return portpicker.is_port_free(port) + +def get_worker_sidecar_image( + pathways_service: str, namespace: str = "default" +) -> str | None: + """Gets the image of the sidecar container used by the workers.""" + pathways_head_hostname = pathways_service.split(":")[0] + _validate_k8s_name(namespace) + + # Try to extract the jobset name from the Pathways service hostname. + jobset_name = None + if "-pathways-head" in pathways_head_hostname: + jobset_name = pathways_head_hostname.split("-pathways-head")[0] + + command = ["kubectl", "get", "pods", "-n", namespace, "-o", "json"] + try: + result = subprocess.run( + command, + check=True, + capture_output=True, + text=True, + ) + except subprocess.CalledProcessError as e: + _logger.exception("Failed to get pods. kubectl output:\n%r", e.stderr) + return None + + try: + pods_data = json.loads(result.stdout) + except json.JSONDecodeError as e: + _logger.exception("Failed to parse kubectl get pods output: %r", e) + return None + + items = pods_data.get("items", []) + + # Look for pods belonging to the jobset and having the sidecar + # container/initContainer. + if jobset_name: + for pod in items: + metadata = pod.get("metadata", {}) + labels = metadata.get("labels", {}) + pod_jobset_name = labels.get("jobset.sigs.k8s.io/jobset-name") + pod_name = metadata.get("name", "") + + if pod_jobset_name == jobset_name or pod_name.startswith(jobset_name): + spec = pod.get("spec", {}) + for container in spec.get("initContainers", []) + spec.get( + "containers", [] + ): + if container.get("name") == "colocated-python-sidecar": + image = container.get("image") + if image: + return image + + return None + + diff --git a/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py b/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py index 80fc04c..0eda058 100644 --- a/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py +++ b/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py @@ -48,14 +48,17 @@ class ProxyOptions: use_insecure_credentials: Whether to use insecure gRPC credentials for the proxy server. xla_flags: A list of XLA flags to pass to the proxy server. + sidecar: Whether to use the worker sidecar or not. """ use_insecure_credentials: bool = False xla_flags: list[str] = dataclasses.field(default_factory=list) + sidecar: bool = False @classmethod def from_list(cls, options: Iterable[str] | None) -> "ProxyOptions": """Creates a ProxyOptions object from a list of 'key:value' strings.""" use_insecure = False + use_sidecar = False xla_flags = [] for option in options or []: if ":" in option: @@ -63,6 +66,8 @@ def from_list(cls, options: Iterable[str] | None) -> "ProxyOptions": key_strip = key.strip().lower() if key_strip == "use_insecure_credentials": use_insecure = value.strip().lower() == "true" + elif key_strip == "sidecar": + use_sidecar = value.strip().lower() == "true" elif key_strip == "xla_flags": val_strip = value.strip() if ( @@ -78,7 +83,11 @@ def from_list(cls, options: Iterable[str] | None) -> "ProxyOptions": if xla_flags: validators.validate_xla_flags(xla_flags) - return cls(use_insecure_credentials=use_insecure, xla_flags=xla_flags) + return cls( + use_insecure_credentials=use_insecure, + xla_flags=xla_flags, + sidecar=use_sidecar, + ) def _deploy_pathways_proxy_server( @@ -134,6 +143,9 @@ def _deploy_pathways_proxy_server( ) proxy_args_str = "\n" + proxy_args_str + if proxy_options.sidecar: + proxy_args_str += "\n - --sidecar_name=external" + template = string.Template(yaml_template) substituted_yaml = template.substitute( PROXY_JOB_NAME=proxy_job_name, @@ -455,13 +467,20 @@ def connect( gke_utils.fetch_cluster_credentials( cluster_name=cluster, project_id=project, location=region ) - proxy_job_name = ( - proxy_job_name or f"isc-proxy-{os.environ.get('USER', 'user')}-{''.join( - random.choices(string.ascii_lowercase + string.digits, k=5) - )}" - ) proxy_options_obj = ProxyOptions.from_list(proxy_options) + if proxy_options_obj.sidecar: + sidecar_image = gke_utils.get_worker_sidecar_image( + pathways_service=pathways_service + ) + if sidecar_image: + validators.validate_sidecar_image_versions(sidecar_image) + + proxy_job_name = ( + proxy_job_name + or f"isc-proxy-{os.environ.get('USER', 'user')}-" + f"{''.join(random.choices(string.ascii_lowercase + string.digits, k=5))}" + ) _logger.info("Starting ISCPathways context.") with _ISCPathways( diff --git a/pathwaysutils/experimental/shared_pathways_service/validators.py b/pathwaysutils/experimental/shared_pathways_service/validators.py index 7b28440..420d7d6 100644 --- a/pathwaysutils/experimental/shared_pathways_service/validators.py +++ b/pathwaysutils/experimental/shared_pathways_service/validators.py @@ -3,8 +3,10 @@ from collections.abc import Iterable, Mapping import logging import re +import sys from typing import Any from absl import flags +import jax _logger = logging.getLogger(__name__) @@ -133,3 +135,70 @@ def validate_xla_flags(xla_flags: Iterable[str] | None) -> None: raise flags.ValidationError( f"XLA flag '{flag}' must start with '--xla_'." ) + + +def validate_sidecar_image_versions(sidecar_image: str) -> None: + """Checks compatibility of sidecar image versions with user environment. + + Compares the Python and JAX versions in the sidecar image tag with the user + environment's Python and JAX versions. + + Args: + sidecar_image: The sidecar image string, e.g., + "us-docker.pkg.dev/.../sidecar:20260423-python_3.12-jax_0.10.0". + + Raises: + ValueError: If the sidecar image Python or JAX versions do not match the + user environment. + """ + _logger.info( + "Checking sidecar image version compatibility: %s", sidecar_image + ) + + parts = sidecar_image.rsplit(":", 1) + if len(parts) < 2: + return + tag = parts[1] + + sidecar_python_match = re.search( + r"python[-_]?(\d+\.\d+(?:\.\d+)*)", tag, re.IGNORECASE + ) + + def clean_version(version_str: str) -> str: + match = re.match(r"^(\d+(?:\.\d+)*)", version_str) + return match.group(1) if match else version_str + + def versions_match(sidecar_ver: str, env_ver: str) -> bool: + sidecar_parts = sidecar_ver.split(".") + env_parts = env_ver.split(".") + compare_len = min(len(sidecar_parts), len(env_parts)) + if compare_len == 0: + return False + return sidecar_parts[:compare_len] == env_parts[:compare_len] + + sidecar_jax_match = re.search( + r"jax[-_]?(\d+\.\d+(?:\.\d+)*)", tag, re.IGNORECASE + ) + if sidecar_python_match: + sidecar_python = clean_version(sidecar_python_match.group(1)) + env_python = ( + f"{sys.version_info.major}.{sys.version_info.minor}." + f"{sys.version_info.micro}" + ) + if not versions_match(sidecar_python, env_python): + raise ValueError( + f"Python version mismatch: sidecar image matches Python version " + f"{sidecar_python}, but the user environment is running Python " + f"{env_python}." + ) + + if sidecar_jax_match: + sidecar_jax = clean_version(sidecar_jax_match.group(1)) + env_jax = clean_version(jax.__version__) + if not versions_match(sidecar_jax, env_jax): + raise ValueError( + f"JAX version mismatch: sidecar image matches JAX version " + f"{sidecar_jax}, but the user environment is running JAX " + f"{env_jax}." + ) + diff --git a/pathwaysutils/experimental/shared_pathways_service/yamls/pw-service.yaml b/pathwaysutils/experimental/shared_pathways_service/yamls/pw-service.yaml index 19769db..a02750e 100644 --- a/pathwaysutils/experimental/shared_pathways_service/yamls/pw-service.yaml +++ b/pathwaysutils/experimental/shared_pathways_service/yamls/pw-service.yaml @@ -87,6 +87,7 @@ spec: - --server_port=29005 - --resource_manager_address=$$(PATHWAYS_HEAD):29001 - --gcs_scratch_location=${GCS_SCRATCH_LOCATION} + - --cloud_pathways_sidecar_shm_directory=${SIDECAR_SHM_DIR} env: - name: TPU_MIN_LOG_LEVEL value: "0" @@ -133,8 +134,47 @@ spec: limits: google.com/tpu: "${CHIPS_PER_VM}" volumeMounts: - - mountPath: /tmp - name: shared-tmp + - name: shared-tmp + mountPath: /tmp + - name: sidecar-shared-memory + mountPath: ${SIDECAR_SHM_DIR} + initContainers: + - name: colocated-python-sidecar + image: ${SIDECAR_IMAGE} + imagePullPolicy: Always + env: + - name: GRPC_SERVER_ADDRESS + value: '''0.0.0.0:50051''' + - name: CLOUD_PATHWAYS_SIDECAR_SHM_DIRECTORY + value: ${SIDECAR_SHM_DIR} + - name: PYTHONUNBUFFERED + value: '1' + # --- High Verbosity Logging Variables --- + - name: LOGLEVEL + value: 'DEBUG' + - name: GLOG_minloglevel + value: '0' # 0 = INFO level base + - name: GLOG_v + value: '5' # Extreme verbosity for all C++ modules + - name: TF_CPP_MIN_LOG_LEVEL + value: '0' + - name: TF_CPP_MIN_VLOG_LEVEL + value: '5' # TF/XLA verbose logging + - name: TPU_MIN_LOG_LEVEL + value: '0' + - name: GLOG_vmodule + value: 'jax_array_handlers=5,type_handlers=5,tensorstore_utils=5' + # ---------------------------------------- + ports: + - containerPort: 50051 + protocol: TCP + resources: {} + restartPolicy: Always + volumeMounts: + - name: shared-tmp + mountPath: /tmp + - name: sidecar-shared-memory + mountPath: ${SIDECAR_SHM_DIR} dnsPolicy: ClusterFirstWithHostNet hostNetwork: true nodeSelector: @@ -146,6 +186,9 @@ spec: hostPath: path: /tmp type: DirectoryOrCreate + - name: sidecar-shared-memory + emptyDir: + medium: Memory startupPolicy: startupPolicyOrder: InOrder successPolicy: