diff --git a/Cargo.lock b/Cargo.lock index a5c5a1f3..6875e75c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -949,6 +949,7 @@ dependencies = [ "fs2", "futures", "gethostname 0.5.0", + "http-body-util", "if-addrs", "jsonwebtoken", "mdns-sd", @@ -962,6 +963,7 @@ dependencies = [ "tokio", "tokio-stream", "tokio-test", + "tower", "tower-http", "tracing", "tracing-subscriber", diff --git a/src/cortex-app-server/Cargo.toml b/src/cortex-app-server/Cargo.toml index 8da2d7bd..cae0e43c 100644 --- a/src/cortex-app-server/Cargo.toml +++ b/src/cortex-app-server/Cargo.toml @@ -25,6 +25,7 @@ cortex-common = { path = "../cortex-common" } # Web framework axum = { workspace = true } tower-http = { workspace = true } +http-body-util = { workspace = true } # Async tokio = { workspace = true, features = ["rt", "rt-multi-thread", "macros", "sync", "fs", "net", "process", "time"] } @@ -73,3 +74,4 @@ gethostname = "0.5" [dev-dependencies] tokio-test = { workspace = true } +tower = { version = "0.5", features = ["util"] } diff --git a/src/cortex-app-server/src/middleware.rs b/src/cortex-app-server/src/middleware.rs index a9971576..0557ab30 100644 --- a/src/cortex-app-server/src/middleware.rs +++ b/src/cortex-app-server/src/middleware.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use std::time::{Duration, Instant}; use axum::{ + body::{Body, to_bytes}, extract::{Request, State}, http::{HeaderValue, Method, StatusCode, header}, middleware::Next, @@ -271,9 +272,36 @@ pub async fn body_limit_middleware( return Err(StatusCode::PAYLOAD_TOO_LARGE); } + let (parts, body) = request.into_parts(); + let body = match to_bytes(body, max_size).await { + Ok(bytes) => Body::from(bytes), + Err(err) if is_body_length_limit_error(&err) => { + warn!("Request body exceeded max size: {} bytes", max_size); + return Err(StatusCode::PAYLOAD_TOO_LARGE); + } + Err(err) => { + warn!("Failed to read request body: {}", err); + return Err(StatusCode::BAD_REQUEST); + } + }; + + let request = Request::from_parts(parts, body); Ok(next.run(request).await) } +fn is_body_length_limit_error(error: &(dyn std::error::Error + 'static)) -> bool { + let mut current = Some(error); + + while let Some(error) = current { + if error.is::() { + return true; + } + current = error.source(); + } + + false +} + /// CORS configuration. /// Includes Access-Control-Max-Age header to allow browsers to cache /// preflight responses, reducing the number of OPTIONS requests. @@ -458,7 +486,19 @@ pub async fn health_check_bypass_middleware(request: Request, next: Next) -> Res #[cfg(test)] mod tests { + use std::convert::Infallible; + + use axum::{ + Router, + body::{Bytes, to_bytes}, + middleware::from_fn_with_state, + routing::post, + }; + use futures::stream; + use tower::ServiceExt; + use super::*; + use crate::config::ServerConfig; #[test] fn test_request_id() { @@ -477,4 +517,67 @@ mod tests { .contains(&"Authorization".to_string()) ); } + + async fn body_len_handler(request: Request) -> String { + let body = to_bytes(request.into_body(), usize::MAX).await.unwrap(); + body.len().to_string() + } + + async fn app_with_body_limit(limit: usize) -> Router { + let mut config = ServerConfig::default(); + config.max_body_size = limit; + let state = Arc::new(AppState::new(config).await.unwrap()); + + Router::new() + .route("/upload", post(body_len_handler)) + .layer(from_fn_with_state(state, body_limit_middleware)) + } + + fn streaming_body(chunks: &'static [&'static [u8]]) -> Body { + Body::from_stream(stream::iter( + chunks + .iter() + .map(|chunk| Ok::<_, Infallible>(Bytes::from_static(chunk))), + )) + } + + #[tokio::test] + async fn body_limit_rejects_streaming_body_without_content_length() { + let app = app_with_body_limit(8).await; + let body = streaming_body(&[b"12345", b"67890"]); + + let response = app + .oneshot( + Request::builder() + .method(Method::POST) + .uri("/upload") + .body(body) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE); + } + + #[tokio::test] + async fn body_limit_allows_streaming_body_under_limit_without_content_length() { + let app = app_with_body_limit(8).await; + let body = streaming_body(&[b"123", b"45"]); + + let response = app + .oneshot( + Request::builder() + .method(Method::POST) + .uri("/upload") + .body(body) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + let body = to_bytes(response.into_body(), usize::MAX).await.unwrap(); + assert_eq!(&body[..], b"5"); + } }