From b749f9cb16fdb3e6fe933a86f2be7c2df95b5a84 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Sun, 3 Aug 2025 13:34:40 -0700 Subject: [PATCH 1/2] refactor rollout processor to accept entire input dataset --- .../pytest/default_agent_rollout_processor.py | 17 ++++++---- .../pytest/default_no_op_rollout_process.py | 6 ++-- .../default_single_turn_rollout_process.py | 34 +++++++++++-------- eval_protocol/pytest/evaluation_test.py | 5 +-- eval_protocol/pytest/types.py | 3 +- 5 files changed, 36 insertions(+), 29 deletions(-) diff --git a/eval_protocol/pytest/default_agent_rollout_processor.py b/eval_protocol/pytest/default_agent_rollout_processor.py index bd68facf..ffdd0c7b 100644 --- a/eval_protocol/pytest/default_agent_rollout_processor.py +++ b/eval_protocol/pytest/default_agent_rollout_processor.py @@ -9,7 +9,7 @@ from eval_protocol.mcp.execution.policy import LiteLLMPolicy from eval_protocol.mcp.mcp_multi_client import MCPMultiClient from eval_protocol.models import EvaluationRow, Message -from eval_protocol.pytest.types import RolloutProcessorConfig +from eval_protocol.pytest.types import Dataset, RolloutProcessorConfig class Agent: @@ -73,8 +73,13 @@ def _get_content_from_tool_result(self, tool_result: CallToolResult) -> str: return first_content.text -async def default_agent_rollout_processor(row: EvaluationRow, config: RolloutProcessorConfig) -> List[EvaluationRow]: - agent = Agent(model=config.model, initial_messages=config.initial_messages, config_path=config.mcp_config_path) - await agent.setup() - await agent.call_agent() - return [EvaluationRow(messages=agent.messages)] +async def default_agent_rollout_processor( + rows: List[EvaluationRow], config: RolloutProcessorConfig +) -> List[EvaluationRow]: + dataset: Dataset = [] + for row in rows: + agent = Agent(model=config.model, initial_messages=row.messages, config_path=config.mcp_config_path) + await agent.setup() + await agent.call_agent() + dataset.append(EvaluationRow(messages=agent.messages, ground_truth=row.ground_truth)) + return dataset diff --git a/eval_protocol/pytest/default_no_op_rollout_process.py b/eval_protocol/pytest/default_no_op_rollout_process.py index 99c4e875..bae733c3 100644 --- a/eval_protocol/pytest/default_no_op_rollout_process.py +++ b/eval_protocol/pytest/default_no_op_rollout_process.py @@ -1,12 +1,12 @@ from typing import List from eval_protocol.models import EvaluationRow -from eval_protocol.pytest.types import ModelParam, RolloutProcessorConfig +from eval_protocol.pytest.types import RolloutProcessorConfig -def default_no_op_rollout_processor(row: EvaluationRow, config: RolloutProcessorConfig) -> List[EvaluationRow]: +def default_no_op_rollout_processor(rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[EvaluationRow]: """ Simply passes input dataset through to the test function. This can be useful if you want to run the rollout yourself. """ - return [row] + return rows diff --git a/eval_protocol/pytest/default_single_turn_rollout_process.py b/eval_protocol/pytest/default_single_turn_rollout_process.py index 4b2a0877..ff349b71 100644 --- a/eval_protocol/pytest/default_single_turn_rollout_process.py +++ b/eval_protocol/pytest/default_single_turn_rollout_process.py @@ -4,27 +4,33 @@ from eval_protocol.auth import get_fireworks_api_base, get_fireworks_api_key from eval_protocol.models import CompletionParams, EvaluationRow, InputMetadata, Message -from eval_protocol.pytest.types import ModelParam, RolloutProcessorConfig +from eval_protocol.pytest.types import Dataset, ModelParam, RolloutProcessorConfig -def default_single_turn_rollout_processor(row: EvaluationRow, config: RolloutProcessorConfig) -> List[EvaluationRow]: +def default_single_turn_rollout_processor( + rows: List[EvaluationRow], config: RolloutProcessorConfig +) -> List[EvaluationRow]: """Generate a single response from a Fireworks model.""" api_key = get_fireworks_api_key() api_base = get_fireworks_api_base() client = OpenAI(api_key=api_key, base_url=f"{api_base}/inference/v1") - if len(row.messages) == 0: - raise ValueError("Messages is empty. Please provide a non-empty dataset") + dataset: Dataset = [] + for row in rows: + if len(row.messages) == 0: + raise ValueError("Messages is empty. Please provide a non-empty dataset") - messages_payload = [{"role": m.role, "content": m.content} for m in row.messages] + messages_payload = [{"role": m.role, "content": m.content} for m in row.messages] - response = client.chat.completions.create(model=config.model, messages=messages_payload, **config.input_params) - assistant_content = response.choices[0].message.content or "" - messages = list(row.messages) + [Message(role="assistant", content=assistant_content)] - processed = EvaluationRow( - messages=messages, - ground_truth=row.ground_truth, - input_metadata=InputMetadata(completion_params=CompletionParams(model=config.model)), - ) - return [processed] + response = client.chat.completions.create(model=config.model, messages=messages_payload, **config.input_params) + assistant_content = response.choices[0].message.content or "" + messages = list(row.messages) + [Message(role="assistant", content=assistant_content)] + processed = EvaluationRow( + messages=messages, + ground_truth=row.ground_truth, + input_metadata=InputMetadata(completion_params=CompletionParams(model=config.model)), + ) + + dataset.append(processed) + return dataset diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index 8cfd2e1a..c4b222d2 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -182,11 +182,8 @@ def wrapper_body(**kwargs): model=model_name, input_params=kwargs.get("input_params") or {}, mcp_config_path=mcp_config_path or "", - initial_messages=kwargs.get("input_messages") if "input_messages" in kwargs else [], ) - for row in data: - processed: List[EvaluationRow] = execute_function(rollout_processor, row=row, config=config) - input_dataset.extend(processed) + input_dataset = execute_function(rollout_processor, rows=data, config=config) all_results: List[EvaluationRow] = [] for _ in range(num_runs): diff --git a/eval_protocol/pytest/types.py b/eval_protocol/pytest/types.py index a1e124c8..57bef1cc 100644 --- a/eval_protocol/pytest/types.py +++ b/eval_protocol/pytest/types.py @@ -39,7 +39,6 @@ class RolloutProcessorConfig: model: ModelParam input_params: InputParam # optional input parameters for inference mcp_config_path: str # for agent rollout processor - initial_messages: list[Message] # for agent rollout processor -RolloutProcessor = Callable[[EvaluationRow, RolloutProcessorConfig], List[EvaluationRow]] +RolloutProcessor = Callable[[List[EvaluationRow], RolloutProcessorConfig], List[EvaluationRow]] From aa31b8245789e31296aa0056527867181f5e39cc Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Sun, 3 Aug 2025 13:38:02 -0700 Subject: [PATCH 2/2] run single turn rollouts in parallel --- .../default_single_turn_rollout_process.py | 32 +++++++++++-------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/eval_protocol/pytest/default_single_turn_rollout_process.py b/eval_protocol/pytest/default_single_turn_rollout_process.py index ff349b71..dbc8fb68 100644 --- a/eval_protocol/pytest/default_single_turn_rollout_process.py +++ b/eval_protocol/pytest/default_single_turn_rollout_process.py @@ -1,36 +1,42 @@ +import asyncio from typing import List -from openai import OpenAI +from openai import AsyncOpenAI from eval_protocol.auth import get_fireworks_api_base, get_fireworks_api_key -from eval_protocol.models import CompletionParams, EvaluationRow, InputMetadata, Message -from eval_protocol.pytest.types import Dataset, ModelParam, RolloutProcessorConfig +from eval_protocol.models import EvaluationRow, Message +from eval_protocol.pytest.types import RolloutProcessorConfig -def default_single_turn_rollout_processor( +async def default_single_turn_rollout_processor( rows: List[EvaluationRow], config: RolloutProcessorConfig ) -> List[EvaluationRow]: - """Generate a single response from a Fireworks model.""" + """Generate a single response from a Fireworks model concurrently.""" api_key = get_fireworks_api_key() api_base = get_fireworks_api_base() - client = OpenAI(api_key=api_key, base_url=f"{api_base}/inference/v1") + client = AsyncOpenAI(api_key=api_key, base_url=f"{api_base}/inference/v1") - dataset: Dataset = [] - for row in rows: + async def process_row(row: EvaluationRow) -> EvaluationRow: + """Process a single row asynchronously.""" if len(row.messages) == 0: raise ValueError("Messages is empty. Please provide a non-empty dataset") messages_payload = [{"role": m.role, "content": m.content} for m in row.messages] - response = client.chat.completions.create(model=config.model, messages=messages_payload, **config.input_params) + response = await client.chat.completions.create( + model=config.model, messages=messages_payload, **config.input_params + ) assistant_content = response.choices[0].message.content or "" messages = list(row.messages) + [Message(role="assistant", content=assistant_content)] - processed = EvaluationRow( + + return EvaluationRow( messages=messages, - ground_truth=row.ground_truth, - input_metadata=InputMetadata(completion_params=CompletionParams(model=config.model)), + **row.model_dump(exclude={"messages"}), ) - dataset.append(processed) + # Process all rows concurrently + tasks = [process_row(row) for row in rows] + dataset = await asyncio.gather(*tasks) + return dataset