Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 34 additions & 9 deletions eval_protocol/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import re
from typing import Any, Dict, List

import requests


def load_jsonl(file_path: str) -> List[Dict[str, Any]]:
"""
Expand All @@ -12,19 +14,42 @@ def load_jsonl(file_path: str) -> List[Dict[str, Any]]:

Returns:
A list of dictionaries, where each dictionary is a parsed JSON object from a line.
Returns an empty list if the file is not found or if errors occur during parsing.
Returns an empty list if the file is not found or if errors occur during parsing. Supports HTTP urls and local file paths.
"""
data: List[Dict[str, Any]] = []
with open(file_path, "r", encoding="utf-8") as f:
for line_number, line in enumerate(f):
if file_path.startswith("http://") or file_path.startswith("https://"):
resp = requests.get(file_path, stream=True, timeout=30)
resp.raise_for_status()
for line_number, raw in enumerate(resp.iter_lines(decode_unicode=True), start=1):
if raw is None:
continue
stripped = raw.strip()
if not stripped:
continue
try:
data.append(json.loads(line.strip()))
data.append(json.loads(stripped))
except json.JSONDecodeError as e:
print(f"Error parsing JSON line for file {file_path} at line {line_number}")
# attempt to find "row_id" in the line by finding index of "row_id" and performing regex of `"row_id": (.*),`
row_id_index = line.find("row_id")
print(f"Error parsing JSON line for URL {file_path} at line {line_number}")
row_id_index = stripped.find("row_id")
if row_id_index != -1:
row_id = re.search(r'"row_id": (.*),', line[row_id_index:])
raise ValueError(f"{e.msg} at line {line_number}: {line} ({row_id})")
row_id = re.search(r'"row_id": (.*),', stripped[row_id_index:])
raise ValueError(f"{e.msg} at line {line_number}: {stripped} ({row_id})") from e
raise e
else:
with open(file_path, "r", encoding="utf-8") as f:
for line_number, line in enumerate(f, start=1):
# Skip entirely blank or whitespace-only lines to be robust to trailing newlines
stripped = line.strip()
if not stripped:
continue
try:
data.append(json.loads(stripped))
except json.JSONDecodeError as e:
print(f"Error parsing JSON line for file {file_path} at line {line_number}")
# attempt to find "row_id" in the line by finding index of "row_id" and performing regex of `"row_id": (.*),`
row_id_index = line.find("row_id")
if row_id_index != -1:
row_id = re.search(r'"row_id": (.*),', line[row_id_index:])
raise ValueError(f"{e.msg} at line {line_number}: {line} ({row_id})") from e
raise e
return data
5 changes: 4 additions & 1 deletion eval_protocol/generation/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import aiohttp
from omegaconf import DictConfig
from pydantic import BaseModel, Field # Added for new models
from pydantic import BaseModel # Added for new models

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -83,6 +83,9 @@ async def generate(
}
if self.top_p is not None:
payload["top_p"] = self.top_p
# Include reasoning settings if configured (for reasoning-capable models)
if self.reasoning_effort:
payload["reasoning_effort"] = self.reasoning_effort

if tools:
payload["tools"] = tools
Expand Down
42 changes: 37 additions & 5 deletions eval_protocol/pytest/default_single_turn_rollout_process.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import asyncio
from typing import List

from litellm import acompletion
from openai.types.chat.chat_completion_message import ChatCompletionMessageToolCall
import logging
import os

from eval_protocol.dataset_logger import default_logger
from eval_protocol.models import EvaluationRow, Message
from eval_protocol.models import EvaluationRow, Message, ChatCompletionMessageToolCall
from eval_protocol.pytest.types import RolloutProcessorConfig


Expand All @@ -14,6 +14,20 @@ async def default_single_turn_rollout_processor(
) -> List[EvaluationRow]:
"""Generate a single response from any supported model provider using LiteLLM."""

# Quiet LiteLLM logs in test runs unless user overrode
try:
if os.environ.get("LITELLM_LOG") is None:
os.environ["LITELLM_LOG"] = "ERROR"
_llog = logging.getLogger("LiteLLM")
_llog.setLevel(logging.CRITICAL)
_llog.propagate = False
for _h in list(_llog.handlers):
_llog.removeHandler(_h)
except Exception:
pass

# Do not modify global LiteLLM cache. Disable caching per-request instead.

async def process_row(row: EvaluationRow) -> EvaluationRow:
"""Process a single row asynchronously."""
if len(row.messages) == 0:
Expand All @@ -22,10 +36,21 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
messages_payload = [{"role": m.role, "content": m.content} for m in row.messages]

request_params = {"model": config.model, "messages": messages_payload, **config.input_params}
# Ensure caching is disabled only for this request (review feedback)
request_params["cache"] = {"no-cache": True}
# Allow passing reasoning effort to Fireworks via LiteLLM using extra_body
# Expected: config.input_params may contain {"reasoning": {"effort": "low|medium|high"}}
if "reasoning" in config.input_params:
request_params.setdefault("extra_body", {})
request_params["extra_body"]["reasoning"] = config.input_params["reasoning"]

if row.tools is not None:
request_params["tools"] = row.tools

# Dynamic import to avoid static dependency/lint errors if LiteLLM isn't installed yet
import importlib
_litellm = importlib.import_module("litellm")
acompletion = getattr(_litellm, "acompletion")
response = await acompletion(**request_params)

assistant_content = response.choices[0].message.content or ""
Expand Down Expand Up @@ -57,8 +82,15 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
default_logger.log(row)
return row

# Process all rows concurrently
tasks = [process_row(row) for row in rows]
# Process rows with bounded concurrency if configured
max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8
semaphore = asyncio.Semaphore(max_concurrent)

async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
async with semaphore:
return await process_row(r)

tasks = [_sem_wrapper(row) for row in rows]
dataset = list(await asyncio.gather(*tasks))

return dataset
Loading
Loading