Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions codex-rs/network-proxy/src/mitm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ use rama_core::Layer;
use rama_core::Service;
use rama_core::bytes::Bytes;
use rama_core::error::BoxError;
use rama_core::extensions::ExtensionsMut;
use rama_core::extensions::ExtensionsRef;
use rama_core::futures::stream::Stream;
use rama_core::futures::stream::Stream as FuturesStream;
use rama_core::rt::Executor;
use rama_core::service::service_fn;
use rama_core::stream::Stream;
use rama_http::Body;
use rama_http::BodyDataStream;
use rama_http::HeaderMap;
Expand Down Expand Up @@ -138,17 +140,25 @@ impl MitmState {

/// Terminate the upgraded CONNECT stream with a generated leaf cert and proxy inner HTTPS traffic.
pub(crate) async fn mitm_tunnel(upgraded: Upgraded) -> Result<()> {
let mitm = upgraded
mitm_stream(upgraded).await
}

/// Terminate a raw client stream with a generated leaf cert and proxy inner HTTPS traffic.
pub(crate) async fn mitm_stream<S>(stream: S) -> Result<()>
where
S: Stream + Unpin + ExtensionsMut,
{
let mitm = stream
.extensions()
.get::<Arc<MitmState>>()
.cloned()
.context("missing MITM state")?;
let app_state = upgraded
let app_state = stream
.extensions()
.get::<Arc<NetworkProxyState>>()
.cloned()
.context("missing app state")?;
let target = upgraded
let target = stream
.extensions()
.get::<ProxyTarget>()
.context("missing proxy target")?
Expand All @@ -157,7 +167,7 @@ pub(crate) async fn mitm_tunnel(upgraded: Upgraded) -> Result<()> {
let target_host = normalize_host(&target.host.to_string());
let target_port = target.port;
let acceptor_data = mitm.tls_acceptor_data_for_host(&target_host)?;
let mode = upgraded
let mode = stream
.extensions()
.get::<NetworkMode>()
.copied()
Expand All @@ -172,7 +182,7 @@ pub(crate) async fn mitm_tunnel(upgraded: Upgraded) -> Result<()> {
mitm,
});

let executor = upgraded
let executor = stream
.extensions()
.get::<Executor>()
.cloned()
Expand All @@ -197,7 +207,7 @@ pub(crate) async fn mitm_tunnel(upgraded: Upgraded) -> Result<()> {
.into_layer(http_service);

https_service
.serve(upgraded)
.serve(stream)
.await
.map_err(|err| anyhow!("MITM serve error: {err}"))?;
Ok(())
Expand Down Expand Up @@ -480,7 +490,7 @@ struct InspectStream<T> {
max_body_bytes: usize,
}

impl<T: BodyLoggable> Stream for InspectStream<T> {
impl<T: BodyLoggable> FuturesStream for InspectStream<T> {
type Item = Result<Bytes, BoxError>;

fn poll_next(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Option<Self::Item>> {
Expand Down
Loading
Loading