diff --git a/docs/source/environments.md b/docs/source/environments.md index a14564a1a..207df4e8c 100644 --- a/docs/source/environments.md +++ b/docs/source/environments.md @@ -549,10 +549,10 @@ AgentWorldModel-1K β€” 1,000 synthetic MCP tool-use environments with 10,000 tas ``` ```` -````{grid-item-card} Opencode +````{grid-item-card} OpenCode :class-card: sd-border-1 -`opencode_env` runs the OpenCode coding agent inside an isolated E2B sandbox against any OpenAI-compatible LLM endpoint, optionally capturing per-token logpr... +`opencode_env` runs the OpenCode coding agent inside an isolated E2B sandbox against any OpenAI-compatible LLM endpoint, with trainer-owned interception for RL workflows. +++ ```{button-link} environments/opencode.html diff --git a/envs/opencode_env/README.md b/envs/opencode_env/README.md index 79ebc6ed3..6840bd3fd 100644 --- a/envs/opencode_env/README.md +++ b/envs/opencode_env/README.md @@ -9,20 +9,21 @@ app_port: 8000 base_path: /web tags: - openenv -short_description: OpenCode coding agent in an E2B sandbox with logprob capture +short_description: OpenCode coding agent in an E2B sandbox --- # OpenCode Environment for OpenEnv -`opencode_env` runs the [OpenCode](https://opencode.ai) coding agent inside -an isolated [E2B](https://e2b.dev) sandbox against any OpenAI-compatible -LLM endpoint, optionally capturing per-token logprobs for GRPO training. +`opencode_env` runs the [OpenCode](https://opencode.ai) coding agent +inside an isolated [E2B](https://e2b.dev) sandbox against any OpenAI-compatible +LLM endpoint, optionally capturing per-token logprobs through a transparent +in-sandbox proxy for RL training data. **πŸš€ Try it live**: [`AdithyaSK/opencode-env`](https://huggingface.co/spaces/AdithyaSK/opencode-env) The deployed Space exposes: -- **Web UI** at [`/web`](https://adithyask-opencode-env.hf.space/web) β€” pick endpoint, write task, hit Run, watch live phase log + reward + logprobs. +- **Web UI** at [`/web`](https://adithyask-opencode-env.hf.space/web) β€” pick endpoint, write task, hit Run, watch live phase log + reward. - **MCP tool API** at [`/mcp`](https://adithyask-opencode-env.hf.space/mcp) β€” programmatic `run_rollout` calls. - **OpenAPI docs** at [`/docs`](https://adithyask-opencode-env.hf.space/docs). - **Health** at [`/health`](https://adithyask-opencode-env.hf.space/health). @@ -30,11 +31,11 @@ The deployed Space exposes: The env is **task-agnostic** β€” every rollout is configured at call-time with a uniform Task shape: - - **`instruction`** β€” prompt for the agent - - **`setup`** β€” list of bash commands run *before* the agent (pip + - **`instruction`** β€” prompt for OpenCode + - **`setup`** β€” list of bash commands run *before* OpenCode (pip install, git clone, file downloads β€” anything you need staged in the sandbox) - - **`verify`** β€” list of bash commands run *after* the agent (asserts, + - **`verify`** β€” list of bash commands run *after* OpenCode (asserts, pytest invocations, score-file writes) Reward = `passed_verify / total_verify` unless any `verify` command writes @@ -81,7 +82,6 @@ async def main(): result = RolloutResult.model_validate_json(_extract_text(raw)) print("reward:", result.reward) - print("turns:", len(result.proxy_turns)) print("files:", list(result.files.keys())) print("wall:", result.wall_s, "s") @@ -93,7 +93,6 @@ Expected output (~20s with the prebaked template): ``` reward: 1.0 -turns: 3 files: ['/home/user/workdir/binary_search.py', ...] wall: 19.8 s ``` @@ -132,11 +131,10 @@ factory = OpenCodeSessionFactory( model="gpt-4o-mini", ), sandbox_backend=E2BSandboxBackend(), - mode="transparent_proxy", # captures per-token logprobs + mode="interception_gate", # trainer-owned interception mode ) session = factory.create(task=OpenCodeTask(instruction="...")) session.wait_for_completion() -turns = session.fetch_proxy_trace() # per-turn (tokens, logprobs) session.close() ``` @@ -174,7 +172,7 @@ The image: ## The MCP Tool: `run_rollout` -Single tool, two ways to specify the LLM endpoint: +Single tool, with two ways to specify the LLM endpoint: **Option A β€” endpoint shorthand (recommended)**: pass `endpoint="vllm"` (or `"openai"` / `"hf_router"`). The server resolves @@ -188,27 +186,29 @@ directly. |---|---|---|---| | `endpoint` | `str` | `""` | One of `"vllm"` / `"openai"` / `"hf_router"`. | | `base_url` / `api_key` / `model` | `str` | `""` | Override / supply explicitly. | -| `instruction` | `str` | required | Prompt passed to `opencode run`. | -| `setup` | `list[str]` | `[]` | Bash commands run **before** the agent. | -| `verify` | `list[str]` | `[]` | Bash commands run **after** the agent. | +| `instruction` | `str` | required | Prompt passed to OpenCode. | +| `setup` | `list[str]` | `[]` | Bash commands run **before** OpenCode. | +| `verify` | `list[str]` | `[]` | Bash commands run **after** OpenCode. | | `task_id` | `str` | `""` | Echoed back in result. | -| `mode` | `str` | `"transparent_proxy"` | Or `"black_box"` (no logprobs). | +| `mode` | `str` | `"transparent_proxy"` | Or `"black_box"` for direct LLM calls. In-process trainers can also construct `OpenCodeSessionFactory(mode="interception_gate", ...)`. | | `disable_thinking` | `bool \| None` | `None` (catalog default) | Inject `chat_template_kwargs.enable_thinking=false`. | | `max_tokens_cap` | `int` | `4096` | Per-turn `max_tokens` clamp. | -| `top_logprobs` | `int` | `5` | HF Router cap is 5; OpenAI 0–20; vLLM unbounded. | -| `agent_timeout_s` | `float` | `600.0` | Hard wall budget for opencode. | +| `top_logprobs` | `int` | `5` | Per-token top-k logprobs requested in `transparent_proxy` mode. | +| `agent_timeout_s` | `float` | `600.0` | Hard wall budget for OpenCode. | | `template` | `str` | `""` | E2B template name; `"opencode-rl"` skips ~2 min of install per rollout. | Returns `RolloutResult` JSON with: `reward`, `setup_results[]`, -`verify_results[]`, `proxy_turns[]`, `files{}`, `agent_log_tail`, -`proxy_log_tail`, `wall_s`, `agent_exit_code`, `sandbox_id`, `error`. +`verify_results[]`, `proxy_turns[]` (logprob records in transparent-proxy +mode), `files{}`, `agent_log_tail`, `proxy_log_tail`, `wall_s`, +`agent_exit_code`, `sandbox_id`, `error`. ## Two Operating Modes | Mode | What it does | Best for | |---|---|---| -| **`transparent_proxy`** (default) | In-sandbox proxy at `localhost:7000` forwards opencode's LLM calls to `base_url`, injects `logprobs=true`, captures per-turn `(messages, completion_tokens, logprobs)` to `proxy_trace.jsonl`. | GRPO / RL training, observability, top-k distillation. | -| **`black_box`** | No proxy. opencode talks straight to `base_url`. | Smoke tests, eval, SFT data collection. | +| **`transparent_proxy`** (default) | OpenCode talks to an in-sandbox proxy. The proxy forwards to `base_url`, requests logprobs, strips them before returning to OpenCode, and records `proxy_turns`. | RL data collection, GRPO-style traces. | +| **`black_box`** | OpenCode talks directly to `base_url`. No logprob capture. | Smoke tests, eval, SFT data collection. | +| **`interception_gate`** | Available through the in-process `OpenCodeSessionFactory`; OpenCode calls are routed through trainer-host interception endpoints. | Trainer-owned generation. | ## Environment Variables @@ -227,21 +227,17 @@ sibling `.env` file; on HF Spaces, set them as **Space secrets**. | **OpenAI endpoint** | | | | `OPENAI_API_KEY` | required for `endpoint="openai"` | Standard OpenAI key. | | `OPENAI_BASE_URL` | no | Defaults to `https://api.openai.com/v1`. | -| `OPENAI_MODEL` | no | Defaults to `gpt-4o-mini` (gpt-5.x and o-series refuse logprobs). | +| `OPENAI_MODEL` | no | Defaults to `gpt-4o-mini`. | | **HF Router endpoint** | | | | `HF_ROUTER_API_KEY` | required for `endpoint="hf_router"` | HF user token. | | `HF_ROUTER_BASE_URL` | no | Defaults to `https://router.huggingface.co/v1`. | | `HF_ROUTER_MODEL` | no | Defaults to `Qwen/Qwen3-4B-Instruct-2507:nscale`. | -Pick `provider:` suffixes that actually return logprobs: -**Together / Nscale / Scaleway / SambaNova / Cerebras**. Avoid Novita / -Hyperbolic / Featherless (silent drop) and Groq (HTTP 400). ## Pre-baked E2B Template The first rollout in a fresh E2B sandbox spends ~2 min installing -opencode and the proxy's Python deps. Build a one-time template that -ships those pre-installed: +OpenCode tooling. Build a one-time template that ships it pre-installed: ```bash .venv/bin/python envs/opencode_env/sandbox/build_template.py @@ -263,7 +259,7 @@ opencode_env/ β”œβ”€β”€ __init__.py # re-exports primitive + client + models β”‚ β”œβ”€β”€ client.py # OpenCodeEnv(MCPToolClient) -β”œβ”€β”€ models.py # RolloutResult / RolloutTurn / OpenCodeState +β”œβ”€β”€ models.py # RolloutResult / OpenCodeState β”‚ β”œβ”€β”€ config.py # OpenCodeConfig (primitive) β”œβ”€β”€ harness.py # OpenCodeSession / OpenCodeSessionFactory (CLI-only) @@ -273,17 +269,22 @@ opencode_env/ β”œβ”€β”€ server/ β”‚ β”œβ”€β”€ __init__.py β”‚ β”œβ”€β”€ app.py # FastAPI factory; mounts Gradio at /web -β”‚ β”œβ”€β”€ opencode_environment.py # MCPEnvironment with single ``run_rollout`` tool +β”‚ β”œβ”€β”€ opencode_environment.py # MCPEnvironment with single ``run_rollout`` tool β”‚ β”œβ”€β”€ gradio_ui.py # the /web Gradio Blocks UI β”‚ β”œβ”€β”€ catalog.py # endpoint shorthand resolver β”‚ └── Dockerfile # multi-stage uv build (used by ``openenv build``) β”‚ └── sandbox/ β”œβ”€β”€ __init__.py - β”œβ”€β”€ base.py # SandboxBackend / SandboxHandle Protocols - β”œβ”€β”€ e2b.py # E2B implementation - β”œβ”€β”€ interception.py # in-sandbox FastAPI proxy (logprob capture) └── build_template.py # one-time E2B template builder + +# Shared sandbox runtime (moved to core): +src/openenv/core/harness/sandbox/ +β”œβ”€β”€ base.py # SandboxBackend / SandboxHandle protocols +β”œβ”€β”€ e2b_backend.py # E2B implementation +β”œβ”€β”€ docker_backend.py # local Docker backend +β”œβ”€β”€ hf_backend.py # HF sandbox backend +└── _util.py # shared sandbox shell utilities ``` ## References @@ -291,4 +292,4 @@ opencode_env/ - [OpenEnv docs](https://meta-pytorch.org/OpenEnv/) - [OpenCode CLI](https://opencode.ai/docs/cli/) - [E2B Python SDK](https://e2b.dev/docs) -- [HF Inference Providers logprob matrix](../../../DOCS/HF/hf_inference_providers_logprobs.md) + diff --git a/envs/opencode_env/__init__.py b/envs/opencode_env/__init__.py index 223be6f7b..ea72f4fe5 100644 --- a/envs/opencode_env/__init__.py +++ b/envs/opencode_env/__init__.py @@ -8,31 +8,31 @@ Two layers in this package: -1. **Harness primitive** β€” :class:`OpenCodeSessionFactory` / +1. **Harness primitive** -- :class:`OpenCodeSessionFactory` / :class:`OpenCodeSession` / :class:`OpenCodeConfig` / - :class:`E2BSandboxBackend`. Used in-process to drive one rollout - inside an E2B sandbox. See ``harness.py``. + :class:`E2BSandboxBackend`. Built on the generic + :class:`CLIAgentDriver` from ``openenv.core.harness.agents``. -2. **Deployable env** β€” :class:`OpenCodeEnv` (MCP client) talks to the +2. **Deployable env** -- :class:`OpenCodeEnv` (MCP client) talks to the FastAPI server at ``server/app.py`` over HTTP. Use this when the - sandbox + agent live behind an HTTP boundary (e.g. an HF Space). + sandbox + OpenCode live behind an HTTP boundary (e.g. an HF Space). See ``client.py`` and ``server/``. """ from openenv.core.env_server.mcp_types import CallToolAction, ListToolsAction +from openenv.core.harness.sandbox import SandboxBackend, SandboxHandle from .client import OpenCodeEnv from .config import OpenCodeConfig, Provider from .harness import OpenCodeSession, OpenCodeSessionFactory -from .models import ( - CommandResult, - OpenCodeState, - RolloutResult, - RolloutTurn, -) -from .sandbox import E2BSandboxBackend, SandboxBackend, SandboxHandle +from .models import CommandResult, OpenCodeState, RolloutResult, RolloutTurn from .task import OpenCodeTask +try: + from openenv.core.harness.sandbox import E2BSandboxBackend +except ImportError: # e2b not installed + E2BSandboxBackend = None # type: ignore[assignment,misc] + __all__ = [ # Deployed-env client "OpenCodeEnv", diff --git a/envs/opencode_env/client.py b/envs/opencode_env/client.py index a00afc4e1..e11599b5e 100644 --- a/envs/opencode_env/client.py +++ b/envs/opencode_env/client.py @@ -7,13 +7,14 @@ """Client for the deployed opencode_env server. The server exposes a single MCP tool ``run_rollout`` that runs one OpenCode -rollout in an E2B sandbox and returns a JSON-serialized :class:`RolloutResult`. +rollout in an E2B sandbox and returns a JSON-serialized +:class:`RolloutResult`. Example:: from opencode_env import OpenCodeEnv - with OpenCodeEnv(base_url="https://adithya-sk-opencode-env.hf.space") as env: + with OpenCodeEnv(base_url="https://your-space.hf.space") as env: env.reset() result = env.run_rollout( base_url="https://api.openai.com/v1", @@ -24,7 +25,7 @@ verify=["python /home/user/test.py"], task_id="binary_search_v1", ) - print(result.reward, len(result.proxy_turns)) + print(result.reward) """ from __future__ import annotations @@ -50,8 +51,8 @@ class OpenCodeEnv(MCPToolClient): def run_rollout( self, *, - # Endpoint β€” pass either the shorthand selector OR explicit fields. - endpoint: str = "", # "vllm" | "openai" | "hf_router" + # Endpoint β€” pass either shorthand endpoint or explicit fields. + endpoint: str = "", # "vllm" | "openai" | "hf_router" base_url: str = "", api_key: str = "", model: str = "", @@ -68,7 +69,7 @@ def run_rollout( agent_timeout_s: float = 600.0, template: str = "", ) -> RolloutResult: - """Run one OpenCode rollout and return the typed result. + """Run one opencode rollout and return the typed result. Args: base_url: OpenAI-compatible LLM endpoint (with trailing /v1). @@ -77,30 +78,29 @@ def run_rollout( model: Model id understood by the LLM endpoint (e.g. ``"gpt-4o-mini"``, ``"Qwen/Qwen3.5-4B"``, ``"Qwen/Qwen3-4B-Instruct-2507:nscale"``). - instruction: Prompt passed to ``opencode run``. - setup: Bash commands run sequentially **before** the agent starts. + instruction: Prompt passed to OpenCode. + setup: Bash commands run sequentially **before** OpenCode starts. Each command runs in the sandbox; non-zero exit aborts setup. - verify: Bash commands run sequentially **after** the agent exits. + verify: Bash commands run sequentially **after** OpenCode exits. Reward = ``passed_count / total`` unless any command writes a float to ``/home/user/logs/verifier/reward.txt`` (override). task_id: Echoed back in the result for traceability. - mode: ``"transparent_proxy"`` (captures per-token logprobs via - an in-sandbox FastAPI proxy) or ``"black_box"`` (no proxy). + mode: ``"transparent_proxy"`` (default, captures logprobs) or + ``"black_box"`` (OpenCode talks directly to the LLM). disable_thinking: Inject ``chat_template_kwargs.enable_thinking=false`` on forwarded requests. Needed for Qwen3.5 vLLM; harmless on Instruct variants; rejected by OpenAI direct. - max_tokens_cap: Clamp on per-turn ``max_tokens``. OpenCode asks - for ~32k by default; gpt-4o-mini caps at 16k. - top_logprobs: Top-k logprobs requested upstream. HF Router caps - at 5; OpenAI accepts up to 20; vLLM is unbounded. - agent_timeout_s: Hard wall-clock budget for one ``opencode run``. + max_tokens_cap: Clamp on per-turn ``max_tokens``. + top_logprobs: Per-token top-k logprobs requested in + ``transparent_proxy`` mode. + agent_timeout_s: Hard wall-clock budget for one OpenCode run. template: E2B template name (e.g. ``"opencode-rl"``). Empty string uses the default (slow) base image. Returns: - A :class:`RolloutResult` with reward, per-turn logprobs, file - outputs, setup/verify results, and diagnostic tails. + A :class:`RolloutResult` with reward, proxy_turns, file outputs, + setup/verify results, and diagnostic tails. """ raw = self.call_tool( "run_rollout", diff --git a/envs/opencode_env/config.py b/envs/opencode_env/config.py index 57273b9eb..1e1a8b167 100644 --- a/envs/opencode_env/config.py +++ b/envs/opencode_env/config.py @@ -34,9 +34,7 @@ class OpenCodeConfig(BaseModel): # --- OpenCode CLI --------------------------------------------------------- opencode_version: str = "latest" - disabled_tools: list[str] = Field( - default_factory=lambda: ["webfetch", "question"] - ) + disabled_tools: list[str] = Field(default_factory=lambda: ["webfetch", "question"]) enabled_tools: list[str] | None = None system_prompt: str | None = None extra_opencode_json: dict[str, Any] = Field(default_factory=dict) @@ -47,25 +45,25 @@ class OpenCodeConfig(BaseModel): extra_env: dict[str, str] = Field(default_factory=dict) extra_setup_shell: str | None = None + # --- Model behavior -------------------------------------------------------- + # Direct OpenCode config knobs (black_box / interception_gate). + disable_thinking: bool = False + max_tokens_cap: int | None = None + + # --- Transparent-proxy logprob capture ------------------------------------ + # Compatibility knobs for the HTTP env's logprob-capturing mode. The proxy + # requests OpenAI-compatible logprobs upstream, records them, and strips + # them before returning the response to OpenCode. + proxy_max_tokens_cap: int | None = 16384 + proxy_top_logprobs: int = 5 + proxy_disable_thinking: bool = False + # --- Sandbox paths -------------------------------------------------------- # Root directory inside the sandbox where the primitive writes config, # task files, and logs. E2B's default user is ``user`` with home # ``/home/user``. Override when using a root-privileged backend (Docker). sandbox_home: str = "/home/user" - # --- Transparent-proxy tuning -------------------------------------------- - # Cap ``max_tokens`` / ``max_completion_tokens`` on forwarded requests. - # OpenCode defaults to a very large number (~32000) which exceeds some - # provider limits (e.g. gpt-4o-mini = 16384). Only used in - # ``mode="transparent_proxy"``. ``None`` disables the cap. - proxy_max_tokens_cap: int | None = 16384 - # Per-turn top-k logprobs the proxy requests from the upstream. - proxy_top_logprobs: int = 5 - # Disable reasoning/thinking mode for Qwen3 / Qwen3.5 models. Proxy sets - # ``extra_body.chat_template_kwargs.enable_thinking=false`` on forwarded - # requests. Ignored by providers that don't support the field. - proxy_disable_thinking: bool = False - _PROVIDER_NPM = { "openai_compatible": "@ai-sdk/openai-compatible", diff --git a/envs/opencode_env/harness.py b/envs/opencode_env/harness.py index da4410dd4..ca5c294c2 100644 --- a/envs/opencode_env/harness.py +++ b/envs/opencode_env/harness.py @@ -4,84 +4,47 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""OpenCode session factory + session implementation. - -Implements the :class:`ResourceSessionFactory` / :class:`ResourceSession` -contracts from ``openenv.core.harness`` (PR #471). The session wraps one -sandbox running the ``opencode`` CLI agent. - -Two operating modes: - - - ``mode="black_box"`` β€” opencode talks directly to ``config.base_url``. - No proxy, no logprob capture. Use for smoke tests / SFT / eval. - - ``mode="transparent_proxy"`` (default) β€” an in-sandbox FastAPI proxy - sits between opencode and the upstream LLM. It injects ``logprobs=true`` - on every request and writes per-turn ``(messages, completion_tokens, - per_token_logps)`` to ``proxy_trace.jsonl`` for GRPO consumption. - -Single driver path: opencode is started as a background subprocess via -``opencode run --format json --dangerously-skip-permissions ...`` and we -poll its exit code. The previous ``opencode serve`` driver was removed β€” -opencode CLI is the only path now. -""" +"""OpenCode session factory + session backed by CLIAgentDriver.""" from __future__ import annotations import json +import queue as _queue_mod import shlex +import uuid from pathlib import Path -from typing import Any, Callable, Literal - -from openenv.core.env_server.mcp_types import Tool -from openenv.core.harness import ( - Message, - ResourceSession, - ResourceSessionFactory, - ToolResult, - VerifyResult, +from typing import Any, Literal + +from openenv.core.harness import ResourceSessionFactory +from openenv.core.harness.agents.cli_driver import ( + CLIAgentDriver, + CLIAgentSession, + Verifier, + build_interception_rollout_url, ) +from openenv.core.harness.agents.interception_server import InterceptionServer +from openenv.core.harness.agents.opencode import OPENCODE_SPEC +from openenv.core.harness.sandbox import BgJob, SandboxBackend, SandboxHandle from .config import OpenCodeConfig from .opencode_runtime import ( agent_log_path, build_env_vars, - build_install_cmd, build_opencode_json, build_run_cmd, - instruction_path, opencode_config_path, - system_prompt_path, ) -from .sandbox.base import BgJob, SandboxBackend, SandboxHandle from .task import OpenCodeTask -# Inside-sandbox proxy paths (Mode B). +# Inside-sandbox transparent proxy paths. _PROXY_PORT = 7000 _PROXY_TRACE_PATH = "/home/user/logs/agent/proxy_trace.jsonl" _PROXY_LOG_PATH = "/home/user/logs/agent/proxy.log" - -# Where the proxy source lives on disk (in this repo). Uploaded into the -# sandbox at /home/user/proxy/interception.py before each rollout, unless -# the sandbox was created from a template that already has it baked in. _PROXY_SOURCE_PATH = Path(__file__).parent / "sandbox" / "interception.py" -Verifier = Callable[[SandboxHandle, OpenCodeTask], VerifyResult] - - -class OpenCodeSession(ResourceSession): - """One live OpenCode rollout inside a sandbox. - - The session is created already-running: :meth:`OpenCodeSessionFactory.create` - calls :meth:`start_agent` before returning. Typical usage:: - - session = factory.create(task) - session.wait_for_completion() - result = session.verify([]) - session.close() - """ - +class OpenCodeSession(CLIAgentSession): def __init__( self, *, @@ -90,89 +53,33 @@ def __init__( task: OpenCodeTask, verifier: Verifier | None = None, base_url_override: str | None = None, + agent_bg_job: BgJob | None = None, proxy_trace_path: str | None = None, proxy_bg_job: BgJob | None = None, + interception_server: InterceptionServer | None = None, + interception_rollout_id: str | None = None, + interception_queue: _queue_mod.Queue[str | None] | None = None, ) -> None: - self.sandbox = sandbox - self.config = config - self.task = task - self._verifier = verifier - self._base_url_override = base_url_override - self._bg_job: BgJob | None = None + super().__init__( + spec=OPENCODE_SPEC, + sandbox=sandbox, + task=task, + config=config, + verifier=verifier, + base_url_override=base_url_override, + agent_bg_job=agent_bg_job, + interception_server=interception_server, + interception_rollout_id=interception_rollout_id, + interception_queue=interception_queue, + ) self._proxy_trace_path = proxy_trace_path self._proxy_bg_job = proxy_bg_job - # ------------------------------------------------------------------ - # ResourceSession contract (PR #471) - # ------------------------------------------------------------------ - def initial_messages(self) -> list[Message]: - return [{"role": "user", "content": self.task.instruction}] - - def list_tools(self) -> list[Tool]: - # OpenCode owns its own tool loop β€” none are exposed to the harness. - return [] - - def call_tool(self, name: str, arguments: dict[str, Any]) -> ToolResult: - return ToolResult( - error=( - "OpenCodeSession does not expose external tool calls; the " - "CLI agent owns its own tool loop." - ) - ) - - def verify( - self, - transcript: list[Message], - final_state: Any | None = None, - ) -> VerifyResult: - if self._verifier is None: - return VerifyResult(env_reward=None, done=True) - return self._verifier(self.sandbox, self.task) - - def close(self) -> None: - if self._bg_job is not None: - try: - self._bg_job.kill() - except Exception: - pass - self._bg_job = None - if self._proxy_bg_job is not None: - try: - self._proxy_bg_job.kill() - except Exception: - pass - self._proxy_bg_job = None - self.sandbox.kill() - - # ------------------------------------------------------------------ - # OpenCode-specific session API - # ------------------------------------------------------------------ - def start_agent(self) -> None: - """Launch ``opencode run`` as a background subprocess in the sandbox.""" - if self._bg_job is not None: - return - cmd = build_run_cmd(self.config) - envs = build_env_vars(self.config, base_url_override=self._base_url_override) - self._bg_job = self.sandbox.start_bg(cmd, envs=envs) - - def wait_for_completion(self, timeout_s: float | None = None) -> int: - """Block until the agent exits, returning its exit code.""" - budget = timeout_s if timeout_s is not None else self.config.agent_timeout_s - if self._bg_job is None: - raise RuntimeError("Agent not started; call start_agent() first.") - return self._bg_job.wait(timeout=budget) - def fetch_trace(self) -> str: - """Return the raw ``opencode run`` log (JSON-lines when ``run_format=json``).""" return self.sandbox.read_text(agent_log_path(self.config)) def fetch_proxy_trace(self) -> list[dict[str, Any]]: - """Return per-turn proxy-captured records (Mode B only). - - Each entry has ``request``, ``response``, ``completion_tokens``, - ``completion_token_ids``, ``per_token_logps``, ``finish_reason``, - and ``latency_s``. Returns ``[]`` in Mode A. - """ + """Return per-turn proxy-captured records (transparent_proxy only).""" if self._proxy_trace_path is None: return [] try: @@ -187,33 +94,62 @@ def fetch_proxy_trace(self) -> list[dict[str, Any]]: records.append(json.loads(line)) return records + def close(self) -> None: + if self._proxy_bg_job is not None: + try: + self._proxy_bg_job.kill() + except Exception: + pass + self._proxy_bg_job = None + super().close() -class OpenCodeSessionFactory(ResourceSessionFactory): - """Produce isolated per-rollout :class:`OpenCodeSession` instances. + def wait_for_completion(self, timeout_s: float | None = None) -> int: + budget = timeout_s if timeout_s is not None else self.config.agent_timeout_s + if self._agent_bg_job is None: + raise RuntimeError("Agent not started.") + return self._agent_bg_job.wait(timeout=budget) + + def start_agent(self) -> None: + if self._agent_bg_job is not None: + return + cmd = build_run_cmd(self.config) + envs = build_env_vars(self.config, base_url_override=self._base_url_override) + self._agent_bg_job = self.sandbox.start_bg(cmd, envs=envs) - The factory owns sandbox provisioning, opencode install, config injection, - and (Mode B) proxy startup. Each :meth:`create` call returns a fresh - sandbox with a running agent. - """ +class OpenCodeSessionFactory(ResourceSessionFactory): def __init__( self, *, config: OpenCodeConfig, sandbox_backend: SandboxBackend, - mode: Literal["black_box", "transparent_proxy"] = "black_box", + mode: Literal[ + "black_box", "transparent_proxy", "interception_gate" + ] = "transparent_proxy", verifier: Verifier | None = None, install_timeout_s: int = 240, setup_timeout_s: int = 300, + interception_server: InterceptionServer | None = None, + interception_base_url: str | None = None, ) -> None: - if mode not in {"black_box", "transparent_proxy"}: + if mode not in {"black_box", "transparent_proxy", "interception_gate"}: raise ValueError(f"Unknown mode: {mode!r}") self._config = config self._backend = sandbox_backend self._mode = mode self._verifier = verifier - self._install_timeout_s = install_timeout_s - self._setup_timeout_s = setup_timeout_s + driver_mode: Literal["black_box", "interception_gate"] = ( + "black_box" if mode == "transparent_proxy" else mode + ) + self._driver = CLIAgentDriver( + spec=OPENCODE_SPEC, + sandbox_backend=sandbox_backend, + mode=driver_mode, + install_timeout_s=install_timeout_s, + setup_timeout_s=setup_timeout_s, + interception_server=interception_server, + interception_base_url=interception_base_url, + ) def create( self, @@ -222,24 +158,24 @@ def create( episode_id: str | None = None, ) -> OpenCodeSession: import logging - _log = logging.getLogger(__name__) + _log = logging.getLogger(__name__) oc_task = OpenCodeTask.coerce(task) - sandbox_timeout = int(self._config.agent_timeout_s) + 300 + setup_parts: list[str] = [] + if self._config.extra_setup_shell: + setup_parts.append(self._config.extra_setup_shell) + if oc_task.setup_shell: + setup_parts.append(oc_task.setup_shell) + if setup_parts: + oc_task = oc_task.model_copy( + update={"setup_shell": "set -e\n" + "\n".join(setup_parts)} + ) - _log.info( - "factory.create: creating sandbox timeout=%ds mode=%s", - sandbox_timeout, self._mode, - ) + sandbox_timeout = int(self._config.agent_timeout_s) + 300 sandbox = self._backend.create( timeout_s=sandbox_timeout, metadata={"episode_id": episode_id} if episode_id else None, ) - sid = ( - getattr(sandbox, "sandbox_id", None) - or getattr(getattr(sandbox, "raw", None), "sandbox_id", "?") - ) - _log.info("factory.create: sandbox=%s β€” bootstrapping…", sid) try: self._bootstrap_sandbox(sandbox, oc_task) except Exception as exc: @@ -248,205 +184,81 @@ def create( raise base_url_override: str | None = None + interception_rollout_id: str | None = None + interception_queue: _queue_mod.Queue[str | None] | None = None proxy_trace_path: str | None = None proxy_bg_job: BgJob | None = None - if self._mode == "transparent_proxy": - _log.info( - "factory.create: starting interception proxy on :%d β†’ %s", - _PROXY_PORT, self._config.base_url, + + if self._mode == "interception_gate": + interception_server = self._driver._interception_server + if interception_server is None: + raise RuntimeError( + "interception_gate mode requires an InterceptionServer" + ) + interception_base_url = self._driver._interception_base_url + if interception_base_url is None: + raise RuntimeError( + "interception_gate mode requires interception_base_url" + ) + rollout_id = episode_id or f"rollout_{uuid.uuid4().hex[:8]}" + interception_rollout_id = rollout_id + interception_queue = interception_server.register_rollout(rollout_id) + base_url_override = build_interception_rollout_url( + interception_base_url, + rollout_id, ) + elif self._mode == "transparent_proxy": proxy_bg_job, base_url_override, proxy_trace_path = self._start_proxy( sandbox ) - _log.info("factory.create: proxy up at %s", base_url_override) - # Rewrite opencode.json so opencode points at the proxy. Force - # ``openai_compatible`` so opencode hits ``/v1/chat/completions`` - # (which the proxy serves) rather than provider-specific paths. - from .config import OpenCodeConfig as _OCC - - proxy_cfg = _OCC( - **{ - **self._config.model_dump(), + + run_config = self._config + if base_url_override is not None: + api_key = self._config.api_key + if self._mode == "interception_gate": + assert self._driver._interception_server is not None + api_key = self._driver._interception_server.secret + run_config = self._config.model_copy( + update={ "provider": "openai_compatible", "base_url": base_url_override, + "api_key": api_key, } ) - sandbox.write_text( - opencode_config_path(self._config), - build_opencode_json(proxy_cfg), - ) + sandbox.write_text( + opencode_config_path(self._config), + build_opencode_json(run_config), + ) + agent_bg_job = self._driver._start_agent( + sandbox, + oc_task, + run_config, + base_url_override=base_url_override, + ) - session = OpenCodeSession( + return OpenCodeSession( sandbox=sandbox, - config=self._config, + config=run_config, task=oc_task, verifier=self._verifier, base_url_override=base_url_override, + agent_bg_job=agent_bg_job, proxy_trace_path=proxy_trace_path, proxy_bg_job=proxy_bg_job, + interception_server=self._driver._interception_server, + interception_rollout_id=interception_rollout_id, + interception_queue=interception_queue, ) - session.start_agent() - return session - - # ------------------------------------------------------------------ - def _wait_for_sandbox_ready( - self, - sandbox: SandboxHandle, - *, - attempts: int = 15, - delay_s: float = 1.0, - ) -> None: - """Probe the sandbox until ``echo ok`` succeeds. - - E2B (and other backends) sometimes return the handle before the - guest is fully ready. Issue ``echo ok`` with short timeouts until - it succeeds. Returns silently on success; raises ``RuntimeError`` - on prolonged failure. - """ - import time - - last_err = "" - for _ in range(attempts): - try: - r = sandbox.exec("echo ok", timeout=5) - if r.exit_code == 0 and "ok" in (r.stdout or ""): - return - last_err = (r.stderr or r.stdout or "").strip() or f"exit={r.exit_code}" - except Exception as exc: # noqa: BLE001 - last_err = f"{type(exc).__name__}: {exc}" - time.sleep(delay_s) - raise RuntimeError( - f"sandbox did not become ready within {attempts * delay_s:.0f}s " - f"(last error: {last_err})" - ) - - def _exec_with_retry( - self, - sandbox: SandboxHandle, - cmd: str, - *, - timeout: float, - attempts: int = 3, - backoff_s: float = 3.0, - label: str = "cmd", - ): - """Run ``sandbox.exec`` with exponential backoff on transient failure. - - Transient = ``exit_code != 0`` AND empty stderr (SIGKILL / network - blip signature) OR an exception during exec. Final failure is raised - as ``RuntimeError`` carrying the last exit code + stderr. - """ - import time - - last_stdout = "" - last_stderr = "" - last_exit = 0 - for i in range(attempts): - try: - r = sandbox.exec(cmd, timeout=timeout) - if r.exit_code == 0: - return r - last_stdout = r.stdout or "" - last_stderr = r.stderr or "" - last_exit = r.exit_code - if last_stderr.strip(): - break - except Exception as exc: # noqa: BLE001 - last_stderr = f"{type(exc).__name__}: {exc}" - last_exit = -1 - if i + 1 < attempts: - time.sleep(backoff_s * (2**i)) - raise RuntimeError( - f"{label} failed after {attempts} attempts " - f"(exit={last_exit}, stderr={last_stderr!r}, stdout_tail={last_stdout[-400:]!r})" - ) - - def _opencode_already_installed(self, sandbox: SandboxHandle) -> bool: - """Cheap probe β€” returns True if opencode is on disk in the sandbox. - - Used to skip the slow ``curl install`` step when running against a - prebaked template that already ships opencode. - """ - try: - r = sandbox.exec( - "/home/user/.opencode/bin/opencode --version", - timeout=10, - ) - return r.exit_code == 0 - except Exception: - return False - - def _bootstrap_sandbox( - self, - sandbox: SandboxHandle, - task: OpenCodeTask, - ) -> None: - """Install opencode, write config + task files, run optional setup.""" - - # Stage 1: wait for the sandbox to be responsive. - self._wait_for_sandbox_ready(sandbox) - - # Stage 2: install opencode (skipped if a prebaked template already - # has it). curl|bash is flaky β€” retry with backoff. - if not self._opencode_already_installed(sandbox): - self._exec_with_retry( - sandbox, - build_install_cmd(self._config), - timeout=self._install_timeout_s, - attempts=3, - backoff_s=3.0, - label="opencode install", - ) - - sandbox.write_text( - opencode_config_path(self._config), - build_opencode_json(self._config), - ) - sandbox.write_text(instruction_path(self._config), task.instruction) - - if self._config.system_prompt: - sandbox.write_text( - system_prompt_path(self._config), - self._config.system_prompt, - ) - - for remote_path, content in task.upload_files.items(): - sandbox.write_text(remote_path, content) - - if self._config.extra_setup_shell: - self._exec_with_retry( - sandbox, - self._config.extra_setup_shell, - timeout=self._setup_timeout_s, - attempts=2, - backoff_s=2.0, - label="extra_setup_shell", - ) - - if task.setup_shell: - r = sandbox.exec(task.setup_shell, timeout=self._setup_timeout_s) - if r.exit_code != 0: - raise RuntimeError( - f"task.setup_shell failed ({r.exit_code}): {r.stderr}" - ) def _start_proxy( self, sandbox: SandboxHandle, ) -> tuple[BgJob, str, str]: - """Install proxy deps + start the proxy as a bg job inside the sandbox. - - Returns ``(proxy_bg_job, base_url_override, proxy_trace_path)``. - Skips the pip install + source-upload steps when the prebaked - template already has them in place. - """ - proxy_already_present = sandbox.exists( - "/home/user/proxy/interception.py" - ) + """Start the in-sandbox logprob-capturing proxy.""" + proxy_already_present = sandbox.exists("/home/user/proxy/interception.py") if not proxy_already_present: - # Install proxy deps (idempotent on retries). - self._exec_with_retry( + self._driver._exec_with_retry( sandbox, "pip install --quiet 'fastapi>=0.104' 'uvicorn[standard]>=0.24' " "'httpx>=0.27' 2>&1 | tail -20", @@ -455,7 +267,6 @@ def _start_proxy( backoff_s=2.0, label="proxy deps install", ) - # Upload the proxy module into the sandbox. sandbox.write_text( "/home/user/proxy/interception.py", _PROXY_SOURCE_PATH.read_text(), @@ -480,8 +291,6 @@ def _start_proxy( ) if self._config.proxy_disable_thinking: proxy_args.append("--disable-thinking") - # Force the upstream model id on every forwarded request β€” opencode's - # internal title-gen call sometimes strips the provider prefix. if self._config.model: proxy_args.extend(["--model-override", self._config.model]) @@ -494,8 +303,6 @@ def _start_proxy( proxy_env = {"OPENCODE_UPSTREAM_API_KEY": self._config.api_key} proxy_job = sandbox.start_bg(proxy_cmd, envs=proxy_env) - # Wait for the proxy to start listening. Cold uvicorn boot inside - # E2B can take anywhere from <1s to ~30s depending on cache state. import time attempts = 120 @@ -523,6 +330,9 @@ def _start_proxy( base_url_override = f"http://127.0.0.1:{_PROXY_PORT}/v1" return proxy_job, base_url_override, _PROXY_TRACE_PATH + def _bootstrap_sandbox(self, sandbox: SandboxHandle, task: OpenCodeTask) -> None: + self._driver.bootstrap_sandbox(sandbox, task, self._config) + __all__ = [ "OpenCodeSession", diff --git a/envs/opencode_env/models.py b/envs/opencode_env/models.py index b218c5f78..d2b023839 100644 --- a/envs/opencode_env/models.py +++ b/envs/opencode_env/models.py @@ -21,7 +21,7 @@ class RolloutTurn(BaseModel): - """One intercepted LLM turn captured by the in-sandbox proxy (Mode B).""" + """One intercepted LLM turn captured by transparent-proxy mode.""" turn: int finish_reason: str | None = None @@ -35,21 +35,22 @@ class RolloutTurn(BaseModel): class CommandResult(BaseModel): - """Outcome of one bash command in setup/verify.""" + """Outcome of one bash command in setup/verify. + + When ``exit_code`` is ``None``, the command ran during sandbox bootstrap + and its individual exit code was not captured (bootstrap succeeds or fails + atomically). + """ cmd: str - exit_code: int + exit_code: int | None = None stdout: str = "" stderr: str = "" duration_s: float = 0.0 class RolloutResult(BaseModel): - """Full payload returned from one ``run_rollout`` invocation. - - The trainer (or any client) decodes this from the MCP tool result JSON - and feeds ``proxy_turns`` + ``reward`` into GRPO. - """ + """Full payload returned from one ``run_rollout`` invocation.""" # Identifiers task_id: str = "" @@ -59,13 +60,13 @@ class RolloutResult(BaseModel): reward: float | None = None agent_exit_code: int | None = None wall_s: float = 0.0 - mode: str = "transparent_proxy" + mode: str = "black_box" # Per-step results setup_results: list[CommandResult] = Field(default_factory=list) verify_results: list[CommandResult] = Field(default_factory=list) - # Per-turn LLM trajectory (empty in black_box mode) + # Per-turn LLM trajectory (empty outside transparent_proxy mode) proxy_turns: list[RolloutTurn] = Field(default_factory=list) # Filesystem the agent produced (path -> contents, truncated) diff --git a/envs/opencode_env/opencode_runtime.py b/envs/opencode_env/opencode_runtime.py index 07fd5322d..0f1484e3a 100644 --- a/envs/opencode_env/opencode_runtime.py +++ b/envs/opencode_env/opencode_runtime.py @@ -52,6 +52,12 @@ def build_opencode_json(config: OpenCodeConfig) -> str: """ provider_name = "intercepted" + model_key = config.model.split("/", 1)[-1] + + model_block: dict[str, Any] = {"name": "Intercepted Model"} + if config.max_tokens_cap is not None: + model_block["limit"] = {"output": config.max_tokens_cap} + provider_block: dict[str, Any] = { "npm": provider_npm_package(config.provider), "name": "Intercepted", @@ -61,16 +67,21 @@ def build_opencode_json(config: OpenCodeConfig) -> str: "timeout": config.request_timeout_ms, }, "models": { - config.model.split("/", 1)[-1]: {"name": "Intercepted Model"}, + model_key: model_block, }, } doc: dict[str, Any] = { "$schema": "https://opencode.ai/config.json", - "model": f"{provider_name}/{config.model.split('/', 1)[-1]}", + "model": f"{provider_name}/{model_key}", "provider": {provider_name: provider_block}, } + # Disable thinking/reasoning tokens when requested. AI SDK respects + # the top-level "reasoning" key to control reasoning token generation. + if config.disable_thinking: + doc["reasoning"] = "none" + tools = _build_tools_block(config) if tools: doc["tools"] = tools @@ -111,7 +122,9 @@ def build_run_cmd(config: OpenCodeConfig) -> str: ).strip() -def build_env_vars(config: OpenCodeConfig, *, base_url_override: str | None = None) -> dict[str, str]: +def build_env_vars( + config: OpenCodeConfig, *, base_url_override: str | None = None +) -> dict[str, str]: """Return env vars to set on the OpenCode process. When a proxy is wrapping ``config.base_url`` the factory passes the proxy's diff --git a/envs/opencode_env/pyproject.toml b/envs/opencode_env/pyproject.toml index 50337baa2..a72ade07d 100644 --- a/envs/opencode_env/pyproject.toml +++ b/envs/opencode_env/pyproject.toml @@ -11,7 +11,7 @@ build-backend = "setuptools.build_meta" [project] name = "openenv-opencode-env" version = "0.1.0" -description = "OpenCode coding-agent environment for OpenEnv β€” runs the OpenCode CLI in an E2B sandbox against any OpenAI-compatible LLM, optionally capturing per-token logprobs." +description = "OpenCode environment for OpenEnv β€” runs the OpenCode CLI in an E2B sandbox against OpenAI-compatible LLM endpoints." requires-python = ">=3.10" dependencies = [ # Core OpenEnv (server + MCP). 0.3.0 ships the harness runtime. @@ -26,7 +26,7 @@ dependencies = [ # behavior drift on Space rebuilds. "gradio>=6.0.0", - # OpenCode harness primitive β€” sandbox + proxy + agent driver + # OpenCode harness primitive β€” sandbox + agent driver "httpx>=0.27.0", "e2b>=1.0.0", ] @@ -48,9 +48,8 @@ packages = [ "opencode_env", "opencode_env.sandbox", "opencode_env.server", - "opencode_env.tests", ] -package-dir = { "opencode_env" = ".", "opencode_env.sandbox" = "sandbox", "opencode_env.server" = "server", "opencode_env.tests" = "tests" } +package-dir = { "opencode_env" = ".", "opencode_env.sandbox" = "sandbox", "opencode_env.server" = "server" } [tool.setuptools.package-data] opencode_env = ["**/*.md"] diff --git a/envs/opencode_env/sandbox/__init__.py b/envs/opencode_env/sandbox/__init__.py index 321f81547..8a2477104 100644 --- a/envs/opencode_env/sandbox/__init__.py +++ b/envs/opencode_env/sandbox/__init__.py @@ -4,50 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""Sandbox backends for the OpenCode harness. +"""Sandbox backends live in ``openenv.core.harness.sandbox``. -The primitive ships with :class:`E2BSandboxBackend` as the default; any backend -that satisfies the :class:`SandboxBackend` / :class:`SandboxHandle` protocols -can be swapped in. - -The ``e2b`` import is wrapped in ``try/except`` so this package can be loaded -in environments where ``e2b`` isn't installed (CI smoke tests, lint runs). -Instantiating ``E2BSandboxBackend`` without ``e2b`` raises a clear error. +This package exists only for the ``build_template`` helper used by E2B +template builds. Import sandbox protocols and backends from +``openenv.core.harness.sandbox`` directly. """ - -from .base import BgJob, ExecResult, SandboxBackend, SandboxHandle - -try: - from .e2b import E2BBgJob, E2BSandboxBackend, E2BSandboxHandle # noqa: F401 -except ImportError as _e2b_err: # pragma: no cover - - class _RequiresE2B: - """Stub raised when ``e2b`` is not installed. - - Lets the package import cleanly so unit tests, ``openenv validate``, - and the docs build can run without the heavy ``e2b`` dependency. - Actually constructing one of these classes raises a clear ImportError. - """ - - _e2b_import_error = _e2b_err - - def __init__(self, *_args, **_kwargs): - raise ImportError( - "e2b is not installed; install it via " - "`pip install 'openenv-opencode-env[dev]'` or " - "`pip install e2b` to use E2BSandboxBackend. " - f"Original import error: {self._e2b_import_error}" - ) - - E2BBgJob = E2BSandboxBackend = E2BSandboxHandle = _RequiresE2B # type: ignore[assignment] - - -__all__ = [ - "BgJob", - "ExecResult", - "SandboxBackend", - "SandboxHandle", - "E2BBgJob", - "E2BSandboxBackend", - "E2BSandboxHandle", -] diff --git a/envs/opencode_env/sandbox/build_template.py b/envs/opencode_env/sandbox/build_template.py index 01c32d537..67cf0756d 100644 --- a/envs/opencode_env/sandbox/build_template.py +++ b/envs/opencode_env/sandbox/build_template.py @@ -4,35 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""Build a pre-baked E2B template with opencode + proxy deps already installed. - -Run-time per rollout drops from ~3 min (cold install) to ~30s once the -template is built, because we skip: - - - ``curl https://opencode.ai/install | bash`` (~30-90s) - - ``pip install fastapi uvicorn httpx`` (~30-60s) - - directory layout setup - - copying the proxy source - -The template ships: - - - opencode CLI at ``/home/user/.opencode/bin/opencode`` - - Python deps for the in-sandbox proxy - - The proxy source at ``/home/user/proxy/interception.py`` - - Pre-created dirs: ``~/.config/opencode``, ``~/logs/{agent,verifier}``, - ``~/task``, ``~/workdir``, ``~/proxy`` - - Default workdir: ``/home/user/workdir`` - -Usage:: - - .venv/bin/python envs/opencode_env/tests/build_e2b_template.py - # β†’ builds (or rebuilds) ``opencode-rl`` template, prints template id - -Then ``test_five_sorts_e2e.py`` will use it via ``--template opencode-rl``. - -Requires ``E2B_API_KEY`` in the environment. First build is ~3-8 min; -subsequent builds reuse the cache and can finish in <60s. -""" +"""Build a pre-baked E2B template with opencode already installed.""" from __future__ import annotations @@ -41,11 +13,9 @@ import sys from pathlib import Path -from e2b import Template, default_build_logger - +from e2b import default_build_logger, Template -_ENV_DIR = Path(__file__).resolve().parent -_PROXY_SOURCE = _ENV_DIR / "interception.py" +_REPO_ROOT = Path(__file__).resolve().parents[3] def _load_env(path: Path) -> None: @@ -63,25 +33,9 @@ def _load_env(path: Path) -> None: def build_template(name: str, *, skip_cache: bool = False) -> str: - if not _PROXY_SOURCE.exists(): - raise RuntimeError(f"proxy source missing at {_PROXY_SOURCE}") - - # Template.copy() resolves relative paths against the caller's source - # file directory. This script lives next to ``interception.py`` so the - # bare filename works. - - # Stage 1 (root): system-wide pip deps for the proxy. - # Stage 2 (user): opencode install + dir layout + proxy copy. template = ( Template() .from_python_image("3.12") - .pip_install( - [ - "fastapi>=0.104", - "uvicorn[standard]>=0.24", - "httpx>=0.27", - ] - ) .set_user("user") .run_cmd("curl -fsSL https://opencode.ai/install | bash") .run_cmd("/home/user/.opencode/bin/opencode --version") @@ -90,13 +44,10 @@ def build_template(name: str, *, skip_cache: bool = False) -> str: .make_dir("/home/user/logs/verifier") .make_dir("/home/user/task") .make_dir("/home/user/workdir") - .make_dir("/home/user/proxy") - .copy("interception.py", "/home/user/proxy/interception.py") .set_workdir("/home/user/workdir") ) if skip_cache: template = template.skip_cache() - info = Template.build( template, name, @@ -109,32 +60,15 @@ def build_template(name: str, *, skip_cache: bool = False) -> str: def main(argv: list[str] | None = None) -> int: p = argparse.ArgumentParser(prog="build_e2b_template") - p.add_argument( - "--name", - default="opencode-rl", - help="Template name (default: opencode-rl).", - ) - p.add_argument( - "--skip-cache", - action="store_true", - help="Force a clean rebuild, ignoring cache.", - ) + p.add_argument("--name", default="opencode-rl") + p.add_argument("--skip-cache", action="store_true") args = p.parse_args(argv) - - _load_env(_ENV_DIR / ".env") + _load_env(_REPO_ROOT / "envs" / "opencode_env" / "sandbox" / ".env") if not os.environ.get("E2B_API_KEY"): print("ERROR: E2B_API_KEY required.", file=sys.stderr) return 2 - - print(f"Building template '{args.name}' " - f"(proxy source: {_PROXY_SOURCE})") - print(f"Skip cache: {args.skip_cache}") - print() - template_id = build_template(args.name, skip_cache=args.skip_cache) - print() print(f"Built. Template id/name: {template_id}") - print(f"Use in code: Sandbox.create(template='{args.name}')") return 0 diff --git a/envs/opencode_env/server/app.py b/envs/opencode_env/server/app.py index 200c7f2d7..0757ef229 100644 --- a/envs/opencode_env/server/app.py +++ b/envs/opencode_env/server/app.py @@ -56,19 +56,13 @@ def _load_env_file() -> None: try: from openenv.core.env_server.http_server import create_app - from openenv.core.env_server.mcp_types import ( - CallToolAction, - CallToolObservation, - ) + from openenv.core.env_server.mcp_types import CallToolAction, CallToolObservation from .gradio_ui import opencode_gradio_builder from .opencode_environment import OpenCodeEnvironment except ImportError: # pragma: no cover from openenv.core.env_server.http_server import create_app - from openenv.core.env_server.mcp_types import ( - CallToolAction, - CallToolObservation, - ) + from openenv.core.env_server.mcp_types import CallToolAction, CallToolObservation from server.gradio_ui import opencode_gradio_builder # type: ignore from server.opencode_environment import OpenCodeEnvironment # type: ignore diff --git a/envs/opencode_env/server/gradio_ui.py b/envs/opencode_env/server/gradio_ui.py index 79a696d75..bb4340aef 100644 --- a/envs/opencode_env/server/gradio_ui.py +++ b/envs/opencode_env/server/gradio_ui.py @@ -19,8 +19,7 @@ agent_timeout_s, template). - Preset buttons for the ready-made example tasks. - Run button β†’ result panel with reward, setup/verify per-command - results, file outputs, logprob stats, agent + proxy log tails, - and the raw RolloutResult JSON. + results, file outputs, proxy/OpenCode log tails, and the raw RolloutResult JSON. """ from __future__ import annotations @@ -31,10 +30,14 @@ import gradio as gr try: - from .catalog import ENDPOINT_KINDS, catalog_summary, resolve_endpoint + from .catalog import catalog_summary, ENDPOINT_KINDS, resolve_endpoint from .opencode_environment import OpenCodeEnvironment except ImportError: # pragma: no cover - from server.catalog import ENDPOINT_KINDS, catalog_summary, resolve_endpoint # type: ignore + from server.catalog import ( # type: ignore + catalog_summary, + ENDPOINT_KINDS, + resolve_endpoint, + ) from server.opencode_environment import OpenCodeEnvironment # type: ignore @@ -144,56 +147,14 @@ def _command_rows(items: list[dict[str, Any]]) -> list[list[str]]: cmd if len(cmd) <= 80 else cmd[:77] + "...", str(it.get("exit_code", "")), f"{it.get('duration_s', 0):.2f}s", - (it.get("stderr") or "").splitlines()[-1][:80] if it.get("exit_code") else "", + (it.get("stderr") or "").splitlines()[-1][:80] + if it.get("exit_code") + else "", ] ) return rows -def _logprobs_md(turns: list[dict[str, Any]]) -> str: - if not turns: - return "_No proxy turns captured._\n\nThis is normal in `black_box` mode. In `transparent_proxy` mode, an empty list usually means the agent never made an LLM call (check the agent log)." - n = len(turns) - productive = sum(1 for t in turns if t.get("completion_tokens")) - total_toks = sum(len(t.get("completion_tokens") or []) for t in turns) - all_lps = [ - float(x) - for t in turns - for x in (t.get("per_token_logps") or []) - if x is not None - ] - mean_lp = (sum(all_lps) / len(all_lps)) if all_lps else None - lines = [ - f"**turns**: `{n}` Β· **productive**: `{productive}` Β· " - f"**total_completion_tokens**: `{total_toks}`", - ] - if mean_lp is not None: - lines.append(f"**mean_logprob**: `{mean_lp:+.4f}`") - finishes: dict[str, int] = {} - for t in turns: - f = t.get("finish_reason") or "unknown" - finishes[f] = finishes.get(f, 0) + 1 - if finishes: - lines.append( - "**finish_reasons**: " + " ".join(f"`{k}={v}`" for k, v in finishes.items()) - ) - productive_rows = [t for t in turns if t.get("completion_tokens")] - if productive_rows: - first = productive_rows[0] - toks = first["completion_tokens"][:10] - lps = first.get("per_token_logps") or [] - lines.append( - f"\n**first productive turn (first 10 tokens)**\n\n" - f"```\n" - + "\n".join( - f" {tok!r:<14} {lp:+.3f}" if i < len(lps) else f" {tok!r:<14} -" - for i, (tok, lp) in enumerate(zip(toks, lps + [None] * len(toks))) - ) - + "\n```" - ) - return "\n\n".join(lines) - - def _live_status_md( endpoint_kind: str, model: str, @@ -249,12 +210,12 @@ def _catalog_banner() -> str: def opencode_gradio_builder( - web_manager, # noqa: ARG001 (unused: we instantiate the env directly) - action_fields, # noqa: ARG001 - metadata, # noqa: ARG001 - is_chat_env, # noqa: ARG001 + web_manager, # noqa: ARG001 (unused: we instantiate the env directly) + action_fields, # noqa: ARG001 + metadata, # noqa: ARG001 + is_chat_env, # noqa: ARG001 title, - quick_start_md, # noqa: ARG001 + quick_start_md, # noqa: ARG001 ) -> gr.Blocks: """Build the opencode_env console. @@ -283,9 +244,9 @@ def run( """Generator handler β€” yields incremental UI updates. Each ``yield`` is a tuple matching ``outputs=[...]``: - (summary_md, setup_table, verify_table, files_md, logprobs_md, - logs_md, raw_json). Early yields keep summary_md as a live phase - log while the rollout runs; the final yield populates everything. + (summary_md, setup_table, verify_table, files_md, logs_md, + raw_json). Early yields keep summary_md as a live phase log while + the rollout runs; the final yield populates everything. """ import queue import threading @@ -299,7 +260,7 @@ def run( ) except ValueError as exc: err = f"endpoint resolution failed: {exc}" - yield (f"### error\n\n```\n{err}\n```", [], [], "", "", "", {"error": err}) + yield (f"### error\n\n```\n{err}\n```", [], [], "", "", {"error": err}) return # Translate "auto" / "on" / "off" into bool / None. @@ -355,7 +316,11 @@ def _worker(): # First yield: announce we've started. Empty result panels. yield ( f"### running…\n\n_endpoint=`{resolved.kind}` model=`{resolved.model}` mode=`{mode}`_", - [], [], "", "", "", {}, + [], + [], + "", + "", + {}, ) status_lines: list[tuple[float, str]] = [] @@ -374,8 +339,14 @@ def _worker(): # Render the live status pane. elapsed = time.time() - t_start - md = _live_status_md(resolved.kind, resolved.model, mode, elapsed, status_lines) - yield (md, [], [], "", "", "", {}) + md = _live_status_md( + resolved.kind, + resolved.model, + mode, + elapsed, + status_lines, + ) + yield (md, [], [], "", "", {}) # Drain any final messages still in the queue. while not status_q.empty(): @@ -390,9 +361,16 @@ def _worker(): err = result_holder.get("error", "unknown error") yield ( f"### error\n\n```\n{err}\n```", - [], [], "", "", - _live_status_md(resolved.kind, resolved.model, mode, - time.time() - t_start, status_lines), + [], + [], + "", + _live_status_md( + resolved.kind, + resolved.model, + mode, + time.time() - t_start, + status_lines, + ), {"error": err}, ) return @@ -403,13 +381,17 @@ def _worker(): _command_rows(result.get("setup_results") or []), _command_rows(result.get("verify_results") or []), _files_md(result.get("files") or {}), - _logprobs_md(result.get("proxy_turns") or []), ( - f"### live phase log\n\n" - + _live_status_md(resolved.kind, resolved.model, mode, - time.time() - t_start, status_lines) - + f"\n\n### agent log (tail)\n```\n{result.get('agent_log_tail', '')[:4000]}\n```\n\n" - f"### proxy log (tail)\n```\n{result.get('proxy_log_tail', '')[:4000]}\n```" + "### live phase log\n\n" + + _live_status_md( + resolved.kind, + resolved.model, + mode, + time.time() - t_start, + status_lines, + ) + + f"\n\n### proxy log (tail)\n```\n{result.get('proxy_log_tail', '')[:3000]}\n```" + + f"\n\n### agent log (tail)\n```\n{result.get('agent_log_tail', '')[:4000]}\n```" ), result, ) @@ -422,8 +404,8 @@ def apply_preset(name: str) -> tuple[str, str, str]: gr.Markdown(f"# {title or 'opencode_env'}") gr.Markdown( "Run one OpenCode rollout in an E2B sandbox against your chosen " - "LLM endpoint. Pick an endpoint, write the task as `(instruction, " - "setup, verify)`, and inspect the reward + per-token logprobs." + "LLM endpoint. Pick an endpoint, write the task as " + "`(instruction, setup, verify)`, and inspect reward + logs." ) gr.Markdown(_catalog_banner()) @@ -436,33 +418,37 @@ def apply_preset(name: str) -> tuple[str, str, str]: scale=1, ) model = gr.Textbox( - label="Model (blank β†’ catalog default)", placeholder="gpt-4o-mini", + label="Model (blank β†’ catalog default)", + placeholder="gpt-4o-mini", scale=2, ) with gr.Row(): base_url = gr.Textbox( label="Base URL (blank β†’ env / catalog default)", - placeholder="https://api.openai.com/v1", scale=2, + placeholder="https://api.openai.com/v1", + scale=2, ) api_key = gr.Textbox( label="API key (blank β†’ server env var)", - placeholder="(server env)", type="password", scale=1, + placeholder="(server env)", + type="password", + scale=1, ) instruction = gr.Textbox( - label="Instruction (the prompt opencode runs)", + label="Instruction (the prompt OpenCode runs)", lines=4, value=PRESETS["binary_search"]["instruction"], ) with gr.Row(): setup_text = gr.Textbox( - label="Setup (one bash command per line β€” runs BEFORE the agent)", + label="Setup (one bash command per line β€” runs BEFORE OpenCode)", lines=5, value=PRESETS["binary_search"]["setup"], ) verify_text = gr.Textbox( - label="Verify (one bash command per line β€” runs AFTER the agent)", + label="Verify (one bash command per line β€” runs AFTER OpenCode)", lines=5, value=PRESETS["binary_search"]["verify"], ) @@ -515,8 +501,6 @@ def apply_preset(name: str) -> tuple[str, str, str]: ) with gr.Tab("Files"): files_md = gr.Markdown("") - with gr.Tab("Logprobs"): - logprobs_md = gr.Markdown("") with gr.Tab("Logs"): logs_md = gr.Markdown("") with gr.Tab("Raw JSON"): @@ -536,14 +520,27 @@ def apply_preset(name: str) -> tuple[str, str, str]: run_btn.click( fn=run, inputs=[ - endpoint, model, base_url, api_key, - instruction, setup_text, verify_text, - mode, disable_thinking, template, - max_tokens_cap, top_logprobs, agent_timeout_s, + endpoint, + model, + base_url, + api_key, + instruction, + setup_text, + verify_text, + mode, + disable_thinking, + template, + max_tokens_cap, + top_logprobs, + agent_timeout_s, ], outputs=[ - summary_md, setup_table, verify_table, - files_md, logprobs_md, logs_md, raw_json, + summary_md, + setup_table, + verify_table, + files_md, + logs_md, + raw_json, ], ) diff --git a/envs/opencode_env/server/opencode_environment.py b/envs/opencode_env/server/opencode_environment.py index 07f0d69ed..52ae27b4d 100644 --- a/envs/opencode_env/server/opencode_environment.py +++ b/envs/opencode_env/server/opencode_environment.py @@ -6,22 +6,23 @@ """OpenCode MCP environment. -Single MCP tool ``run_rollout`` that takes a uniform Task shape: +Single MCP tool ``run_rollout`` with a uniform task shape: - - ``instruction`` β€” prompt for the agent - - ``setup`` β€” bash commands run BEFORE the agent (in the sandbox) - - ``verify`` β€” bash commands run AFTER the agent + - ``instruction`` β€” prompt for OpenCode + - ``setup`` β€” bash commands run BEFORE OpenCode (in the sandbox) + - ``verify`` β€” bash commands run AFTER OpenCode Reward = ``passed_verify_commands / total`` unless a verify command writes a float to ``/home/user/logs/verifier/reward.txt`` (override). -Returns a JSON-serialized :class:`RolloutResult` with reward + per-turn -logprobs (Mode B) + setup/verify command results + file outputs. +Returns a JSON-serialized :class:`RolloutResult` with reward, +setup/verify command results, transparent-proxy logprob turns, and file outputs. """ from __future__ import annotations import json +import logging import os import time from typing import Any, Optional @@ -40,7 +41,7 @@ from server.catalog import ENDPOINT_KINDS, resolve_endpoint # type: ignore -# One rollout (sandbox cold start + opencode install + opencode run + +# One rollout (sandbox cold start + OpenCode install + agent run + # verifier) typically takes 30-180s; can spike to ~600s under load. Override # OpenEnv's 30s MCP-tool default so the server doesn't cut us off. _RUN_ROLLOUT_TIMEOUT_S = 900.0 @@ -49,6 +50,8 @@ HOME = "/home/user" WORKDIR = f"{HOME}/workdir" INSTRUCTION_PATH = f"{HOME}/task/instruction.md" +_log = logging.getLogger(__name__) + REWARD_FILE = f"{HOME}/logs/verifier/reward.txt" PROXY_LOG = f"{HOME}/logs/agent/proxy.log" AGENT_LOG = f"{HOME}/logs/agent/opencode.jsonl" @@ -64,25 +67,27 @@ def __init__(self) -> None: # Lazy imports so module import stays cheap and so tests can patch. try: from ..models import ( - CommandResult, OpenCodeState, + CommandResult, RolloutResult, RolloutTurn, ) except ImportError: # pragma: no cover from models import ( # type: ignore - CommandResult, OpenCodeState, + CommandResult, RolloutResult, RolloutTurn, ) - from opencode_env import ( - E2BSandboxBackend, - OpenCodeConfig, - OpenCodeSessionFactory, - OpenCodeTask, - ) + from opencode_env.config import OpenCodeConfig + from opencode_env.harness import OpenCodeSessionFactory + from opencode_env.task import OpenCodeTask + + try: + from openenv.core.harness.sandbox import E2BSandboxBackend + except ImportError: + E2BSandboxBackend = None # type: ignore[assignment,misc] self._CommandResult = CommandResult self._RolloutResult = RolloutResult @@ -283,78 +288,65 @@ def _emit(msg: str) -> None: _emit(f"resolving config (model={model}, mode={mode})") - # Build OpenCodeConfig + factory. We keep the proxy in charge of - # ``model_override`` / ``logprobs`` / ``max_tokens``-cap injection. - config = self._OpenCodeConfig( - provider="openai_compatible", - base_url=base_url.rstrip("/"), + config = self._build_agent_config( + base_url=base_url, api_key=api_key, model=model, agent_timeout_s=agent_timeout_s, - proxy_disable_thinking=disable_thinking, - proxy_top_logprobs=top_logprobs, - proxy_max_tokens_cap=max_tokens_cap if max_tokens_cap > 0 else None, + disable_thinking=disable_thinking, + top_logprobs=top_logprobs, + max_tokens_cap=max_tokens_cap, ) - # Concatenate setup commands into a single ``set -e`` script and let - # the primitive run it as ``task.setup_shell`` before the agent - # starts. The per-command tracking happens here too β€” we re-run - # each command in a wrapper that captures exit/stdout/stderr. - # That way the primitive still aborts on setup failure AND we get - # observability in the response. - instruction_payload = instruction - opencode_task = self._OpenCodeTask( - instruction=instruction_payload, + # Concatenate setup commands into a single ``set -e`` script so the + # primitive runs them inside _bootstrap_sandbox BEFORE the agent + # starts. This avoids the race where the agent's first tool call + # depends on files or packages that setup is still installing. + setup_shell: str | None = None + if setup: + # ``set -e`` makes the script abort on the first failing command. + setup_shell = "set -e\n" + "\n".join(setup) + + rollout_task = self._OpenCodeTask( + instruction=instruction, + setup_shell=setup_shell, metadata={"task_id": task_id}, ) - backend_kwargs: dict[str, Any] = {} - if template: - backend_kwargs["template"] = template - - factory = self._OpenCodeSessionFactory( - config=config, - sandbox_backend=self._E2BSandboxBackend(**backend_kwargs), - mode=mode, - verifier=None, - ) - session = None try: + factory = self._build_session_factory( + config=config, + mode=mode, + template=template, + ) _emit( f"creating E2B sandbox (template={template or 'default'}) β€” " "this is the slow phase (~5–60s cold, ~5s with template)" ) - session = factory.create(task=opencode_task) + session = factory.create(task=rollout_task) result.sandbox_id = session.sandbox.sandbox_id - _emit( - f"sandbox ready: {result.sandbox_id} β€” agent started " - f"({'proxy on :7000, logprobs capturing' if mode == 'transparent_proxy' else 'direct LLM, no logprobs'})" - ) - - # Run setup commands one at a time, *before* the agent starts. - # The factory has already started the agent in start_agent() - # during create(); to keep the order "setup β†’ agent β†’ verify" - # we'd need to restructure. As a pragmatic compromise we run - # setup IMMEDIATELY after create(), which races with the agent - # for ~1-2s but is fine for typical pip/git/download work - # because opencode itself takes >=20s to make its first model - # call. - for i, cmd in enumerate(setup, 1): - _emit(f"setup [{i}/{len(setup)}]: {cmd[:80]}") - cr = self._exec_command(session.sandbox, cmd) - result.setup_results.append(cr) - if cr.exit_code != 0: - result.error = ( - f"setup command failed (exit {cr.exit_code}): {cmd[:120]}" + _emit(f"sandbox ready: {result.sandbox_id} β€” agent started (mode={mode})") + + # setup commands already ran atomically during sandbox bootstrap. + # Avoid re-running them here because many setup scripts are not + # idempotent (e.g., migrations, one-shot installs, destructive prep). + # We still surface per-command bookkeeping for callers. + for cmd in setup: + result.setup_results.append( + self._CommandResult( + cmd=cmd, + exit_code=None, + stdout="executed during bootstrap (individual exit code not captured)", + stderr="", + duration_s=0.0, ) - _emit(f"setup FAILED at [{i}]: exit={cr.exit_code}") - break + ) - # Block until the agent is done (or setup already failed). + # Block until the agent is done. if result.error is None: _emit( - f"agent running β€” opencode CLI in sandbox " + "agent running β€” OpenCode CLI in sandbox " f"(timeout {int(agent_timeout_s)}s)" ) try: @@ -384,12 +376,12 @@ def _emit(msg: str) -> None: else: result.reward = None - # Collect filesystem + proxy trace. - _emit("collecting workdir files + proxy trace + logs") + # Collect filesystem + logs + transparent-proxy trace. + _emit("collecting workdir files + logs") result.files, result.files_extra = self._collect_files(session.sandbox) result.proxy_turns = self._collect_proxy_turns(session) result.proxy_log_tail = self._safe_read(session.sandbox, PROXY_LOG)[-2000:] - result.agent_log_tail = self._safe_read(session.sandbox, AGENT_LOG)[-2000:] + result.agent_log_tail = self._collect_agent_log_tail(session) _emit( f"collected: {len(result.files)} file(s), " f"{len(result.proxy_turns)} proxy turn(s), " @@ -400,7 +392,7 @@ def _emit(msg: str) -> None: _emit(f"ERROR: {result.error}") if session is not None: result.proxy_log_tail = self._safe_read(session.sandbox, PROXY_LOG)[-2000:] - result.agent_log_tail = self._safe_read(session.sandbox, AGENT_LOG)[-2000:] + result.agent_log_tail = self._collect_agent_log_tail(session) finally: if session is not None: try: @@ -420,6 +412,68 @@ def _emit(msg: str) -> None: return result.model_dump_json() + def _build_agent_config( + self, + *, + base_url: str, + api_key: str, + model: str, + agent_timeout_s: float, + disable_thinking: bool, + top_logprobs: int, + max_tokens_cap: int, + ) -> Any: + cap = max_tokens_cap if max_tokens_cap > 0 else None + return self._OpenCodeConfig( + provider="openai_compatible", + base_url=base_url.rstrip("/"), + api_key=api_key, + model=model, + agent_timeout_s=agent_timeout_s, + disable_thinking=disable_thinking, + max_tokens_cap=cap, + proxy_disable_thinking=disable_thinking, + proxy_top_logprobs=max(0, int(top_logprobs)), + proxy_max_tokens_cap=cap, + ) + + def _build_session_factory( + self, + *, + config: Any, + mode: str, + template: str, + ) -> Any: + if self._E2BSandboxBackend is None: + raise RuntimeError( + "E2BSandboxBackend unavailable: install optional dependency 'e2b'." + ) + + backend_kwargs: dict[str, Any] = {} + if template: + backend_kwargs["template"] = template + backend = self._E2BSandboxBackend(**backend_kwargs) + + return self._OpenCodeSessionFactory( + config=config, + sandbox_backend=backend, + mode=mode, + verifier=None, + ) + + def _collect_agent_log_tail(self, session: Any) -> str: + if hasattr(session, "collect_artifacts"): + try: + artifacts = session.collect_artifacts() + if isinstance(artifacts, dict) and "agent_log" in artifacts: + val = artifacts["agent_log"] + if isinstance(val, str): + return val[-2000:] + return json.dumps(val, default=str)[-2000:] + except Exception: + pass + return self._safe_read(session.sandbox, AGENT_LOG)[-2000:] + # ── Helpers ──────────────────────────────────────────────────────────── def _exec_command(self, sandbox: Any, cmd: str) -> Any: @@ -450,9 +504,7 @@ def _read_reward(self, sandbox: Any) -> float | None: except ValueError: return None - def _collect_files( - self, sandbox: Any - ) -> tuple[dict[str, str], list[str]]: + def _collect_files(self, sandbox: Any) -> tuple[dict[str, str], list[str]]: listing = sandbox.exec( f"find {WORKDIR} -maxdepth 2 -type f -size -64k 2>/dev/null | head -32", timeout=10, @@ -471,18 +523,9 @@ def _collect_files( def _collect_proxy_turns(self, session: Any) -> list[Any]: turns: list[Any] = [] - proxy_trace_path = getattr(session, "_proxy_trace_path", None) - if not proxy_trace_path: + if not hasattr(session, "fetch_proxy_trace"): return turns - raw = self._safe_read(session.sandbox, proxy_trace_path) - for line in raw.splitlines(): - line = line.strip() - if not line: - continue - try: - rec = json.loads(line) - except Exception: - continue + for rec in session.fetch_proxy_trace(): response = rec.get("response") or {} turns.append( self._RolloutTurn( @@ -490,10 +533,7 @@ def _collect_proxy_turns(self, session: Any) -> list[Any]: finish_reason=rec.get("finish_reason"), completion_tokens=list(rec.get("completion_tokens") or []), completion_token_ids=list(rec.get("completion_token_ids") or []), - per_token_logps=[ - float(x) for x in (rec.get("per_token_logps") or []) - if x is not None - ], + per_token_logps=list(rec.get("per_token_logps") or []), latency_s=float(rec.get("latency_s") or 0.0), timestamp=float(rec.get("timestamp") or 0.0), upstream_status=response.get("upstream_status"), diff --git a/envs/opencode_env/uv.lock b/envs/opencode_env/uv.lock index 80dd00ba0..aa802ee9d 100644 --- a/envs/opencode_env/uv.lock +++ b/envs/opencode_env/uv.lock @@ -1663,37 +1663,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/12/cf/03675d8bd8ecbf4445504d8071adab19f5f993676795708e36402ab38263/openapi_pydantic-0.5.1-py3-none-any.whl", hash = "sha256:a3a09ef4586f5bd760a8df7f43028b60cafb6d9f61de2acba9574766255ab146", size = 96381, upload-time = "2025-01-08T19:29:25.275Z" }, ] -[[package]] -name = "openenv-core" -version = "0.2.3" -source = { git = "https://github.com/adithya-s-k/OpenEnv.git?rev=opencode-harness#aabcdbb9d52aa62a842ec69472b2a1106acb831a" } -dependencies = [ - { name = "fastapi" }, - { name = "fastmcp" }, - { name = "gradio" }, - { name = "httpx" }, - { name = "huggingface-hub" }, - { name = "openai" }, - { name = "pydantic" }, - { name = "pyyaml" }, - { name = "requests" }, - { name = "rich" }, - { name = "tomli" }, - { name = "tomli-w" }, - { name = "typer" }, - { name = "uvicorn" }, - { name = "websockets" }, -] - -[package.optional-dependencies] -core = [ - { name = "fastapi" }, - { name = "pydantic" }, - { name = "requests" }, - { name = "uvicorn" }, - { name = "websockets" }, -] - [[package]] name = "openenv-opencode-env" version = "0.1.0" @@ -1724,7 +1693,7 @@ requires-dist = [ { name = "fastmcp", specifier = ">=2.0.0" }, { name = "gradio", specifier = ">=6.0.0" }, { name = "httpx", specifier = ">=0.27.0" }, - { name = "openenv-core", extras = ["core"], git = "https://github.com/adithya-s-k/OpenEnv.git?rev=opencode-harness" }, + { name = "openenv-core", extras = ["core"], specifier = ">=0.3.0" }, { name = "pydantic", specifier = ">=2.0.0" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" }, { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.23.0" }, @@ -1734,6 +1703,41 @@ requires-dist = [ ] provides-extras = ["dev"] +[[package]] +name = "openenv-core" +version = "0.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "fastapi" }, + { name = "fastmcp" }, + { name = "gradio" }, + { name = "httpx" }, + { name = "huggingface-hub" }, + { name = "openai" }, + { name = "pydantic" }, + { name = "pyyaml" }, + { name = "requests" }, + { name = "rich" }, + { name = "tomli" }, + { name = "tomli-w" }, + { name = "typer" }, + { name = "uvicorn" }, + { name = "websockets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ce/d6/3bebe8afb55fcc3ea9251c4c2dfbab2879e31089bc91a8fe9696e5ce019b/openenv_core-0.3.0.tar.gz", hash = "sha256:c7fee2035badab5be497eb6f4afb2cb417de000f82cc19afd72fb5ec332c431d", size = 164720, upload-time = "2026-05-11T11:37:57.274Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/f5/aafa43138589bfd5d369a8d02ea365aae9d6fe55ac0b3894368d6d69bd03/openenv_core-0.3.0-py3-none-any.whl", hash = "sha256:859e875c9d5211b157c30fb9abc681606fcf0bf1b6ffcdf404678992823a1df0", size = 194313, upload-time = "2026-05-11T11:37:55.537Z" }, +] + +[package.optional-dependencies] +core = [ + { name = "fastapi" }, + { name = "pydantic" }, + { name = "requests" }, + { name = "uvicorn" }, + { name = "websockets" }, +] + [[package]] name = "opentelemetry-api" version = "1.41.1" diff --git a/examples/opencode_env_simple.py b/examples/opencode_env_simple.py index 1713880fb..660421fdd 100644 --- a/examples/opencode_env_simple.py +++ b/examples/opencode_env_simple.py @@ -14,12 +14,9 @@ 1. Spawns a fresh E2B sandbox (using the prebaked ``opencode-rl`` template β€” falls back to a cold install if the template isn't present in your E2B account). - 2. Bootstraps an in-sandbox FastAPI proxy that captures per-token - logprobs (``mode="transparent_proxy"``). - 3. Runs ``opencode run`` with the instruction. - 4. Executes the verify bash commands; reward = passed / total. - 5. Returns a ``RolloutResult`` with reward + per-turn logprobs + - the file contents the agent produced. + 2. Runs OpenCode with the instruction. + 3. Executes the verify bash commands; reward = passed / total. + 4. Returns a ``RolloutResult`` with reward + produced file contents. Prerequisites ------------- @@ -34,7 +31,6 @@ Expected output (~20s with the prebaked template):: reward: 1.0 - turns: 3 files: ['/home/user/workdir/binary_search.py', ...] wall: 19.8 s """ @@ -54,7 +50,9 @@ from opencode_env.models import RolloutResult # noqa: E402 -SPACE = os.environ.get("OPENCODE_ENV_SPACE", "https://adithyask-opencode-env.hf.space") +SPACE = os.environ.get( + "OPENCODE_ENV_SPACE", "https://adithyask-opencode-env.hf.space" +) INSTRUCTION = ( "Create a single Python file named `binary_search.py` in the current " @@ -109,8 +107,6 @@ async def main() -> int: print("--- result ---") print(f"reward: {result.reward}") - print(f"turns: {len(result.proxy_turns)}") - print(f"tokens: {sum(len(t.completion_tokens) for t in result.proxy_turns)}") print(f"sandbox: {result.sandbox_id}") print(f"wall_s: {result.wall_s}") print(f"files: {sorted(result.files)}") @@ -118,16 +114,6 @@ async def main() -> int: if result.error: print(f"error: {result.error}") - if result.proxy_turns: - first = next((t for t in result.proxy_turns if t.completion_tokens), None) - if first: - print() - print("--- first productive turn (first 8 tokens with logprobs) ---") - toks = first.completion_tokens[:8] - lps = first.per_token_logps[:8] - for tok, lp in zip(toks, lps): - print(f" {tok!r:<14} {lp:+.3f}") - return 0 if result.reward == 1.0 else 1 diff --git a/pyproject.toml b/pyproject.toml index 08f1bb6d3..e40b79c9a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ # Web UI dependencies "gradio>=4.0.0", "httpx>=0.28.1", + "aiohttp>=3.13.5", ] [project.optional-dependencies] diff --git a/src/openenv/core/harness/agents/__init__.py b/src/openenv/core/harness/agents/__init__.py new file mode 100644 index 000000000..b715582a4 --- /dev/null +++ b/src/openenv/core/harness/agents/__init__.py @@ -0,0 +1,111 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Agent registry and public API for CLI-based agentic harnesses. + +The registry maps agent names (``"opencode"``, ``"claude-code"``, etc.) to +their :class:`CLIAgentSpec` declarations. Each agent module registers itself +via :func:`register_agent` at import time. + +Usage:: + + from openenv.core.harness.agents import get_agent_spec, list_agents + + spec = get_agent_spec("opencode") + print(list_agents()) # ["opencode"] +""" + +from __future__ import annotations + +from .base import ( + AgentConfig, + AgentEvent, + AgentTask, + ArtifactSpec, + CLIAgentSpec, + MCPConfigSpec, +) +from .interception_server import deliver_response, InterceptionServer + +# Registry + +_REGISTRY: dict[str, CLIAgentSpec] = {} + + +def register_agent(spec: CLIAgentSpec) -> None: + """Register a :class:`CLIAgentSpec` under ``spec.name``. + + Raises :class:`ValueError` if the name is already registered with a + *different* spec object (re-registering the same object is a no-op, + which makes ``importlib.reload`` safe). + """ + existing = _REGISTRY.get(spec.name) + if existing is not None and existing is not spec: + raise ValueError( + f"Agent {spec.name!r} is already registered. " + "Use a unique name or call unregister_agent() first." + ) + _REGISTRY[spec.name] = spec + + +def unregister_agent(name: str) -> CLIAgentSpec | None: + """Remove a registered agent spec, returning it (or ``None``).""" + return _REGISTRY.pop(name, None) + + +def get_agent_spec(name: str) -> CLIAgentSpec: + """Look up a registered agent spec by name. + + Raises :class:`KeyError` if not found. To trigger auto-registration of + built-in agents, import the specific module first (e.g. + ``import openenv.core.harness.agents.opencode``). + """ + if name not in _REGISTRY: + # Auto-import built-in agent modules to trigger registration. + _auto_import(name) + try: + return _REGISTRY[name] + except KeyError: + available = ", ".join(sorted(_REGISTRY)) or "(none)" + raise KeyError( + f"Unknown agent {name!r}. Registered agents: {available}" + ) from None + + +def list_agents() -> list[str]: + """Return sorted names of all registered agents.""" + return sorted(_REGISTRY) + + +def _auto_import(name: str) -> None: + """Try to import the built-in module for ``name`` to trigger registration.""" + # Map agent names to module names (handles hyphens). + module_name = name.replace("-", "_") + try: + __import__(f"openenv.core.harness.agents.{module_name}", fromlist=["_"]) + except ImportError: + pass + + +# Convenience re-exports + +__all__ = [ + # Registry + "get_agent_spec", + "list_agents", + "register_agent", + "unregister_agent", + # Base types + "AgentConfig", + "AgentEvent", + "AgentTask", + "ArtifactSpec", + "CLIAgentSpec", + "MCPConfigSpec", + # Interception gate + "InterceptionServer", + "deliver_response", +] diff --git a/src/openenv/core/harness/agents/base.py b/src/openenv/core/harness/agents/base.py new file mode 100644 index 000000000..4ec1c297a --- /dev/null +++ b/src/openenv/core/harness/agents/base.py @@ -0,0 +1,256 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Agent spec and event protocols for CLI-based agentic harnesses. + +Defines the declarative :class:`CLIAgentSpec` data model that captures +*everything* a CLI harness needs β€” install commands, file uploads, MCP +config format, environment variables, artifacts to collect, and three +small callables (command builder, MCP config builder, event parser). + +The :class:`CLIAgentDriver` reads these fields mechanically without knowing +anything about the specific agent. Adding a new agent is filling in a +dataclass, not writing driver code. + +Pattern borrowed from `verifiers `_ +(Prime Intellect), where OpenCode, MiniSWEAgent, Pi, and RLM all express +their differences through constructor data passed to ``CLIHarness.__init__()``. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Callable, Literal, Protocol + + +# MCP config injection + + +@dataclass(frozen=True) +class MCPConfigSpec: + """How a harness discovers MCP tools. + + ``method`` controls how the driver injects MCP server configuration: + + - ``"config_file"`` β€” write a JSON file at ``path_template`` (e.g. + ``"{workdir}/mcp.json"``). The template receives ``{workdir}`` + and ``{home}`` substitutions at runtime. + - ``"cli_flags"`` β€” the driver passes MCP configuration via CLI + flags built by :attr:`CLIAgentSpec.build_command`. + - ``"settings_file"`` β€” write into a global settings file (e.g. + e.g. ``~/.config/agent/settings.json``). + """ + + method: Literal["config_file", "cli_flags", "settings_file"] + path_template: str | None = None + + +# Artifacts + + +@dataclass(frozen=True) +class ArtifactSpec: + """Declares a file to collect from the sandbox after the agent exits. + + The driver iterates :attr:`CLIAgentSpec.artifacts` and calls + ``sandbox.read_text(spec.path)`` for each entry. No per-agent collection + methods needed β€” the spec declares *what* to collect, the driver collects + it. + """ + + path: str + format: Literal["text", "json", "jsonl"] = "text" + optional: bool = True + + +# Agent events (normalized across harnesses) + + +@dataclass +class AgentEvent: + """Normalized event from any CLI harness's stdout. + + The :attr:`CLIAgentSpec.parse_events` callable converts raw JSONL lines + into these events so the driver can log and observe the agent's progress + without knowing which agent is running. + """ + + type: Literal[ + "assistant", + "tool_call", + "tool_result", + "reasoning", + "error", + "done", + ] + data: dict[str, Any] = field(default_factory=dict) + raw: str = "" + + +# Task protocol + + +class AgentTask(Protocol): + """Minimal interface a task must satisfy for the CLI agent driver.""" + + @property + def instruction(self) -> str: ... + + @property + def setup_shell(self) -> str | None: ... + + @property + def upload_files(self) -> dict[str, str]: ... + + @property + def metadata(self) -> dict[str, Any]: ... + + +# Agent config protocol + + +class AgentConfig(Protocol): + """Minimal interface a config must satisfy for the CLI agent driver. + + This is intentionally thin β€” concrete configs like :class:`OpenCodeConfig` + carry much more, but the generic driver only accesses these. + """ + + @property + def base_url(self) -> str: ... + + @property + def api_key(self) -> str: ... + + @property + def model(self) -> str: ... + + @property + def agent_timeout_s(self) -> float: ... + + +# CLIAgentSpec β€” the core declarative data model + + +@dataclass +class CLIAgentSpec: + """Declarative specification for a CLI-based agentic harness. + + Following the pattern established by verifiers' ``CLIHarness`` (Prime + Intellect), as much per-agent knowledge as possible is expressed as + *data* rather than imperative code. The :class:`CLIAgentDriver` + iterates these fields mechanically β€” it never needs to know what + ``"pi"`` or ``"claude-code"`` means. + + Three callables cover the remaining agent-specific logic that can't + be expressed as pure data: + + - :attr:`build_command` β€” constructs the CLI argv + - :attr:`build_mcp_config` β€” serializes MCP server configuration + - :attr:`parse_events` β€” converts raw stdout lines to :class:`AgentEvent` + + Everything else β€” file uploads, env vars, install scripts, artifact + collection β€” is pure data. + """ + + name: str + """Unique identifier: ``"opencode"``, ``"claude-code"``, ``"codex"``, etc.""" + + install_check_cmd: list[str] + """Command to probe whether the agent is already installed. + + Example: ``["claude", "--version"]`` + """ + + base_command: list[str] + """Base CLI invocation (before task-specific flags). + + Example: ``["claude", "--print", "--output-format", "stream-json"]`` + """ + + mcp_config: MCPConfigSpec + """How MCP tool configuration is injected.""" + + default_timeout_s: float = 600.0 + """Default per-rollout timeout in seconds.""" + + setup: str | list[str] | None = None + """Shell command(s) to install the agent CLI inside the sandbox. + + Run once after the sandbox is created, before any files are written. + Skipped when ``install_check_cmd`` succeeds (pre-baked template). + Can be a single string or a list of strings executed in order. + """ + + files: dict[str, str | Callable] | None = None + """Files to upload into the sandbox before the agent starts. + + Keys are absolute sandbox paths. Values are either literal strings or + callables ``(task, config) -> str`` resolved at rollout time. + """ + + artifacts: dict[str, ArtifactSpec] | None = None + """Files to collect from the sandbox after the agent exits. + + The driver iterates this dict and calls ``sandbox.read_text(spec.path)`` + for each entry. + """ + + env: dict[str, str] | None = None + """Environment variables for the agent process. + + Values can contain ``{model}``, ``{base_url}``, ``{api_key}`` placeholders + resolved from the rollout config at runtime. + """ + + extension_dir_template: str | None = None + """Optional extension install directory template. + + Receives ``{home}`` substitution at runtime (e.g. + ``"{home}/.pi/agent/extensions"``). Drivers may use this to create + extension directories in the correct sandbox user home. + """ + + build_command: Callable[..., str] | None = None + """``(spec, config, task, mcp_config_path) -> str`` + + Build the full shell command line for launching the agent. Returns a + string (not a list) because sandbox ``start_bg`` / ``exec`` take shell + strings. + """ + + build_mcp_config: Callable[..., str] | None = None + """``(spec, tools, workdir) -> str`` + + Serialize MCP server configuration in the format the agent expects. + Returns the file content (for ``config_file``/``settings_file`` methods) + or empty string (for ``cli_flags``, where the command builder handles it). + """ + + parse_events: Callable[[str], AgentEvent | None] | None = None + """``(line: str) -> AgentEvent | None`` + + Parse one line of the agent's stdout into a normalized event. + Return ``None`` for lines that are not parseable events. + """ + + build_env_vars: Callable[..., dict[str, str]] | None = None + """``(spec, config) -> dict[str, str]`` + + Optional override for env var construction. When provided, this is + called *instead of* resolving placeholders in :attr:`env`. Prefer + the declarative :attr:`env` dict for new agents. + """ + + +__all__ = [ + "AgentConfig", + "AgentEvent", + "AgentTask", + "ArtifactSpec", + "CLIAgentSpec", + "MCPConfigSpec", +] diff --git a/src/openenv/core/harness/agents/cli_driver.py b/src/openenv/core/harness/agents/cli_driver.py new file mode 100644 index 000000000..80d482ef3 --- /dev/null +++ b/src/openenv/core/harness/agents/cli_driver.py @@ -0,0 +1,666 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Shared CLI agent driver, session, and session factory. + +Two modes are supported: + +- ``black_box`` β€” the agent talks directly to the upstream LLM. No logprob + capture. For eval and demos. +- ``interception_gate`` β€” the agent's LLM calls are routed to an + :class:`InterceptionServer` running on the trainer host. The training + loop owns the forward pass and delivers responses back. For RL training. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import queue as _queue_mod +import shlex +import time +import uuid +from typing import Any, Callable, Literal + +from openenv.core.env_server.mcp_types import Tool +from openenv.core.harness import ( + Message, + ResourceSession, + ResourceSessionFactory, + ToolResult, + VerifyResult, +) +from openenv.core.harness.sandbox import BgJob, SandboxBackend, SandboxHandle + +from .base import CLIAgentSpec +from .interception_server import deliver_response, InterceptionServer, ToolHandler + + +_log = logging.getLogger(__name__) + +Verifier = Callable[..., VerifyResult] + + +def build_interception_rollout_url(base_url: str, rollout_id: str) -> str: + """Build OpenAI-compatible interception endpoint for one rollout.""" + return f"{base_url.rstrip('/')}/rollout/{rollout_id}/v1" + + +class _ConfigOverrideView: + """Read-only attribute view with optional overrides.""" + + def __init__(self, base: Any, **overrides: Any) -> None: + self._base = base + self._overrides = overrides + + def __getattr__(self, name: str) -> Any: + if name in self._overrides: + return self._overrides[name] + return getattr(self._base, name) + + +class CLIAgentSession(ResourceSession): + """Per-rollout session wrapping one sandbox with one running agent CLI.""" + + def __init__( + self, + *, + spec: CLIAgentSpec, + sandbox: SandboxHandle, + task: Any, + config: Any, + verifier: Verifier | None = None, + base_url_override: str | None = None, + agent_bg_job: BgJob | None = None, + interception_server: InterceptionServer | None = None, + interception_rollout_id: str | None = None, + interception_queue: _queue_mod.Queue[str] | None = None, + ) -> None: + self.spec = spec + self.sandbox = sandbox + self.task = task + self.config = config + self._verifier = verifier + self._base_url_override = base_url_override + self._agent_bg_job = agent_bg_job + self._interception_server = interception_server + self._interception_rollout_id = interception_rollout_id + self._interception_queue = interception_queue + + def initial_messages(self) -> list[Message]: + instruction = ( + self.task.instruction + if hasattr(self.task, "instruction") + else str(self.task) + ) + return [{"role": "user", "content": instruction}] + + def list_tools(self) -> list[Tool]: + return [] + + def call_tool(self, name: str, arguments: dict[str, Any]) -> ToolResult: + return ToolResult( + error=( + f"{self.spec.name} session does not expose external tool calls; " + "the CLI agent owns its own tool loop." + ) + ) + + def verify( + self, + transcript: list[Message], + final_state: Any | None = None, + ) -> VerifyResult: + if self._verifier is None: + return VerifyResult(env_reward=None, done=True) + return self._verifier(self.sandbox, self.task) + + def close(self) -> None: + if self._agent_bg_job is not None: + try: + self._agent_bg_job.kill() + except Exception: + pass + self._agent_bg_job = None + if ( + self._interception_server is not None + and self._interception_rollout_id is not None + ): + self._interception_server.unregister_rollout(self._interception_rollout_id) + self.sandbox.kill() + + def wait_for_completion(self, timeout_s: float | None = None) -> int: + """Block until the agent exits, returning its exit code.""" + if self._agent_bg_job is None: + raise RuntimeError("Agent not started.") + default_timeout = ( + self.config.agent_timeout_s + if hasattr(self.config, "agent_timeout_s") + else self.spec.default_timeout_s + ) + budget = timeout_s if timeout_s is not None else default_timeout + return self._agent_bg_job.wait(timeout=budget) + + def collect_artifacts(self) -> dict[str, Any]: + """Collect all artifacts declared in ``spec.artifacts`` from the sandbox.""" + result: dict[str, Any] = {} + if not self.spec.artifacts: + return result + for name, artifact_spec in self.spec.artifacts.items(): + try: + content = self.sandbox.read_text(artifact_spec.path) + if artifact_spec.format == "json": + result[name] = json.loads(content) + elif artifact_spec.format == "jsonl": + records = [] + for line in content.splitlines(): + line = line.strip() + if not line: + continue + try: + records.append(json.loads(line)) + except json.JSONDecodeError: + _log.debug( + "Skipping non-JSON line in %s: %s", + artifact_spec.path, + line[:120], + ) + result[name] = records + else: + result[name] = content + except Exception: + if not artifact_spec.optional: + raise + _log.debug( + "Optional artifact %r (%s) not found, skipping", + name, + artifact_spec.path, + ) + return result + + # interception_gate API + + async def next_request( + self, timeout_s: float | None = None + ) -> dict[str, Any] | None: + """Await the next LLM request from the agent (interception_gate only). + + Returns the intercept dict, or ``None`` when the agent has exited. + """ + if self._interception_queue is None: + raise RuntimeError( + "next_request() is only available in interception_gate mode." + ) + server = self._interception_server + assert server is not None + + deadline = time.time() + (timeout_s or self.spec.default_timeout_s) + while True: + remaining = deadline - time.time() + if remaining <= 0: + raise TimeoutError( + f"{self.spec.name} interception_gate: no request within timeout" + ) + try: + request_id = await asyncio.to_thread( + self._interception_queue.get, + timeout=min(remaining, 1.0), + ) + # None sentinel = agent process exited (sent by /exit endpoint) + if request_id is None: + return None + intercept = server.get_intercept(request_id) + if intercept is not None: + return intercept + except _queue_mod.Empty: + pass + + if self._agent_bg_job is not None: + try: + self._agent_bg_job.wait(timeout=0) + return None + except TimeoutError: + pass + continue + + async def deliver( + self, intercept: dict[str, Any], response_dict: dict[str, Any] + ) -> None: + """Return a trainer-generated response to the waiting agent.""" + await deliver_response(intercept, response_dict) + + def register_tool_handler( + self, + tool_name: str, + handler: ToolHandler, + *, + tool_definition: dict[str, Any] | None = None, + ) -> None: + """Register a host-side interception tool for this rollout.""" + if self._interception_server is None or self._interception_rollout_id is None: + raise RuntimeError( + "register_tool_handler() is only available in interception_gate mode." + ) + self._interception_server.register_tool_handler( + self._interception_rollout_id, + tool_name, + handler, + tool_definition=tool_definition, + ) + + +class CLIAgentDriver: + """Shared driver for all CLI-based agentic harnesses.""" + + def __init__( + self, + spec: CLIAgentSpec, + sandbox_backend: SandboxBackend, + mode: Literal["black_box", "interception_gate"] = "black_box", + *, + install_timeout_s: int = 240, + setup_timeout_s: int = 300, + interception_server: InterceptionServer | None = None, + interception_base_url: str | None = None, + ) -> None: + if mode not in {"black_box", "interception_gate"}: + raise ValueError(f"Unknown mode: {mode!r}") + if mode == "interception_gate": + if interception_server is None: + raise ValueError( + "interception_gate mode requires an InterceptionServer instance." + ) + if interception_base_url is None: + raise ValueError( + "interception_gate mode requires interception_base_url." + ) + self.spec = spec + self.sandbox_backend = sandbox_backend + self.mode = mode + self._install_timeout_s = install_timeout_s + self._setup_timeout_s = setup_timeout_s + self._interception_server = interception_server + self._interception_base_url = interception_base_url + + def bootstrap_sandbox(self, sandbox: SandboxHandle, task: Any, config: Any) -> None: + """Public bootstrap hook used by external wrappers. + + Runs readiness checks, optional install, file upload, MCP config write, + and task setup shell execution. + """ + self._bootstrap_sandbox(sandbox, task, config) + + def create_session( + self, + task: Any, + config: Any, + *, + verifier: Verifier | None = None, + seed: int | None = None, + episode_id: str | None = None, + ) -> CLIAgentSession: + timeout_s = ( + config.agent_timeout_s + if hasattr(config, "agent_timeout_s") + else self.spec.default_timeout_s + ) + sandbox_timeout = int(timeout_s) + 300 + sandbox = self.sandbox_backend.create( + timeout_s=sandbox_timeout, + metadata={"episode_id": episode_id} if episode_id else None, + ) + try: + self._bootstrap_sandbox(sandbox, task, config) + except Exception as exc: + _log.error("%s driver: bootstrap failed: %r", self.spec.name, exc) + sandbox.kill() + raise + + base_url_override: str | None = None + interception_rollout_id: str | None = None + interception_queue: _queue_mod.Queue[str] | None = None + + if self.mode == "interception_gate": + assert self._interception_server is not None + assert self._interception_base_url is not None + rollout_id = episode_id or f"rollout_{uuid.uuid4().hex[:8]}" + interception_rollout_id = rollout_id + interception_queue = self._interception_server.register_rollout(rollout_id) + base_url_override = build_interception_rollout_url( + self._interception_base_url, + rollout_id, + ) + + agent_bg_job = self._start_agent( + sandbox, task, config, base_url_override=base_url_override + ) + + return CLIAgentSession( + spec=self.spec, + sandbox=sandbox, + task=task, + config=config, + verifier=verifier, + base_url_override=base_url_override, + agent_bg_job=agent_bg_job, + interception_server=self._interception_server, + interception_rollout_id=interception_rollout_id, + interception_queue=interception_queue, + ) + + def _bootstrap_sandbox( + self, sandbox: SandboxHandle, task: Any, config: Any + ) -> None: + self._wait_for_sandbox_ready(sandbox) + if not self._agent_already_installed(sandbox): + self._install_agent(sandbox) + self._ensure_extension_dir(sandbox, config) + self._upload_files(sandbox, task, config) + self._write_mcp_config(sandbox, config) + setup_shell = task.setup_shell if hasattr(task, "setup_shell") else None + if setup_shell: + r = sandbox.exec(setup_shell, timeout=self._setup_timeout_s) + if r.exit_code != 0: + raise RuntimeError( + f"task.setup_shell failed ({r.exit_code}): {r.stderr}" + ) + + def _wait_for_sandbox_ready( + self, sandbox: SandboxHandle, *, attempts: int = 15, delay_s: float = 1.0 + ) -> None: + last_err = "" + for _ in range(attempts): + try: + r = sandbox.exec("echo ok", timeout=5) + if r.exit_code == 0 and "ok" in (r.stdout or ""): + return + last_err = (r.stderr or r.stdout or "").strip() or f"exit={r.exit_code}" + except Exception as exc: + last_err = f"{type(exc).__name__}: {exc}" + time.sleep(delay_s) + raise RuntimeError( + f"sandbox did not become ready within {attempts * delay_s:.0f}s " + f"(last error: {last_err})" + ) + + def _agent_already_installed(self, sandbox: SandboxHandle) -> bool: + cmd = " ".join(shlex.quote(c) for c in self.spec.install_check_cmd) + try: + r = sandbox.exec(cmd, timeout=10) + return r.exit_code == 0 + except Exception: + return False + + def _install_agent(self, sandbox: SandboxHandle) -> None: + if self.spec.setup is None: + raise RuntimeError( + f"Agent {self.spec.name!r} is not installed and no setup commands provided." + ) + commands = ( + [self.spec.setup] if isinstance(self.spec.setup, str) else self.spec.setup + ) + for cmd in commands: + self._exec_with_retry( + sandbox, + cmd, + timeout=self._install_timeout_s, + attempts=3, + backoff_s=3.0, + label=f"{self.spec.name} install", + ) + + def _resolve_sandbox_home(self, sandbox: SandboxHandle, config: Any) -> str: + configured = getattr(config, "sandbox_home", None) + if isinstance(configured, str) and configured.strip(): + return configured + try: + result = sandbox.exec('printf %s "$HOME"', timeout=5) + candidate = (result.stdout or "").strip() + if result.exit_code == 0 and candidate: + return candidate + except Exception: + pass + return "/home/user" + + def _ensure_extension_dir(self, sandbox: SandboxHandle, config: Any) -> None: + template = self.spec.extension_dir_template + if not template: + return + home = self._resolve_sandbox_home(sandbox, config) + extension_dir = template.format(home=home) + result = sandbox.exec(f"mkdir -p {shlex.quote(extension_dir)}", timeout=10) + if result.exit_code != 0: + raise RuntimeError( + f"failed to create extension dir {extension_dir!r}: {result.stderr}" + ) + + def _upload_files(self, sandbox: SandboxHandle, task: Any, config: Any) -> None: + if not self.spec.files: + return + for path, content_or_fn in self.spec.files.items(): + content = ( + content_or_fn(task, config) + if callable(content_or_fn) + else content_or_fn + ) + if content is not None: + sandbox.write_text(path, content) + upload_files = task.upload_files if hasattr(task, "upload_files") else {} + for path, content in upload_files.items(): + sandbox.write_text(path, content) + + def _write_mcp_config(self, sandbox: SandboxHandle, config: Any) -> None: + if self.spec.build_mcp_config is None: + return + if ( + self.spec.mcp_config.method == "config_file" + and self.spec.mcp_config.path_template + ): + home = ( + config.sandbox_home if hasattr(config, "sandbox_home") else "/home/user" + ) + workdir = ( + config.workdir + if hasattr(config, "workdir") and getattr(config, "workdir") + else f"{home}/workdir" + ) + mcp_path = self.spec.mcp_config.path_template.format( + workdir=workdir, home=home + ) + mcp_content = self.spec.build_mcp_config(self.spec, [], workdir) + if mcp_content: + sandbox.write_text(mcp_path, mcp_content) + + def _start_agent( + self, + sandbox: SandboxHandle, + task: Any, + config: Any, + *, + base_url_override: str | None = None, + ) -> BgJob: + command_config = config + if ( + self.mode == "interception_gate" + and self._interception_server is not None + and self.spec.name == "pi" + and base_url_override + ): + self._write_pi_models_config( + sandbox, + config, + rollout_url=base_url_override, + api_key=self._interception_server.secret, + ) + command_config = _ConfigOverrideView(config, provider="openenv") + + if self.spec.build_command is not None: + cmd = self.spec.build_command(self.spec, command_config, task, None) + else: + cmd = " ".join(shlex.quote(c) for c in self.spec.base_command) + envs = self._resolve_env_vars(config, base_url_override=base_url_override) + if self.spec.name == "pi": + home = self._resolve_sandbox_home(sandbox, config) + # Make pi config discovery independent of the runtime user's $HOME. + envs["PI_CODING_AGENT_DIR"] = f"{home}/.pi/agent" + if self.mode == "interception_gate" and self._interception_server is not None: + envs["OPENAI_API_KEY"] = self._interception_server.secret + envs["ANTHROPIC_API_KEY"] = self._interception_server.secret + + # Append an exit notification so the InterceptionServer detects + # agent exit immediately instead of waiting for the full timeout. + # The /exit endpoint enqueues a None sentinel on the request queue, + # causing next_request() to return None. + if base_url_override: + exit_url = f"{base_url_override.rstrip('/')}/exit" + auth_header = ( + "Authorization: Bearer " + f"{self._interception_server.secret}" + ) + cmd = ( + f"{{ {cmd} ; }} ; " + f"curl -sf -X POST -H {shlex.quote(auth_header)} " + f"{shlex.quote(exit_url)} || true" + ) + + return sandbox.start_bg(cmd, envs=envs) + + def _write_pi_models_config( + self, + sandbox: SandboxHandle, + config: Any, + *, + rollout_url: str, + api_key: str, + ) -> None: + home = self._resolve_sandbox_home(sandbox, config) + model = config.model if hasattr(config, "model") else "model" + content = json.dumps( + { + "providers": { + "openenv": { + "baseUrl": rollout_url, + "api": "openai-completions", + "apiKey": api_key, + "compat": { + "supportsDeveloperRole": False, + "supportsReasoningEffort": False, + }, + "models": [{"id": model, "reasoning": False}], + } + } + }, + indent=2, + ) + sandbox.write_text(f"{home}/.pi/agent/models.json", content) + + def _resolve_env_vars( + self, + config: Any, + *, + base_url_override: str | None = None, + ) -> dict[str, str]: + if self.spec.build_env_vars is not None: + return self.spec.build_env_vars(self.spec, config) + if not self.spec.env: + return {} + base_url = base_url_override or ( + config.base_url if hasattr(config, "base_url") else "" + ) + api_key = config.api_key if hasattr(config, "api_key") else "intercepted" + model = config.model if hasattr(config, "model") else "" + substitutions = {"base_url": base_url, "api_key": api_key, "model": model} + resolved: dict[str, str] = {} + for key, value in self.spec.env.items(): + try: + resolved[key] = value.format(**substitutions) + except KeyError: + resolved[key] = value + return resolved + + def _exec_with_retry( + self, + sandbox: SandboxHandle, + cmd: str, + *, + timeout: float, + attempts: int = 3, + backoff_s: float = 3.0, + label: str = "cmd", + ) -> Any: + last_stdout = "" + last_stderr = "" + last_exit = 0 + for i in range(attempts): + try: + r = sandbox.exec(cmd, timeout=timeout) + if r.exit_code == 0: + return r + last_stdout = r.stdout or "" + last_stderr = r.stderr or "" + last_exit = r.exit_code + except Exception as exc: + last_stderr = f"{type(exc).__name__}: {exc}" + last_exit = -1 + if i + 1 < attempts: + time.sleep(backoff_s * (2**i)) + raise RuntimeError( + f"{label} failed after {attempts} attempts " + f"(exit={last_exit}, stderr={last_stderr!r}, " + f"stdout_tail={last_stdout[-400:]!r})" + ) + + +class CLIAgentSessionFactory(ResourceSessionFactory): + def __init__( + self, + *, + spec: CLIAgentSpec, + config: Any, + sandbox_backend: SandboxBackend, + mode: Literal["black_box", "interception_gate"] = "black_box", + verifier: Verifier | None = None, + install_timeout_s: int = 240, + setup_timeout_s: int = 300, + interception_server: InterceptionServer | None = None, + interception_base_url: str | None = None, + ) -> None: + self._spec = spec + self._config = config + self._verifier = verifier + self._driver = CLIAgentDriver( + spec=spec, + sandbox_backend=sandbox_backend, + mode=mode, + install_timeout_s=install_timeout_s, + setup_timeout_s=setup_timeout_s, + interception_server=interception_server, + interception_base_url=interception_base_url, + ) + + def create( + self, + task: Any, + seed: int | None = None, + episode_id: str | None = None, + ) -> CLIAgentSession: + return self._driver.create_session( + task=task, + config=self._config, + verifier=self._verifier, + seed=seed, + episode_id=episode_id, + ) + + +__all__ = [ + "CLIAgentDriver", + "CLIAgentSession", + "CLIAgentSessionFactory", + "Verifier", + "build_interception_rollout_url", +] diff --git a/src/openenv/core/harness/agents/interception_server.py b/src/openenv/core/harness/agents/interception_server.py new file mode 100644 index 000000000..97573b352 --- /dev/null +++ b/src/openenv/core/harness/agents/interception_server.py @@ -0,0 +1,670 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Host-side interception server for trainer-owned generation. + +The :class:`InterceptionServer` runs on the trainer node, outside any +sandbox. Each sandbox's agent is pointed at:: + + http://:/rollout//v1 + +When the agent makes an LLM call it blocks at this server. The training +loop calls :meth:`~InterceptionServer.register_rollout` to get a queue, +``await queue.get()`` to dequeue the pending request, runs its own vLLM +forward pass, then calls :func:`deliver_response` to unblock the agent. + +The caller is responsible for making the server reachable from the sandbox. +For Docker sandboxes on the same machine, ``host.docker.internal:`` +works. For remote sandboxes (E2B, HF Sandbox), set up your own tunnel +(ngrok, frp, public IP, VPN) and pass the URL as +``interception_base_url``. + +Usage β€” training loop:: + + server = InterceptionServer(port=8765, tool_name_allowlist={"answer"}) + await server.start() + + # Make the server reachable β€” your responsibility. + # Docker: base_url = f"http://host.docker.internal:{server.port}" + # Remote: base_url = your_tunnel_or_public_url + + request_queue = server.register_rollout(rollout_id) + # Agent runs with OPENAI_BASE_URL = f"{base_url}/rollout/{rollout_id}/v1" + + while True: + request_id = await asyncio.to_thread(request_queue.get, timeout=...) + intercept = server.get_intercept(request_id) + if intercept is None: + continue + response = await vllm.generate(intercept["messages"], ...) + await deliver_response(intercept, response) + + server.unregister_rollout(rollout_id) + await server.stop() +""" + +from __future__ import annotations + +import asyncio +import hmac +import json +import logging +import queue as _queue_mod +import re +import secrets +import threading +import time +import uuid +from typing import Any, Awaitable, Callable + +from openenv.core.env_server.mcp_types import RESERVED_TOOL_NAMES + +from aiohttp import web + + +_log = logging.getLogger(__name__) + +_KEEPALIVE_INTERVAL_S = 3.0 +_MAX_REQUEST_BODY = 16 * 1024 * 1024 +_TOOL_NAME_RE = re.compile(r"^[A-Za-z0-9_-]{1,64}$") + +ToolHandler = Callable[[dict[str, Any]], Awaitable[dict[str, Any]]] + + +class InterceptionServer: + """Async HTTP server that gates every LLM call from sandboxed agents. + + One shared instance handles all concurrent rollouts. Each rollout is + identified by a ``rollout_id`` in the URL path. + """ + + def __init__( + self, + port: int = 0, + secret: str | None = None, + host: str = "127.0.0.1", + tool_name_allowlist: set[str] | None = None, + ) -> None: + self.port = port + self.host = host + self.secret = secret or secrets.token_urlsafe(32) + if not self.secret.strip(): + raise ValueError("InterceptionServer secret must not be blank.") + normalized_allowlist: set[str] = set() + for raw_name in tool_name_allowlist or set(): + name = raw_name.strip() + if not name: + raise ValueError("tool_name_allowlist must not include blank names") + if not _TOOL_NAME_RE.fullmatch(name): + raise ValueError( + "tool_name_allowlist entries must match " + f"^[A-Za-z0-9_-]{{1,64}}$ (got {raw_name!r})" + ) + normalized_allowlist.add(name) + self._tool_name_allowlist = frozenset(normalized_allowlist) + self._app: web.Application | None = None + self._runner: web.AppRunner | None = None + self._site: web.TCPSite | None = None + self._lock = asyncio.Lock() + self._state_lock = threading.RLock() + self.active_rollouts: dict[str, dict[str, Any]] = {} + self.intercepts: dict[str, dict[str, Any]] = {} + + async def start(self) -> None: + async with self._lock: + if self._app is not None: + return + app = web.Application(client_max_size=_MAX_REQUEST_BODY) + app.router.add_post( + "/rollout/{rollout_id}/v1/chat/completions", + self._handle_chat_completions, + ) + app.router.add_post( + "/rollout/{rollout_id}/v1/tools/{tool_name}", + self._handle_tool_call, + ) + app.router.add_post( + "/rollout/{rollout_id}/v1/exit", + self._handle_exit, + ) + app.router.add_get("/health", self._handle_health) + runner = web.AppRunner(app) + await runner.setup() + if self.host == "0.0.0.0": + _log.warning("InterceptionServer exposed on all interfaces (0.0.0.0).") + site = web.TCPSite(runner, self.host, self.port) + await site.start() + if self.port == 0: + server = getattr(site, "_server", None) + sockets = getattr(server, "sockets", None) if server else None + if sockets: + self.port = sockets[0].getsockname()[1] + if self.port == 0: + raise RuntimeError("Failed to resolve OS-assigned port") + self._app = app + self._runner = runner + self._site = site + _log.info("InterceptionServer listening on :%d", self.port) + + async def stop(self) -> None: + async with self._lock: + if self._runner is None: + return + with self._state_lock: + intercepts = list(self.intercepts.values()) + self.intercepts.clear() + self.active_rollouts.clear() + for intercept in intercepts: + fut: asyncio.Future | None = intercept.get("response_future") + if fut and not fut.done(): + fut.cancel() + cq: asyncio.Queue | None = intercept.get("chunk_queue") + if cq is not None: + try: + cq.put_nowait(None) + except asyncio.QueueFull: + pass + try: + await self._runner.cleanup() + except RuntimeError: + pass + self._runner = None + self._site = None + self._app = None + + def register_rollout( + self, + rollout_id: str, + state: dict[str, Any] | None = None, + ) -> _queue_mod.Queue[str]: + request_queue: _queue_mod.Queue[str] = _queue_mod.Queue() + with self._state_lock: + self.active_rollouts[rollout_id] = { + "request_id_queue": request_queue, + "state": state, + "tool_handlers": {}, + "tool_defs": {}, + } + active = len(self.active_rollouts) + _log.info( + "interception_rollout_registered rollout_id=%s active_rollouts=%d", + rollout_id, + active, + ) + return request_queue + + def unregister_rollout(self, rollout_id: str) -> None: + with self._state_lock: + matching_ids = [ + request_id + for request_id, intercept in self.intercepts.items() + if intercept.get("rollout_id") == rollout_id + ] + matching_intercepts = [self.intercepts[i] for i in matching_ids] + for request_id in matching_ids: + del self.intercepts[request_id] + removed = self.active_rollouts.pop(rollout_id, None) is not None + active = len(self.active_rollouts) + pending = len(self.intercepts) + + for intercept in matching_intercepts: + fut: asyncio.Future | None = intercept.get("response_future") + if fut and not fut.done(): + fut.cancel() + cq: asyncio.Queue | None = intercept.get("chunk_queue") + if cq is not None: + try: + cq.put_nowait(None) + except asyncio.QueueFull: + pass + + _log.info( + "interception_rollout_unregistered rollout_id=%s removed=%s " + "active_rollouts=%d pending_intercepts=%d", + rollout_id, + removed, + active, + pending, + ) + + def get_intercept(self, request_id: str) -> dict[str, Any] | None: + with self._state_lock: + return self.intercepts.get(request_id) + + def stats(self) -> dict[str, int]: + """Return lightweight runtime counters for health/debug views.""" + with self._state_lock: + return { + "active_rollouts": len(self.active_rollouts), + "pending_intercepts": len(self.intercepts), + } + + def register_tool_handler( + self, + rollout_id: str, + tool_name: str, + handler: ToolHandler, + *, + tool_definition: dict[str, Any] | None = None, + ) -> None: + """Register a host-side tool handler for a rollout. + + The handler is called by ``POST /rollout/{rollout_id}/v1/tools/{tool_name}`` + with a JSON payload containing ``arguments``. + + Optionally provide ``tool_definition`` (OpenAI tool schema). Registered + schemas are injected into intercepted chat-completion requests for the + rollout when the incoming request does not already include the tool. + + Only tool names explicitly configured in ``tool_name_allowlist`` are + accepted. Control-plane names (``reset``, ``step``, ``state``, + ``close``) are always rejected to preserve the dual API boundary. + """ + normalized_name = self._validate_tool_registration( + tool_name, + tool_definition=tool_definition, + ) + + with self._state_lock: + context = self.active_rollouts.get(rollout_id) + if context is None: + raise KeyError(f"rollout not found: {rollout_id}") + handlers: dict[str, ToolHandler] = context["tool_handlers"] + handlers[normalized_name] = handler + if tool_definition is not None: + tool_defs: dict[str, dict[str, Any]] = context["tool_defs"] + tool_defs[normalized_name] = tool_definition + + def unregister_tool_handler(self, rollout_id: str, tool_name: str) -> None: + with self._state_lock: + context = self.active_rollouts.get(rollout_id) + if context is None: + return + handlers: dict[str, ToolHandler] = context.get("tool_handlers", {}) + handlers.pop(tool_name, None) + tool_defs: dict[str, dict[str, Any]] = context.get("tool_defs", {}) + tool_defs.pop(tool_name, None) + + @staticmethod + def _tool_name(tool: dict[str, Any]) -> str | None: + if not isinstance(tool, dict): + return None + function = tool.get("function") + if not isinstance(function, dict): + return None + name = function.get("name") + return name if isinstance(name, str) and name else None + + def _validate_tool_registration( + self, + tool_name: str, + *, + tool_definition: dict[str, Any] | None, + ) -> str: + normalized = tool_name.strip() + if not normalized: + raise ValueError("tool_name must not be blank") + if not _TOOL_NAME_RE.fullmatch(normalized): + raise ValueError( + f"tool_name must match ^[A-Za-z0-9_-]{{1,64}}$ (got {tool_name!r})" + ) + if normalized.lower() in RESERVED_TOOL_NAMES: + raise ValueError( + "Interception tool name is reserved for infrastructure/control " + f"APIs: {normalized!r}" + ) + if normalized not in self._tool_name_allowlist: + raise ValueError( + "Interception tool name is not in the configured allowlist: " + f"{normalized!r}" + ) + + if tool_definition is not None: + definition_name = self._tool_name(tool_definition) + if definition_name is None: + raise ValueError( + "tool_definition must be an OpenAI tool schema with function.name" + ) + if definition_name != normalized: + raise ValueError( + "tool_definition.function.name must exactly match tool_name " + f"({definition_name!r} != {normalized!r})" + ) + + return normalized + + def _merge_rollout_tools( + self, + tools: Any, + tool_defs: dict[str, dict[str, Any]], + ) -> list[dict[str, Any]] | None: + merged: list[dict[str, Any]] = [] + if isinstance(tools, list): + for tool in tools: + if isinstance(tool, dict): + merged.append(tool) + + existing = { + name for item in merged if (name := self._tool_name(item)) is not None + } + for name, tool in tool_defs.items(): + if name in existing: + continue + merged.append(tool) + + return merged or None + + def _authorized(self, request: web.Request) -> bool: + auth = request.headers.get("Authorization", "") + api_key = request.headers.get("x-api-key", "") + return hmac.compare_digest( + auth, f"Bearer {self.secret}" + ) or hmac.compare_digest(api_key, self.secret) + + async def _handle_health(self, request: web.Request) -> web.Response: + return web.json_response({"status": "ok", **self.stats()}) + + async def _handle_exit(self, request: web.Request) -> web.Response: + """Handle agent process exit notification. + + Called by the sandbox entrypoint after the agent process exits. + Enqueues a sentinel ``None`` on the rollout's request queue so that + ``next_request()`` returns immediately instead of waiting for the + full timeout. + """ + rollout_id = request.match_info["rollout_id"] + with self._state_lock: + rollout = self.active_rollouts.get(rollout_id) + if rollout is None: + return web.json_response({"status": "ignored", "reason": "unknown rollout_id"}) + + queue = rollout.get("request_id_queue") + if queue is not None: + try: + queue.put_nowait(None) # sentinel: signals "agent exited" + except Exception: + pass + + _log.info( + "interception_exit_signal rollout_id=%s", + rollout_id, + ) + return web.json_response({"status": "ok"}) + + async def _handle_tool_call(self, request: web.Request) -> web.Response: + if not self._authorized(request): + return web.json_response({"error": "Unauthorized"}, status=401) + + rollout_id = request.match_info["rollout_id"] + tool_name = request.match_info["tool_name"] + with self._state_lock: + context = self.active_rollouts.get(rollout_id) + if context is None: + return web.json_response({"error": "rollout not found"}, status=404) + handlers: dict[str, ToolHandler] = context.get("tool_handlers", {}) + handler = handlers.get(tool_name) + if handler is None: + return web.json_response({"error": "tool not found"}, status=404) + + try: + body = await request.json() + except Exception as exc: + return web.json_response({"error": f"invalid JSON: {exc}"}, status=400) + + arguments_raw: Any + if isinstance(body, dict) and "arguments" in body: + arguments_raw = body.get("arguments") + else: + arguments_raw = body + + if arguments_raw is None: + arguments = {} + elif isinstance(arguments_raw, dict): + arguments = arguments_raw + else: + return web.json_response( + {"error": "tool arguments must be a JSON object"}, + status=400, + ) + + try: + response = await handler(arguments) + except Exception: + _log.exception( + "tool handler failed (rollout=%s, tool=%s)", + rollout_id, + tool_name, + ) + return web.json_response({"error": "tool execution failed"}, status=500) + + if not isinstance(response, dict): + return web.json_response( + {"error": "tool handler must return a JSON object"}, + status=500, + ) + return web.json_response(response) + + async def _handle_chat_completions( + self, request: web.Request + ) -> web.StreamResponse | web.Response: + if not self._authorized(request): + return web.json_response({"error": "Unauthorized"}, status=401) + + rollout_id = request.match_info["rollout_id"] + with self._state_lock: + context = self.active_rollouts.get(rollout_id) + if not context: + return web.json_response({"error": "rollout not found"}, status=404) + + try: + body = await request.json() + except Exception as exc: + return web.json_response({"error": f"invalid JSON: {exc}"}, status=400) + + tool_defs: dict[str, dict[str, Any]] = dict(context.get("tool_defs", {})) + merged_tools = self._merge_rollout_tools(body.get("tools"), tool_defs) + if merged_tools is not None: + body["tools"] = merged_tools + + is_streaming = bool(body.get("stream")) + request_id = f"req_{uuid.uuid4().hex[:8]}" + chunk_queue: asyncio.Queue | None = asyncio.Queue() if is_streaming else None + + intercept: dict[str, Any] = { + "request_id": request_id, + "rollout_id": rollout_id, + "messages": body.get("messages"), + "model": body.get("model"), + "tools": body.get("tools"), + "stream": is_streaming, + "chunk_queue": chunk_queue, + "response_future": asyncio.get_running_loop().create_future(), + "body": body, + } + with self._state_lock: + context = self.active_rollouts.get(rollout_id) + if context is None: + return web.json_response({"error": "rollout not found"}, status=404) + self.intercepts[request_id] = intercept + request_queue: _queue_mod.Queue[str] = context["request_id_queue"] + request_queue.put_nowait(request_id) + + if is_streaming: + return await self._stream_response(request, intercept) + + try: + response_dict = await intercept["response_future"] + except asyncio.CancelledError: + return web.json_response({"error": "rollout cancelled"}, status=499) + except Exception: + _log.exception("interception request %s failed", request_id) + return web.json_response({"error": "internal error"}, status=500) + finally: + with self._state_lock: + self.intercepts.pop(request_id, None) + + return web.json_response(response_dict) + + async def _stream_response( + self, request: web.Request, intercept: dict[str, Any] + ) -> web.StreamResponse: + chunk_queue: asyncio.Queue = intercept["chunk_queue"] + resp = web.StreamResponse( + status=200, + headers={ + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + }, + ) + await resp.prepare(request) + get_task: asyncio.Task | None = None + try: + while True: + if get_task is None: + get_task = asyncio.create_task(chunk_queue.get()) + done, _ = await asyncio.wait({get_task}, timeout=_KEEPALIVE_INTERVAL_S) + if get_task not in done: + await resp.write(b": keepalive\n\n") + continue + chunk = get_task.result() + get_task = None + if chunk is None: + await resp.write(b"data: [DONE]\n\n") + break + await resp.write(f"data: {json.dumps(chunk)}\n\n".encode()) + await asyncio.sleep(0) + except (asyncio.CancelledError, ConnectionResetError): + pass + finally: + if get_task and not get_task.done(): + get_task.cancel() + fut: asyncio.Future | None = intercept.get("response_future") + if fut and not fut.done(): + fut.cancel() + request_id = intercept.get("request_id") + if isinstance(request_id, str): + with self._state_lock: + self.intercepts.pop(request_id, None) + try: + await resp.write_eof() + except Exception: + pass + return resp + + +def _resolve_future_threadsafe(future: asyncio.Future, value: Any) -> None: + """Set a future's result from any thread. + + ``asyncio.Future`` is not thread-safe: calling ``set_result`` from a + thread that is not running the future's event loop can silently fail + to wake the coroutine awaiting it. This helper detects cross-loop + calls and uses ``call_soon_threadsafe`` to schedule the resolution on + the correct loop. + """ + if future.done(): + return + loop = future.get_loop() + try: + running = asyncio.get_running_loop() + except RuntimeError: + running = None + if running is loop: + future.set_result(value) + else: + loop.call_soon_threadsafe(future.set_result, value) + + +def _put_queue_threadsafe(q: asyncio.Queue, item: Any) -> None: + """Put an item on an asyncio.Queue from any thread.""" + loop = getattr(q, "_loop", None) + if loop is None: + # Fallback: put_nowait which is simpler. Let QueueFull propagate β€” + # silently dropping items would cause hard-to-debug streaming issues. + q.put_nowait(item) + return + try: + running = asyncio.get_running_loop() + except RuntimeError: + running = None + if running is loop: + q.put_nowait(item) + else: + loop.call_soon_threadsafe(q.put_nowait, item) + + +async def deliver_response( + intercept: dict[str, Any], response_dict: dict[str, Any] +) -> None: + """Unblock the agent's HTTP handler with ``response_dict``. + + For non-streaming requests, resolves the future directly. + For streaming requests, synthesizes SSE chunks from the complete + response and signals EOF. + + Thread-safe: can be called from any thread, not just the event loop + that owns the future/queue. This is required because the rollout + worker may run ``deliver_response`` from its own ``asyncio.run()`` + in a daemon thread while the ``InterceptionServer``'s aiohttp + handler awaits the future on a different loop. + """ + is_streaming = intercept.get("stream", False) + chunk_queue: asyncio.Queue | None = intercept.get("chunk_queue") + future: asyncio.Future | None = intercept.get("response_future") + + if not is_streaming: + if future: + _resolve_future_threadsafe(future, response_dict) + return + + if chunk_queue is None: + raise RuntimeError("chunk_queue missing on streaming intercept") + + choices = response_dict.get("choices") or [] + for choice in choices: + msg = choice.get("message") or {} + content_chunk = { + "id": response_dict.get("id", ""), + "object": "chat.completion.chunk", + "created": response_dict.get("created", int(time.time())), + "model": response_dict.get("model", ""), + "choices": [ + { + "index": choice.get("index", 0), + "delta": { + "role": "assistant", + "content": msg.get("content"), + "tool_calls": msg.get("tool_calls"), + }, + "finish_reason": None, + } + ], + } + _put_queue_threadsafe(chunk_queue, content_chunk) + finish_chunk = { + "id": response_dict.get("id", ""), + "object": "chat.completion.chunk", + "created": response_dict.get("created", int(time.time())), + "model": response_dict.get("model", ""), + "choices": [ + { + "index": choice.get("index", 0), + "delta": {}, + "finish_reason": choice.get("finish_reason"), + } + ], + } + _put_queue_threadsafe(chunk_queue, finish_chunk) + + _put_queue_threadsafe(chunk_queue, None) + if future: + _resolve_future_threadsafe(future, response_dict) + + +__all__ = [ + "InterceptionServer", + "deliver_response", +] diff --git a/src/openenv/core/harness/agents/opencode.py b/src/openenv/core/harness/agents/opencode.py new file mode 100644 index 000000000..9a829c3e2 --- /dev/null +++ b/src/openenv/core/harness/agents/opencode.py @@ -0,0 +1,220 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""OpenCode agent adapter. + +Expresses the OpenCode harness as a purely declarative :class:`CLIAgentSpec`. +All builders (command construction, config generation, env var resolution) +are self-contained with no imports from any environment package. + +Registered on import:: + + import openenv.core.harness.agents.opencode + # OPENCODE_SPEC is now in the registry +""" + +from __future__ import annotations + +import json +import shlex +from typing import Any + +from . import register_agent +from .base import AgentEvent, ArtifactSpec, CLIAgentSpec, MCPConfigSpec + + +def _build_opencode_command( + spec: CLIAgentSpec, + config: Any, + task: Any, + mcp_config_path: str | None, +) -> str: + """Build the ``opencode run`` shell command.""" + home = config.sandbox_home if hasattr(config, "sandbox_home") else "/home/user" + run_format = config.run_format if hasattr(config, "run_format") else "json" + format_flag = "--format json" if run_format == "json" else "" + instruction_file = f"{home}/task/instruction.md" + log_file = f"{home}/logs/agent/opencode.jsonl" + workdir = f"{home}/workdir" + + workdir_q = shlex.quote(workdir) + instruction_q = shlex.quote(instruction_file) + log_q = shlex.quote(log_file) + + return ( + f'export PATH="$HOME/.opencode/bin:$PATH" && ' + f"cd {workdir_q} && git init -q 2>/dev/null; " + f'opencode run {format_flag} "$(cat {instruction_q})" ' + f"2>&1 | tee {log_q}" + ).strip() + + +def _build_opencode_mcp_config( + spec: CLIAgentSpec, + tools: list[Any], + workdir: str, +) -> str: + """Build ``opencode.json`` content. + + Returns an empty string so the driver skips writing this file. + The actual config is written via ``spec.files`` using + ``_build_opencode_config_file`` which has access to the rollout + config (base_url, api_key, model). + """ + return "" + + +def _build_opencode_config_file(task: Any, config: Any) -> str: + """Build the full ``opencode.json`` dynamically from config fields.""" + base_url = ( + config.base_url if hasattr(config, "base_url") else "http://127.0.0.1:7000/v1" + ) + api_key = config.api_key if hasattr(config, "api_key") else "intercepted" + model = config.model if hasattr(config, "model") else "model" + timeout = ( + int(config.agent_timeout_s * 1000) + if hasattr(config, "agent_timeout_s") + else 600000 + ) + + # Split model into provider_name/model_id for the opencode config format. + # e.g. "zai-org/GLM-5.1:zai-org" becomes provider "hf", model_id as-is. + provider_name = "default" + model_id = model + if hasattr(config, "provider_name") and config.provider_name: + provider_name = config.provider_name + + return json.dumps( + { + "$schema": "https://opencode.ai/config.json", + "model": f"{provider_name}/{model_id}", + "provider": { + provider_name: { + "npm": "@ai-sdk/openai-compatible", + "name": provider_name.title(), + "options": { + "baseURL": base_url, + "apiKey": api_key, + "timeout": timeout, + }, + "models": { + model_id: { + "name": model_id, + "id": model_id, + }, + }, + } + }, + }, + indent=2, + ) + + +def _build_opencode_env_vars( + spec: CLIAgentSpec, + config: Any, +) -> dict[str, str]: + """Build env vars for the OpenCode process.""" + home = config.sandbox_home if hasattr(config, "sandbox_home") else "/home/user" + base_url = config.base_url if hasattr(config, "base_url") else "" + api_key = config.api_key if hasattr(config, "api_key") else "intercepted" + extra_env = config.extra_env if hasattr(config, "extra_env") else {} + + env = dict(extra_env) + env["OPENAI_BASE_URL"] = base_url + env["OPENAI_API_KEY"] = api_key + env["OPENCODE_CONFIG"] = f"{home}/.config/opencode/opencode.json" + return env + + +def _parse_opencode_event(line: str) -> AgentEvent | None: + """Parse one line of OpenCode's JSONL stdout.""" + line = line.strip() + if not line: + return None + try: + data = json.loads(line) + except json.JSONDecodeError: + return None + + event_type = data.get("type", "") + if event_type in ("assistant", "message", "text"): + return AgentEvent(type="assistant", data=data, raw=line) + elif event_type in ("tool_call", "tool_use"): + return AgentEvent(type="tool_call", data=data, raw=line) + elif event_type in ("tool_result", "tool_response"): + return AgentEvent(type="tool_result", data=data, raw=line) + elif event_type in ("step_start",): + return AgentEvent(type="assistant", data=data, raw=line) + elif event_type in ("step_finish",): + return AgentEvent(type="done", data=data, raw=line) + elif event_type == "error": + return AgentEvent(type="error", data=data, raw=line) + elif event_type in ("done", "complete", "end"): + return AgentEvent(type="done", data=data, raw=line) + return AgentEvent(type="assistant", data=data, raw=line) + + +def _instruction_file_content(task: Any, config: Any) -> str: + return task.instruction if hasattr(task, "instruction") else str(task) + + +def _system_prompt_content(task: Any, config: Any) -> str | None: + if hasattr(config, "system_prompt") and config.system_prompt: + return config.system_prompt + return None + + +OPENCODE_SPEC = CLIAgentSpec( + name="opencode", + install_check_cmd=["/home/user/.opencode/bin/opencode", "--version"], + base_command=[ + "opencode", + "run", + "--format", + "json", + "--dangerously-skip-permissions", + ], + mcp_config=MCPConfigSpec( + method="config_file", + path_template="{home}/.config/opencode/opencode.json", + ), + default_timeout_s=900.0, + setup=( + "set -e && " + "curl -fsSL https://opencode.ai/install | bash && " + "mkdir -p /home/user/.config/opencode /home/user/logs/agent " + "/home/user/logs/verifier /home/user/task /home/user/workdir && " + 'export PATH="$HOME/.opencode/bin:$PATH" && ' + "opencode --version" + ), + files={ + "/home/user/task/instruction.md": _instruction_file_content, + "/home/user/task/system.md": _system_prompt_content, + "/home/user/.config/opencode/opencode.json": _build_opencode_config_file, + }, + artifacts={ + "agent_log": ArtifactSpec( + path="/home/user/logs/agent/opencode.jsonl", + format="jsonl", + ), + }, + env={ + "PATH": "/home/user/.opencode/bin:$PATH", + "OPENAI_BASE_URL": "{base_url}", + "OPENAI_API_KEY": "{api_key}", + }, + build_command=_build_opencode_command, + build_mcp_config=_build_opencode_mcp_config, + parse_events=_parse_opencode_event, + build_env_vars=_build_opencode_env_vars, +) + +register_agent(OPENCODE_SPEC) + +__all__ = [ + "OPENCODE_SPEC", +] diff --git a/src/openenv/core/harness/agents/pi.py b/src/openenv/core/harness/agents/pi.py new file mode 100644 index 000000000..a2fdd7537 --- /dev/null +++ b/src/openenv/core/harness/agents/pi.py @@ -0,0 +1,193 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Pi coding agent adapter. + +Pi runs in print mode for non-interactive harness usage:: + + pi --no-session --no-context-files --provider

--model --thinking off \\ + -p @/home/user/task/instruction.txt 2>&1 | tee /home/user/logs/agent/pi.txt + +The provider and model are passed as CLI flags. Provider-specific credentials +are exported via ``build_env_vars`` according to Pi's provider docs +(``HF_TOKEN`` for ``huggingface``, ``OPENAI_API_KEY`` for ``openai``, etc.). + +Registered on import:: + + import openenv.core.harness.agents.pi + # PI_SPEC is now in the registry +""" + +from __future__ import annotations + +import json +import shlex +from typing import Any + +from . import register_agent +from .base import AgentEvent, ArtifactSpec, CLIAgentSpec, MCPConfigSpec + + +def _instruction(task: Any, config: Any) -> str: + return task.instruction if hasattr(task, "instruction") else str(task) + + +def _system_prompt(task: Any, config: Any) -> str | None: + if hasattr(config, "system_prompt") and config.system_prompt: + return config.system_prompt + return None + + +def _build_command( + spec: CLIAgentSpec, + config: Any, + task: Any, + mcp_config_path: str | None, +) -> str: + home = config.sandbox_home if hasattr(config, "sandbox_home") else "/home/user" + instruction_file = f"{home}/task/instruction.txt" + log_file = f"{home}/logs/agent/pi.txt" + workdir = ( + config.workdir + if hasattr(config, "workdir") and getattr(config, "workdir") + else f"{home}/workdir" + ) + + provider = "" + if hasattr(config, "provider") and config.provider: + provider = f" --provider {shlex.quote(config.provider)}" + model = "" + if hasattr(config, "model") and config.model: + model = f" --model {shlex.quote(config.model)}" + thinking = " --thinking off" + if hasattr(config, "thinking") and config.thinking: + thinking = f" --thinking {shlex.quote(config.thinking)}" + + workdir_q = shlex.quote(workdir) + instruction_q = shlex.quote(instruction_file) + log_q = shlex.quote(log_file) + + return ( + f"cd {workdir_q} && git init -q 2>/dev/null; " + f"pi --no-session --no-context-files" + f"{provider}{model}{thinking}" + f" -p @{instruction_q}" + f" 2>&1 | tee {log_q}" + ) + + +def _build_mcp_config( + spec: CLIAgentSpec, + tools: list[Any], + workdir: str, +) -> str: + return json.dumps({"mcpServers": {}}, indent=2) + + +def _parse_events(line: str) -> AgentEvent | None: + line = line.strip() + if not line: + return None + try: + data = json.loads(line) + except json.JSONDecodeError: + return AgentEvent(type="assistant", data={"text": line}, raw=line) + + event_type = data.get("type", "") + if event_type in ("assistant", "message", "response"): + return AgentEvent(type="assistant", data=data, raw=line) + if event_type in ("tool_call", "tool_use", "function_call"): + return AgentEvent(type="tool_call", data=data, raw=line) + if event_type in ("tool_result", "tool_response"): + return AgentEvent(type="tool_result", data=data, raw=line) + if event_type in ("thinking", "reasoning"): + return AgentEvent(type="reasoning", data=data, raw=line) + if event_type == "error": + return AgentEvent(type="error", data=data, raw=line) + if event_type in ("done", "complete", "end"): + return AgentEvent(type="done", data=data, raw=line) + return AgentEvent(type="assistant", data=data, raw=line) + + +def _provider_api_key_env(provider: str) -> str: + provider_key = provider.strip().lower() + env_by_provider = { + # https://github.com/earendil-works/pi/tree/main/packages/coding-agent#providers--models + "openai": "OPENAI_API_KEY", + "openenv": "OPENAI_API_KEY", + "huggingface": "HF_TOKEN", + "anthropic": "ANTHROPIC_API_KEY", + "gemini": "GEMINI_API_KEY", + "google": "GEMINI_API_KEY", + } + env_name = env_by_provider.get(provider_key) + if env_name is None: + raise ValueError( + f"Unsupported pi provider {provider!r}; expected one of " + f"{sorted(env_by_provider)}" + ) + return env_name + + +def _build_env_vars(spec: CLIAgentSpec, config: Any) -> dict[str, str]: + provider = config.provider if hasattr(config, "provider") else "openai" + if not isinstance(provider, str) or not provider.strip(): + provider = "openai" + api_key = config.api_key if hasattr(config, "api_key") else "" + base_url = config.base_url if hasattr(config, "base_url") else "" + extra_env = config.extra_env if hasattr(config, "extra_env") else {} + + env = dict(extra_env) + env["PI_SKIP_VERSION_CHECK"] = "1" + env["PI_TELEMETRY"] = "0" + + if base_url: + env["OPENAI_BASE_URL"] = base_url + + key_env_var = _provider_api_key_env(provider) + if api_key: + env[key_env_var] = api_key + + return env + + +PI_SPEC = CLIAgentSpec( + name="pi", + install_check_cmd=["pi", "--version"], + base_command=["pi", "--no-session", "--no-context-files"], + mcp_config=MCPConfigSpec( + method="config_file", + path_template="{workdir}/.mcp.json", + ), + default_timeout_s=600.0, + setup=( + "set -e && " + "apt-get update -qq && apt-get install -y -qq curl ca-certificates gnupg && " + "curl -fsSL https://deb.nodesource.com/setup_22.x | bash - && " + "apt-get install -y -qq nodejs && " + "curl -fsSL https://pi.dev/install.sh | sh && " + "mkdir -p /home/user/logs/agent /home/user/task /home/user/workdir && " + 'export PATH="$HOME/.local/bin:$HOME/.pi/bin:$PATH" && ' + "pi --version" + ), + files={ + "/home/user/task/instruction.txt": _instruction, + "/home/user/task/system.txt": _system_prompt, + }, + artifacts={ + "agent_log": ArtifactSpec(path="/home/user/logs/agent/pi.txt"), + }, + env=None, + extension_dir_template="{home}/.pi/agent/extensions", + build_command=_build_command, + build_mcp_config=_build_mcp_config, + parse_events=_parse_events, + build_env_vars=_build_env_vars, +) + +register_agent(PI_SPEC) + +__all__ = ["PI_SPEC"] diff --git a/src/openenv/core/harness/sandbox/__init__.py b/src/openenv/core/harness/sandbox/__init__.py new file mode 100644 index 000000000..208fe54d5 --- /dev/null +++ b/src/openenv/core/harness/sandbox/__init__.py @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Sandbox backends for harness-driven rollouts. + +Provides the :class:`SandboxBackend` / :class:`SandboxHandle` protocols and +concrete implementations. Any harness adapter can use any backend -- the +sandbox layer is orthogonal to the agent CLI choice. + +Optional backend imports are wrapped in ``try/except`` so this package +loads cleanly when dependencies aren't installed (CI smoke tests, lint). +""" + +from typing import Any, Literal + +from .base import BgJob, ExecResult, SandboxBackend, SandboxHandle +from .docker_backend import DockerBgJob, DockerSandboxBackend, DockerSandboxHandle + +__all__ = [ + "BgJob", + "DockerBgJob", + "DockerSandboxBackend", + "DockerSandboxHandle", + "ExecResult", + "SandboxBackend", + "SandboxHandle", + "create_sandbox_backend", +] + +try: + from .e2b_backend import E2BBgJob, E2BSandboxBackend, E2BSandboxHandle # noqa: F401 + + __all__.extend(["E2BBgJob", "E2BSandboxBackend", "E2BSandboxHandle"]) +except ImportError: + pass # e2b not installed + +try: + from .hf_backend import HFBgJob, HFSandboxBackend, HFSandboxHandle # noqa: F401 + + __all__.extend(["HFBgJob", "HFSandboxBackend", "HFSandboxHandle"]) +except ImportError: + pass # hf-sandbox not installed + + +def create_sandbox_backend( + backend: Literal["e2b", "docker", "hf"] = "e2b", + **kwargs: Any, +) -> SandboxBackend: + """Create a sandbox backend by name. + + For ``"e2b"``: works with both E2B cloud and CubeSandbox + (set ``E2B_API_URL``). + + For ``"docker"``: local Docker, no external dependencies. + + For ``"hf"``: Hugging Face Jobs via ``hf-sandbox``. + """ + if backend == "e2b": + from .e2b_backend import E2BSandboxBackend + + return E2BSandboxBackend(**kwargs) + elif backend == "docker": + return DockerSandboxBackend(**kwargs) + elif backend == "hf": + from .hf_backend import HFSandboxBackend + + return HFSandboxBackend(**kwargs) + raise ValueError( + f"Unknown sandbox backend: {backend!r}. Use 'e2b', 'docker', or 'hf'." + ) diff --git a/src/openenv/core/harness/sandbox/_util.py b/src/openenv/core/harness/sandbox/_util.py new file mode 100644 index 000000000..6291b0fb3 --- /dev/null +++ b/src/openenv/core/harness/sandbox/_util.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + + +def shell_quote(s: str) -> str: + """Single-quote a string for shell, escaping embedded single quotes.""" + return "'" + s.replace("'", "'\\''") + "'" diff --git a/envs/opencode_env/sandbox/base.py b/src/openenv/core/harness/sandbox/base.py similarity index 79% rename from envs/opencode_env/sandbox/base.py rename to src/openenv/core/harness/sandbox/base.py index 76869149a..22f096310 100644 --- a/envs/opencode_env/sandbox/base.py +++ b/src/openenv/core/harness/sandbox/base.py @@ -6,18 +6,18 @@ """Sandbox backend protocol. -A ``SandboxBackend`` produces ``SandboxHandle`` instances that the harness uses -to stage files, run the OpenCode install, launch the agent as a background -process, and later tear the sandbox down. +A ``SandboxBackend`` produces ``SandboxHandle`` instances that harnesses use +to stage files, install agent CLIs, launch the agent as a background process, +and later tear the sandbox down. -Backends can be implemented against any provider (E2B, Docker, Modal, Prime) -as long as they satisfy the Protocols defined here. +Backends can be implemented against any provider (E2B, CubeSandbox, Docker, +Modal) as long as they satisfy the Protocols defined here. """ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Protocol, runtime_checkable +from typing import Protocol, runtime_checkable @dataclass @@ -96,5 +96,10 @@ def create( timeout_s: int = 900, envs: dict[str, str] | None = None, metadata: dict[str, str] | None = None, + image: str | None = None, ) -> SandboxHandle: - """Create and return a new, ready-to-use sandbox.""" + """Create and return a new, ready-to-use sandbox. + + ``image`` is backend-specific and may be ignored by providers that do + not support per-sandbox image selection. + """ diff --git a/src/openenv/core/harness/sandbox/docker_backend.py b/src/openenv/core/harness/sandbox/docker_backend.py new file mode 100644 index 000000000..120fb9a11 --- /dev/null +++ b/src/openenv/core/harness/sandbox/docker_backend.py @@ -0,0 +1,373 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Docker implementation of :class:`SandboxBackend`. + +Runs each sandbox as a ``docker run -d`` container on the local machine. +Commands execute via ``docker exec``, files transfer via ``docker exec`` +with stdin piping. Suitable for CI, local dev, and environments without +KVM or cloud sandbox credentials. + +Usage:: + + from openenv.core.harness.sandbox.docker_backend import DockerSandboxBackend + + backend = DockerSandboxBackend(image="ubuntu:22.04") + sandbox = backend.create() + result = sandbox.exec("echo hello") + print(result.stdout) # "hello" + sandbox.kill() +""" + +from __future__ import annotations + +import logging +import subprocess +import threading +import time +import uuid +from pathlib import PurePosixPath + +from openenv.core.harness.sandbox._util import shell_quote +from openenv.core.harness.sandbox.base import BgJob, ExecResult + +_log = logging.getLogger(__name__) + + +class DockerBgJob: + """Handle to a background process running inside a Docker container. + + Launches the command via ``docker exec -d`` and tracks the wrapper + shell PID. Completion is detected by polling whether the PID is still + alive inside the container. + """ + + def __init__( + self, + container_id: str, + pid: int, + poll_thread: threading.Thread | None = None, + ) -> None: + self._container_id = container_id + self._pid = pid + self._exit_code: int | None = None + self._done = threading.Event() + self._poll_thread = poll_thread + + @property + def pid(self) -> int: + return self._pid + + def wait(self, timeout: float | None = None) -> int: + if not self._done.wait(timeout=timeout): + raise TimeoutError( + f"Background command (pid={self._pid}) did not exit within {timeout}s" + ) + return self._exit_code if self._exit_code is not None else 0 + + def kill(self) -> None: + try: + subprocess.run( + ["docker", "exec", self._container_id, "kill", "-9", str(self._pid)], + capture_output=True, + timeout=5, + ) + except Exception: + pass + self._done.set() + + +class DockerSandboxHandle: + """Wraps a running Docker container to satisfy :class:`SandboxHandle`.""" + + def __init__(self, container_id: str, *, user: str | None = None) -> None: + self._container_id = container_id + self._user = user + self._bg_jobs: list[DockerBgJob] = [] + + @property + def sandbox_id(self) -> str: + return self._container_id[:12] + + def exec( + self, + cmd: str, + *, + envs: dict[str, str] | None = None, + cwd: str | None = None, + timeout: float | None = 60, + ) -> ExecResult: + docker_cmd = self._build_exec_cmd(envs=envs, cwd=cwd) + docker_cmd.extend(["bash", "-c", cmd]) + try: + result = subprocess.run( + docker_cmd, + capture_output=True, + text=True, + timeout=timeout, + ) + return ExecResult( + exit_code=result.returncode, + stdout=result.stdout, + stderr=result.stderr, + ) + except subprocess.TimeoutExpired: + return ExecResult( + exit_code=-1, stdout="", stderr=f"Command timed out after {timeout}s" + ) + except Exception as exc: + return ExecResult(exit_code=-1, stdout="", stderr=str(exc)) + + def start_bg( + self, + cmd: str, + *, + envs: dict[str, str] | None = None, + cwd: str | None = None, + ) -> BgJob: + marker = f"/tmp/.bg_{uuid.uuid4().hex}" + wrapped = f"bash -c {shell_quote(cmd + f'; echo $? > {marker}')} &\necho $!" + docker_cmd = self._build_exec_cmd(envs=envs, cwd=cwd) + docker_cmd.extend(["bash", "-c", wrapped]) + result = subprocess.run(docker_cmd, capture_output=True, text=True, timeout=10) + if result.returncode != 0: + raise RuntimeError(f"Failed to start background command: {result.stderr}") + # Extract PID from the last numeric-only line (commands may print + # banners like "Database migration complete." before the PID). + pid_line = None + for line in reversed(result.stdout.strip().splitlines()): + if line.strip().isdigit(): + pid_line = line.strip() + break + if pid_line is None: + raise RuntimeError( + f"Could not extract PID from start_bg output: {result.stdout!r}" + ) + pid = int(pid_line) + + job = DockerBgJob(self._container_id, pid) + poll_thread = threading.Thread( + target=self._poll_bg_job, + args=(job, marker), + daemon=True, + ) + job._poll_thread = poll_thread + self._bg_jobs.append(job) + poll_thread.start() + return job + + def write_text(self, path: str, content: str) -> None: + parent = str(PurePosixPath(path).parent) + if parent not in ("", "/"): + mkdir_result = subprocess.run( + ["docker", "exec", self._container_id, "mkdir", "-p", parent], + capture_output=True, + timeout=10, + ) + if mkdir_result.returncode != 0: + raise RuntimeError( + f"Failed to create directory {parent!r} in container " + f"{self._container_id}: {mkdir_result.stderr.decode(errors='replace')}" + ) + write_result = subprocess.run( + [ + "docker", + "exec", + "-i", + self._container_id, + "bash", + "-c", + f"cat > {shell_quote(path)}", + ], + input=content.encode(), + capture_output=True, + timeout=30, + ) + if write_result.returncode != 0: + raise RuntimeError( + f"Failed to write file {path!r} in container " + f"{self._container_id}: {write_result.stderr.decode(errors='replace')}" + ) + + def read_text(self, path: str) -> str: + result = subprocess.run( + ["docker", "exec", self._container_id, "cat", path], + capture_output=True, + text=True, + timeout=30, + ) + if result.returncode != 0: + raise FileNotFoundError(f"No such file in container: {path}") + return result.stdout + + def exists(self, path: str) -> bool: + result = subprocess.run( + ["docker", "exec", self._container_id, "test", "-e", path], + capture_output=True, + timeout=10, + ) + return result.returncode == 0 + + def kill(self) -> None: + for job in self._bg_jobs: + try: + job.kill() + except Exception: + pass + self._bg_jobs.clear() + try: + subprocess.run( + ["docker", "rm", "-f", self._container_id], + capture_output=True, + timeout=15, + ) + except Exception: + pass + + def _build_exec_cmd( + self, + *, + envs: dict[str, str] | None = None, + cwd: str | None = None, + ) -> list[str]: + cmd = ["docker", "exec"] + if self._user: + cmd.extend(["-u", self._user]) + if cwd: + cmd.extend(["-w", cwd]) + for k, v in (envs or {}).items(): + cmd.extend(["-e", f"{k}={v}"]) + cmd.append(self._container_id) + return cmd + + def _poll_bg_job(self, job: DockerBgJob, marker: str) -> None: + consecutive_failures = 0 + while not job._done.is_set(): + try: + result = subprocess.run( + ["docker", "exec", self._container_id, "cat", marker], + capture_output=True, + text=True, + timeout=5, + ) + if result.returncode == 0 and result.stdout.strip(): + job._exit_code = int(result.stdout.strip()) + job._done.set() + return + if "No such container" in (result.stderr or ""): + job._exit_code = 1 + job._done.set() + return + except Exception: + consecutive_failures += 1 + else: + consecutive_failures = 0 + + # Also check if PID is gone (crash without writing marker). + try: + check = subprocess.run( + ["docker", "exec", self._container_id, "kill", "-0", str(job._pid)], + capture_output=True, + text=True, + timeout=5, + ) + if check.returncode != 0: + job._exit_code = 1 + job._done.set() + return + if "No such container" in (check.stderr or ""): + job._exit_code = 1 + job._done.set() + return + except Exception: + consecutive_failures += 1 + + if consecutive_failures >= 10: + job._exit_code = 1 + job._done.set() + return + + time.sleep(0.5) + + +class DockerSandboxBackend: + """Creates Docker container sandboxes. + + Each :meth:`create` call spawns a fresh ``docker run -d`` container + that stays alive until :meth:`SandboxHandle.kill` is called or the + container's ``timeout_s`` sleep expires. + """ + + def __init__( + self, + *, + image: str = "ubuntu:22.04", + docker_args: list[str] | None = None, + user: str | None = None, + ) -> None: + self._image = image + self._docker_args = list(docker_args or []) + self._user = user + + # Linux Docker Engine does not auto-resolve host.docker.internal + # unless we explicitly map it. + if "host.docker.internal:host-gateway" not in self._docker_args: + self._docker_args.extend( + ["--add-host", "host.docker.internal:host-gateway"] + ) + + try: + subprocess.run( + ["docker", "version"], + capture_output=True, + check=True, + timeout=5, + ) + except ( + subprocess.CalledProcessError, + FileNotFoundError, + subprocess.TimeoutExpired, + ) as exc: + raise RuntimeError( + "DockerSandboxBackend requires a running Docker daemon." + ) from exc + + def create( + self, + *, + timeout_s: int = 900, + envs: dict[str, str] | None = None, + metadata: dict[str, str] | None = None, + image: str | None = None, + ) -> DockerSandboxHandle: + cmd = [ + "docker", + "run", + "-d", + "--label", + "openenv.sandbox=true", + ] + if metadata: + for k, v in metadata.items(): + cmd.extend(["--label", f"openenv.{k}={v}"]) + for k, v in (envs or {}).items(): + cmd.extend(["-e", f"{k}={v}"]) + cmd.extend(self._docker_args) + effective_image = image or self._image + cmd.extend([effective_image, "sleep", str(timeout_s)]) + + result = subprocess.run(cmd, capture_output=True, text=True, timeout=30) + if result.returncode != 0: + raise RuntimeError( + f"Failed to create Docker sandbox: {result.stderr.strip()}" + ) + container_id = result.stdout.strip() + _log.info( + "Docker sandbox created: %s (image=%s)", + container_id[:12], + effective_image, + ) + return DockerSandboxHandle(container_id, user=self._user) diff --git a/envs/opencode_env/sandbox/e2b.py b/src/openenv/core/harness/sandbox/e2b_backend.py similarity index 92% rename from envs/opencode_env/sandbox/e2b.py rename to src/openenv/core/harness/sandbox/e2b_backend.py index b567a9e65..c0cbf75ba 100644 --- a/envs/opencode_env/sandbox/e2b.py +++ b/src/openenv/core/harness/sandbox/e2b_backend.py @@ -4,7 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""E2B implementation of :class:`SandboxBackend`.""" +"""E2B implementation of :class:`SandboxBackend`. + +Works with both E2B cloud (api.e2b.dev) and self-hosted E2B-compatible +backends like CubeSandbox. For CubeSandbox, set:: + + E2B_API_URL=http://your-cubesandbox:3000 + E2B_API_KEY=dummy # any non-empty string +""" from __future__ import annotations @@ -14,8 +21,7 @@ from e2b import Sandbox from e2b.sandbox_sync.commands.command_handle import CommandHandle - -from .base import BgJob, ExecResult, SandboxBackend, SandboxHandle +from openenv.core.harness.sandbox.base import BgJob, ExecResult, SandboxHandle class E2BBgJob: @@ -46,9 +52,7 @@ def pid(self) -> int: def wait(self, timeout: float | None = None) -> int: self._thread.join(timeout) if self._thread.is_alive(): - raise TimeoutError( - f"Background command did not exit within {timeout}s" - ) + raise TimeoutError(f"Background command did not exit within {timeout}s") if self._error is not None: # E2B raises CommandExitException on non-zero; treat as exit code. code = getattr(self._error, "exit_code", None) @@ -180,7 +184,9 @@ def create( timeout_s: int = 900, envs: dict[str, str] | None = None, metadata: dict[str, str] | None = None, + image: str | None = None, ) -> SandboxHandle: + del image sbx = Sandbox.create( template=self._template, timeout=timeout_s, diff --git a/src/openenv/core/harness/sandbox/hf_backend.py b/src/openenv/core/harness/sandbox/hf_backend.py new file mode 100644 index 000000000..3b7b060b5 --- /dev/null +++ b/src/openenv/core/harness/sandbox/hf_backend.py @@ -0,0 +1,308 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Hugging Face Sandbox implementation of :class:`SandboxBackend`. + +Wraps `hf-sandbox` (https://github.com/huggingface/hf-sandbox) so OpenEnv +harnesses can use it through the same protocol. +""" + +from __future__ import annotations + +import re +import time +import uuid +from pathlib import PurePosixPath +from threading import Event +from typing import Any + +from hf_sandbox import Sandbox +from openenv.core.harness.sandbox._util import shell_quote +from openenv.core.harness.sandbox.base import BgJob, ExecResult, SandboxHandle + +_ENV_KEY_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") + + +class HFSandboxError(RuntimeError): + """Base class for HF sandbox backend errors.""" + + +class HFSandboxCreateError(HFSandboxError): + """Raised when backend cannot create a sandbox.""" + + +class HFBgJob: + """Background process handle for :class:`HFSandboxHandle`.""" + + def __init__( + self, + sandbox: "HFSandboxHandle", + *, + pid: int, + marker_path: str, + poll_interval_s: float = 0.5, + ) -> None: + self._sandbox = sandbox + self._pid = pid + self._marker_path = marker_path + self._poll_interval_s = poll_interval_s + self._done = Event() + self._exit_code: int | None = None + + @property + def pid(self) -> int: + return self._pid + + def wait(self, timeout: float | None = None) -> int: + deadline = None if timeout is None else (time.monotonic() + timeout) + while True: + if self._done.is_set(): + return self._exit_code if self._exit_code is not None else 0 + if deadline is not None and time.monotonic() > deadline: + raise TimeoutError( + f"Background command (pid={self._pid}) " + f"did not exit within {timeout}s" + ) + + marker = self._sandbox.exec( + f"cat {shell_quote(self._marker_path)}", + timeout=10, + ) + if marker.exit_code == 0 and marker.stdout.strip(): + self._exit_code = _parse_exit_code(marker.stdout.strip(), default=0) + self._done.set() + return self._exit_code + + alive = self._sandbox.exec(f"kill -0 {self._pid}", timeout=10) + if alive.exit_code != 0: + self._exit_code = 1 + self._done.set() + return self._exit_code + + time.sleep(self._poll_interval_s) + + def kill(self) -> None: + if self._done.is_set(): + return + try: + self._sandbox.exec(f"kill -9 {self._pid}", timeout=5) + except Exception: + pass + self._exit_code = 137 + self._done.set() + + +class HFSandboxHandle: + """Wraps a live ``hf_sandbox.Sandbox`` to satisfy :class:`SandboxHandle`.""" + + def __init__( + self, + sandbox: Any, + *, + default_envs: dict[str, str] | None = None, + ) -> None: + self._sbx = sandbox + self._default_envs = dict(default_envs or {}) + self._bg_jobs: list[HFBgJob] = [] + + @property + def sandbox_id(self) -> str: + return str(getattr(self._sbx, "job_id", "hf-sandbox")) + + @property + def raw(self) -> Any: + """Escape hatch for callers that need the underlying SDK object.""" + return self._sbx + + def exec( + self, + cmd: str, + *, + envs: dict[str, str] | None = None, + cwd: str | None = None, + timeout: float | None = 60, + ) -> ExecResult: + merged_envs = dict(self._default_envs) + merged_envs.update(envs or {}) + shell_cmd = _with_env_prefix(cmd, merged_envs) + timeout_s = _normalize_exec_timeout(timeout) + try: + result = self._sbx.exec( + "bash", + "-lc", + shell_cmd, + workdir=cwd, + timeout=timeout_s, + ) + return ExecResult( + exit_code=int(getattr(result, "returncode", 1)), + stdout=str(getattr(result, "stdout", "") or ""), + stderr=str(getattr(result, "stderr", "") or ""), + ) + except Exception as exc: + return ExecResult(exit_code=-1, stdout="", stderr=str(exc)) + + def start_bg( + self, + cmd: str, + *, + envs: dict[str, str] | None = None, + cwd: str | None = None, + ) -> BgJob: + marker_path = f"/tmp/.openenv_bg_{uuid.uuid4().hex[:12]}.exit" + wrapped = f"{cmd}; rc=$?; echo $rc > {shell_quote(marker_path)}" + launch_cmd = f"nohup bash -lc {shell_quote(wrapped)} >/dev/null 2>&1 & echo $!" + + result = self.exec(launch_cmd, envs=envs, cwd=cwd, timeout=30) + if result.exit_code != 0: + raise RuntimeError( + f"Failed to start background command: {result.stderr or result.stdout}" + ) + + pid = _parse_pid(result.stdout) + if pid is None: + raise RuntimeError( + f"Could not extract PID from start_bg output: {result.stdout!r}" + ) + + job = HFBgJob(self, pid=pid, marker_path=marker_path) + self._bg_jobs.append(job) + return job + + def write_text(self, path: str, content: str) -> None: + parent = str(PurePosixPath(path).parent) + if parent not in ("", "/"): + r = self.exec(f"mkdir -p {shell_quote(parent)}", timeout=10) + if r.exit_code != 0: + raise RuntimeError( + f"Failed to create parent directory {parent!r}: {r.stderr}" + ) + self._sbx.write_file(path, content) + + def read_text(self, path: str) -> str: + return str(self._sbx.read_file(path, text=True)) + + def exists(self, path: str) -> bool: + r = self.exec(f"test -e {shell_quote(path)}", timeout=10) + return r.exit_code == 0 + + def kill(self) -> None: + for job in self._bg_jobs: + try: + job.kill() + except Exception: + pass + self._bg_jobs.clear() + try: + self._sbx.terminate() + except Exception: + pass + + +class HFSandboxBackend: + """Creates HF sandboxes for harness rollouts via ``hf-sandbox``.""" + + def __init__( + self, + *, + image: str = "python:3.12", + flavor: str = "cpu-basic", + timeout: str | None = None, + forward_hf_token: bool = False, + create_retries: int = 3, + create_backoff_s: float = 2.0, + ) -> None: + self._image = image + self._flavor = flavor + self._timeout = timeout + self._forward_hf_token = forward_hf_token + self._create_retries = max(1, int(create_retries)) + self._create_backoff_s = max(0.0, float(create_backoff_s)) + + def create( + self, + *, + timeout_s: int = 900, + envs: dict[str, str] | None = None, + metadata: dict[str, str] | None = None, + image: str | None = None, + ) -> SandboxHandle: + # `hf-sandbox` does not support metadata at create-time yet. + del metadata + + timeout = self._timeout or _format_timeout(timeout_s) + effective_image = image or self._image + last_error: Exception | None = None + + for attempt in range(self._create_retries): + try: + sbx = Sandbox.create( + image=effective_image, + flavor=self._flavor, + timeout=timeout, + forward_hf_token=self._forward_hf_token, + ) + return HFSandboxHandle(sbx, default_envs=envs) + except Exception as exc: # noqa: BLE001 + last_error = exc + if attempt + 1 < self._create_retries: + time.sleep(self._create_backoff_s * (2**attempt)) + + assert last_error is not None + raise HFSandboxCreateError( + f"Failed to create HF sandbox after {self._create_retries} attempts " + f"({type(last_error).__name__})." + ) from last_error + + +def _with_env_prefix(cmd: str, envs: dict[str, str]) -> str: + if not envs: + return cmd + parts: list[str] = [] + for key, value in envs.items(): + if not _ENV_KEY_RE.match(key): + raise ValueError(f"Invalid environment variable name: {key!r}") + parts.append(f"export {key}={shell_quote(str(value))};") + return " ".join(parts) + f" {cmd}" + + +def _normalize_exec_timeout(timeout: float | None) -> int: + if timeout is None: + return 24 * 60 * 60 + return max(1, int(timeout)) + + +def _format_timeout(timeout_s: int) -> str: + timeout_s = max(1, int(timeout_s)) + if timeout_s % 3600 == 0: + return f"{timeout_s // 3600}h" + if timeout_s % 60 == 0: + return f"{timeout_s // 60}m" + return f"{timeout_s}s" + + +def _parse_pid(stdout: str) -> int | None: + for line in reversed(stdout.strip().splitlines()): + raw = line.strip() + if raw.isdigit(): + return int(raw) + return None + + +def _parse_exit_code(raw: str, *, default: int) -> int: + try: + return int(raw.splitlines()[-1].strip()) + except Exception: + return default + + +__all__ = [ + "HFBgJob", + "HFSandboxBackend", + "HFSandboxCreateError", + "HFSandboxError", + "HFSandboxHandle", +] diff --git a/tests/core/test_cli_agent_driver.py b/tests/core/test_cli_agent_driver.py new file mode 100644 index 000000000..f174b6733 --- /dev/null +++ b/tests/core/test_cli_agent_driver.py @@ -0,0 +1,1317 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for the CLI agent driver abstraction (Phase 2). + +Covers: + - Agent spec + event protocols (base.py) + - Agent registry (__init__.py) + - CLIAgentDriver / CLIAgentSession / CLIAgentSessionFactory (cli_driver.py) + - OpenCode adapter spec (opencode.py) + +All tests run without external dependencies (no E2B, no LLM, no network). +""" + +from __future__ import annotations + +import asyncio +import json +import queue as _queue_mod +import threading +import time +from dataclasses import dataclass, field +from typing import Any + +import pytest +from openenv.core.harness.sandbox.base import ExecResult, SandboxHandle + + +# Fake sandbox infrastructure (mirrors test_opencode_env.py pattern) + + +@dataclass +class FakeBgJob: + cmd: str = "" + envs: dict[str, str] | None = None + _exit_code: int = 0 + + @property + def pid(self) -> int: + return 12345 + + def wait(self, timeout: float | None = None) -> int: + return self._exit_code + + def kill(self) -> None: + pass + + +class FakeSandbox: + """In-memory sandbox for unit testing.""" + + def __init__( + self, + *, + install_check_succeeds: bool = False, + healthz_succeeds: bool = True, + ) -> None: + self.sandbox_id = "fake-sandbox-001" + self.written: dict[str, str] = {} + self.executed: list[str] = [] + self.bg_commands: list[tuple[str, dict[str, str] | None]] = [] + self._install_check_succeeds = install_check_succeeds + self._healthz_succeeds = healthz_succeeds + self._killed = False + + def exec( + self, + cmd: str, + *, + envs: dict[str, str] | None = None, + cwd: str | None = None, + timeout: float | None = 60, + ) -> ExecResult: + self.executed.append(cmd) + if cmd == "echo ok": + return ExecResult(exit_code=0, stdout="ok", stderr="") + # install check β€” only standalone version-check commands (short, just + # binary + --version) should be treated as install probes. Multi-part + # setup scripts that happen to end with --version should succeed. + if "--version" in cmd and len(cmd) < 80 and "&&" not in cmd: + if self._install_check_succeeds: + return ExecResult(exit_code=0, stdout="1.0.0", stderr="") + return ExecResult(exit_code=127, stdout="", stderr="not found") + # healthz check + if "healthz" in cmd: + if self._healthz_succeeds: + return ExecResult(exit_code=0, stdout='{"status":"ok"}', stderr="") + return ExecResult(exit_code=7, stdout="", stderr="connection refused") + # All other commands succeed + return ExecResult(exit_code=0, stdout="", stderr="") + + def start_bg( + self, + cmd: str, + *, + envs: dict[str, str] | None = None, + cwd: str | None = None, + ) -> FakeBgJob: + self.bg_commands.append((cmd, envs)) + return FakeBgJob(cmd=cmd, envs=envs) + + def write_text(self, path: str, content: str) -> None: + self.written[path] = content + + def read_text(self, path: str) -> str: + if path not in self.written: + raise FileNotFoundError(f"No such file: {path}") + return self.written[path] + + def exists(self, path: str) -> bool: + return path in self.written + + def kill(self) -> None: + self._killed = True + + +class FakeSandboxBackend: + """Backend that returns FakeSandbox instances.""" + + def __init__( + self, + *, + install_check_succeeds: bool = False, + healthz_succeeds: bool = True, + ) -> None: + self._install_check_succeeds = install_check_succeeds + self._healthz_succeeds = healthz_succeeds + self.created: list[FakeSandbox] = [] + + def create( + self, + *, + timeout_s: int = 900, + envs: dict[str, str] | None = None, + metadata: dict[str, str] | None = None, + ) -> SandboxHandle: + sbx = FakeSandbox( + install_check_succeeds=self._install_check_succeeds, + healthz_succeeds=self._healthz_succeeds, + ) + self.created.append(sbx) + return sbx + + +@dataclass +class FakeTask: + instruction: str = "Write hello.py" + setup_shell: str | None = None + upload_files: dict[str, str] = field(default_factory=dict) + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class FakeConfig: + base_url: str = "https://api.example.com/v1" + api_key: str = "sk-test-key" + model: str = "test-model" + agent_timeout_s: float = 300.0 + sandbox_home: str = "/home/user" + workdir: str | None = None + extra_env: dict[str, str] = field(default_factory=dict) + + +# PR 2.1: Agent Spec and Event Parser Protocols + + +class TestAgentSpecProtocols: + """Tests for base.py data models.""" + + def test_mcp_config_spec_frozen(self): + from openenv.core.harness.agents.base import MCPConfigSpec + + spec = MCPConfigSpec(method="config_file", path_template="{workdir}/mcp.json") + assert spec.method == "config_file" + assert spec.path_template == "{workdir}/mcp.json" + with pytest.raises(AttributeError): + spec.method = "cli_flags" # type: ignore[misc] + + def test_artifact_spec_defaults(self): + from openenv.core.harness.agents.base import ArtifactSpec + + a = ArtifactSpec(path="/logs/agent/out.log") + assert a.format == "text" + assert a.optional is True + + def test_artifact_spec_json(self): + from openenv.core.harness.agents.base import ArtifactSpec + + a = ArtifactSpec(path="/data/traj.json", format="json", optional=False) + assert a.format == "json" + assert a.optional is False + + def test_agent_event_creation(self): + from openenv.core.harness.agents.base import AgentEvent + + e = AgentEvent( + type="tool_call", data={"name": "bash"}, raw='{"type":"tool_call"}' + ) + assert e.type == "tool_call" + assert e.data["name"] == "bash" + + def test_cli_agent_spec_minimal(self): + from openenv.core.harness.agents.base import CLIAgentSpec, MCPConfigSpec + + spec = CLIAgentSpec( + name="test-agent", + install_check_cmd=["test-agent", "--version"], + base_command=["test-agent", "run"], + mcp_config=MCPConfigSpec(method="cli_flags"), + ) + assert spec.name == "test-agent" + assert spec.default_timeout_s == 600.0 + assert spec.setup is None + assert spec.files is None + assert spec.artifacts is None + assert spec.env is None + assert spec.extension_dir_template is None + assert spec.build_command is None + + def test_cli_agent_spec_full(self): + from openenv.core.harness.agents.base import ( + ArtifactSpec, + CLIAgentSpec, + MCPConfigSpec, + ) + + spec = CLIAgentSpec( + name="full-agent", + install_check_cmd=["full-agent", "--version"], + base_command=["full-agent", "exec"], + mcp_config=MCPConfigSpec( + method="config_file", path_template="{workdir}/mcp.json" + ), + default_timeout_s=900.0, + setup="npm install -g full-agent", + files={ + "/task.txt": "hello", + "/dynamic.txt": lambda task, config: task.instruction, + }, + artifacts={ + "log": ArtifactSpec(path="/logs/out.log"), + "traj": ArtifactSpec(path="/logs/traj.json", format="json"), + }, + env={"API_KEY": "{api_key}", "MODEL": "{model}"}, + build_command=lambda spec, config, task, mcp: "full-agent exec", + build_mcp_config=lambda spec, tools, workdir: "{}", + parse_events=lambda line: None, + ) + assert spec.name == "full-agent" + assert spec.artifacts is not None + assert len(spec.artifacts) == 2 + assert spec.files is not None + assert callable(spec.files["/dynamic.txt"]) + + +# PR 2.2: Agent Registry + + +class TestAgentRegistry: + """Tests for the agent registry.""" + + def test_register_and_lookup(self): + from openenv.core.harness.agents import ( + get_agent_spec, + list_agents, + register_agent, + unregister_agent, + ) + from openenv.core.harness.agents.base import CLIAgentSpec, MCPConfigSpec + + spec = CLIAgentSpec( + name="test-registry-agent", + install_check_cmd=["tra", "--version"], + base_command=["tra", "run"], + mcp_config=MCPConfigSpec(method="cli_flags"), + ) + try: + register_agent(spec) + assert "test-registry-agent" in list_agents() + assert get_agent_spec("test-registry-agent") is spec + finally: + unregister_agent("test-registry-agent") + + def test_duplicate_registration_same_object_ok(self): + from openenv.core.harness.agents import register_agent, unregister_agent + from openenv.core.harness.agents.base import CLIAgentSpec, MCPConfigSpec + + spec = CLIAgentSpec( + name="test-dup-ok", + install_check_cmd=["x"], + base_command=["x"], + mcp_config=MCPConfigSpec(method="cli_flags"), + ) + try: + register_agent(spec) + register_agent(spec) # same object β€” should be fine + finally: + unregister_agent("test-dup-ok") + + def test_duplicate_registration_different_object_raises(self): + from openenv.core.harness.agents import register_agent, unregister_agent + from openenv.core.harness.agents.base import CLIAgentSpec, MCPConfigSpec + + spec1 = CLIAgentSpec( + name="test-dup-fail", + install_check_cmd=["x"], + base_command=["x"], + mcp_config=MCPConfigSpec(method="cli_flags"), + ) + spec2 = CLIAgentSpec( + name="test-dup-fail", + install_check_cmd=["y"], + base_command=["y"], + mcp_config=MCPConfigSpec(method="cli_flags"), + ) + try: + register_agent(spec1) + with pytest.raises(ValueError, match="already registered"): + register_agent(spec2) + finally: + unregister_agent("test-dup-fail") + + def test_unknown_agent_raises_keyerror(self): + from openenv.core.harness.agents import get_agent_spec + + with pytest.raises(KeyError, match="Unknown agent"): + get_agent_spec("nonexistent-agent-xyz") + + def test_unregister_returns_spec(self): + from openenv.core.harness.agents import register_agent, unregister_agent + from openenv.core.harness.agents.base import CLIAgentSpec, MCPConfigSpec + + spec = CLIAgentSpec( + name="test-unreg", + install_check_cmd=["x"], + base_command=["x"], + mcp_config=MCPConfigSpec(method="cli_flags"), + ) + register_agent(spec) + removed = unregister_agent("test-unreg") + assert removed is spec + assert unregister_agent("test-unreg") is None + + def test_auto_import_opencode(self): + """Auto-import triggers registration of built-in agents.""" + from openenv.core.harness.agents import get_agent_spec + + spec = get_agent_spec("opencode") + assert spec.name == "opencode" + + +# PR 2.3: CLIAgentDriver / CLIAgentSession / CLIAgentSessionFactory + + +def _make_test_spec(**overrides: Any): + from openenv.core.harness.agents.base import ( + ArtifactSpec, + CLIAgentSpec, + MCPConfigSpec, + ) + + defaults: dict[str, Any] = dict( + name="test-agent", + install_check_cmd=["test-agent", "--version"], + base_command=["test-agent", "run", "--json"], + mcp_config=MCPConfigSpec( + method="config_file", path_template="{workdir}/mcp.json" + ), + setup="apt-get install -y test-agent", + files={ + "/home/user/task/instruction.txt": lambda task, config: task.instruction, + }, + artifacts={ + "agent_log": ArtifactSpec(path="/home/user/logs/agent.log"), + }, + env={ + "API_KEY": "{api_key}", + "BASE_URL": "{base_url}", + "MODEL": "{model}", + }, + build_command=lambda spec, config, task, mcp: ( + f"test-agent run --json '{task.instruction}' 2>&1 | tee /home/user/logs/agent.log" + ), + build_mcp_config=lambda spec, tools, workdir: json.dumps({"tools": []}), + parse_events=lambda line: None, + ) + defaults.update(overrides) + return CLIAgentSpec(**defaults) + + +class TestCLIAgentDriver: + """Tests for the shared CLI agent driver.""" + + def test_create_session_full_lifecycle(self): + from openenv.core.harness.agents.cli_driver import CLIAgentDriver + + spec = _make_test_spec() + backend = FakeSandboxBackend() + driver = CLIAgentDriver(spec=spec, sandbox_backend=backend, mode="black_box") + + task = FakeTask(instruction="Write hello.py") + config = FakeConfig() + session = driver.create_session(task=task, config=config) + + # Verify sandbox was created + assert len(backend.created) == 1 + sbx = backend.created[0] + + # Verify sandbox readiness was probed + assert "echo ok" in sbx.executed + + # Verify install was attempted (agent not pre-installed) + assert any("apt-get install" in cmd for cmd in sbx.executed) + + # Verify files were uploaded + assert "/home/user/task/instruction.txt" in sbx.written + assert sbx.written["/home/user/task/instruction.txt"] == "Write hello.py" + + # Verify MCP config was written + assert "/home/user/workdir/mcp.json" in sbx.written + + # Verify agent was launched as bg process + assert len(sbx.bg_commands) == 1 + bg_cmd, bg_envs = sbx.bg_commands[0] + assert "test-agent run" in bg_cmd + + # Verify env vars were resolved + assert bg_envs is not None + assert bg_envs["API_KEY"] == "sk-test-key" + assert bg_envs["BASE_URL"] == "https://api.example.com/v1" + assert bg_envs["MODEL"] == "test-model" + + # Session API + assert session.initial_messages() == [ + {"role": "user", "content": "Write hello.py"} + ] + assert session.list_tools() == [] + assert session.call_tool("x", {}).error is not None + assert session.wait_for_completion() == 0 + + session.close() + assert sbx._killed + + def test_create_session_honors_configured_workdir_for_mcp_file(self): + from openenv.core.harness.agents.cli_driver import CLIAgentDriver + + spec = _make_test_spec() + backend = FakeSandboxBackend() + driver = CLIAgentDriver(spec=spec, sandbox_backend=backend, mode="black_box") + + config = FakeConfig(workdir="/testbed") + session = driver.create_session(task=FakeTask(), config=config) + + sbx = backend.created[0] + assert "/testbed/mcp.json" in sbx.written + session.close() + + def test_create_session_creates_extension_dir_when_spec_declares_one(self): + from openenv.core.harness.agents.cli_driver import CLIAgentDriver + + spec = _make_test_spec(extension_dir_template="{home}/.agent/extensions") + backend = FakeSandboxBackend() + driver = CLIAgentDriver(spec=spec, sandbox_backend=backend, mode="black_box") + + session = driver.create_session(task=FakeTask(), config=FakeConfig()) + sbx = backend.created[0] + assert any( + cmd.startswith("mkdir -p /home/user/.agent/extensions") + for cmd in sbx.executed + ) + session.close() + + def test_create_session_skips_install_when_prebaked(self): + from openenv.core.harness.agents.cli_driver import CLIAgentDriver + + spec = _make_test_spec() + backend = FakeSandboxBackend(install_check_succeeds=True) + driver = CLIAgentDriver(spec=spec, sandbox_backend=backend, mode="black_box") + + session = driver.create_session( + task=FakeTask(), + config=FakeConfig(), + ) + + sbx = backend.created[0] + # install should have been skipped + assert not any("apt-get install" in cmd for cmd in sbx.executed) + session.close() + + def test_create_session_interception_gate_requires_server(self): + from openenv.core.harness.agents.cli_driver import CLIAgentDriver + + spec = _make_test_spec() + with pytest.raises(ValueError, match="InterceptionServer"): + CLIAgentDriver( + spec=spec, + sandbox_backend=FakeSandboxBackend(), + mode="interception_gate", + ) + + def test_create_session_uploads_task_files(self): + from openenv.core.harness.agents.cli_driver import CLIAgentDriver + + spec = _make_test_spec() + backend = FakeSandboxBackend() + driver = CLIAgentDriver(spec=spec, sandbox_backend=backend, mode="black_box") + + task = FakeTask( + instruction="Write code", + upload_files={"/extra/data.json": '{"key": "value"}'}, + ) + session = driver.create_session(task=task, config=FakeConfig()) + + sbx = backend.created[0] + assert sbx.written["/extra/data.json"] == '{"key": "value"}' + session.close() + + def test_opencode_black_box_api_key_stays_out_of_command_argv(self): + from openenv.core.harness.agents.cli_driver import CLIAgentDriver + from openenv.core.harness.agents.opencode import OPENCODE_SPEC + + secret = "sk-test '$(leak)" + config = FakeConfig(api_key=secret) + backend = FakeSandboxBackend() + driver = CLIAgentDriver( + spec=OPENCODE_SPEC, + sandbox_backend=backend, + mode="black_box", + ) + + session = driver.create_session(task=FakeTask(), config=config) + sbx = backend.created[0] + cmd, envs = sbx.bg_commands[-1] + assert secret not in cmd + assert envs is not None + assert envs["OPENAI_API_KEY"] == secret + session.close() + + def test_opencode_interception_gate_uses_server_secret_not_user_key(self): + from openenv.core.harness.agents.cli_driver import CLIAgentDriver + from openenv.core.harness.agents.interception_server import InterceptionServer + from openenv.core.harness.agents.opencode import OPENCODE_SPEC + + secret = "sk-test '$(leak)" + config = FakeConfig(api_key=secret) + backend = FakeSandboxBackend() + server = InterceptionServer(port=0, secret="gate-secret") + driver = CLIAgentDriver( + spec=OPENCODE_SPEC, + sandbox_backend=backend, + mode="interception_gate", + interception_server=server, + interception_base_url="http://127.0.0.1:8765", + ) + + session = driver.create_session(task=FakeTask(), config=config) + sbx = backend.created[0] + cmd, envs = sbx.bg_commands[-1] + assert secret not in cmd + assert envs is not None + assert envs["OPENAI_API_KEY"] == "gate-secret" + session.close() + + def test_pi_interception_gate_writes_models_json_and_uses_openenv_provider(self): + from openenv.core.harness.agents.cli_driver import CLIAgentDriver + from openenv.core.harness.agents.interception_server import InterceptionServer + from openenv.core.harness.agents.pi import PI_SPEC + + backend = FakeSandboxBackend() + server = InterceptionServer(port=0, secret="gate-secret") + driver = CLIAgentDriver( + spec=PI_SPEC, + sandbox_backend=backend, + mode="interception_gate", + interception_server=server, + interception_base_url="http://127.0.0.1:8765", + ) + + session = driver.create_session(task=FakeTask(), config=FakeConfig()) + sbx = backend.created[0] + + # Command should force the custom provider backed by models.json. + cmd, envs = sbx.bg_commands[-1] + assert "--provider openenv" in cmd + assert envs is not None + assert envs["PI_CODING_AGENT_DIR"] == "/home/user/.pi/agent" + + home_models = "/home/user/.pi/agent/models.json" + root_models = "/root/.pi/agent/models.json" + assert home_models in sbx.written + assert root_models not in sbx.written + + cfg = json.loads(sbx.written[home_models]) + provider = cfg["providers"]["openenv"] + assert provider["api"] == "openai-completions" + assert provider["apiKey"] == "gate-secret" + assert provider["models"][0]["id"] == "test-model" + assert "/rollout/" in provider["baseUrl"] + assert provider["baseUrl"].endswith("/v1") + + session.close() + + def test_pi_interception_gate_uses_explicit_pi_config_dir(self): + from openenv.core.harness.agents.cli_driver import CLIAgentDriver + from openenv.core.harness.agents.interception_server import InterceptionServer + from openenv.core.harness.agents.pi import PI_SPEC + + backend = FakeSandboxBackend() + server = InterceptionServer(port=0, secret="gate-secret") + driver = CLIAgentDriver( + spec=PI_SPEC, + sandbox_backend=backend, + mode="interception_gate", + interception_server=server, + interception_base_url="http://127.0.0.1:8765", + ) + + config = FakeConfig(sandbox_home="/custom/home") + session = driver.create_session(task=FakeTask(), config=config) + sbx = backend.created[0] + + _cmd, envs = sbx.bg_commands[-1] + assert envs is not None + assert envs["PI_CODING_AGENT_DIR"] == "/custom/home/.pi/agent" + assert "/custom/home/.pi/agent/models.json" in sbx.written + assert "/root/.pi/agent/models.json" not in sbx.written + + session.close() + + def test_create_session_runs_task_setup_shell(self): + from openenv.core.harness.agents.cli_driver import CLIAgentDriver + + spec = _make_test_spec() + backend = FakeSandboxBackend() + driver = CLIAgentDriver(spec=spec, sandbox_backend=backend, mode="black_box") + + task = FakeTask( + instruction="Write code", + setup_shell="pip install pandas", + ) + session = driver.create_session(task=task, config=FakeConfig()) + + sbx = backend.created[0] + assert "pip install pandas" in sbx.executed + session.close() + + def test_create_session_with_verifier(self): + from openenv.core.harness import VerifyResult + from openenv.core.harness.agents.cli_driver import CLIAgentDriver + + spec = _make_test_spec() + backend = FakeSandboxBackend() + driver = CLIAgentDriver(spec=spec, sandbox_backend=backend, mode="black_box") + + def verifier(sandbox, task): + return VerifyResult(env_reward=1.0, done=True, metrics={"correct": True}) + + session = driver.create_session( + task=FakeTask(), + config=FakeConfig(), + verifier=verifier, + ) + + result = session.verify([]) + assert result.env_reward == 1.0 + assert result.metrics["correct"] is True + session.close() + + def test_session_verify_without_verifier(self): + from openenv.core.harness.agents.cli_driver import CLIAgentDriver + + spec = _make_test_spec() + backend = FakeSandboxBackend() + driver = CLIAgentDriver(spec=spec, sandbox_backend=backend, mode="black_box") + + session = driver.create_session(task=FakeTask(), config=FakeConfig()) + + result = session.verify([]) + assert result.env_reward is None + assert result.done is True + session.close() + + def test_invalid_mode_raises(self): + from openenv.core.harness.agents.cli_driver import CLIAgentDriver + + spec = _make_test_spec() + with pytest.raises(ValueError, match="Unknown mode"): + CLIAgentDriver( + spec=spec, + sandbox_backend=FakeSandboxBackend(), + mode="invalid", # type: ignore[arg-type] + ) + + +class TestCLIAgentSession: + """Tests for CLIAgentSession.""" + + def test_collect_artifacts_text(self): + from openenv.core.harness.agents.base import ArtifactSpec + from openenv.core.harness.agents.cli_driver import CLIAgentSession + + spec = _make_test_spec( + artifacts={ + "log": ArtifactSpec(path="/logs/out.log"), + }, + ) + sbx = FakeSandbox() + sbx.written["/logs/out.log"] = "line1\nline2\n" + + session = CLIAgentSession( + spec=spec, + sandbox=sbx, + task=FakeTask(), + config=FakeConfig(), + ) + arts = session.collect_artifacts() + assert arts["log"] == "line1\nline2\n" + + def test_collect_artifacts_json(self): + from openenv.core.harness.agents.base import ArtifactSpec + from openenv.core.harness.agents.cli_driver import CLIAgentSession + + spec = _make_test_spec( + artifacts={ + "traj": ArtifactSpec(path="/logs/traj.json", format="json"), + }, + ) + sbx = FakeSandbox() + sbx.written["/logs/traj.json"] = json.dumps({"steps": [1, 2, 3]}) + + session = CLIAgentSession( + spec=spec, + sandbox=sbx, + task=FakeTask(), + config=FakeConfig(), + ) + arts = session.collect_artifacts() + assert arts["traj"] == {"steps": [1, 2, 3]} + + def test_collect_artifacts_jsonl(self): + from openenv.core.harness.agents.base import ArtifactSpec + from openenv.core.harness.agents.cli_driver import CLIAgentSession + + spec = _make_test_spec( + artifacts={ + "events": ArtifactSpec(path="/logs/events.jsonl", format="jsonl"), + }, + ) + sbx = FakeSandbox() + sbx.written["/logs/events.jsonl"] = '{"a":1}\n{"b":2}\n' + + session = CLIAgentSession( + spec=spec, + sandbox=sbx, + task=FakeTask(), + config=FakeConfig(), + ) + arts = session.collect_artifacts() + assert arts["events"] == [{"a": 1}, {"b": 2}] + + def test_collect_artifacts_missing_optional(self): + from openenv.core.harness.agents.base import ArtifactSpec + from openenv.core.harness.agents.cli_driver import CLIAgentSession + + spec = _make_test_spec( + artifacts={ + "log": ArtifactSpec(path="/missing/file.log", optional=True), + }, + ) + sbx = FakeSandbox() + session = CLIAgentSession( + spec=spec, + sandbox=sbx, + task=FakeTask(), + config=FakeConfig(), + ) + arts = session.collect_artifacts() + assert "log" not in arts + + def test_collect_artifacts_missing_required_raises(self): + from openenv.core.harness.agents.base import ArtifactSpec + from openenv.core.harness.agents.cli_driver import CLIAgentSession + + spec = _make_test_spec( + artifacts={ + "log": ArtifactSpec(path="/missing/file.log", optional=False), + }, + ) + sbx = FakeSandbox() + session = CLIAgentSession( + spec=spec, + sandbox=sbx, + task=FakeTask(), + config=FakeConfig(), + ) + with pytest.raises(FileNotFoundError): + session.collect_artifacts() + + def test_close_kills_sandbox_and_jobs(self): + from openenv.core.harness.agents.cli_driver import CLIAgentSession + + spec = _make_test_spec() + sbx = FakeSandbox() + agent_job = FakeBgJob() + + session = CLIAgentSession( + spec=spec, + sandbox=sbx, + task=FakeTask(), + config=FakeConfig(), + agent_bg_job=agent_job, + ) + session.close() + assert sbx._killed + assert session._agent_bg_job is None + + @pytest.mark.asyncio + async def test_next_request_handles_missing_intercept_without_keyerror(self): + from openenv.core.harness.agents.cli_driver import CLIAgentSession + from openenv.core.harness.agents.interception_server import InterceptionServer + + spec = _make_test_spec() + sbx = FakeSandbox() + q: _queue_mod.Queue[str] = _queue_mod.Queue() + q.put("req_missing") + + session = CLIAgentSession( + spec=spec, + sandbox=sbx, + task=FakeTask(), + config=FakeConfig(), + agent_bg_job=FakeBgJob(), + interception_server=InterceptionServer(secret="s"), + interception_rollout_id="rollout-1", + interception_queue=q, + ) + + # Missing request IDs can happen if unregister_rollout races with queue.get(). + assert await session.next_request(timeout_s=0.2) is None + + def test_next_request_soak_cross_loop_queue_get(self): + """Soak test cross-loop request dequeueing via queue.Queue. + + Exercises the worker pattern that used to be unsafe with asyncio.Queue: + repeatedly call next_request() from fresh event loops (asyncio.run) + while request IDs are pushed from another thread. + """ + from openenv.core.harness.agents.cli_driver import CLIAgentSession + from openenv.core.harness.agents.interception_server import InterceptionServer + + spec = _make_test_spec() + sbx = FakeSandbox() + server = InterceptionServer(secret="s") + request_queue = server.register_rollout("rollout-soak") + + session = CLIAgentSession( + spec=spec, + sandbox=sbx, + task=FakeTask(), + config=FakeConfig(), + interception_server=server, + interception_rollout_id="rollout-soak", + interception_queue=request_queue, + ) + + total_requests = 200 + consumed: list[str] = [] + failures: list[BaseException] = [] + + def _consumer() -> None: + try: + for _ in range(total_requests): + intercept = asyncio.run(session.next_request(timeout_s=2.0)) + assert intercept is not None + request_id = intercept["request_id"] + consumed.append(request_id) + with server._state_lock: + server.intercepts.pop(request_id, None) + except BaseException as exc: # pragma: no cover - assertion path + failures.append(exc) + + def _producer() -> None: + try: + for i in range(total_requests): + request_id = f"req_soak_{i:04d}" + with server._state_lock: + server.intercepts[request_id] = { + "request_id": request_id, + "messages": [{"role": "user", "content": "ping"}], + } + request_queue.put_nowait(request_id) + if i % 10 == 0: + time.sleep(0.001) + except BaseException as exc: # pragma: no cover - unexpected + failures.append(exc) + + consumer_t = threading.Thread(target=_consumer, name="soak-consumer") + producer_t = threading.Thread(target=_producer, name="soak-producer") + + consumer_t.start() + producer_t.start() + + producer_t.join(timeout=10) + consumer_t.join(timeout=15) + + assert not producer_t.is_alive(), "producer thread hung" + assert not consumer_t.is_alive(), "consumer thread hung" + assert not failures + assert len(consumed) == total_requests + assert len(set(consumed)) == total_requests + + session.close() + + +class TestCLIAgentSessionFactory: + """Tests for the ResourceSessionFactory wrapper.""" + + def test_factory_creates_sessions(self): + from openenv.core.harness.agents.cli_driver import CLIAgentSessionFactory + + spec = _make_test_spec() + backend = FakeSandboxBackend() + + factory = CLIAgentSessionFactory( + spec=spec, + config=FakeConfig(), + sandbox_backend=backend, + mode="black_box", + ) + + session = factory.create(task=FakeTask()) + assert len(backend.created) == 1 + assert session.initial_messages()[0]["content"] == "Write hello.py" + session.close() + + def test_factory_with_verifier(self): + from openenv.core.harness import VerifyResult + from openenv.core.harness.agents.cli_driver import CLIAgentSessionFactory + + spec = _make_test_spec() + backend = FakeSandboxBackend() + + def verifier(sandbox, task): + return VerifyResult(env_reward=0.5, done=True) + + factory = CLIAgentSessionFactory( + spec=spec, + config=FakeConfig(), + sandbox_backend=backend, + mode="black_box", + verifier=verifier, + ) + + session = factory.create(task=FakeTask()) + result = session.verify([]) + assert result.env_reward == 0.5 + session.close() + + def test_factory_implements_resource_session_factory(self): + from openenv.core.harness import ResourceSessionFactory + from openenv.core.harness.agents.cli_driver import CLIAgentSessionFactory + + assert issubclass(CLIAgentSessionFactory, ResourceSessionFactory) + + def test_session_implements_resource_session(self): + from openenv.core.harness import ResourceSession + from openenv.core.harness.agents.cli_driver import CLIAgentSession + + assert issubclass(CLIAgentSession, ResourceSession) + + +# PR 2.4: OpenCode Adapter Spec + + +class TestOpenCodeSpec: + """Tests for the OpenCode declarative spec.""" + + def test_spec_is_registered(self): + from openenv.core.harness.agents import get_agent_spec + + spec = get_agent_spec("opencode") + assert spec.name == "opencode" + + def test_spec_fields(self): + from openenv.core.harness.agents.opencode import OPENCODE_SPEC + + assert OPENCODE_SPEC.name == "opencode" + assert OPENCODE_SPEC.install_check_cmd == [ + "/home/user/.opencode/bin/opencode", + "--version", + ] + assert OPENCODE_SPEC.default_timeout_s == 900.0 + assert OPENCODE_SPEC.mcp_config.method == "config_file" + assert OPENCODE_SPEC.mcp_config.path_template is not None + assert "{home}" in OPENCODE_SPEC.mcp_config.path_template + assert OPENCODE_SPEC.artifacts is not None + assert "agent_log" in OPENCODE_SPEC.artifacts + assert OPENCODE_SPEC.artifacts["agent_log"].format == "jsonl" + + def test_build_command(self): + from openenv.core.harness.agents.opencode import OPENCODE_SPEC + + @dataclass + class OcConfig: + sandbox_home: str = "/home/user" + run_format: str = "json" + + assert OPENCODE_SPEC.build_command is not None + cmd = OPENCODE_SPEC.build_command( + OPENCODE_SPEC, + OcConfig(), + FakeTask(instruction="Write hello.py"), + None, + ) + assert "opencode run" in cmd + assert "--format json" in cmd + assert "/home/user/task/instruction.md" in cmd + + def test_build_command_quotes_paths(self): + from openenv.core.harness.agents.opencode import OPENCODE_SPEC + + @dataclass + class OcConfig: + sandbox_home: str = "/home/user with space" + run_format: str = "json" + + assert OPENCODE_SPEC.build_command is not None + cmd = OPENCODE_SPEC.build_command( + OPENCODE_SPEC, + OcConfig(), + FakeTask(instruction="Write hello.py"), + None, + ) + assert "cd '/home/user with space/workdir'" in cmd + assert "cat '/home/user with space/task/instruction.md'" in cmd + assert "tee '/home/user with space/logs/agent/opencode.jsonl'" in cmd + + def test_build_mcp_config(self): + from openenv.core.harness.agents.opencode import OPENCODE_SPEC + + assert OPENCODE_SPEC.build_mcp_config is not None + config_str = OPENCODE_SPEC.build_mcp_config( + OPENCODE_SPEC, + [], + "/home/user/workdir", + ) + # OpenCode returns empty string because the config is written + # via spec.files using _build_opencode_config_file instead. + assert config_str == "" + + def test_parse_events_assistant(self): + from openenv.core.harness.agents.opencode import OPENCODE_SPEC + + assert OPENCODE_SPEC.parse_events is not None + line = json.dumps({"type": "assistant", "content": "hello"}) + event = OPENCODE_SPEC.parse_events(line) + assert event is not None + assert event.type == "assistant" + + def test_parse_events_tool_call(self): + from openenv.core.harness.agents.opencode import OPENCODE_SPEC + + assert OPENCODE_SPEC.parse_events is not None + line = json.dumps({"type": "tool_call", "name": "bash", "args": {}}) + event = OPENCODE_SPEC.parse_events(line) + assert event is not None + assert event.type == "tool_call" + + def test_parse_events_error(self): + from openenv.core.harness.agents.opencode import OPENCODE_SPEC + + assert OPENCODE_SPEC.parse_events is not None + line = json.dumps({"type": "error", "message": "boom"}) + event = OPENCODE_SPEC.parse_events(line) + assert event is not None + assert event.type == "error" + + def test_parse_events_done(self): + from openenv.core.harness.agents.opencode import OPENCODE_SPEC + + assert OPENCODE_SPEC.parse_events is not None + line = json.dumps({"type": "done"}) + event = OPENCODE_SPEC.parse_events(line) + assert event is not None + assert event.type == "done" + + def test_parse_events_invalid_json(self): + from openenv.core.harness.agents.opencode import OPENCODE_SPEC + + assert OPENCODE_SPEC.parse_events is not None + assert OPENCODE_SPEC.parse_events("not json") is None + assert OPENCODE_SPEC.parse_events("") is None + + def test_build_env_vars(self): + from openenv.core.harness.agents.opencode import OPENCODE_SPEC + + config = FakeConfig() + config.extra_env = {"EXTRA": "val"} + assert OPENCODE_SPEC.build_env_vars is not None + envs = OPENCODE_SPEC.build_env_vars(OPENCODE_SPEC, config) + assert envs["OPENAI_BASE_URL"] == "https://api.example.com/v1" + assert envs["OPENAI_API_KEY"] == "sk-test-key" + assert envs["OPENCODE_CONFIG"] == "/home/user/.config/opencode/opencode.json" + assert envs["EXTRA"] == "val" + + def test_files_instruction_resolver(self): + from openenv.core.harness.agents.opencode import OPENCODE_SPEC + + task = FakeTask(instruction="Build a REST API") + config = FakeConfig() + assert OPENCODE_SPEC.files is not None + instruction_fn = OPENCODE_SPEC.files["/home/user/task/instruction.md"] + assert callable(instruction_fn) + assert instruction_fn(task, config) == "Build a REST API" + + def test_files_system_prompt_resolver(self): + from openenv.core.harness.agents.opencode import OPENCODE_SPEC + + task = FakeTask() + config = FakeConfig() + assert OPENCODE_SPEC.files is not None + system_fn = OPENCODE_SPEC.files["/home/user/task/system.md"] + assert callable(system_fn) + # No system prompt on FakeConfig β†’ returns None + assert system_fn(task, config) is None + + def test_opencode_driver_integration(self): + """End-to-end: create a session using the OpenCode spec via the driver.""" + from openenv.core.harness.agents.cli_driver import CLIAgentSessionFactory + from openenv.core.harness.agents.opencode import OPENCODE_SPEC + + backend = FakeSandboxBackend() + factory = CLIAgentSessionFactory( + spec=OPENCODE_SPEC, + config=FakeConfig(), + sandbox_backend=backend, + mode="black_box", + ) + + session = factory.create(task=FakeTask(instruction="Hello")) + assert session.spec.name == "opencode" + assert session.initial_messages()[0]["content"] == "Hello" + + sbx = backend.created[0] + # Instruction file should have been written + assert sbx.written.get("/home/user/task/instruction.md") == "Hello" + + session.close() + + +class TestPiSpec: + def test_build_command_quotes_paths(self): + from openenv.core.harness.agents.pi import PI_SPEC + + @dataclass + class PiConfig: + sandbox_home: str = "/home/user with space" + provider: str = "openai" + model: str = "model/name" + thinking: str = "off" + + assert PI_SPEC.build_command is not None + cmd = PI_SPEC.build_command( + PI_SPEC, + PiConfig(), + FakeTask(instruction="Write hello.py"), + None, + ) + assert "cd '/home/user with space/workdir'" in cmd + assert "-p @'/home/user with space/task/instruction.txt'" in cmd + assert "tee '/home/user with space/logs/agent/pi.txt'" in cmd + + def test_build_command_uses_config_workdir_when_present(self): + from openenv.core.harness.agents.pi import PI_SPEC + + @dataclass + class PiConfig: + sandbox_home: str = "/home/user" + workdir: str = "/testbed" + provider: str = "openai" + model: str = "model/name" + thinking: str = "off" + + assert PI_SPEC.build_command is not None + cmd = PI_SPEC.build_command( + PI_SPEC, + PiConfig(), + FakeTask(instruction="Write hello.py"), + None, + ) + assert "cd /testbed" in cmd + + def test_spec_declares_extension_dir_template(self): + from openenv.core.harness.agents.pi import PI_SPEC + + assert PI_SPEC.extension_dir_template == "{home}/.pi/agent/extensions" + + +# Env var resolution + + +class TestEnvVarResolution: + """Tests for environment variable placeholder resolution.""" + + def test_resolve_placeholders(self): + from openenv.core.harness.agents.cli_driver import CLIAgentDriver + + spec = _make_test_spec( + env={ + "KEY": "{api_key}", + "URL": "{base_url}", + "MDL": "{model}", + "STATIC": "fixed_value", + }, + build_env_vars=None, # use placeholder resolution + ) + driver = CLIAgentDriver( + spec=spec, + sandbox_backend=FakeSandboxBackend(), + mode="black_box", + ) + envs = driver._resolve_env_vars(FakeConfig()) + assert envs["KEY"] == "sk-test-key" + assert envs["URL"] == "https://api.example.com/v1" + assert envs["MDL"] == "test-model" + assert envs["STATIC"] == "fixed_value" + + def test_resolve_with_proxy_override(self): + from openenv.core.harness.agents.cli_driver import CLIAgentDriver + + spec = _make_test_spec( + env={"URL": "{base_url}"}, + build_env_vars=None, + ) + driver = CLIAgentDriver( + spec=spec, + sandbox_backend=FakeSandboxBackend(), + mode="black_box", + ) + envs = driver._resolve_env_vars( + FakeConfig(), + base_url_override="http://127.0.0.1:7000/v1", + ) + assert envs["URL"] == "http://127.0.0.1:7000/v1" + + def test_build_env_vars_hook_takes_precedence(self): + from openenv.core.harness.agents.cli_driver import CLIAgentDriver + + def custom_env(spec, config): + return {"CUSTOM": "yes", "MODEL": config.model} + + spec = _make_test_spec( + env={"SHOULD_NOT": "appear"}, + build_env_vars=custom_env, + ) + driver = CLIAgentDriver( + spec=spec, + sandbox_backend=FakeSandboxBackend(), + mode="black_box", + ) + envs = driver._resolve_env_vars(FakeConfig()) + assert envs == {"CUSTOM": "yes", "MODEL": "test-model"} + assert "SHOULD_NOT" not in envs + + def test_empty_env_dict(self): + from openenv.core.harness.agents.cli_driver import CLIAgentDriver + + spec = _make_test_spec(env=None, build_env_vars=None) + driver = CLIAgentDriver( + spec=spec, + sandbox_backend=FakeSandboxBackend(), + mode="black_box", + ) + envs = driver._resolve_env_vars(FakeConfig()) + assert envs == {} + + +# Multiple setup commands + + +class TestMultiStepSetup: + """Tests for specs with multi-step setup commands.""" + + def test_list_of_setup_commands(self): + from openenv.core.harness.agents.cli_driver import CLIAgentDriver + + spec = _make_test_spec( + setup=[ + "apt-get update", + "apt-get install -y nodejs", + "npm install -g test-agent", + ], + ) + backend = FakeSandboxBackend() + driver = CLIAgentDriver(spec=spec, sandbox_backend=backend, mode="black_box") + + session = driver.create_session(task=FakeTask(), config=FakeConfig()) + sbx = backend.created[0] + + # All three setup commands should have been executed + assert any("apt-get update" in cmd for cmd in sbx.executed) + assert any("apt-get install" in cmd for cmd in sbx.executed) + assert any("npm install" in cmd for cmd in sbx.executed) + session.close() + + def test_no_setup_and_not_installed_raises(self): + from openenv.core.harness.agents.cli_driver import CLIAgentDriver + + spec = _make_test_spec(setup=None) + backend = FakeSandboxBackend(install_check_succeeds=False) + driver = CLIAgentDriver(spec=spec, sandbox_backend=backend, mode="black_box") + + with pytest.raises(RuntimeError, match="not installed"): + driver.create_session(task=FakeTask(), config=FakeConfig()) diff --git a/tests/core/test_docker_sandbox_backend.py b/tests/core/test_docker_sandbox_backend.py new file mode 100644 index 000000000..c309e63b3 --- /dev/null +++ b/tests/core/test_docker_sandbox_backend.py @@ -0,0 +1,366 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for the Docker sandbox backend. + +Tests marked ``@pytest.mark.docker`` require a running Docker daemon and +are skipped in CI when Docker is unavailable. They exercise the real +``docker run`` / ``docker exec`` / ``docker rm`` lifecycle. +""" + +from __future__ import annotations + +import subprocess +import time + +import pytest + +_DOCKER_AVAILABLE = False +try: + subprocess.run( + ["docker", "version"], + capture_output=True, + check=True, + timeout=5, + ) + _DOCKER_AVAILABLE = True +except Exception: + pass + +docker = pytest.mark.skipif(not _DOCKER_AVAILABLE, reason="Docker not available") + + +class TestDockerSandboxBackendUnit: + """Unit tests that don't require Docker.""" + + def test_import(self): + from openenv.core.harness.sandbox.docker_backend import ( + DockerBgJob, + DockerSandboxBackend, + DockerSandboxHandle, + ) + + assert DockerSandboxBackend is not None + assert DockerSandboxHandle is not None + assert DockerBgJob is not None + + def test_exported_from_package(self): + from openenv.core.harness.sandbox import ( + DockerBgJob, + DockerSandboxBackend, + DockerSandboxHandle, + ) + + assert DockerSandboxBackend is not None + assert DockerSandboxHandle is not None + assert DockerBgJob is not None + + def test_create_sandbox_backend_factory(self): + from openenv.core.harness.sandbox import create_sandbox_backend + + assert callable(create_sandbox_backend) + + def test_create_sandbox_backend_unknown_raises(self): + from openenv.core.harness.sandbox import create_sandbox_backend + + with pytest.raises(ValueError, match="Unknown sandbox backend"): + create_sandbox_backend("bogus") # type: ignore[arg-type] + + def test_create_adds_host_gateway_and_supports_image_override(self, monkeypatch): + import openenv.core.harness.sandbox.docker_backend as docker_backend + + calls: list[list[str]] = [] + + def _fake_run(cmd, *args, **kwargs): + calls.append(list(cmd)) + if cmd[:2] == ["docker", "version"]: + return subprocess.CompletedProcess(cmd, 0, "", "") + if cmd[:2] == ["docker", "run"]: + return subprocess.CompletedProcess( + cmd, + 0, + "1234567890abcdef\n", + "", + ) + return subprocess.CompletedProcess(cmd, 0, "", "") + + monkeypatch.setattr(docker_backend.subprocess, "run", _fake_run) + + backend = docker_backend.DockerSandboxBackend(image="base:latest") + handle = backend.create(image="override:latest") + assert handle.sandbox_id == "1234567890ab" + + run_cmds = [cmd for cmd in calls if cmd[:2] == ["docker", "run"]] + assert len(run_cmds) == 1 + run_cmd = run_cmds[0] + assert "--add-host" in run_cmd + assert "host.docker.internal:host-gateway" in run_cmd + assert "override:latest" in run_cmd + + @pytest.mark.skipif(_DOCKER_AVAILABLE, reason="Only test error when Docker missing") + def test_backend_raises_without_docker(self): + from openenv.core.harness.sandbox.docker_backend import DockerSandboxBackend + + with pytest.raises(RuntimeError, match="Docker daemon"): + DockerSandboxBackend() + + +@docker +class TestDockerSandboxBackendIntegration: + """Integration tests against a real Docker daemon.""" + + def test_create_and_kill(self): + from openenv.core.harness.sandbox.docker_backend import DockerSandboxBackend + + backend = DockerSandboxBackend(image="ubuntu:22.04") + sandbox = backend.create(timeout_s=60) + try: + assert sandbox.sandbox_id + assert len(sandbox.sandbox_id) == 12 + finally: + sandbox.kill() + + def test_exec_echo(self): + from openenv.core.harness.sandbox.docker_backend import DockerSandboxBackend + + backend = DockerSandboxBackend(image="ubuntu:22.04") + sandbox = backend.create(timeout_s=60) + try: + result = sandbox.exec("echo hello world") + assert result.exit_code == 0 + assert "hello world" in result.stdout + finally: + sandbox.kill() + + def test_exec_nonzero_exit(self): + from openenv.core.harness.sandbox.docker_backend import DockerSandboxBackend + + backend = DockerSandboxBackend(image="ubuntu:22.04") + sandbox = backend.create(timeout_s=60) + try: + result = sandbox.exec("exit 42") + assert result.exit_code == 42 + finally: + sandbox.kill() + + def test_exec_with_env(self): + from openenv.core.harness.sandbox.docker_backend import DockerSandboxBackend + + backend = DockerSandboxBackend(image="ubuntu:22.04") + sandbox = backend.create(timeout_s=60) + try: + result = sandbox.exec("echo $MY_VAR", envs={"MY_VAR": "test123"}) + assert result.exit_code == 0 + assert "test123" in result.stdout + finally: + sandbox.kill() + + def test_exec_with_cwd(self): + from openenv.core.harness.sandbox.docker_backend import DockerSandboxBackend + + backend = DockerSandboxBackend(image="ubuntu:22.04") + sandbox = backend.create(timeout_s=60) + try: + result = sandbox.exec("pwd", cwd="/tmp") + assert result.exit_code == 0 + assert "/tmp" in result.stdout + finally: + sandbox.kill() + + def test_write_and_read_text(self): + from openenv.core.harness.sandbox.docker_backend import DockerSandboxBackend + + backend = DockerSandboxBackend(image="ubuntu:22.04") + sandbox = backend.create(timeout_s=60) + try: + sandbox.write_text("/tmp/test.txt", "hello from test") + content = sandbox.read_text("/tmp/test.txt") + assert content == "hello from test" + finally: + sandbox.kill() + + def test_write_creates_parent_dirs(self): + from openenv.core.harness.sandbox.docker_backend import DockerSandboxBackend + + backend = DockerSandboxBackend(image="ubuntu:22.04") + sandbox = backend.create(timeout_s=60) + try: + sandbox.write_text("/home/user/deep/nested/file.txt", "nested content") + content = sandbox.read_text("/home/user/deep/nested/file.txt") + assert content == "nested content" + finally: + sandbox.kill() + + def test_write_special_chars(self): + from openenv.core.harness.sandbox.docker_backend import DockerSandboxBackend + + backend = DockerSandboxBackend(image="ubuntu:22.04") + sandbox = backend.create(timeout_s=60) + try: + text = "line1\nline2\n'quotes' and \"doubles\" and $vars" + sandbox.write_text("/tmp/special.txt", text) + content = sandbox.read_text("/tmp/special.txt") + assert content == text + finally: + sandbox.kill() + + def test_read_missing_file_raises(self): + from openenv.core.harness.sandbox.docker_backend import DockerSandboxBackend + + backend = DockerSandboxBackend(image="ubuntu:22.04") + sandbox = backend.create(timeout_s=60) + try: + with pytest.raises(FileNotFoundError): + sandbox.read_text("/nonexistent/path.txt") + finally: + sandbox.kill() + + def test_exists(self): + from openenv.core.harness.sandbox.docker_backend import DockerSandboxBackend + + backend = DockerSandboxBackend(image="ubuntu:22.04") + sandbox = backend.create(timeout_s=60) + try: + assert not sandbox.exists("/tmp/check_me.txt") + sandbox.write_text("/tmp/check_me.txt", "exists") + assert sandbox.exists("/tmp/check_me.txt") + finally: + sandbox.kill() + + def test_start_bg_and_wait(self): + from openenv.core.harness.sandbox.docker_backend import DockerSandboxBackend + + backend = DockerSandboxBackend(image="ubuntu:22.04") + sandbox = backend.create(timeout_s=60) + try: + job = sandbox.start_bg("sleep 1 && echo done > /tmp/bg_out.txt") + exit_code = job.wait(timeout=10) + assert exit_code == 0 + content = sandbox.read_text("/tmp/bg_out.txt") + assert "done" in content + finally: + sandbox.kill() + + def test_start_bg_kill(self): + from openenv.core.harness.sandbox.docker_backend import DockerSandboxBackend + + backend = DockerSandboxBackend(image="ubuntu:22.04") + sandbox = backend.create(timeout_s=60) + try: + job = sandbox.start_bg("sleep 300") + time.sleep(0.5) + job.kill() + # Should be able to wait without hanging + exit_code = job.wait(timeout=5) + # Exit code after kill is implementation-defined + assert isinstance(exit_code, int) + finally: + sandbox.kill() + + def test_start_bg_timeout(self): + from openenv.core.harness.sandbox.docker_backend import DockerSandboxBackend + + backend = DockerSandboxBackend(image="ubuntu:22.04") + sandbox = backend.create(timeout_s=60) + try: + job = sandbox.start_bg("sleep 300") + with pytest.raises(TimeoutError): + job.wait(timeout=1) + job.kill() + finally: + sandbox.kill() + + def test_create_with_envs(self): + from openenv.core.harness.sandbox.docker_backend import DockerSandboxBackend + + backend = DockerSandboxBackend(image="ubuntu:22.04") + sandbox = backend.create(timeout_s=60, envs={"INIT_VAR": "from_create"}) + try: + result = sandbox.exec("echo $INIT_VAR") + assert "from_create" in result.stdout + finally: + sandbox.kill() + + def test_create_with_metadata(self): + from openenv.core.harness.sandbox.docker_backend import DockerSandboxBackend + + backend = DockerSandboxBackend(image="ubuntu:22.04") + sandbox = backend.create( + timeout_s=60, + metadata={"episode_id": "ep-123"}, + ) + try: + result = subprocess.run( + [ + "docker", + "inspect", + "--format", + '{{index .Config.Labels "openenv.episode_id"}}', + sandbox._container_id, + ], + capture_output=True, + text=True, + ) + assert "ep-123" in result.stdout + finally: + sandbox.kill() + + def test_factory_creates_docker_backend(self): + from openenv.core.harness.sandbox import create_sandbox_backend + + backend = create_sandbox_backend("docker", image="ubuntu:22.04") + sandbox = backend.create(timeout_s=60) + try: + result = sandbox.exec("echo ok") + assert result.exit_code == 0 + finally: + sandbox.kill() + + def test_satisfies_sandbox_handle_protocol(self): + from openenv.core.harness.sandbox import SandboxHandle + from openenv.core.harness.sandbox.docker_backend import DockerSandboxBackend + + backend = DockerSandboxBackend(image="ubuntu:22.04") + sandbox = backend.create(timeout_s=60) + try: + assert isinstance(sandbox, SandboxHandle) + assert hasattr(sandbox, "sandbox_id") + assert hasattr(sandbox, "exec") + assert hasattr(sandbox, "start_bg") + assert hasattr(sandbox, "write_text") + assert hasattr(sandbox, "read_text") + assert hasattr(sandbox, "exists") + assert hasattr(sandbox, "kill") + finally: + sandbox.kill() + + def test_satisfies_sandbox_backend_protocol(self): + from openenv.core.harness.sandbox import SandboxBackend + from openenv.core.harness.sandbox.docker_backend import DockerSandboxBackend + + assert issubclass(DockerSandboxBackend, SandboxBackend) + + def test_satisfies_bg_job_protocol(self): + from openenv.core.harness.sandbox.docker_backend import DockerSandboxBackend + + backend = DockerSandboxBackend(image="ubuntu:22.04") + sandbox = backend.create(timeout_s=60) + try: + job = sandbox.start_bg("sleep 1") + assert hasattr(job, "pid") + assert hasattr(job, "wait") + assert hasattr(job, "kill") + job.kill() + finally: + sandbox.kill() + + def test_kill_is_idempotent(self): + from openenv.core.harness.sandbox.docker_backend import DockerSandboxBackend + + backend = DockerSandboxBackend(image="ubuntu:22.04") + sandbox = backend.create(timeout_s=60) + sandbox.kill() + sandbox.kill() # should not raise diff --git a/tests/core/test_harness_adapters.py b/tests/core/test_harness_adapters.py new file mode 100644 index 000000000..1766b8ad4 --- /dev/null +++ b/tests/core/test_harness_adapters.py @@ -0,0 +1,125 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for currently implemented harness adapters (OpenCode + Pi).""" + +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from typing import Any + +import pytest + + +@dataclass +class FakeTask: + instruction: str = "Write hello.py" + setup_shell: str | None = None + upload_files: dict[str, str] = field(default_factory=dict) + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class FakeConfig: + base_url: str = "https://api.example.com/v1" + api_key: str = "sk-test" + model: str = "test-model" + agent_timeout_s: float = 300.0 + sandbox_home: str = "/home/user" + system_prompt: str | None = None + + +class TestPiSpec: + def test_registered(self): + from openenv.core.harness.agents import get_agent_spec + + spec = get_agent_spec("pi") + assert spec.name == "pi" + + def test_fields(self): + from openenv.core.harness.agents.pi import PI_SPEC + + assert PI_SPEC.install_check_cmd == ["pi", "--version"] + assert PI_SPEC.mcp_config.method == "config_file" + assert PI_SPEC.mcp_config.path_template is not None + assert ".mcp.json" in PI_SPEC.mcp_config.path_template + assert PI_SPEC.build_env_vars is not None + + def test_build_env_vars_provider_specific_api_key(self): + from openenv.core.harness.agents.pi import PI_SPEC + + @dataclass + class PiConfig: + provider: str + api_key: str = "secret" + base_url: str = "https://api.example.com/v1" + extra_env: dict[str, str] = field(default_factory=dict) + + assert PI_SPEC.build_env_vars is not None + + hf_env = PI_SPEC.build_env_vars(PI_SPEC, PiConfig(provider="huggingface")) + assert hf_env["HF_TOKEN"] == "secret" + assert "OPENAI_API_KEY" not in hf_env + + oa_env = PI_SPEC.build_env_vars(PI_SPEC, PiConfig(provider="openai")) + assert oa_env["OPENAI_API_KEY"] == "secret" + assert "HF_TOKEN" not in oa_env + + def test_build_env_vars_rejects_unknown_provider(self): + from openenv.core.harness.agents.pi import PI_SPEC + + @dataclass + class PiConfig: + provider: str = "unknown" + api_key: str = "secret" + base_url: str = "https://api.example.com/v1" + extra_env: dict[str, str] = field(default_factory=dict) + + assert PI_SPEC.build_env_vars is not None + with pytest.raises(ValueError, match="Unsupported pi provider"): + PI_SPEC.build_env_vars(PI_SPEC, PiConfig()) + + def test_build_command(self): + from openenv.core.harness.agents.pi import PI_SPEC + + assert PI_SPEC.build_command is not None + cmd = PI_SPEC.build_command(PI_SPEC, FakeConfig(), FakeTask(), None) + assert "pi --no-session" in cmd + assert "--no-context-files" in cmd + + def test_build_mcp_config(self): + from openenv.core.harness.agents.pi import PI_SPEC + + assert PI_SPEC.build_mcp_config is not None + content = PI_SPEC.build_mcp_config(PI_SPEC, [], "/workdir") + assert "mcpServers" in json.loads(content) + + +class TestOpenCodeSpec: + def test_registered(self): + from openenv.core.harness.agents import get_agent_spec + + spec = get_agent_spec("opencode") + assert spec.name == "opencode" + + +class TestRegistryAutoImport: + @pytest.mark.parametrize("name", ["pi", "opencode"]) + def test_auto_import(self, name): + from openenv.core.harness.agents import get_agent_spec + + spec = get_agent_spec(name) + assert spec.name == name + + def test_list_agents_includes_current(self): + import openenv.core.harness.agents.opencode # noqa: F401 + import openenv.core.harness.agents.pi # noqa: F401 + from openenv.core.harness.agents import list_agents + + agents = list_agents() + for name in ["opencode", "pi"]: + assert name in agents, f"{name} not in {agents}" diff --git a/tests/core/test_hf_sandbox_backend.py b/tests/core/test_hf_sandbox_backend.py new file mode 100644 index 000000000..9cd94b5d8 --- /dev/null +++ b/tests/core/test_hf_sandbox_backend.py @@ -0,0 +1,225 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Unit tests for the HF sandbox backend. + +These tests mock ``hf-sandbox`` so they run without network or HF credentials. +""" + +from __future__ import annotations + +import importlib +import re +import subprocess +import sys +import types +from dataclasses import dataclass, field + +import pytest + + +@dataclass +class _FakeSandbox: + job_id: str + files: dict[str, str] = field(default_factory=dict) + marker_files: dict[str, str] = field(default_factory=dict) + bg_jobs: dict[int, dict] = field(default_factory=dict) + next_pid: int = 1000 + terminated: bool = False + + def exec( + self, + *cmd: str, + workdir: str | None = None, + stdin: str | None = None, + timeout: int = 600, + ) -> subprocess.CompletedProcess: + del workdir, stdin, timeout + if len(cmd) < 3: + return subprocess.CompletedProcess(cmd, 1, "", "invalid command") + script = cmd[2] + + if "ok_cmd" in script: + return subprocess.CompletedProcess(cmd, 0, "ok\n", "") + if "fail_cmd" in script: + return subprocess.CompletedProcess(cmd, 42, "", "failed") + if "timeout_cmd" in script: + return subprocess.CompletedProcess(cmd, -1, "", "timeout") + + if "mkdir -p" in script: + return subprocess.CompletedProcess(cmd, 0, "", "") + + if "test -e " in script: + match = re.search(r"test -e '([^']+)'", script) + assert match is not None + path = match.group(1) + exists = path in self.files or path in self.marker_files + return subprocess.CompletedProcess(cmd, 0 if exists else 1, "", "") + + if "cat '/tmp/.openenv_bg_" in script: + match = re.search(r"cat '([^']+)'", script) + assert match is not None + marker = match.group(1) + if marker in self.marker_files: + return subprocess.CompletedProcess( + cmd, + 0, + f"{self.marker_files[marker]}\n", + "", + ) + return subprocess.CompletedProcess(cmd, 1, "", "missing") + + if script.strip().startswith("kill -0 "): + pid = int(script.strip().split()[2]) + alive = self.bg_jobs.get(pid, {}).get("alive", False) + return subprocess.CompletedProcess(cmd, 0 if alive else 1, "", "") + + if script.strip().startswith("kill -9 "): + pid = int(script.strip().split()[2]) + if pid in self.bg_jobs: + self.bg_jobs[pid]["alive"] = False + marker = self.bg_jobs[pid]["marker"] + self.marker_files[marker] = "137" + return subprocess.CompletedProcess(cmd, 0, "", "") + + if "echo $!" in script: + marker_match = re.search(r"(/tmp/\.openenv_bg_[A-Za-z0-9]+\.exit)", script) + assert marker_match is not None + marker = marker_match.group(1) + pid = self.next_pid + self.next_pid += 1 + long_running = "sleep 300" in script + self.bg_jobs[pid] = { + "marker": marker, + "alive": long_running, + } + if not long_running: + self.marker_files[marker] = "0" + return subprocess.CompletedProcess(cmd, 0, f"{pid}\n", "") + + return subprocess.CompletedProcess(cmd, 0, "", "") + + def write_file( + self, + path: str, + content: str | bytes | bytearray | memoryview, + ) -> None: + if isinstance(content, str): + normalized = content + else: + normalized = bytes(content).decode("utf-8", "replace") + self.files[path] = normalized + + def read_file(self, path: str, text: bool = True) -> str | bytes: + if path not in self.files: + raise FileNotFoundError(path) + return self.files[path] if text else self.files[path].encode() + + def terminate(self) -> None: + self.terminated = True + + +class _FakeSandboxAPI: + calls: list[dict] = [] + + @classmethod + def create( + cls, + image: str, + flavor: str, + timeout: str, + forward_hf_token: bool, + ) -> _FakeSandbox: + cls.calls.append( + { + "image": image, + "flavor": flavor, + "timeout": timeout, + "forward_hf_token": forward_hf_token, + } + ) + return _FakeSandbox(job_id="job-123") + + +def _install_fake_hf_sandbox(monkeypatch) -> None: + fake_module = types.ModuleType("hf_sandbox") + setattr(fake_module, "Sandbox", _FakeSandboxAPI) + monkeypatch.setitem(sys.modules, "hf_sandbox", fake_module) + + +@pytest.fixture(autouse=True) +def _reset_fake_hf_calls() -> None: + _FakeSandboxAPI.calls.clear() + + +class TestHFSandboxBackend: + def test_exported_from_package(self, monkeypatch): + _install_fake_hf_sandbox(monkeypatch) + + import openenv.core.harness.sandbox as sandbox_pkg + + importlib.reload(sandbox_pkg) + assert hasattr(sandbox_pkg, "HFSandboxBackend") + assert hasattr(sandbox_pkg, "HFSandboxHandle") + assert hasattr(sandbox_pkg, "HFBgJob") + + def test_create_exec_write_read_exists_bg_and_kill(self, monkeypatch): + import openenv.core.harness.sandbox.hf_backend as hf_backend + + _install_fake_hf_sandbox(monkeypatch) + importlib.reload(hf_backend) + + monkeypatch.setattr(hf_backend, "Sandbox", _FakeSandboxAPI) + + backend = hf_backend.HFSandboxBackend( + image="python:3.12", + flavor="cpu-basic", + forward_hf_token=True, + ) + sandbox = backend.create(timeout_s=120, envs={"GLOBAL_ENV": "on"}) + + assert sandbox.sandbox_id == "job-123" + assert _FakeSandboxAPI.calls[-1]["timeout"] == "2m" + + ok = sandbox.exec("ok_cmd") + assert ok.exit_code == 0 + + failed = sandbox.exec("fail_cmd") + assert failed.exit_code == 42 + + timed = sandbox.exec("timeout_cmd") + assert timed.exit_code == -1 + + sandbox.write_text("/tmp/hello.txt", "hello") + assert sandbox.exists("/tmp/hello.txt") + assert sandbox.read_text("/tmp/hello.txt") == "hello" + + short_job = sandbox.start_bg("echo done > /tmp/bg.txt") + assert short_job.wait(timeout=2) == 0 + + long_job = sandbox.start_bg("sleep 300") + with pytest.raises(TimeoutError): + long_job.wait(timeout=0.1) + long_job.kill() + assert isinstance(long_job.wait(timeout=2), int) + + sandbox.kill() + raw = getattr(sandbox, "raw", None) + assert raw is not None + assert raw.terminated is True + + def test_factory_creates_hf_backend(self, monkeypatch): + _install_fake_hf_sandbox(monkeypatch) + + import openenv.core.harness.sandbox as sandbox_pkg + import openenv.core.harness.sandbox.hf_backend as hf_backend + + importlib.reload(hf_backend) + importlib.reload(sandbox_pkg) + + monkeypatch.setattr(hf_backend, "Sandbox", _FakeSandboxAPI) + backend = sandbox_pkg.create_sandbox_backend("hf", image="python:3.12") + assert isinstance(backend, hf_backend.HFSandboxBackend) diff --git a/tests/core/test_interception_server.py b/tests/core/test_interception_server.py new file mode 100644 index 000000000..41ef38fe5 --- /dev/null +++ b/tests/core/test_interception_server.py @@ -0,0 +1,315 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import asyncio + +import aiohttp +import pytest + +from openenv.core.harness.agents.interception_server import ( + InterceptionServer, + deliver_response, +) + + +_ANSWER_TOOL = { + "type": "function", + "function": { + "name": "answer", + "description": "Submit final answer for grading", + "parameters": { + "type": "object", + "properties": { + "answer": {"type": "string"}, + }, + "required": ["answer"], + }, + }, +} + + +@pytest.mark.asyncio +async def test_interception_server_rejects_unauthorized_requests() -> None: + server = InterceptionServer(port=0, secret="secret-token") + await server.start() + try: + async with aiohttp.ClientSession() as client: + resp = await client.post( + f"http://127.0.0.1:{server.port}/rollout/r1/v1/chat/completions", + json={"messages": []}, + ) + assert resp.status == 401 + finally: + await server.stop() + + +@pytest.mark.asyncio +async def test_interception_server_returns_404_for_unknown_rollout() -> None: + server = InterceptionServer(port=0, secret="secret-token") + await server.start() + try: + async with aiohttp.ClientSession() as client: + resp = await client.post( + f"http://127.0.0.1:{server.port}/rollout/missing/v1/chat/completions", + headers={"Authorization": "Bearer secret-token"}, + json={"messages": []}, + ) + assert resp.status == 404 + finally: + await server.stop() + + +@pytest.mark.asyncio +async def test_interception_server_non_stream_roundtrip_cleans_intercept() -> None: + server = InterceptionServer(port=0, secret="secret-token") + await server.start() + queue = server.register_rollout("r1") + try: + async with aiohttp.ClientSession() as client: + request_task = asyncio.create_task( + client.post( + f"http://127.0.0.1:{server.port}/rollout/r1/v1/chat/completions", + headers={"Authorization": "Bearer secret-token"}, + json={ + "messages": [{"role": "user", "content": "hi"}], + "stream": False, + }, + ) + ) + request_id = await asyncio.to_thread(queue.get, timeout=1.0) + intercept = server.get_intercept(request_id) + assert intercept is not None + + await deliver_response( + intercept, + { + "id": "resp-1", + "model": "test-model", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "hello"}, + "finish_reason": "stop", + } + ], + }, + ) + + resp = await request_task + assert resp.status == 200 + payload = await resp.json() + assert payload["id"] == "resp-1" + + # Request entries should not leak after completion. + assert server.get_intercept(request_id) is None + finally: + server.unregister_rollout("r1") + await server.stop() + + +@pytest.mark.asyncio +async def test_interception_server_unregister_rollout_cancels_pending_request() -> None: + server = InterceptionServer(port=0, secret="secret-token") + await server.start() + queue = server.register_rollout("r1") + try: + async with aiohttp.ClientSession() as client: + request_task = asyncio.create_task( + client.post( + f"http://127.0.0.1:{server.port}/rollout/r1/v1/chat/completions", + headers={"Authorization": "Bearer secret-token"}, + json={ + "messages": [{"role": "user", "content": "hi"}], + "stream": False, + }, + ) + ) + _request_id = await asyncio.to_thread(queue.get, timeout=1.0) + server.unregister_rollout("r1") + + resp = await request_task + assert resp.status == 499 + payload = await resp.json() + assert payload["error"] == "rollout cancelled" + finally: + await server.stop() + + +@pytest.mark.asyncio +async def test_interception_server_tool_endpoint_executes_registered_handler() -> None: + server = InterceptionServer( + port=0, + secret="secret-token", + tool_name_allowlist={"answer"}, + ) + await server.start() + server.register_rollout("r1") + seen: dict[str, object] = {} + + async def _handler(arguments: dict) -> dict: + seen["arguments"] = arguments + return {"content": [{"type": "text", "text": "βœ…"}]} + + server.register_tool_handler("r1", "answer", _handler) + try: + async with aiohttp.ClientSession() as client: + resp = await client.post( + f"http://127.0.0.1:{server.port}/rollout/r1/v1/tools/answer", + headers={"Authorization": "Bearer secret-token"}, + json={"arguments": {"answer": "42"}}, + ) + assert resp.status == 200 + payload = await resp.json() + assert payload["content"][0]["text"] == "βœ…" + assert seen["arguments"] == {"answer": "42"} + finally: + server.unregister_rollout("r1") + await server.stop() + + +@pytest.mark.asyncio +async def test_interception_server_tool_endpoint_returns_404_for_unknown_tool() -> None: + server = InterceptionServer(port=0, secret="secret-token") + await server.start() + server.register_rollout("r1") + try: + async with aiohttp.ClientSession() as client: + resp = await client.post( + f"http://127.0.0.1:{server.port}/rollout/r1/v1/tools/missing", + headers={"Authorization": "Bearer secret-token"}, + json={"arguments": {}}, + ) + assert resp.status == 404 + finally: + server.unregister_rollout("r1") + await server.stop() + + +def test_interception_server_rejects_reserved_tool_name_registration() -> None: + server = InterceptionServer( + port=0, + secret="secret-token", + tool_name_allowlist={"reset"}, + ) + server.register_rollout("r1") + + async def _handler(arguments: dict) -> dict: + return {"ok": True} + + with pytest.raises(ValueError, match="reserved"): + server.register_tool_handler("r1", "reset", _handler) + + +def test_interception_server_rejects_tool_definition_name_mismatch() -> None: + server = InterceptionServer( + port=0, + secret="secret-token", + tool_name_allowlist={"answer"}, + ) + server.register_rollout("r1") + + async def _handler(arguments: dict) -> dict: + return {"ok": True} + + mismatched = { + "type": "function", + "function": { + "name": "not_answer", + "description": "Mismatch", + "parameters": {"type": "object", "properties": {}}, + }, + } + + with pytest.raises(ValueError, match="must exactly match"): + server.register_tool_handler( + "r1", + "answer", + _handler, + tool_definition=mismatched, + ) + + +def test_interception_server_rejects_tool_not_in_allowlist() -> None: + server = InterceptionServer( + port=0, + secret="secret-token", + tool_name_allowlist={"answer"}, + ) + server.register_rollout("r1") + + async def _handler(arguments: dict) -> dict: + return {"ok": True} + + with pytest.raises(ValueError, match="allowlist"): + server.register_tool_handler("r1", "search", _handler) + + +@pytest.mark.asyncio +async def test_interception_server_injects_registered_tool_defs_into_intercept() -> ( + None +): + server = InterceptionServer( + port=0, + secret="secret-token", + tool_name_allowlist={"answer"}, + ) + await server.start() + queue = server.register_rollout("r1") + + async def _handler(arguments: dict) -> dict: + return {"content": [{"type": "text", "text": str(arguments)}]} + + server.register_tool_handler( + "r1", + "answer", + _handler, + tool_definition=_ANSWER_TOOL, + ) + + try: + async with aiohttp.ClientSession() as client: + request_task = asyncio.create_task( + client.post( + f"http://127.0.0.1:{server.port}/rollout/r1/v1/chat/completions", + headers={"Authorization": "Bearer secret-token"}, + json={ + "messages": [{"role": "user", "content": "grade this"}], + "stream": False, + }, + ) + ) + request_id = await asyncio.to_thread(queue.get, timeout=1.0) + intercept = server.get_intercept(request_id) + assert intercept is not None + tool_names = { + tool["function"]["name"] + for tool in intercept.get("tools", []) + if isinstance(tool, dict) and isinstance(tool.get("function"), dict) + } + assert "answer" in tool_names + + await deliver_response( + intercept, + { + "id": "resp-1", + "model": "test-model", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": "done"}, + "finish_reason": "stop", + } + ], + }, + ) + + resp = await request_task + assert resp.status == 200 + finally: + server.unregister_rollout("r1") + await server.stop() diff --git a/tests/envs/test_opencode_env.py b/tests/envs/test_opencode_env.py index 812ade194..701d562e9 100644 --- a/tests/envs/test_opencode_env.py +++ b/tests/envs/test_opencode_env.py @@ -24,7 +24,6 @@ from __future__ import annotations import os -import shlex import sys import pytest @@ -46,14 +45,14 @@ def test_public_api_imports() -> None: """Top-level package re-exports the documented surface.""" from opencode_env import ( # noqa: F401 - CommandResult, - E2BSandboxBackend, OpenCodeConfig, OpenCodeEnv, OpenCodeSession, OpenCodeSessionFactory, OpenCodeState, OpenCodeTask, + CommandResult, + E2BSandboxBackend, Provider, RolloutResult, RolloutTurn, @@ -166,6 +165,73 @@ def test_catalog_summary_shape() -> None: } <= entry.keys() +def test_build_agent_config_opencode() -> None: + from opencode_env.server.opencode_environment import OpenCodeEnvironment + + env = OpenCodeEnvironment() + cfg = env._build_agent_config( + base_url="https://api.openai.com/v1", + api_key="sk-test", + model="gpt-4o-mini", + agent_timeout_s=123.0, + disable_thinking=True, + top_logprobs=7, + max_tokens_cap=2048, + ) + assert isinstance(cfg, env._OpenCodeConfig) + assert cfg.model == "gpt-4o-mini" + assert cfg.agent_timeout_s == 123.0 + assert cfg.max_tokens_cap == 2048 + assert cfg.proxy_max_tokens_cap == 2048 + assert cfg.proxy_top_logprobs == 7 + assert cfg.proxy_disable_thinking is True + + cfg_4096 = env._build_agent_config( + base_url="https://api.openai.com/v1", + api_key="sk-test", + model="gpt-4o-mini", + agent_timeout_s=123.0, + disable_thinking=True, + top_logprobs=7, + max_tokens_cap=4096, + ) + assert cfg_4096.max_tokens_cap == 4096 + + cfg_uncapped = env._build_agent_config( + base_url="https://api.openai.com/v1", + api_key="sk-test", + model="gpt-4o-mini", + agent_timeout_s=123.0, + disable_thinking=True, + top_logprobs=7, + max_tokens_cap=0, + ) + assert cfg_uncapped.max_tokens_cap is None + + +def test_build_session_factory_requires_e2b_dependency() -> None: + from opencode_env.server.opencode_environment import OpenCodeEnvironment + + env = OpenCodeEnvironment() + env._E2BSandboxBackend = None + cfg = env._build_agent_config( + base_url="https://api.openai.com/v1", + api_key="sk-test", + model="gpt-4o-mini", + agent_timeout_s=180.0, + disable_thinking=False, + top_logprobs=5, + max_tokens_cap=4096, + ) + + with pytest.raises(RuntimeError, match="E2BSandboxBackend unavailable"): + env._build_session_factory( + config=cfg, + mode="black_box", + template="", + ) + + # --------------------------------------------------------------------------- # Models + task coercion # --------------------------------------------------------------------------- @@ -180,16 +246,15 @@ def test_rollout_result_serializes_round_trip() -> None: reward=0.75, agent_exit_code=0, wall_s=12.5, - mode="transparent_proxy", + mode="black_box", setup_results=[CommandResult(cmd="pip install pandas", exit_code=0)], verify_results=[CommandResult(cmd="pytest", exit_code=1, stderr="boom")], proxy_turns=[ RolloutTurn( turn=1, - finish_reason="stop", - completion_tokens=["hi"], + completion_tokens=["ok"], + completion_token_ids=[123], per_token_logps=[-0.1], - latency_s=0.2, ) ], files={"/home/user/workdir/x.py": "print('x')"}, @@ -198,7 +263,7 @@ def test_rollout_result_serializes_round_trip() -> None: rebuilt = RolloutResult.model_validate_json(blob) assert rebuilt.reward == 0.75 assert rebuilt.verify_results[0].exit_code == 1 - assert rebuilt.proxy_turns[0].completion_tokens == ["hi"] + assert rebuilt.proxy_turns[0].per_token_logps == [-0.1] def test_opencode_task_coerce_str() -> None: @@ -213,7 +278,9 @@ def test_opencode_task_coerce_str() -> None: def test_opencode_task_coerce_dict() -> None: from opencode_env import OpenCodeTask - t = OpenCodeTask.coerce({"instruction": "x", "setup_shell": "pip install pandas"}) + t = OpenCodeTask.coerce( + {"instruction": "x", "setup_shell": "pip install pandas"} + ) assert t.instruction == "x" assert t.setup_shell == "pip install pandas" @@ -232,107 +299,6 @@ def test_opencode_task_coerce_rejects_unknown_type() -> None: OpenCodeTask.coerce(42) # type: ignore[arg-type] -def test_start_proxy_keeps_upstream_key_out_of_command() -> None: - """The proxy API key must be passed via env, not shell argv.""" - from opencode_env import OpenCodeConfig, OpenCodeSessionFactory - - class FakeExecResult: - exit_code = 0 - stdout = "ok" - stderr = "" - - class FakeBgJob: - def wait(self, timeout: float | None = None) -> int: - return 0 - - def kill(self) -> None: - pass - - class FakeSandbox: - sandbox_id = "fake-sandbox" - - def __init__(self) -> None: - self.started_cmd: str | None = None - self.started_envs: dict[str, str] | None = None - self.written: dict[str, str] = {} - - def exec(self, *args, **kwargs) -> FakeExecResult: - return FakeExecResult() - - def start_bg(self, cmd: str, *, envs=None, cwd=None) -> FakeBgJob: - self.started_cmd = cmd - self.started_envs = envs - return FakeBgJob() - - def write_text(self, path: str, content: str) -> None: - self.written[path] = content - - def read_text(self, path: str) -> str: - return "" - - def exists(self, path: str) -> bool: - return path in self.written - - def kill(self) -> None: - pass - - class NoopInstallFactory(OpenCodeSessionFactory): - def _exec_with_retry(self, *args, **kwargs): - return FakeExecResult() - - secret = "sk-test '$(leak)" - model = "provider/model'; touch /tmp/pwn #" - config = OpenCodeConfig( - base_url="https://example.test/v1?x='y", - api_key=secret, - model=model, - ) - sandbox = FakeSandbox() - factory = NoopInstallFactory( - config=config, - sandbox_backend=object(), # unused by this protected-method test - mode="transparent_proxy", - ) - - factory._start_proxy(sandbox) - - assert sandbox.started_cmd is not None - assert sandbox.started_envs == {"OPENCODE_UPSTREAM_API_KEY": secret} - assert secret not in sandbox.started_cmd - assert "--upstream-api-key" not in sandbox.started_cmd - - argv = shlex.split(sandbox.started_cmd.split("&&", 1)[1].split(">", 1)[0].strip()) - assert argv[argv.index("--upstream-url") + 1] == config.base_url - assert argv[argv.index("--model-override") + 1] == model - - -def test_interception_cli_reads_upstream_key_from_env( - monkeypatch: pytest.MonkeyPatch, -) -> None: - from opencode_env.sandbox import interception - - captured = {} - - def fake_serve(cfg) -> None: - captured["cfg"] = cfg - - monkeypatch.setattr(interception, "serve", fake_serve) - monkeypatch.setenv("OPENCODE_UPSTREAM_API_KEY", "sk-from-env") - monkeypatch.setattr( - sys, - "argv", - [ - "interception.py", - "--upstream-url", - "https://example.test/v1", - ], - ) - - interception.main() - - assert captured["cfg"].upstream_api_key == "sk-from-env" - - # --------------------------------------------------------------------------- # Integration β€” only runs when E2B + endpoint creds are present and the # user explicitly opts in via ``pytest -m integration``. @@ -393,7 +359,6 @@ async def _go() -> RolloutResult: assert result.reward == 1.0, ( f"expected reward=1.0 got {result.reward}: {result.error}" ) - assert result.proxy_turns, "expected at least one captured LLM turn" assert any(f.endswith("/binary_search.py") for f in result.files), ( f"expected binary_search.py in workdir, got {list(result.files)}" )