Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,13 @@ jobs:
FIREWORKS_ACCOUNT_ID: ${{ secrets.FIREWORKS_ACCOUNT_ID }}
PYTHONWARNINGS: "ignore::DeprecationWarning,ignore::RuntimeWarning"
run: |
# Run most tests in parallel, but explicitly ignore tests that manage their own servers
# Run most tests in parallel, but explicitly ignore tests that manage their own servers or are slow
uv run pytest \
-n auto \
--ignore=tests/test_batch_evaluation.py \
--ignore=tests/pytest/test_frozen_lake.py \
--ignore=tests/pytest/test_lunar_lander.py \
--ignore=tests/pytest/test_tau_bench_airline.py \
--cov=eval_protocol --cov-append --cov-report=xml --cov-report=term-missing -v --durations=10

- name: Store coverage file
Expand Down
7 changes: 3 additions & 4 deletions eval_protocol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from .mcp_env import (
AnthropicPolicy,
OpenAIPolicy,
LiteLLMPolicy,
FireworksPolicy,
make,
rollout,
test_mcp,
Expand Down Expand Up @@ -60,6 +62,7 @@
# MCP Environment API
"make",
"rollout",
"LiteLLMPolicy",
"AnthropicPolicy",
"FireworksPolicy",
"OpenAIPolicy",
Expand All @@ -73,10 +76,6 @@
"mcp",
]

# Add FireworksPolicy to exports if available
if _FIREWORKS_AVAILABLE:
__all__.insert(__all__.index("OpenAIPolicy") + 1, "FireworksPolicy")

from . import _version

__version__ = _version.get_versions()["version"]
19 changes: 17 additions & 2 deletions eval_protocol/mcp/execution/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,25 @@ async def _execute_with_semaphore(idx):
# Convert trajectories to unified EvaluationRow format
evaluation_rows = []
for trajectory in trajectories:
messages = [Message.model_validate(msg) for msg in trajectory.conversation_history]
# Handle multimodal content by extracting text from complex content structures
messages = []
for msg in trajectory.conversation_history:
# Create a copy to avoid modifying the original
msg_dict = dict(msg)

# Handle multimodal content (list of content blocks) by extracting text
if isinstance(msg_dict.get("content"), list):
text_content = None
for content_block in msg_dict["content"]:
if isinstance(content_block, dict) and content_block.get("type") == "text":
text_content = content_block.get("text")
break
msg_dict["content"] = text_content or ""

messages.append(Message.model_validate(msg_dict))

input_metadata = InputMetadata(
row_id=trajectory.session.session_id,
row_id=trajectory.session.dataset_row.id if trajectory.session.dataset_row else None,
dataset_info=asdict(trajectory.session.dataset_row) if trajectory.session.dataset_row else {},
completion_params=CompletionParams(
model=policy.model_id,
Expand Down
74 changes: 52 additions & 22 deletions eval_protocol/mcp_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,18 @@
Usage remains the same:
import eval_protocol as ep

# Load dataset with environment configuration and prompts
dataset = load_jsonl("dataset.jsonl")

# Create general policy (environment-agnostic)
policy = ep.FireworksPolicy(model_id="accounts/fireworks/models/qwen3-235b-a22b")

# Create environments with dataset-driven configuration
envs = ep.make("http://localhost:8000/mcp", dataset=dataset)
# Create environments with evaluation_rows configuration
envs = ep.make("http://localhost:8000/mcp", evaluation_rows=evaluation_rows)

# Execute tool-calling rollouts
evaluation_rows = await ep.rollout(envs, policy=policy, steps=512)

Key Features:
- General tool-calling interface that works with any MCP environment
- Dataset-driven configuration with system prompts and user prompt templates
- EvaluationRow-driven configuration with system prompts and user prompt templates
- Automatic MCP tool discovery from servers
- **PROPER MCP PATTERN**: Initial state obtained from MCP resources during session establishment
- Tools used only for actions/interactions, not for getting initial state
Expand All @@ -50,7 +47,7 @@

# Import all functionality from the new modular components
from .mcp.execution.manager import ExecutionManager
from .mcp.execution.policy import AnthropicPolicy, FireworksPolicy, LLMBasePolicy, OpenAIPolicy
from .mcp.execution.policy import AnthropicPolicy, FireworksPolicy, LLMBasePolicy, OpenAIPolicy, LiteLLMPolicy
from .mcp.session.manager import GeneralMCPVectorEnv
from .models import EvaluationRow
from .types import DatasetRow, MCPSession, MCPToolCall
Expand All @@ -60,18 +57,20 @@

def make(
env_spec: str,
evaluation_rows: Optional[List[EvaluationRow]] = None,
dataset: Optional[List[Dict]] = None,
n: Optional[int] = None,
seeds: Optional[List[int]] = None,
model_id: str = "unknown",
user_prompt_formatter: Optional[Callable] = None,
) -> GeneralMCPVectorEnv:
"""
Create general MCP environments driven by dataset configuration.
Create general MCP environments driven by evaluation_rows configuration.

Args:
env_spec: MCP server URL
dataset: List of dataset rows with prompts and context (preferred)
evaluation_rows: List of EvaluationRow objects containing messages and metadata (preferred)
dataset: List of dataset entries (for backward compatibility)
n: Number of environments (for backward compatibility)
seeds: List of seeds (for backward compatibility)
model_id: Model identifier
Expand All @@ -81,8 +80,10 @@ def make(
General MCP environment that works with any MCP server

Example:
# New dataset-driven approach (preferred)
dataset = load_jsonl("dataset.jsonl")
# EvaluationRow approach (preferred)
envs = ep.make("http://localhost:8000/mcp", evaluation_rows=evaluation_rows)

# Dataset approach (backward compatibility)
envs = ep.make("http://localhost:8000/mcp", dataset=dataset)

# Legacy approach (backward compatibility)
Expand All @@ -97,13 +98,39 @@ def make(
if not base_url.endswith("/"):
base_url += "/"

# Handle dataset-driven vs legacy approaches
if dataset is not None:
# New dataset-driven approach
# Convert evaluation_rows to dataset format if provided
internal_dataset = []

if evaluation_rows:
for i, row in enumerate(evaluation_rows):
dataset_info = row.input_metadata.dataset_info if row.input_metadata else {}

system_message = row.get_system_message()
system_prompt = system_message.content or ""

dataset_entry = {
"id": row.input_metadata.row_id if row.input_metadata and row.input_metadata.row_id else f"task_{i}",
"system_prompt": system_prompt,
"user_prompt_template": dataset_info.get("user_prompt_template", ""),
"environment_context": dataset_info.get("environment_context", {}),
"user_simulation": dataset_info.get("user_simulation", {}),
"evaluation_criteria": dataset_info.get("evaluation_criteria", {})
}
internal_dataset.append(dataset_entry)
elif dataset:
# Use provided dataset directly for backward compatibility
internal_dataset = dataset

dataset_rows = []
sessions = []

# Handle evaluation_rows vs legacy approaches
if internal_dataset:
# New evaluation_rows approach
dataset_rows = []
sessions = []

for row in dataset:
for row in internal_dataset:
# Parse dataset row
if isinstance(row, dict):
# Handle seed from both old location (backward compatibility) and new location
Expand Down Expand Up @@ -138,7 +165,7 @@ def make(
else:
# Legacy approach for backward compatibility
if n is None:
raise ValueError("Either 'dataset' or 'n' must be provided")
raise ValueError("Either 'evaluation_rows' or 'n' must be provided")

# Generate seeds if not provided
if seeds is None:
Expand Down Expand Up @@ -178,6 +205,7 @@ async def rollout(
envs: GeneralMCPVectorEnv,
policy: Union[FireworksPolicy, LLMBasePolicy, Callable],
*,
evaluation_rows: Optional[List[EvaluationRow]] = None,
dataset: Optional[List[Dict]] = None,
model_id: Optional[str] = None,
steps: int = 512,
Expand All @@ -191,13 +219,14 @@ async def rollout(

This works with ANY MCP environment because:
1. Policy receives tool schemas and makes tool calls
2. Environment prompts come from dataset
2. Environment prompts come from evaluation_rows
3. No hardcoded environment logic

Args:
envs: Either a GeneralMCPVectorEnv instance or the MCP server URL
policy: Policy that takes tool schemas, observations, prompts and returns tool calls
dataset: Dataset used when envs is a URL (required for automatic env creation)
evaluation_rows: EvaluationRow list used when envs is a URL (for automatic env creation)
dataset: Dataset list used for backward compatibility when envs is a URL
model_id: Model identifier used when creating environments. Defaults to ``policy.model_id`` when available.
steps: Maximum steps per rollout
openai_format_log_file: Optional file to log clean OpenAI format for terminated trajectories only
Expand All @@ -220,7 +249,7 @@ async def rollout(
trajectories = await ep.rollout(
"http://localhost:8000/mcp/",
policy,
dataset=my_dataset,
evaluation_rows=my_evaluation_rows,
model_id=policy.model_id,
)

Expand All @@ -233,11 +262,11 @@ async def rollout(
"""
# Automatically create environments if a base URL is provided
if isinstance(envs, str):
if dataset is None:
raise ValueError("'dataset' must be provided when envs is a URL")
if evaluation_rows is None and dataset is None:
raise ValueError("Either 'evaluation_rows' or 'dataset' must be provided when envs is a URL")

auto_model_id = model_id or getattr(policy, "model_id", "unknown")
envs = make(envs, dataset=dataset, model_id=auto_model_id)
envs = make(envs, evaluation_rows=evaluation_rows, dataset=dataset, model_id=auto_model_id)

# Use the new ExecutionManager for execution
execution_manager = ExecutionManager()
Expand Down Expand Up @@ -304,6 +333,7 @@ async def test_mcp(base_url: str, seeds: List[int]) -> Dict[str, Any]:
"AnthropicPolicy",
"FireworksPolicy",
"OpenAIPolicy",
"LiteLLMPolicy",
"LLMBasePolicy", # New base class for OpenAI integration
"GeneralMCPVectorEnv",
"MCPToolCall",
Expand Down
7 changes: 7 additions & 0 deletions eval_protocol/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,13 @@ def get_conversation_length(self) -> int:
"""Returns the number of messages in the conversation."""
return len(self.messages)

def get_system_message(self) -> Message:
"""Returns the system message from the conversation. Returns empty Message if none found."""
system_messages = [msg for msg in self.messages if msg.role == "system"]
if not system_messages:
return Message(role="system", content="")
return system_messages[0]

def get_assistant_messages(self) -> List[Message]:
"""Returns only the assistant messages from the conversation."""
return [msg for msg in self.messages if msg.role == "assistant"]
Expand Down
Loading
Loading