Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions eval_protocol/pytest/default_single_turn_rollout_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
from typing import List

from litellm import acompletion
from openai.types.chat.chat_completion_message import ChatCompletionMessageToolCall
from openai.types.chat.chat_completion_message import (
ChatCompletionMessageToolCall,
FunctionCall,
)

from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest.types import RolloutProcessorConfig
Expand All @@ -18,7 +21,24 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
if len(row.messages) == 0:
raise ValueError("Messages is empty. Please provide a non-empty dataset")

messages_payload = [{"role": m.role, "content": m.content} for m in row.messages]
messages_payload = []
for m in row.messages:
payload = {"role": m.role}
if m.content is not None:
payload["content"] = m.content
if m.name is not None:
payload["name"] = m.name
if m.tool_call_id is not None:
payload["tool_call_id"] = m.tool_call_id
if m.tool_calls is not None:
payload["tool_calls"] = [
tc.model_dump(exclude_none=True) for tc in m.tool_calls
]
if m.function_call is not None:
payload["function_call"] = m.function_call.model_dump(
exclude_none=True
)
messages_payload.append(payload)

request_params = {"model": config.model, "messages": messages_payload, **config.input_params}

Expand All @@ -27,8 +47,10 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:

response = await acompletion(**request_params)

assistant_content = response.choices[0].message.content or ""
tool_calls = response.choices[0].message.tool_calls if response.choices[0].message.tool_calls else None
assistant_message = response.choices[0].message
assistant_content = assistant_message.content or ""
tool_calls = assistant_message.tool_calls if assistant_message.tool_calls else None
function_call = assistant_message.function_call

converted_tool_calls = None
if tool_calls:
Expand All @@ -49,6 +71,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
role="assistant",
content=assistant_content,
tool_calls=converted_tool_calls,
function_call=function_call,
)
]

Expand Down
74 changes: 74 additions & 0 deletions tests/test_default_single_turn_rollout_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import asyncio
from types import SimpleNamespace
from typing import Any, Dict, List
from unittest import mock
from openai.types.chat.chat_completion_message import (
ChatCompletionMessageToolCall,
ChatCompletionMessageToolCallFunction,
)

from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest.types import RolloutProcessorConfig
from eval_protocol.pytest.default_single_turn_rollout_process import (
default_single_turn_rollout_processor,
)


def test_handles_function_call_messages() -> None:
async def run_test() -> None:
tool_call = ChatCompletionMessageToolCall(
id="call_1",
type="function",
function=ChatCompletionMessageToolCallFunction(
name="get_weather", arguments="{}"
),
)
row = EvaluationRow(
messages=[
Message(role="user", content="Hi"),
Message(role="assistant", tool_calls=[tool_call], content=""),
Message(role="tool", tool_call_id="call_1", content="sunny"),
],
tools=[{"type": "function", "function": {"name": "get_weather"}}],
)
config = RolloutProcessorConfig(
model="gpt-4o-mini", input_params={}, mcp_config_path=""
)

captured_messages: List[Dict[str, Any]] = []

async def fake_acompletion(**kwargs: Any) -> Any:
nonlocal captured_messages
captured_messages = kwargs["messages"]
return SimpleNamespace(
choices=[
SimpleNamespace(
message=SimpleNamespace(
content="done",
tool_calls=[
ChatCompletionMessageToolCall(
id="call_2",
type="function",
function=ChatCompletionMessageToolCallFunction(
name="foo", arguments="{}"
),
)
],
function_call=None,
)
)
]
)

with mock.patch(
"eval_protocol.pytest.default_single_turn_rollout_process.acompletion",
side_effect=fake_acompletion,
):
dataset = await default_single_turn_rollout_processor([row], config)

assert captured_messages[1]["tool_calls"][0]["id"] == "call_1"
assert captured_messages[2]["tool_call_id"] == "call_1"
result_row = dataset[0]
assert result_row.messages[-1].tool_calls[0].id == "call_2"

asyncio.run(run_test())
Loading