diff --git a/eval_protocol/pytest/default_single_turn_rollout_process.py b/eval_protocol/pytest/default_single_turn_rollout_process.py index f8e7a23e..9e85d744 100644 --- a/eval_protocol/pytest/default_single_turn_rollout_process.py +++ b/eval_protocol/pytest/default_single_turn_rollout_process.py @@ -2,7 +2,10 @@ from typing import List from litellm import acompletion -from openai.types.chat.chat_completion_message import ChatCompletionMessageToolCall +from openai.types.chat.chat_completion_message import ( + ChatCompletionMessageToolCall, + FunctionCall, +) from eval_protocol.models import EvaluationRow, Message from eval_protocol.pytest.types import RolloutProcessorConfig @@ -18,7 +21,24 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: if len(row.messages) == 0: raise ValueError("Messages is empty. Please provide a non-empty dataset") - messages_payload = [{"role": m.role, "content": m.content} for m in row.messages] + messages_payload = [] + for m in row.messages: + payload = {"role": m.role} + if m.content is not None: + payload["content"] = m.content + if m.name is not None: + payload["name"] = m.name + if m.tool_call_id is not None: + payload["tool_call_id"] = m.tool_call_id + if m.tool_calls is not None: + payload["tool_calls"] = [ + tc.model_dump(exclude_none=True) for tc in m.tool_calls + ] + if m.function_call is not None: + payload["function_call"] = m.function_call.model_dump( + exclude_none=True + ) + messages_payload.append(payload) request_params = {"model": config.model, "messages": messages_payload, **config.input_params} @@ -27,8 +47,10 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: response = await acompletion(**request_params) - assistant_content = response.choices[0].message.content or "" - tool_calls = response.choices[0].message.tool_calls if response.choices[0].message.tool_calls else None + assistant_message = response.choices[0].message + assistant_content = assistant_message.content or "" + tool_calls = assistant_message.tool_calls if assistant_message.tool_calls else None + function_call = assistant_message.function_call converted_tool_calls = None if tool_calls: @@ -49,6 +71,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: role="assistant", content=assistant_content, tool_calls=converted_tool_calls, + function_call=function_call, ) ] diff --git a/tests/test_default_single_turn_rollout_processor.py b/tests/test_default_single_turn_rollout_processor.py new file mode 100644 index 00000000..28e61c45 --- /dev/null +++ b/tests/test_default_single_turn_rollout_processor.py @@ -0,0 +1,74 @@ +import asyncio +from types import SimpleNamespace +from typing import Any, Dict, List +from unittest import mock +from openai.types.chat.chat_completion_message import ( + ChatCompletionMessageToolCall, + ChatCompletionMessageToolCallFunction, +) + +from eval_protocol.models import EvaluationRow, Message +from eval_protocol.pytest.types import RolloutProcessorConfig +from eval_protocol.pytest.default_single_turn_rollout_process import ( + default_single_turn_rollout_processor, +) + + +def test_handles_function_call_messages() -> None: + async def run_test() -> None: + tool_call = ChatCompletionMessageToolCall( + id="call_1", + type="function", + function=ChatCompletionMessageToolCallFunction( + name="get_weather", arguments="{}" + ), + ) + row = EvaluationRow( + messages=[ + Message(role="user", content="Hi"), + Message(role="assistant", tool_calls=[tool_call], content=""), + Message(role="tool", tool_call_id="call_1", content="sunny"), + ], + tools=[{"type": "function", "function": {"name": "get_weather"}}], + ) + config = RolloutProcessorConfig( + model="gpt-4o-mini", input_params={}, mcp_config_path="" + ) + + captured_messages: List[Dict[str, Any]] = [] + + async def fake_acompletion(**kwargs: Any) -> Any: + nonlocal captured_messages + captured_messages = kwargs["messages"] + return SimpleNamespace( + choices=[ + SimpleNamespace( + message=SimpleNamespace( + content="done", + tool_calls=[ + ChatCompletionMessageToolCall( + id="call_2", + type="function", + function=ChatCompletionMessageToolCallFunction( + name="foo", arguments="{}" + ), + ) + ], + function_call=None, + ) + ) + ] + ) + + with mock.patch( + "eval_protocol.pytest.default_single_turn_rollout_process.acompletion", + side_effect=fake_acompletion, + ): + dataset = await default_single_turn_rollout_processor([row], config) + + assert captured_messages[1]["tool_calls"][0]["id"] == "call_1" + assert captured_messages[2]["tool_call_id"] == "call_1" + result_row = dataset[0] + assert result_row.messages[-1].tool_calls[0].id == "call_2" + + asyncio.run(run_test())