Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion eval_protocol/pytest/default_agent_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
try:
return await process_row(r)
except Exception as e:
logger.exception(f"Error processing row {r.input_metadata.row_id}: {e}")
r.rollout_status.status = "error"
r.rollout_status.error_message = str(e)
return r

# Create all tasks
Expand Down
2 changes: 2 additions & 0 deletions eval_protocol/pytest/default_single_turn_rollout_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
try:
return await process_row(r)
except Exception as e:
r.rollout_status.status = "error"
r.rollout_status.error_message = str(e)
return r

# Create all tasks
Expand Down
92 changes: 50 additions & 42 deletions eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,26 +415,29 @@ def _log_eval_error(
kwargs=rollout_processor_kwargs,
)

max_retry = int(os.getenv("EP_MAX_RETRY", "0"))

for i in range(num_runs):
run_id = generate_id()
retry_attempt = 0
current_data = data

# Regenerate outputs each run by deep-copying the pristine dataset
# so model responses are not reused across runs.
run_id = generate_id()
fresh_dataset = [r.model_copy(deep=True) for r in data]
fresh_dataset = [r.model_copy(deep=True) for r in current_data]

# apply new run_id to fresh_dataset
for row in fresh_dataset:
row.execution_metadata.run_id = run_id
row.run_id = run_id

# generate new rollout_id for each row
for row in fresh_dataset:
row.execution_metadata.rollout_id = generate_id()
row.rollout_id = generate_id()

# log the fresh_dataset
for row in fresh_dataset:
active_logger.log(row)

rollout_result = rollout_processor(fresh_dataset, config)

if mode == "pointwise":
# Pointwise mode, rollouts will return as they complete so we can pipeline evaluation_test execution
semaphore = asyncio.Semaphore(max_concurrent_rollouts)
Expand All @@ -461,7 +464,7 @@ async def _execute_with_semaphore(row):
else:
# Batch mode: collect all results first, then evaluate (no pipelining)
input_dataset = []
async for row in rollout_result:
async for row in rollout_processor(fresh_dataset, config):
input_dataset.append(row)

results = await execute_with_params(
Expand All @@ -487,47 +490,52 @@ async def _execute_with_semaphore(row):
)
all_results[i] = results

scores = [
sum([r.evaluation_result.score for r in result if r.evaluation_result]) / len(result)
for result in all_results
]
agg_score = aggregate(scores, aggregation_method)
score_std = statistics.stdev(scores) if len(scores) > 1 else 0.0

# Compute 95% confidence interval for the fixed-set mean μ (by-question, using repeats)
ci_low: float | None = None
ci_high: float | None = None
if aggregation_method == "mean":
try:
result_ci = compute_fixed_set_mu_ci([item for sublist in all_results for item in sublist])
mu_ci_low, mu_ci_high = result_ci[1], result_ci[2]
if mu_ci_low is not None and mu_ci_high is not None:
ci_low = float(mu_ci_low)
ci_high = float(mu_ci_high)
# Keep agg_score as-is (mean over scores). For equal repeats per question these match.
except Exception:
ci_low = None
ci_high = None

# Determine if the evaluation passed based on threshold
passed = None

if threshold is not None:
success_passed, std_passed = True, True

success_passed = agg_score >= threshold.success
retry_attempt += 1

scores = [
sum([r.evaluation_result.score for r in result if r.evaluation_result]) / len(result)
for result in all_results
]
agg_score = aggregate(scores, aggregation_method)
score_std = statistics.stdev(scores) if len(scores) > 1 else 0.0

# Compute 95% confidence interval for the fixed-set mean μ (by-question, using repeats)
ci_low: float | None = None
ci_high: float | None = None
if aggregation_method == "mean":
try:
result_ci = compute_fixed_set_mu_ci(
[item for sublist in all_results for item in sublist]
)
mu_ci_low, mu_ci_high = result_ci[1], result_ci[2]
if mu_ci_low is not None and mu_ci_high is not None:
ci_low = float(mu_ci_low)
ci_high = float(mu_ci_high)
# Keep agg_score as-is (mean over scores). For equal repeats per question these match.
except Exception:
ci_low = None
ci_high = None

if threshold.standard_deviation is not None:
std_passed = score_std <= threshold.standard_deviation
# Determine if the evaluation passed based on threshold
passed = None

if threshold is not None:
success_passed, std_passed = True, True

success_passed = agg_score >= threshold.success

if threshold.standard_deviation is not None:
std_passed = score_std <= threshold.standard_deviation

passed = success_passed and std_passed
passed = success_passed and std_passed

# Update eval metadata status and passed field for all results
for result in all_results:
for r in result:
if r.eval_metadata is not None:
r.eval_metadata.status = "finished"
r.eval_metadata.passed = passed
if r.rollout_status is not None:
if r.rollout_status.status != "error":
r.rollout_status.status = "finished"
r.rollout_status.passed = passed
active_logger.log(r)

# Optional: print and/or persist a summary artifact for CI
Expand Down
11 changes: 11 additions & 0 deletions eval_protocol/pytest/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ def pytest_addoption(parser) -> None:
"Values: low|medium|high"
),
)
group.addoption(
"--ep-max-retry",
action="store",
type=int,
default=None,
help=("Failed rollouts (with rollout_status.status == 'error') will be retried up to this many times."),
)


def _normalize_max_rows(val: Optional[str]) -> Optional[str]:
Expand Down Expand Up @@ -100,6 +107,10 @@ def pytest_configure(config) -> None:
if summary_json_path:
os.environ["EP_SUMMARY_JSON"] = summary_json_path

max_retry = config.getoption("--ep-max-retry")
if max_retry is not None:
os.environ["EP_MAX_RETRY"] = str(max_retry)

# Allow ad-hoc overrides of input params via CLI flags
try:
import json as _json
Expand Down
Loading
Loading