Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 160 additions & 1 deletion eval_protocol/benchmarks/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _deep_update(base: Dict[str, Any], over: Dict[str, Any]) -> Dict[str, Any]:
server_script_path = ep_config.get("server_script_path")
steps = ep_config.get("steps")
mode = ep_config.get("mode")
combine_datasets = ep_config.get("combine_datasets")
# combine_datasets captured but not used here

# Choose the first rollout param set by default
rollout_params = None
Expand Down Expand Up @@ -169,3 +169,162 @@ def _deep_update(base: Dict[str, Any], over: Dict[str, Any]) -> Dict[str, Any]:
return test_wrapper

return _decorator


def register_composite_benchmark(name: str, children: List[str]) -> None:
"""
Register a composite benchmark that runs multiple exported benchmarks and aggregates results.

The composite runner forwards common overrides to each child benchmark and aggregates
a combined score as a rows-weighted mean of each child's aggregated score.

Args:
name: Name of the composite benchmark to register.
children: List of child benchmark names previously registered via export_benchmark.
"""

def _composite_runner(
*,
model: Optional[str] = None,
print_summary: bool = False,
out: Optional[str] = None,
reasoning_effort: Optional[str] = None,
max_rows: Optional[int | str] = None,
num_runs: Optional[int] = None,
input_params_override: Optional[Dict[str, Any]] = None,
max_concurrency: Optional[int] = None,
) -> Dict[str, Any]:
# Resolve child runners at call-time to ensure all suites are imported
# Local import avoided to prevent circular import at module import time
_get_benchmark_runner = get_benchmark_runner
import pathlib as _pathlib
import time as _time
_json = json

child_summaries: List[Dict[str, Any]] = []
total_rows = 0
weighted_sum = 0.0
# For per-metric aggregation across children
metric_weighted_sums: Dict[str, float] = {}
metric_total_rows: Dict[str, int] = {}
combined_rows: List[Any] = []

# If 'out' is a file path, also compute a directory for child artifacts
child_out_dir: Optional[str] = None
if out:
p = _pathlib.Path(out)
if p.suffix.lower() == ".json" and not str(out).endswith("/"):
# Use parent directory for child artifacts
child_out_dir = str(p.parent)
else:
child_out_dir = out

for child_name in children:
runner = _get_benchmark_runner(child_name)
result = runner(
model=model,
print_summary=print_summary,
out=child_out_dir,
reasoning_effort=reasoning_effort,
max_rows=max_rows,
num_runs=num_runs,
input_params_override=input_params_override,
max_concurrency=max_concurrency,
)
summary = (result or {}).get("summary") if isinstance(result, dict) else None
if not summary:
continue
# Gather underlying rows to recompute CI across children
try:
rows_obj = result.get("results") if isinstance(result, dict) else None
if isinstance(rows_obj, list):
combined_rows.extend(rows_obj)
except Exception:
pass
child_summaries.append(summary)
rows = int(summary.get("rows", 0) or 0)
agg = summary.get("agg_score")
if isinstance(agg, (int, float)) and rows > 0:
total_rows += rows
weighted_sum += float(agg) * rows
# Combine per-metric means if available
metrics_agg = summary.get("metrics_agg") or {}
if isinstance(metrics_agg, dict):
for m_name, m_vals in metrics_agg.items():
m_mean = m_vals.get("mean")
if isinstance(m_mean, (int, float)) and rows > 0:
metric_weighted_sums[m_name] = metric_weighted_sums.get(m_name, 0.0) + float(m_mean) * rows
metric_total_rows[m_name] = metric_total_rows.get(m_name, 0) + rows

combined_agg = (weighted_sum / total_rows) if total_rows > 0 else None
# Compute 95% CI for combined rows if available
ci_low: Optional[float] = None
ci_high: Optional[float] = None
if combined_rows:
try:
from eval_protocol.stats.confidence_intervals import compute_fixed_set_mu_ci as _compute_ci

r = _compute_ci(combined_rows)
if r and len(r) >= 3 and r[1] is not None and r[2] is not None:
ci_low = float(r[1])
ci_high = float(r[2])
except Exception:
ci_low = None
ci_high = None
combined_metrics: Dict[str, Dict[str, float]] = {}
for m_name, wsum in metric_weighted_sums.items():
denom = metric_total_rows.get(m_name, 0)
if denom > 0:
combined_metrics[m_name] = {"mean": float(wsum / denom)}
combined = {
"suite": name,
"model": model,
"agg_score": float(combined_agg) if combined_agg is not None else None,
"rows": total_rows,
"children": child_summaries,
"num_runs": num_runs,
**({"metrics_agg": combined_metrics} if combined_metrics else {}),
**({"agg_ci_low": ci_low, "agg_ci_high": ci_high} if (ci_low is not None and ci_high is not None) else {}),
}

# Optional print and persist
# Respect either function arg or EP_PRINT_SUMMARY env
_should_print = print_summary or (os.getenv("EP_PRINT_SUMMARY") == "1")
if _should_print:
try:
if combined_agg is not None:
if ci_low is not None and ci_high is not None:
print(
f"EP Summary | suite={name} model={model} agg={combined['agg_score']:.3f} ci95=[{ci_low:.3f},{ci_high:.3f}] rows={total_rows}"
)
else:
print(
f"EP Summary | suite={name} model={model} agg={combined['agg_score']:.3f} rows={total_rows}"
)
else:
print(
f"EP Summary | suite={name} model={model} agg=None rows={total_rows}"
)
except Exception:
pass

if out:
out_path = _pathlib.Path(out)
if out_path.suffix.lower() == ".json" and not str(out).endswith("/"):
# Write to the specified file
out_path.parent.mkdir(parents=True, exist_ok=True)
with open(out_path, "w", encoding="utf-8") as f:
_json.dump({**combined, "timestamp": int(_time.time())}, f)
else:
# Treat as directory
dir_path = out_path
dir_path.mkdir(parents=True, exist_ok=True)
safe_name = name.replace("/", "__")
file_path = dir_path / f"{safe_name}__composite.json"
with open(file_path, "w", encoding="utf-8") as f:
_json.dump({**combined, "timestamp": int(_time.time())}, f)

return {"summary": combined}

# Register (overwrite if exists)
_BENCHMARK_REGISTRY[name] = _composite_runner
1 change: 1 addition & 0 deletions eval_protocol/benchmarks/suites/aime25.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def aime2025_dataset_adapter(rows: List[Dict[str, Any]]) -> List[EvaluationRow]:
rollout_input_params=[{"max_tokens": 131000, "extra_body": {"reasoning_effort": "low"}}],
rollout_processor=default_single_turn_rollout_processor,
aggregation_method="mean",
passed_threshold=None,
num_runs=8,
max_dataset_rows=2,
max_concurrent_rollouts=4,
Expand Down
26 changes: 20 additions & 6 deletions eval_protocol/benchmarks/suites/gpqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ def _load_gpqa_messages_from_csv() -> List[List[Message]]:
[
Message(role="system", content=SYSTEM_PROMPT),
Message(role="user", content=user_content),
# Correct answer is always option A by construction
Message(role="system", content="__GT__:A"),
]
)
if not messages_list:
Expand All @@ -57,14 +55,31 @@ def _extract_abcd_letter(text: str) -> str | None:

_GPQA_INPUT_MESSAGES = _load_gpqa_messages_from_csv()

def _strip_gt_messages(msgs: List[Message]) -> List[Message]:
return [m for m in msgs if not (m.role == "system" and (m.content or "").startswith("__GT__:"))]


async def gpqa_strip_gt_rollout_processor(rows: List[EvaluationRow], config) -> List[EvaluationRow]:
"""Preprocess rows to set ground_truth and remove __GT__ messages, then delegate to default processor."""
processed: List[EvaluationRow] = []
for r in rows:
gt_tokens = [m.content for m in r.messages if m.role == "system" and (m.content or "").startswith("__GT__:")]
if gt_tokens:
gt_val = gt_tokens[-1].split(":", 1)[1].strip()
r.ground_truth = gt_val
r.messages = [m for m in r.messages if not (m.role == "system" and (m.content or "").startswith("__GT__:"))]
processed.append(r)
return await default_single_turn_rollout_processor(processed, config)


@export_benchmark("gpqa")
@evaluation_test(
model=["fireworks_ai/accounts/fireworks/models/gpt-oss-120b"],
input_messages=_GPQA_INPUT_MESSAGES,
rollout_input_params=[{"extra_body": {"reasoning_effort": "low"}}],
rollout_processor=default_single_turn_rollout_processor,
rollout_processor=gpqa_strip_gt_rollout_processor,
aggregation_method="mean",
passed_threshold=None,
num_runs=8,
mode="pointwise",
)
Expand All @@ -73,9 +88,8 @@ def gpqa_pointwise(row: EvaluationRow) -> EvaluationRow:
content = assistant_msgs[-1].content if assistant_msgs else ""

pred = _extract_abcd_letter(content or "")
# Retrieve GT from the trailing system message we appended
gt_tokens = [m.content for m in row.messages if m.role == "system" and (m.content or "").startswith("__GT__:")]
gt = gt_tokens[-1].split(":", 1)[1].strip() if gt_tokens else None
# GPQA diamond CSV constructs options so that the correct answer is always A
gt = "A"

is_valid = pred is not None and gt in {"A", "B", "C", "D"}
score = 1.0 if (is_valid and pred == gt) else 0.0
Expand Down
Loading
Loading