From cf508f1fa65f1d890fcd20b3c732dc2393f84dd0 Mon Sep 17 00:00:00 2001 From: benjibc Date: Tue, 12 Aug 2025 23:18:28 +0000 Subject: [PATCH 1/2] add live bench --- eval_protocol/benchmarks/registry.py | 161 +++++- eval_protocol/benchmarks/suites/aime25.py | 1 + eval_protocol/benchmarks/suites/gpqa.py | 3 +- .../suites/livebench_data_analysis.py | 513 ++++++++++++++++++ eval_protocol/dataset_logger/__init__.py | 14 +- .../sqlite_evaluation_row_store.py | 3 +- .../default_single_turn_rollout_process.py | 7 +- 7 files changed, 696 insertions(+), 6 deletions(-) create mode 100644 eval_protocol/benchmarks/suites/livebench_data_analysis.py diff --git a/eval_protocol/benchmarks/registry.py b/eval_protocol/benchmarks/registry.py index 98065b82..31840fd1 100644 --- a/eval_protocol/benchmarks/registry.py +++ b/eval_protocol/benchmarks/registry.py @@ -126,7 +126,7 @@ def _deep_update(base: Dict[str, Any], over: Dict[str, Any]) -> Dict[str, Any]: server_script_path = ep_config.get("server_script_path") steps = ep_config.get("steps") mode = ep_config.get("mode") - combine_datasets = ep_config.get("combine_datasets") + # combine_datasets captured but not used here # Choose the first rollout param set by default rollout_params = None @@ -169,3 +169,162 @@ def _deep_update(base: Dict[str, Any], over: Dict[str, Any]) -> Dict[str, Any]: return test_wrapper return _decorator + + +def register_composite_benchmark(name: str, children: List[str]) -> None: + """ + Register a composite benchmark that runs multiple exported benchmarks and aggregates results. + + The composite runner forwards common overrides to each child benchmark and aggregates + a combined score as a rows-weighted mean of each child's aggregated score. + + Args: + name: Name of the composite benchmark to register. + children: List of child benchmark names previously registered via export_benchmark. + """ + + def _composite_runner( + *, + model: Optional[str] = None, + print_summary: bool = False, + out: Optional[str] = None, + reasoning_effort: Optional[str] = None, + max_rows: Optional[int | str] = None, + num_runs: Optional[int] = None, + input_params_override: Optional[Dict[str, Any]] = None, + max_concurrency: Optional[int] = None, + ) -> Dict[str, Any]: + # Resolve child runners at call-time to ensure all suites are imported + # Local import avoided to prevent circular import at module import time + _get_benchmark_runner = get_benchmark_runner + import pathlib as _pathlib + import time as _time + _json = json + + child_summaries: List[Dict[str, Any]] = [] + total_rows = 0 + weighted_sum = 0.0 + # For per-metric aggregation across children + metric_weighted_sums: Dict[str, float] = {} + metric_total_rows: Dict[str, int] = {} + combined_rows: List[Any] = [] + + # If 'out' is a file path, also compute a directory for child artifacts + child_out_dir: Optional[str] = None + if out: + p = _pathlib.Path(out) + if p.suffix.lower() == ".json" and not str(out).endswith("/"): + # Use parent directory for child artifacts + child_out_dir = str(p.parent) + else: + child_out_dir = out + + for child_name in children: + runner = _get_benchmark_runner(child_name) + result = runner( + model=model, + print_summary=print_summary, + out=child_out_dir, + reasoning_effort=reasoning_effort, + max_rows=max_rows, + num_runs=num_runs, + input_params_override=input_params_override, + max_concurrency=max_concurrency, + ) + summary = (result or {}).get("summary") if isinstance(result, dict) else None + if not summary: + continue + # Gather underlying rows to recompute CI across children + try: + rows_obj = result.get("results") if isinstance(result, dict) else None + if isinstance(rows_obj, list): + combined_rows.extend(rows_obj) + except Exception: + pass + child_summaries.append(summary) + rows = int(summary.get("rows", 0) or 0) + agg = summary.get("agg_score") + if isinstance(agg, (int, float)) and rows > 0: + total_rows += rows + weighted_sum += float(agg) * rows + # Combine per-metric means if available + metrics_agg = summary.get("metrics_agg") or {} + if isinstance(metrics_agg, dict): + for m_name, m_vals in metrics_agg.items(): + m_mean = m_vals.get("mean") + if isinstance(m_mean, (int, float)) and rows > 0: + metric_weighted_sums[m_name] = metric_weighted_sums.get(m_name, 0.0) + float(m_mean) * rows + metric_total_rows[m_name] = metric_total_rows.get(m_name, 0) + rows + + combined_agg = (weighted_sum / total_rows) if total_rows > 0 else None + # Compute 95% CI for combined rows if available + ci_low: Optional[float] = None + ci_high: Optional[float] = None + if combined_rows: + try: + from eval_protocol.stats.confidence_intervals import compute_fixed_set_mu_ci as _compute_ci + + r = _compute_ci(combined_rows) + if r and len(r) >= 3 and r[1] is not None and r[2] is not None: + ci_low = float(r[1]) + ci_high = float(r[2]) + except Exception: + ci_low = None + ci_high = None + combined_metrics: Dict[str, Dict[str, float]] = {} + for m_name, wsum in metric_weighted_sums.items(): + denom = metric_total_rows.get(m_name, 0) + if denom > 0: + combined_metrics[m_name] = {"mean": float(wsum / denom)} + combined = { + "suite": name, + "model": model, + "agg_score": float(combined_agg) if combined_agg is not None else None, + "rows": total_rows, + "children": child_summaries, + "num_runs": num_runs, + **({"metrics_agg": combined_metrics} if combined_metrics else {}), + **({"agg_ci_low": ci_low, "agg_ci_high": ci_high} if (ci_low is not None and ci_high is not None) else {}), + } + + # Optional print and persist + # Respect either function arg or EP_PRINT_SUMMARY env + _should_print = print_summary or (os.getenv("EP_PRINT_SUMMARY") == "1") + if _should_print: + try: + if combined_agg is not None: + if ci_low is not None and ci_high is not None: + print( + f"EP Summary | suite={name} model={model} agg={combined['agg_score']:.3f} ci95=[{ci_low:.3f},{ci_high:.3f}] rows={total_rows}" + ) + else: + print( + f"EP Summary | suite={name} model={model} agg={combined['agg_score']:.3f} rows={total_rows}" + ) + else: + print( + f"EP Summary | suite={name} model={model} agg=None rows={total_rows}" + ) + except Exception: + pass + + if out: + out_path = _pathlib.Path(out) + if out_path.suffix.lower() == ".json" and not str(out).endswith("/"): + # Write to the specified file + out_path.parent.mkdir(parents=True, exist_ok=True) + with open(out_path, "w", encoding="utf-8") as f: + _json.dump({**combined, "timestamp": int(_time.time())}, f) + else: + # Treat as directory + dir_path = out_path + dir_path.mkdir(parents=True, exist_ok=True) + safe_name = name.replace("/", "__") + file_path = dir_path / f"{safe_name}__composite.json" + with open(file_path, "w", encoding="utf-8") as f: + _json.dump({**combined, "timestamp": int(_time.time())}, f) + + return {"summary": combined} + + # Register (overwrite if exists) + _BENCHMARK_REGISTRY[name] = _composite_runner diff --git a/eval_protocol/benchmarks/suites/aime25.py b/eval_protocol/benchmarks/suites/aime25.py index 406ee74b..4a5d3a4c 100644 --- a/eval_protocol/benchmarks/suites/aime25.py +++ b/eval_protocol/benchmarks/suites/aime25.py @@ -69,6 +69,7 @@ def aime2025_dataset_adapter(rows: List[Dict[str, Any]]) -> List[EvaluationRow]: rollout_input_params=[{"max_tokens": 131000, "extra_body": {"reasoning_effort": "low"}}], rollout_processor=default_single_turn_rollout_processor, aggregation_method="mean", + passed_threshold=None, num_runs=8, max_dataset_rows=2, max_concurrent_rollouts=4, diff --git a/eval_protocol/benchmarks/suites/gpqa.py b/eval_protocol/benchmarks/suites/gpqa.py index 2024d202..ec67ae94 100644 --- a/eval_protocol/benchmarks/suites/gpqa.py +++ b/eval_protocol/benchmarks/suites/gpqa.py @@ -39,8 +39,6 @@ def _load_gpqa_messages_from_csv() -> List[List[Message]]: [ Message(role="system", content=SYSTEM_PROMPT), Message(role="user", content=user_content), - # Correct answer is always option A by construction - Message(role="system", content="__GT__:A"), ] ) if not messages_list: @@ -65,6 +63,7 @@ def _extract_abcd_letter(text: str) -> str | None: rollout_input_params=[{"extra_body": {"reasoning_effort": "low"}}], rollout_processor=default_single_turn_rollout_processor, aggregation_method="mean", + passed_threshold=None, num_runs=8, mode="pointwise", ) diff --git a/eval_protocol/benchmarks/suites/livebench_data_analysis.py b/eval_protocol/benchmarks/suites/livebench_data_analysis.py new file mode 100644 index 00000000..d7bd8729 --- /dev/null +++ b/eval_protocol/benchmarks/suites/livebench_data_analysis.py @@ -0,0 +1,513 @@ +from typing import Any, Dict, List, Optional + +import json +import re + +from eval_protocol.models import EvaluateResult, EvaluationRow, Message, MetricResult +from eval_protocol.pytest.default_single_turn_rollout_process import ( + default_single_turn_rollout_processor, +) +from eval_protocol.pytest.evaluation_test import evaluation_test +from eval_protocol.benchmarks.registry import export_benchmark, register_composite_benchmark + + +# ------------------------- +# Lightweight ports of LiveBench scoring utilities for data_analysis tasks +# ------------------------- + +def _lb_clean_text(text: str) -> str: + text = text.lower().strip() + text = re.sub(r"[^\w]", "", text) + return text + + +def _extract_last_boxed_segment(text: str) -> Optional[str]: + # Extract the last occurrence of \\boxed{...} or \\framebox{...} + pattern = r"\\(?:boxed|framebox)\{(.*?)\}" + matches = re.findall(pattern, text, re.DOTALL) + if not matches: + return None + return matches[-1] + + +def _cta_process_results(ground_truth: str, llm_answer: str) -> int: + parsed_answer = llm_answer + if "\\boxed{" in parsed_answer or "\\framebox{" in parsed_answer: + boxed = _extract_last_boxed_segment(parsed_answer) + if boxed is not None: + parsed_answer = boxed + parsed_answer = ( + parsed_answer.replace("\\text{", "").replace("}", "").replace("\\", "") + ) + + gt_clean = _lb_clean_text(ground_truth) + ans_clean = _lb_clean_text(parsed_answer) + if gt_clean == ans_clean: + return 1 + # Suffix match to handle answers like "... answer: XYZ" + if len(ans_clean) >= len(gt_clean) and ans_clean[-len(gt_clean) :] == gt_clean: + return 1 + return 0 + + +def _tj_clean_llm_output(s: str) -> Dict[str, Any]: + # Try to extract the last ... + m = re.findall(r"(.*?)", s, re.DOTALL) + if len(m) > 0: + return _tj_clean_llm_output(m[-1].strip()) + + candidate: Optional[str] = None + # Prefer code blocks (python/json/any) + for fence in ("```python", "```json", "```"): + mm = re.findall(r"%s(.*?)```" % re.escape(fence), s.replace("\n", ""), re.MULTILINE) + if mm: + candidate = mm[-1] + break + # Fallback to boxed + if candidate is None and "\\boxed" in s: + boxed = _extract_last_boxed_segment(s.replace("\n", "")) + if boxed: + # Convert \text{"str"} to 'str' and strip backslashes + candidate = re.sub(r"\\text{['\"](.*?)['\"]}", r"'\1'", boxed).replace("\\", "") + if candidate is None: + candidate = s + + # Make JSON-like to python literal + candidate = candidate.replace("null", "None") + try: + from ast import literal_eval + + parsed = literal_eval(candidate) + if not isinstance(parsed, dict): + return {} + # Drop None values + for k in list(parsed.keys()): + if parsed[k] is None: + del parsed[k] + return parsed + except Exception: + return {} + + +def _tablejoin_process_results(ground_truth: Any, llm_answer: str) -> float: + import json as _json + from ast import literal_eval + + # Parse GT into dict if needed + gt: Dict[str, Any] + if isinstance(ground_truth, str): + try: + gt = literal_eval(ground_truth) + except Exception: + try: + gt = _json.loads(ground_truth) + except Exception: + return 0.0 + else: + gt = dict(ground_truth) + + pred = _tj_clean_llm_output(llm_answer) + if len(pred) == 0: + return 0.0 + + tp = 0 + fp = 0 + fn = 0 + for k, v in pred.items(): + gt_v = gt.get(k, None) + if gt_v is None: + fp += 1 + elif gt_v == v: + tp += 1 + else: + fp += 1 + fn += 1 + for k, v in gt.items(): + if k not in pred: + fn += 1 + denom = (2 * tp) + fp + fn + if denom == 0: + return 0.0 + # Round to 2 decimals to mirror LiveBench + return round((2 * tp) / denom, 2) + + +def _tablereformat_process_results( + input_command: str, ground_truth: str, llm_answer: str, version: str +) -> int: + try: + import pandas as pd # type: ignore + except Exception: + return 0 + + from io import StringIO + import math as _math + import traceback as _traceback + + def _read_df_v1(df_type: str, df_str: str): + if df_type == "json": + for orient in ("index", "records", "records", "table", "values"): + try: + return pd.read_json(StringIO(df_str), orient=orient) + except Exception: + pass + return pd.read_json(StringIO(df_str), orient="values") + if df_type == "jsonl": + return pd.read_json(StringIO(df_str), orient="records", lines=True) + if df_type == "html": + return pd.concat(pd.read_html(StringIO(df_str)), axis=0) + if df_type == "csv": + return pd.read_csv(StringIO(df_str)) + if df_type == "markdown": + return pd.read_table(StringIO(df_str), sep="|", header=0, index_col=1, skipinitialspace=True) + if df_type == "tsv": + return pd.read_csv(StringIO(df_str), sep="\t") + raise ValueError(f"Unsupported type {df_type}") + + def _read_df_v2(df_type: str, df_str: str): + if df_type == "json": + for orient in ("table", "index", "records"): + try: + return pd.read_json(StringIO(df_str), orient=orient) + except Exception: + pass + return None + if df_type == "jsonl": + return pd.read_json(StringIO(df_str), orient="records", lines=True) + if df_type == "html": + return pd.concat(pd.read_html(StringIO(df_str)), axis=0) + if df_type == "csv": + return pd.read_csv(StringIO(df_str)) + if df_type == "markdown": + # Remove alignment line + lines = df_str.strip().split("\n") + header = lines[0] + data_lines = lines[2:] if len(lines) > 2 else [] + processed = header + "\n" + "\n".join(data_lines) + df = pd.read_table(StringIO(processed), sep="|", header=0, skipinitialspace=True).iloc[:, 1:-1] + for col in df.columns: + if df[col].dtype == "object": + df[col] = df[col].astype(str).str.strip() + return df + if df_type == "tsv": + return pd.read_csv(StringIO(df_str), sep="\t") + raise ValueError(f"Unsupported type {df_type}") + + def _clean_llm_output(s: str) -> str: + m = re.findall(r"```json\n(.*?)```", s, re.DOTALL) + if m: + return m[-1].strip() + m = re.findall(r"```html\n(.*?)```", s, re.DOTALL) + if m: + return m[-1].strip() + s = re.sub(r"^```.*\n", "", s) + s = s.replace("&", "&") + return s.replace("```", "").strip() + + def _remove_initial_phrase(text: str) -> str: + return re.sub(r"^\s*(Here|Input)\b.*?\b(format|table)\s*[:)]\s*", "", text, flags=re.IGNORECASE).strip() + + def _read_sep_table_from_text(text: str, header: str, sep: str): + text = text.strip() + lines = text.split("\n") + header_line = 0 + while header_line < len(lines) and lines[header_line].strip() != header.strip(): + header_line += 1 + if header_line == len(lines) or lines[header_line].strip() != header.strip(): + return None + table = lines[header_line:] + parsed = None + while parsed is None and table: + try: + parsed = pd.read_csv(StringIO("\n".join(table)), sep=sep) + except Exception: + table = table[:-1] + return parsed + + def _read_jsonl_table_from_text(text: str, header_cols: List[str]): + rows = [] + for line in text.strip().split("\n"): + if len(line) < 2 or line[0] != "{" or line[-1] != "}": + continue + if not all(col in line for col in header_cols): + continue + try: + rows.append(json.loads(line)) + except Exception: + continue + if not rows: + return None + import pandas as _pd + + return _pd.DataFrame(rows) + + # Determine formats from the instruction + if version == "v1": + input_fmt = input_command.split("Please convert the Input Table from ")[1].split(" format")[0].lower() + output_fmt = ( + input_command.split("Please convert the Input Table from ")[1] + .split("format to ")[1] + .split(" format")[0] + .lower() + ) + else: + lines = input_command.split("\n") + input_fmt = [l for l in lines if "Source Format" in l][-1].split("Source Format: ")[-1].strip().lower() + output_fmt = [l for l in lines if "Target Format" in l][-1].split("Target Format: ")[-1].strip().lower() + + reader = _read_df_v1 if version == "v1" else _read_df_v2 + gt_df = reader(output_fmt, ground_truth) + + llm_clean = _clean_llm_output(llm_answer) + llm_clean = _remove_initial_phrase(llm_clean) + try: + llm_df = reader(output_fmt, llm_clean) + except Exception: + llm_df = None + if output_fmt in ("csv", "tsv") and gt_df is not None: + header = (",", "\t")[output_fmt == "tsv"].join(list(gt_df.columns)) + llm_df = _read_sep_table_from_text(llm_clean, header, sep="," if output_fmt == "csv" else "\t") + elif output_fmt == "jsonl" and gt_df is not None: + llm_df = _read_jsonl_table_from_text(llm_clean, list(gt_df.columns)) + if llm_df is None: + return 0 + + # Compare + try: + gt_df.columns = [str(s).strip() for s in gt_df.columns] + if "index" in gt_df.columns: + gt_df = gt_df.drop(columns=["index"]) + llm_df.columns = [str(s).strip() for s in llm_df.columns] + if "index" in llm_df.columns: + llm_df = llm_df.drop(columns=["index"]) + assert len(llm_df) == len(gt_df) + assert sorted(llm_df.columns) == sorted(gt_df.columns) + for i in range(len(llm_df)): + for key in llm_df.columns: + lv = llm_df.iloc[i][key] + gv = gt_df.iloc[i][key] + if isinstance(lv, str): + lv = lv.strip() + if isinstance(gv, str): + gv = gv.strip() + # Numeric tolerance for floats + try: + lvf = float(lv) + gvf = float(gv) + if _math.isnan(lvf) and _math.isnan(gvf): + continue + assert abs(lvf - gvf) < 1e-6 + except Exception: + assert str(lv) == str(gv) + except AssertionError: + return 0 + except Exception: + # Silent on failure, match LiveBench robustness + _traceback.print_exc() + return 0 + return 1 + + +# ------------------------- +# Dataset loading from Hugging Face at import time +# ------------------------- + +SYSTEM_PROMPT = "You are a helpful data analyst. Read the task and answer precisely." + + +def _load_livebench_da_messages(task_name: str) -> List[List[Message]]: + try: + from datasets import load_dataset # type: ignore + except Exception as e: # pragma: no cover + raise RuntimeError( + "The 'datasets' package is required for LiveBench Data Analysis benchmarks. Please 'pip install datasets'." + ) from e + + ds = load_dataset("livebench/data_analysis", split="test") + rows: List[List[Message]] = [] + for ex in ds: + if str(ex.get("task", "")) != task_name: + continue + question_text = str(ex.get("turns", [""])[0]) + ground_truth = ex.get("ground_truth") + try: + gt_json = json.dumps({ + "ground_truth": ground_truth, + "release": ex.get("livebench_release_date", ""), + }, ensure_ascii=False) + except TypeError: + # Some rows may include non-serializable types; fall back to string cast + gt_json = json.dumps({"ground_truth": str(ground_truth), "release": str(ex.get("livebench_release_date", ""))}) + rows.append( + [ + Message(role="system", content=SYSTEM_PROMPT), + Message(role="user", content=question_text), + Message(role="system", content=f"__GT__:{gt_json}"), + ] + ) + if not rows: + raise RuntimeError(f"No rows found for LiveBench data_analysis task '{task_name}'") + return rows + + +def _extract_gt(row: EvaluationRow) -> Dict[str, Any]: + gt_tokens = [ + m.content + for m in row.messages + if m.role == "system" and (m.content or "").startswith("__GT__:") + ] + if not gt_tokens: + return {"ground_truth": None, "release": None} + try: + payload = json.loads(gt_tokens[-1].split(":", 1)[1]) + return payload if isinstance(payload, dict) else {"ground_truth": payload, "release": None} + except Exception: + return {"ground_truth": gt_tokens[-1].split(":", 1)[1], "release": None} + + +# ------------------------- +# CTA +# ------------------------- + +_CTA_MESSAGES = _load_livebench_da_messages("cta") + + +@export_benchmark("live_bench/data_analysis/cta") +@evaluation_test( + model=["fireworks_ai/accounts/fireworks/models/gpt-oss-120b"], + input_messages=_CTA_MESSAGES, + rollout_input_params=[{"extra_body": {"reasoning_effort": "low"}}], + rollout_processor=default_single_turn_rollout_processor, + aggregation_method="mean", + passed_threshold=None, + num_runs=4, + mode="pointwise", +) +def livebench_cta_pointwise(row: EvaluationRow) -> EvaluationRow: + assistant_msgs = [m for m in row.messages if m.role == "assistant"] + content = assistant_msgs[-1].content if assistant_msgs else "" + payload = _extract_gt(row) + gt = payload.get("ground_truth") + gt_str = str(gt) if gt is not None else "" + + score_val = float(_cta_process_results(gt_str, content or "")) if gt_str else 0.0 + is_valid = bool(gt_str) + + row.evaluation_result = EvaluateResult( + score=score_val, + reason=("Matched" if score_val == 1.0 else "Not matched"), + is_score_valid=is_valid, + metrics={ + "exact_match": MetricResult( + score=score_val, + is_score_valid=is_valid, + reason=("Exact/suffix match" if score_val == 1.0 else "Mismatch"), + ) + }, + ) + return row + + +# ------------------------- +# Table Join +# ------------------------- + +_TABLEJOIN_MESSAGES = _load_livebench_da_messages("tablejoin") + + +@export_benchmark("live_bench/data_analysis/tablejoin") +@evaluation_test( + model=["fireworks_ai/accounts/fireworks/models/gpt-oss-120b"], + input_messages=_TABLEJOIN_MESSAGES, + rollout_input_params=[{"extra_body": {"reasoning_effort": "low"}}], + rollout_processor=default_single_turn_rollout_processor, + aggregation_method="mean", + passed_threshold=None, + num_runs=4, + mode="pointwise", +) +def livebench_tablejoin_pointwise(row: EvaluationRow) -> EvaluationRow: + user_msgs = [m for m in row.messages if m.role == "user"] + question = user_msgs[-1].content if user_msgs else "" + assistant_msgs = [m for m in row.messages if m.role == "assistant"] + content = assistant_msgs[-1].content if assistant_msgs else "" + payload = _extract_gt(row) + gt = payload.get("ground_truth") + + score_val = float(_tablejoin_process_results(gt, content or "")) + is_valid = True + + row.evaluation_result = EvaluateResult( + score=score_val, + reason=f"F1 score: {score_val:.2f}", + is_score_valid=is_valid, + metrics={ + "f1": MetricResult( + score=score_val, + is_score_valid=is_valid, + reason="Entity/relation mapping F1", + ) + }, + ) + return row + + +# ------------------------- +# Table Reformat +# ------------------------- + +_TABLEREFORMAT_MESSAGES = _load_livebench_da_messages("tablereformat") + + +@export_benchmark("live_bench/data_analysis/tablereformat") +@evaluation_test( + model=["fireworks_ai/accounts/fireworks/models/gpt-oss-120b"], + input_messages=_TABLEREFORMAT_MESSAGES, + rollout_input_params=[{"extra_body": {"reasoning_effort": "low"}}], + rollout_processor=default_single_turn_rollout_processor, + aggregation_method="mean", + passed_threshold=None, + num_runs=4, + mode="pointwise", +) +def livebench_tablereformat_pointwise(row: EvaluationRow) -> EvaluationRow: + user_msgs = [m for m in row.messages if m.role == "user"] + question = user_msgs[-1].content if user_msgs else "" + assistant_msgs = [m for m in row.messages if m.role == "assistant"] + content = assistant_msgs[-1].content if assistant_msgs else "" + payload = _extract_gt(row) + gt = payload.get("ground_truth") + release = payload.get("release") or "" + version = "v2" if str(release) >= "2025-04-25" else "v1" + + gt_str = str(gt) if gt is not None else "" + score_int = _tablereformat_process_results(question or "", gt_str, content or "", version) + score_val = float(score_int) + is_valid = bool(gt_str) + + row.evaluation_result = EvaluateResult( + score=score_val, + reason=("Table matches" if score_val == 1.0 else "Table mismatch"), + is_score_valid=is_valid, + metrics={ + "structure_exact": MetricResult( + score=score_val, + is_score_valid=is_valid, + reason="Exact structure and values match", + ) + }, + ) + return row + + +# Register a composite benchmark that aggregates all three LiveBench Data Analysis tests +register_composite_benchmark( + name="live_bench/data_analysis", + children=[ + "live_bench/data_analysis/cta", + "live_bench/data_analysis/tablejoin", + "live_bench/data_analysis/tablereformat", + ], +) + + diff --git a/eval_protocol/dataset_logger/__init__.py b/eval_protocol/dataset_logger/__init__.py index d60fe513..9478ec6f 100644 --- a/eval_protocol/dataset_logger/__init__.py +++ b/eval_protocol/dataset_logger/__init__.py @@ -1,3 +1,15 @@ from eval_protocol.dataset_logger.sqlite_dataset_logger_adapter import SqliteDatasetLoggerAdapter +import os -default_logger = SqliteDatasetLoggerAdapter() +# Allow disabling sqlite logger to avoid environment-specific constraints in simple CLI runs. +if os.getenv("EP_SQLITE_LOG", "0").strip() == "1": + default_logger = SqliteDatasetLoggerAdapter() +else: + class _NoOpLogger: + def log(self, row): + return None + + def read(self, rollout_id=None): + return [] + + default_logger = _NoOpLogger() diff --git a/eval_protocol/dataset_logger/sqlite_evaluation_row_store.py b/eval_protocol/dataset_logger/sqlite_evaluation_row_store.py index a8f149a8..6ab0bb8e 100644 --- a/eval_protocol/dataset_logger/sqlite_evaluation_row_store.py +++ b/eval_protocol/dataset_logger/sqlite_evaluation_row_store.py @@ -30,7 +30,8 @@ class EvaluationRow(BaseModel): # type: ignore self._EvaluationRow = EvaluationRow self._db.connect() - self._db.create_tables([EvaluationRow]) + # Use safe=True to avoid errors when tables/indexes already exist + self._db.create_tables([EvaluationRow], safe=True) @property def db_path(self) -> str: diff --git a/eval_protocol/pytest/default_single_turn_rollout_process.py b/eval_protocol/pytest/default_single_turn_rollout_process.py index 95613ebc..f6e95a02 100644 --- a/eval_protocol/pytest/default_single_turn_rollout_process.py +++ b/eval_protocol/pytest/default_single_turn_rollout_process.py @@ -31,7 +31,12 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: if len(row.messages) == 0: raise ValueError("Messages is empty. Please provide a non-empty dataset") - messages_payload = [{"role": m.role, "content": m.content} for m in row.messages] + # Filter out any sentinel ground-truth system messages (e.g., "__GT__:") before sending to the model + messages_payload = [ + {"role": m.role, "content": m.content} + for m in row.messages + if not (m.role == "system" and (m.content or "").startswith("__GT__:")) + ] request_params = {"model": config.model, "messages": messages_payload, **config.input_params} # Ensure caching is disabled only for this request (review feedback) From b38e8317f0d93d9c2a4aeb291e7412220edf0cfd Mon Sep 17 00:00:00 2001 From: benjibc Date: Wed, 13 Aug 2025 03:32:41 +0000 Subject: [PATCH 2/2] fix live bench and rollout processor --- eval_protocol/benchmarks/suites/gpqa.py | 23 ++++++-- .../suites/livebench_data_analysis.py | 55 +++++++++---------- .../default_single_turn_rollout_process.py | 7 +-- 3 files changed, 47 insertions(+), 38 deletions(-) diff --git a/eval_protocol/benchmarks/suites/gpqa.py b/eval_protocol/benchmarks/suites/gpqa.py index ec67ae94..91620c9a 100644 --- a/eval_protocol/benchmarks/suites/gpqa.py +++ b/eval_protocol/benchmarks/suites/gpqa.py @@ -55,13 +55,29 @@ def _extract_abcd_letter(text: str) -> str | None: _GPQA_INPUT_MESSAGES = _load_gpqa_messages_from_csv() +def _strip_gt_messages(msgs: List[Message]) -> List[Message]: + return [m for m in msgs if not (m.role == "system" and (m.content or "").startswith("__GT__:"))] + + +async def gpqa_strip_gt_rollout_processor(rows: List[EvaluationRow], config) -> List[EvaluationRow]: + """Preprocess rows to set ground_truth and remove __GT__ messages, then delegate to default processor.""" + processed: List[EvaluationRow] = [] + for r in rows: + gt_tokens = [m.content for m in r.messages if m.role == "system" and (m.content or "").startswith("__GT__:")] + if gt_tokens: + gt_val = gt_tokens[-1].split(":", 1)[1].strip() + r.ground_truth = gt_val + r.messages = [m for m in r.messages if not (m.role == "system" and (m.content or "").startswith("__GT__:"))] + processed.append(r) + return await default_single_turn_rollout_processor(processed, config) + @export_benchmark("gpqa") @evaluation_test( model=["fireworks_ai/accounts/fireworks/models/gpt-oss-120b"], input_messages=_GPQA_INPUT_MESSAGES, rollout_input_params=[{"extra_body": {"reasoning_effort": "low"}}], - rollout_processor=default_single_turn_rollout_processor, + rollout_processor=gpqa_strip_gt_rollout_processor, aggregation_method="mean", passed_threshold=None, num_runs=8, @@ -72,9 +88,8 @@ def gpqa_pointwise(row: EvaluationRow) -> EvaluationRow: content = assistant_msgs[-1].content if assistant_msgs else "" pred = _extract_abcd_letter(content or "") - # Retrieve GT from the trailing system message we appended - gt_tokens = [m.content for m in row.messages if m.role == "system" and (m.content or "").startswith("__GT__:")] - gt = gt_tokens[-1].split(":", 1)[1].strip() if gt_tokens else None + # GPQA diamond CSV constructs options so that the correct answer is always A + gt = "A" is_valid = pred is not None and gt in {"A", "B", "C", "D"} score = 1.0 if (is_valid and pred == gt) else 0.0 diff --git a/eval_protocol/benchmarks/suites/livebench_data_analysis.py b/eval_protocol/benchmarks/suites/livebench_data_analysis.py index d7bd8729..1c04b6fd 100644 --- a/eval_protocol/benchmarks/suites/livebench_data_analysis.py +++ b/eval_protocol/benchmarks/suites/livebench_data_analysis.py @@ -315,7 +315,7 @@ def _read_jsonl_table_from_text(text: str, header_cols: List[str]): SYSTEM_PROMPT = "You are a helpful data analyst. Read the task and answer precisely." -def _load_livebench_da_messages(task_name: str) -> List[List[Message]]: +def _load_livebench_da_messages(task_name: str) -> List[EvaluationRow]: try: from datasets import load_dataset # type: ignore except Exception as e: # pragma: no cover @@ -324,26 +324,25 @@ def _load_livebench_da_messages(task_name: str) -> List[List[Message]]: ) from e ds = load_dataset("livebench/data_analysis", split="test") - rows: List[List[Message]] = [] + rows: List[EvaluationRow] = [] for ex in ds: if str(ex.get("task", "")) != task_name: continue question_text = str(ex.get("turns", [""])[0]) ground_truth = ex.get("ground_truth") + release = ex.get("livebench_release_date", "") try: - gt_json = json.dumps({ - "ground_truth": ground_truth, - "release": ex.get("livebench_release_date", ""), - }, ensure_ascii=False) + gt_payload = json.dumps({"ground_truth": ground_truth, "release": release}, ensure_ascii=False) except TypeError: - # Some rows may include non-serializable types; fall back to string cast - gt_json = json.dumps({"ground_truth": str(ground_truth), "release": str(ex.get("livebench_release_date", ""))}) + gt_payload = json.dumps({"ground_truth": str(ground_truth), "release": str(release)}) rows.append( - [ - Message(role="system", content=SYSTEM_PROMPT), - Message(role="user", content=question_text), - Message(role="system", content=f"__GT__:{gt_json}"), - ] + EvaluationRow( + messages=[ + Message(role="system", content=SYSTEM_PROMPT), + Message(role="user", content=question_text), + ], + ground_truth=gt_payload, + ) ) if not rows: raise RuntimeError(f"No rows found for LiveBench data_analysis task '{task_name}'") @@ -351,31 +350,31 @@ def _load_livebench_da_messages(task_name: str) -> List[List[Message]]: def _extract_gt(row: EvaluationRow) -> Dict[str, Any]: - gt_tokens = [ - m.content - for m in row.messages - if m.role == "system" and (m.content or "").startswith("__GT__:") - ] - if not gt_tokens: + # For LiveBench Data Analysis, we fetch the ground truth from the HF dataset + # and store it in the top-level ground_truth field in the adapter below. + # Here, just parse row.ground_truth if it contains a JSON payload, else string. + if row.ground_truth is None: return {"ground_truth": None, "release": None} try: - payload = json.loads(gt_tokens[-1].split(":", 1)[1]) - return payload if isinstance(payload, dict) else {"ground_truth": payload, "release": None} + payload = json.loads(row.ground_truth) + if isinstance(payload, dict): + return payload except Exception: - return {"ground_truth": gt_tokens[-1].split(":", 1)[1], "release": None} + pass + return {"ground_truth": row.ground_truth, "release": None} # ------------------------- # CTA # ------------------------- -_CTA_MESSAGES = _load_livebench_da_messages("cta") +_CTA_ROWS = _load_livebench_da_messages("cta") @export_benchmark("live_bench/data_analysis/cta") @evaluation_test( model=["fireworks_ai/accounts/fireworks/models/gpt-oss-120b"], - input_messages=_CTA_MESSAGES, + input_messages=[[m for m in r.messages] for r in _CTA_ROWS], rollout_input_params=[{"extra_body": {"reasoning_effort": "low"}}], rollout_processor=default_single_turn_rollout_processor, aggregation_method="mean", @@ -412,13 +411,13 @@ def livebench_cta_pointwise(row: EvaluationRow) -> EvaluationRow: # Table Join # ------------------------- -_TABLEJOIN_MESSAGES = _load_livebench_da_messages("tablejoin") +_TABLEJOIN_ROWS = _load_livebench_da_messages("tablejoin") @export_benchmark("live_bench/data_analysis/tablejoin") @evaluation_test( model=["fireworks_ai/accounts/fireworks/models/gpt-oss-120b"], - input_messages=_TABLEJOIN_MESSAGES, + input_messages=[[m for m in r.messages] for r in _TABLEJOIN_ROWS], rollout_input_params=[{"extra_body": {"reasoning_effort": "low"}}], rollout_processor=default_single_turn_rollout_processor, aggregation_method="mean", @@ -456,13 +455,13 @@ def livebench_tablejoin_pointwise(row: EvaluationRow) -> EvaluationRow: # Table Reformat # ------------------------- -_TABLEREFORMAT_MESSAGES = _load_livebench_da_messages("tablereformat") +_TABLEREFORMAT_ROWS = _load_livebench_da_messages("tablereformat") @export_benchmark("live_bench/data_analysis/tablereformat") @evaluation_test( model=["fireworks_ai/accounts/fireworks/models/gpt-oss-120b"], - input_messages=_TABLEREFORMAT_MESSAGES, + input_messages=[[m for m in r.messages] for r in _TABLEREFORMAT_ROWS], rollout_input_params=[{"extra_body": {"reasoning_effort": "low"}}], rollout_processor=default_single_turn_rollout_processor, aggregation_method="mean", diff --git a/eval_protocol/pytest/default_single_turn_rollout_process.py b/eval_protocol/pytest/default_single_turn_rollout_process.py index f6e95a02..95613ebc 100644 --- a/eval_protocol/pytest/default_single_turn_rollout_process.py +++ b/eval_protocol/pytest/default_single_turn_rollout_process.py @@ -31,12 +31,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: if len(row.messages) == 0: raise ValueError("Messages is empty. Please provide a non-empty dataset") - # Filter out any sentinel ground-truth system messages (e.g., "__GT__:") before sending to the model - messages_payload = [ - {"role": m.role, "content": m.content} - for m in row.messages - if not (m.role == "system" and (m.content or "").startswith("__GT__:")) - ] + messages_payload = [{"role": m.role, "content": m.content} for m in row.messages] request_params = {"model": config.model, "messages": messages_payload, **config.input_params} # Ensure caching is disabled only for this request (review feedback)