Skip to content

Commit b73b5aa

Browse files
authored
Merge branch 'main' into derekx/standard-deviation
2 parents 9a9088e + 38a4444 commit b73b5aa

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+1150
-461
lines changed

eval_protocol/dataset_logger/sqlite_dataset_logger_adapter.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,17 @@ def __init__(self, db_path: Optional[str] = None, store: Optional[SqliteEvaluati
2222
self._store = SqliteEvaluationRowStore(self.db_path)
2323

2424
def log(self, row: "EvaluationRow") -> None:
25-
row_id = row.input_metadata.row_id
2625
data = row.model_dump(exclude_none=True, mode="json")
27-
self._store.upsert_row(row_id=row_id, data=data)
26+
self._store.upsert_row(data=data)
2827
try:
2928
event_bus.emit(LOG_EVENT_TYPE, EvaluationRow(**data))
3029
except Exception as e:
3130
# Avoid breaking storage due to event emission issues
3231
logger.error(f"Failed to emit row_upserted event: {e}")
3332
pass
3433

35-
def read(self, row_id: Optional[str] = None) -> List["EvaluationRow"]:
34+
def read(self, rollout_id: Optional[str] = None) -> List["EvaluationRow"]:
3635
from eval_protocol.models import EvaluationRow
3736

38-
results = self._store.read_rows(row_id=row_id)
37+
results = self._store.read_rows(rollout_id=rollout_id)
3938
return [EvaluationRow(**data) for data in results]

eval_protocol/dataset_logger/sqlite_evaluation_row_store.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class SqliteEvaluationRowStore:
1111
"""
1212
Lightweight reusable SQLite store for evaluation rows.
1313
14-
Stores arbitrary row data as JSON keyed by a unique string `row_id`.
14+
Stores arbitrary row data as JSON keyed by a unique string `rollout_id`.
1515
"""
1616

1717
def __init__(self, db_path: str):
@@ -24,7 +24,7 @@ class Meta:
2424
database = self._db
2525

2626
class EvaluationRow(BaseModel): # type: ignore
27-
row_id = CharField(unique=True)
27+
rollout_id = CharField(unique=True)
2828
data = JSONField()
2929

3030
self._EvaluationRow = EvaluationRow
@@ -36,22 +36,25 @@ class EvaluationRow(BaseModel): # type: ignore
3636
def db_path(self) -> str:
3737
return self._db_path
3838

39-
def upsert_row(self, row_id: str, data: dict) -> None:
40-
if self._EvaluationRow.select().where(self._EvaluationRow.row_id == row_id).exists():
41-
self._EvaluationRow.update(data=data).where(self._EvaluationRow.row_id == row_id).execute()
39+
def upsert_row(self, data: dict) -> None:
40+
rollout_id = data["rollout_id"]
41+
if "rollout_id" not in data:
42+
raise ValueError("rollout_id is required to upsert a row")
43+
if self._EvaluationRow.select().where(self._EvaluationRow.rollout_id == rollout_id).exists():
44+
self._EvaluationRow.update(data=data).where(self._EvaluationRow.rollout_id == rollout_id).execute()
4245
else:
43-
self._EvaluationRow.create(row_id=row_id, data=data)
46+
self._EvaluationRow.create(rollout_id=rollout_id, data=data)
4447

45-
def read_rows(self, row_id: Optional[str] = None) -> List[dict]:
46-
if row_id is None:
48+
def read_rows(self, rollout_id: Optional[str] = None) -> List[dict]:
49+
if rollout_id is None:
4750
query = self._EvaluationRow.select().dicts()
4851
else:
49-
query = self._EvaluationRow.select().dicts().where(self._EvaluationRow.row_id == row_id)
52+
query = self._EvaluationRow.select().dicts().where(self._EvaluationRow.rollout_id == rollout_id)
5053
results = list(query)
5154
return [result["data"] for result in results]
5255

53-
def delete_row(self, row_id: str) -> int:
54-
return self._EvaluationRow.delete().where(self._EvaluationRow.row_id == row_id).execute()
56+
def delete_row(self, rollout_id: str) -> int:
57+
return self._EvaluationRow.delete().where(self._EvaluationRow.rollout_id == rollout_id).execute()
5558

5659
def delete_all_rows(self) -> int:
5760
return self._EvaluationRow.delete().execute()

eval_protocol/mcp/client/connection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -539,10 +539,10 @@ async def close_session(self, session: MCPSession) -> None:
539539
await session._exit_stack.aclose()
540540
except asyncio.CancelledError:
541541
# Handle cancellation gracefully (especially important for Python 3.12)
542-
logger.debug(f"Session {session.session_id} close was cancelled")
542+
logger.error(f"Session {session.session_id} close was cancelled")
543543
except Exception as e:
544544
# Hitting this error, probably because of use of threads: "Attempted to exit cancel scope in a different task than it was entered in"
545-
logger.debug(f"Error closing session {session.session_id}: {e}")
545+
logger.error(f"Error closing session {session.session_id}: {e}")
546546
finally:
547547
session._exit_stack = None
548548
session._mcp_session = None

eval_protocol/mcp/execution/base_policy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ async def _generate_live_tool_calls(
220220
return mcp_tool_calls, usage_stats
221221
else:
222222
# No tool calls in response - this is normal when episode ends or LLM provides only text
223-
logger.info(f"No tool calls in response for env {env_index}, message content: {message.get('content')}")
223+
logger.debug(f"No tool calls in response for env {env_index}, message content: {message.get('content')}")
224224
return [
225225
MCPToolCall(
226226
tool_name="_no_tool_call",

eval_protocol/mcp/execution/manager.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,12 @@ async def execute_rollouts(
9797

9898
async def _execute_with_semaphore(idx):
9999
async with semaphore:
100-
return await self._execute_rollout(
100+
result = await self._execute_rollout(
101101
envs, policy, idx, steps, openai_logger, recording_mode, playback_mode, start_time
102102
)
103103

104+
return result
105+
104106
tasks = [_execute_with_semaphore(i) for i in range(envs.n)]
105107
# exceptions will be try catched inside single _execute_rollout
106108
trajectories = await asyncio.gather(*tasks)
@@ -112,9 +114,6 @@ async def _execute_with_semaphore(idx):
112114

113115
shared_tool_schema = envs.tool_schemas
114116

115-
# Clean up
116-
await envs.close()
117-
118117
# Enhanced reporting with control plane info
119118
successful = sum(1 for traj in trajectories if traj.total_reward > 0)
120119
terminated_by_control_plane = sum(
@@ -175,8 +174,11 @@ async def _execute_with_semaphore(idx):
175174
TerminationReason.USER_STOP,
176175
}:
177176
evaluation_rows[idx].rollout_status.status = "finished"
178-
elif trajectory.termination_reason == TerminationReason.MAX_STEPS:
177+
elif trajectory.termination_reason in {TerminationReason.MAX_STEPS, TerminationReason.INTERRUPTED}:
179178
evaluation_rows[idx].rollout_status.status = "stopped"
179+
evaluation_rows[idx].rollout_status.error_message = trajectory.control_plane_summary.get(
180+
"termination_reason", trajectory.termination_reason
181+
)
180182
else:
181183
evaluation_rows[idx].rollout_status.status = "error"
182184
evaluation_rows[idx].rollout_status.error_message = trajectory.control_plane_summary.get(
@@ -226,6 +228,7 @@ async def _execute_rollout(
226228
"total_tokens": 0,
227229
},
228230
)
231+
failure_reason = None
229232
try:
230233
current_observation, tool_schema = await envs.reset(session)
231234
system_prompt = dataset_row.system_prompt
@@ -311,8 +314,7 @@ async def _execute_rollout(
311314
# If there's no user simulator, no tool call means policy failed and we should terminate the rollout
312315
elif tool_calls[0].tool_name in ["_playback_terminate", "_no_tool_call"]:
313316
trajectory.terminated = True
314-
trajectory.termination_reason = TerminationReason.ERROR
315-
trajectory.control_plane_summary.update({"error_message": "No expected tool call"})
317+
trajectory.termination_reason = TerminationReason.INTERRUPTED
316318
break
317319

318320
# Execute each tool call sequentially
@@ -466,11 +468,26 @@ async def _execute_rollout(
466468
logger.info(
467469
f"✅ Rollout {rollout_idx} completed: {trajectory.steps} steps, reward: {trajectory.total_reward:.2f}, termination: {trajectory.termination_reason}, in thread {threading.current_thread().name}"
468470
)
471+
472+
except asyncio.CancelledError:
473+
logger.error(f"🚨 AsyncIO Cancel Error in roll out {rollout_idx}", exc_info=True)
474+
failure_reason = "asyncio context cancelled"
469475
except Exception as e:
470476
logger.error(f"🚨 Error in rollout {rollout_idx}: {e}", exc_info=True)
471-
trajectory.terminated = True
472-
trajectory.termination_reason = TerminationReason.ERROR
473-
trajectory.control_plane_summary.update({"error_message": str(e)})
477+
failure_reason = str(e)
478+
finally:
479+
if failure_reason:
480+
trajectory.terminated = True
481+
trajectory.termination_reason = TerminationReason.ERROR
482+
trajectory.control_plane_summary.update({"error_message": f"{failure_reason}"})
483+
try:
484+
await envs.connection_manager.reset_session(session)
485+
except:
486+
logger.error(f"Error resetting session {session.session_id}")
487+
try:
488+
await envs.connection_manager.close_session(session)
489+
except:
490+
logger.error(f"Error closing session {session.session_id}")
474491
return trajectory
475492

476493
async def _get_control_plane_status(self, session) -> Optional[Dict[str, Any]]:

eval_protocol/mcp/session/manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ async def reset(self, session: MCPSession) -> Tuple[Any, List[Dict]]:
5858
5959
This is thread-safe and can be called from worker threads.
6060
"""
61+
await self.connection_manager.initialize_session(session)
6162
# Get available tools from MCP server
6263
tool_schemas = await self.connection_manager.discover_tools(session)
6364

eval_protocol/mcp_env.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
policy = ep.FireworksPolicy(model_id="accounts/fireworks/models/qwen3-235b-a22b")
1818
1919
# Create environments with evaluation_rows configuration
20-
envs = await ep.make("http://localhost:8000/mcp", evaluation_rows=evaluation_rows)
20+
envs = ep.make("http://localhost:8000/mcp", evaluation_rows=evaluation_rows)
2121
2222
# Execute tool-calling rollouts
2323
evaluation_rows = await ep.rollout(envs, policy=policy, steps=512)
@@ -86,18 +86,17 @@ async def reset_mcp_sessions(envs: GeneralMCPVectorEnv):
8686
Reset mcp server sessions
8787
"""
8888
tasks = [envs.connection_manager.reset_session(session) for session in envs.sessions]
89-
await asyncio.gather(*tasks)
89+
await asyncio.gather(*tasks, return_exceptions=True)
9090

9191

92-
async def make(
92+
def make(
9393
env_spec: str,
9494
evaluation_rows: Optional[List[EvaluationRow]] = None,
9595
dataset: Optional[List[Dict]] = None,
9696
n: Optional[int] = None,
9797
seeds: Optional[List[int]] = None,
9898
model_id: str = "unknown",
9999
user_prompt_formatter: Optional[Callable] = None,
100-
reset_sessions: bool = False,
101100
) -> GeneralMCPVectorEnv:
102101
"""
103102
Create general MCP environments driven by evaluation_rows configuration.
@@ -110,20 +109,19 @@ async def make(
110109
seeds: List of seeds (for backward compatibility)
111110
model_id: Model identifier
112111
user_prompt_formatter: Optional callback for formatting user prompts
113-
reset_sessions: Whether to reset sessions before returning the environment
114112
115113
Returns:
116114
General MCP environment that works with any MCP server
117115
118116
Example:
119117
# EvaluationRow approach (preferred)
120-
envs = await ep.make("http://localhost:8000/mcp", evaluation_rows=evaluation_rows)
118+
envs = ep.make("http://localhost:8000/mcp", evaluation_rows=evaluation_rows)
121119
122120
# Dataset approach (backward compatibility)
123-
envs = await ep.make("http://localhost:8000/mcp", dataset=dataset)
121+
envs = ep.make("http://localhost:8000/mcp", dataset=dataset)
124122
125123
# Legacy approach (backward compatibility)
126-
envs = await ep.make("http://localhost:8000/mcp", n=10, seeds=seeds)
124+
envs = ep.make("http://localhost:8000/mcp", n=10, seeds=seeds)
127125
"""
128126
# Parse environment specification - make sure URL format is correct
129127
base_url = env_spec
@@ -236,12 +234,6 @@ async def make(
236234
sessions.append(session)
237235

238236
mcp_envs = GeneralMCPVectorEnv(sessions, dataset_rows, user_prompt_formatter)
239-
tasks = [mcp_envs.connection_manager.initialize_session(session) for session in sessions]
240-
await asyncio.gather(*tasks)
241-
242-
if reset_sessions:
243-
await reset_mcp_sessions(mcp_envs)
244-
245237
return mcp_envs
246238

247239

eval_protocol/pytest/default_agent_rollout_processor.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from openai.types.chat import ChatCompletionContentPartTextParam, ChatCompletionMessage, ChatCompletionToolParam
99
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
1010

11-
from eval_protocol.dataset_logger import default_logger
11+
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
1212
from eval_protocol.mcp.execution.policy import LiteLLMPolicy
1313
from eval_protocol.mcp.mcp_multi_client import MCPMultiClient
1414
from eval_protocol.models import EvaluationRow, Message
@@ -20,12 +20,13 @@ class Agent:
2020
A really simple agent that calls the model until no more tool calls are needed.
2121
"""
2222

23-
def __init__(self, model: str, row: EvaluationRow, config_path: str):
23+
def __init__(self, model: str, row: EvaluationRow, config_path: str, logger: DatasetLogger):
2424
self.model = model
2525
self.evaluation_row: EvaluationRow = row
2626
self._policy = LiteLLMPolicy(model_id=model)
2727
self.mcp_client = MCPMultiClient(config_path=config_path) if config_path else None
2828
self.tools: Union[List[ChatCompletionToolParam], NotGiven] = NOT_GIVEN
29+
self.logger: DatasetLogger = logger
2930

3031
async def setup(self):
3132
if self.mcp_client:
@@ -42,7 +43,7 @@ def messages(self) -> list[Message]:
4243

4344
def append_message_and_log(self, message: Message):
4445
self.messages.append(message)
45-
default_logger.log(self.evaluation_row)
46+
self.logger.log(self.evaluation_row)
4647

4748
async def call_agent(self) -> str:
4849
"""
@@ -116,7 +117,7 @@ async def default_agent_rollout_processor(
116117
) -> List[EvaluationRow]:
117118
dataset: Dataset = []
118119
for row in rows:
119-
agent = Agent(model=config.model, row=row, config_path=config.mcp_config_path)
120+
agent = Agent(model=config.model, row=row, config_path=config.mcp_config_path, logger=config.logger)
120121
await agent.setup()
121122
await agent.call_agent()
122123
dataset.append(agent.evaluation_row)

eval_protocol/pytest/default_mcp_gym_rollout_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ async def default_mcp_gym_rollout_processor(
226226
)
227227

228228
# Create MCP environments directly from evaluation_rows
229-
envs = await ep.make(
229+
envs = ep.make(
230230
"http://localhost:9700/mcp/",
231231
evaluation_rows=rows,
232232
model_id=policy.model_id,

eval_protocol/pytest/default_single_turn_rollout_process.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
import asyncio
2-
from typing import List
3-
42
import logging
53
import os
4+
from typing import List
65

7-
from eval_protocol.dataset_logger import default_logger
8-
from eval_protocol.models import EvaluationRow, Message, ChatCompletionMessageToolCall
6+
from eval_protocol.models import ChatCompletionMessageToolCall, EvaluationRow, Message
97
from eval_protocol.pytest.types import RolloutProcessorConfig
108

119

@@ -49,6 +47,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
4947

5048
# Dynamic import to avoid static dependency/lint errors if LiteLLM isn't installed yet
5149
import importlib
50+
5251
_litellm = importlib.import_module("litellm")
5352
acompletion = getattr(_litellm, "acompletion")
5453
response = await acompletion(**request_params)
@@ -79,7 +78,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
7978
]
8079

8180
row.messages = messages
82-
default_logger.log(row)
81+
config.logger.log(row)
8382
return row
8483

8584
# Process rows with bounded concurrency if configured

0 commit comments

Comments
 (0)