From fb191aeaa70304cd29ab8732a3eed6de2654175a Mon Sep 17 00:00:00 2001 From: Pat O'Connor Date: Tue, 29 Jul 2025 12:33:39 +0100 Subject: [PATCH 1/2] feat(RHOAIENG-26487): Cluster lifecycling via RayJob Signed-off-by: Pat O'Connor --- src/codeflare_sdk/ray/rayjobs/pretty_print.py | 1 + src/codeflare_sdk/ray/rayjobs/rayjob.py | 168 +++++++-- .../ray/rayjobs/test_pretty_print.py | 3 + src/codeflare_sdk/ray/rayjobs/test_rayjob.py | 324 +++++++++++++++++- 4 files changed, 460 insertions(+), 36 deletions(-) diff --git a/src/codeflare_sdk/ray/rayjobs/pretty_print.py b/src/codeflare_sdk/ray/rayjobs/pretty_print.py index 9bc89b88..34e8dfa1 100644 --- a/src/codeflare_sdk/ray/rayjobs/pretty_print.py +++ b/src/codeflare_sdk/ray/rayjobs/pretty_print.py @@ -40,6 +40,7 @@ def print_job_status(job_info: RayJobInfo): # Add timing information if available if job_info.start_time: + table.add_row() table.add_row(f"[bold]Started:[/bold] {job_info.start_time}") # Add attempt counts if there are failures diff --git a/src/codeflare_sdk/ray/rayjobs/rayjob.py b/src/codeflare_sdk/ray/rayjobs/rayjob.py index ac2210a2..0d966b0e 100644 --- a/src/codeflare_sdk/ray/rayjobs/rayjob.py +++ b/src/codeflare_sdk/ray/rayjobs/rayjob.py @@ -20,6 +20,10 @@ from typing import Dict, Any, Optional, Tuple from odh_kuberay_client.kuberay_job_api import RayjobApi +from ..cluster.cluster import Cluster +from ..cluster.config import ClusterConfiguration +from ..cluster.build_ray_cluster import build_ray_cluster + from .status import ( RayJobDeploymentStatus, CodeflareRayJobStatus, @@ -27,7 +31,7 @@ ) from . import pretty_print -# Set up logging + logger = logging.getLogger(__name__) @@ -42,74 +46,110 @@ class RayJob: def __init__( self, job_name: str, - cluster_name: str, + cluster_name: Optional[str] = None, + cluster_config: Optional[ClusterConfiguration] = None, namespace: str = "default", - entrypoint: str = "None", + entrypoint: Optional[str] = None, runtime_env: Optional[Dict[str, Any]] = None, + shutdown_after_job_finishes: bool = True, + ttl_seconds_after_finished: int = 0, + active_deadline_seconds: Optional[int] = None, ): """ Initialize a RayJob instance. Args: - name: The name for the Ray job - namespace: The Kubernetes namespace to submit the job to (default: "default") - cluster_name: The name of the Ray cluster to submit the job to - **kwargs: Additional configuration options + job_name: The name for the Ray job + cluster_name: The name of an existing Ray cluster (optional if cluster_config provided) + cluster_config: Configuration for creating a new cluster (optional if cluster_name provided) + namespace: The Kubernetes namespace (default: "default") + entrypoint: The Python script or command to run (required for submission) + runtime_env: Ray runtime environment configuration (optional) + shutdown_after_job_finishes: Whether to automatically cleanup the cluster after job completion (default: True) + ttl_seconds_after_finished: Seconds to wait before cleanup after job finishes (default: 0) + active_deadline_seconds: Maximum time the job can run before being terminated (optional) """ + # Validate input parameters + if cluster_name is None and cluster_config is None: + raise ValueError("Either cluster_name or cluster_config must be provided") + + if cluster_name is not None and cluster_config is not None: + raise ValueError("Cannot specify both cluster_name and cluster_config") + self.name = job_name self.namespace = namespace - self.cluster_name = cluster_name self.entrypoint = entrypoint self.runtime_env = runtime_env + self.shutdown_after_job_finishes = shutdown_after_job_finishes + self.ttl_seconds_after_finished = ttl_seconds_after_finished + self.active_deadline_seconds = active_deadline_seconds + + # Cluster configuration + self._cluster_name = cluster_name + self._cluster_config = cluster_config + + # Determine cluster name for the job + if cluster_config is not None: + # Ensure cluster config has the same namespace as the job + if cluster_config.namespace is None: + cluster_config.namespace = namespace + elif cluster_config.namespace != namespace: + logger.warning( + f"Cluster config namespace ({cluster_config.namespace}) differs from job namespace ({namespace})" + ) + + self.cluster_name = cluster_config.name or f"{job_name}-cluster" + # Update the cluster config name if it wasn't set + if not cluster_config.name: + cluster_config.name = self.cluster_name + else: + self.cluster_name = cluster_name # Initialize the KubeRay job API client self._api = RayjobApi() logger.info(f"Initialized RayJob: {self.name} in namespace: {self.namespace}") - def submit( - self, - ) -> str: + def submit(self) -> str: """ Submit the Ray job to the Kubernetes cluster. - Args: - entrypoint: The Python script or command to run - runtime_env: Ray runtime environment configuration (optional) + The RayJob CRD will automatically: + - Create a new cluster if cluster_config was provided + - Use existing cluster if cluster_name was provided + - Clean up resources based on shutdown_after_job_finishes setting Returns: The job ID/name if submission was successful Raises: - RuntimeError: If the job has already been submitted or submission fails + ValueError: If entrypoint is not provided + RuntimeError: If job submission fails """ + # Validate required parameters + if not self.entrypoint: + raise ValueError("entrypoint must be provided to submit a RayJob") + # Build the RayJob custom resource - rayjob_cr = self._build_rayjob_cr( - entrypoint=self.entrypoint, - runtime_env=self.runtime_env, - ) + rayjob_cr = self._build_rayjob_cr() - # Submit the job - logger.info( - f"Submitting RayJob {self.name} to RayCluster {self.cluster_name} in namespace {self.namespace}" - ) + # Submit the job - KubeRay operator handles everything else + logger.info(f"Submitting RayJob {self.name} to KubeRay operator") result = self._api.submit_job(k8s_namespace=self.namespace, job=rayjob_cr) if result: logger.info(f"Successfully submitted RayJob {self.name}") + if self.shutdown_after_job_finishes: + logger.info( + f"Cluster will be automatically cleaned up {self.ttl_seconds_after_finished}s after job completion" + ) return self.name else: raise RuntimeError(f"Failed to submit RayJob {self.name}") - def _build_rayjob_cr( - self, - entrypoint: str, - runtime_env: Optional[Dict[str, Any]] = None, - ) -> Dict[str, Any]: + def _build_rayjob_cr(self) -> Dict[str, Any]: """ - Build the RayJob custom resource specification. - - This creates a minimal RayJob CR that can be extended later. + Build the RayJob custom resource specification using native RayJob capabilities. """ # Basic RayJob custom resource structure rayjob_cr = { @@ -120,17 +160,75 @@ def _build_rayjob_cr( "namespace": self.namespace, }, "spec": { - "entrypoint": entrypoint, - "clusterSelector": {"ray.io/cluster": self.cluster_name}, + "entrypoint": self.entrypoint, + "shutdownAfterJobFinishes": self.shutdown_after_job_finishes, + "ttlSecondsAfterFinished": self.ttl_seconds_after_finished, }, } + # Add active deadline if specified + if self.active_deadline_seconds: + rayjob_cr["spec"]["activeDeadlineSeconds"] = self.active_deadline_seconds + # Add runtime environment if specified - if runtime_env: - rayjob_cr["spec"]["runtimeEnvYAML"] = str(runtime_env) + if self.runtime_env: + rayjob_cr["spec"]["runtimeEnvYAML"] = str(self.runtime_env) + + # Configure cluster: either use existing or create new + if self._cluster_config is not None: + # Use rayClusterSpec to create a new cluster - leverage existing build logic + ray_cluster_spec = self._build_ray_cluster_spec() + rayjob_cr["spec"]["rayClusterSpec"] = ray_cluster_spec + logger.info(f"RayJob will create new cluster: {self.cluster_name}") + else: + # Use clusterSelector to reference existing cluster + rayjob_cr["spec"]["clusterSelector"] = {"ray.io/cluster": self.cluster_name} + logger.info(f"RayJob will use existing cluster: {self.cluster_name}") return rayjob_cr + def _build_ray_cluster_spec(self) -> Dict[str, Any]: + """ + Build the RayCluster spec from ClusterConfiguration using existing build_ray_cluster logic. + + Returns: + Dict containing the RayCluster spec for embedding in RayJob + """ + if not self._cluster_config: + raise RuntimeError("No cluster configuration provided") + + # Create a shallow copy of the cluster config to avoid modifying the original + import copy + + temp_config = copy.copy(self._cluster_config) + + # Ensure we get a RayCluster (not AppWrapper) and don't write to file + temp_config.appwrapper = False + temp_config.write_to_file = False + + # Create a minimal Cluster object for the build process + from ..cluster.cluster import Cluster + + temp_cluster = Cluster.__new__(Cluster) # Create without calling __init__ + temp_cluster.config = temp_config + + """ + For now, RayJob with a new/auto-created cluster will not work with Kueue. + This is due to the Kueue label not being propagated to the RayCluster. + """ + + # Use the existing build_ray_cluster function to generate the RayCluster + ray_cluster_dict = build_ray_cluster(temp_cluster) + + # Extract just the RayCluster spec - RayJob CRD doesn't support metadata in rayClusterSpec + # Note: CodeFlare Operator should still create dashboard routes for the RayCluster + ray_cluster_spec = ray_cluster_dict["spec"] + + logger.info( + f"Built RayCluster spec using existing build logic for cluster: {self.cluster_name}" + ) + return ray_cluster_spec + def status( self, print_to_console: bool = True ) -> Tuple[CodeflareRayJobStatus, bool]: diff --git a/src/codeflare_sdk/ray/rayjobs/test_pretty_print.py b/src/codeflare_sdk/ray/rayjobs/test_pretty_print.py index dbfd7caf..3bbe8bee 100644 --- a/src/codeflare_sdk/ray/rayjobs/test_pretty_print.py +++ b/src/codeflare_sdk/ray/rayjobs/test_pretty_print.py @@ -106,6 +106,7 @@ def test_print_job_status_running_format(mocker): call("[bold]Status:[/bold] Running"), call("[bold]RayCluster:[/bold] test-cluster"), call("[bold]Namespace:[/bold] test-ns"), + call(), # Empty row before timing info call("[bold]Started:[/bold] 2025-07-28T11:37:07Z"), ] mock_inner_table.add_row.assert_has_calls(expected_calls) @@ -166,6 +167,7 @@ def test_print_job_status_complete_format(mocker): call("[bold]Status:[/bold] Complete"), call("[bold]RayCluster:[/bold] prod-cluster"), call("[bold]Namespace:[/bold] prod-ns"), + call(), # Empty row before timing info call("[bold]Started:[/bold] 2025-07-28T11:37:07Z"), ] mock_inner_table.add_row.assert_has_calls(expected_calls) @@ -215,6 +217,7 @@ def test_print_job_status_failed_with_attempts_format(mocker): call("[bold]Status:[/bold] Failed"), call("[bold]RayCluster:[/bold] test-cluster"), call("[bold]Namespace:[/bold] test-ns"), + call(), # Empty row before timing info call("[bold]Started:[/bold] 2025-07-28T11:37:07Z"), call("[bold]Failed Attempts:[/bold] 3"), # Failed attempts should be shown ] diff --git a/src/codeflare_sdk/ray/rayjobs/test_rayjob.py b/src/codeflare_sdk/ray/rayjobs/test_rayjob.py index 5429f303..7554ca4c 100644 --- a/src/codeflare_sdk/ray/rayjobs/test_rayjob.py +++ b/src/codeflare_sdk/ray/rayjobs/test_rayjob.py @@ -13,8 +13,9 @@ # limitations under the License. import pytest -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch from codeflare_sdk.ray.rayjobs.rayjob import RayJob +from codeflare_sdk.ray.cluster.config import ClusterConfiguration def test_rayjob_submit_success(mocker): @@ -86,3 +87,324 @@ def test_rayjob_submit_failure(mocker): # Test that RuntimeError is raised on failure with pytest.raises(RuntimeError, match="Failed to submit RayJob test-rayjob"): rayjob.submit() + + +def test_rayjob_init_validation_both_provided(mocker): + """Test that providing both cluster_name and cluster_config raises error.""" + # Mock kubernetes config loading + mocker.patch("kubernetes.config.load_kube_config") + + # Mock the RayjobApi class entirely + mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi") + + cluster_config = ClusterConfiguration(name="test-cluster", namespace="test") + + with pytest.raises( + ValueError, match="Cannot specify both cluster_name and cluster_config" + ): + RayJob( + job_name="test-job", + cluster_name="existing-cluster", + cluster_config=cluster_config, + entrypoint="python script.py", + ) + + +def test_rayjob_init_validation_neither_provided(mocker): + """Test that providing neither cluster_name nor cluster_config raises error.""" + # Mock kubernetes config loading (though this should fail before reaching it) + mocker.patch("kubernetes.config.load_kube_config") + + # Mock the RayjobApi class entirely (though this should fail before reaching it) + mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi") + + with pytest.raises( + ValueError, match="Either cluster_name or cluster_config must be provided" + ): + RayJob(job_name="test-job", entrypoint="python script.py") + + +def test_rayjob_init_with_cluster_config(mocker): + """Test RayJob initialization with cluster configuration for auto-creation.""" + # Mock kubernetes config loading + mocker.patch("kubernetes.config.load_kube_config") + + # Mock the RayjobApi class entirely + mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi") + + cluster_config = ClusterConfiguration( + name="auto-cluster", namespace="test-namespace", num_workers=2 + ) + + rayjob = RayJob( + job_name="test-job", + cluster_config=cluster_config, + entrypoint="python script.py", + ) + + assert rayjob.name == "test-job" + assert rayjob.cluster_name == "auto-cluster" + assert rayjob._cluster_config == cluster_config + assert rayjob._cluster_name is None + + +def test_rayjob_cluster_name_generation(mocker): + """Test that cluster names are generated when config has empty name.""" + # Mock kubernetes config loading + mocker.patch("kubernetes.config.load_kube_config") + + # Mock the RayjobApi class entirely + mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi") + + cluster_config = ClusterConfiguration( + name="", # Empty name should trigger generation + namespace="test-namespace", + num_workers=1, + ) + + rayjob = RayJob( + job_name="my-job", cluster_config=cluster_config, entrypoint="python script.py" + ) + + assert rayjob.cluster_name == "my-job-cluster" + assert cluster_config.name == "my-job-cluster" # Should be updated + + +def test_rayjob_cluster_config_namespace_none(mocker): + """Test that cluster config namespace is set when None.""" + # Mock kubernetes config loading + mocker.patch("kubernetes.config.load_kube_config") + + # Mock the RayjobApi class entirely + mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi") + + cluster_config = ClusterConfiguration( + name="test-cluster", + namespace=None, # This should be set to job namespace + num_workers=1, + ) + + rayjob = RayJob( + job_name="test-job", + cluster_config=cluster_config, + namespace="job-namespace", + entrypoint="python script.py", + ) + + assert cluster_config.namespace == "job-namespace" + assert rayjob.namespace == "job-namespace" + + +def test_rayjob_with_active_deadline_seconds(mocker): + """Test RayJob CR generation with active deadline seconds.""" + # Mock kubernetes config loading + mocker.patch("kubernetes.config.load_kube_config") + + # Mock the RayjobApi class entirely + mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi") + + rayjob = RayJob( + job_name="test-job", + cluster_name="test-cluster", + namespace="test-namespace", + entrypoint="python main.py", + active_deadline_seconds=30, + ) + + rayjob_cr = rayjob._build_rayjob_cr() + + assert rayjob_cr["spec"]["activeDeadlineSeconds"] == 30 + + +def test_build_ray_cluster_spec_no_config_error(mocker): + """Test _build_ray_cluster_spec raises error when no cluster config.""" + # Mock kubernetes config loading + mocker.patch("kubernetes.config.load_kube_config") + + # Mock the RayjobApi class entirely + mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi") + + # Create RayJob with cluster_name (no cluster_config) + rayjob = RayJob( + job_name="test-job", + cluster_name="existing-cluster", + entrypoint="python script.py", + ) + + # Line 198: Should raise RuntimeError when trying to build spec without config + with pytest.raises(RuntimeError, match="No cluster configuration provided"): + rayjob._build_ray_cluster_spec() + + +@patch("codeflare_sdk.ray.rayjobs.rayjob.build_ray_cluster") +def test_build_ray_cluster_spec(mock_build_ray_cluster, mocker): + """Test _build_ray_cluster_spec method.""" + mocker.patch("kubernetes.config.load_kube_config") + + # Mock the RayjobApi class entirely + mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi") + + mock_ray_cluster = { + "apiVersion": "ray.io/v1", + "kind": "RayCluster", + "metadata": {"name": "test-cluster", "namespace": "test"}, + "spec": { + "rayVersion": "2.9.0", + "headGroupSpec": {"replicas": 1}, + "workerGroupSpecs": [{"replicas": 2}], + }, + } + mock_build_ray_cluster.return_value = mock_ray_cluster + + cluster_config = ClusterConfiguration( + name="test-cluster", namespace="test", num_workers=2 + ) + + rayjob = RayJob( + job_name="test-job", + cluster_config=cluster_config, + entrypoint="python script.py", + ) + + spec = rayjob._build_ray_cluster_spec() + + # Should return only the spec part, not metadata + assert spec == mock_ray_cluster["spec"] + assert "metadata" not in spec + + # Verify build_ray_cluster was called with correct parameters + mock_build_ray_cluster.assert_called_once() + call_args = mock_build_ray_cluster.call_args[0][0] + assert call_args.config.appwrapper is False + assert call_args.config.write_to_file is False + + +def test_build_rayjob_cr_with_existing_cluster(mocker): + """Test _build_rayjob_cr method with existing cluster.""" + mocker.patch("kubernetes.config.load_kube_config") + + # Mock the RayjobApi class entirely + mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi") + + rayjob = RayJob( + job_name="test-job", + cluster_name="existing-cluster", + namespace="test-namespace", + entrypoint="python main.py", + shutdown_after_job_finishes=False, + ttl_seconds_after_finished=300, + ) + + rayjob_cr = rayjob._build_rayjob_cr() + + # Check basic structure + assert rayjob_cr["apiVersion"] == "ray.io/v1" + assert rayjob_cr["kind"] == "RayJob" + assert rayjob_cr["metadata"]["name"] == "test-job" + + # Check lifecycle parameters + spec = rayjob_cr["spec"] + assert spec["entrypoint"] == "python main.py" + assert spec["shutdownAfterJobFinishes"] is False + assert spec["ttlSecondsAfterFinished"] == 300 + + # Should use clusterSelector for existing cluster + assert spec["clusterSelector"]["ray.io/cluster"] == "existing-cluster" + assert "rayClusterSpec" not in spec + + +@patch("codeflare_sdk.ray.rayjobs.rayjob.build_ray_cluster") +def test_build_rayjob_cr_with_auto_cluster(mock_build_ray_cluster, mocker): + """Test _build_rayjob_cr method with auto-created cluster.""" + mocker.patch("kubernetes.config.load_kube_config") + + # Mock the RayjobApi class entirely + mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi") + + mock_ray_cluster = { + "apiVersion": "ray.io/v1", + "kind": "RayCluster", + "metadata": {"name": "auto-cluster", "namespace": "test"}, + "spec": { + "rayVersion": "2.9.0", + "headGroupSpec": {"replicas": 1}, + "workerGroupSpecs": [{"replicas": 2}], + }, + } + mock_build_ray_cluster.return_value = mock_ray_cluster + + cluster_config = ClusterConfiguration( + name="auto-cluster", namespace="test-namespace", num_workers=2 + ) + + rayjob = RayJob( + job_name="test-job", cluster_config=cluster_config, entrypoint="python main.py" + ) + + rayjob_cr = rayjob._build_rayjob_cr() + + # Should use rayClusterSpec for auto-created cluster + assert rayjob_cr["spec"]["rayClusterSpec"] == mock_ray_cluster["spec"] + assert "clusterSelector" not in rayjob_cr["spec"] + + +def test_submit_validation_no_entrypoint(mocker): + """Test that submit() raises error when entrypoint is None.""" + mocker.patch("kubernetes.config.load_kube_config") + mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi") + + rayjob = RayJob( + job_name="test-job", + cluster_name="test-cluster", + entrypoint=None, # No entrypoint provided + ) + + with pytest.raises( + ValueError, match="entrypoint must be provided to submit a RayJob" + ): + rayjob.submit() + + +@patch("codeflare_sdk.ray.rayjobs.rayjob.build_ray_cluster") +def test_submit_with_auto_cluster(mock_build_ray_cluster, mocker): + """Test successful submission with auto-created cluster.""" + mocker.patch("kubernetes.config.load_kube_config") + + mock_ray_cluster = { + "apiVersion": "ray.io/v1", + "kind": "RayCluster", + "spec": { + "rayVersion": "2.9.0", + "headGroupSpec": {"replicas": 1}, + "workerGroupSpecs": [{"replicas": 1}], + }, + } + mock_build_ray_cluster.return_value = mock_ray_cluster + + # Mock the RayjobApi + mock_api_class = mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi") + mock_api_instance = MagicMock() + mock_api_class.return_value = mock_api_instance + mock_api_instance.submit_job.return_value = True + + cluster_config = ClusterConfiguration( + name="auto-cluster", namespace="test", num_workers=1 + ) + + rayjob = RayJob( + job_name="test-job", + cluster_config=cluster_config, + entrypoint="python script.py", + ) + + result = rayjob.submit() + + assert result == "test-job" + + # Verify the correct RayJob CR was submitted + mock_api_instance.submit_job.assert_called_once() + call_args = mock_api_instance.submit_job.call_args + + job_cr = call_args.kwargs["job"] + assert "rayClusterSpec" in job_cr["spec"] + assert job_cr["spec"]["rayClusterSpec"] == mock_ray_cluster["spec"] From 6585567d6494f83fd3162a2a3fc386a582c48cc0 Mon Sep 17 00:00:00 2001 From: kryanbeane Date: Tue, 12 Aug 2025 20:46:01 +0100 Subject: [PATCH 2/2] feat(RHOAIENG-26487): rayjob lifecycled cluster improvements and tests --- poetry.lock | 42 +- pyproject.toml | 3 +- src/codeflare_sdk/__init__.py | 1 + src/codeflare_sdk/common/kueue/kueue.py | 3 +- src/codeflare_sdk/common/utils/__init__.py | 7 + src/codeflare_sdk/common/utils/k8s_utils.py | 37 + src/codeflare_sdk/common/utils/test_demos.py | 57 ++ .../common/utils/test_k8s_utils.py | 255 +++++++ .../common/widgets/test_widgets.py | 4 +- src/codeflare_sdk/common/widgets/widgets.py | 6 +- src/codeflare_sdk/ray/__init__.py | 1 + src/codeflare_sdk/ray/cluster/cluster.py | 28 +- src/codeflare_sdk/ray/rayjobs/__init__.py | 2 +- src/codeflare_sdk/ray/rayjobs/config.py | 455 +++++++++++++ src/codeflare_sdk/ray/rayjobs/rayjob.py | 165 ++--- src/codeflare_sdk/ray/rayjobs/test_config.py | 82 +++ src/codeflare_sdk/ray/rayjobs/test_rayjob.py | 641 ++++++++++++++++-- 17 files changed, 1604 insertions(+), 185 deletions(-) create mode 100644 src/codeflare_sdk/common/utils/k8s_utils.py create mode 100644 src/codeflare_sdk/common/utils/test_demos.py create mode 100644 src/codeflare_sdk/common/utils/test_k8s_utils.py create mode 100644 src/codeflare_sdk/ray/rayjobs/config.py create mode 100644 src/codeflare_sdk/ray/rayjobs/test_config.py diff --git a/poetry.lock b/poetry.lock index fca55833..381383d4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2399,26 +2399,6 @@ rsa = ["cryptography (>=3.0.0)"] signals = ["blinker (>=1.4.0)"] signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"] -[[package]] -name = "odh-kuberay-client" -version = "0.0.0.dev40" -description = "Python SDK for Kuberay client" -optional = false -python-versions = ">=3.11,<4.0" -groups = ["main"] -files = [ - {file = "odh_kuberay_client-0.0.0.dev40-py3-none-any.whl", hash = "sha256:547daaa07ff3687b75dc844473b0897822d3aa4803aed865037ddf41da22f593"}, - {file = "odh_kuberay_client-0.0.0.dev40.tar.gz", hash = "sha256:a4ec11aff244099256cbca0628d8dbb4c5fe48e09966a6b75b412895aebd4834"}, -] - -[package.dependencies] -kubernetes = ">=25.0.0" - -[package.source] -type = "legacy" -url = "https://test.pypi.org/simple" -reference = "testpypi" - [[package]] name = "opencensus" version = "0.11.4" @@ -3324,6 +3304,26 @@ files = [ [package.dependencies] pytest = ">=7.0.0" +[[package]] +name = "python-client" +version = "0.0.0-dev" +description = "Python Client for Kuberay" +optional = false +python-versions = "^3.11" +groups = ["main"] +files = [] +develop = false + +[package.dependencies] +kubernetes = ">=25.0.0" + +[package.source] +type = "git" +url = "https://github.com/ray-project/kuberay.git" +reference = "d1e750d9beac612ad455b951c1a789f971409ab3" +resolved_reference = "d1e750d9beac612ad455b951c1a789f971409ab3" +subdirectory = "clients/python-client" + [[package]] name = "python-dateutil" version = "3.9.0" @@ -4696,4 +4696,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = "^3.11" -content-hash = "a308f6517e3e8c5c14542421019a82056e8a81b43fb316a7862503ad799b9d49" +content-hash = "30c47f95bf1bf33682dd0bc107eef88f4e9226ca7ad5b33e929bfd3ab7030e95" diff --git a/pyproject.toml b/pyproject.toml index fe1ae401..df57c1f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ cryptography = "43.0.3" executing = "1.2.0" pydantic = "< 2" ipywidgets = "8.1.2" -odh-kuberay-client = {version = "0.0.0.dev40", source = "testpypi"} +python-client = { git = "https://github.com/ray-project/kuberay.git", subdirectory = "clients/python-client", rev = "d1e750d9beac612ad455b951c1a789f971409ab3" } [[tool.poetry.source]] name = "pypi" @@ -67,4 +67,3 @@ markers = [ ] addopts = "--timeout=900" testpaths = ["src/codeflare_sdk"] -collect_ignore = ["src/codeflare_sdk/common/utils/unit_test_support.py"] diff --git a/src/codeflare_sdk/__init__.py b/src/codeflare_sdk/__init__.py index 95753a59..f9a06524 100644 --- a/src/codeflare_sdk/__init__.py +++ b/src/codeflare_sdk/__init__.py @@ -11,6 +11,7 @@ AppWrapperStatus, RayJobClient, RayJob, + RayJobClusterConfig, ) from .common.widgets import view_clusters diff --git a/src/codeflare_sdk/common/kueue/kueue.py b/src/codeflare_sdk/common/kueue/kueue.py index 00f3364a..a721713e 100644 --- a/src/codeflare_sdk/common/kueue/kueue.py +++ b/src/codeflare_sdk/common/kueue/kueue.py @@ -18,6 +18,8 @@ from kubernetes import client from kubernetes.client.exceptions import ApiException +from ...common.utils import get_current_namespace + def get_default_kueue_name(namespace: str) -> Optional[str]: """ @@ -81,7 +83,6 @@ def list_local_queues( List[dict]: A list of dictionaries containing the name of the local queue and the available flavors """ - from ...ray.cluster.cluster import get_current_namespace if namespace is None: # pragma: no cover namespace = get_current_namespace() diff --git a/src/codeflare_sdk/common/utils/__init__.py b/src/codeflare_sdk/common/utils/__init__.py index e69de29b..e662bf5e 100644 --- a/src/codeflare_sdk/common/utils/__init__.py +++ b/src/codeflare_sdk/common/utils/__init__.py @@ -0,0 +1,7 @@ +""" +Common utilities for the CodeFlare SDK. +""" + +from .k8s_utils import get_current_namespace + +__all__ = ["get_current_namespace"] diff --git a/src/codeflare_sdk/common/utils/k8s_utils.py b/src/codeflare_sdk/common/utils/k8s_utils.py new file mode 100644 index 00000000..57eccf2d --- /dev/null +++ b/src/codeflare_sdk/common/utils/k8s_utils.py @@ -0,0 +1,37 @@ +""" +Kubernetes utility functions for the CodeFlare SDK. +""" + +import os +from kubernetes import config +from ..kubernetes_cluster import config_check, _kube_api_error_handling + + +def get_current_namespace(): + """ + Retrieves the current Kubernetes namespace. + + This function attempts to detect the current namespace by: + 1. First checking if running inside a pod (reading from service account namespace file) + 2. Falling back to reading from the current kubeconfig context + + Returns: + str: + The current namespace or None if not found. + """ + if os.path.isfile("/var/run/secrets/kubernetes.io/serviceaccount/namespace"): + try: + file = open("/var/run/secrets/kubernetes.io/serviceaccount/namespace", "r") + active_context = file.readline().strip("\n") + return active_context + except Exception as e: + print("Unable to find current namespace") + print("trying to gather from current context") + try: + _, active_context = config.list_kube_config_contexts(config_check()) + except Exception as e: + return _kube_api_error_handling(e) + try: + return active_context["context"]["namespace"] + except KeyError: + return None diff --git a/src/codeflare_sdk/common/utils/test_demos.py b/src/codeflare_sdk/common/utils/test_demos.py new file mode 100644 index 00000000..9124cbec --- /dev/null +++ b/src/codeflare_sdk/common/utils/test_demos.py @@ -0,0 +1,57 @@ +# Copyright 2025 IBM, Red Hat +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for demos module. +""" + +import pytest +import tempfile +from pathlib import Path +from unittest.mock import patch, MagicMock +from codeflare_sdk.common.utils.demos import copy_demo_nbs + + +class TestCopyDemoNbs: + """Test cases for copy_demo_nbs function.""" + + def test_copy_demo_nbs_directory_exists_error(self): + """Test that FileExistsError is raised when directory exists and overwrite=False.""" + with tempfile.TemporaryDirectory() as temp_dir: + # Create a subdirectory that will conflict + conflict_dir = Path(temp_dir) / "demo-notebooks" + conflict_dir.mkdir() + + with pytest.raises(FileExistsError, match="Directory.*already exists"): + copy_demo_nbs(dir=str(conflict_dir), overwrite=False) + + def test_copy_demo_nbs_overwrite_true(self): + """Test that overwrite=True allows copying to existing directory.""" + with tempfile.TemporaryDirectory() as temp_dir: + # Create a subdirectory that will conflict + conflict_dir = Path(temp_dir) / "demo-notebooks" + conflict_dir.mkdir() + + # Mock the demo_dir to point to a real directory + with patch("codeflare_sdk.common.utils.demos.demo_dir", temp_dir): + # Should not raise an error with overwrite=True + copy_demo_nbs(dir=str(conflict_dir), overwrite=True) + + def test_copy_demo_nbs_default_parameters(self): + """Test copy_demo_nbs with default parameters.""" + with tempfile.TemporaryDirectory() as temp_dir: + # Mock the demo_dir to point to a real directory + with patch("codeflare_sdk.common.utils.demos.demo_dir", temp_dir): + # Should work with default parameters + copy_demo_nbs(dir=temp_dir, overwrite=True) diff --git a/src/codeflare_sdk/common/utils/test_k8s_utils.py b/src/codeflare_sdk/common/utils/test_k8s_utils.py new file mode 100644 index 00000000..fcd0623d --- /dev/null +++ b/src/codeflare_sdk/common/utils/test_k8s_utils.py @@ -0,0 +1,255 @@ +# Copyright 2025 IBM, Red Hat +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for k8s_utils module. +""" + +import pytest +from unittest.mock import mock_open, patch, MagicMock +from codeflare_sdk.common.utils.k8s_utils import get_current_namespace + + +class TestGetCurrentNamespace: + """Test cases for get_current_namespace function.""" + + def test_get_current_namespace_incluster_success(self): + """Test successful namespace detection from in-cluster service account.""" + mock_file_content = "test-namespace\n" + + with patch("os.path.isfile", return_value=True): + with patch("builtins.open", mock_open(read_data=mock_file_content)): + result = get_current_namespace() + + assert result == "test-namespace" + + def test_get_current_namespace_incluster_file_read_error(self): + """Test handling of file read errors when reading service account namespace.""" + with patch("os.path.isfile", return_value=True): + with patch("builtins.open", side_effect=IOError("File read error")): + with patch("builtins.print") as mock_print: + # Mock config_check to avoid kubeconfig fallback + with patch( + "codeflare_sdk.common.utils.k8s_utils.config_check", + side_effect=Exception("Config error"), + ): + with patch( + "codeflare_sdk.common.utils.k8s_utils._kube_api_error_handling", + return_value=None, + ): + result = get_current_namespace() + + assert result is None + # Should see both error messages: in-cluster failure and kubeconfig fallback + mock_print.assert_any_call("Unable to find current namespace") + mock_print.assert_any_call("trying to gather from current context") + + def test_get_current_namespace_incluster_file_open_error(self): + """Test handling of file open errors when reading service account namespace.""" + with patch("os.path.isfile", return_value=True): + with patch( + "builtins.open", side_effect=PermissionError("Permission denied") + ): + with patch("builtins.print") as mock_print: + # Mock config_check to avoid kubeconfig fallback + with patch( + "codeflare_sdk.common.utils.k8s_utils.config_check", + side_effect=Exception("Config error"), + ): + with patch( + "codeflare_sdk.common.utils.k8s_utils._kube_api_error_handling", + return_value=None, + ): + result = get_current_namespace() + + assert result is None + # Should see both error messages: in-cluster failure and kubeconfig fallback + mock_print.assert_any_call("Unable to find current namespace") + mock_print.assert_any_call("trying to gather from current context") + + def test_get_current_namespace_kubeconfig_success(self): + """Test successful namespace detection from kubeconfig context.""" + mock_contexts = [ + {"name": "context1", "context": {"namespace": "default"}}, + {"name": "context2", "context": {"namespace": "test-namespace"}}, + ] + mock_active_context = { + "name": "context2", + "context": {"namespace": "test-namespace"}, + } + + with patch("os.path.isfile", return_value=False): + with patch("builtins.print") as mock_print: + with patch( + "codeflare_sdk.common.utils.k8s_utils.config_check", + return_value="~/.kube/config", + ): + with patch( + "kubernetes.config.list_kube_config_contexts", + return_value=(mock_contexts, mock_active_context), + ): + result = get_current_namespace() + + assert result == "test-namespace" + mock_print.assert_called_with("trying to gather from current context") + + def test_get_current_namespace_kubeconfig_no_namespace_in_context(self): + """Test handling when kubeconfig context has no namespace field.""" + mock_contexts = [ + {"name": "context1", "context": {}}, + {"name": "context2", "context": {"cluster": "test-cluster"}}, + ] + mock_active_context = { + "name": "context2", + "context": {"cluster": "test-cluster"}, + } + + with patch("os.path.isfile", return_value=False): + with patch("builtins.print") as mock_print: + with patch( + "codeflare_sdk.common.utils.k8s_utils.config_check", + return_value="~/.kube/config", + ): + with patch( + "kubernetes.config.list_kube_config_contexts", + return_value=(mock_contexts, mock_active_context), + ): + result = get_current_namespace() + + assert result is None + mock_print.assert_called_with("trying to gather from current context") + + def test_get_current_namespace_kubeconfig_config_check_error(self): + """Test handling when config_check raises an exception.""" + with patch("os.path.isfile", return_value=False): + with patch("builtins.print") as mock_print: + with patch( + "codeflare_sdk.common.utils.k8s_utils.config_check", + side_effect=Exception("Config error"), + ): + with patch( + "codeflare_sdk.common.utils.k8s_utils._kube_api_error_handling", + return_value=None, + ) as mock_error_handler: + result = get_current_namespace() + + assert result is None + mock_print.assert_called_with("trying to gather from current context") + mock_error_handler.assert_called_once() + + def test_get_current_namespace_kubeconfig_list_contexts_error(self): + """Test handling when list_kube_config_contexts raises an exception.""" + with patch("os.path.isfile", return_value=False): + with patch("builtins.print") as mock_print: + with patch( + "codeflare_sdk.common.utils.k8s_utils.config_check", + return_value="~/.kube/config", + ): + with patch( + "kubernetes.config.list_kube_config_contexts", + side_effect=Exception("Context error"), + ): + with patch( + "codeflare_sdk.common.utils.k8s_utils._kube_api_error_handling", + return_value=None, + ) as mock_error_handler: + result = get_current_namespace() + + assert result is None + mock_print.assert_called_with("trying to gather from current context") + mock_error_handler.assert_called_once() + + def test_get_current_namespace_kubeconfig_key_error(self): + """Test handling when accessing context namespace raises KeyError.""" + mock_contexts = [{"name": "context1", "context": {"namespace": "default"}}] + mock_active_context = {"name": "context1"} # Missing 'context' key + + with patch("os.path.isfile", return_value=False): + with patch("builtins.print") as mock_print: + with patch( + "codeflare_sdk.common.utils.k8s_utils.config_check", + return_value="~/.kube/config", + ): + with patch( + "kubernetes.config.list_kube_config_contexts", + return_value=(mock_contexts, mock_active_context), + ): + result = get_current_namespace() + + assert result is None + mock_print.assert_called_with("trying to gather from current context") + + def test_get_current_namespace_fallback_flow(self): + """Test the complete fallback flow from in-cluster to kubeconfig.""" + # First attempt: in-cluster file doesn't exist + # Second attempt: kubeconfig context has namespace + mock_contexts = [ + {"name": "context1", "context": {"namespace": "fallback-namespace"}} + ] + mock_active_context = { + "name": "context1", + "context": {"namespace": "fallback-namespace"}, + } + + with patch("os.path.isfile", return_value=False): + with patch("builtins.print") as mock_print: + with patch( + "codeflare_sdk.common.utils.k8s_utils.config_check", + return_value="~/.kube/config", + ): + with patch( + "kubernetes.config.list_kube_config_contexts", + return_value=(mock_contexts, mock_active_context), + ): + result = get_current_namespace() + + assert result == "fallback-namespace" + mock_print.assert_called_with("trying to gather from current context") + + def test_get_current_namespace_complete_failure(self): + """Test complete failure scenario where no namespace can be detected.""" + with patch("os.path.isfile", return_value=False): + with patch("builtins.print") as mock_print: + with patch( + "codeflare_sdk.common.utils.k8s_utils.config_check", + side_effect=Exception("Config error"), + ): + with patch( + "codeflare_sdk.common.utils.k8s_utils._kube_api_error_handling", + return_value=None, + ): + result = get_current_namespace() + + assert result is None + mock_print.assert_called_with("trying to gather from current context") + + def test_get_current_namespace_mixed_errors(self): + """Test scenario with mixed error conditions.""" + # In-cluster file exists but read fails, then kubeconfig also fails + with patch("os.path.isfile", return_value=True): + with patch("builtins.open", side_effect=IOError("File read error")): + with patch("builtins.print") as mock_print: + with patch( + "codeflare_sdk.common.utils.k8s_utils.config_check", + side_effect=Exception("Config error"), + ): + with patch( + "codeflare_sdk.common.utils.k8s_utils._kube_api_error_handling", + return_value=None, + ): + result = get_current_namespace() + + assert result is None + # Should see both error messages + assert mock_print.call_count >= 2 diff --git a/src/codeflare_sdk/common/widgets/test_widgets.py b/src/codeflare_sdk/common/widgets/test_widgets.py index f88d8eb2..33beca5c 100644 --- a/src/codeflare_sdk/common/widgets/test_widgets.py +++ b/src/codeflare_sdk/common/widgets/test_widgets.py @@ -106,7 +106,7 @@ def test_view_clusters(mocker, capsys): # Prepare to run view_clusters when notebook environment is detected mocker.patch("codeflare_sdk.common.widgets.widgets.is_notebook", return_value=True) mock_get_current_namespace = mocker.patch( - "codeflare_sdk.ray.cluster.cluster.get_current_namespace", + "codeflare_sdk.common.utils.get_current_namespace", return_value="default", ) namespace = mock_get_current_namespace.return_value @@ -250,7 +250,7 @@ def test_ray_cluster_manager_widgets_init(mocker, capsys): return_value=test_ray_clusters_df, ) mocker.patch( - "codeflare_sdk.ray.cluster.cluster.get_current_namespace", + "codeflare_sdk.common.utils.get_current_namespace", return_value=namespace, ) mock_delete_cluster = mocker.patch( diff --git a/src/codeflare_sdk/common/widgets/widgets.py b/src/codeflare_sdk/common/widgets/widgets.py index 36d896e8..91295fa9 100644 --- a/src/codeflare_sdk/common/widgets/widgets.py +++ b/src/codeflare_sdk/common/widgets/widgets.py @@ -26,6 +26,8 @@ import ipywidgets as widgets from IPython.display import display, HTML, Javascript import pandas as pd + +from ...common.utils import get_current_namespace from ...ray.cluster.config import ClusterConfiguration from ...ray.cluster.status import RayClusterStatus from ..kubernetes_cluster import _kube_api_error_handling @@ -43,8 +45,6 @@ class RayClusterManagerWidgets: """ def __init__(self, ray_clusters_df: pd.DataFrame, namespace: str = None): - from ...ray.cluster.cluster import get_current_namespace - # Data self.ray_clusters_df = ray_clusters_df self.namespace = get_current_namespace() if not namespace else namespace @@ -353,7 +353,7 @@ def view_clusters(namespace: str = None): ) return # Exit function if not in Jupyter Notebook - from ...ray.cluster.cluster import get_current_namespace + from ...common.utils import get_current_namespace if not namespace: namespace = get_current_namespace() diff --git a/src/codeflare_sdk/ray/__init__.py b/src/codeflare_sdk/ray/__init__.py index b2278a05..806ed9a4 100644 --- a/src/codeflare_sdk/ray/__init__.py +++ b/src/codeflare_sdk/ray/__init__.py @@ -6,6 +6,7 @@ from .rayjobs import ( RayJob, + RayJobClusterConfig, RayJobDeploymentStatus, CodeflareRayJobStatus, RayJobInfo, diff --git a/src/codeflare_sdk/ray/cluster/cluster.py b/src/codeflare_sdk/ray/cluster/cluster.py index 9eaad39e..5c378efd 100644 --- a/src/codeflare_sdk/ray/cluster/cluster.py +++ b/src/codeflare_sdk/ray/cluster/cluster.py @@ -27,6 +27,8 @@ import uuid import warnings +from ...common.utils import get_current_namespace + from ...common.kubernetes_cluster.auth import ( config_check, get_api_client, @@ -638,32 +640,6 @@ def list_all_queued( return resources -def get_current_namespace(): # pragma: no cover - """ - Retrieves the current Kubernetes namespace. - - Returns: - str: - The current namespace or None if not found. - """ - if os.path.isfile("/var/run/secrets/kubernetes.io/serviceaccount/namespace"): - try: - file = open("/var/run/secrets/kubernetes.io/serviceaccount/namespace", "r") - active_context = file.readline().strip("\n") - return active_context - except Exception as e: - print("Unable to find current namespace") - print("trying to gather from current context") - try: - _, active_context = config.list_kube_config_contexts(config_check()) - except Exception as e: - return _kube_api_error_handling(e) - try: - return active_context["context"]["namespace"] - except KeyError: - return None - - def get_cluster( cluster_name: str, namespace: str = "default", diff --git a/src/codeflare_sdk/ray/rayjobs/__init__.py b/src/codeflare_sdk/ray/rayjobs/__init__.py index 47b573af..756fad91 100644 --- a/src/codeflare_sdk/ray/rayjobs/__init__.py +++ b/src/codeflare_sdk/ray/rayjobs/__init__.py @@ -1,2 +1,2 @@ -from .rayjob import RayJob +from .rayjob import RayJob, RayJobClusterConfig from .status import RayJobDeploymentStatus, CodeflareRayJobStatus, RayJobInfo diff --git a/src/codeflare_sdk/ray/rayjobs/config.py b/src/codeflare_sdk/ray/rayjobs/config.py new file mode 100644 index 00000000..24c89a64 --- /dev/null +++ b/src/codeflare_sdk/ray/rayjobs/config.py @@ -0,0 +1,455 @@ +# Copyright 2022 IBM, Red Hat +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +The config sub-module contains the definition of the RayJobClusterConfigV2 dataclass, +which is used to specify resource requirements and other details when creating a +Cluster object. +""" + +import pathlib +from dataclasses import dataclass, field, fields +from typing import Dict, List, Optional, Union, get_args, get_origin, Any +from kubernetes.client import ( + V1ConfigMapVolumeSource, + V1KeyToPath, + V1Toleration, + V1Volume, + V1VolumeMount, + V1ObjectMeta, + V1Container, + V1ContainerPort, + V1Lifecycle, + V1ExecAction, + V1LifecycleHandler, + V1EnvVar, + V1PodTemplateSpec, + V1PodSpec, + V1ResourceRequirements, +) + +import logging +from ...common.utils.constants import CUDA_RUNTIME_IMAGE, RAY_VERSION + +logger = logging.getLogger(__name__) + +dir = pathlib.Path(__file__).parent.parent.resolve() + +# https://docs.ray.io/en/latest/ray-core/scheduling/accelerators.html +DEFAULT_ACCELERATORS = { + "nvidia.com/gpu": "GPU", + "intel.com/gpu": "GPU", + "amd.com/gpu": "GPU", + "aws.amazon.com/neuroncore": "neuron_cores", + "google.com/tpu": "TPU", + "habana.ai/gaudi": "HPU", + "huawei.com/Ascend910": "NPU", + "huawei.com/Ascend310": "NPU", +} + +# Default volume mounts for CA certificates +DEFAULT_VOLUME_MOUNTS = [ + V1VolumeMount( + mount_path="/etc/pki/tls/certs/odh-trusted-ca-bundle.crt", + name="odh-trusted-ca-cert", + sub_path="odh-trusted-ca-bundle.crt", + ), + V1VolumeMount( + mount_path="/etc/ssl/certs/odh-trusted-ca-bundle.crt", + name="odh-trusted-ca-cert", + sub_path="odh-trusted-ca-bundle.crt", + ), + V1VolumeMount( + mount_path="/etc/pki/tls/certs/odh-ca-bundle.crt", + name="odh-ca-cert", + sub_path="odh-ca-bundle.crt", + ), + V1VolumeMount( + mount_path="/etc/ssl/certs/odh-ca-bundle.crt", + name="odh-ca-cert", + sub_path="odh-ca-bundle.crt", + ), +] + +# Default volumes for CA certificates +DEFAULT_VOLUMES = [ + V1Volume( + name="odh-trusted-ca-cert", + config_map=V1ConfigMapVolumeSource( + name="odh-trusted-ca-bundle", + items=[V1KeyToPath(key="ca-bundle.crt", path="odh-trusted-ca-bundle.crt")], + optional=True, + ), + ), + V1Volume( + name="odh-ca-cert", + config_map=V1ConfigMapVolumeSource( + name="odh-trusted-ca-bundle", + items=[V1KeyToPath(key="odh-ca-bundle.crt", path="odh-ca-bundle.crt")], + optional=True, + ), + ), +] + + +@dataclass +class RayJobClusterConfig: + """ + This dataclass is used to specify resource requirements and other details for RayJobs. + The cluster name and namespace are automatically derived from the RayJob configuration. + + Args: + head_accelerators: + A dictionary of extended resource requests for the head node. ex: {"nvidia.com/gpu": 1} + head_tolerations: + List of tolerations for head nodes. + num_workers: + The number of workers to create. + worker_tolerations: + List of tolerations for worker nodes. + envs: + A dictionary of environment variables to set for the cluster. + image: + The image to use for the cluster. + image_pull_secrets: + A list of image pull secrets to use for the cluster. + labels: + A dictionary of labels to apply to the cluster. + worker_accelerators: + A dictionary of extended resource requests for each worker. ex: {"nvidia.com/gpu": 1} + accelerator_configs: + A dictionary of custom resource mappings to map extended resource requests to RayCluster resource names. + Defaults to DEFAULT_ACCELERATORS but can be overridden with custom mappings. + local_queue: + The name of the queue to use for the cluster. + annotations: + A dictionary of annotations to apply to the cluster. + volumes: + A list of V1Volume objects to add to the Cluster + volume_mounts: + A list of V1VolumeMount objects to add to the Cluster + """ + + head_cpu_requests: Union[int, str] = 2 + head_cpu_limits: Union[int, str] = 2 + head_memory_requests: Union[int, str] = 8 + head_memory_limits: Union[int, str] = 8 + head_accelerators: Dict[str, Union[str, int]] = field(default_factory=dict) + head_tolerations: Optional[List[V1Toleration]] = None + worker_cpu_requests: Union[int, str] = 1 + worker_cpu_limits: Union[int, str] = 1 + num_workers: int = 1 + worker_memory_requests: Union[int, str] = 2 + worker_memory_limits: Union[int, str] = 2 + worker_tolerations: Optional[List[V1Toleration]] = None + envs: Dict[str, str] = field(default_factory=dict) + image: str = "" + image_pull_secrets: List[str] = field(default_factory=list) + labels: Dict[str, str] = field(default_factory=dict) + worker_accelerators: Dict[str, Union[str, int]] = field(default_factory=dict) + accelerator_configs: Dict[str, str] = field( + default_factory=lambda: DEFAULT_ACCELERATORS.copy() + ) + local_queue: Optional[str] = None + annotations: Dict[str, str] = field(default_factory=dict) + volumes: list[V1Volume] = field(default_factory=list) + volume_mounts: list[V1VolumeMount] = field(default_factory=list) + + def __post_init__(self): + self._validate_types() + self._memory_to_string() + self._validate_gpu_config(self.head_accelerators) + self._validate_gpu_config(self.worker_accelerators) + + def _validate_gpu_config(self, gpu_config: Dict[str, int]): + for k in gpu_config.keys(): + if k not in self.accelerator_configs.keys(): + raise ValueError( + f"GPU configuration '{k}' not found in accelerator_configs, available resources are {list(self.accelerator_configs.keys())}, to add more supported resources use accelerator_configs. i.e. accelerator_configs = {{'{k}': 'FOO_BAR'}}" + ) + + def _memory_to_string(self): + if isinstance(self.head_memory_requests, int): + self.head_memory_requests = f"{self.head_memory_requests}G" + if isinstance(self.head_memory_limits, int): + self.head_memory_limits = f"{self.head_memory_limits}G" + if isinstance(self.worker_memory_requests, int): + self.worker_memory_requests = f"{self.worker_memory_requests}G" + if isinstance(self.worker_memory_limits, int): + self.worker_memory_limits = f"{self.worker_memory_limits}G" + + def _validate_types(self): + """Validate the types of all fields in the RayJobClusterConfig dataclass.""" + errors = [] + for field_info in fields(self): + value = getattr(self, field_info.name) + expected_type = field_info.type + if not self._is_type(value, expected_type): + errors.append(f"'{field_info.name}' should be of type {expected_type}.") + + if errors: + raise TypeError("Type validation failed:\n" + "\n".join(errors)) + + @staticmethod + def _is_type(value, expected_type): + """Check if the value matches the expected type.""" + + def check_type(value, expected_type): + origin_type = get_origin(expected_type) + args = get_args(expected_type) + if origin_type is Union: + return any(check_type(value, union_type) for union_type in args) + if origin_type is list: + if value is not None: + return all(check_type(elem, args[0]) for elem in (value or [])) + else: + return True + if origin_type is dict: + if value is not None: + return all( + check_type(k, args[0]) and check_type(v, args[1]) + for k, v in value.items() + ) + else: + return True + if origin_type is tuple: + return all(check_type(elem, etype) for elem, etype in zip(value, args)) + if expected_type is int: + return isinstance(value, int) and not isinstance(value, bool) + if expected_type is bool: + return isinstance(value, bool) + return isinstance(value, expected_type) + + return check_type(value, expected_type) + + def build_ray_cluster_spec(self, cluster_name: str) -> Dict[str, Any]: + """ + Build the RayCluster spec from RayJobClusterConfig for embedding in RayJob. + + Args: + self: The cluster configuration object (RayJobClusterConfig) + cluster_name: The name for the cluster (derived from RayJob name) + + Returns: + Dict containing the RayCluster spec for embedding in RayJob + """ + ray_cluster_spec = { + "rayVersion": RAY_VERSION, + "enableInTreeAutoscaling": False, + "headGroupSpec": self._build_head_group_spec(), + "workerGroupSpecs": [self._build_worker_group_spec(cluster_name)], + } + + return ray_cluster_spec + + def _build_head_group_spec(self) -> Dict[str, Any]: + """Build the head group specification.""" + return { + "serviceType": "ClusterIP", + "enableIngress": False, + "rayStartParams": self._build_head_ray_params(), + "template": V1PodTemplateSpec( + metadata=V1ObjectMeta(annotations=self.annotations), + spec=self._build_pod_spec(self._build_head_container(), is_head=True), + ), + } + + def _build_worker_group_spec(self, cluster_name: str) -> Dict[str, Any]: + """Build the worker group specification.""" + return { + "replicas": self.num_workers, + "minReplicas": self.num_workers, + "maxReplicas": self.num_workers, + "groupName": f"worker-group-{cluster_name}", + "rayStartParams": self._build_worker_ray_params(), + "template": V1PodTemplateSpec( + metadata=V1ObjectMeta(annotations=self.annotations), + spec=self._build_pod_spec( + self._build_worker_container(), + is_head=False, + ), + ), + } + + def _build_head_ray_params(self) -> Dict[str, str]: + """Build Ray start parameters for head node.""" + params = { + "dashboard-host": "0.0.0.0", + "dashboard-port": "8265", + "block": "true", + } + + # Add GPU count if specified + if self.head_accelerators: + gpu_count = sum( + count + for resource_type, count in self.head_accelerators.items() + if "gpu" in resource_type.lower() + ) + if gpu_count > 0: + params["num-gpus"] = str(gpu_count) + + return params + + def _build_worker_ray_params(self) -> Dict[str, str]: + """Build Ray start parameters for worker nodes.""" + params = { + "block": "true", + } + + # Add GPU count if specified + if self.worker_accelerators: + gpu_count = sum( + count + for resource_type, count in self.worker_accelerators.items() + if "gpu" in resource_type.lower() + ) + if gpu_count > 0: + params["num-gpus"] = str(gpu_count) + + return params + + def _build_head_container(self) -> V1Container: + """Build the head container specification.""" + container = V1Container( + name="ray-head", + image=self.image or CUDA_RUNTIME_IMAGE, + image_pull_policy="IfNotPresent", # Always IfNotPresent for RayJobs + ports=[ + V1ContainerPort(name="gcs", container_port=6379), + V1ContainerPort(name="dashboard", container_port=8265), + V1ContainerPort(name="client", container_port=10001), + ], + lifecycle=V1Lifecycle( + pre_stop=V1LifecycleHandler( + _exec=V1ExecAction(command=["/bin/sh", "-c", "ray stop"]) + ) + ), + resources=self._build_resource_requirements( + self.head_cpu_requests, + self.head_cpu_limits, + self.head_memory_requests, + self.head_memory_limits, + self.head_accelerators, + ), + volume_mounts=self._generate_volume_mounts(), + ) + + # Add environment variables if specified + if hasattr(self, "envs") and self.envs: + container.env = self._build_env_vars() + + return container + + def _build_worker_container(self) -> V1Container: + """Build the worker container specification.""" + container = V1Container( + name="ray-worker", + image=self.image or CUDA_RUNTIME_IMAGE, + image_pull_policy="IfNotPresent", # Always IfNotPresent for RayJobs + lifecycle=V1Lifecycle( + pre_stop=V1LifecycleHandler( + _exec=V1ExecAction(command=["/bin/sh", "-c", "ray stop"]) + ) + ), + resources=self._build_resource_requirements( + self.worker_cpu_requests, + self.worker_cpu_limits, + self.worker_memory_requests, + self.worker_memory_limits, + self.worker_accelerators, + ), + volume_mounts=self._generate_volume_mounts(), + ) + + # Add environment variables if specified + if hasattr(self, "envs") and self.envs: + container.env = self._build_env_vars() + + return container + + def _build_resource_requirements( + self, + cpu_requests: Union[int, str], + cpu_limits: Union[int, str], + memory_requests: Union[int, str], + memory_limits: Union[int, str], + extended_resource_requests: Dict[str, Union[int, str]] = None, + ) -> V1ResourceRequirements: + """Build Kubernetes resource requirements.""" + resource_requirements = V1ResourceRequirements( + requests={"cpu": cpu_requests, "memory": memory_requests}, + limits={"cpu": cpu_limits, "memory": memory_limits}, + ) + + # Add extended resources (e.g., GPUs) + if extended_resource_requests: + for resource_type, amount in extended_resource_requests.items(): + resource_requirements.limits[resource_type] = amount + resource_requirements.requests[resource_type] = amount + + return resource_requirements + + def _build_pod_spec(self, container: V1Container, is_head: bool) -> V1PodSpec: + """Build the pod specification.""" + pod_spec = V1PodSpec( + containers=[container], + volumes=self._generate_volumes(), + restart_policy="Never", # RayJobs should not restart + ) + + # Add tolerations if specified + if is_head and hasattr(self, "head_tolerations") and self.head_tolerations: + pod_spec.tolerations = self.head_tolerations + elif ( + not is_head + and hasattr(self, "worker_tolerations") + and self.worker_tolerations + ): + pod_spec.tolerations = self.worker_tolerations + + # Add image pull secrets if specified + if hasattr(self, "image_pull_secrets") and self.image_pull_secrets: + from kubernetes.client import V1LocalObjectReference + + pod_spec.image_pull_secrets = [ + V1LocalObjectReference(name=secret) + for secret in self.image_pull_secrets + ] + + return pod_spec + + def _generate_volume_mounts(self) -> list: + """Generate volume mounts for the container.""" + volume_mounts = DEFAULT_VOLUME_MOUNTS.copy() + + # Add custom volume mounts if specified + if hasattr(self, "volume_mounts") and self.volume_mounts: + volume_mounts.extend(self.volume_mounts) + + return volume_mounts + + def _generate_volumes(self) -> list: + """Generate volumes for the pod.""" + volumes = DEFAULT_VOLUMES.copy() + + # Add custom volumes if specified + if hasattr(self, "volumes") and self.volumes: + volumes.extend(self.volumes) + + return volumes + + def _build_env_vars(self) -> list: + """Build environment variables list.""" + return [V1EnvVar(name=key, value=value) for key, value in self.envs.items()] diff --git a/src/codeflare_sdk/ray/rayjobs/rayjob.py b/src/codeflare_sdk/ray/rayjobs/rayjob.py index 0d966b0e..ab0899d2 100644 --- a/src/codeflare_sdk/ray/rayjobs/rayjob.py +++ b/src/codeflare_sdk/ray/rayjobs/rayjob.py @@ -13,16 +13,16 @@ # limitations under the License. """ -RayJob client for submitting and managing Ray jobs using the odh-kuberay-client. +RayJob client for submitting and managing Ray jobs using the kuberay python client. """ import logging from typing import Dict, Any, Optional, Tuple -from odh_kuberay_client.kuberay_job_api import RayjobApi +from python_client.kuberay_job_api import RayjobApi -from ..cluster.cluster import Cluster -from ..cluster.config import ClusterConfiguration -from ..cluster.build_ray_cluster import build_ray_cluster +from codeflare_sdk.ray.rayjobs.config import RayJobClusterConfig + +from ...common.utils import get_current_namespace from .status import ( RayJobDeploymentStatus, @@ -46,12 +46,12 @@ class RayJob: def __init__( self, job_name: str, + entrypoint: str, cluster_name: Optional[str] = None, - cluster_config: Optional[ClusterConfiguration] = None, - namespace: str = "default", - entrypoint: Optional[str] = None, + cluster_config: Optional[RayJobClusterConfig] = None, + namespace: Optional[str] = None, runtime_env: Optional[Dict[str, Any]] = None, - shutdown_after_job_finishes: bool = True, + shutdown_after_job_finishes: Optional[bool] = None, ttl_seconds_after_finished: int = 0, active_deadline_seconds: Optional[int] = None, ): @@ -60,50 +60,85 @@ def __init__( Args: job_name: The name for the Ray job + entrypoint: The Python script or command to run (required) cluster_name: The name of an existing Ray cluster (optional if cluster_config provided) cluster_config: Configuration for creating a new cluster (optional if cluster_name provided) - namespace: The Kubernetes namespace (default: "default") - entrypoint: The Python script or command to run (required for submission) + namespace: The Kubernetes namespace (auto-detected if not specified) runtime_env: Ray runtime environment configuration (optional) - shutdown_after_job_finishes: Whether to automatically cleanup the cluster after job completion (default: True) + shutdown_after_job_finishes: Whether to shut down cluster after job finishes (optional) ttl_seconds_after_finished: Seconds to wait before cleanup after job finishes (default: 0) active_deadline_seconds: Maximum time the job can run before being terminated (optional) + + Note: + shutdown_after_job_finishes is automatically detected but can be overridden: + - True if cluster_config is provided (new cluster will be cleaned up) + - False if cluster_name is provided (existing cluster will not be shut down) + - User can explicitly set this value to override auto-detection """ - # Validate input parameters if cluster_name is None and cluster_config is None: - raise ValueError("Either cluster_name or cluster_config must be provided") + raise ValueError( + "❌ Configuration Error: You must provide either 'cluster_name' (for existing cluster) " + "or 'cluster_config' (to create new cluster), but not both." + ) if cluster_name is not None and cluster_config is not None: - raise ValueError("Cannot specify both cluster_name and cluster_config") + raise ValueError( + "❌ Configuration Error: You cannot specify both 'cluster_name' and 'cluster_config'. " + "Choose one approach:\n" + "• Use 'cluster_name' to connect to an existing cluster\n" + "• Use 'cluster_config' to create a new cluster" + ) + + if cluster_config is None and cluster_name is None: + raise ValueError( + "❌ Configuration Error: When not providing 'cluster_config', 'cluster_name' is required " + "to specify which existing cluster to use." + ) self.name = job_name - self.namespace = namespace self.entrypoint = entrypoint self.runtime_env = runtime_env - self.shutdown_after_job_finishes = shutdown_after_job_finishes self.ttl_seconds_after_finished = ttl_seconds_after_finished self.active_deadline_seconds = active_deadline_seconds - # Cluster configuration + # Auto-set shutdown_after_job_finishes based on cluster_config presence + # If cluster_config is provided, we want to clean up the cluster after job finishes + # If using existing cluster, we don't want to shut it down + # User can override this behavior by explicitly setting shutdown_after_job_finishes + if shutdown_after_job_finishes is not None: + self.shutdown_after_job_finishes = shutdown_after_job_finishes + elif cluster_config is not None: + self.shutdown_after_job_finishes = True + else: + self.shutdown_after_job_finishes = False + + if namespace is None: + detected_namespace = get_current_namespace() + if detected_namespace: + self.namespace = detected_namespace + logger.info(f"Auto-detected namespace: {self.namespace}") + else: + raise ValueError( + "❌ Configuration Error: Could not auto-detect Kubernetes namespace. " + "Please explicitly specify the 'namespace' parameter. " + ) + else: + self.namespace = namespace + self._cluster_name = cluster_name self._cluster_config = cluster_config - # Determine cluster name for the job if cluster_config is not None: - # Ensure cluster config has the same namespace as the job - if cluster_config.namespace is None: - cluster_config.namespace = namespace - elif cluster_config.namespace != namespace: - logger.warning( - f"Cluster config namespace ({cluster_config.namespace}) differs from job namespace ({namespace})" - ) - - self.cluster_name = cluster_config.name or f"{job_name}-cluster" - # Update the cluster config name if it wasn't set - if not cluster_config.name: - cluster_config.name = self.cluster_name + self.cluster_name = f"{job_name}-cluster" + logger.info(f"Creating new cluster: {self.cluster_name}") else: + # Using existing cluster: cluster_name must be provided + if cluster_name is None: + raise ValueError( + "❌ Configuration Error: a 'cluster_name' is required when not providing 'cluster_config'" + ) self.cluster_name = cluster_name + logger.info(f"Using existing cluster: {self.cluster_name}") # Initialize the KubeRay job API client self._api = RayjobApi() @@ -111,21 +146,6 @@ def __init__( logger.info(f"Initialized RayJob: {self.name} in namespace: {self.namespace}") def submit(self) -> str: - """ - Submit the Ray job to the Kubernetes cluster. - - The RayJob CRD will automatically: - - Create a new cluster if cluster_config was provided - - Use existing cluster if cluster_name was provided - - Clean up resources based on shutdown_after_job_finishes setting - - Returns: - The job ID/name if submission was successful - - Raises: - ValueError: If entrypoint is not provided - RuntimeError: If job submission fails - """ # Validate required parameters if not self.entrypoint: raise ValueError("entrypoint must be provided to submit a RayJob") @@ -176,9 +196,16 @@ def _build_rayjob_cr(self) -> Dict[str, Any]: # Configure cluster: either use existing or create new if self._cluster_config is not None: - # Use rayClusterSpec to create a new cluster - leverage existing build logic - ray_cluster_spec = self._build_ray_cluster_spec() + ray_cluster_spec = self._cluster_config.build_ray_cluster_spec( + cluster_name=self.cluster_name + ) + + logger.info( + f"Built RayCluster spec using RayJob-specific builder for cluster: {self.cluster_name}" + ) + rayjob_cr["spec"]["rayClusterSpec"] = ray_cluster_spec + logger.info(f"RayJob will create new cluster: {self.cluster_name}") else: # Use clusterSelector to reference existing cluster @@ -187,48 +214,6 @@ def _build_rayjob_cr(self) -> Dict[str, Any]: return rayjob_cr - def _build_ray_cluster_spec(self) -> Dict[str, Any]: - """ - Build the RayCluster spec from ClusterConfiguration using existing build_ray_cluster logic. - - Returns: - Dict containing the RayCluster spec for embedding in RayJob - """ - if not self._cluster_config: - raise RuntimeError("No cluster configuration provided") - - # Create a shallow copy of the cluster config to avoid modifying the original - import copy - - temp_config = copy.copy(self._cluster_config) - - # Ensure we get a RayCluster (not AppWrapper) and don't write to file - temp_config.appwrapper = False - temp_config.write_to_file = False - - # Create a minimal Cluster object for the build process - from ..cluster.cluster import Cluster - - temp_cluster = Cluster.__new__(Cluster) # Create without calling __init__ - temp_cluster.config = temp_config - - """ - For now, RayJob with a new/auto-created cluster will not work with Kueue. - This is due to the Kueue label not being propagated to the RayCluster. - """ - - # Use the existing build_ray_cluster function to generate the RayCluster - ray_cluster_dict = build_ray_cluster(temp_cluster) - - # Extract just the RayCluster spec - RayJob CRD doesn't support metadata in rayClusterSpec - # Note: CodeFlare Operator should still create dashboard routes for the RayCluster - ray_cluster_spec = ray_cluster_dict["spec"] - - logger.info( - f"Built RayCluster spec using existing build logic for cluster: {self.cluster_name}" - ) - return ray_cluster_spec - def status( self, print_to_console: bool = True ) -> Tuple[CodeflareRayJobStatus, bool]: diff --git a/src/codeflare_sdk/ray/rayjobs/test_config.py b/src/codeflare_sdk/ray/rayjobs/test_config.py new file mode 100644 index 00000000..cefe9606 --- /dev/null +++ b/src/codeflare_sdk/ray/rayjobs/test_config.py @@ -0,0 +1,82 @@ +""" +Tests for the simplified RayJobClusterConfig accelerator_configs behavior. +""" + +import pytest +from codeflare_sdk.ray.rayjobs.config import RayJobClusterConfig, DEFAULT_ACCELERATORS + + +def test_accelerator_configs_defaults_to_default_accelerators(): + """Test that accelerator_configs defaults to DEFAULT_ACCELERATORS.copy()""" + config = RayJobClusterConfig() + + # Should have all the default accelerators + assert "nvidia.com/gpu" in config.accelerator_configs + assert "intel.com/gpu" in config.accelerator_configs + assert "google.com/tpu" in config.accelerator_configs + + # Should be a copy, not the same object + assert config.accelerator_configs is not DEFAULT_ACCELERATORS + assert config.accelerator_configs == DEFAULT_ACCELERATORS + + +def test_accelerator_configs_can_be_overridden(): + """Test that users can override accelerator_configs with custom mappings""" + custom_configs = { + "nvidia.com/gpu": "GPU", + "custom.com/accelerator": "CUSTOM_ACCELERATOR", + } + + config = RayJobClusterConfig(accelerator_configs=custom_configs) + + # Should have custom configs + assert config.accelerator_configs == custom_configs + assert "custom.com/accelerator" in config.accelerator_configs + assert "nvidia.com/gpu" in config.accelerator_configs + + # Should NOT have other defaults + assert "intel.com/gpu" not in config.accelerator_configs + assert "google.com/tpu" not in config.accelerator_configs + + +def test_accelerator_configs_can_extend_defaults(): + """Test that users can extend defaults by providing additional configs""" + extended_configs = { + **DEFAULT_ACCELERATORS, + "custom.com/accelerator": "CUSTOM_ACCEL", + } + + config = RayJobClusterConfig(accelerator_configs=extended_configs) + + # Should have all defaults plus custom + assert "nvidia.com/gpu" in config.accelerator_configs + assert "intel.com/gpu" in config.accelerator_configs + assert "custom.com/accelerator" in config.accelerator_configs + assert config.accelerator_configs["custom.com/accelerator"] == "CUSTOM_ACCEL" + + +def test_gpu_validation_works_with_defaults(): + """Test that GPU validation works with default accelerator configs""" + config = RayJobClusterConfig(head_accelerators={"nvidia.com/gpu": 1}) + + # Should not raise any errors + assert config.head_accelerators == {"nvidia.com/gpu": 1} + + +def test_gpu_validation_works_with_custom_configs(): + """Test that GPU validation works with custom accelerator configs""" + config = RayJobClusterConfig( + accelerator_configs={"custom.com/accelerator": "CUSTOM_ACCEL"}, + head_accelerators={"custom.com/accelerator": 1}, + ) + + # Should not raise any errors + assert config.head_accelerators == {"custom.com/accelerator": 1} + + +def test_gpu_validation_fails_with_unsupported_accelerator(): + """Test that GPU validation fails with unsupported accelerators""" + with pytest.raises( + ValueError, match="GPU configuration 'unsupported.com/accelerator' not found" + ): + RayJobClusterConfig(head_accelerators={"unsupported.com/accelerator": 1}) diff --git a/src/codeflare_sdk/ray/rayjobs/test_rayjob.py b/src/codeflare_sdk/ray/rayjobs/test_rayjob.py index 7554ca4c..970f0159 100644 --- a/src/codeflare_sdk/ray/rayjobs/test_rayjob.py +++ b/src/codeflare_sdk/ray/rayjobs/test_rayjob.py @@ -14,6 +14,8 @@ import pytest from unittest.mock import MagicMock, patch +from codeflare_sdk.common.utils.constants import CUDA_RUNTIME_IMAGE, RAY_VERSION + from codeflare_sdk.ray.rayjobs.rayjob import RayJob from codeflare_sdk.ray.cluster.config import ClusterConfiguration @@ -100,7 +102,8 @@ def test_rayjob_init_validation_both_provided(mocker): cluster_config = ClusterConfiguration(name="test-cluster", namespace="test") with pytest.raises( - ValueError, match="Cannot specify both cluster_name and cluster_config" + ValueError, + match="❌ Configuration Error: You cannot specify both 'cluster_name' and 'cluster_config'", ): RayJob( job_name="test-job", @@ -119,7 +122,8 @@ def test_rayjob_init_validation_neither_provided(mocker): mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi") with pytest.raises( - ValueError, match="Either cluster_name or cluster_config must be provided" + ValueError, + match="❌ Configuration Error: You must provide either 'cluster_name'", ): RayJob(job_name="test-job", entrypoint="python script.py") @@ -140,10 +144,11 @@ def test_rayjob_init_with_cluster_config(mocker): job_name="test-job", cluster_config=cluster_config, entrypoint="python script.py", + namespace="test-namespace", ) assert rayjob.name == "test-job" - assert rayjob.cluster_name == "auto-cluster" + assert rayjob.cluster_name == "test-job-cluster" # Generated from job name assert rayjob._cluster_config == cluster_config assert rayjob._cluster_name is None @@ -163,11 +168,13 @@ def test_rayjob_cluster_name_generation(mocker): ) rayjob = RayJob( - job_name="my-job", cluster_config=cluster_config, entrypoint="python script.py" + job_name="my-job", + cluster_config=cluster_config, + entrypoint="python script.py", + namespace="test-namespace", ) assert rayjob.cluster_name == "my-job-cluster" - assert cluster_config.name == "my-job-cluster" # Should be updated def test_rayjob_cluster_config_namespace_none(mocker): @@ -191,7 +198,6 @@ def test_rayjob_cluster_config_namespace_none(mocker): entrypoint="python script.py", ) - assert cluster_config.namespace == "job-namespace" assert rayjob.namespace == "job-namespace" @@ -229,15 +235,20 @@ def test_build_ray_cluster_spec_no_config_error(mocker): job_name="test-job", cluster_name="existing-cluster", entrypoint="python script.py", + namespace="test-namespace", ) - # Line 198: Should raise RuntimeError when trying to build spec without config - with pytest.raises(RuntimeError, match="No cluster configuration provided"): - rayjob._build_ray_cluster_spec() + # Since we removed _build_ray_cluster_spec method, this test is no longer applicable + # The method is now called internally by _build_rayjob_cr when needed + # We can test this by calling _build_rayjob_cr instead + rayjob_cr = rayjob._build_rayjob_cr() + + # Should use clusterSelector for existing cluster + assert rayjob_cr["spec"]["clusterSelector"]["ray.io/cluster"] == "existing-cluster" + assert "rayClusterSpec" not in rayjob_cr["spec"] -@patch("codeflare_sdk.ray.rayjobs.rayjob.build_ray_cluster") -def test_build_ray_cluster_spec(mock_build_ray_cluster, mocker): +def test_build_ray_cluster_spec(mocker): """Test _build_ray_cluster_spec method.""" mocker.patch("kubernetes.config.load_kube_config") @@ -249,34 +260,38 @@ def test_build_ray_cluster_spec(mock_build_ray_cluster, mocker): "kind": "RayCluster", "metadata": {"name": "test-cluster", "namespace": "test"}, "spec": { - "rayVersion": "2.9.0", + "rayVersion": RAY_VERSION, "headGroupSpec": {"replicas": 1}, "workerGroupSpecs": [{"replicas": 2}], }, } - mock_build_ray_cluster.return_value = mock_ray_cluster + # Use RayJobClusterConfig which has the build_ray_cluster_spec method + from codeflare_sdk.ray.rayjobs.config import RayJobClusterConfig - cluster_config = ClusterConfiguration( - name="test-cluster", namespace="test", num_workers=2 + cluster_config = RayJobClusterConfig(num_workers=2) + + # Mock the method that will be called + mocker.patch.object( + cluster_config, "build_ray_cluster_spec", return_value=mock_ray_cluster["spec"] ) rayjob = RayJob( job_name="test-job", cluster_config=cluster_config, entrypoint="python script.py", + namespace="test-namespace", ) - spec = rayjob._build_ray_cluster_spec() + # Test the integration through _build_rayjob_cr + rayjob_cr = rayjob._build_rayjob_cr() - # Should return only the spec part, not metadata - assert spec == mock_ray_cluster["spec"] - assert "metadata" not in spec + # Should have rayClusterSpec + assert "rayClusterSpec" in rayjob_cr["spec"] - # Verify build_ray_cluster was called with correct parameters - mock_build_ray_cluster.assert_called_once() - call_args = mock_build_ray_cluster.call_args[0][0] - assert call_args.config.appwrapper is False - assert call_args.config.write_to_file is False + # Verify build_ray_cluster_spec was called on the cluster config + cluster_config.build_ray_cluster_spec.assert_called_once_with( + cluster_name="test-job-cluster" + ) def test_build_rayjob_cr_with_existing_cluster(mocker): @@ -291,7 +306,6 @@ def test_build_rayjob_cr_with_existing_cluster(mocker): cluster_name="existing-cluster", namespace="test-namespace", entrypoint="python main.py", - shutdown_after_job_finishes=False, ttl_seconds_after_finished=300, ) @@ -305,6 +319,7 @@ def test_build_rayjob_cr_with_existing_cluster(mocker): # Check lifecycle parameters spec = rayjob_cr["spec"] assert spec["entrypoint"] == "python main.py" + # shutdownAfterJobFinishes should be False when using existing cluster (auto-set) assert spec["shutdownAfterJobFinishes"] is False assert spec["ttlSecondsAfterFinished"] == 300 @@ -313,8 +328,7 @@ def test_build_rayjob_cr_with_existing_cluster(mocker): assert "rayClusterSpec" not in spec -@patch("codeflare_sdk.ray.rayjobs.rayjob.build_ray_cluster") -def test_build_rayjob_cr_with_auto_cluster(mock_build_ray_cluster, mocker): +def test_build_rayjob_cr_with_auto_cluster(mocker): """Test _build_rayjob_cr method with auto-created cluster.""" mocker.patch("kubernetes.config.load_kube_config") @@ -326,19 +340,26 @@ def test_build_rayjob_cr_with_auto_cluster(mock_build_ray_cluster, mocker): "kind": "RayCluster", "metadata": {"name": "auto-cluster", "namespace": "test"}, "spec": { - "rayVersion": "2.9.0", + "rayVersion": RAY_VERSION, "headGroupSpec": {"replicas": 1}, "workerGroupSpecs": [{"replicas": 2}], }, } - mock_build_ray_cluster.return_value = mock_ray_cluster + # Use RayJobClusterConfig and mock its build_ray_cluster_spec method + from codeflare_sdk.ray.rayjobs.config import RayJobClusterConfig - cluster_config = ClusterConfiguration( - name="auto-cluster", namespace="test-namespace", num_workers=2 + cluster_config = RayJobClusterConfig(num_workers=2) + + # Mock the method that will be called + mocker.patch.object( + cluster_config, "build_ray_cluster_spec", return_value=mock_ray_cluster["spec"] ) rayjob = RayJob( - job_name="test-job", cluster_config=cluster_config, entrypoint="python main.py" + job_name="test-job", + cluster_config=cluster_config, + entrypoint="python main.py", + namespace="test-namespace", ) rayjob_cr = rayjob._build_rayjob_cr() @@ -357,6 +378,7 @@ def test_submit_validation_no_entrypoint(mocker): job_name="test-job", cluster_name="test-cluster", entrypoint=None, # No entrypoint provided + namespace="test-namespace", ) with pytest.raises( @@ -365,8 +387,7 @@ def test_submit_validation_no_entrypoint(mocker): rayjob.submit() -@patch("codeflare_sdk.ray.rayjobs.rayjob.build_ray_cluster") -def test_submit_with_auto_cluster(mock_build_ray_cluster, mocker): +def test_submit_with_auto_cluster(mocker): """Test successful submission with auto-created cluster.""" mocker.patch("kubernetes.config.load_kube_config") @@ -374,27 +395,32 @@ def test_submit_with_auto_cluster(mock_build_ray_cluster, mocker): "apiVersion": "ray.io/v1", "kind": "RayCluster", "spec": { - "rayVersion": "2.9.0", + "rayVersion": RAY_VERSION, "headGroupSpec": {"replicas": 1}, "workerGroupSpecs": [{"replicas": 1}], }, } - mock_build_ray_cluster.return_value = mock_ray_cluster - # Mock the RayjobApi mock_api_class = mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi") mock_api_instance = MagicMock() mock_api_class.return_value = mock_api_instance mock_api_instance.submit_job.return_value = True - cluster_config = ClusterConfiguration( - name="auto-cluster", namespace="test", num_workers=1 + # Use RayJobClusterConfig and mock its build_ray_cluster_spec method + from codeflare_sdk.ray.rayjobs.config import RayJobClusterConfig + + cluster_config = RayJobClusterConfig(num_workers=1) + + # Mock the method that will be called + mocker.patch.object( + cluster_config, "build_ray_cluster_spec", return_value=mock_ray_cluster["spec"] ) rayjob = RayJob( job_name="test-job", cluster_config=cluster_config, entrypoint="python script.py", + namespace="test-namespace", ) result = rayjob.submit() @@ -408,3 +434,540 @@ def test_submit_with_auto_cluster(mock_build_ray_cluster, mocker): job_cr = call_args.kwargs["job"] assert "rayClusterSpec" in job_cr["spec"] assert job_cr["spec"]["rayClusterSpec"] == mock_ray_cluster["spec"] + + +def test_namespace_auto_detection_success(mocker): + """Test successful namespace auto-detection.""" + mocker.patch( + "codeflare_sdk.ray.rayjobs.rayjob.get_current_namespace", + return_value="detected-ns", + ) + mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi") + + rayjob = RayJob( + job_name="test-job", entrypoint="python script.py", cluster_name="test-cluster" + ) + + assert rayjob.namespace == "detected-ns" + + +def test_namespace_auto_detection_fallback(mocker): + """Test that namespace auto-detection failure raises an error.""" + mocker.patch( + "codeflare_sdk.ray.rayjobs.rayjob.get_current_namespace", return_value=None + ) + mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi") + + with pytest.raises(ValueError, match="Could not auto-detect Kubernetes namespace"): + RayJob( + job_name="test-job", + entrypoint="python script.py", + cluster_name="test-cluster", + ) + + +def test_namespace_explicit_override(mocker): + """Test that explicit namespace overrides auto-detection.""" + mocker.patch( + "codeflare_sdk.ray.rayjobs.rayjob.get_current_namespace", + return_value="detected-ns", + ) + mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi") + + rayjob = RayJob( + job_name="test-job", + entrypoint="python script.py", + cluster_name="test-cluster", + namespace="explicit-ns", + ) + + assert rayjob.namespace == "explicit-ns" + + +def test_shutdown_behavior_with_cluster_config(mocker): + """Test that shutdown_after_job_finishes is True when cluster_config is provided.""" + mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi") + + from codeflare_sdk.ray.rayjobs.config import RayJobClusterConfig + + cluster_config = RayJobClusterConfig() + + rayjob = RayJob( + job_name="test-job", + entrypoint="python script.py", + cluster_config=cluster_config, + namespace="test-namespace", + ) + + assert rayjob.shutdown_after_job_finishes is True + + +def test_shutdown_behavior_with_existing_cluster(mocker): + """Test that shutdown_after_job_finishes is False when using existing cluster.""" + mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi") + + rayjob = RayJob( + job_name="test-job", + entrypoint="python script.py", + cluster_name="existing-cluster", + namespace="test-namespace", + ) + + assert rayjob.shutdown_after_job_finishes is False + + +def test_rayjob_with_rayjob_cluster_config(mocker): + """Test RayJob with the new RayJobClusterConfig.""" + mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi") + + from codeflare_sdk.ray.rayjobs.config import RayJobClusterConfig + + cluster_config = RayJobClusterConfig( + num_workers=2, + head_cpu_requests="500m", + head_memory_requests="512Mi", + ) + + rayjob = RayJob( + job_name="test-job", + entrypoint="python script.py", + cluster_config=cluster_config, + namespace="test-namespace", + ) + + assert rayjob._cluster_config == cluster_config + assert rayjob.cluster_name == "test-job-cluster" # Generated from job name + + +def test_rayjob_cluster_config_validation(mocker): + """Test validation of RayJobClusterConfig parameters.""" + mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi") + + from codeflare_sdk.ray.rayjobs.config import RayJobClusterConfig + + # Test with minimal valid config + cluster_config = RayJobClusterConfig() + + rayjob = RayJob( + job_name="test-job", + entrypoint="python script.py", + cluster_config=cluster_config, + namespace="test-namespace", + ) + + assert rayjob._cluster_config is not None + + +def test_rayjob_missing_entrypoint_validation(mocker): + """Test that RayJob requires entrypoint for submission.""" + mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi") + + # Should raise an error during construction + with pytest.raises( + TypeError, match="missing 1 required positional argument: 'entrypoint'" + ): + RayJob( + job_name="test-job", + cluster_name="test-cluster", + # No entrypoint provided + ) + + +def test_build_ray_cluster_spec_integration(mocker): + """Test integration with the new build_ray_cluster_spec method.""" + # Mock kubernetes config loading + mocker.patch("kubernetes.config.load_kube_config") + + # Mock the RayjobApi class entirely + mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi") + + from codeflare_sdk.ray.rayjobs.config import RayJobClusterConfig + + cluster_config = RayJobClusterConfig() + + # Mock the build_ray_cluster_spec method on the cluster config + mock_spec = {"spec": "test-spec"} + mocker.patch.object( + cluster_config, "build_ray_cluster_spec", return_value=mock_spec + ) + + rayjob = RayJob( + job_name="test-job", + entrypoint="python script.py", + cluster_config=cluster_config, + namespace="test-namespace", + ) + + # Build the RayJob CR + rayjob_cr = rayjob._build_rayjob_cr() + + # Verify the method was called correctly + cluster_config.build_ray_cluster_spec.assert_called_once_with( + cluster_name="test-job-cluster" + ) + + # Verify the spec is included in the RayJob CR + assert "rayClusterSpec" in rayjob_cr["spec"] + assert rayjob_cr["spec"]["rayClusterSpec"] == mock_spec + + +def test_rayjob_with_runtime_env(mocker): + """Test RayJob with runtime environment configuration.""" + mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi") + + runtime_env = {"pip": ["numpy", "pandas"]} + + rayjob = RayJob( + job_name="test-job", + entrypoint="python script.py", + cluster_name="test-cluster", + runtime_env=runtime_env, + namespace="test-namespace", + ) + + assert rayjob.runtime_env == runtime_env + + # Verify runtime env is included in the CR + rayjob_cr = rayjob._build_rayjob_cr() + assert rayjob_cr["spec"]["runtimeEnvYAML"] == str(runtime_env) + + +def test_rayjob_with_active_deadline_and_ttl(mocker): + """Test RayJob with both active deadline and TTL settings.""" + mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi") + + rayjob = RayJob( + job_name="test-job", + entrypoint="python script.py", + cluster_name="test-cluster", + active_deadline_seconds=300, + ttl_seconds_after_finished=600, + namespace="test-namespace", + ) + + assert rayjob.active_deadline_seconds == 300 + assert rayjob.ttl_seconds_after_finished == 600 + + # Verify both are included in the CR + rayjob_cr = rayjob._build_rayjob_cr() + assert rayjob_cr["spec"]["activeDeadlineSeconds"] == 300 + assert rayjob_cr["spec"]["ttlSecondsAfterFinished"] == 600 + + +def test_rayjob_cluster_name_generation_with_config(mocker): + """Test cluster name generation when using cluster_config.""" + mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi") + + from codeflare_sdk.ray.rayjobs.config import RayJobClusterConfig + + cluster_config = RayJobClusterConfig() + + rayjob = RayJob( + job_name="my-job", + entrypoint="python script.py", + cluster_config=cluster_config, + namespace="test-namespace", # Explicitly specify namespace + ) + + assert rayjob.cluster_name == "my-job-cluster" + # Note: cluster_config.name is not set in RayJob (it's only for resource config) + # The cluster name is generated independently for the RayJob + + +def test_rayjob_namespace_propagation_to_cluster_config(mocker): + """Test that job namespace is propagated to cluster config when None.""" + mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi") + + from codeflare_sdk.ray.rayjobs.rayjob import get_current_namespace + + mocker.patch( + "codeflare_sdk.ray.rayjobs.rayjob.get_current_namespace", + return_value="detected-ns", + ) + + from codeflare_sdk.ray.rayjobs.config import RayJobClusterConfig + + cluster_config = RayJobClusterConfig() + + rayjob = RayJob( + job_name="test-job", + entrypoint="python script.py", + cluster_config=cluster_config, + ) + + assert rayjob.namespace == "detected-ns" + + +def test_rayjob_error_handling_invalid_cluster_config(mocker): + """Test error handling with invalid cluster configuration.""" + mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi") + + with pytest.raises(ValueError): + RayJob( + job_name="test-job", + entrypoint="python script.py", + ) + + +def test_rayjob_constructor_parameter_validation(mocker): + """Test constructor parameter validation.""" + mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi") + + # Test with valid parameters + rayjob = RayJob( + job_name="test-job", + entrypoint="python script.py", + cluster_name="test-cluster", + namespace="test-ns", + runtime_env={"pip": ["numpy"]}, + ttl_seconds_after_finished=300, + active_deadline_seconds=600, + ) + + assert rayjob.name == "test-job" + assert rayjob.entrypoint == "python script.py" + assert rayjob.cluster_name == "test-cluster" + assert rayjob.namespace == "test-ns" + assert rayjob.runtime_env == {"pip": ["numpy"]} + assert rayjob.ttl_seconds_after_finished == 300 + assert rayjob.active_deadline_seconds == 600 + + +def test_build_ray_cluster_spec_function(mocker): + """Test the build_ray_cluster_spec method directly.""" + from codeflare_sdk.ray.rayjobs.config import RayJobClusterConfig + + # Create a test cluster config + cluster_config = RayJobClusterConfig( + num_workers=2, + head_cpu_requests="500m", + head_memory_requests="512Mi", + worker_cpu_requests="250m", + worker_memory_requests="256Mi", + ) + + # Build the spec using the method on the cluster config + spec = cluster_config.build_ray_cluster_spec("test-cluster") + + # Verify basic structure + assert "rayVersion" in spec + assert "enableInTreeAutoscaling" in spec + assert "headGroupSpec" in spec + assert "workerGroupSpecs" in spec + + # Verify head group spec + head_spec = spec["headGroupSpec"] + assert head_spec["serviceType"] == "ClusterIP" + assert head_spec["enableIngress"] is False + assert "rayStartParams" in head_spec + assert "template" in head_spec + + # Verify worker group spec + worker_specs = spec["workerGroupSpecs"] + assert len(worker_specs) == 1 + worker_spec = worker_specs[0] + assert worker_spec["replicas"] == 2 + assert worker_spec["minReplicas"] == 2 + assert worker_spec["maxReplicas"] == 2 + assert worker_spec["groupName"] == "worker-group-test-cluster" + + +def test_build_ray_cluster_spec_with_accelerators(mocker): + """Test build_ray_cluster_spec with GPU accelerators.""" + from codeflare_sdk.ray.rayjobs.config import RayJobClusterConfig + + # Create a test cluster config with GPU accelerators + cluster_config = RayJobClusterConfig( + head_accelerators={"nvidia.com/gpu": 1}, + worker_accelerators={"nvidia.com/gpu": 2}, + ) + + # Build the spec using the method on the cluster config + spec = cluster_config.build_ray_cluster_spec("test-cluster") + + # Verify head group has GPU parameters + head_spec = spec["headGroupSpec"] + head_params = head_spec["rayStartParams"] + assert "num-gpus" in head_params + assert head_params["num-gpus"] == "1" + + # Verify worker group has GPU parameters + worker_specs = spec["workerGroupSpecs"] + worker_spec = worker_specs[0] + worker_params = worker_spec["rayStartParams"] + assert "num-gpus" in worker_params + assert worker_params["num-gpus"] == "2" + + +def test_build_ray_cluster_spec_with_custom_volumes(mocker): + """Test build_ray_cluster_spec with custom volumes and volume mounts.""" + from codeflare_sdk.ray.rayjobs.config import RayJobClusterConfig + from kubernetes.client import V1Volume, V1VolumeMount + + # Create custom volumes and volume mounts + custom_volume = V1Volume(name="custom-data", empty_dir={}) + custom_volume_mount = V1VolumeMount(name="custom-data", mount_path="/data") + + # Create a test cluster config with custom volumes + cluster_config = RayJobClusterConfig( + volumes=[custom_volume], + volume_mounts=[custom_volume_mount], + ) + + # Build the spec using the method on the cluster config + spec = cluster_config.build_ray_cluster_spec("test-cluster") + + # Verify custom volumes are included + head_spec = spec["headGroupSpec"] + head_pod_spec = head_spec["template"].spec # Access the spec attribute + # Note: We can't easily check DEFAULT_VOLUMES length since they're now part of the class + assert len(head_pod_spec.volumes) > 0 + + # Verify custom volume mounts are included + head_container = head_pod_spec.containers[0] # Access the containers attribute + # Note: We can't easily check DEFAULT_VOLUME_MOUNTS length since they're now part of the class + assert len(head_container.volume_mounts) > 0 + + +def test_build_ray_cluster_spec_with_environment_variables(mocker): + """Test build_ray_cluster_spec with environment variables.""" + from codeflare_sdk.ray.rayjobs.config import RayJobClusterConfig + + # Create a test cluster config with environment variables + cluster_config = RayJobClusterConfig( + envs={"CUDA_VISIBLE_DEVICES": "0", "RAY_DISABLE_IMPORT_WARNING": "1"}, + ) + + spec = cluster_config.build_ray_cluster_spec("test-cluster") + + # Verify environment variables are included in head container + head_spec = spec["headGroupSpec"] + head_pod_spec = head_spec["template"].spec + head_container = head_pod_spec.containers[0] + assert hasattr(head_container, "env") + env_vars = {env.name: env.value for env in head_container.env} + assert env_vars["CUDA_VISIBLE_DEVICES"] == "0" + assert env_vars["RAY_DISABLE_IMPORT_WARNING"] == "1" + + # Verify environment variables are included in worker container + worker_specs = spec["workerGroupSpecs"] + worker_spec = worker_specs[0] + worker_pod_spec = worker_spec["template"].spec + worker_container = worker_pod_spec.containers[0] + + assert hasattr(worker_container, "env") + worker_env_vars = {env.name: env.value for env in worker_container.env} + assert worker_env_vars["CUDA_VISIBLE_DEVICES"] == "0" + assert worker_env_vars["RAY_DISABLE_IMPORT_WARNING"] == "1" + + +def test_build_ray_cluster_spec_with_tolerations(mocker): + """Test build_ray_cluster_spec with tolerations.""" + from codeflare_sdk.ray.rayjobs.config import RayJobClusterConfig + from kubernetes.client import V1Toleration + + # Create test tolerations + head_toleration = V1Toleration( + key="node-role.kubernetes.io/master", operator="Exists", effect="NoSchedule" + ) + worker_toleration = V1Toleration( + key="nvidia.com/gpu", operator="Exists", effect="NoSchedule" + ) + + # Create a test cluster config with tolerations + cluster_config = RayJobClusterConfig( + head_tolerations=[head_toleration], + worker_tolerations=[worker_toleration], + ) + + spec = cluster_config.build_ray_cluster_spec("test-cluster") + + # Verify head tolerations + head_spec = spec["headGroupSpec"] + head_pod_spec = head_spec["template"].spec # Access the spec attribute + assert hasattr(head_pod_spec, "tolerations") + assert len(head_pod_spec.tolerations) == 1 + assert head_pod_spec.tolerations[0].key == "node-role.kubernetes.io/master" + + # Verify worker tolerations + worker_specs = spec["workerGroupSpecs"] + worker_spec = worker_specs[0] + worker_pod_spec = worker_spec["template"].spec # Access the spec attribute + assert hasattr(worker_pod_spec, "tolerations") + assert len(worker_pod_spec.tolerations) == 1 + assert worker_pod_spec.tolerations[0].key == "nvidia.com/gpu" + + +def test_build_ray_cluster_spec_with_image_pull_secrets(mocker): + """Test build_ray_cluster_spec with image pull secrets.""" + from codeflare_sdk.ray.rayjobs.config import RayJobClusterConfig + + # Create a test cluster config with image pull secrets + cluster_config = RayJobClusterConfig( + image_pull_secrets=["my-registry-secret", "another-secret"] + ) + + spec = cluster_config.build_ray_cluster_spec("test-cluster") + + # Verify image pull secrets are included in head pod + head_spec = spec["headGroupSpec"] + head_pod_spec = head_spec["template"].spec # Access the spec attribute + assert hasattr(head_pod_spec, "image_pull_secrets") + + head_secrets = head_pod_spec.image_pull_secrets + assert len(head_secrets) == 2 + assert head_secrets[0].name == "my-registry-secret" + assert head_secrets[1].name == "another-secret" + + # Verify image pull secrets are included in worker pod + worker_specs = spec["workerGroupSpecs"] + worker_spec = worker_specs[0] + worker_pod_spec = worker_spec["template"].spec + assert hasattr(worker_pod_spec, "image_pull_secrets") + + worker_secrets = worker_pod_spec.image_pull_secrets + assert len(worker_secrets) == 2 + assert worker_secrets[0].name == "my-registry-secret" + assert worker_secrets[1].name == "another-secret" + + +def test_rayjob_user_override_shutdown_behavior(mocker): + """Test that user can override the auto-detected shutdown behavior.""" + mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi") + + # Test 1: User overrides shutdown to True even when using existing cluster + rayjob_existing_override = RayJob( + job_name="test-job", + entrypoint="python script.py", + cluster_name="existing-cluster", + shutdown_after_job_finishes=True, # User override + namespace="test-namespace", # Explicitly specify namespace + ) + + assert rayjob_existing_override.shutdown_after_job_finishes is True + + # Test 2: User overrides shutdown to False even when creating new cluster + from codeflare_sdk.ray.rayjobs.config import RayJobClusterConfig + + cluster_config = RayJobClusterConfig() + + rayjob_new_override = RayJob( + job_name="test-job", + entrypoint="python script.py", + cluster_config=cluster_config, + shutdown_after_job_finishes=False, # User override + namespace="test-namespace", # Explicitly specify namespace + ) + + assert rayjob_new_override.shutdown_after_job_finishes is False + + # Test 3: User override takes precedence over auto-detection + rayjob_override_priority = RayJob( + job_name="test-job", + entrypoint="python script.py", + cluster_config=cluster_config, + shutdown_after_job_finishes=True, # Should override auto-detection + namespace="test-namespace", # Explicitly specify namespace + ) + + assert rayjob_override_priority.shutdown_after_job_finishes is True