Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions src/cortex-app-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ cortex-common = { path = "../cortex-common" }
# Web framework
axum = { workspace = true }
tower-http = { workspace = true }
http-body-util = { workspace = true }

# Async
tokio = { workspace = true, features = ["rt", "rt-multi-thread", "macros", "sync", "fs", "net", "process", "time"] }
Expand Down Expand Up @@ -73,3 +74,4 @@ gethostname = "0.5"

[dev-dependencies]
tokio-test = { workspace = true }
tower = { version = "0.5", features = ["util"] }
103 changes: 103 additions & 0 deletions src/cortex-app-server/src/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::sync::Arc;
use std::time::{Duration, Instant};

use axum::{
body::{Body, to_bytes},
extract::{Request, State},
http::{HeaderValue, Method, StatusCode, header},
middleware::Next,
Expand Down Expand Up @@ -271,9 +272,36 @@ pub async fn body_limit_middleware(
return Err(StatusCode::PAYLOAD_TOO_LARGE);
}

let (parts, body) = request.into_parts();
let body = match to_bytes(body, max_size).await {
Ok(bytes) => Body::from(bytes),
Err(err) if is_body_length_limit_error(&err) => {
warn!("Request body exceeded max size: {} bytes", max_size);
return Err(StatusCode::PAYLOAD_TOO_LARGE);
}
Err(err) => {
warn!("Failed to read request body: {}", err);
return Err(StatusCode::BAD_REQUEST);
}
};

let request = Request::from_parts(parts, body);
Ok(next.run(request).await)
}

fn is_body_length_limit_error(error: &(dyn std::error::Error + 'static)) -> bool {
let mut current = Some(error);

while let Some(error) = current {
if error.is::<http_body_util::LengthLimitError>() {
return true;
}
current = error.source();
}

false
}

/// CORS configuration.
/// Includes Access-Control-Max-Age header to allow browsers to cache
/// preflight responses, reducing the number of OPTIONS requests.
Expand Down Expand Up @@ -458,7 +486,19 @@ pub async fn health_check_bypass_middleware(request: Request, next: Next) -> Res

#[cfg(test)]
mod tests {
use std::convert::Infallible;

use axum::{
Router,
body::{Bytes, to_bytes},
middleware::from_fn_with_state,
routing::post,
};
use futures::stream;
use tower::ServiceExt;

use super::*;
use crate::config::ServerConfig;

#[test]
fn test_request_id() {
Expand All @@ -477,4 +517,67 @@ mod tests {
.contains(&"Authorization".to_string())
);
}

async fn body_len_handler(request: Request) -> String {
let body = to_bytes(request.into_body(), usize::MAX).await.unwrap();
body.len().to_string()
}

async fn app_with_body_limit(limit: usize) -> Router {
let mut config = ServerConfig::default();
config.max_body_size = limit;
let state = Arc::new(AppState::new(config).await.unwrap());

Router::new()
.route("/upload", post(body_len_handler))
.layer(from_fn_with_state(state, body_limit_middleware))
}

fn streaming_body(chunks: &'static [&'static [u8]]) -> Body {
Body::from_stream(stream::iter(
chunks
.iter()
.map(|chunk| Ok::<_, Infallible>(Bytes::from_static(chunk))),
))
}

#[tokio::test]
async fn body_limit_rejects_streaming_body_without_content_length() {
let app = app_with_body_limit(8).await;
let body = streaming_body(&[b"12345", b"67890"]);

let response = app
.oneshot(
Request::builder()
.method(Method::POST)
.uri("/upload")
.body(body)
.unwrap(),
)
.await
.unwrap();

assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
}

#[tokio::test]
async fn body_limit_allows_streaming_body_under_limit_without_content_length() {
let app = app_with_body_limit(8).await;
let body = streaming_body(&[b"123", b"45"]);

let response = app
.oneshot(
Request::builder()
.method(Method::POST)
.uri("/upload")
.body(body)
.unwrap(),
)
.await
.unwrap();

assert_eq!(response.status(), StatusCode::OK);
let body = to_bytes(response.into_body(), usize::MAX).await.unwrap();
assert_eq!(&body[..], b"5");
}
}