diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index 589be7aac2e..62c4dedd47a 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -1006,6 +1006,21 @@ impl From for StartupOutcomeError { } } +fn elicitation_capability_for_server(server_name: &str) -> Option { + if server_name == CODEX_APPS_MCP_SERVER_NAME { + // https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation#capabilities + // indicates this should be an empty object. + Some(ElicitationCapability { + form: Some(FormElicitationCapability { + schema_validation: None, + }), + url: None, + }) + } else { + None + } +} + async fn start_server_task( server_name: String, client: Arc, @@ -1015,6 +1030,8 @@ async fn start_server_task( tx_event: Sender, elicitation_requests: ElicitationRequestManager, ) -> Result { + let elicitation = elicitation_capability_for_server(&server_name); + let params = InitializeRequestParams { meta: None, capabilities: ClientCapabilities { @@ -1022,14 +1039,7 @@ async fn start_server_task( extensions: None, roots: None, sampling: None, - // https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation#capabilities - // indicates this should be an empty object. - elicitation: Some(ElicitationCapability { - form: Some(FormElicitationCapability { - schema_validation: None, - }), - url: None, - }), + elicitation, tasks: None, }, client_info: Implementation { @@ -1541,6 +1551,22 @@ mod tests { }); } + #[test] + fn elicitation_capability_enabled_only_for_codex_apps() { + let codex_apps_capability = elicitation_capability_for_server(CODEX_APPS_MCP_SERVER_NAME); + assert!(matches!( + codex_apps_capability, + Some(ElicitationCapability { + form: Some(FormElicitationCapability { + schema_validation: None + }), + url: None, + }) + )); + + assert!(elicitation_capability_for_server("custom_mcp").is_none()); + } + #[test] fn mcp_init_error_display_prompts_for_github_pat() { let server_name = "github"; diff --git a/codex-rs/core/src/mcp_tool_call.rs b/codex-rs/core/src/mcp_tool_call.rs index 745633b1488..5f7734118cf 100644 --- a/codex-rs/core/src/mcp_tool_call.rs +++ b/codex-rs/core/src/mcp_tool_call.rs @@ -282,7 +282,7 @@ const MCP_TOOL_APPROVAL_CANCEL: &str = "Cancel"; #[derive(Debug, Serialize)] struct McpToolApprovalKey { server: String, - connector_id: String, + connector_id: Option, tool_name: String, } @@ -296,36 +296,29 @@ async fn maybe_request_mcp_tool_approval( if is_full_access_mode(turn_context) { return None; } - if server != CODEX_APPS_MCP_SERVER_NAME { - return None; - } let metadata = lookup_mcp_tool_metadata(sess, server, tool_name).await?; if !requires_mcp_tool_approval(&metadata.annotations) { return None; } - let approval_key = metadata - .connector_id - .as_deref() - .map(|connector_id| McpToolApprovalKey { - server: server.to_string(), - connector_id: connector_id.to_string(), - tool_name: tool_name.to_string(), - }); - if let Some(key) = approval_key.as_ref() - && mcp_tool_approval_is_remembered(sess, key).await - { + let approval_key = McpToolApprovalKey { + server: server.to_string(), + connector_id: metadata.connector_id.clone(), + tool_name: tool_name.to_string(), + }; + if mcp_tool_approval_is_remembered(sess, &approval_key).await { return Some(McpToolApprovalDecision::Accept); } let question_id = format!("{MCP_TOOL_APPROVAL_QUESTION_ID_PREFIX}_{call_id}"); let question = build_mcp_tool_approval_question( question_id.clone(), + server, tool_name, metadata.tool_title.as_deref(), metadata.connector_name.as_deref(), &metadata.annotations, - approval_key.is_some(), + true, ); let args = RequestUserInputArgs { questions: vec![question], @@ -334,10 +327,8 @@ async fn maybe_request_mcp_tool_approval( .request_user_input(turn_context, call_id.to_string(), args) .await; let decision = parse_mcp_tool_approval_response(response, &question_id); - if matches!(decision, McpToolApprovalDecision::AcceptAndRemember) - && let Some(key) = approval_key - { - remember_mcp_tool_approval(sess, key).await; + if matches!(decision, McpToolApprovalDecision::AcceptAndRemember) { + remember_mcp_tool_approval(sess, approval_key).await; } Some(decision) } @@ -407,6 +398,7 @@ async fn lookup_mcp_app_usage_metadata( fn build_mcp_tool_approval_question( question_id: String, + server: &str, tool_name: &str, tool_title: Option<&str>, connector_name: Option<&str>, @@ -425,7 +417,13 @@ fn build_mcp_tool_approval_question( let tool_label = tool_title.unwrap_or(tool_name); let app_label = connector_name .map(|name| format!("The {name} app")) - .unwrap_or_else(|| "This app".to_string()); + .unwrap_or_else(|| { + if server == CODEX_APPS_MCP_SERVER_NAME { + "This app".to_string() + } else { + format!("The {server} MCP server") + } + }); let question = format!( "{app_label} wants to run the tool \"{tool_label}\", which {reason}. Allow this action?" ); @@ -570,6 +568,52 @@ mod tests { assert_eq!(requires_mcp_tool_approval(&annotations), false); } + #[test] + fn custom_mcp_tool_question_mentions_server_name() { + let question = build_mcp_tool_approval_question( + "q".to_string(), + "custom_server", + "run_action", + Some("Run Action"), + None, + &annotations(Some(false), Some(true), None), + true, + ); + + assert_eq!(question.header, "Approve app tool call?"); + assert_eq!( + question.question, + "The custom_server MCP server wants to run the tool \"Run Action\", which may modify or delete data. Allow this action?" + ); + assert!( + question + .options + .expect("options") + .into_iter() + .map(|option| option.label) + .any(|label| label == MCP_TOOL_APPROVAL_ACCEPT_AND_REMEMBER) + ); + } + + #[test] + fn codex_apps_tool_question_keeps_legacy_app_label() { + let question = build_mcp_tool_approval_question( + "q".to_string(), + CODEX_APPS_MCP_SERVER_NAME, + "run_action", + Some("Run Action"), + None, + &annotations(Some(false), Some(true), None), + true, + ); + + assert!( + question + .question + .starts_with("This app wants to run the tool \"Run Action\"") + ); + } + #[test] fn sanitize_mcp_tool_result_for_model_rewrites_image_content() { let result = Ok(CallToolResult {