diff --git a/openhands-agent-server/openhands/agent_server/__main__.py b/openhands-agent-server/openhands/agent_server/__main__.py index f216e18f4d..12285b4f8b 100644 --- a/openhands-agent-server/openhands/agent_server/__main__.py +++ b/openhands-agent-server/openhands/agent_server/__main__.py @@ -1,6 +1,7 @@ import argparse import atexit import faulthandler +import importlib import signal import sys from types import FrameType @@ -15,6 +16,30 @@ logger = get_logger(__name__) +def preload_modules(modules_arg: str | None) -> None: + """Import user-specified modules so their top-level side effects run. + + Used to register custom tools before any conversation is created, avoiding + a race with dynamic `tool_module_qualnames` import in conversation_service. + """ + if not modules_arg: + return + for module_name in modules_arg.split(","): + module_name = module_name.strip() + if not module_name: + continue + try: + importlib.import_module(module_name) + logger.info("Imported module: %s", module_name) + except ImportError as e: + logger.error( + "Failed to import module '%s' specified in --import-modules: %s", + module_name, + e, + ) + raise + + def check_browser(): """Check if browser functionality can render about:blank.""" executor = None @@ -110,16 +135,28 @@ def main() -> None: action="store_true", help="Check if browser functionality works and exit", ) + parser.add_argument( + "--import-modules", + type=str, + default=None, + help=( + "Comma-separated list of modules to import at startup " + "(e.g. 'myapp.tools,myapp.plugins')" + ), + ) args = parser.parse_args() - # Handle browser check + # Handle browser check (should run without importing user modules) if args.check_browser: if check_browser(): sys.exit(0) else: sys.exit(1) + # Import user modules after early-exit checks + preload_modules(args.import_modules) + print(f"Starting OpenHands Agent Server on {args.host}:{args.port}") print(f"API docs will be available at http://{args.host}:{args.port}/docs") print(f"Auto-reload: {'enabled' if args.reload else 'disabled'}") diff --git a/tests/agent_server/test_preload_modules.py b/tests/agent_server/test_preload_modules.py new file mode 100644 index 0000000000..3c2d6f6e5f --- /dev/null +++ b/tests/agent_server/test_preload_modules.py @@ -0,0 +1,143 @@ +"""Tests for the --import-modules preloading helper.""" + +import logging +import sys +import textwrap +from unittest.mock import MagicMock, patch + +import pytest + +from openhands.agent_server.__main__ import preload_modules + + +class TestPreloadModules: + def test_none_is_noop(self): + with patch( + "openhands.agent_server.__main__.importlib.import_module" + ) as mock_import: + preload_modules(None) + mock_import.assert_not_called() + + def test_empty_string_is_noop(self): + with patch( + "openhands.agent_server.__main__.importlib.import_module" + ) as mock_import: + preload_modules("") + mock_import.assert_not_called() + + def test_single_module(self): + with patch( + "openhands.agent_server.__main__.importlib.import_module" + ) as mock_import: + preload_modules("myapp.tools") + mock_import.assert_called_once_with("myapp.tools") + + def test_comma_separated_strips_whitespace(self): + with patch( + "openhands.agent_server.__main__.importlib.import_module" + ) as mock_import: + preload_modules(" myapp.tools , myapp.plugins ") + assert [c.args[0] for c in mock_import.call_args_list] == [ + "myapp.tools", + "myapp.plugins", + ] + + def test_empty_segments_skipped(self): + with patch( + "openhands.agent_server.__main__.importlib.import_module" + ) as mock_import: + preload_modules("myapp.tools,,myapp.plugins, ") + assert [c.args[0] for c in mock_import.call_args_list] == [ + "myapp.tools", + "myapp.plugins", + ] + + def test_missing_module_raises(self): + # Follow project convention: don't swallow import errors. + with pytest.raises(ModuleNotFoundError): + preload_modules("definitely_not_a_real_module_xyz_2771") + + @pytest.fixture + def fake_tool_module(self, tmp_path, monkeypatch): + """Create an on-disk module whose top-level body has an observable + side effect (analogous to a `register_tool(...)` call).""" + pkg_name = "preload_modules_test_pkg" + pkg = tmp_path / pkg_name + pkg.mkdir() + (pkg / "__init__.py").write_text("") + (pkg / "my_tool.py").write_text( + textwrap.dedent( + """\ + REGISTRY = [] + REGISTRY.append("MyCustomTool") + """ + ) + ) + monkeypatch.syspath_prepend(str(tmp_path)) + qualname = f"{pkg_name}.my_tool" + sys.modules.pop(pkg_name, None) + sys.modules.pop(qualname, None) + yield qualname + sys.modules.pop(pkg_name, None) + sys.modules.pop(qualname, None) + + def test_module_side_effects_execute(self, fake_tool_module): + """With the flag: side effects land before conversations are served — + the race this flag exists to fix.""" + preload_modules(fake_tool_module) + + imported = sys.modules[fake_tool_module] + assert imported.REGISTRY == ["MyCustomTool"] + + def test_module_not_imported_without_flag(self, fake_tool_module): + """Contract companion: if `preload_modules` is not called (i.e. the + operator forgot `--import-modules`), the module stays unimported and + its `register_tool`-style side effects never run. This is exactly + the broken state the CLI flag exists to prevent.""" + preload_modules(None) + + assert fake_tool_module not in sys.modules + + def test_import_error_is_logged_before_raising(self, caplog): + """Import failures should log the module name and error for + operator diagnostics before re-raising.""" + with caplog.at_level(logging.ERROR): + with pytest.raises(ModuleNotFoundError): + preload_modules("no_such_module_xyz_2771") + + assert any( + "no_such_module_xyz_2771" in r.message and "--import-modules" in r.message + for r in caplog.records + ) + + +class TestMainCheckBrowserOrdering: + """Verify --check-browser runs independently of --import-modules.""" + + def test_check_browser_exits_before_preload(self): + """--check-browser should short-circuit before preload_modules + runs, so a broken user module cannot mask the browser check.""" + mock_result = MagicMock() + mock_result.is_error = False + + mock_executor = MagicMock() + mock_executor.return_value = mock_result + + with ( + patch("sys.argv", ["prog", "--check-browser", "--import-modules", "boom"]), + patch("openhands.tools.preset.default.register_default_tools"), + patch( + "openhands.tools.browser_use.impl.BrowserToolExecutor", + return_value=mock_executor, + ), + patch("openhands.agent_server.__main__.preload_modules") as mock_preload, + ): + from openhands.agent_server.__main__ import main + + with pytest.raises(SystemExit) as exc_info: + main() + + # Browser check succeeded → exit 0 + assert exc_info.value.code == 0 + # preload_modules must NOT have been called + mock_preload.assert_not_called() diff --git a/tests/cross/test_remote_conversation_live_server.py b/tests/cross/test_remote_conversation_live_server.py index a7982930b2..84c0f38445 100644 --- a/tests/cross/test_remote_conversation_live_server.py +++ b/tests/cross/test_remote_conversation_live_server.py @@ -5,11 +5,15 @@ """ import json +import shutil import sys +import textwrap import threading import time from collections.abc import Generator +from contextlib import contextmanager from pathlib import Path +from uuid import UUID import httpx import pytest @@ -17,6 +21,7 @@ from litellm.types.utils import Choices, Message as LiteLLMMessage, ModelResponse from pydantic import SecretStr +from openhands.agent_server.__main__ import preload_modules from openhands.sdk import LLM, Agent, AgentContext, Conversation from openhands.sdk.conversation import RemoteConversation from openhands.sdk.event import ( @@ -46,8 +51,12 @@ from openhands.workspace.docker.workspace import find_available_tcp_port -@pytest.fixture -def server_env(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Generator[dict]: +@contextmanager +def live_server_env( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + import_modules: str | None = None, +) -> Generator[dict]: """Launch a real FastAPI server backed by temp workspace and conversations. We set OPENHANDS_AGENT_SERVER_CONFIG_PATH before creating the app so that @@ -59,9 +68,6 @@ def server_env(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Generator[dic workspace_path = tmp_path / "workspace" # Ensure clean directories (both tmp and any leftover in cwd) - import shutil - from pathlib import Path - # Clean up any leftover directories from previous runs in current working directory cwd_conversations = Path("workspace/conversations") if cwd_conversations.exists(): @@ -101,6 +107,9 @@ def server_env(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Generator[dic monkeypatch.setenv("OPENHANDS_AGENT_SERVER_CONFIG_PATH", str(cfg_file)) monkeypatch.delenv("SESSION_API_KEY", raising=False) + if import_modules is not None: + preload_modules(import_modules) + # Build app after env is set from openhands.agent_server.api import create_app from openhands.agent_server.config import Config @@ -153,6 +162,12 @@ def server_env(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Generator[dic shutil.rmtree(cwd_conversations) +@pytest.fixture +def server_env(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Generator[dict]: + with live_server_env(tmp_path, monkeypatch) as env: + yield env + + @pytest.fixture def patched_llm(monkeypatch: pytest.MonkeyPatch) -> None: """Patch LLM.completion to a deterministic assistant message response.""" @@ -202,6 +217,128 @@ def fake_completion( monkeypatch.setattr(LLM, "completion", fake_completion, raising=True) +def test_preloaded_custom_tool_resolves_in_live_server( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +): + """A startup-preloaded tool is available during live conversation creation.""" + from openhands.sdk.tool import Tool, registry as tool_registry + + package_name = "preload_live_server_tools_2771" + module_qualname = f"{package_name}.tools" + package_dir = tmp_path / package_name + package_dir.mkdir() + (package_dir / "__init__.py").write_text("") + (package_dir / "tools.py").write_text( + textwrap.dedent( + """\ + from __future__ import annotations + + from collections.abc import Sequence + from typing import ClassVar + + from openhands.sdk.tool import ( + Action, + Observation, + ToolDefinition, + ToolExecutor, + register_tool, + ) + + + class PreloadedAction(Action): + pass + + + class PreloadedObservation(Observation): + pass + + + class PreloadedExecutor( + ToolExecutor[PreloadedAction, PreloadedObservation] + ): + def __call__( + self, + action: PreloadedAction, + conversation=None, + ) -> PreloadedObservation: + return PreloadedObservation.from_text("preloaded") + + + class PreloadedLiveServerTool( + ToolDefinition[PreloadedAction, PreloadedObservation] + ): + name: ClassVar[str] = "preloaded_live_server_tool" + + @classmethod + def create( + cls, conv_state=None, **params + ) -> Sequence[PreloadedLiveServerTool]: + return [ + cls( + description="Tool registered by startup preload.", + action_type=PreloadedAction, + observation_type=PreloadedObservation, + executor=PreloadedExecutor(), + ) + ] + + + register_tool(PreloadedLiveServerTool.name, PreloadedLiveServerTool) + """ + ) + ) + + registry_snapshot = dict(tool_registry._REG) + usability_snapshot = dict(tool_registry._USABILITY_REG) + module_snapshot = dict(tool_registry._MODULE_QUALNAMES) + monkeypatch.syspath_prepend(str(tmp_path)) + sys.modules.pop(package_name, None) + sys.modules.pop(module_qualname, None) + + try: + with live_server_env( + tmp_path, monkeypatch, import_modules=module_qualname + ) as env: + llm = LLM(model="gpt-4o-mini", api_key=SecretStr("test")) + agent = Agent( + llm=llm, + tools=[Tool(name="preloaded_live_server_tool")], + include_default_tools=[], + ) + payload = { + "agent": agent.model_dump( + mode="json", context={"expose_secrets": True} + ), + "workspace": {"working_dir": "/tmp/workspace/project"}, + "initial_message": { + "role": "user", + "content": [{"type": "text", "text": "Initialize tools."}], + }, + "tool_module_qualnames": {}, + } + + with httpx.Client(base_url=env["host"]) as client: + response = client.post("/api/conversations", json=payload, timeout=10) + + assert response.status_code == 201, response.text + conversation_id = UUID(response.json()["id"]) + event_service = env["conversation_service"]._event_services[conversation_id] + assert event_service._conversation is not None + assert ( + "preloaded_live_server_tool" + in event_service._conversation.agent.tools_map + ) + finally: + sys.modules.pop(package_name, None) + sys.modules.pop(module_qualname, None) + tool_registry._REG.clear() + tool_registry._REG.update(registry_snapshot) + tool_registry._USABILITY_REG.clear() + tool_registry._USABILITY_REG.update(usability_snapshot) + tool_registry._MODULE_QUALNAMES.clear() + tool_registry._MODULE_QUALNAMES.update(module_snapshot) + + def test_websocket_attach_wait_does_not_block_ready_endpoint(server_env): """A blocked websocket snapshot must not stall the live server event loop.