Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions eval_protocol/pytest/default_agent_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from eval_protocol.mcp.execution.policy import LiteLLMPolicy
from eval_protocol.mcp.mcp_multi_client import MCPMultiClient
from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest.types import RolloutProcessorConfig
from eval_protocol.pytest.types import Dataset, RolloutProcessorConfig


class Agent:
Expand Down Expand Up @@ -73,8 +73,13 @@ def _get_content_from_tool_result(self, tool_result: CallToolResult) -> str:
return first_content.text


async def default_agent_rollout_processor(row: EvaluationRow, config: RolloutProcessorConfig) -> List[EvaluationRow]:
agent = Agent(model=config.model, initial_messages=config.initial_messages, config_path=config.mcp_config_path)
await agent.setup()
await agent.call_agent()
return [EvaluationRow(messages=agent.messages)]
async def default_agent_rollout_processor(
rows: List[EvaluationRow], config: RolloutProcessorConfig
) -> List[EvaluationRow]:
dataset: Dataset = []
for row in rows:
agent = Agent(model=config.model, initial_messages=row.messages, config_path=config.mcp_config_path)
await agent.setup()
await agent.call_agent()
dataset.append(EvaluationRow(messages=agent.messages, ground_truth=row.ground_truth))
return dataset
6 changes: 3 additions & 3 deletions eval_protocol/pytest/default_no_op_rollout_process.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import List

from eval_protocol.models import EvaluationRow
from eval_protocol.pytest.types import ModelParam, RolloutProcessorConfig
from eval_protocol.pytest.types import RolloutProcessorConfig


def default_no_op_rollout_processor(row: EvaluationRow, config: RolloutProcessorConfig) -> List[EvaluationRow]:
def default_no_op_rollout_processor(rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[EvaluationRow]:
"""
Simply passes input dataset through to the test function. This can be useful
if you want to run the rollout yourself.
"""
return [row]
return rows
54 changes: 33 additions & 21 deletions eval_protocol/pytest/default_single_turn_rollout_process.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,42 @@
import asyncio
from typing import List

from openai import OpenAI
from openai import AsyncOpenAI

from eval_protocol.auth import get_fireworks_api_base, get_fireworks_api_key
from eval_protocol.models import CompletionParams, EvaluationRow, InputMetadata, Message
from eval_protocol.pytest.types import ModelParam, RolloutProcessorConfig
from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest.types import RolloutProcessorConfig


def default_single_turn_rollout_processor(row: EvaluationRow, config: RolloutProcessorConfig) -> List[EvaluationRow]:
"""Generate a single response from a Fireworks model."""
async def default_single_turn_rollout_processor(
rows: List[EvaluationRow], config: RolloutProcessorConfig
) -> List[EvaluationRow]:
"""Generate a single response from a Fireworks model concurrently."""

api_key = get_fireworks_api_key()
api_base = get_fireworks_api_base()
client = OpenAI(api_key=api_key, base_url=f"{api_base}/inference/v1")

if len(row.messages) == 0:
raise ValueError("Messages is empty. Please provide a non-empty dataset")

messages_payload = [{"role": m.role, "content": m.content} for m in row.messages]

response = client.chat.completions.create(model=config.model, messages=messages_payload, **config.input_params)
assistant_content = response.choices[0].message.content or ""
messages = list(row.messages) + [Message(role="assistant", content=assistant_content)]
processed = EvaluationRow(
messages=messages,
ground_truth=row.ground_truth,
input_metadata=InputMetadata(completion_params=CompletionParams(model=config.model)),
)
return [processed]
client = AsyncOpenAI(api_key=api_key, base_url=f"{api_base}/inference/v1")

async def process_row(row: EvaluationRow) -> EvaluationRow:
"""Process a single row asynchronously."""
if len(row.messages) == 0:
raise ValueError("Messages is empty. Please provide a non-empty dataset")

messages_payload = [{"role": m.role, "content": m.content} for m in row.messages]

response = await client.chat.completions.create(
model=config.model, messages=messages_payload, **config.input_params
)
assistant_content = response.choices[0].message.content or ""
messages = list(row.messages) + [Message(role="assistant", content=assistant_content)]

return EvaluationRow(
messages=messages,
**row.model_dump(exclude={"messages"}),
)

# Process all rows concurrently
tasks = [process_row(row) for row in rows]
dataset = await asyncio.gather(*tasks)

return dataset
5 changes: 1 addition & 4 deletions eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,8 @@ def wrapper_body(**kwargs):
model=model_name,
input_params=kwargs.get("input_params") or {},
mcp_config_path=mcp_config_path or "",
initial_messages=kwargs.get("input_messages") if "input_messages" in kwargs else [],
)
for row in data:
processed: List[EvaluationRow] = execute_function(rollout_processor, row=row, config=config)
input_dataset.extend(processed)
input_dataset = execute_function(rollout_processor, rows=data, config=config)

all_results: List[EvaluationRow] = []
for _ in range(num_runs):
Expand Down
3 changes: 1 addition & 2 deletions eval_protocol/pytest/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ class RolloutProcessorConfig:
model: ModelParam
input_params: InputParam # optional input parameters for inference
mcp_config_path: str # for agent rollout processor
initial_messages: list[Message] # for agent rollout processor


RolloutProcessor = Callable[[EvaluationRow, RolloutProcessorConfig], List[EvaluationRow]]
RolloutProcessor = Callable[[List[EvaluationRow], RolloutProcessorConfig], List[EvaluationRow]]
Loading