Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions codex-rs/core/src/session/tests/guardian_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ use crate::exec_policy::ExecPolicyManager;
use crate::guardian::GUARDIAN_REVIEWER_NAME;
use crate::sandboxing::SandboxPermissions;
use crate::test_support::models_manager_with_provider;
use crate::tools::context::FunctionToolOutput;
use crate::tools::context::ToolCallSource;
use crate::tools::context::ToolOutput;
use crate::tools::context::ToolPayload;
use crate::turn_diff_tracker::TurnDiffTracker;
use codex_app_server_protocol::ConfigLayerSource;
use codex_config::ConfigLayerEntry;
Expand All @@ -22,8 +23,8 @@ use codex_protocol::config_types::ApprovalsReviewer;
use codex_protocol::models::AdditionalPermissionProfile as PermissionProfile;
use codex_protocol::models::ContentItem;
use codex_protocol::models::NetworkPermissions;
use codex_protocol::models::ResponseInputItem;
use codex_protocol::models::ResponseItem;
use codex_protocol::models::function_call_output_content_items_to_text;
use codex_protocol::protocol::AskForApproval;
use codex_protocol::request_permissions::PermissionGrantScope;
use codex_protocol::request_permissions::RequestPermissionProfile;
Expand All @@ -48,8 +49,23 @@ use tempfile::tempdir;
use tokio::time::timeout;
use tokio_util::sync::CancellationToken;

fn expect_text_output(output: &FunctionToolOutput) -> String {
function_call_output_content_items_to_text(&output.body).unwrap_or_default()
fn expect_text_output<T>(output: &T) -> String
where
T: ToolOutput + ?Sized,
{
let response = output.to_response_item(
"call-guardian",
&ToolPayload::Function {
arguments: "{}".to_string(),
},
);
match response {
ResponseInputItem::FunctionCallOutput { output, .. }
| ResponseInputItem::CustomToolCallOutput { output, .. } => {
output.body.to_text().unwrap_or_default()
}
other => panic!("expected function output, got {other:?}"),
}
}

#[tokio::test]
Expand Down
19 changes: 11 additions & 8 deletions codex-rs/core/src/tools/code_mode/execute_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ use crate::function_tool::FunctionCallError;
use crate::tools::context::FunctionToolOutput;
use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolPayload;
use crate::tools::context::boxed_tool_output;
use crate::tools::registry::CoreToolRuntime;
use crate::tools::registry::ToolExecutor;
use crate::tools::registry::ToolHandler;
use codex_tools::ToolName;
use codex_tools::ToolSpec;

Expand Down Expand Up @@ -89,8 +90,6 @@ impl CodeModeExecuteHandler {

#[async_trait::async_trait]
impl ToolExecutor<ToolInvocation> for CodeModeExecuteHandler {
type Output = FunctionToolOutput;

fn tool_name(&self) -> ToolName {
ToolName::plain(PUBLIC_TOOL_NAME)
}
Expand All @@ -99,7 +98,10 @@ impl ToolExecutor<ToolInvocation> for CodeModeExecuteHandler {
Some(self.spec.clone())
}

async fn handle(&self, invocation: ToolInvocation) -> Result<Self::Output, FunctionCallError> {
async fn handle(
&self,
invocation: ToolInvocation,
) -> Result<Box<dyn crate::tools::context::ToolOutput>, FunctionCallError> {
let ToolInvocation {
session,
turn,
Expand All @@ -110,17 +112,18 @@ impl ToolExecutor<ToolInvocation> for CodeModeExecuteHandler {
} = invocation;

match payload {
ToolPayload::Custom { input } if is_exec_tool_name(&tool_name) => {
self.execute(session, turn, call_id, input).await
}
ToolPayload::Custom { input } if is_exec_tool_name(&tool_name) => self
.execute(session, turn, call_id, input)
.await
.map(boxed_tool_output),
_ => Err(FunctionCallError::RespondToModel(format!(
"{PUBLIC_TOOL_NAME} expects raw JavaScript source text"
))),
}
}
}

impl ToolHandler for CodeModeExecuteHandler {
impl CoreToolRuntime for CodeModeExecuteHandler {
fn matches_kind(&self, payload: &ToolPayload) -> bool {
matches!(payload, ToolPayload::Custom { .. })
}
Expand Down
14 changes: 8 additions & 6 deletions codex-rs/core/src/tools/code_mode/wait_handler.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use serde::Deserialize;

use crate::function_tool::FunctionCallError;
use crate::tools::context::FunctionToolOutput;
use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolPayload;
use crate::tools::context::boxed_tool_output;
use crate::tools::registry::CoreToolRuntime;
use crate::tools::registry::ToolExecutor;
use crate::tools::registry::ToolHandler;
use codex_tools::ToolName;
use codex_tools::ToolSpec;

Expand Down Expand Up @@ -43,8 +43,6 @@ where

#[async_trait::async_trait]
impl ToolExecutor<ToolInvocation> for CodeModeWaitHandler {
type Output = FunctionToolOutput;

fn tool_name(&self) -> ToolName {
ToolName::plain(WAIT_TOOL_NAME)
}
Expand All @@ -53,7 +51,10 @@ impl ToolExecutor<ToolInvocation> for CodeModeWaitHandler {
Some(create_wait_tool())
}

async fn handle(&self, invocation: ToolInvocation) -> Result<Self::Output, FunctionCallError> {
async fn handle(
&self,
invocation: ToolInvocation,
) -> Result<Box<dyn crate::tools::context::ToolOutput>, FunctionCallError> {
let ToolInvocation {
session,
turn,
Expand Down Expand Up @@ -99,6 +100,7 @@ impl ToolExecutor<ToolInvocation> for CodeModeWaitHandler {
}
handle_runtime_response(&exec, wait_response.into(), args.max_tokens, started_at)
.await
.map(boxed_tool_output)
.map_err(FunctionCallError::RespondToModel)
}
_ => Err(FunctionCallError::RespondToModel(format!(
Expand All @@ -108,4 +110,4 @@ impl ToolExecutor<ToolInvocation> for CodeModeWaitHandler {
}
}

impl ToolHandler for CodeModeWaitHandler {}
impl CoreToolRuntime for CodeModeWaitHandler {}
25 changes: 25 additions & 0 deletions codex-rs/core/src/tools/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ use tokio_util::sync::CancellationToken;
pub use codex_tools::ToolOutput;
pub use codex_tools::ToolPayload;

pub(crate) fn boxed_tool_output<T>(output: T) -> Box<dyn ToolOutput>
where
T: ToolOutput + 'static,
{
Box::new(output)
}

pub type SharedTurnDiffTracker = Arc<Mutex<TurnDiffTracker>>;

#[derive(Clone, Debug, Eq, PartialEq)]
Expand Down Expand Up @@ -91,6 +98,10 @@ impl ToolOutput for McpToolOutput {
})
}

fn post_tool_use_input(&self, _payload: &ToolPayload) -> Option<JsonValue> {
Some(self.tool_input.clone())
}

fn post_tool_use_response(&self, _call_id: &str, _payload: &ToolPayload) -> Option<JsonValue> {
serde_json::to_value(&self.result).ok()
}
Expand Down Expand Up @@ -327,6 +338,20 @@ impl ToolOutput for ExecCommandToolOutput {
)
}

fn post_tool_use_id(&self, call_id: &str) -> String {
if self.event_call_id.is_empty() {
call_id.to_string()
} else {
self.event_call_id.clone()
}
}

fn post_tool_use_input(&self, _payload: &ToolPayload) -> Option<JsonValue> {
self.hook_command
.as_ref()
.map(|command| serde_json::json!({ "command": command }))
}

fn post_tool_use_response(&self, _call_id: &str, _payload: &ToolPayload) -> Option<JsonValue> {
if self.process_id.is_some() || self.hook_command.is_none() {
return None;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ use crate::function_tool::FunctionCallError;
use crate::tools::context::FunctionToolOutput;
use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolPayload;
use crate::tools::context::boxed_tool_output;
use crate::tools::handlers::agent_jobs_spec::create_report_agent_job_result_tool;
use crate::tools::registry::CoreToolRuntime;
use crate::tools::registry::ToolExecutor;
use crate::tools::registry::ToolHandler;
use codex_tools::ToolName;
use codex_tools::ToolSpec;

Expand All @@ -14,8 +15,6 @@ pub struct ReportAgentJobResultHandler;

#[async_trait::async_trait]
impl ToolExecutor<ToolInvocation> for ReportAgentJobResultHandler {
type Output = FunctionToolOutput;

fn tool_name(&self) -> ToolName {
ToolName::plain("report_agent_job_result")
}
Expand All @@ -24,7 +23,10 @@ impl ToolExecutor<ToolInvocation> for ReportAgentJobResultHandler {
Some(create_report_agent_job_result_tool())
}

async fn handle(&self, invocation: ToolInvocation) -> Result<Self::Output, FunctionCallError> {
async fn handle(
&self,
invocation: ToolInvocation,
) -> Result<Box<dyn crate::tools::context::ToolOutput>, FunctionCallError> {
let ToolInvocation {
session, payload, ..
} = invocation;
Expand All @@ -38,11 +40,11 @@ impl ToolExecutor<ToolInvocation> for ReportAgentJobResultHandler {
}
};

handle(session, arguments).await
handle(session, arguments).await.map(boxed_tool_output)
}
}

impl ToolHandler for ReportAgentJobResultHandler {
impl CoreToolRuntime for ReportAgentJobResultHandler {
fn matches_kind(&self, payload: &ToolPayload) -> bool {
matches!(payload, ToolPayload::Function { .. })
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ use crate::function_tool::FunctionCallError;
use crate::tools::context::FunctionToolOutput;
use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolPayload;
use crate::tools::context::boxed_tool_output;
use crate::tools::handlers::agent_jobs_spec::create_spawn_agents_on_csv_tool;
use crate::tools::registry::CoreToolRuntime;
use crate::tools::registry::ToolExecutor;
use crate::tools::registry::ToolHandler;
use codex_tools::ToolName;
use codex_tools::ToolSpec;
use codex_utils_absolute_path::AbsolutePathBuf;
Expand All @@ -15,8 +16,6 @@ pub struct SpawnAgentsOnCsvHandler;

#[async_trait::async_trait]
impl ToolExecutor<ToolInvocation> for SpawnAgentsOnCsvHandler {
type Output = FunctionToolOutput;

fn tool_name(&self) -> ToolName {
ToolName::plain("spawn_agents_on_csv")
}
Expand All @@ -25,7 +24,10 @@ impl ToolExecutor<ToolInvocation> for SpawnAgentsOnCsvHandler {
Some(create_spawn_agents_on_csv_tool())
}

async fn handle(&self, invocation: ToolInvocation) -> Result<Self::Output, FunctionCallError> {
async fn handle(
&self,
invocation: ToolInvocation,
) -> Result<Box<dyn crate::tools::context::ToolOutput>, FunctionCallError> {
let ToolInvocation {
session,
turn,
Expand All @@ -42,11 +44,13 @@ impl ToolExecutor<ToolInvocation> for SpawnAgentsOnCsvHandler {
}
};

handle(session, turn, arguments).await
handle(session, turn, arguments)
.await
.map(boxed_tool_output)
}
}

impl ToolHandler for SpawnAgentsOnCsvHandler {
impl CoreToolRuntime for SpawnAgentsOnCsvHandler {
fn matches_kind(&self, payload: &ToolPayload) -> bool {
matches!(payload, ToolPayload::Function { .. })
}
Expand Down
19 changes: 10 additions & 9 deletions codex-rs/core/src/tools/handlers/apply_patch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ use crate::tools::context::ApplyPatchToolOutput;
use crate::tools::context::FunctionToolOutput;
use crate::tools::context::SharedTurnDiffTracker;
use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolOutput;
use crate::tools::context::ToolPayload;
use crate::tools::context::boxed_tool_output;
use crate::tools::events::ToolEmitter;
use crate::tools::events::ToolEventCtx;
use crate::tools::handlers::apply_granted_turn_permissions;
Expand All @@ -27,11 +27,11 @@ use crate::tools::handlers::resolve_tool_environment;
use crate::tools::handlers::updated_hook_command;
use crate::tools::hook_names::HookToolName;
use crate::tools::orchestrator::ToolOrchestrator;
use crate::tools::registry::CoreToolRuntime;
use crate::tools::registry::PostToolUsePayload;
use crate::tools::registry::PreToolUsePayload;
use crate::tools::registry::ToolArgumentDiffConsumer;
use crate::tools::registry::ToolExecutor;
use crate::tools::registry::ToolHandler;
use crate::tools::runtimes::apply_patch::ApplyPatchRequest;
use crate::tools::runtimes::apply_patch::ApplyPatchRuntime;
use crate::tools::sandboxing::ToolCtx;
Expand Down Expand Up @@ -299,8 +299,6 @@ async fn effective_patch_permissions(

#[async_trait::async_trait]
impl ToolExecutor<ToolInvocation> for ApplyPatchHandler {
type Output = ApplyPatchToolOutput;

fn tool_name(&self) -> ToolName {
ToolName::plain("apply_patch")
}
Expand All @@ -309,7 +307,10 @@ impl ToolExecutor<ToolInvocation> for ApplyPatchHandler {
Some(create_apply_patch_freeform_tool(self.multi_environment))
}

async fn handle(&self, invocation: ToolInvocation) -> Result<Self::Output, FunctionCallError> {
async fn handle(
&self,
invocation: ToolInvocation,
) -> Result<Box<dyn crate::tools::context::ToolOutput>, FunctionCallError> {
let ToolInvocation {
session,
turn,
Expand Down Expand Up @@ -359,7 +360,7 @@ impl ToolExecutor<ToolInvocation> for ApplyPatchHandler {
{
InternalApplyPatchInvocation::Output(item) => {
let content = item?;
Ok(ApplyPatchToolOutput::from_text(content))
Ok(boxed_tool_output(ApplyPatchToolOutput::from_text(content)))
}
InternalApplyPatchInvocation::DelegateToRuntime(apply) => {
let changes = convert_apply_patch_to_protocol(&apply.action);
Expand Down Expand Up @@ -414,7 +415,7 @@ impl ToolExecutor<ToolInvocation> for ApplyPatchHandler {
Some(&tracker),
);
let content = emitter.finish(event_ctx, out, delta.as_ref()).await?;
Ok(ApplyPatchToolOutput::from_text(content))
Ok(boxed_tool_output(ApplyPatchToolOutput::from_text(content)))
}
}
}
Expand All @@ -438,7 +439,7 @@ impl ToolExecutor<ToolInvocation> for ApplyPatchHandler {
}
}

impl ToolHandler for ApplyPatchHandler {
impl CoreToolRuntime for ApplyPatchHandler {
fn matches_kind(&self, payload: &ToolPayload) -> bool {
matches!(payload, ToolPayload::Custom { .. })
}
Expand Down Expand Up @@ -472,7 +473,7 @@ impl ToolHandler for ApplyPatchHandler {
fn post_tool_use_payload(
&self,
invocation: &ToolInvocation,
result: &Self::Output,
result: &dyn crate::tools::context::ToolOutput,
) -> Option<PostToolUsePayload> {
let tool_response =
result.post_tool_use_response(&invocation.call_id, &invocation.payload)?;
Expand Down
Loading
Loading