From 4d201068448d76d2787ab3ac61327f2cf01b2142 Mon Sep 17 00:00:00 2001 From: Joshua Towner Date: Thu, 14 May 2026 13:12:38 -0700 Subject: [PATCH] feat: replace eval Rich spinner with plain print observability (matching training format) --- .../src/sagemaker/train/evaluate/execution.py | 170 +++++++++--------- .../evaluate/test_execution_observability.py | 107 +++++++++++ 2 files changed, 196 insertions(+), 81 deletions(-) create mode 100644 sagemaker-train/tests/unit/train/evaluate/test_execution_observability.py diff --git a/sagemaker-train/src/sagemaker/train/evaluate/execution.py b/sagemaker-train/src/sagemaker/train/evaluate/execution.py index 6dd7043406..e37cbddc84 100644 --- a/sagemaker-train/src/sagemaker/train/evaluate/execution.py +++ b/sagemaker-train/src/sagemaker/train/evaluate/execution.py @@ -1161,87 +1161,95 @@ def get_cached_mlflow_url(): time.sleep(poll) else: - # Terminal experience with rich library - try: - from rich.live import Live - from rich.panel import Panel - from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn - from rich.console import Group - from rich.status import Status - from rich.style import Style - - progress = Progress( - SpinnerColumn("bouncingBar"), - TextColumn("{task.description}"), - TimeElapsedColumn(), - ) - progress.add_task(f"Waiting for PipelineExecution to reach [bold]{target_status}[/bold] status...") - status = Status("Current status:") - - with Live( - Panel( - Group(progress, status), - title="Wait Log Panel", - border_style=Style(color="blue"), - ), - transient=True, - ): - while True: - self.refresh() - current_status = self.status.overall_status - status.update(f"Current status: [bold]{current_status}[/bold]") - - if target_status == current_status: - logger.info(f"Final Resource Status: [bold]{current_status}[/bold]") - return - - if "failed" in current_status.lower(): - from sagemaker.core.utils.exceptions import FailedStatusError - raise FailedStatusError( - resource_type="PipelineExecution", - status=current_status, - reason=self.status.failure_reason, - ) - - if timeout is not None and time.time() - start_time >= timeout: - from sagemaker.core.utils.exceptions import TimeoutExceededError - raise TimeoutExceededError( - resource_type="EvaluationJob", - status=current_status, - message="Your evaluation job is still running. Call .refresh() to check its current status.", - ) - - time.sleep(poll) - except ImportError: - # Fallback to simple print-based progress if rich is not available - logger.info(f"Waiting for PipelineExecution to reach {target_status} status...") - while True: - self.refresh() - current_status = self.status.overall_status - elapsed = time.time() - start_time - print(f"Current status: {current_status} (Elapsed: {elapsed:.1f}s)") - - if target_status == current_status: - logger.info(f"Final Resource Status: {current_status}") - return - - if "failed" in current_status.lower(): - from sagemaker.core.utils.exceptions import FailedStatusError - raise FailedStatusError( - resource_type="PipelineExecution", - status=current_status, - reason=self.status.failure_reason, - ) - - if timeout is not None and elapsed >= timeout: - from sagemaker.core.utils.exceptions import TimeoutExceededError - raise TimeoutExceededError( - resource_type="EvaluationJob", - status=current_status, - message="Your evaluation job is still running. Call .refresh() to check its current status.", - ) - - time.sleep(poll) + # Terminal experience - plain print matching training format + pipeline_name = None + if self.arn: + arn_parts = self.arn.split('/') + if len(arn_parts) >= 3: + pipeline_name = arn_parts[-3] + print(f"\nEvaluation started: {self.name}", flush=True) + if pipeline_name: + print(f"Pipeline: {pipeline_name}", flush=True) + print(f"Execution ARN: {self.arn}", flush=True) + + while True: + self.refresh() + current_status = self.status.overall_status + elapsed = time.time() - start_time + + # Print step transitions (matching training format) + if self.status.step_details: + print("\n--------------------------------------\n", flush=True) + print("Status Transitions:", flush=True) + for step in self.status.step_details: + duration = "" + check = "" + if step.start_time and step.end_time: + try: + from datetime import datetime + start_dt = datetime.fromisoformat(step.start_time.replace('Z', '+00:00')) + end_dt = datetime.fromisoformat(step.end_time.replace('Z', '+00:00')) + duration = f"({(end_dt - start_dt).total_seconds():.1f}s)" + except Exception: + pass + check = "✓" + elif step.start_time: + try: + from datetime import datetime, timezone + start_dt = datetime.fromisoformat(step.start_time.replace('Z', '+00:00')) + running_secs = (datetime.now(timezone.utc) - start_dt).total_seconds() + duration = f"(Running... {running_secs:.0f}s)" + except Exception: + duration = "(Running...)" + check = "⋯" + + step_msg = f" {check} {step.display_name or step.name}: {step.status} {duration}" + print(step_msg, flush=True) + if step.job_arn and ('executing' in step.status.lower() or 'failed' in step.status.lower()): + print(f" Job ARN: {step.job_arn}", flush=True) + + print(f"\nStatus: {current_status} (Elapsed: {elapsed:.1f}s)", flush=True) + + if target_status == current_status: + if self.s3_output_path: + print(f"\nResults S3: {self.s3_output_path}", flush=True) + return + + if "failed" in current_status.lower(): + if self.status.failure_reason: + print(f"\nFailure reason: {self.status.failure_reason}", flush=True) + if self.status.step_details: + for step in self.status.step_details: + if 'failed' in step.status.lower(): + print(f"\nFailed step: {step.display_name or step.name}", flush=True) + if step.failure_reason: + print(f"Failure reason: {step.failure_reason}", flush=True) + if step.job_arn: + print(f"Job ARN: {step.job_arn}", flush=True) + job_name = step.job_arn.split('/')[-1] if '/' in step.job_arn else '' + if job_name: + print(f"Log group: /aws/sagemaker/TrainingJobs", flush=True) + print(f"Log stream prefix: {job_name}", flush=True) + from sagemaker.train.common_utils.metrics_visualizer import get_cloudwatch_logs_url + cw_url = get_cloudwatch_logs_url(step.job_arn) + if cw_url: + print(f"CloudWatch Logs: {cw_url}", flush=True) + from sagemaker.core.utils.exceptions import FailedStatusError + raise FailedStatusError( + resource_type="PipelineExecution", + status=current_status, + reason=self.status.failure_reason, + ) + + if timeout is not None and time.time() - start_time >= timeout: + from sagemaker.core.utils.exceptions import TimeoutExceededError + raise TimeoutExceededError( + resource_type="EvaluationJob", + status=current_status, + message="Your evaluation job is still running. Call .refresh() to check its current status.", + ) + + time.sleep(poll) def _enrich_with_step_details( self, diff --git a/sagemaker-train/tests/unit/train/evaluate/test_execution_observability.py b/sagemaker-train/tests/unit/train/evaluate/test_execution_observability.py new file mode 100644 index 0000000000..09a842177f --- /dev/null +++ b/sagemaker-train/tests/unit/train/evaluate/test_execution_observability.py @@ -0,0 +1,107 @@ +"""Tests for eval pipeline observability prints in terminal mode.""" +from unittest.mock import patch, MagicMock + +import pytest + +from sagemaker.train.evaluate.execution import ( + EvaluationPipelineExecution, + PipelineExecutionStatus, + StepDetail, +) + + +def _make_execution(status="Succeeded", step_details=None, failure_reason=None, s3_output_path=None): + exec_obj = EvaluationPipelineExecution( + name="benchmark-eval-mmlu", + arn="arn:aws:sagemaker:us-west-2:123456789:pipeline/sm-eval-benchmark-abc/execution/exec-123", + status=PipelineExecutionStatus( + overall_status=status, + step_details=step_details or [], + failure_reason=failure_reason, + ), + s3_output_path=s3_output_path, + ) + exec_obj._pipeline_execution = MagicMock() + return exec_obj + + +class TestEvalObservabilityAtStart: + @patch("sagemaker.train.evaluate.execution.time.sleep") + @patch.object(EvaluationPipelineExecution, "refresh") + def test_prints_pipeline_info_at_start(self, mock_refresh, mock_sleep, capsys): + exec_obj = _make_execution(status="Succeeded") + exec_obj.wait(poll=0, timeout=1) + captured = capsys.readouterr() + assert "Evaluation started: benchmark-eval-mmlu" in captured.out + assert "Pipeline: sm-eval-benchmark-abc" in captured.out + assert "Execution ARN:" in captured.out + + +class TestEvalObservabilityStepTransitions: + @patch("sagemaker.train.evaluate.execution.time.sleep") + @patch.object(EvaluationPipelineExecution, "refresh") + def test_prints_step_transitions(self, mock_refresh, mock_sleep, capsys): + steps = [ + StepDetail(name="EvaluateBaseModel", status="Succeeded", display_name="EvaluateBaseModel", + start_time="2026-01-01T00:00:00Z", end_time="2026-01-01T00:01:00Z", + job_arn="arn:aws:sagemaker:us-west-2:123456789:training-job/eval-base-xyz"), + ] + exec_obj = _make_execution(status="Succeeded", step_details=steps) + exec_obj.wait(poll=0, timeout=1) + captured = capsys.readouterr() + assert "✓ EvaluateBaseModel: Succeeded" in captured.out + assert "(60.0s)" in captured.out + + @patch("sagemaker.train.evaluate.execution.time.sleep") + @patch.object(EvaluationPipelineExecution, "refresh") + def test_prints_job_arn_for_executing_step(self, mock_refresh, mock_sleep, capsys): + steps = [ + StepDetail(name="EvaluateCustomModel", status="Executing", display_name="EvaluateCustomModel", + start_time="2026-01-01T00:00:00Z", + job_arn="arn:aws:sagemaker:us-west-2:123456789:training-job/eval-custom-xyz"), + ] + # First poll shows Executing, then Succeeded + call_count = [0] + def side_effect(): + call_count[0] += 1 + if call_count[0] > 1: + exec_obj.status.overall_status = "Succeeded" + exec_obj.status.step_details[0].status = "Succeeded" + exec_obj.status.step_details[0].end_time = "2026-01-01T00:01:00Z" + exec_obj = _make_execution(status="Executing", step_details=steps) + mock_refresh.side_effect = side_effect + exec_obj.wait(poll=0, timeout=5) + captured = capsys.readouterr() + assert "Job ARN: arn:aws:sagemaker:us-west-2:123456789:training-job/eval-custom-xyz" in captured.out + + +class TestEvalObservabilityOnSuccess: + @patch("sagemaker.train.evaluate.execution.time.sleep") + @patch.object(EvaluationPipelineExecution, "refresh") + def test_prints_s3_output_on_success(self, mock_refresh, mock_sleep, capsys): + exec_obj = _make_execution(status="Succeeded", s3_output_path="s3://bucket/eval-results/") + exec_obj.wait(poll=0, timeout=1) + captured = capsys.readouterr() + assert "Results S3: s3://bucket/eval-results/" in captured.out + + +class TestEvalObservabilityOnFailure: + @patch("sagemaker.train.evaluate.execution.time.sleep") + @patch.object(EvaluationPipelineExecution, "refresh") + def test_prints_failed_step_info(self, mock_refresh, mock_sleep, capsys): + steps = [ + StepDetail(name="EvaluateCustomModel", status="Failed", + display_name="EvaluateCustomModel", + failure_reason="ResourceLimitExceeded", + job_arn="arn:aws:sagemaker:us-west-2:123456789:training-job/eval-custom-xyz"), + ] + exec_obj = _make_execution(status="Failed", step_details=steps, failure_reason="Step failed") + with pytest.raises(Exception): + exec_obj.wait(poll=0, timeout=1) + captured = capsys.readouterr() + assert "Failed step: EvaluateCustomModel" in captured.out + assert "ResourceLimitExceeded" in captured.out + assert "Job ARN: arn:aws:sagemaker:us-west-2:123456789:training-job/eval-custom-xyz" in captured.out + assert "Log group: /aws/sagemaker/TrainingJobs" in captured.out + assert "Log stream prefix: eval-custom-xyz" in captured.out + assert "CloudWatch Logs:" in captured.out