From 74bc589e461ee9a4c33e31d69f1efa42d3a025f3 Mon Sep 17 00:00:00 2001 From: Winston Howes Date: Thu, 14 May 2026 12:54:06 -0700 Subject: [PATCH] Add SOCKS5 TCP MITM coverage --- codex-rs/network-proxy/src/mitm.rs | 26 ++- codex-rs/network-proxy/src/socks5.rs | 336 +++++++++++++++++++++++---- 2 files changed, 310 insertions(+), 52 deletions(-) diff --git a/codex-rs/network-proxy/src/mitm.rs b/codex-rs/network-proxy/src/mitm.rs index f1aafddb2326..e456009435e9 100644 --- a/codex-rs/network-proxy/src/mitm.rs +++ b/codex-rs/network-proxy/src/mitm.rs @@ -20,10 +20,12 @@ use rama_core::Layer; use rama_core::Service; use rama_core::bytes::Bytes; use rama_core::error::BoxError; +use rama_core::extensions::ExtensionsMut; use rama_core::extensions::ExtensionsRef; -use rama_core::futures::stream::Stream; +use rama_core::futures::stream::Stream as FuturesStream; use rama_core::rt::Executor; use rama_core::service::service_fn; +use rama_core::stream::Stream; use rama_http::Body; use rama_http::BodyDataStream; use rama_http::HeaderMap; @@ -138,17 +140,25 @@ impl MitmState { /// Terminate the upgraded CONNECT stream with a generated leaf cert and proxy inner HTTPS traffic. pub(crate) async fn mitm_tunnel(upgraded: Upgraded) -> Result<()> { - let mitm = upgraded + mitm_stream(upgraded).await +} + +/// Terminate a raw client stream with a generated leaf cert and proxy inner HTTPS traffic. +pub(crate) async fn mitm_stream(stream: S) -> Result<()> +where + S: Stream + Unpin + ExtensionsMut, +{ + let mitm = stream .extensions() .get::>() .cloned() .context("missing MITM state")?; - let app_state = upgraded + let app_state = stream .extensions() .get::>() .cloned() .context("missing app state")?; - let target = upgraded + let target = stream .extensions() .get::() .context("missing proxy target")? @@ -157,7 +167,7 @@ pub(crate) async fn mitm_tunnel(upgraded: Upgraded) -> Result<()> { let target_host = normalize_host(&target.host.to_string()); let target_port = target.port; let acceptor_data = mitm.tls_acceptor_data_for_host(&target_host)?; - let mode = upgraded + let mode = stream .extensions() .get::() .copied() @@ -172,7 +182,7 @@ pub(crate) async fn mitm_tunnel(upgraded: Upgraded) -> Result<()> { mitm, }); - let executor = upgraded + let executor = stream .extensions() .get::() .cloned() @@ -197,7 +207,7 @@ pub(crate) async fn mitm_tunnel(upgraded: Upgraded) -> Result<()> { .into_layer(http_service); https_service - .serve(upgraded) + .serve(stream) .await .map_err(|err| anyhow!("MITM serve error: {err}"))?; Ok(()) @@ -480,7 +490,7 @@ struct InspectStream { max_body_bytes: usize, } -impl Stream for InspectStream { +impl FuturesStream for InspectStream { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll> { diff --git a/codex-rs/network-proxy/src/socks5.rs b/codex-rs/network-proxy/src/socks5.rs index a1c430c7db8e..19f0b5098089 100644 --- a/codex-rs/network-proxy/src/socks5.rs +++ b/codex-rs/network-proxy/src/socks5.rs @@ -1,5 +1,6 @@ use crate::config::NetworkMode; use crate::connect_policy::TargetCheckedTcpConnector; +use crate::mitm; use crate::network_policy::BlockDecisionAuditEventArgs; use crate::network_policy::NetworkDecision; use crate::network_policy::NetworkDecisionSource; @@ -12,6 +13,7 @@ use crate::network_policy::emit_block_decision_audit_event; use crate::network_policy::evaluate_host_policy; use crate::policy::normalize_host; use crate::reasons::REASON_METHOD_NOT_ALLOWED; +use crate::reasons::REASON_MITM_REQUIRED; use crate::reasons::REASON_PROXY_DISABLED; use crate::responses::PolicyDecisionDetails; use crate::responses::blocked_message_with_policy; @@ -23,10 +25,17 @@ use anyhow::Result; use rama_core::Layer; use rama_core::Service; use rama_core::error::BoxError; +use rama_core::extensions::Extensions; +use rama_core::extensions::ExtensionsMut; use rama_core::extensions::ExtensionsRef; use rama_core::layer::AddInputExtensionLayer; use rama_core::service::service_fn; +use rama_net::address::HostWithPort; use rama_net::client::EstablishedClientConnection; +use rama_net::proxy::ProxyRequest; +use rama_net::proxy::ProxyTarget; +use rama_net::proxy::StreamForwardService; +use rama_net::stream::Socket; use rama_net::stream::SocketInfo; use rama_socks5::Socks5Acceptor; use rama_socks5::server::DefaultConnector; @@ -39,8 +48,14 @@ use rama_tcp::server::TcpListener; use std::io; use std::net::SocketAddr; use std::net::TcpListener as StdTcpListener; +use std::pin::Pin; use std::sync::Arc; +use std::task::Context as TaskContext; +use std::task::Poll; use std::time::Instant; +use tokio::io::AsyncRead; +use tokio::io::AsyncWrite; +use tokio::io::ReadBuf; use tracing::error; use tracing::info; use tracing::warn; @@ -87,7 +102,7 @@ async fn run_socks5_with_listener( match state.network_mode().await { Ok(NetworkMode::Limited) => { - info!("SOCKS5 is blocked in limited mode; set mode=\"full\" to allow SOCKS5"); + info!("SOCKS5 UDP is blocked in limited mode; SOCKS5 TCP requires MITM inspection"); } Ok(NetworkMode::Full) => {} Err(err) => { @@ -105,7 +120,10 @@ async fn run_socks5_with_listener( } }); - let socks_connector = DefaultConnector::default().with_connector(policy_tcp_connector); + let socks_proxy = service_fn(|request| async move { proxy_socks5_tcp(request).await }); + let socks_connector = DefaultConnector::default() + .with_connector(policy_tcp_connector) + .with_service(socks_proxy); let base = Socks5Acceptor::new().with_connector(socks_connector); if enable_socks5_udp { @@ -134,7 +152,7 @@ async fn handle_socks5_tcp( req: TcpRequest, tcp_connector: TargetCheckedTcpConnector, policy_decider: Option>, -) -> Result, BoxError> { +) -> Result, BoxError> { let app_state = req .extensions() .get::>() @@ -143,6 +161,7 @@ async fn handle_socks5_tcp( let host = normalize_host(&req.authority.host.to_string()); let port = req.authority.port; + let target = req.authority.clone(); if host.is_empty() { return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid host").into()); } @@ -195,50 +214,13 @@ async fn handle_socks5_tcp( } } - match app_state.network_mode().await { - Ok(NetworkMode::Limited) => { - emit_socks_block_decision_audit_event( - &app_state, - NetworkDecisionSource::ModeGuard, - REASON_METHOD_NOT_ALLOWED, - NetworkProtocol::Socks5Tcp, - host.as_str(), - port, - client.as_deref(), - ); - let details = PolicyDecisionDetails { - decision: NetworkPolicyDecision::Deny, - reason: REASON_METHOD_NOT_ALLOWED, - source: NetworkDecisionSource::ModeGuard, - protocol: NetworkProtocol::Socks5Tcp, - host: &host, - port, - }; - let _ = app_state - .record_blocked(BlockedRequest::new(BlockedRequestArgs { - host: host.clone(), - reason: REASON_METHOD_NOT_ALLOWED.to_string(), - client: client.clone(), - method: None, - mode: Some(NetworkMode::Limited), - protocol: "socks5".to_string(), - decision: Some(details.decision.as_str().to_string()), - source: Some(details.source.as_str().to_string()), - port: Some(port), - })) - .await; - let client = client.as_deref().unwrap_or_default(); - warn!( - "SOCKS blocked by method policy (client={client}, host={host}, mode=limited, allowed_methods=GET, HEAD, OPTIONS)" - ); - return Err(policy_denied_error(REASON_METHOD_NOT_ALLOWED, &details).into()); - } - Ok(NetworkMode::Full) => {} + let mode = match app_state.network_mode().await { + Ok(mode) => mode, Err(err) => { error!("failed to evaluate method policy: {err}"); return Err(io::Error::other("proxy error").into()); } - } + }; let request = NetworkPolicyRequest::new(NetworkPolicyRequestArgs { protocol: NetworkProtocol::Socks5Tcp, @@ -291,9 +273,82 @@ async fn handle_socks5_tcp( } } + let mitm_state = match app_state.mitm_state().await { + Ok(state) => state, + Err(err) => { + error!("failed to load MITM state: {err}"); + return Err(io::Error::other("proxy error").into()); + } + }; + let host_has_mitm_hooks = match app_state.host_has_mitm_hooks(&host).await { + Ok(has_hooks) => has_hooks, + Err(err) => { + error!("failed to inspect MITM hooks for {host}: {err}"); + return Err(io::Error::other("proxy error").into()); + } + }; + let socks_needs_mitm = mode == NetworkMode::Limited || host_has_mitm_hooks; + if socks_needs_mitm { + let Some(mitm_state) = mitm_state else { + emit_socks_block_decision_audit_event( + &app_state, + NetworkDecisionSource::ModeGuard, + REASON_MITM_REQUIRED, + NetworkProtocol::Socks5Tcp, + host.as_str(), + port, + client.as_deref(), + ); + let details = PolicyDecisionDetails { + decision: NetworkPolicyDecision::Deny, + reason: REASON_MITM_REQUIRED, + source: NetworkDecisionSource::ModeGuard, + protocol: NetworkProtocol::Socks5Tcp, + host: &host, + port, + }; + let _ = app_state + .record_blocked(BlockedRequest::new(BlockedRequestArgs { + host: host.clone(), + reason: REASON_MITM_REQUIRED.to_string(), + client: client.clone(), + method: None, + mode: Some(mode), + protocol: "socks5".to_string(), + decision: Some(details.decision.as_str().to_string()), + source: Some(details.source.as_str().to_string()), + port: Some(port), + })) + .await; + let client = client.as_deref().unwrap_or_default(); + warn!( + "SOCKS blocked; MITM required to enforce HTTPS policy (client={client}, host={host}, mode={mode:?}, hooked_host={host_has_mitm_hooks})" + ); + return Err(policy_denied_error(REASON_MITM_REQUIRED, &details).into()); + }; + + let client = client.as_deref().unwrap_or_default(); + info!("SOCKS MITM enabled (client={client}, host={host}, port={port}, mode={mode:?})"); + return Ok(EstablishedClientConnection { + input: req, + conn: Socks5TcpConnection::Mitm { + target, + mode, + mitm: mitm_state, + extensions: Extensions::new(), + }, + }); + } + info!("SOCKS upstream dial started (host={host}, port={port})"); let connect_started_at = Instant::now(); - let result = tcp_connector.serve(req).await; + let result = tcp_connector.serve(req).await.map(|connection| { + let EstablishedClientConnection { input, conn } = connection; + EstablishedClientConnection { + input, + conn: Socks5TcpConnection::Direct(conn), + } + }); match &result { Ok(_) => info!( "SOCKS upstream dial established (host={host}, port={port}, elapsed_ms={})", @@ -307,6 +362,113 @@ async fn handle_socks5_tcp( result } +/// Internal connector output for SOCKS5 TCP. MITM requests do not dial upstream before the +/// inner HTTPS request is inspected, so they carry the target metadata instead of a socket. +#[derive(Debug)] +enum Socks5TcpConnection { + Direct(TcpStream), + Mitm { + target: HostWithPort, + mode: NetworkMode, + mitm: Arc, + extensions: Extensions, + }, +} + +impl AsyncRead for Socks5TcpConnection { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut TaskContext<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match self.get_mut() { + Self::Direct(stream) => Pin::new(stream).poll_read(cx, buf), + Self::Mitm { .. } => Poll::Ready(Ok(())), + } + } +} + +impl AsyncWrite for Socks5TcpConnection { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut TaskContext<'_>, + buf: &[u8], + ) -> Poll> { + match self.get_mut() { + Self::Direct(stream) => Pin::new(stream).poll_write(cx, buf), + Self::Mitm { .. } => Poll::Ready(Ok(buf.len())), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll> { + match self.get_mut() { + Self::Direct(stream) => Pin::new(stream).poll_flush(cx), + Self::Mitm { .. } => Poll::Ready(Ok(())), + } + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll> { + match self.get_mut() { + Self::Direct(stream) => Pin::new(stream).poll_shutdown(cx), + Self::Mitm { .. } => Poll::Ready(Ok(())), + } + } +} + +impl Socket for Socks5TcpConnection { + fn local_addr(&self) -> io::Result { + match self { + Self::Direct(stream) => stream.local_addr(), + Self::Mitm { .. } => Ok(SocketAddr::from(([0, 0, 0, 0], 0))), + } + } + + fn peer_addr(&self) -> io::Result { + match self { + Self::Direct(stream) => stream.peer_addr(), + Self::Mitm { .. } => Ok(SocketAddr::from(([0, 0, 0, 0], 0))), + } + } +} + +impl ExtensionsRef for Socks5TcpConnection { + fn extensions(&self) -> &Extensions { + match self { + Self::Direct(stream) => stream.extensions(), + Self::Mitm { extensions, .. } => extensions, + } + } +} + +impl ExtensionsMut for Socks5TcpConnection { + fn extensions_mut(&mut self) -> &mut Extensions { + match self { + Self::Direct(stream) => stream.extensions_mut(), + Self::Mitm { extensions, .. } => extensions, + } + } +} + +async fn proxy_socks5_tcp( + request: ProxyRequest, +) -> Result<(), BoxError> { + let ProxyRequest { mut source, target } = request; + match target { + Socks5TcpConnection::Direct(target) => StreamForwardService::default() + .serve(ProxyRequest { source, target }) + .await + .map_err(Into::into), + Socks5TcpConnection::Mitm { + target, mode, mitm, .. + } => { + source.extensions_mut().insert(ProxyTarget(target)); + source.extensions_mut().insert(mode); + source.extensions_mut().insert(mitm); + mitm::mitm_stream(source).await.map_err(Into::into) + } + } +} + async fn inspect_socks5_udp( request: RelayRequest, state: Arc, @@ -589,6 +751,92 @@ mod tests { assert_eq!(event.field("client.address"), Some("unknown")); } + #[tokio::test(flavor = "current_thread")] + async fn handle_socks5_tcp_uses_mitm_in_limited_mode() { + let mut settings = NetworkProxySettings { + enabled: true, + mode: NetworkMode::Limited, + mitm: true, + ..NetworkProxySettings::default() + }; + settings.set_allowed_domains(vec!["example.com".to_string()]); + let state = state_for_settings(settings); + let mut request = + TcpRequest::new(HostWithPort::try_from("example.com:443").expect("valid authority")); + request.extensions_mut().insert(state.clone()); + + let result = handle_socks5_tcp( + request, + TargetCheckedTcpConnector::new(state), + /*policy_decider*/ None, + ) + .await + .expect("limited-mode HTTPS should use MITM"); + + assert!(matches!(result.conn, Socks5TcpConnection::Mitm { .. })); + } + + #[tokio::test(flavor = "current_thread")] + async fn handle_socks5_tcp_blocks_limited_mode_without_mitm_state() { + let mut settings = NetworkProxySettings { + enabled: true, + mode: NetworkMode::Limited, + ..NetworkProxySettings::default() + }; + settings.set_allowed_domains(vec!["example.com".to_string()]); + let state = state_for_settings(settings); + let mut request = + TcpRequest::new(HostWithPort::try_from("example.com:443").expect("valid authority")); + request.extensions_mut().insert(state.clone()); + + let err = handle_socks5_tcp( + request, + TargetCheckedTcpConnector::new(state), + /*policy_decider*/ None, + ) + .await + .expect_err("limited-mode HTTPS requires MITM"); + + assert!( + format!("{err:?}").contains("MITM required"), + "unexpected error: {err:?}" + ); + } + + #[tokio::test(flavor = "current_thread")] + async fn handle_socks5_tcp_uses_mitm_for_hooked_host_in_full_mode() { + let mut settings = NetworkProxySettings { + enabled: true, + mode: NetworkMode::Full, + mitm: true, + mitm_hooks: vec![crate::mitm_hook::MitmHookConfig { + host: "api.github.com".to_string(), + matcher: crate::mitm_hook::MitmHookMatchConfig { + methods: vec!["POST".to_string()], + path_prefixes: vec!["/repos/openai/".to_string()], + ..crate::mitm_hook::MitmHookMatchConfig::default() + }, + actions: crate::mitm_hook::MitmHookActionsConfig::default(), + }], + ..NetworkProxySettings::default() + }; + settings.set_allowed_domains(vec!["api.github.com".to_string()]); + let state = state_for_settings(settings); + let mut request = + TcpRequest::new(HostWithPort::try_from("api.github.com:443").expect("valid authority")); + request.extensions_mut().insert(state.clone()); + + let result = handle_socks5_tcp( + request, + TargetCheckedTcpConnector::new(state), + /*policy_decider*/ None, + ) + .await + .expect("hooked HTTPS should use MITM"); + + assert!(matches!(result.conn, Socks5TcpConnection::Mitm { .. })); + } + #[tokio::test(flavor = "current_thread")] async fn inspect_socks5_udp_emits_block_decision_for_mode_guard_deny() { let state = state_for_settings(NetworkProxySettings {