Skip to content

Commit ffcb08d

Browse files
author
Dylan Huang
committed
Merge branch 'main' into show-aggregated-metrics-in-ui
# Conflicts: # eval_protocol/pytest/evaluation_test.py
2 parents 949acba + 55005a1 commit ffcb08d

File tree

16 files changed

+1066
-42
lines changed

16 files changed

+1066
-42
lines changed

eval_protocol/common_utils.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import re
33
from typing import Any, Dict, List
44

5+
import requests
6+
57

68
def load_jsonl(file_path: str) -> List[Dict[str, Any]]:
79
"""
@@ -12,19 +14,42 @@ def load_jsonl(file_path: str) -> List[Dict[str, Any]]:
1214
1315
Returns:
1416
A list of dictionaries, where each dictionary is a parsed JSON object from a line.
15-
Returns an empty list if the file is not found or if errors occur during parsing.
17+
Returns an empty list if the file is not found or if errors occur during parsing. Supports HTTP urls and local file paths.
1618
"""
1719
data: List[Dict[str, Any]] = []
18-
with open(file_path, "r", encoding="utf-8") as f:
19-
for line_number, line in enumerate(f):
20+
if file_path.startswith("http://") or file_path.startswith("https://"):
21+
resp = requests.get(file_path, stream=True, timeout=30)
22+
resp.raise_for_status()
23+
for line_number, raw in enumerate(resp.iter_lines(decode_unicode=True), start=1):
24+
if raw is None:
25+
continue
26+
stripped = raw.strip()
27+
if not stripped:
28+
continue
2029
try:
21-
data.append(json.loads(line.strip()))
30+
data.append(json.loads(stripped))
2231
except json.JSONDecodeError as e:
23-
print(f"Error parsing JSON line for file {file_path} at line {line_number}")
24-
# attempt to find "row_id" in the line by finding index of "row_id" and performing regex of `"row_id": (.*),`
25-
row_id_index = line.find("row_id")
32+
print(f"Error parsing JSON line for URL {file_path} at line {line_number}")
33+
row_id_index = stripped.find("row_id")
2634
if row_id_index != -1:
27-
row_id = re.search(r'"row_id": (.*),', line[row_id_index:])
28-
raise ValueError(f"{e.msg} at line {line_number}: {line} ({row_id})")
35+
row_id = re.search(r'"row_id": (.*),', stripped[row_id_index:])
36+
raise ValueError(f"{e.msg} at line {line_number}: {stripped} ({row_id})") from e
2937
raise e
38+
else:
39+
with open(file_path, "r", encoding="utf-8") as f:
40+
for line_number, line in enumerate(f, start=1):
41+
# Skip entirely blank or whitespace-only lines to be robust to trailing newlines
42+
stripped = line.strip()
43+
if not stripped:
44+
continue
45+
try:
46+
data.append(json.loads(stripped))
47+
except json.JSONDecodeError as e:
48+
print(f"Error parsing JSON line for file {file_path} at line {line_number}")
49+
# attempt to find "row_id" in the line by finding index of "row_id" and performing regex of `"row_id": (.*),`
50+
row_id_index = line.find("row_id")
51+
if row_id_index != -1:
52+
row_id = re.search(r'"row_id": (.*),', line[row_id_index:])
53+
raise ValueError(f"{e.msg} at line {line_number}: {line} ({row_id})") from e
54+
raise e
3055
return data

eval_protocol/generation/clients.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import aiohttp
1313
from omegaconf import DictConfig
14-
from pydantic import BaseModel, Field # Added for new models
14+
from pydantic import BaseModel # Added for new models
1515

1616
logger = logging.getLogger(__name__)
1717

@@ -83,6 +83,9 @@ async def generate(
8383
}
8484
if self.top_p is not None:
8585
payload["top_p"] = self.top_p
86+
# Include reasoning settings if configured (for reasoning-capable models)
87+
if self.reasoning_effort:
88+
payload["reasoning_effort"] = self.reasoning_effort
8689

8790
if tools:
8891
payload["tools"] = tools

eval_protocol/mcp/execution/manager.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ async def _execute_with_semaphore(idx):
163163
evaluation_rows[idx].input_metadata.row_id = envs.dataset_rows[idx].id
164164
evaluation_rows[idx].input_metadata.dataset_info = asdict(envs.dataset_rows[idx])
165165
evaluation_rows[idx].tools = shared_tool_schema
166-
evaluation_rows[idx].usage = trajectory.usage
166+
evaluation_rows[idx].usage = CompletionUsage(**trajectory.usage)
167167
evaluation_rows[idx].input_metadata.completion_params = CompletionParams(
168168
model=policy.model_id,
169169
temperature=getattr(policy, "temperature", None),
@@ -260,8 +260,6 @@ async def _execute_rollout(
260260
{"role": "user", "content": user_prompt},
261261
]
262262

263-
usage_stats_list: List[CompletionUsage] = []
264-
265263
logger.info(f"🎯 Starting rollout {rollout_idx} in thread {threading.current_thread().name}")
266264

267265
# Run rollout loop for this specific environment
@@ -299,6 +297,12 @@ async def _execute_rollout(
299297
while not turn_completed and not trajectory.terminated:
300298
tool_calls, usage_stats = await policy(tool_schema, rollout_idx, conversation_history)
301299

300+
# calc llm usage stats happened in this turn if there is aany
301+
if usage_stats:
302+
trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens
303+
trajectory.usage["completion_tokens"] += usage_stats.completion_tokens
304+
trajectory.usage["total_tokens"] += usage_stats.total_tokens
305+
302306
# If no tool call is generated, turn is finished
303307
if len(tool_calls) == 1:
304308
# If there's a user simulator, no tool call means the policy is ready to provide final response on this turn
@@ -308,6 +312,8 @@ async def _execute_rollout(
308312
# If there's no user simulator, no tool call means policy failed and we should terminate the rollout
309313
elif tool_calls[0].tool_name in ["_playback_terminate", "_no_tool_call"]:
310314
trajectory.terminated = True
315+
trajectory.termination_reason = TerminationReason.ERROR
316+
trajectory.control_plane_summary.update({"error_message": "No expected tool call"})
311317
break
312318

313319
# Execute each tool call sequentially
@@ -373,10 +379,6 @@ async def _execute_rollout(
373379
if observation is not None:
374380
current_observation = observation
375381

376-
# calc llm usage stats happened in this turn if there is aany
377-
if usage_stats:
378-
usage_stats_list.append(usage_stats)
379-
380382
# With user simulator, increment step after an entire conversation step
381383
if user_simulator is not None:
382384
step += 1
@@ -409,7 +411,9 @@ async def _execute_rollout(
409411
# tool indicates rollout should be terminated, call policy one last time to get the final response
410412
_, usage_stats = await policy(tool_schema, rollout_idx, conversation_history)
411413
if usage_stats:
412-
usage_stats_list.append(usage_stats)
414+
trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens
415+
trajectory.usage["completion_tokens"] += usage_stats.completion_tokens
416+
trajectory.usage["total_tokens"] += usage_stats.total_tokens
413417

414418
# Add final control plane summary
415419
trajectory.control_plane_summary.update(
@@ -460,11 +464,6 @@ async def _execute_rollout(
460464
msg["control_plane_step"]["termination_reason"] = trajectory.termination_reason
461465
break
462466

463-
for usage_stats in usage_stats_list:
464-
trajectory.usage["prompt_tokens"] += usage_stats.prompt_tokens
465-
trajectory.usage["completion_tokens"] += usage_stats.completion_tokens
466-
trajectory.usage["total_tokens"] += usage_stats.total_tokens
467-
468467
logger.info(
469468
f"✅ Rollout {rollout_idx} completed: {trajectory.steps} steps, reward: {trajectory.total_reward:.2f}, termination: {trajectory.termination_reason}, in thread {threading.current_thread().name}"
470469
)

eval_protocol/pytest/default_single_turn_rollout_process.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import asyncio
22
from typing import List
33

4-
from litellm import acompletion
5-
from openai.types.chat.chat_completion_message import ChatCompletionMessageToolCall
4+
import logging
5+
import os
66

77
from eval_protocol.dataset_logger import default_logger
8-
from eval_protocol.models import EvaluationRow, Message
8+
from eval_protocol.models import EvaluationRow, Message, ChatCompletionMessageToolCall
99
from eval_protocol.pytest.types import RolloutProcessorConfig
1010

1111

@@ -14,6 +14,20 @@ async def default_single_turn_rollout_processor(
1414
) -> List[EvaluationRow]:
1515
"""Generate a single response from any supported model provider using LiteLLM."""
1616

17+
# Quiet LiteLLM logs in test runs unless user overrode
18+
try:
19+
if os.environ.get("LITELLM_LOG") is None:
20+
os.environ["LITELLM_LOG"] = "ERROR"
21+
_llog = logging.getLogger("LiteLLM")
22+
_llog.setLevel(logging.CRITICAL)
23+
_llog.propagate = False
24+
for _h in list(_llog.handlers):
25+
_llog.removeHandler(_h)
26+
except Exception:
27+
pass
28+
29+
# Do not modify global LiteLLM cache. Disable caching per-request instead.
30+
1731
async def process_row(row: EvaluationRow) -> EvaluationRow:
1832
"""Process a single row asynchronously."""
1933
if len(row.messages) == 0:
@@ -22,10 +36,21 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
2236
messages_payload = [{"role": m.role, "content": m.content} for m in row.messages]
2337

2438
request_params = {"model": config.model, "messages": messages_payload, **config.input_params}
39+
# Ensure caching is disabled only for this request (review feedback)
40+
request_params["cache"] = {"no-cache": True}
41+
# Allow passing reasoning effort to Fireworks via LiteLLM using extra_body
42+
# Expected: config.input_params may contain {"reasoning": {"effort": "low|medium|high"}}
43+
if "reasoning" in config.input_params:
44+
request_params.setdefault("extra_body", {})
45+
request_params["extra_body"]["reasoning"] = config.input_params["reasoning"]
2546

2647
if row.tools is not None:
2748
request_params["tools"] = row.tools
2849

50+
# Dynamic import to avoid static dependency/lint errors if LiteLLM isn't installed yet
51+
import importlib
52+
_litellm = importlib.import_module("litellm")
53+
acompletion = getattr(_litellm, "acompletion")
2954
response = await acompletion(**request_params)
3055

3156
assistant_content = response.choices[0].message.content or ""
@@ -57,8 +82,15 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
5782
default_logger.log(row)
5883
return row
5984

60-
# Process all rows concurrently
61-
tasks = [process_row(row) for row in rows]
85+
# Process rows with bounded concurrency if configured
86+
max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8
87+
semaphore = asyncio.Semaphore(max_concurrent)
88+
89+
async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
90+
async with semaphore:
91+
return await process_row(r)
92+
93+
tasks = [_sem_wrapper(row) for row in rows]
6294
dataset = list(await asyncio.gather(*tasks))
6395

6496
return dataset

0 commit comments

Comments
 (0)