diff --git a/eval_protocol/pytest/default_agent_rollout_processor.py b/eval_protocol/pytest/default_agent_rollout_processor.py index bd68facf..ffdd0c7b 100644 --- a/eval_protocol/pytest/default_agent_rollout_processor.py +++ b/eval_protocol/pytest/default_agent_rollout_processor.py @@ -9,7 +9,7 @@ from eval_protocol.mcp.execution.policy import LiteLLMPolicy from eval_protocol.mcp.mcp_multi_client import MCPMultiClient from eval_protocol.models import EvaluationRow, Message -from eval_protocol.pytest.types import RolloutProcessorConfig +from eval_protocol.pytest.types import Dataset, RolloutProcessorConfig class Agent: @@ -73,8 +73,13 @@ def _get_content_from_tool_result(self, tool_result: CallToolResult) -> str: return first_content.text -async def default_agent_rollout_processor(row: EvaluationRow, config: RolloutProcessorConfig) -> List[EvaluationRow]: - agent = Agent(model=config.model, initial_messages=config.initial_messages, config_path=config.mcp_config_path) - await agent.setup() - await agent.call_agent() - return [EvaluationRow(messages=agent.messages)] +async def default_agent_rollout_processor( + rows: List[EvaluationRow], config: RolloutProcessorConfig +) -> List[EvaluationRow]: + dataset: Dataset = [] + for row in rows: + agent = Agent(model=config.model, initial_messages=row.messages, config_path=config.mcp_config_path) + await agent.setup() + await agent.call_agent() + dataset.append(EvaluationRow(messages=agent.messages, ground_truth=row.ground_truth)) + return dataset diff --git a/eval_protocol/pytest/default_no_op_rollout_process.py b/eval_protocol/pytest/default_no_op_rollout_process.py index 99c4e875..bae733c3 100644 --- a/eval_protocol/pytest/default_no_op_rollout_process.py +++ b/eval_protocol/pytest/default_no_op_rollout_process.py @@ -1,12 +1,12 @@ from typing import List from eval_protocol.models import EvaluationRow -from eval_protocol.pytest.types import ModelParam, RolloutProcessorConfig +from eval_protocol.pytest.types import RolloutProcessorConfig -def default_no_op_rollout_processor(row: EvaluationRow, config: RolloutProcessorConfig) -> List[EvaluationRow]: +def default_no_op_rollout_processor(rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[EvaluationRow]: """ Simply passes input dataset through to the test function. This can be useful if you want to run the rollout yourself. """ - return [row] + return rows diff --git a/eval_protocol/pytest/default_single_turn_rollout_process.py b/eval_protocol/pytest/default_single_turn_rollout_process.py index 4b2a0877..dbc8fb68 100644 --- a/eval_protocol/pytest/default_single_turn_rollout_process.py +++ b/eval_protocol/pytest/default_single_turn_rollout_process.py @@ -1,30 +1,42 @@ +import asyncio from typing import List -from openai import OpenAI +from openai import AsyncOpenAI from eval_protocol.auth import get_fireworks_api_base, get_fireworks_api_key -from eval_protocol.models import CompletionParams, EvaluationRow, InputMetadata, Message -from eval_protocol.pytest.types import ModelParam, RolloutProcessorConfig +from eval_protocol.models import EvaluationRow, Message +from eval_protocol.pytest.types import RolloutProcessorConfig -def default_single_turn_rollout_processor(row: EvaluationRow, config: RolloutProcessorConfig) -> List[EvaluationRow]: - """Generate a single response from a Fireworks model.""" +async def default_single_turn_rollout_processor( + rows: List[EvaluationRow], config: RolloutProcessorConfig +) -> List[EvaluationRow]: + """Generate a single response from a Fireworks model concurrently.""" api_key = get_fireworks_api_key() api_base = get_fireworks_api_base() - client = OpenAI(api_key=api_key, base_url=f"{api_base}/inference/v1") - - if len(row.messages) == 0: - raise ValueError("Messages is empty. Please provide a non-empty dataset") - - messages_payload = [{"role": m.role, "content": m.content} for m in row.messages] - - response = client.chat.completions.create(model=config.model, messages=messages_payload, **config.input_params) - assistant_content = response.choices[0].message.content or "" - messages = list(row.messages) + [Message(role="assistant", content=assistant_content)] - processed = EvaluationRow( - messages=messages, - ground_truth=row.ground_truth, - input_metadata=InputMetadata(completion_params=CompletionParams(model=config.model)), - ) - return [processed] + client = AsyncOpenAI(api_key=api_key, base_url=f"{api_base}/inference/v1") + + async def process_row(row: EvaluationRow) -> EvaluationRow: + """Process a single row asynchronously.""" + if len(row.messages) == 0: + raise ValueError("Messages is empty. Please provide a non-empty dataset") + + messages_payload = [{"role": m.role, "content": m.content} for m in row.messages] + + response = await client.chat.completions.create( + model=config.model, messages=messages_payload, **config.input_params + ) + assistant_content = response.choices[0].message.content or "" + messages = list(row.messages) + [Message(role="assistant", content=assistant_content)] + + return EvaluationRow( + messages=messages, + **row.model_dump(exclude={"messages"}), + ) + + # Process all rows concurrently + tasks = [process_row(row) for row in rows] + dataset = await asyncio.gather(*tasks) + + return dataset diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index 8cfd2e1a..c4b222d2 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -182,11 +182,8 @@ def wrapper_body(**kwargs): model=model_name, input_params=kwargs.get("input_params") or {}, mcp_config_path=mcp_config_path or "", - initial_messages=kwargs.get("input_messages") if "input_messages" in kwargs else [], ) - for row in data: - processed: List[EvaluationRow] = execute_function(rollout_processor, row=row, config=config) - input_dataset.extend(processed) + input_dataset = execute_function(rollout_processor, rows=data, config=config) all_results: List[EvaluationRow] = [] for _ in range(num_runs): diff --git a/eval_protocol/pytest/types.py b/eval_protocol/pytest/types.py index a1e124c8..57bef1cc 100644 --- a/eval_protocol/pytest/types.py +++ b/eval_protocol/pytest/types.py @@ -39,7 +39,6 @@ class RolloutProcessorConfig: model: ModelParam input_params: InputParam # optional input parameters for inference mcp_config_path: str # for agent rollout processor - initial_messages: list[Message] # for agent rollout processor -RolloutProcessor = Callable[[EvaluationRow, RolloutProcessorConfig], List[EvaluationRow]] +RolloutProcessor = Callable[[List[EvaluationRow], RolloutProcessorConfig], List[EvaluationRow]]