Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 89 additions & 81 deletions sagemaker-train/src/sagemaker/train/evaluate/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1161,87 +1161,95 @@ def get_cached_mlflow_url():

time.sleep(poll)
else:
# Terminal experience with rich library
try:
from rich.live import Live
from rich.panel import Panel
from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
from rich.console import Group
from rich.status import Status
from rich.style import Style

progress = Progress(
SpinnerColumn("bouncingBar"),
TextColumn("{task.description}"),
TimeElapsedColumn(),
)
progress.add_task(f"Waiting for PipelineExecution to reach [bold]{target_status}[/bold] status...")
status = Status("Current status:")

with Live(
Panel(
Group(progress, status),
title="Wait Log Panel",
border_style=Style(color="blue"),
),
transient=True,
):
while True:
self.refresh()
current_status = self.status.overall_status
status.update(f"Current status: [bold]{current_status}[/bold]")

if target_status == current_status:
logger.info(f"Final Resource Status: [bold]{current_status}[/bold]")
return

if "failed" in current_status.lower():
from sagemaker.core.utils.exceptions import FailedStatusError
raise FailedStatusError(
resource_type="PipelineExecution",
status=current_status,
reason=self.status.failure_reason,
)

if timeout is not None and time.time() - start_time >= timeout:
from sagemaker.core.utils.exceptions import TimeoutExceededError
raise TimeoutExceededError(
resource_type="EvaluationJob",
status=current_status,
message="Your evaluation job is still running. Call .refresh() to check its current status.",
)

time.sleep(poll)
except ImportError:
# Fallback to simple print-based progress if rich is not available
logger.info(f"Waiting for PipelineExecution to reach {target_status} status...")
while True:
self.refresh()
current_status = self.status.overall_status
elapsed = time.time() - start_time
print(f"Current status: {current_status} (Elapsed: {elapsed:.1f}s)")

if target_status == current_status:
logger.info(f"Final Resource Status: {current_status}")
return

if "failed" in current_status.lower():
from sagemaker.core.utils.exceptions import FailedStatusError
raise FailedStatusError(
resource_type="PipelineExecution",
status=current_status,
reason=self.status.failure_reason,
)

if timeout is not None and elapsed >= timeout:
from sagemaker.core.utils.exceptions import TimeoutExceededError
raise TimeoutExceededError(
resource_type="EvaluationJob",
status=current_status,
message="Your evaluation job is still running. Call .refresh() to check its current status.",
)

time.sleep(poll)
# Terminal experience - plain print matching training format
pipeline_name = None
if self.arn:
arn_parts = self.arn.split('/')
if len(arn_parts) >= 3:
pipeline_name = arn_parts[-3]
print(f"\nEvaluation started: {self.name}")
if pipeline_name:
print(f"Pipeline: {pipeline_name}")
print(f"Execution ARN: {self.arn}")

while True:
self.refresh()
current_status = self.status.overall_status
elapsed = time.time() - start_time

# Print step transitions (matching training format)
if self.status.step_details:
print("\n--------------------------------------\n")
print("Status Transitions:")
for step in self.status.step_details:
duration = ""
check = ""
if step.start_time and step.end_time:
try:
from datetime import datetime
start_dt = datetime.fromisoformat(step.start_time.replace('Z', '+00:00'))
end_dt = datetime.fromisoformat(step.end_time.replace('Z', '+00:00'))
duration = f"({(end_dt - start_dt).total_seconds():.1f}s)"
except Exception:
pass
check = "✓"
elif step.start_time:
try:
from datetime import datetime, timezone
start_dt = datetime.fromisoformat(step.start_time.replace('Z', '+00:00'))
running_secs = (datetime.now(timezone.utc) - start_dt).total_seconds()
duration = f"(Running... {running_secs:.0f}s)"
except Exception:
duration = "(Running...)"
check = "⋯"

step_msg = f" {check} {step.display_name or step.name}: {step.status} {duration}"
print(step_msg)
if step.job_arn and ('executing' in step.status.lower() or 'failed' in step.status.lower()):
print(f" Job ARN: {step.job_arn}")

print(f"\nStatus: {current_status} (Elapsed: {elapsed:.1f}s)")

if target_status == current_status:
if self.s3_output_path:
print(f"\nResults S3: {self.s3_output_path}")
return

if "failed" in current_status.lower():
if self.status.failure_reason:
print(f"\nFailure reason: {self.status.failure_reason}")
if self.status.step_details:
for step in self.status.step_details:
if 'failed' in step.status.lower():
print(f"\nFailed step: {step.display_name or step.name}")
if step.failure_reason:
print(f"Failure reason: {step.failure_reason}")
if step.job_arn:
print(f"Job ARN: {step.job_arn}")
job_name = step.job_arn.split('/')[-1] if '/' in step.job_arn else ''
if job_name:
print(f"Log group: /aws/sagemaker/TrainingJobs")
print(f"Log stream prefix: {job_name}")
from sagemaker.train.common_utils.metrics_visualizer import get_cloudwatch_logs_url
cw_url = get_cloudwatch_logs_url(step.job_arn)
if cw_url:
print(f"CloudWatch Logs: {cw_url}")
from sagemaker.core.utils.exceptions import FailedStatusError
raise FailedStatusError(
resource_type="PipelineExecution",
status=current_status,
reason=self.status.failure_reason,
)

if timeout is not None and time.time() - start_time >= timeout:
from sagemaker.core.utils.exceptions import TimeoutExceededError
raise TimeoutExceededError(
resource_type="EvaluationJob",
status=current_status,
message="Your evaluation job is still running. Call .refresh() to check its current status.",
)

time.sleep(poll)

def _enrich_with_step_details(
self,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""Tests for eval pipeline observability prints in terminal mode."""
from unittest.mock import patch, MagicMock

import pytest

from sagemaker.train.evaluate.execution import (
EvaluationPipelineExecution,
PipelineExecutionStatus,
StepDetail,
)


def _make_execution(status="Succeeded", step_details=None, failure_reason=None, s3_output_path=None):
exec_obj = EvaluationPipelineExecution(
name="benchmark-eval-mmlu",
arn="arn:aws:sagemaker:us-west-2:123456789:pipeline/sm-eval-benchmark-abc/execution/exec-123",
status=PipelineExecutionStatus(
overall_status=status,
step_details=step_details or [],
failure_reason=failure_reason,
),
s3_output_path=s3_output_path,
)
exec_obj._pipeline_execution = MagicMock()
return exec_obj


class TestEvalObservabilityAtStart:
@patch("sagemaker.train.evaluate.execution.time.sleep")
@patch.object(EvaluationPipelineExecution, "refresh")
def test_prints_pipeline_info_at_start(self, mock_refresh, mock_sleep, capsys):
exec_obj = _make_execution(status="Succeeded")
exec_obj.wait(poll=0, timeout=1)
captured = capsys.readouterr()
assert "Evaluation started: benchmark-eval-mmlu" in captured.out
assert "Pipeline: sm-eval-benchmark-abc" in captured.out
assert "Execution ARN:" in captured.out


class TestEvalObservabilityStepTransitions:
@patch("sagemaker.train.evaluate.execution.time.sleep")
@patch.object(EvaluationPipelineExecution, "refresh")
def test_prints_step_transitions(self, mock_refresh, mock_sleep, capsys):
steps = [
StepDetail(name="EvaluateBaseModel", status="Succeeded", display_name="EvaluateBaseModel",
start_time="2026-01-01T00:00:00Z", end_time="2026-01-01T00:01:00Z",
job_arn="arn:aws:sagemaker:us-west-2:123456789:training-job/eval-base-xyz"),
]
exec_obj = _make_execution(status="Succeeded", step_details=steps)
exec_obj.wait(poll=0, timeout=1)
captured = capsys.readouterr()
assert "✓ EvaluateBaseModel: Succeeded" in captured.out
assert "(60.0s)" in captured.out

@patch("sagemaker.train.evaluate.execution.time.sleep")
@patch.object(EvaluationPipelineExecution, "refresh")
def test_prints_job_arn_for_executing_step(self, mock_refresh, mock_sleep, capsys):
steps = [
StepDetail(name="EvaluateCustomModel", status="Executing", display_name="EvaluateCustomModel",
start_time="2026-01-01T00:00:00Z",
job_arn="arn:aws:sagemaker:us-west-2:123456789:training-job/eval-custom-xyz"),
]
# First poll shows Executing, then Succeeded
call_count = [0]
def side_effect():
call_count[0] += 1
if call_count[0] > 1:
exec_obj.status.overall_status = "Succeeded"
exec_obj.status.step_details[0].status = "Succeeded"
exec_obj.status.step_details[0].end_time = "2026-01-01T00:01:00Z"
exec_obj = _make_execution(status="Executing", step_details=steps)
mock_refresh.side_effect = side_effect
exec_obj.wait(poll=0, timeout=5)
captured = capsys.readouterr()
assert "Job ARN: arn:aws:sagemaker:us-west-2:123456789:training-job/eval-custom-xyz" in captured.out


class TestEvalObservabilityOnSuccess:
@patch("sagemaker.train.evaluate.execution.time.sleep")
@patch.object(EvaluationPipelineExecution, "refresh")
def test_prints_s3_output_on_success(self, mock_refresh, mock_sleep, capsys):
exec_obj = _make_execution(status="Succeeded", s3_output_path="s3://bucket/eval-results/")
exec_obj.wait(poll=0, timeout=1)
captured = capsys.readouterr()
assert "Results S3: s3://bucket/eval-results/" in captured.out


class TestEvalObservabilityOnFailure:
@patch("sagemaker.train.evaluate.execution.time.sleep")
@patch.object(EvaluationPipelineExecution, "refresh")
def test_prints_failed_step_info(self, mock_refresh, mock_sleep, capsys):
steps = [
StepDetail(name="EvaluateCustomModel", status="Failed",
display_name="EvaluateCustomModel",
failure_reason="ResourceLimitExceeded",
job_arn="arn:aws:sagemaker:us-west-2:123456789:training-job/eval-custom-xyz"),
]
exec_obj = _make_execution(status="Failed", step_details=steps, failure_reason="Step failed")
with pytest.raises(Exception):
exec_obj.wait(poll=0, timeout=1)
captured = capsys.readouterr()
assert "Failed step: EvaluateCustomModel" in captured.out
assert "ResourceLimitExceeded" in captured.out
assert "Job ARN: arn:aws:sagemaker:us-west-2:123456789:training-job/eval-custom-xyz" in captured.out
assert "Log group: /aws/sagemaker/TrainingJobs" in captured.out
assert "Log stream prefix: eval-custom-xyz" in captured.out
assert "CloudWatch Logs:" in captured.out
Loading