From 7be345a459e65328a1cfac5ce7ef99ee9a57af10 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Mon, 11 Aug 2025 10:41:29 +0000 Subject: [PATCH 1/4] checkpointing --- eval_protocol/pytest/evaluation_test.py | 29 ++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index fe0c8cf5..30e77348 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -404,7 +404,34 @@ def _log_eval_error( for row in fresh_dataset: active_logger.log(row) - processed_dataset = execute_function(rollout_processor, rows=fresh_dataset, config=config) + # filter out rows that already have completed rollouts via checkpointing + rows_to_process = [] + completed_rollout_ids = set() + + finished_logs = active_logger.read() + + for finished_row in finished_logs: + # need to add finished rows to all_results so that we can aggregate them later. + all_results.append(finished_row) + # TODO: need to also add the num_run to track which run the row belongs to. + # TODO: ask why we made row_id optional in the first place. checkpointing won't work without some ID. + if finished_row.input_metadata and finished_row.input_metadata.row_id: + completed_rollout_ids.add(finished_row.input_metadata.row_id) + + for row in fresh_dataset: + row_id = row.input_metadata.row_id if row.input_metadata else None + if row_id not in completed_rollout_ids: + rows_to_process.append(row) + + if len(rows_to_process) < len(fresh_dataset): + print( + f"Checkpointing: Found {len(fresh_dataset) - len(rows_to_process)} completed rows, processing {len(rows_to_process)} remaining rows" + ) + + if rows_to_process: + processed_dataset = execute_function( + rollout_processor, rows=rows_to_process, config=config + ) if mode == "pointwise": # Pointwise mode: apply the evaluator function to each row From b6a6ec5d293a1e6318457465b45cecef436eb51e Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Mon, 11 Aug 2025 22:49:45 -0700 Subject: [PATCH 2/4] wip --- .../pytest/default_agent_rollout_processor.py | 17 +- .../default_single_turn_rollout_process.py | 7 +- eval_protocol/pytest/evaluation_test.py | 223 +++++++++-------- eval_protocol/pytest/plugin.py | 24 +- tests/pytest/test_tau_bench_airline.py | 4 +- tests/test_rollout_error_handling.py | 229 ++++++++++++++++++ 6 files changed, 374 insertions(+), 130 deletions(-) create mode 100644 tests/test_rollout_error_handling.py diff --git a/eval_protocol/pytest/default_agent_rollout_processor.py b/eval_protocol/pytest/default_agent_rollout_processor.py index bd7c62c2..10085941 100644 --- a/eval_protocol/pytest/default_agent_rollout_processor.py +++ b/eval_protocol/pytest/default_agent_rollout_processor.py @@ -117,10 +117,15 @@ async def default_agent_rollout_processor( ) -> List[EvaluationRow]: dataset: Dataset = [] for row in rows: - agent = Agent(model=config.model, row=row, config_path=config.mcp_config_path, logger=config.logger) - await agent.setup() - await agent.call_agent() - dataset.append(agent.evaluation_row) - if agent.mcp_client: - await agent.mcp_client.cleanup() + try: + agent = Agent(model=config.model, row=row, config_path=config.mcp_config_path, logger=config.logger) + await agent.setup() + await agent.call_agent() + dataset.append(agent.evaluation_row) + if agent.mcp_client: + await agent.mcp_client.cleanup() + except Exception as e: + row.rollout_status.status = "error" + row.rollout_status.error_message = str(e) + dataset.append(row) return dataset diff --git a/eval_protocol/pytest/default_single_turn_rollout_process.py b/eval_protocol/pytest/default_single_turn_rollout_process.py index 69966b39..962ea26a 100644 --- a/eval_protocol/pytest/default_single_turn_rollout_process.py +++ b/eval_protocol/pytest/default_single_turn_rollout_process.py @@ -87,7 +87,12 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow: async with semaphore: - return await process_row(r) + try: + return await process_row(r) + except Exception as e: + r.rollout_status.status = "error" + r.rollout_status.error_message = str(e) + return r tasks = [_sem_wrapper(row) for row in rows] dataset = list(await asyncio.gather(*tasks)) diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index 9b68e0d2..03df28a0 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -401,133 +401,132 @@ def _log_eval_error( logger=active_logger, ) + max_retry = int(os.getenv("EP_MAX_RETRY", "0")) + for i in range(num_runs): - # Regenerate outputs each run by deep-copying the pristine dataset - # so model responses are not reused across runs. run_id = generate_id() - fresh_dataset = [r.model_copy(deep=True) for r in data] - - # apply new run_id to fresh_dataset - for row in fresh_dataset: - row.run_id = run_id - - # generate new rollout_id for each row - for row in fresh_dataset: - row.rollout_id = generate_id() - - # log the fresh_dataset - for row in fresh_dataset: - active_logger.log(row) - - # filter out rows that already have completed rollouts via checkpointing - rows_to_process = [] - completed_rollout_ids = set() - - finished_logs = active_logger.read() - - for finished_row in finished_logs: - # need to add finished rows to all_results so that we can aggregate them later. - all_results.append(finished_row) - # TODO: need to also add the num_run to track which run the row belongs to. - # TODO: ask why we made row_id optional in the first place. checkpointing won't work without some ID. - if finished_row.input_metadata and finished_row.input_metadata.row_id: - completed_rollout_ids.add(finished_row.input_metadata.row_id) - - for row in fresh_dataset: - row_id = row.input_metadata.row_id if row.input_metadata else None - if row_id not in completed_rollout_ids: - rows_to_process.append(row) - - if len(rows_to_process) < len(fresh_dataset): - print( - f"Checkpointing: Found {len(fresh_dataset) - len(rows_to_process)} completed rows, processing {len(rows_to_process)} remaining rows" - ) - - if rows_to_process: - processed_dataset = execute_function( - rollout_processor, rows=rows_to_process, config=config - ) - - if mode == "pointwise": - # Pointwise mode: apply the evaluator function to each row - for row in processed_dataset: - result = execute_with_params( + retry_attempt = 0 + current_data = data + + while retry_attempt <= max_retry: + if retry_attempt > 0: + logged_rows = active_logger.read() + failed_rows = [ + row + for row in logged_rows + if row.rollout_status + and row.rollout_status.status == "error" + and row.run_id == run_id + ] + if not failed_rows: + break + current_data = failed_rows + + # Regenerate outputs each run by deep-copying the pristine dataset + # so model responses are not reused across runs. + fresh_dataset = [r.model_copy(deep=True) for r in current_data] + + # apply new run_id to fresh_dataset + for row in fresh_dataset: + row.run_id = run_id + + # generate new rollout_id for each row + for row in fresh_dataset: + row.rollout_id = generate_id() + + # log the fresh_dataset + for row in fresh_dataset: + active_logger.log(row) + + processed_dataset = execute_function(rollout_processor, rows=fresh_dataset, config=config) + + if mode == "pointwise": + # Pointwise mode: apply the evaluator function to each row + for row in processed_dataset: + result = execute_with_params( + test_func, + processed_row=row, + evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {}, + ) + if result is None or not isinstance(result, EvaluationRow): + raise ValueError( + f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test." + ) + # TODO: not this simple, only append ones that are not error + all_results[i].append(result) + else: + # Batch mode: call the test function with the full dataset + results = execute_with_params( test_func, - processed_row=row, + processed_dataset=processed_dataset, evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {}, ) - if result is None or not isinstance(result, EvaluationRow): + if results is None: raise ValueError( f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test." ) - all_results[i].append(result) - else: - # Batch mode: call the test function with the full dataset - results = execute_with_params( - test_func, - processed_dataset=processed_dataset, - evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {}, - ) - if results is None: - raise ValueError( - f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test." - ) - if not isinstance(results, list): - raise ValueError( - f"Test function {test_func.__name__} did not return a list of EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test." - ) - if not results: - raise ValueError( - f"Test function {test_func.__name__} returned an empty list. You must return a non-empty list of EvaluationRow instances from your test function decorated with @evaluation_test." - ) - if not all(isinstance(r, EvaluationRow) for r in results): - raise ValueError( - f"Test function {test_func.__name__} returned a list containing non-EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test." + if not isinstance(results, list): + raise ValueError( + f"Test function {test_func.__name__} did not return a list of EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test." + ) + if not results: + raise ValueError( + f"Test function {test_func.__name__} returned an empty list. You must return a non-empty list of EvaluationRow instances from your test function decorated with @evaluation_test." + ) + if not all(isinstance(r, EvaluationRow) for r in results): + raise ValueError( + f"Test function {test_func.__name__} returned a list containing non-EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test." + ) + # TODO: not this simple, only append ones that are not error + all_results[i] = results + + retry_attempt += 1 + + scores = [ + sum([r.evaluation_result.score for r in result if r.evaluation_result]) / len(result) + for result in all_results + ] + agg_score = aggregate(scores, aggregation_method) + score_std = statistics.stdev(scores) if len(scores) > 1 else 0.0 + + # Compute 95% confidence interval for the fixed-set mean μ (by-question, using repeats) + ci_low: float | None = None + ci_high: float | None = None + if aggregation_method == "mean": + try: + result_ci = compute_fixed_set_mu_ci( + [item for sublist in all_results for item in sublist] ) - all_results[i] = results - - scores = [ - sum([r.evaluation_result.score for r in result if r.evaluation_result]) / len(result) - for result in all_results - ] - agg_score = aggregate(scores, aggregation_method) - score_std = statistics.stdev(scores) if len(scores) > 1 else 0.0 - - # Compute 95% confidence interval for the fixed-set mean μ (by-question, using repeats) - ci_low: float | None = None - ci_high: float | None = None - if aggregation_method == "mean": - try: - result_ci = compute_fixed_set_mu_ci([item for sublist in all_results for item in sublist]) - mu_ci_low, mu_ci_high = result_ci[1], result_ci[2] - if mu_ci_low is not None and mu_ci_high is not None: - ci_low = float(mu_ci_low) - ci_high = float(mu_ci_high) - # Keep agg_score as-is (mean over scores). For equal repeats per question these match. - except Exception: - ci_low = None - ci_high = None - - # Determine if the evaluation passed based on threshold - passed = None - - if threshold is not None: - success_passed, std_passed = True, True - - success_passed = agg_score >= threshold.success + mu_ci_low, mu_ci_high = result_ci[1], result_ci[2] + if mu_ci_low is not None and mu_ci_high is not None: + ci_low = float(mu_ci_low) + ci_high = float(mu_ci_high) + # Keep agg_score as-is (mean over scores). For equal repeats per question these match. + except Exception: + ci_low = None + ci_high = None - if threshold.standard_deviation is not None: - std_passed = score_std <= threshold.standard_deviation + # Determine if the evaluation passed based on threshold + passed = None + + if threshold is not None: + success_passed, std_passed = True, True + + success_passed = agg_score >= threshold.success + + if threshold.standard_deviation is not None: + std_passed = score_std <= threshold.standard_deviation - passed = success_passed and std_passed + passed = success_passed and std_passed # Update eval metadata status and passed field for all results for result in all_results: for r in result: - if r.eval_metadata is not None: - r.eval_metadata.status = "finished" - r.eval_metadata.passed = passed - default_logger.log(r) + if r.rollout_status is not None: + if r.rollout_status.status != "error": + r.rollout_status.status = "finished" + r.rollout_status.passed = passed + active_logger.log(r) # Optional: print and/or persist a summary artifact for CI try: diff --git a/eval_protocol/pytest/plugin.py b/eval_protocol/pytest/plugin.py index 5eb9a946..0926ca69 100644 --- a/eval_protocol/pytest/plugin.py +++ b/eval_protocol/pytest/plugin.py @@ -12,8 +12,8 @@ max_dataset_rows value set in the decorator). """ -import os import logging +import os from typing import Optional @@ -32,17 +32,13 @@ def pytest_addoption(parser) -> None: "--ep-print-summary", action="store_true", default=False, - help=( - "Print a concise summary line (suite/model/effort/agg score) at the end of each evaluation_test." - ), + help=("Print a concise summary line (suite/model/effort/agg score) at the end of each evaluation_test."), ) group.addoption( "--ep-summary-json", action="store", default=None, - help=( - "Write a JSON summary artifact at the given path (e.g., ./outputs/aime_low.json)." - ), + help=("Write a JSON summary artifact at the given path (e.g., ./outputs/aime_low.json)."), ) group.addoption( "--ep-input-param", @@ -63,6 +59,13 @@ def pytest_addoption(parser) -> None: "Values: low|medium|high" ), ) + group.addoption( + "--ep-max-retry", + action="store", + type=int, + default=None, + help=("Failed rollouts (with rollout_status.status == 'error') will be retried up to this many times."), + ) def _normalize_max_rows(val: Optional[str]) -> Optional[str]: @@ -104,10 +107,15 @@ def pytest_configure(config) -> None: if summary_json_path: os.environ["EP_SUMMARY_JSON"] = summary_json_path + max_retry = config.getoption("--ep-max-retry") + if max_retry is not None: + os.environ["EP_MAX_RETRY"] = str(max_retry) + # Allow ad-hoc overrides of input params via CLI flags try: import json as _json import pathlib as _pathlib + merged: dict = {} input_params_opts = config.getoption("--ep-input-param") if input_params_opts: @@ -140,5 +148,3 @@ def pytest_configure(config) -> None: except Exception: # best effort, do not crash pytest session pass - - diff --git a/tests/pytest/test_tau_bench_airline.py b/tests/pytest/test_tau_bench_airline.py index 80aadf14..97dae602 100644 --- a/tests/pytest/test_tau_bench_airline.py +++ b/tests/pytest/test_tau_bench_airline.py @@ -58,7 +58,7 @@ def tau_bench_airline_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Eval rows.append(eval_row) - return rows + return rows[0:1] @evaluation_test( @@ -68,7 +68,7 @@ def tau_bench_airline_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Eval rollout_input_params=[{"temperature": 0.8, "max_tokens": 4096, "reasoning_effort": "low"}], rollout_processor=default_mcp_gym_rollout_processor, passed_threshold={"success": 0.4, "standard_deviation": 0.1}, - num_runs=8, + num_runs=1, mode="pointwise", max_concurrent_rollouts=50, server_script_path="examples/tau2_mcp/server.py", diff --git a/tests/test_rollout_error_handling.py b/tests/test_rollout_error_handling.py new file mode 100644 index 00000000..87694dc3 --- /dev/null +++ b/tests/test_rollout_error_handling.py @@ -0,0 +1,229 @@ +""" +Unit tests for rollout processor error handling. + +Tests that rollout processors properly set rollout_status.status = "error" when exceptions occur. +""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from eval_protocol.dataset_logger import default_logger +from eval_protocol.models import EvaluationRow, Message, RolloutStatus +from eval_protocol.pytest.default_agent_rollout_processor import default_agent_rollout_processor +from eval_protocol.pytest.default_single_turn_rollout_process import default_single_turn_rollout_processor +from eval_protocol.pytest.types import RolloutProcessorConfig + + +class TestRolloutErrorHandling: + """Test that rollout processors handle errors correctly.""" + + @pytest.mark.asyncio + async def test_agent_rollout_processor_429_error(self): + """Test that agent rollout processor handles 429 rate limit errors correctly.""" + + # Create test row with initialized rollout_status + test_row = EvaluationRow( + messages=[Message(role="user", content="Hello")], rollout_status=RolloutStatus(status="running") + ) + + config = RolloutProcessorConfig( + model="gpt-4", input_params={}, mcp_config_path="", logger=default_logger # Empty to avoid MCP setup + ) + + # Mock the LiteLLM policy to raise a 429 error + with patch("eval_protocol.pytest.default_agent_rollout_processor.LiteLLMPolicy") as mock_policy_class: + # Create a mock policy instance + mock_policy = AsyncMock() + mock_policy_class.return_value = mock_policy + + # Mock the _make_llm_call method to raise a 429 error + import litellm + + mock_policy._make_llm_call.side_effect = litellm.RateLimitError( + message="Rate limit exceeded: 429", llm_provider="openai", model="gpt-4" + ) + + # The agent rollout processor should catch the exception and set error status + result = await default_agent_rollout_processor([test_row], config) + + assert len(result) == 1 + assert result[0].rollout_status.status == "error" + assert result[0].rollout_status.error_message is not None + assert ( + "429" in result[0].rollout_status.error_message + or "rate limit" in result[0].rollout_status.error_message.lower() + ) + + @pytest.mark.asyncio + async def test_agent_rollout_processor_bad_request_error(self): + """Test that agent rollout processor handles BadRequest errors correctly.""" + + test_row = EvaluationRow( + messages=[Message(role="user", content="Hello")], rollout_status=RolloutStatus(status="running") + ) + + config = RolloutProcessorConfig(model="gpt-4", input_params={}, mcp_config_path="", logger=default_logger) + + # Mock the LiteLLM policy to raise a BadRequest error like the one in your example + with patch("eval_protocol.pytest.default_agent_rollout_processor.LiteLLMPolicy") as mock_policy_class: + mock_policy = AsyncMock() + mock_policy_class.return_value = mock_policy + + import openai + + mock_policy._make_llm_call.side_effect = openai.BadRequestError( + "Invalid value for 'content': expected a string, got null.", response=MagicMock(), body=None + ) + + result = await default_agent_rollout_processor([test_row], config) + + assert len(result) == 1 + assert result[0].rollout_status.status == "error" + assert result[0].rollout_status.error_message is not None + assert ( + "content" in result[0].rollout_status.error_message + or "BadRequest" in result[0].rollout_status.error_message + ) + + @pytest.mark.asyncio + async def test_single_turn_rollout_processor_429_error(self): + """Test that single turn rollout processor handles 429 rate limit errors correctly.""" + + test_row = EvaluationRow( + messages=[Message(role="user", content="Hello")], rollout_status=RolloutStatus(status="running") + ) + + config = RolloutProcessorConfig(model="gpt-4", input_params={}, mcp_config_path="", logger=default_logger) + + # Mock litellm.acompletion to raise a 429 error + with patch("importlib.import_module") as mock_import: + mock_litellm = MagicMock() + mock_import.return_value = mock_litellm + + import litellm + + mock_litellm.acompletion.side_effect = litellm.RateLimitError( + message="Rate limit exceeded: 429", llm_provider="openai", model="gpt-4" + ) + + result = await default_single_turn_rollout_processor([test_row], config) + + assert len(result) == 1 + assert result[0].rollout_status.status == "error" + assert result[0].rollout_status.error_message is not None + assert ( + "429" in result[0].rollout_status.error_message + or "rate limit" in result[0].rollout_status.error_message.lower() + ) + + @pytest.mark.asyncio + async def test_single_turn_rollout_processor_bad_request_error(self): + """Test that single turn rollout processor handles BadRequest errors correctly.""" + + test_row = EvaluationRow( + messages=[Message(role="user", content="Hello")], rollout_status=RolloutStatus(status="running") + ) + + config = RolloutProcessorConfig(model="gpt-4", input_params={}, mcp_config_path="", logger=default_logger) + + # Mock litellm.acompletion to raise a BadRequest error + with patch("importlib.import_module") as mock_import: + mock_litellm = MagicMock() + mock_import.return_value = mock_litellm + + import openai + + mock_litellm.acompletion.side_effect = openai.BadRequestError( + "Invalid value for 'content': expected a string, got null.", response=MagicMock(), body=None + ) + + result = await default_single_turn_rollout_processor([test_row], config) + + assert len(result) == 1 + assert result[0].rollout_status.status == "error" + assert result[0].rollout_status.error_message is not None + assert ( + "content" in result[0].rollout_status.error_message + or "BadRequest" in result[0].rollout_status.error_message + ) + + @pytest.mark.asyncio + async def test_multiple_rows_with_mixed_errors(self): + """Test that when some rows get 429 errors and some succeed, each gets the correct status.""" + + # Create test rows + row1 = EvaluationRow( + messages=[Message(role="user", content="Hello 1")], rollout_status=RolloutStatus(status="running") + ) + + row2 = EvaluationRow( + messages=[Message(role="user", content="Hello 2")], rollout_status=RolloutStatus(status="running") + ) + + config = RolloutProcessorConfig(model="gpt-4", input_params={}, mcp_config_path="", logger=default_logger) + + # Mock litellm.acompletion to raise 429 for both rows (simulating rate limiting) + with patch("importlib.import_module") as mock_import: + mock_litellm = MagicMock() + mock_import.return_value = mock_litellm + + import litellm + + mock_litellm.acompletion.side_effect = litellm.RateLimitError( + message="Rate limit exceeded: 429", llm_provider="openai", model="gpt-4" + ) + + result = await default_single_turn_rollout_processor([row1, row2], config) + + assert len(result) == 2 + # Both should have error status due to 429 errors + for row in result: + assert row.rollout_status.status == "error" + assert row.rollout_status.error_message is not None + assert ( + "429" in row.rollout_status.error_message + or "rate limit" in row.rollout_status.error_message.lower() + ) + + @pytest.mark.asyncio + async def test_rollout_status_preserves_original_row_data_on_api_error(self): + """Test that when API errors occur, the original row data is preserved.""" + + original_message = Message(role="user", content="Original message") + test_row = EvaluationRow(messages=[original_message], rollout_status=RolloutStatus(status="running")) + + config = RolloutProcessorConfig(model="gpt-4", input_params={}, mcp_config_path="", logger=default_logger) + + # Mock the LiteLLM policy to raise an API error + with patch("eval_protocol.pytest.default_agent_rollout_processor.LiteLLMPolicy") as mock_policy_class: + mock_policy = AsyncMock() + mock_policy_class.return_value = mock_policy + + import litellm + + mock_policy._make_llm_call.side_effect = litellm.RateLimitError( + message="Rate limit exceeded: 429", llm_provider="openai", model="gpt-4" + ) + + result = await default_agent_rollout_processor([test_row], config) + + assert len(result) == 1 + assert result[0].rollout_status.status == "error" + # Original message should be preserved + assert len(result[0].messages) == 1 + assert result[0].messages[0].content == "Original message" + + def test_rollout_status_initialization(self): + """Test that RolloutStatus initializes with correct default values.""" + + # Test default initialization + status = RolloutStatus() + assert status.status == "finished" # Default from the model + assert status.error_message is None + + # Test explicit initialization + status = RolloutStatus(status="error", error_message="Test error") + assert status.status == "error" + assert status.error_message == "Test error" From bdd630c7795b414ec046b0c85f989b9dbf9c0583 Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Wed, 13 Aug 2025 14:28:44 -0700 Subject: [PATCH 3/4] test --- tests/pytest/test_retry_logic.py | 342 +++++++++++++++++++++++++++++++ 1 file changed, 342 insertions(+) create mode 100644 tests/pytest/test_retry_logic.py diff --git a/tests/pytest/test_retry_logic.py b/tests/pytest/test_retry_logic.py new file mode 100644 index 00000000..a34ca05e --- /dev/null +++ b/tests/pytest/test_retry_logic.py @@ -0,0 +1,342 @@ +""" +Test suite for the individual rollout retry logic in evaluation_test. + +Tests the new efficient retry system that retries individual rollouts immediately +as they fail, rather than waiting for entire batches to complete. +""" + +import asyncio +import os +from typing import List +from unittest.mock import patch + +import pytest + +from eval_protocol.models import EvaluateResult, EvaluationRow, Message, RolloutStatus +from eval_protocol.pytest import evaluation_test +from eval_protocol.pytest.types import RolloutProcessor, RolloutProcessorConfig + + +class MockRetryRolloutProcessor: + """ + Mock rollout processor that simulates different rollout statuses. + + On first call, returns rollouts with mixed statuses (finished, error, running). + On retry calls, converts error/running rollouts to finished status. + """ + + def __init__(self): + self.call_count = 0 + self.processed_rollout_ids = set() + + async def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig): + """Process rollouts with simulated statuses""" + self.call_count += 1 + + for row in rows: + # If this is a retry (rollout_id we've seen before), make it succeed + if row.execution_metadata.rollout_id in self.processed_rollout_ids: + row.rollout_status = RolloutStatus(status="finished") + row.messages.append( + Message(role="assistant", content=f"Retry success for {row.execution_metadata.rollout_id}") + ) + else: + # First time processing this logical rollout + self.processed_rollout_ids.add(row.execution_metadata.rollout_id) + + # Simulate different statuses based on content + content = row.messages[0].content if row.messages else "" + + if "should_finish" in content: + # This one succeeds immediately + row.rollout_status = RolloutStatus(status="finished") + row.messages.append(Message(role="assistant", content="Success on first try")) + elif "should_error" in content: + # This one errors on first try, should be retried + row.rollout_status = RolloutStatus(status="error", termination_reason="Simulated error") + row.messages.append(Message(role="assistant", content="Error on first try")) + elif "should_be_running" in content: + # This one is left in running state, should be retried + row.rollout_status = RolloutStatus(status="running") + row.messages.append(Message(role="assistant", content="Left running, needs retry")) + else: + # Default to finished + row.rollout_status = RolloutStatus(status="finished") + row.messages.append(Message(role="assistant", content="Default success")) + + yield row + + +class MockAlwaysFailRolloutProcessor: + """Mock rollout processor that always fails, to test retry exhaustion""" + + def __init__(self): + self.call_count = 0 + + async def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig): + """Always return error status to test retry exhaustion""" + self.call_count += 1 + + for row in rows: + row.rollout_status = RolloutStatus( + status="error", termination_reason=f"Persistent failure (attempt {self.call_count})" + ) + row.messages.append(Message(role="assistant", content=f"Failed attempt {self.call_count}")) + yield row + + +# Create instances that will be shared across test functions +mock_retry_processor = MockRetryRolloutProcessor() +mock_always_fail_processor = MockAlwaysFailRolloutProcessor() + + +# Set environment variable at module level for this test +@patch.dict(os.environ, {"EP_MAX_RETRY": "3"}) +@evaluation_test( + input_messages=[ + [Message(role="user", content="Test case that should_finish immediately")], + [Message(role="user", content="Test case that should_error on first try")], + [Message(role="user", content="Test case that should_be_running and need retry")], + ], + model=["dummy/local-model"], + rollout_processor=mock_retry_processor, + mode="batch", + num_runs=1, +) +def test_retry_mixed_statuses_batch_mode(rows: List[EvaluationRow]) -> List[EvaluationRow]: + """ + Test that retry logic works with mixed rollout statuses in batch mode. + + Tests: + - One rollout finishes immediately (should not retry) + - One rollout has error status (should retry and succeed) + - One rollout has running status (should retry and succeed) + """ + # Reset processor state at the beginning + mock_retry_processor.call_count = 0 + mock_retry_processor.processed_rollout_ids.clear() + + # Verify we got all our test cases + assert len(rows) == 3 + + # Verify all rollouts ended up in finished state after retries + for row in rows: + assert row.rollout_status is not None + assert row.rollout_status.status == "finished", f"Row should be finished but was {row.rollout_status.status}" + + # Check that retry cases got the retry response + content = row.messages[0].content + if "should_error" in content or "should_be_running" in content: + # These should have been retried + assistant_messages = [msg for msg in row.messages if msg.role == "assistant"] + assert len(assistant_messages) >= 1 + assert "Retry success" in assistant_messages[-1].content + + # Set evaluation results + for row in rows: + row.evaluation_result = EvaluateResult(score=1.0, reason="All rollouts completed successfully") + + return rows + + +@patch.dict(os.environ, {"EP_MAX_RETRY": "3"}) +@evaluation_test( + input_messages=[ + [Message(role="user", content="Test pointwise should_error")], + [Message(role="user", content="Test pointwise should_be_running")], + [Message(role="user", content="Test pointwise should_finish")], + ], + model=["dummy/local-model"], + rollout_processor=mock_retry_processor, + mode="pointwise", + num_runs=1, +) +def test_retry_mixed_statuses_pointwise_mode(row: EvaluationRow) -> EvaluationRow: + """ + Test that retry logic works with mixed rollout statuses in pointwise mode. + + Each rollout is processed individually and should retry if not finished. + """ + # Verify rollout ended up in finished state after any needed retries + assert row.rollout_status is not None + assert row.rollout_status.status == "finished", f"Row should be finished but was {row.rollout_status.status}" + + # Set evaluation result + row.evaluation_result = EvaluateResult(score=1.0, reason="Rollout completed successfully") + + return row + + +def test_retry_exhaustion_should_fail(): + """ + Test that rollout process fails when max retries are exceeded. + + Sets EP_MAX_RETRY=2 and uses a processor that always fails. + Should fail after 3 total attempts (initial + 2 retries). + """ + + # Set max retries environment variable + with patch.dict(os.environ, {"EP_MAX_RETRY": "2"}): + + @evaluation_test( + input_messages=[ + [Message(role="user", content="This will always fail")], + ], + model=["dummy/local-model"], + rollout_processor=mock_always_fail_processor, + mode="batch", + num_runs=1, + ) + def failing_evaluation_test(rows: List[EvaluationRow]) -> List[EvaluationRow]: + # This should never be reached due to rollout failures + for row in rows: + row.evaluation_result = EvaluateResult(score=1.0, reason="Should not reach here") + return rows + + # The evaluation_test should raise RuntimeError due to retry exhaustion + with pytest.raises(RuntimeError) as exc_info: + # Run the test directly to trigger the retry logic + import asyncio + + # Reset the processor call count + mock_always_fail_processor.call_count = 0 + + # Create test data + rows = [EvaluationRow(messages=[Message(role="user", content="This will always fail")])] + + # This should fail after 3 attempts (initial + 2 retries) + asyncio.run(failing_evaluation_test(rows)) + + # Verify the error message mentions retry exhaustion + error_msg = str(exc_info.value) + assert "failed after 2 retries" in error_msg.lower() or "retry" in error_msg.lower() + + # Verify the processor was called multiple times (initial + retries) + assert ( + mock_always_fail_processor.call_count >= 3 + ), f"Expected >= 3 calls, got {mock_always_fail_processor.call_count}" + + +def test_no_retries_when_max_retry_zero(): + """ + Test that no retries happen when EP_MAX_RETRY=0 (default). + + Even with failing rollouts, should fail immediately without retries. + """ + + # Ensure EP_MAX_RETRY is 0 (default) + with patch.dict(os.environ, {"EP_MAX_RETRY": "0"}): + + @evaluation_test( + input_messages=[ + [Message(role="user", content="This will fail once and not retry")], + ], + model=["dummy/local-model"], + rollout_processor=mock_always_fail_processor, + mode="batch", + num_runs=1, + ) + def no_retry_evaluation_test(rows: List[EvaluationRow]) -> List[EvaluationRow]: + # This should never be reached due to immediate failure + for row in rows: + row.evaluation_result = EvaluateResult(score=1.0, reason="Should not reach here") + return rows + + # Should fail immediately without retries + with pytest.raises(RuntimeError) as exc_info: + # Reset processor call count + mock_always_fail_processor.call_count = 0 + + # Create test data + rows = [EvaluationRow(messages=[Message(role="user", content="This will fail once and not retry")])] + + # Should fail after just 1 attempt + asyncio.run(no_retry_evaluation_test(rows)) + + # Verify only 1 attempt was made (no retries) + assert ( + mock_always_fail_processor.call_count == 1 + ), f"Expected 1 call, got {mock_always_fail_processor.call_count}" + + +@pytest.mark.asyncio +async def test_concurrent_retry_efficiency(): + """ + Test that retries happen efficiently with proper concurrency. + + Verifies that successful rollouts don't wait for failing ones, + and that retries start immediately as failures are detected. + """ + + class TimingMockProcessor: + """Mock processor that tracks timing of rollout processing""" + + def __init__(self): + self.processing_times = {} + self.start_times = {} + + async def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig): + import time + + for row in rows: + rollout_id = row.execution_metadata.rollout_id + self.start_times[rollout_id] = time.time() + + # Simulate different processing times + content = row.messages[0].content if row.messages else "" + + if "slow_success" in content: + # Slow but successful rollout + await asyncio.sleep(0.1) + row.rollout_status = RolloutStatus(status="finished") + row.messages.append(Message(role="assistant", content="Slow success")) + elif "fast_fail" in content: + # Fast failure that should retry quickly + await asyncio.sleep(0.01) + if rollout_id not in self.processing_times: + # First attempt - fail + row.rollout_status = RolloutStatus(status="error", termination_reason="Fast failure") + row.messages.append(Message(role="assistant", content="Fast failure")) + self.processing_times[rollout_id] = time.time() + else: + # Retry - succeed + row.rollout_status = RolloutStatus(status="finished") + row.messages.append(Message(role="assistant", content="Fast retry success")) + + yield row + + timing_processor = TimingMockProcessor() + + with patch.dict(os.environ, {"EP_MAX_RETRY": "3"}): + + @evaluation_test( + input_messages=[ + [Message(role="user", content="slow_success - this takes longer but succeeds")], + [Message(role="user", content="fast_fail - this fails fast then retries")], + ], + model=["dummy/local-model"], + rollout_processor=timing_processor, + mode="batch", + num_runs=1, + ) + def timing_test(rows: List[EvaluationRow]) -> List[EvaluationRow]: + # Both should succeed eventually + assert len(rows) == 2 + for row in rows: + assert row.rollout_status.status == "finished" + row.evaluation_result = EvaluateResult(score=1.0, reason="Success") + return rows + + # Create test data + rows = [ + EvaluationRow(messages=[Message(role="user", content="slow_success - this takes longer but succeeds")]), + EvaluationRow(messages=[Message(role="user", content="fast_fail - this fails fast then retries")]), + ] + + # Run the test - should complete successfully with proper retry timing + result = await timing_test(rows) + assert len(result) == 2 + + # Verify that the fast-failing rollout was processed multiple times due to retry + fast_fail_processed = any("fast_fail" in row.messages[0].content for row in result) + assert fast_fail_processed, "Fast-failing rollout should have been processed" From a0edb8fa57f9fe3294664683af5e98ac3af9cc8b Mon Sep 17 00:00:00 2001 From: Derek Xu Date: Wed, 13 Aug 2025 14:29:19 -0700 Subject: [PATCH 4/4] default --- eval_protocol/pytest/evaluation_test.py | 50 ++++++++----------------- 1 file changed, 16 insertions(+), 34 deletions(-) diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index 71d1f7d1..d85bbc0e 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -418,43 +418,25 @@ def _log_eval_error( max_retry = int(os.getenv("EP_MAX_RETRY", "0")) for i in range(num_runs): - # Regenerate outputs each run by deep-copying the pristine dataset - # so model responses are not reused across runs. run_id = generate_id() retry_attempt = 0 current_data = data - while retry_attempt <= max_retry: - if retry_attempt > 0: - logged_rows = active_logger.read() - failed_rows = [ - row - for row in logged_rows - if row.rollout_status - and row.rollout_status.status == "error" - and row.run_id == run_id - ] - if not failed_rows: - break - current_data = failed_rows - - # Regenerate outputs each run by deep-copying the pristine dataset - # so model responses are not reused across runs. - fresh_dataset = [r.model_copy(deep=True) for r in current_data] - - # apply new run_id to fresh_dataset - for row in fresh_dataset: - row.run_id = run_id - - # generate new rollout_id for each row - for row in fresh_dataset: - row.rollout_id = generate_id() - - # log the fresh_dataset - for row in fresh_dataset: - active_logger.log(row) - - rollout_result = rollout_processor(fresh_dataset, config) + # Regenerate outputs each run by deep-copying the pristine dataset + # so model responses are not reused across runs. + fresh_dataset = [r.model_copy(deep=True) for r in current_data] + + # apply new run_id to fresh_dataset + for row in fresh_dataset: + row.run_id = run_id + + # generate new rollout_id for each row + for row in fresh_dataset: + row.rollout_id = generate_id() + + # log the fresh_dataset + for row in fresh_dataset: + active_logger.log(row) if mode == "pointwise": # Pointwise mode, rollouts will return as they complete so we can pipeline evaluation_test execution @@ -482,7 +464,7 @@ async def _execute_with_semaphore(row): else: # Batch mode: collect all results first, then evaluate (no pipelining) input_dataset = [] - async for row in rollout_result: + async for row in rollout_processor(fresh_dataset, config): input_dataset.append(row) results = await execute_with_params(