diff --git a/eval_protocol/mcp/execution/manager.py b/eval_protocol/mcp/execution/manager.py index 5664e5ac..e0d101a7 100644 --- a/eval_protocol/mcp/execution/manager.py +++ b/eval_protocol/mcp/execution/manager.py @@ -12,7 +12,7 @@ import threading import time from dataclasses import asdict -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Dict, List, Optional, Union import anyio from openai.types import CompletionUsage @@ -43,7 +43,7 @@ async def execute_rollouts( openai_format_log_file: Optional[str] = None, max_concurrent_rollouts: int = 8, evaluation_rows: Optional[List[EvaluationRow]] = None, - ) -> List[EvaluationRow]: + ) -> AsyncIterator[EvaluationRow]: """ Execute general rollouts using tool calling interface with automatic record/playback. @@ -66,7 +66,7 @@ async def execute_rollouts( - Set and file exists: Playback mode (uses recorded data) Returns: - List of EvaluationRow objects with unified evaluation data format + AsyncIterator of EvaluationRow objects with unified evaluation data format """ start_time = time.time() @@ -92,96 +92,77 @@ async def execute_rollouts( logger.info(f"๐Ÿงต Starting {envs.n} rollouts with max {max_concurrent_rollouts} concurrent threads...") - results = {} + if evaluation_rows is None: + evaluation_rows = [EvaluationRow(messages=[], input_metadata=InputMetadata()) for _ in range(envs.n)] + + shared_tool_schema = envs.tool_schemas semaphore = asyncio.Semaphore(max_concurrent_rollouts) async def _execute_with_semaphore(idx): async with semaphore: - result = await self._execute_rollout( + trajectory = await self._execute_rollout( envs, policy, idx, steps, openai_logger, recording_mode, playback_mode, start_time ) - return result - - tasks = [_execute_with_semaphore(i) for i in range(envs.n)] - # exceptions will be try catched inside single _execute_rollout - trajectories = await asyncio.gather(*tasks) - - # Calculate durations - total_duration = time.time() - start_time - for trajectory in trajectories: - trajectory.duration = total_duration - - shared_tool_schema = envs.tool_schemas - - # Enhanced reporting with control plane info - successful = sum(1 for traj in trajectories if traj.total_reward > 0) - terminated_by_control_plane = sum( - 1 - for traj in trajectories - if traj.control_plane_summary.get("termination_reason") == "control_plane_signal" - ) + # Convert trajectory to EvaluationRow immediately + evaluation_row = evaluation_rows[idx] + + # Handle multimodal content by extracting text from complex content structures + messages = [] + for msg in trajectory.conversation_history: + # Create a copy to avoid modifying the original + msg_dict = dict(msg) + + # Handle multimodal content (list of content blocks) by extracting text + if isinstance(msg_dict.get("content"), list): + text_content = None + for content_block in msg_dict["content"]: + if isinstance(content_block, dict) and content_block.get("type") == "text": + text_content = content_block.get("text") + break + msg_dict["content"] = text_content or "" + + messages.append(Message.model_validate(msg_dict)) + + evaluation_row.messages = messages + evaluation_row.tools = shared_tool_schema + evaluation_row.usage = CompletionUsage(**trajectory.usage) + evaluation_row.input_metadata.completion_params = CompletionParams( + model=policy.model_id, + temperature=getattr(policy, "temperature", None), + max_tokens=getattr(policy, "max_tokens", None), + max_tool_calls=getattr(policy, "max_tools_per_turn", None), + ) - logger.info(f"๐Ÿ“Š Rollout complete: {successful}/{len(trajectories)} reached goal") - logger.info(f"๐ŸŽ›๏ธ Control plane terminations: {terminated_by_control_plane}/{len(trajectories)}") - logger.info(f"โฑ๏ธ Total duration: {total_duration:.2f}s") - logger.info(f"๐Ÿงต Used {max_concurrent_rollouts} concurrent threads") + if trajectory.terminated: + if trajectory.termination_reason == TerminationReason.ERROR: + evaluation_row.rollout_status.status = "error" + evaluation_row.rollout_status.error_message = trajectory.control_plane_summary.get( + "error_message", None + ) + else: + evaluation_row.rollout_status.status = "finished" + evaluation_row.rollout_status.termination_reason = trajectory.termination_reason + else: + evaluation_row.rollout_status.status = "running" - # Print log file locations if created - if openai_format_log_file: - logger.info(f"๐Ÿ’ฌ OpenAI format log: {openai_format_log_file}") - if recording_mode: - logger.info(f"๐Ÿ“ Recorded trajectory: {playback_file}") - # Add note about control plane separation - logger.info(f"๐ŸŽ›๏ธ Trajectories include control plane separation") + return evaluation_row - # Convert trajectories to unified EvaluationRow format. If no evaluation_rows are provided, create empty ones for backwards compatibility. - if evaluation_rows is None: - evaluation_rows = [EvaluationRow(messages=[], input_metadata=InputMetadata()) for _ in trajectories] - - for idx, trajectory in enumerate(trajectories): - # Handle multimodal content by extracting text from complex content structures - messages = [] - for msg in trajectory.conversation_history: - # Create a copy to avoid modifying the original - msg_dict = dict(msg) - - # Handle multimodal content (list of content blocks) by extracting text - if isinstance(msg_dict.get("content"), list): - text_content = None - for content_block in msg_dict["content"]: - if isinstance(content_block, dict) and content_block.get("type") == "text": - text_content = content_block.get("text") - break - msg_dict["content"] = text_content or "" - - messages.append(Message.model_validate(msg_dict)) - - evaluation_rows[idx].messages = messages - # evaluation_rows[idx].input_metadata.row_id = envs.dataset_rows[idx].id - # evaluation_rows[idx].input_metadata.dataset_info = asdict(envs.dataset_rows[idx]) - evaluation_rows[idx].tools = shared_tool_schema - evaluation_rows[idx].usage = CompletionUsage(**trajectory.usage) - evaluation_rows[idx].input_metadata.completion_params = CompletionParams( - model=policy.model_id, - temperature=getattr(policy, "temperature", None), - max_tokens=getattr(policy, "max_tokens", None), - max_tool_calls=getattr(policy, "max_tools_per_turn", None), - ) - if trajectory.terminated: - if trajectory.termination_reason == TerminationReason.ERROR: - evaluation_rows[idx].rollout_status.status = "error" - evaluation_rows[idx].rollout_status.termination_reason = trajectory.control_plane_summary.get( - "error_message", None - ) - else: - evaluation_rows[idx].rollout_status.status = "finished" - evaluation_rows[idx].rollout_status.termination_reason = trajectory.termination_reason - else: - evaluation_rows[idx].rollout_status.status = "running" + # Create all tasks + tasks = [asyncio.create_task(_execute_with_semaphore(i)) for i in range(envs.n)] - return evaluation_rows + # Yield results as they complete (note that they're not necessarily in original order) + try: + for task in asyncio.as_completed(tasks): + try: + yield await task + except Exception: + logger.exception("Error processing rollout") + finally: + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) async def _execute_rollout( self, diff --git a/eval_protocol/mcp_env.py b/eval_protocol/mcp_env.py index 5ec67658..5d930a4e 100644 --- a/eval_protocol/mcp_env.py +++ b/eval_protocol/mcp_env.py @@ -41,11 +41,13 @@ """ import asyncio +import hashlib +import json # For legacy compatibility - import the facade functions import logging import random -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Union # Import all functionality from the new modular components from .mcp.execution.manager import ExecutionManager @@ -53,9 +55,6 @@ from .mcp.session.manager import GeneralMCPVectorEnv from .models import EvaluationRow from .types import DatasetRow, MCPSession, MCPToolCall -import asyncio -import hashlib -import json logger = logging.getLogger(__name__) @@ -247,7 +246,7 @@ async def rollout( steps: int = 512, openai_format_log_file: Optional[str] = None, max_concurrent_rollouts: int = 8, -) -> List[EvaluationRow]: +) -> AsyncIterator[EvaluationRow]: """ Execute general rollouts using tool calling interface with automatic record/playback. @@ -307,9 +306,10 @@ async def rollout( # Use the new ExecutionManager for execution execution_manager = ExecutionManager() - return await execution_manager.execute_rollouts( + async for evaluation_row in execution_manager.execute_rollouts( envs, policy, steps, openai_format_log_file, max_concurrent_rollouts, evaluation_rows - ) + ): + yield evaluation_row async def test_mcp(base_url: str, seeds: List[int]) -> Dict[str, Any]: diff --git a/eval_protocol/pytest/__init__.py b/eval_protocol/pytest/__init__.py index ce881ccc..2d2576d6 100644 --- a/eval_protocol/pytest/__init__.py +++ b/eval_protocol/pytest/__init__.py @@ -1,12 +1,14 @@ from .default_agent_rollout_processor import default_agent_rollout_processor +from .default_dataset_adapter import default_dataset_adapter +from .default_mcp_gym_rollout_processor import default_mcp_gym_rollout_processor from .default_no_op_rollout_process import default_no_op_rollout_processor from .default_single_turn_rollout_process import default_single_turn_rollout_processor from .evaluation_test import evaluation_test from .types import RolloutProcessor, RolloutProcessorConfig -from .default_dataset_adapter import default_dataset_adapter __all__ = [ "default_agent_rollout_processor", + "default_mcp_gym_rollout_processor", "default_no_op_rollout_processor", "default_single_turn_rollout_processor", "default_dataset_adapter", diff --git a/eval_protocol/pytest/default_agent_rollout_processor.py b/eval_protocol/pytest/default_agent_rollout_processor.py index bd7c62c2..6a158b54 100644 --- a/eval_protocol/pytest/default_agent_rollout_processor.py +++ b/eval_protocol/pytest/default_agent_rollout_processor.py @@ -1,7 +1,8 @@ import asyncio import json +import logging import os -from typing import Any, List, Optional, Union +from typing import Any, AsyncIterator, List, Optional, Union from mcp.types import CallToolResult, TextContent from openai import NOT_GIVEN, NotGiven @@ -14,6 +15,8 @@ from eval_protocol.models import EvaluationRow, Message from eval_protocol.pytest.types import Dataset, RolloutProcessorConfig +logger = logging.getLogger(__name__) + class Agent: """ @@ -114,13 +117,42 @@ def _get_content_from_tool_result(self, tool_result: CallToolResult) -> List[Tex async def default_agent_rollout_processor( rows: List[EvaluationRow], config: RolloutProcessorConfig -) -> List[EvaluationRow]: - dataset: Dataset = [] - for row in rows: +) -> AsyncIterator[EvaluationRow]: + """Process agent rollouts with bounded concurrency and yield as they complete.""" + + max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8 + semaphore = asyncio.Semaphore(max_concurrent) + + async def process_row(row: EvaluationRow) -> EvaluationRow: + """Process a single row with agent rollout.""" agent = Agent(model=config.model, row=row, config_path=config.mcp_config_path, logger=config.logger) - await agent.setup() - await agent.call_agent() - dataset.append(agent.evaluation_row) - if agent.mcp_client: - await agent.mcp_client.cleanup() - return dataset + try: + await agent.setup() + await agent.call_agent() + return agent.evaluation_row + finally: + if agent.mcp_client: + await agent.mcp_client.cleanup() + + async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow: + async with semaphore: + try: + return await process_row(r) + except Exception as e: + logger.exception(f"Error processing row {r.input_metadata.row_id}: {e}") + return r + + # Create all tasks + tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows] + + # Yield results as they complete (note that they're not necessarily in original order) + try: + for task in asyncio.as_completed(tasks): + try: + yield await task + except Exception: + logger.exception("Error processing row") + finally: + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) diff --git a/eval_protocol/pytest/default_mcp_gym_rollout_processor.py b/eval_protocol/pytest/default_mcp_gym_rollout_processor.py index 5037cbad..de9d8ca1 100644 --- a/eval_protocol/pytest/default_mcp_gym_rollout_processor.py +++ b/eval_protocol/pytest/default_mcp_gym_rollout_processor.py @@ -6,7 +6,7 @@ import subprocess import time from pathlib import Path -from typing import List, Optional +from typing import AsyncIterator, List, Optional import eval_protocol as ep from eval_protocol.models import EvaluationRow, Message @@ -194,22 +194,19 @@ def __exit__(self, exc_type, exc_val, exc_tb): async def default_mcp_gym_rollout_processor( rows: List[EvaluationRow], config: RolloutProcessorConfig -) -> List[EvaluationRow]: +) -> AsyncIterator[EvaluationRow]: """ Rollout processor for tau bench environments. - This processor starts an MCP server, creates tau bench environments, and runs rollouts - using the eval_protocol framework, following the pattern from test_tau2_e2e.py. - + using the eval_protocol framework, yielding results as they complete. Args: rows: List of EvaluationRow objects containing messages and dataset info in input_metadata config: RolloutProcessorConfig with model and other parameters - Returns: - List of EvaluationRow objects with completed conversations + AsyncIterator of EvaluationRow objects with completed conversations """ if config.server_script_path is None: raise ValueError("server_script_path is required for default_mcp_gym_rollout_processor") @@ -233,15 +230,14 @@ async def default_mcp_gym_rollout_processor( ) # Run rollout with environments and policy - evaluation_rows = await ep.rollout( + async for evaluation_row in ep.rollout( envs, policy=policy, evaluation_rows=rows, steps=config.steps, max_concurrent_rollouts=config.max_concurrent_rollouts, - ) - - return evaluation_rows + ): + yield evaluation_row finally: # Always clean up the server diff --git a/eval_protocol/pytest/default_no_op_rollout_process.py b/eval_protocol/pytest/default_no_op_rollout_process.py index bae733c3..47cb17be 100644 --- a/eval_protocol/pytest/default_no_op_rollout_process.py +++ b/eval_protocol/pytest/default_no_op_rollout_process.py @@ -1,12 +1,15 @@ -from typing import List +from typing import AsyncIterator, List from eval_protocol.models import EvaluationRow from eval_protocol.pytest.types import RolloutProcessorConfig -def default_no_op_rollout_processor(rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[EvaluationRow]: +async def default_no_op_rollout_processor( + rows: List[EvaluationRow], config: RolloutProcessorConfig +) -> AsyncIterator[EvaluationRow]: """ Simply passes input dataset through to the test function. This can be useful if you want to run the rollout yourself. """ - return rows + for row in rows: + yield row diff --git a/eval_protocol/pytest/default_single_turn_rollout_process.py b/eval_protocol/pytest/default_single_turn_rollout_process.py index 95613ebc..424347cd 100644 --- a/eval_protocol/pytest/default_single_turn_rollout_process.py +++ b/eval_protocol/pytest/default_single_turn_rollout_process.py @@ -1,15 +1,23 @@ import asyncio import logging import os -from typing import List +import time +from typing import AsyncIterator, List -from eval_protocol.models import ChatCompletionMessageToolCall, EvaluationRow, Message +import litellm +from litellm import acompletion +from openai.types.chat.chat_completion_message import ChatCompletionMessageToolCall + +from eval_protocol.dataset_logger import default_logger +from eval_protocol.models import EvaluationRow, Message from eval_protocol.pytest.types import RolloutProcessorConfig +logger = logging.getLogger(__name__) + async def default_single_turn_rollout_processor( rows: List[EvaluationRow], config: RolloutProcessorConfig -) -> List[EvaluationRow]: +) -> AsyncIterator[EvaluationRow]: """Generate a single response from any supported model provider using LiteLLM.""" # Quiet LiteLLM logs in test runs unless user overrode @@ -41,7 +49,10 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: if isinstance(config.input_params, dict): if "reasoning_effort" in config.input_params: effort_val = str(config.input_params["reasoning_effort"]) # flat shape - elif isinstance(config.input_params.get("extra_body"), dict) and "reasoning_effort" in config.input_params["extra_body"]: + elif ( + isinstance(config.input_params.get("extra_body"), dict) + and "reasoning_effort" in config.input_params["extra_body"] + ): # Accept if user passed it directly inside extra_body effort_val = str(config.input_params["extra_body"]["reasoning_effort"]) # already in extra_body @@ -89,10 +100,10 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: ] row.messages = messages - config.logger.log(row) + default_logger.log(row) return row - # Process rows with bounded concurrency if configured + # Process rows with bounded concurrency and yield as they complete max_concurrent = getattr(config, "max_concurrent_rollouts", 8) or 8 semaphore = asyncio.Semaphore(max_concurrent) @@ -103,7 +114,17 @@ async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow: except Exception as e: return r - tasks = [_sem_wrapper(row) for row in rows] - dataset = list(await asyncio.gather(*tasks)) + # Create all tasks + tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows] - return dataset + # Yield results as they complete (note that they're not necessarily in original order) + try: + for task in asyncio.as_completed(tasks): + try: + yield await task + except Exception: + logger.exception("Error processing row") + finally: + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index f1d9af50..81856ff6 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -1,8 +1,13 @@ +import asyncio import copy import inspect +import json import math import os +import pathlib +import re import statistics +import time from typing import Any, Callable, Dict, List, Literal, Optional, Union import pytest @@ -173,7 +178,7 @@ def decorator( if sig.return_annotation is not List[EvaluationRow]: raise ValueError("In batch mode, your eval function must return a list of EvaluationRow instances") - def execute_with_params( + async def execute_with_params( test_func: TestFunction, processed_row: EvaluationRow | None = None, processed_dataset: List[EvaluationRow] | None = None, @@ -190,7 +195,12 @@ def execute_with_params( if "rows" in evaluation_test_kwargs: raise ValueError("'rows' is a reserved parameter for the evaluation function") kwargs.update(evaluation_test_kwargs) - return execute_function(test_func, **kwargs) + + # Handle both sync and async test functions + if asyncio.iscoroutinefunction(test_func): + return await test_func(**kwargs) + else: + return test_func(**kwargs) # Calculate all possible combinations of parameters def _parse_ep_max_rows(default_value: int | None) -> int | None: @@ -300,7 +310,7 @@ def create_wrapper_with_signature() -> Callable: # Create the function body that will be used invocation_id = generate_id() - def wrapper_body(**kwargs): + async def wrapper_body(**kwargs): model_name = kwargs["model"] eval_metadata = None all_results: List[List[EvaluationRow]] = [[] for _ in range(num_runs)] @@ -423,26 +433,40 @@ def _log_eval_error( for row in fresh_dataset: active_logger.log(row) - processed_dataset = execute_function(rollout_processor, rows=fresh_dataset, config=config) + rollout_result = rollout_processor(fresh_dataset, config) if mode == "pointwise": - # Pointwise mode: apply the evaluator function to each row - for row in processed_dataset: - result = execute_with_params( - test_func, - processed_row=row, - evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {}, - ) - if result is None or not isinstance(result, EvaluationRow): - raise ValueError( - f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test." + # Pointwise mode, rollouts will return as they complete so we can pipeline evaluation_test execution + semaphore = asyncio.Semaphore(max_concurrent_rollouts) + tasks = [] + + async def _execute_with_semaphore(row): + async with semaphore: + result = await execute_with_params( + test_func, + processed_row=row, + evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {}, ) - all_results[i].append(result) + if result is None or not isinstance(result, EvaluationRow): + raise ValueError( + f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test." + ) + return result + + async for row in rollout_processor(fresh_dataset, config): + tasks.append(asyncio.create_task(_execute_with_semaphore(row))) + + all_results[i] = await asyncio.gather(*tasks) + else: - # Batch mode: call the test function with the full dataset - results = execute_with_params( + # Batch mode: collect all results first, then evaluate (no pipelining) + input_dataset = [] + async for row in rollout_result: + input_dataset.append(row) + + results = await execute_with_params( test_func, - processed_dataset=processed_dataset, + processed_dataset=input_dataset, evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {}, ) if results is None: @@ -568,10 +592,6 @@ def _log_eval_error( ) # As per project convention, avoid printing per-metric CI lines to reduce noise if summary_path: - import json - import pathlib - import re - import time def _sanitize_filename(text: str) -> str: safe = re.sub(r"[^A-Za-z0-9._-]+", "-", text.strip()) @@ -667,6 +687,7 @@ def _extract_effort_tag(params: dict) -> str | None: # Create the pytest wrapper pytest_wrapper = create_wrapper_with_signature() pytest_wrapper = pytest.mark.parametrize(test_param_names, param_tuples)(pytest_wrapper) + pytest_wrapper = pytest.mark.asyncio(pytest_wrapper) def create_dual_mode_wrapper() -> Callable: """ @@ -687,66 +708,39 @@ def create_dual_mode_wrapper() -> Callable: # Check if the test function is async is_async = asyncio.iscoroutinefunction(test_func) - if is_async: - - async def dual_mode_wrapper(*args, **kwargs): - # Check if this is a direct call with the expected signature - if mode == "pointwise": - # For pointwise mode, check if called with a single row argument - if len(args) == 1 and isinstance(args[0], EvaluationRow) and not kwargs: - return await test_func(row=args[0]) - else: - # For batch mode, check if called with rows argument - if ( - len(args) == 1 - and isinstance(args[0], list) - and all(isinstance(r, EvaluationRow) for r in args[0]) - and not kwargs - ): - return await test_func(rows=args[0]) - # Also check if called with keyword argument 'rows' - if ( - len(args) == 0 - and "rows" in kwargs - and isinstance(kwargs["rows"], list) - and all(isinstance(r, EvaluationRow) for r in kwargs["rows"]) - ): - return await test_func(**kwargs) - - # If not a direct call, use the pytest wrapper - return pytest_wrapper(*args, **kwargs) - - else: - - def dual_mode_wrapper(*args, **kwargs): - # Check if this is a direct call with the expected signature - if mode == "pointwise": - # For pointwise mode, check if called with a single row argument - if len(args) == 1 and isinstance(args[0], EvaluationRow) and not kwargs: - return test_func(row=args[0]) - - if len(args) == 0 and "row" in kwargs and isinstance(kwargs["row"], EvaluationRow): - return test_func(**kwargs) - else: - # For batch mode, check if called with rows argument - if ( - len(args) == 1 - and isinstance(args[0], list) - and all(isinstance(r, EvaluationRow) for r in args[0]) - and not kwargs - ): - return test_func(rows=args[0]) - # Also check if called with keyword argument 'rows' - if ( - len(args) == 0 - and "rows" in kwargs - and isinstance(kwargs["rows"], list) - and all(isinstance(r, EvaluationRow) for r in kwargs["rows"]) - ): - return test_func(**kwargs) - - # If not a direct call, use the pytest wrapper - return pytest_wrapper(*args, **kwargs) + async def call_test_func(**call_kwargs): + """Helper to call test_func with proper async/sync handling""" + if is_async: + return await test_func(**call_kwargs) + else: + return test_func(**call_kwargs) + + async def dual_mode_wrapper(*args, **kwargs): + # Check if this is a direct call with the expected signature + if mode == "pointwise": + # For pointwise mode, check if called with a single row argument + if len(args) == 1 and isinstance(args[0], EvaluationRow) and not kwargs: + return await call_test_func(row=args[0]) + else: + # For batch mode, check if called with rows argument + if ( + len(args) == 1 + and isinstance(args[0], list) + and all(isinstance(r, EvaluationRow) for r in args[0]) + and not kwargs + ): + return await call_test_func(rows=args[0]) + # Also check if called with keyword argument 'rows' + if ( + len(args) == 0 + and "rows" in kwargs + and isinstance(kwargs["rows"], list) + and all(isinstance(r, EvaluationRow) for r in kwargs["rows"]) + ): + return await call_test_func(**kwargs) + + # If not a direct call, use the pytest wrapper + return await pytest_wrapper(*args, **kwargs) # Copy all attributes from the pytest wrapper to our dual mode wrapper import functools diff --git a/eval_protocol/pytest/plugin.py b/eval_protocol/pytest/plugin.py index 6c58d1e2..3a5ec0e2 100644 --- a/eval_protocol/pytest/plugin.py +++ b/eval_protocol/pytest/plugin.py @@ -12,8 +12,8 @@ max_dataset_rows value set in the decorator). """ -import os import logging +import os from typing import Optional @@ -32,17 +32,13 @@ def pytest_addoption(parser) -> None: "--ep-print-summary", action="store_true", default=False, - help=( - "Print a concise summary line (suite/model/effort/agg score) at the end of each evaluation_test." - ), + help=("Print a concise summary line (suite/model/effort/agg score) at the end of each evaluation_test."), ) group.addoption( "--ep-summary-json", action="store", default=None, - help=( - "Write a JSON summary artifact at the given path (e.g., ./outputs/aime_low.json)." - ), + help=("Write a JSON summary artifact at the given path (e.g., ./outputs/aime_low.json)."), ) group.addoption( "--ep-input-param", @@ -108,6 +104,7 @@ def pytest_configure(config) -> None: try: import json as _json import pathlib as _pathlib + merged: dict = {} input_params_opts = config.getoption("--ep-input-param") if input_params_opts: @@ -139,5 +136,3 @@ def pytest_configure(config) -> None: except Exception: # best effort, do not crash pytest session pass - - diff --git a/eval_protocol/pytest/types.py b/eval_protocol/pytest/types.py index c6de681e..9f564ce1 100644 --- a/eval_protocol/pytest/types.py +++ b/eval_protocol/pytest/types.py @@ -3,7 +3,7 @@ """ from dataclasses import dataclass, field -from typing import Any, Callable, Dict, List, Literal, Optional +from typing import Any, AsyncIterator, Callable, Dict, List, Literal, Optional from eval_protocol.dataset_logger import default_logger from eval_protocol.dataset_logger.dataset_logger import DatasetLogger @@ -53,4 +53,4 @@ class RolloutProcessorConfig: kwargs: Dict[str, Any] = field(default_factory=dict) # any additional kwargs to pass to the rollout processor -RolloutProcessor = Callable[[List[EvaluationRow], RolloutProcessorConfig], List[EvaluationRow]] +RolloutProcessor = Callable[[List[EvaluationRow], RolloutProcessorConfig], AsyncIterator[EvaluationRow]] diff --git a/eval_protocol/pytest/utils.py b/eval_protocol/pytest/utils.py index 981c1ed3..23a5722d 100644 --- a/eval_protocol/pytest/utils.py +++ b/eval_protocol/pytest/utils.py @@ -88,8 +88,8 @@ def create_dynamically_parameterized_wrapper(test_func, wrapper_body, test_param from functools import wraps @wraps(test_func) - def wrapper(**kwargs): - return wrapper_body(**kwargs) + async def wrapper(**kwargs): + return await wrapper_body(**kwargs) parameters = [inspect.Parameter(name, inspect.Parameter.POSITIONAL_OR_KEYWORD) for name in test_param_names] wrapper.__signature__ = inspect.Signature(parameters) diff --git a/tests/pytest/test_pytest_ids.py b/tests/pytest/test_pytest_ids.py index 0131bcbe..24ba3baf 100644 --- a/tests/pytest/test_pytest_ids.py +++ b/tests/pytest/test_pytest_ids.py @@ -19,7 +19,7 @@ def read(self): return list(self._rows.values()) -def test_evaluation_test_decorator(monkeypatch): +async def test_evaluation_test_decorator(monkeypatch): from eval_protocol.pytest.evaluation_test import evaluation_test logger = InMemoryLogger() @@ -45,13 +45,13 @@ def eval_fn(row: EvaluationRow) -> EvaluationRow: # Manually invoke all parameter combinations within a single test for ds_path in dataset_paths: - eval_fn(model="dummy/local-model", dataset_path=[ds_path]) + await eval_fn(model="dummy/local-model", dataset_path=[ds_path]) # Assertions on IDs generated by the decorator logic assert len(logger.read()) == 38 -def test_evaluation_test_decorator_ids_single(monkeypatch): +async def test_evaluation_test_decorator_ids_single(monkeypatch): in_memory_logger = InMemoryLogger() unique_run_ids = set() unique_experiment_ids = set() @@ -92,7 +92,7 @@ def eval_fn(row: EvaluationRow) -> EvaluationRow: # Manually invoke all parameter combinations within a single test for ds_path in dataset_paths: for params in input_params_list: - eval_fn(model="dummy/local-model", dataset_path=[ds_path], input_params=params) + await eval_fn(model="dummy/local-model", dataset_path=[ds_path], input_params=params) # Assertions on IDs generated by the decorator logic assert len(unique_invocation_ids) == 1 diff --git a/tests/test_rollout_control_plane_integration.py b/tests/test_rollout_control_plane_integration.py index dcaac0e9..1b92d5aa 100644 --- a/tests/test_rollout_control_plane_integration.py +++ b/tests/test_rollout_control_plane_integration.py @@ -239,7 +239,9 @@ def mock_step_side_effect(env_index, tool_call): policy = MockPolicy(["right", "down", "right"]) # Execute rollout - evaluation_rows = await self.execution_manager.execute_rollouts(mock_env, policy, steps=10) + evaluation_rows = [] + async for row in self.execution_manager.execute_rollouts(mock_env, policy, steps=10): + evaluation_rows.append(row) # Validate results assert len(evaluation_rows) == 1, "Should have one evaluation row" @@ -457,7 +459,9 @@ async def test_rollout_handles_control_plane_failure_gracefully(self): # Execute rollout with control plane failure policy = MockPolicy(["right"]) - evaluation_rows = await self.execution_manager.execute_rollouts(mock_env, policy, steps=1) + evaluation_rows = [] + async for row in self.execution_manager.execute_rollouts(mock_env, policy, steps=1): + evaluation_rows.append(row) # Should still work, but without control plane info assert len(evaluation_rows) == 1 @@ -500,15 +504,26 @@ async def test_rollout_creates_envs_from_url(self): mock_make.return_value = mock_env manager_instance = MockManager.return_value - manager_instance.execute_rollouts = AsyncMock(return_value=["ok"]) - result = await ep.rollout( + # Mock execute_rollouts to return an async generator and track calls + call_args = [] + + async def mock_execute_rollouts(*args, **kwargs): + call_args.append((args, kwargs)) + for item in ["ok"]: + yield item + + manager_instance.execute_rollouts = mock_execute_rollouts + + result = [] + async for row in ep.rollout( "http://localhost:1234/mcp/", policy, dataset=dataset, model_id="test_model", steps=5, - ) + ): + result.append(row) mock_make.assert_called_once_with( "http://localhost:1234/mcp/", @@ -517,14 +532,12 @@ async def test_rollout_creates_envs_from_url(self): model_id="test_model", ) - manager_instance.execute_rollouts.assert_called_once_with( - mock_make.return_value, - policy, - 5, - None, - 8, - None, - ) + # Verify execute_rollouts was called with correct arguments + assert len(call_args) == 1, "execute_rollouts should be called once" + args, kwargs = call_args[0] + assert args[0] == mock_make.return_value, "First arg should be mock env" + assert args[1] == policy, "Second arg should be policy" + assert args[2] == 5, "Third arg should be steps" assert result == ["ok"]