Skip to content

Commit 9f67172

Browse files
author
Dylan Huang
committed
Merge branch 'main' into aggregated-metrics-part-5
2 parents 050a0a1 + fd1c7c9 commit 9f67172

13 files changed

+129
-83
lines changed

eval_protocol/mcp/execution/manager.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,8 @@ async def _execute_with_semaphore(idx):
158158
messages.append(Message.model_validate(msg_dict))
159159

160160
evaluation_rows[idx].messages = messages
161-
evaluation_rows[idx].input_metadata.row_id = envs.dataset_rows[idx].id
162-
evaluation_rows[idx].input_metadata.dataset_info = asdict(envs.dataset_rows[idx])
161+
# evaluation_rows[idx].input_metadata.row_id = envs.dataset_rows[idx].id
162+
# evaluation_rows[idx].input_metadata.dataset_info = asdict(envs.dataset_rows[idx])
163163
evaluation_rows[idx].tools = shared_tool_schema
164164
evaluation_rows[idx].usage = CompletionUsage(**trajectory.usage)
165165
evaluation_rows[idx].input_metadata.completion_params = CompletionParams(
@@ -482,11 +482,11 @@ async def _execute_rollout(
482482
trajectory.control_plane_summary.update({"error_message": f"{failure_reason}"})
483483
try:
484484
await envs.connection_manager.reset_session(session)
485-
except:
485+
except: # noqa: E722
486486
logger.error(f"Error resetting session {session.session_id}")
487487
try:
488488
await envs.connection_manager.close_session(session)
489-
except:
489+
except: # noqa: E722
490490
logger.error(f"Error closing session {session.session_id}")
491491
return trajectory
492492

eval_protocol/models.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,21 @@ class InputMetadata(BaseModel):
202202
)
203203

204204

205+
class EvaluationThreshold(BaseModel):
206+
"""Threshold configuration for evaluation tests.
207+
208+
The success field is required - tests must specify a minimum success rate.
209+
The standard_deviation field is optional - if provided, tests must also meet the maximum standard deviation requirement.
210+
"""
211+
212+
success: float = Field(
213+
..., description="Minimum success rate threshold (fraction of total score, 0.0 to 1.0)", ge=0.0, le=1.0
214+
)
215+
standard_deviation: Optional[float] = Field(
216+
None, description="Maximum standard deviation threshold (fraction of total score, 0.0 to 1.0)", ge=0.0, le=1.0
217+
)
218+
219+
205220
class EvalMetadata(BaseModel):
206221
"""Metadata about the evaluation that was run."""
207222

@@ -216,7 +231,9 @@ class EvalMetadata(BaseModel):
216231
)
217232
num_runs: int = Field(..., description="Number of times the evaluation was repeated")
218233
aggregation_method: str = Field(..., description="Method used to aggregate scores across runs")
219-
threshold_of_success: Optional[float] = Field(None, description="Threshold score for test success")
234+
passed_threshold: Optional[EvaluationThreshold] = Field(
235+
None, description="Threshold configuration for test success"
236+
)
220237
passed: Optional[bool] = Field(None, description="Whether the evaluation passed based on the threshold")
221238

222239

eval_protocol/pytest/evaluation_test.py

Lines changed: 57 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,21 @@
33
import math
44
import os
55
import statistics
6-
from typing import Any, Callable, Dict, List, Literal, Optional
6+
from typing import Any, Callable, Dict, List, Literal, Optional, Union
77

88
import pytest
99

1010
from eval_protocol.dataset_logger import default_logger
1111
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
1212
from eval_protocol.human_id import generate_id
13-
from eval_protocol.models import CompletionParams, EvalMetadata, EvaluationRow, InputMetadata, Message
13+
from eval_protocol.models import (
14+
CompletionParams,
15+
EvalMetadata,
16+
EvaluationRow,
17+
EvaluationThreshold,
18+
InputMetadata,
19+
Message,
20+
)
1421
from eval_protocol.pytest.default_dataset_adapter import default_dataset_adapter
1522
from eval_protocol.pytest.default_no_op_rollout_process import default_no_op_rollout_processor
1623
from eval_protocol.pytest.types import (
@@ -47,7 +54,7 @@ def evaluation_test( # noqa: C901
4754
rollout_processor: RolloutProcessor = default_no_op_rollout_processor,
4855
evaluation_test_kwargs: Optional[List[EvaluationInputParam]] = None,
4956
aggregation_method: AggregationMethod = "mean",
50-
threshold_of_success: Optional[float] = None,
57+
passed_threshold: Optional[Union[EvaluationThreshold, float]] = None,
5158
num_runs: int = 1,
5259
max_dataset_rows: Optional[int] = None,
5360
mcp_config_path: Optional[str] = None,
@@ -108,8 +115,8 @@ def evaluation_test( # noqa: C901
108115
rollout_processor: Function used to perform the rollout.
109116
evaluation_test_kwargs: Kwargs for the evaluation function.
110117
aggregation_method: How to aggregate scores across rows.
111-
threshold_of_success: If set, fail the test if the aggregated score is
112-
below this threshold.
118+
passed_threshold: Threshold configuration for test success.
119+
Success rate must be above success, and if set, standard deviation must be below standard_deviation.
113120
num_runs: Number of times to repeat the rollout and evaluations.
114121
max_dataset_rows: Limit dataset to the first N rows.
115122
mcp_config_path: Path to MCP config file that follows MCPMultiClientConfiguration schema
@@ -127,6 +134,14 @@ def evaluation_test( # noqa: C901
127134
def decorator(
128135
test_func: TestFunction,
129136
):
137+
if passed_threshold is not None:
138+
if isinstance(passed_threshold, float):
139+
threshold = EvaluationThreshold(success=passed_threshold)
140+
else:
141+
threshold = EvaluationThreshold(**passed_threshold)
142+
else:
143+
threshold = None
144+
130145
sig = inspect.signature(test_func)
131146

132147
# For pointwise/rowwise mode, we expect a different signature
@@ -285,7 +300,7 @@ def create_wrapper_with_signature() -> Callable:
285300
def wrapper_body(**kwargs):
286301
model_name = kwargs["model"]
287302
eval_metadata = None
288-
all_results: List[EvaluationRow] = []
303+
all_results: List[List[EvaluationRow]] = [[] for _ in range(num_runs)]
289304

290305
cohort_id = generate_id()
291306

@@ -346,7 +361,7 @@ def _log_eval_error(
346361
status="running",
347362
num_runs=num_runs,
348363
aggregation_method=aggregation_method,
349-
threshold_of_success=threshold_of_success,
364+
passed_threshold=threshold,
350365
passed=None,
351366
)
352367

@@ -386,11 +401,11 @@ def _log_eval_error(
386401
logger=active_logger,
387402
)
388403

389-
for _ in range(num_runs):
404+
for i in range(num_runs):
390405
# Regenerate outputs each run by deep-copying the pristine dataset
391406
# so model responses are not reused across runs.
392407
run_id = generate_id()
393-
fresh_dataset = [copy.deepcopy(r) for r in data]
408+
fresh_dataset = [r.model_copy(deep=True) for r in data]
394409

395410
# apply new run_id to fresh_dataset
396411
for row in fresh_dataset:
@@ -418,7 +433,7 @@ def _log_eval_error(
418433
raise ValueError(
419434
f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test."
420435
)
421-
all_results.append(result)
436+
all_results[i].append(result)
422437
else:
423438
# Batch mode: call the test function with the full dataset
424439
results = execute_with_params(
@@ -442,17 +457,21 @@ def _log_eval_error(
442457
raise ValueError(
443458
f"Test function {test_func.__name__} returned a list containing non-EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test."
444459
)
445-
all_results.extend(results)
460+
all_results[i] = results
446461

447-
scores = [r.evaluation_result.score for r in all_results if r.evaluation_result]
462+
scores = [
463+
sum([r.evaluation_result.score for r in result if r.evaluation_result]) / len(result)
464+
for result in all_results
465+
]
448466
agg_score = aggregate(scores, aggregation_method)
467+
score_std = statistics.stdev(scores) if len(scores) > 1 else 0.0
449468

450469
# Compute 95% confidence interval for the fixed-set mean μ (by-question, using repeats)
451470
ci_low: float | None = None
452471
ci_high: float | None = None
453472
if aggregation_method == "mean":
454473
try:
455-
result_ci = compute_fixed_set_mu_ci(all_results)
474+
result_ci = compute_fixed_set_mu_ci([item for sublist in all_results for item in sublist])
456475
mu_ci_low, mu_ci_high = result_ci[1], result_ci[2]
457476
if mu_ci_low is not None and mu_ci_high is not None:
458477
ci_low = float(mu_ci_low)
@@ -464,23 +483,32 @@ def _log_eval_error(
464483

465484
# Determine if the evaluation passed based on threshold
466485
passed = None
467-
if threshold_of_success is not None:
468-
passed = agg_score >= threshold_of_success
486+
487+
if threshold is not None:
488+
success_passed, std_passed = True, True
489+
490+
success_passed = agg_score >= threshold.success
491+
492+
if threshold.standard_deviation is not None:
493+
std_passed = score_std <= threshold.standard_deviation
494+
495+
passed = success_passed and std_passed
469496

470497
# Update eval metadata status and passed field for all results
471-
for r in all_results:
472-
if r.eval_metadata is not None:
473-
r.eval_metadata.status = "finished"
474-
r.eval_metadata.passed = passed
475-
active_logger.log(r)
498+
for result in all_results:
499+
for r in result:
500+
if r.eval_metadata is not None:
501+
r.eval_metadata.status = "finished"
502+
r.eval_metadata.passed = passed
503+
default_logger.log(r)
476504

477505
# Optional: print and/or persist a summary artifact for CI
478506
try:
479507
should_print = os.getenv("EP_PRINT_SUMMARY") == "1"
480508
summary_path = os.getenv("EP_SUMMARY_JSON")
481509
suite_name = test_func.__name__
482510
model_used = model_name
483-
total_rows = len(all_results)
511+
total_rows = len([item for sublist in all_results for item in sublist])
484512
summary_obj = {
485513
"suite": suite_name,
486514
"model": model_used,
@@ -497,7 +525,7 @@ def _log_eval_error(
497525
from collections import defaultdict
498526

499527
metric_scores: Dict[str, list] = defaultdict(list)
500-
for r in all_results:
528+
for r in [item for sublist in all_results for item in sublist]:
501529
if r.evaluation_result and r.evaluation_result.metrics:
502530
for m_name, m_res in r.evaluation_result.metrics.items():
503531
if m_res is not None and getattr(m_res, "score", None) is not None:
@@ -614,10 +642,14 @@ def _extract_effort_tag(params: dict) -> str | None:
614642
# pass
615643

616644
# Check threshold after logging
617-
if threshold_of_success is not None and not passed:
645+
if threshold is not None and not passed:
618646
assert (
619-
agg_score >= threshold_of_success
620-
), f"Aggregated score {agg_score:.3f} below threshold {threshold_of_success}"
647+
agg_score >= threshold.success
648+
), f"Aggregated score {agg_score:.3f} below threshold {threshold.success}"
649+
if threshold.standard_deviation is not None:
650+
assert (
651+
score_std <= threshold.standard_deviation
652+
), f"Standard deviation {score_std:.3f} above threshold {threshold.standard_deviation}"
621653

622654
except AssertionError:
623655
_log_eval_error("finished", data if "data" in locals() else None, passed=False)

tests/pytest/test_apps_coding.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,7 @@ def apps_dataset_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluatio
1818
Convert entries from APPS dataset to EvaluationRow objects.
1919
"""
2020
return [
21-
EvaluationRow(
22-
messages=[Message(role="user", content=row["question"])],
23-
ground_truth=row["input_output"]
24-
)
21+
EvaluationRow(messages=[Message(role="user", content=row["question"])], ground_truth=row["input_output"])
2522
for row in data
2623
]
2724

@@ -31,7 +28,7 @@ def apps_dataset_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluatio
3128
dataset_adapter=apps_dataset_to_evaluation_row,
3229
model=["fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"],
3330
rollout_input_params=[{"temperature": 0.0, "max_tokens": 4096}],
34-
threshold_of_success=0.33,
31+
passed_threshold=0.33,
3532
rollout_processor=default_single_turn_rollout_processor,
3633
num_runs=1,
3734
mode="pointwise",
@@ -42,7 +39,7 @@ def test_apps_code_evaluation(row: EvaluationRow) -> EvaluationRow:
4239
4340
Args:
4441
row: EvaluationRow containing the conversation messages and ground_truth as JSON string
45-
42+
4643
Returns:
4744
EvaluationRow with the evaluation result
4845
"""
@@ -51,8 +48,8 @@ def test_apps_code_evaluation(row: EvaluationRow) -> EvaluationRow:
5148
messages=row.messages,
5249
ground_truth=row.ground_truth,
5350
)
54-
51+
5552
# Set the evaluation result on the row
5653
row.evaluation_result = result
57-
58-
return row
54+
55+
return row

tests/pytest/test_basic_coding.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from eval_protocol.models import EvaluateResult, EvaluationRow, Message
1111
from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test
12-
from eval_protocol.rewards.code_execution import extract_code_blocks, execute_python_code
12+
from eval_protocol.rewards.code_execution import execute_python_code, extract_code_blocks
1313

1414

1515
def coding_dataset_to_evaluation_row(data: List[Dict[str, Any]]) -> List[EvaluationRow]:
@@ -18,8 +18,8 @@ def coding_dataset_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluat
1818
"""
1919
return [
2020
EvaluationRow(
21-
messages=[Message(role="user", content=f"{row['prompt']} Input: {row['input']}")],
22-
ground_truth=row["expected_output"]
21+
messages=[Message(role="user", content=f"{row['prompt']} Input: {row['input']}")],
22+
ground_truth=row["expected_output"],
2323
)
2424
for row in data
2525
]
@@ -30,63 +30,59 @@ def coding_dataset_to_evaluation_row(data: List[Dict[str, Any]]) -> List[Evaluat
3030
dataset_adapter=coding_dataset_to_evaluation_row,
3131
model=["fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"],
3232
rollout_input_params=[{"temperature": 0.0, "max_tokens": 4096}],
33-
threshold_of_success=0.8,
33+
passed_threshold=0.8,
3434
rollout_processor=default_single_turn_rollout_processor,
3535
num_runs=1,
3636
mode="pointwise",
3737
)
3838
def test_coding_code_evaluation(row: EvaluationRow) -> EvaluationRow:
3939
"""
4040
Evaluation function that tests code correctness by executing it locally.
41-
41+
4242
This function:
4343
1. Extracts Python code from the assistant's response
4444
2. Executes the code locally with timeout=10
4545
3. Compares the output to ground_truth
4646
4. Returns a score of 1.0 if output matches, 0.0 otherwise
47-
47+
4848
Args:
4949
row: EvaluationRow containing the conversation messages and expected_output in ground_truth
50-
50+
5151
Returns:
5252
EvaluationRow with the evaluation result
5353
"""
5454
# Check if we have an assistant response
5555
if len(row.messages) < 2 or row.messages[-1].role != "assistant":
5656
row.evaluation_result = EvaluateResult(score=0.0, reason="No assistant response found")
5757
return row
58-
58+
5959
assistant_content = row.messages[-1].content or ""
6060
expected_output = (row.ground_truth or "").strip()
61-
61+
6262
# Extract Python code blocks
6363
code_blocks = extract_code_blocks(assistant_content, language="python")
6464
if not code_blocks:
6565
row.evaluation_result = EvaluateResult(score=0.0, reason="No Python code block found")
6666
return row
67-
67+
6868
code = code_blocks[0]["code"]
69-
69+
7070
# Execute the code locally
7171
execution_result = execute_python_code(code, timeout=10)
72-
72+
7373
if not execution_result.get("success", False):
7474
error_msg = execution_result.get("error", "Code execution failed")
7575
row.evaluation_result = EvaluateResult(score=0.0, reason=f"Execution error: {error_msg}")
7676
return row
77-
77+
7878
# Compare output with expected
7979
actual_output = (execution_result.get("output", "") or "").strip()
80-
80+
8181
if actual_output == expected_output:
82-
row.evaluation_result = EvaluateResult(
83-
score=1.0,
84-
reason=f"✅ Output matches: '{actual_output}'"
85-
)
82+
row.evaluation_result = EvaluateResult(score=1.0, reason=f"✅ Output matches: '{actual_output}'")
8683
else:
8784
row.evaluation_result = EvaluateResult(
88-
score=0.0,
89-
reason=f"❌ Expected: '{expected_output}', Got: '{actual_output}'"
85+
score=0.0, reason=f"❌ Expected: '{expected_output}', Got: '{actual_output}'"
9086
)
91-
87+
9288
return row

0 commit comments

Comments
 (0)