From c37ca459fa06ee0cee22de07e0a18d4efaac92e4 Mon Sep 17 00:00:00 2001 From: "Yufei (Benny) Chen" <1585539+benjibc@users.noreply.github.com> Date: Thu, 7 Aug 2025 11:57:46 -0700 Subject: [PATCH 1/2] Add test for rollout processor tool calls --- .../default_single_turn_rollout_process.py | 31 +++- ...t_default_single_turn_rollout_processor.py | 171 ++++++++++++++++++ 2 files changed, 198 insertions(+), 4 deletions(-) create mode 100644 tests/test_default_single_turn_rollout_processor.py diff --git a/eval_protocol/pytest/default_single_turn_rollout_process.py b/eval_protocol/pytest/default_single_turn_rollout_process.py index f8e7a23e..9e85d744 100644 --- a/eval_protocol/pytest/default_single_turn_rollout_process.py +++ b/eval_protocol/pytest/default_single_turn_rollout_process.py @@ -2,7 +2,10 @@ from typing import List from litellm import acompletion -from openai.types.chat.chat_completion_message import ChatCompletionMessageToolCall +from openai.types.chat.chat_completion_message import ( + ChatCompletionMessageToolCall, + FunctionCall, +) from eval_protocol.models import EvaluationRow, Message from eval_protocol.pytest.types import RolloutProcessorConfig @@ -18,7 +21,24 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: if len(row.messages) == 0: raise ValueError("Messages is empty. Please provide a non-empty dataset") - messages_payload = [{"role": m.role, "content": m.content} for m in row.messages] + messages_payload = [] + for m in row.messages: + payload = {"role": m.role} + if m.content is not None: + payload["content"] = m.content + if m.name is not None: + payload["name"] = m.name + if m.tool_call_id is not None: + payload["tool_call_id"] = m.tool_call_id + if m.tool_calls is not None: + payload["tool_calls"] = [ + tc.model_dump(exclude_none=True) for tc in m.tool_calls + ] + if m.function_call is not None: + payload["function_call"] = m.function_call.model_dump( + exclude_none=True + ) + messages_payload.append(payload) request_params = {"model": config.model, "messages": messages_payload, **config.input_params} @@ -27,8 +47,10 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: response = await acompletion(**request_params) - assistant_content = response.choices[0].message.content or "" - tool_calls = response.choices[0].message.tool_calls if response.choices[0].message.tool_calls else None + assistant_message = response.choices[0].message + assistant_content = assistant_message.content or "" + tool_calls = assistant_message.tool_calls if assistant_message.tool_calls else None + function_call = assistant_message.function_call converted_tool_calls = None if tool_calls: @@ -49,6 +71,7 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: role="assistant", content=assistant_content, tool_calls=converted_tool_calls, + function_call=function_call, ) ] diff --git a/tests/test_default_single_turn_rollout_processor.py b/tests/test_default_single_turn_rollout_processor.py new file mode 100644 index 00000000..a6f5fdb8 --- /dev/null +++ b/tests/test_default_single_turn_rollout_processor.py @@ -0,0 +1,171 @@ +import sys +import types +from dataclasses import dataclass +from typing import Any, Dict, List + +import asyncio +import pytest +from pydantic import BaseModel +from unittest import mock + + +# ---- Stub external dependencies ---- +openai = types.ModuleType("openai") +types_mod = types.ModuleType("openai.types") +chat_mod = types.ModuleType("openai.types.chat") +chat_msg_mod = types.ModuleType("openai.types.chat.chat_completion_message") + + +class FunctionCall(BaseModel): + name: str + arguments: str + + +class ToolFunction(BaseModel): + name: str + arguments: str + + +class ChatCompletionMessageToolCall(BaseModel): + id: str + type: str + function: ToolFunction + + +class CompletionUsage(BaseModel): + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 + + +chat_msg_mod.FunctionCall = FunctionCall +chat_msg_mod.ChatCompletionMessageToolCall = ChatCompletionMessageToolCall +chat_mod.chat_completion_message = chat_msg_mod +openai.types = types_mod +types_mod.chat = chat_mod +types_mod.CompletionUsage = CompletionUsage +sys.modules["openai"] = openai +sys.modules["openai.types"] = types_mod +sys.modules["openai.types.chat"] = chat_mod +sys.modules["openai.types.chat.chat_completion_message"] = chat_msg_mod + + +# Stub litellm +litellm = types.ModuleType("litellm") + + +async def acompletion(**kwargs): + raise NotImplementedError + + +litellm.acompletion = acompletion +sys.modules["litellm"] = litellm + + +# Stub eval_protocol models and types +class Message(BaseModel): + role: str + content: Any = "" + name: str | None = None + tool_call_id: str | None = None + tool_calls: List[ChatCompletionMessageToolCall] | None = None + function_call: FunctionCall | None = None + + +class EvaluationRow(BaseModel): + messages: List[Message] + tools: Any = None + ground_truth: Any = None + + +@dataclass +class RolloutProcessorConfig: + model: str + input_params: Dict[str, Any] + mcp_config_path: str + server_script_path: str | None = None + max_concurrent_rollouts: int = 8 + steps: int = 30 + + +# Register stub modules +import_path = "/workspace/python-sdk/eval_protocol" +eval_protocol_pkg = types.ModuleType("eval_protocol") +eval_protocol_pkg.__path__ = [import_path] +models_module = types.ModuleType("eval_protocol.models") +models_module.Message = Message +models_module.EvaluationRow = EvaluationRow +pytest_pkg = types.ModuleType("eval_protocol.pytest") +pytest_pkg.__path__ = [f"{import_path}/pytest"] +types_module = types.ModuleType("eval_protocol.pytest.types") +types_module.RolloutProcessorConfig = RolloutProcessorConfig + +sys.modules["eval_protocol"] = eval_protocol_pkg +sys.modules["eval_protocol.models"] = models_module +sys.modules["eval_protocol.pytest"] = pytest_pkg +sys.modules["eval_protocol.pytest.types"] = types_module + + +# Now we can import the rollout processor +from eval_protocol.pytest.default_single_turn_rollout_process import ( + default_single_turn_rollout_processor, +) + + +def test_handles_function_call_messages(): + async def run_test(): + tool_call = ChatCompletionMessageToolCall( + id="call_1", + type="function", + function=ToolFunction(name="get_weather", arguments="{}"), + ) + row = EvaluationRow( + messages=[ + Message(role="user", content="Hi"), + Message(role="assistant", tool_calls=[tool_call], content=""), + Message(role="tool", tool_call_id="call_1", content="sunny"), + ], + tools=[{"type": "function", "function": {"name": "get_weather"}}], + ) + config = RolloutProcessorConfig( + model="gpt-4o-mini", input_params={}, mcp_config_path="" + ) + + captured_messages: List[Dict[str, Any]] = [] + + async def fake_acompletion(**kwargs): + nonlocal captured_messages + captured_messages = kwargs["messages"] + return types.SimpleNamespace( + choices=[ + types.SimpleNamespace( + message=types.SimpleNamespace( + content="done", + tool_calls=[ + ChatCompletionMessageToolCall( + id="call_2", + type="function", + function=ToolFunction(name="foo", arguments="{}"), + ) + ], + function_call=None, + ) + ) + ] + ) + + with pytest.raises(NotImplementedError): + await acompletion() + + with mock.patch( + "eval_protocol.pytest.default_single_turn_rollout_process.acompletion", + side_effect=fake_acompletion, + ): + dataset = await default_single_turn_rollout_processor([row], config) + + assert captured_messages[1]["tool_calls"][0]["id"] == "call_1" + assert captured_messages[2]["tool_call_id"] == "call_1" + result_row = dataset[0] + assert result_row.messages[-1].tool_calls[0].id == "call_2" + + asyncio.run(run_test()) From bfb5c2dd41a0fc772d3b54206f15fb77913fa258 Mon Sep 17 00:00:00 2001 From: "Yufei (Benny) Chen" <1585539+benjibc@users.noreply.github.com> Date: Thu, 7 Aug 2025 17:44:53 -0700 Subject: [PATCH 2/2] test: use real dependencies for rollout processor --- ...t_default_single_turn_rollout_processor.py | 137 +++--------------- 1 file changed, 20 insertions(+), 117 deletions(-) diff --git a/tests/test_default_single_turn_rollout_processor.py b/tests/test_default_single_turn_rollout_processor.py index a6f5fdb8..28e61c45 100644 --- a/tests/test_default_single_turn_rollout_processor.py +++ b/tests/test_default_single_turn_rollout_processor.py @@ -1,123 +1,27 @@ -import sys -import types -from dataclasses import dataclass -from typing import Any, Dict, List - import asyncio -import pytest -from pydantic import BaseModel +from types import SimpleNamespace +from typing import Any, Dict, List from unittest import mock +from openai.types.chat.chat_completion_message import ( + ChatCompletionMessageToolCall, + ChatCompletionMessageToolCallFunction, +) - -# ---- Stub external dependencies ---- -openai = types.ModuleType("openai") -types_mod = types.ModuleType("openai.types") -chat_mod = types.ModuleType("openai.types.chat") -chat_msg_mod = types.ModuleType("openai.types.chat.chat_completion_message") - - -class FunctionCall(BaseModel): - name: str - arguments: str - - -class ToolFunction(BaseModel): - name: str - arguments: str - - -class ChatCompletionMessageToolCall(BaseModel): - id: str - type: str - function: ToolFunction - - -class CompletionUsage(BaseModel): - prompt_tokens: int = 0 - completion_tokens: int = 0 - total_tokens: int = 0 - - -chat_msg_mod.FunctionCall = FunctionCall -chat_msg_mod.ChatCompletionMessageToolCall = ChatCompletionMessageToolCall -chat_mod.chat_completion_message = chat_msg_mod -openai.types = types_mod -types_mod.chat = chat_mod -types_mod.CompletionUsage = CompletionUsage -sys.modules["openai"] = openai -sys.modules["openai.types"] = types_mod -sys.modules["openai.types.chat"] = chat_mod -sys.modules["openai.types.chat.chat_completion_message"] = chat_msg_mod - - -# Stub litellm -litellm = types.ModuleType("litellm") - - -async def acompletion(**kwargs): - raise NotImplementedError - - -litellm.acompletion = acompletion -sys.modules["litellm"] = litellm - - -# Stub eval_protocol models and types -class Message(BaseModel): - role: str - content: Any = "" - name: str | None = None - tool_call_id: str | None = None - tool_calls: List[ChatCompletionMessageToolCall] | None = None - function_call: FunctionCall | None = None - - -class EvaluationRow(BaseModel): - messages: List[Message] - tools: Any = None - ground_truth: Any = None - - -@dataclass -class RolloutProcessorConfig: - model: str - input_params: Dict[str, Any] - mcp_config_path: str - server_script_path: str | None = None - max_concurrent_rollouts: int = 8 - steps: int = 30 - - -# Register stub modules -import_path = "/workspace/python-sdk/eval_protocol" -eval_protocol_pkg = types.ModuleType("eval_protocol") -eval_protocol_pkg.__path__ = [import_path] -models_module = types.ModuleType("eval_protocol.models") -models_module.Message = Message -models_module.EvaluationRow = EvaluationRow -pytest_pkg = types.ModuleType("eval_protocol.pytest") -pytest_pkg.__path__ = [f"{import_path}/pytest"] -types_module = types.ModuleType("eval_protocol.pytest.types") -types_module.RolloutProcessorConfig = RolloutProcessorConfig - -sys.modules["eval_protocol"] = eval_protocol_pkg -sys.modules["eval_protocol.models"] = models_module -sys.modules["eval_protocol.pytest"] = pytest_pkg -sys.modules["eval_protocol.pytest.types"] = types_module - - -# Now we can import the rollout processor +from eval_protocol.models import EvaluationRow, Message +from eval_protocol.pytest.types import RolloutProcessorConfig from eval_protocol.pytest.default_single_turn_rollout_process import ( default_single_turn_rollout_processor, ) -def test_handles_function_call_messages(): - async def run_test(): +def test_handles_function_call_messages() -> None: + async def run_test() -> None: tool_call = ChatCompletionMessageToolCall( id="call_1", type="function", - function=ToolFunction(name="get_weather", arguments="{}"), + function=ChatCompletionMessageToolCallFunction( + name="get_weather", arguments="{}" + ), ) row = EvaluationRow( messages=[ @@ -133,19 +37,21 @@ async def run_test(): captured_messages: List[Dict[str, Any]] = [] - async def fake_acompletion(**kwargs): + async def fake_acompletion(**kwargs: Any) -> Any: nonlocal captured_messages captured_messages = kwargs["messages"] - return types.SimpleNamespace( + return SimpleNamespace( choices=[ - types.SimpleNamespace( - message=types.SimpleNamespace( + SimpleNamespace( + message=SimpleNamespace( content="done", tool_calls=[ ChatCompletionMessageToolCall( id="call_2", type="function", - function=ToolFunction(name="foo", arguments="{}"), + function=ChatCompletionMessageToolCallFunction( + name="foo", arguments="{}" + ), ) ], function_call=None, @@ -154,9 +60,6 @@ async def fake_acompletion(**kwargs): ] ) - with pytest.raises(NotImplementedError): - await acompletion() - with mock.patch( "eval_protocol.pytest.default_single_turn_rollout_process.acompletion", side_effect=fake_acompletion,