Skip to content

Commit e355931

Browse files
author
Dylan Huang
authored
Show aggregated metrics in UI (Part 1) (#43)
* save * vite build use model_dump(mode="json") Add run_id to EvalMetadataSchema for unique run identification - Introduced run_id as an optional string to the EvalMetadataSchema to uniquely identify evaluation runs. - Updated description to clarify the purpose of the run_id field. Add run_id field to EvalMetadata for unique run identification - Added run_id as an optional string to the EvalMetadata class to uniquely identify groups of evaluation rows. - Updated the field description to clarify its purpose in relation to evaluation tests. Fix evaluation result assignment in markdown highlighting test - Updated the test_markdown_highlighting_evaluation function to assign the evaluation result directly to the row when no assistant response is found, ensuring proper handling of evaluation results. Add run_id generation in evaluation_test for unique identification - Integrated the generate_id function to create a run_id within the evaluation_test function. - Passed the generated run_id to the evaluation function, ensuring unique identification of evaluation runs. * Wrap logo image in a link to the Eval Protocol website for improved navigation. * TODO: test the pivot table logic * Refactor WebSocket log initialization message handling in logs_server.py - Simplified the construction of the log initialization message by creating a data dictionary before sending it over the WebSocket, improving code readability. * flatten json test * refine pivot.ts * Add support for composite columns in computePivot tests - Introduced a new test case to validate the handling of multiple column fields in the computePivot function. - Verified correct computation of cell values, row totals, column totals, and grand total for the pivot table with composite columns. * assertion error means finished * Refactor EvalMetadata and EvaluationRow models; add cohort_id, rollout_id, and run_id fields. Update evaluation_test to handle new identifiers and improve documentation on evaluation concepts. * Add invocation_id field to EvaluationRow model and update corresponding schema in eval-protocol types. This enhances tracking of invocation context for evaluation rows. * rename as its causing issues in pytest collection * square up all the id madness and add a test
1 parent 55005a1 commit e355931

28 files changed

+10814
-101
lines changed

eval_protocol/models.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,8 +245,8 @@ class EvaluationRow(BaseModel):
245245
supporting both row-wise batch evaluation and trajectory-based RL evaluation.
246246
"""
247247

248-
# Core conversation data
249-
messages: List[Message] = Field(description="List of messages in the conversation/trajectory.")
248+
# Core OpenAI ChatCompletion compatible conversation data
249+
messages: List[Message] = Field(description="List of messages in the conversation. Also known as a trajectory.")
250250

251251
# Tool and function call information
252252
tools: Optional[List[Dict[str, Any]]] = Field(
@@ -264,6 +264,26 @@ class EvaluationRow(BaseModel):
264264
description="The status of the rollout.",
265265
)
266266

267+
invocation_id: Optional[str] = Field(
268+
default_factory=generate_id,
269+
description="The ID of the invocation that this row belongs to.",
270+
)
271+
272+
cohort_id: Optional[str] = Field(
273+
default_factory=generate_id,
274+
description="The ID of the cohort that this row belongs to.",
275+
)
276+
277+
rollout_id: Optional[str] = Field(
278+
default_factory=generate_id,
279+
description="The ID of the rollout that this row belongs to.",
280+
)
281+
282+
run_id: Optional[str] = Field(
283+
None,
284+
description=("The ID of the run that this row belongs to."),
285+
)
286+
267287
# Ground truth reference (moved from EvaluateResult to top level)
268288
ground_truth: Optional[str] = Field(
269289
default=None, description="Optional ground truth reference for this evaluation."

eval_protocol/pytest/default_mcp_gym_rollout_processor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,8 @@ async def default_mcp_gym_rollout_processor(
200200
Returns:
201201
List of EvaluationRow objects with completed conversations
202202
"""
203+
if config.server_script_path is None:
204+
raise ValueError("server_script_path is required for default_mcp_gym_rollout_processor")
203205
server = MCPServerManager(config.server_script_path, port=9700)
204206

205207
try:

eval_protocol/pytest/evaluation_test.py

Lines changed: 92 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
import inspect
2-
import os
31
import copy
2+
import inspect
43
import math
4+
import os
55
import statistics
6-
from typing import Any, Callable, Dict, List, Optional
6+
from typing import Any, Callable, Dict, List, Literal, Optional
77

88
import pytest
99

1010
from eval_protocol.dataset_logger import default_logger
11+
from eval_protocol.human_id import generate_id
1112
from eval_protocol.models import CompletionParams, EvalMetadata, EvaluationRow, InputMetadata, Message
1213
from eval_protocol.pytest.default_dataset_adapter import default_dataset_adapter
1314
from eval_protocol.pytest.default_no_op_rollout_process import default_no_op_rollout_processor
@@ -28,13 +29,14 @@
2829
aggregate,
2930
create_dynamically_parameterized_wrapper,
3031
execute_function,
32+
log_eval_status_and_rows,
3133
)
34+
from eval_protocol.stats.confidence_intervals import compute_fixed_set_mu_ci
3235

3336
from ..common_utils import load_jsonl
34-
from eval_protocol.stats.confidence_intervals import compute_fixed_set_mu_ci
3537

3638

37-
def evaluation_test(
39+
def evaluation_test( # noqa: C901
3840
*,
3941
model: List[ModelParam],
4042
input_messages: Optional[List[InputMessagesParam]] = None,
@@ -59,6 +61,37 @@ def evaluation_test(
5961
]:
6062
"""Decorator to create pytest-based evaluation tests.
6163
64+
Here are some key concepts to understand the terminology in EP:
65+
66+
- "invocation" is a single execution of a test function. An invocation can
67+
generate 1 or more cohorts. Grouping by invocation might be useful to
68+
aggregate eval scores across multiple invocations when you want to aggregate
69+
scores across multiple datasets.
70+
- "cohort" is a group of runs with for a combination of parameters. A single
71+
cohort will have multiple runs if num_runs > 1.
72+
1. If your evaluation_test has combinations of parameters, it will generate
73+
multiple cohorts per combination of parameters.
74+
2. A new execution of a test function will generate a new cohort.
75+
- "run" is a group of rollouts. For multiple num_runs > 1, there will be
76+
multiple "run_id"s.
77+
- "rollout" is the execution/process that produces a "trajectory". You
78+
"execute" multiple rollouts to generate a dataset of trajectories.
79+
- "trajectory" is the result produced by a rollout — a list of OpenAI Chat
80+
Completion messages (e.g. the "messages" field in EvaluationRow).
81+
- "row" both the input and output of an evaluation. For example, in
82+
tau-bench, a row is a task within the dataset that can be identified as
83+
"airline_task_0" or "airline_task_1" etc. The "row_id" can be populated from
84+
the dataset itself to identify a particular task you want to evaluate. If
85+
not provided, EP will generate a "row_id" for each row whenever you call the
86+
evaluation test.
87+
- "dataset" is a collection of rows (e.g. List[EvauluationRow])
88+
- "eval" is a rubric implemented in the body of an @evaluation_test
89+
decorated test. It simply produces a score from 0 to 1 and attached it
90+
to the row as the "evaluation_result" field.
91+
92+
"invocation", "cohort", "run", "rollout", and "row" each have a unique ID
93+
which can be used to easily group and identify your dataset by.
94+
6295
Args:
6396
model: Model identifiers to query.
6497
input_messages: Messages to send to the model. This is useful if you
@@ -75,7 +108,7 @@ def evaluation_test(
75108
aggregation_method: How to aggregate scores across rows.
76109
threshold_of_success: If set, fail the test if the aggregated score is
77110
below this threshold.
78-
num_runs: Number of times to repeat the evaluation.
111+
num_runs: Number of times to repeat the rollout and evaluations.
79112
max_dataset_rows: Limit dataset to the first N rows.
80113
mcp_config_path: Path to MCP config file that follows MCPMultiClientConfiguration schema
81114
max_concurrent_rollouts: Maximum number of concurrent rollouts to run in parallel.
@@ -119,15 +152,15 @@ def decorator(
119152

120153
def execute_with_params(
121154
test_func: TestFunction,
122-
row: EvaluationRow | None = None,
123-
input_dataset: List[EvaluationRow] | None = None,
155+
processed_row: EvaluationRow | None = None,
156+
processed_dataset: List[EvaluationRow] | None = None,
124157
evaluation_test_kwargs: Optional[EvaluationInputParam] = None,
125158
):
126159
kwargs = {}
127-
if input_dataset is not None:
128-
kwargs["rows"] = input_dataset
129-
if row is not None:
130-
kwargs["row"] = row
160+
if processed_dataset is not None:
161+
kwargs["rows"] = processed_dataset
162+
if processed_row is not None:
163+
kwargs["row"] = processed_row
131164
if evaluation_test_kwargs is not None:
132165
if "row" in evaluation_test_kwargs:
133166
raise ValueError("'row' is a reserved parameter for the evaluation function")
@@ -176,7 +209,7 @@ def generate_combinations():
176209
datasets = [[input_dataset]] # type: ignore
177210
else:
178211
datasets = [None]
179-
params: List[Optional[RolloutInputParam]] = rollout_input_params if rollout_input_params is not None else [None] # type: ignore
212+
rips: List[Optional[RolloutInputParam]] = rollout_input_params if rollout_input_params is not None else [None] # type: ignore
180213
# Apply EP_MAX_DATASET_ROWS to input_messages, but do NOT parameterize over
181214
# each row. Instead, pass the entire sliced list through in a single test run
182215
# so summaries aggregate all rows together (AIME-style behavior).
@@ -195,15 +228,15 @@ def generate_combinations():
195228
# Generate all combinations
196229
for m in model:
197230
for ds in datasets:
198-
for ip in params:
231+
for rip in rips:
199232
for im in messages:
200233
for etk in kwargs:
201234
# if no dataset and no messages, raise an error
202235
if ds is None and im is None:
203236
raise ValueError(
204237
"No dataset or messages provided. Please provide at least one of input_dataset or input_messages."
205238
)
206-
combinations.append((m, ds, ip, im, etk))
239+
combinations.append((m, ds, rip, im, etk))
207240

208241
return combinations
209242

@@ -216,12 +249,12 @@ def generate_combinations():
216249
# Create parameter tuples for pytest.mark.parametrize
217250
param_tuples = []
218251
for combo in combinations:
219-
model_name, dataset, params, messages, etk = combo
252+
model_name, dataset, rip, messages, etk = combo
220253
param_tuple = [model_name]
221254
if input_dataset is not None:
222255
param_tuple.append(dataset)
223256
if rollout_input_params is not None:
224-
param_tuple.append(params)
257+
param_tuple.append(rip)
225258
if input_messages is not None:
226259
param_tuple.append(messages)
227260
if evaluation_test_kwargs is not None:
@@ -242,11 +275,20 @@ def generate_combinations():
242275
# Create wrapper function with exact signature that pytest expects
243276
def create_wrapper_with_signature() -> Callable:
244277
# Create the function body that will be used
278+
invocation_id = generate_id()
279+
245280
def wrapper_body(**kwargs):
246281
model_name = kwargs["model"]
247282
eval_metadata = None
248283
all_results: List[EvaluationRow] = []
249284

285+
cohort_id = generate_id()
286+
287+
def _log_eval_error(
288+
status: Literal["finished", "error"], rows: Optional[List[EvaluationRow]] | None, passed: bool
289+
) -> None:
290+
log_eval_status_and_rows(eval_metadata, rows, status, passed, default_logger)
291+
250292
try:
251293
# Handle dataset loading
252294
data: List[EvaluationRow] = []
@@ -283,6 +325,7 @@ def wrapper_body(**kwargs):
283325
# into input_params (e.g., '{"temperature":0,"extra_body":{"reasoning":{"effort":"low"}}}').
284326
try:
285327
import json as _json
328+
286329
_env_override = os.getenv("EP_INPUT_PARAMS_JSON")
287330
if _env_override:
288331
override_obj = _json.loads(_env_override)
@@ -320,6 +363,8 @@ def wrapper_body(**kwargs):
320363
row.input_metadata.session_data["mode"] = mode
321364
# Initialize eval_metadata for each row
322365
row.eval_metadata = eval_metadata
366+
row.cohort_id = cohort_id
367+
row.invocation_id = invocation_id
323368

324369
# has to be done in the pytest main process since it's
325370
# used to determine whether this eval has stopped
@@ -339,14 +384,25 @@ def wrapper_body(**kwargs):
339384
for _ in range(num_runs):
340385
# Regenerate outputs each run by deep-copying the pristine dataset
341386
# so model responses are not reused across runs.
342-
fresh_rows = [copy.deepcopy(r) for r in data]
343-
input_dataset = execute_function(rollout_processor, rows=fresh_rows, config=config)
387+
run_id = generate_id()
388+
fresh_dataset = [copy.deepcopy(r) for r in data]
389+
390+
# apply new run_id to fresh_dataset
391+
for row in fresh_dataset:
392+
row.run_id = run_id
393+
394+
# generate new rollout_id for each row
395+
for row in fresh_dataset:
396+
row.rollout_id = generate_id()
397+
398+
processed_dataset = execute_function(rollout_processor, rows=fresh_dataset, config=config)
399+
344400
if mode == "pointwise":
345401
# Pointwise mode: apply the evaluator function to each row
346-
for row in input_dataset:
402+
for row in processed_dataset:
347403
result = execute_with_params(
348404
test_func,
349-
row=row,
405+
processed_row=row,
350406
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
351407
)
352408
if result is None or not isinstance(result, EvaluationRow):
@@ -358,7 +414,7 @@ def wrapper_body(**kwargs):
358414
# Batch mode: call the test function with the full dataset
359415
results = execute_with_params(
360416
test_func,
361-
input_dataset=input_dataset,
417+
processed_dataset=processed_dataset,
362418
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
363419
)
364420
if results is None:
@@ -430,6 +486,7 @@ def wrapper_body(**kwargs):
430486
# Aggregate per-metric mean and 95% CI when available
431487
metrics_summary: Dict[str, Dict[str, float]] = {}
432488
from collections import defaultdict
489+
433490
metric_scores: Dict[str, list] = defaultdict(list)
434491
for r in all_results:
435492
if r.evaluation_result and r.evaluation_result.metrics:
@@ -470,7 +527,10 @@ def wrapper_body(**kwargs):
470527
)
471528
# As per project convention, avoid printing per-metric CI lines to reduce noise
472529
if summary_path:
473-
import json, pathlib, time, re
530+
import json
531+
import pathlib
532+
import re
533+
import time
474534

475535
def _sanitize_filename(text: str) -> str:
476536
safe = re.sub(r"[^A-Za-z0-9._-]+", "-", text.strip())
@@ -487,7 +547,11 @@ def _extract_effort_tag(params: dict) -> str | None:
487547
return str(eb["reasoning"]["effort"]).lower()
488548
if "reasoning_effort" in eb:
489549
return str(eb["reasoning_effort"]).lower()
490-
if "reasoning" in params and isinstance(params["reasoning"], dict) and "effort" in params["reasoning"]:
550+
if (
551+
"reasoning" in params
552+
and isinstance(params["reasoning"], dict)
553+
and "effort" in params["reasoning"]
554+
):
491555
return str(params["reasoning"]["effort"]).lower()
492556
except Exception:
493557
return None
@@ -529,25 +593,11 @@ def _extract_effort_tag(params: dict) -> str | None:
529593
agg_score >= threshold_of_success
530594
), f"Aggregated score {agg_score:.3f} below threshold {threshold_of_success}"
531595

596+
except AssertionError:
597+
_log_eval_error("finished", data if "data" in locals() else None, passed=False)
598+
raise
532599
except Exception:
533-
# Update eval metadata status to error and log it
534-
if eval_metadata is not None:
535-
eval_metadata.status = "error"
536-
eval_metadata.passed = False
537-
538-
# Create a minimal result row to log the error if we don't have any results yet
539-
if not data:
540-
error_row = EvaluationRow(messages=[], eval_metadata=eval_metadata, evaluation_result=None)
541-
default_logger.log(error_row)
542-
else:
543-
# Update existing results with error status
544-
for r in data:
545-
if r.eval_metadata is not None:
546-
r.eval_metadata.status = "error"
547-
r.eval_metadata.passed = False
548-
default_logger.log(r)
549-
550-
# Re-raise the exception to maintain pytest behavior
600+
_log_eval_error("error", data if "data" in locals() else None, passed=False)
551601
raise
552602

553603
return create_dynamically_parameterized_wrapper(test_func, wrapper_body, test_param_names)

eval_protocol/pytest/utils.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import asyncio
22
import inspect
3-
from typing import Any, Callable, List, Literal
3+
from typing import Any, Callable, List, Literal, Optional
4+
5+
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
6+
from eval_protocol.models import EvalMetadata, EvaluationRow
47

58

69
def execute_function(func: Callable, **kwargs) -> Any:
@@ -92,3 +95,32 @@ def wrapper(**kwargs):
9295
wrapper.__signature__ = inspect.Signature(parameters)
9396

9497
return wrapper
98+
99+
100+
def log_eval_status_and_rows(
101+
eval_metadata: Optional[EvalMetadata],
102+
rows: Optional[List[EvaluationRow]] | None,
103+
status: Literal["finished", "error"],
104+
passed: bool,
105+
logger: DatasetLogger,
106+
) -> None:
107+
"""Update eval status and emit rows to the given logger.
108+
109+
If no rows are provided, emits a minimal placeholder row so downstream
110+
consumers still observe a terminal status.
111+
"""
112+
if eval_metadata is None:
113+
return
114+
115+
eval_metadata.status = status
116+
eval_metadata.passed = passed
117+
118+
rows_to_log: List[EvaluationRow] = rows or []
119+
if not rows_to_log:
120+
error_row = EvaluationRow(messages=[], eval_metadata=eval_metadata, evaluation_result=None)
121+
logger.log(error_row)
122+
else:
123+
for r in rows_to_log:
124+
if r.eval_metadata is not None:
125+
r.eval_metadata.status = status
126+
logger.log(r)

0 commit comments

Comments
 (0)