Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/cortex-app-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,12 @@ where
};

let listener = TcpListener::bind(addr).await?;
axum::serve(listener, app)
.with_graceful_shutdown(shutdown)
.await?;
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.with_graceful_shutdown(shutdown)
.await?;

// Graceful shutdown: close all active sessions first
// This ensures WebSocket clients receive proper close frames
Expand Down
38 changes: 36 additions & 2 deletions src/cortex-app-server/src/middleware.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
//! HTTP middleware components.

use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};

use axum::{
extract::{Request, State},
extract::{ConnectInfo, Request, State},
http::{HeaderValue, Method, StatusCode, header},
middleware::Next,
response::{IntoResponse, Response},
Expand Down Expand Up @@ -178,7 +179,11 @@ fn get_rate_limit_key(request: &Request, state: &AppState) -> String {
}
}

// Default to unknown when not behind proxy or headers not present
if let Some(ConnectInfo(addr)) = request.extensions().get::<ConnectInfo<SocketAddr>>() {
return format!("ip:{}", addr.ip());
}

// Default to unknown only when connection metadata is unavailable.
"ip:unknown".to_string()
}

Expand Down Expand Up @@ -458,7 +463,10 @@ pub async fn health_check_bypass_middleware(request: Request, next: Next) -> Res

#[cfg(test)]
mod tests {
use axum::body::Body;

use super::*;
use crate::config::ServerConfig;

#[test]
fn test_request_id() {
Expand All @@ -477,4 +485,30 @@ mod tests {
.contains(&"Authorization".to_string())
);
}

#[tokio::test]
async fn rate_limit_key_uses_connect_info_for_direct_clients() {
let config = ServerConfig::default();
let state = AppState::new(config).await.unwrap();
let mut request = Request::builder().body(Body::empty()).unwrap();
let addr: SocketAddr = "203.0.113.10:4242".parse().unwrap();
request.extensions_mut().insert(ConnectInfo(addr));

assert_eq!(get_rate_limit_key(&request, &state), "ip:203.0.113.10");
}

#[tokio::test]
async fn rate_limit_key_prefers_proxy_headers_when_trusted() {
let mut config = ServerConfig::default();
config.rate_limit.trust_proxy = true;
let state = AppState::new(config).await.unwrap();
let mut request = Request::builder()
.header("X-Forwarded-For", "198.51.100.7, 203.0.113.8")
.body(Body::empty())
.unwrap();
let addr: SocketAddr = "203.0.113.10:4242".parse().unwrap();
request.extensions_mut().insert(ConnectInfo(addr));

assert_eq!(get_rate_limit_key(&request, &state), "ip:198.51.100.7");
}
}