Skip to content

Commit 5bc39d7

Browse files
committed
Add AIME2025, GPQA, HealthBench evaluation_test suites; unify row-limiting via pytest flag; clean up examples
1 parent c35d7f0 commit 5bc39d7

File tree

11 files changed

+579
-12
lines changed

11 files changed

+579
-12
lines changed

eval_protocol/common_utils.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import re
33
from typing import Any, Dict, List
44

5+
import requests
6+
57

68
def load_jsonl(file_path: str) -> List[Dict[str, Any]]:
79
"""
@@ -15,16 +17,39 @@ def load_jsonl(file_path: str) -> List[Dict[str, Any]]:
1517
Returns an empty list if the file is not found or if errors occur during parsing.
1618
"""
1719
data: List[Dict[str, Any]] = []
18-
with open(file_path, "r", encoding="utf-8") as f:
19-
for line_number, line in enumerate(f):
20+
if file_path.startswith("http://") or file_path.startswith("https://"):
21+
resp = requests.get(file_path, stream=True, timeout=30)
22+
resp.raise_for_status()
23+
for line_number, raw in enumerate(resp.iter_lines(decode_unicode=True), start=1):
24+
if raw is None:
25+
continue
26+
stripped = raw.strip()
27+
if not stripped:
28+
continue
2029
try:
21-
data.append(json.loads(line.strip()))
30+
data.append(json.loads(stripped))
2231
except json.JSONDecodeError as e:
23-
print(f"Error parsing JSON line for file {file_path} at line {line_number}")
24-
# attempt to find "row_id" in the line by finding index of "row_id" and performing regex of `"row_id": (.*),`
25-
row_id_index = line.find("row_id")
32+
print(f"Error parsing JSON line for URL {file_path} at line {line_number}")
33+
row_id_index = stripped.find("row_id")
2634
if row_id_index != -1:
27-
row_id = re.search(r'"row_id": (.*),', line[row_id_index:])
28-
raise ValueError(f"{e.msg} at line {line_number}: {line} ({row_id})")
35+
row_id = re.search(r'"row_id": (.*),', stripped[row_id_index:])
36+
raise ValueError(f"{e.msg} at line {line_number}: {stripped} ({row_id})")
2937
raise e
38+
else:
39+
with open(file_path, "r", encoding="utf-8") as f:
40+
for line_number, line in enumerate(f, start=1):
41+
# Skip entirely blank or whitespace-only lines to be robust to trailing newlines
42+
stripped = line.strip()
43+
if not stripped:
44+
continue
45+
try:
46+
data.append(json.loads(stripped))
47+
except json.JSONDecodeError as e:
48+
print(f"Error parsing JSON line for file {file_path} at line {line_number}")
49+
# attempt to find "row_id" in the line by finding index of "row_id" and performing regex of `"row_id": (.*),`
50+
row_id_index = line.find("row_id")
51+
if row_id_index != -1:
52+
row_id = re.search(r'"row_id": (.*),', line[row_id_index:])
53+
raise ValueError(f"{e.msg} at line {line_number}: {line} ({row_id})")
54+
raise e
3055
return data

eval_protocol/generation/clients.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import aiohttp
1313
from omegaconf import DictConfig
14-
from pydantic import BaseModel, Field # Added for new models
14+
from pydantic import BaseModel # Added for new models
1515

1616
logger = logging.getLogger(__name__)
1717

@@ -83,6 +83,9 @@ async def generate(
8383
}
8484
if self.top_p is not None:
8585
payload["top_p"] = self.top_p
86+
# Include reasoning settings if configured (for reasoning-capable models)
87+
if self.reasoning_effort:
88+
payload["reasoning_effort"] = self.reasoning_effort
8689

8790
if tools:
8891
payload["tools"] = tools

eval_protocol/pytest/evaluation_test.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import inspect
22
import os
3+
import os
34
from typing import Any, Callable, Dict, List, Optional
45

56
import pytest
@@ -132,13 +133,34 @@ def execute_with_params(
132133
return execute_function(test_func, **kwargs)
133134

134135
# Calculate all possible combinations of parameters
136+
def _parse_ep_max_rows(default_value: int | None) -> int | None:
137+
"""Read EP_MAX_DATASET_ROWS env override as int or None."""
138+
raw = os.getenv("EP_MAX_DATASET_ROWS")
139+
if raw is None:
140+
return default_value
141+
s = raw.strip().lower()
142+
if s == "none":
143+
return None
144+
try:
145+
return int(s)
146+
except ValueError:
147+
return default_value
148+
135149
def generate_combinations():
136150
combinations = []
137151

138152
# Handle optional parameters with defaults
139153
datasets: List[Optional[DatasetPathParam]] = input_dataset if input_dataset is not None else [None] # type: ignore
140154
params: List[Optional[RolloutInputParam]] = rollout_input_params if rollout_input_params is not None else [None] # type: ignore
141-
messages: List[Optional[InputMessagesParam]] = input_messages if input_messages is not None else [None] # type: ignore
155+
# Apply EP_MAX_DATASET_ROWS to input_messages to uniformly control row count when messages are provided
156+
if input_messages is not None and isinstance(input_messages, list):
157+
effective_max_rows = _parse_ep_max_rows(max_dataset_rows)
158+
if effective_max_rows is not None:
159+
messages: List[Optional[InputMessagesParam]] = input_messages[:effective_max_rows] # type: ignore
160+
else:
161+
messages = input_messages # type: ignore
162+
else:
163+
messages = [None] # type: ignore
142164
kwargs: List[Optional[EvaluationInputParam]] = evaluation_test_kwargs if evaluation_test_kwargs is not None else [None] # type: ignore
143165

144166
# Generate all combinations
@@ -201,8 +223,10 @@ def wrapper_body(**kwargs):
201223
data: List[EvaluationRow] = []
202224
if "dataset_path" in kwargs and kwargs["dataset_path"] is not None:
203225
data_jsonl = load_jsonl(kwargs["dataset_path"])
204-
if max_dataset_rows is not None:
205-
data_jsonl = data_jsonl[:max_dataset_rows]
226+
# Apply env override for max rows if present
227+
effective_max_rows = _parse_ep_max_rows(max_dataset_rows)
228+
if effective_max_rows is not None:
229+
data_jsonl = data_jsonl[:effective_max_rows]
206230
data = dataset_adapter(data_jsonl)
207231
elif "input_messages" in kwargs and kwargs["input_messages"] is not None:
208232
data: List[EvaluationRow] = [EvaluationRow(messages=kwargs["input_messages"])]

eval_protocol/pytest/plugin.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
"""
2+
Pytest plugin for Eval Protocol developer ergonomics.
3+
4+
Adds a discoverable CLI flag `--ep-max-rows` to control how many rows
5+
evaluation_test processes. This sets the environment variable
6+
`EP_MAX_DATASET_ROWS` so the core decorator can apply it uniformly to
7+
both URL datasets and in-memory input_messages.
8+
9+
Usage:
10+
- CLI: pytest --ep-max-rows=2 # or --ep-max-rows=all for no limit
11+
- Defaults: If not provided, no override is applied (tests use the
12+
max_dataset_rows value set in the decorator).
13+
"""
14+
15+
import os
16+
from typing import Optional
17+
18+
import pytest
19+
20+
21+
def pytest_addoption(parser: pytest.Parser) -> None:
22+
group = parser.getgroup("eval-protocol")
23+
group.addoption(
24+
"--ep-max-rows",
25+
action="store",
26+
default=None,
27+
help=(
28+
"Limit number of dataset rows processed by evaluation_test. "
29+
"Pass an integer (e.g., 2, 50) or 'all' for no limit."
30+
),
31+
)
32+
33+
34+
def _normalize_max_rows(val: Optional[str]) -> Optional[str]:
35+
if val is None:
36+
return None
37+
s = val.strip().lower()
38+
if s == "all":
39+
return "None"
40+
# Validate int; if invalid, ignore and return None (no override)
41+
try:
42+
int(s)
43+
return s
44+
except ValueError:
45+
return None
46+
47+
48+
def pytest_configure(config: pytest.Config) -> None:
49+
cli_val = config.getoption("--ep-max-rows")
50+
norm = _normalize_max_rows(cli_val)
51+
if norm is not None:
52+
os.environ["EP_MAX_DATASET_ROWS"] = norm
53+
54+
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
## AIME2025 Chat Completion Example
2+
3+
This example reproduces gpt-oss's AIME2025 chat completion evaluation inside Eval Protocol.
4+
5+
### What it does
6+
- Loads AIME2025 questions from Hugging Face
7+
- Prompts a reasoning-capable chat-completions model
8+
- Extracts the final integer answer from \boxed{...}
9+
- Scores exact-match vs. the ground-truth integer
10+
11+
### Quick run (pytest, CI-friendly)
12+
The evaluation is implemented as a pytest `evaluation_test` under `tests/`. Run it directly:
13+
14+
```bash
15+
pytest -q examples/aime2025_chat_completion/tests/test_evaluation.py -q
16+
```
17+
18+
Environment variables expected:
19+
- `FIREWORKS_API_KEY`
20+
21+
To scale up, adjust parameters in the decorator (e.g., `threshold_of_success`, `max_dataset_rows`).
22+
23+
24+
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
__all__ = ["main"]
2+
3+
4+
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
"""
2+
Eval Protocol example: AIME2025 chat completion evaluation
3+
4+
This example mirrors gpt-oss's AIME 2025 evaluation using OpenAI-compatible
5+
chat completions. It evaluates whether the assistant's final answer matches the
6+
ground-truth integer, extracting answers from \\boxed{...} or fallback digits.
7+
"""
8+
9+
import re
10+
from typing import Any, Dict, List, Optional, Union
11+
12+
from eval_protocol import EvaluateResult, MetricResult, reward_function
13+
from eval_protocol.models import Message
14+
15+
16+
def _extract_boxed_text(text: str) -> str:
17+
"""
18+
Extract the last occurrence of a boxed answer (\\boxed{...} or \\framebox{...}).
19+
If none found, fall back to the last integer found in the text.
20+
"""
21+
if not text:
22+
return ""
23+
24+
pattern_boxed = r"boxed{(.*?)}|framebox{(.*?)}"
25+
matches = re.findall(pattern_boxed, text, re.DOTALL)
26+
if matches:
27+
# Iterate from the end to prioritize the final boxed answer
28+
for match in matches[::-1]:
29+
for group in match:
30+
if group:
31+
return group.split(",")[-1].strip()
32+
33+
# Fallback: last integer in the text
34+
matches_digits = re.findall(r"\d+", text, re.DOTALL)
35+
if matches_digits:
36+
return matches_digits[-1]
37+
return ""
38+
39+
40+
def _normalize_to_int_or_none(s: str) -> Optional[int]:
41+
if s is None:
42+
return None
43+
# Only take leading digits
44+
m = re.match(r"\d+", str(s).strip())
45+
if not m:
46+
return None
47+
try:
48+
return int(m.group(0))
49+
except ValueError:
50+
return None
51+
52+
53+
@reward_function(id="aime2025_exact_match")
54+
def evaluate(
55+
messages: Union[List[Message], List[Dict[str, Any]]],
56+
ground_truth: Optional[str] = None,
57+
**kwargs,
58+
) -> EvaluateResult:
59+
"""
60+
Score 1.0 if extracted final answer equals the ground-truth integer, else 0.0.
61+
"""
62+
if not messages:
63+
return EvaluateResult(
64+
score=0.0,
65+
reason="No messages provided",
66+
is_score_valid=False,
67+
metrics={
68+
"parse_status": MetricResult(score=0.0, is_score_valid=False, reason="empty messages")
69+
},
70+
)
71+
72+
last_msg = messages[-1]
73+
content = last_msg["content"] if isinstance(last_msg, dict) else (last_msg.content or "")
74+
75+
extracted_text = _extract_boxed_text(content)
76+
extracted_int = _normalize_to_int_or_none(extracted_text)
77+
gt_int = _normalize_to_int_or_none(ground_truth if ground_truth is not None else "")
78+
79+
is_valid = extracted_int is not None and gt_int is not None
80+
score = 1.0 if (is_valid and extracted_int == gt_int) else 0.0
81+
82+
metrics: Dict[str, MetricResult] = {
83+
"exact_match": MetricResult(
84+
score=score,
85+
is_score_valid=is_valid,
86+
reason=(
87+
"Parsed both integers and they matched"
88+
if score == 1.0
89+
else (
90+
"Parsed integers did not match"
91+
if is_valid
92+
else "Failed to parse integer from prediction or ground truth"
93+
)
94+
),
95+
data={
96+
"extracted_text": extracted_text,
97+
"extracted_int": extracted_int,
98+
"ground_truth_int": gt_int,
99+
},
100+
)
101+
}
102+
103+
return EvaluateResult(
104+
score=score,
105+
reason=("Answer correct" if score == 1.0 else "Answer incorrect"),
106+
is_score_valid=is_valid,
107+
metrics=metrics,
108+
)
109+
110+

0 commit comments

Comments
 (0)