From 4cfa07326841020bcacdda8453914c44ffbe4385 Mon Sep 17 00:00:00 2001 From: onlyyu1996 <1158673577@qq.com> Date: Mon, 18 May 2026 14:41:26 +0800 Subject: [PATCH] fix(app-server): rate limit direct clients by IP --- src/cortex-app-server/src/lib.rs | 9 ++++-- src/cortex-app-server/src/middleware.rs | 38 +++++++++++++++++++++++-- 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/src/cortex-app-server/src/lib.rs b/src/cortex-app-server/src/lib.rs index 8e7acdf88..c2e753a1b 100644 --- a/src/cortex-app-server/src/lib.rs +++ b/src/cortex-app-server/src/lib.rs @@ -99,9 +99,12 @@ where }; let listener = TcpListener::bind(addr).await?; - axum::serve(listener, app) - .with_graceful_shutdown(shutdown) - .await?; + axum::serve( + listener, + app.into_make_service_with_connect_info::(), + ) + .with_graceful_shutdown(shutdown) + .await?; // Graceful shutdown: close all active sessions first // This ensures WebSocket clients receive proper close frames diff --git a/src/cortex-app-server/src/middleware.rs b/src/cortex-app-server/src/middleware.rs index a9971576d..1653f8894 100644 --- a/src/cortex-app-server/src/middleware.rs +++ b/src/cortex-app-server/src/middleware.rs @@ -1,10 +1,11 @@ //! HTTP middleware components. +use std::net::SocketAddr; use std::sync::Arc; use std::time::{Duration, Instant}; use axum::{ - extract::{Request, State}, + extract::{ConnectInfo, Request, State}, http::{HeaderValue, Method, StatusCode, header}, middleware::Next, response::{IntoResponse, Response}, @@ -178,7 +179,11 @@ fn get_rate_limit_key(request: &Request, state: &AppState) -> String { } } - // Default to unknown when not behind proxy or headers not present + if let Some(ConnectInfo(addr)) = request.extensions().get::>() { + return format!("ip:{}", addr.ip()); + } + + // Default to unknown only when connection metadata is unavailable. "ip:unknown".to_string() } @@ -458,7 +463,10 @@ pub async fn health_check_bypass_middleware(request: Request, next: Next) -> Res #[cfg(test)] mod tests { + use axum::body::Body; + use super::*; + use crate::config::ServerConfig; #[test] fn test_request_id() { @@ -477,4 +485,30 @@ mod tests { .contains(&"Authorization".to_string()) ); } + + #[tokio::test] + async fn rate_limit_key_uses_connect_info_for_direct_clients() { + let config = ServerConfig::default(); + let state = AppState::new(config).await.unwrap(); + let mut request = Request::builder().body(Body::empty()).unwrap(); + let addr: SocketAddr = "203.0.113.10:4242".parse().unwrap(); + request.extensions_mut().insert(ConnectInfo(addr)); + + assert_eq!(get_rate_limit_key(&request, &state), "ip:203.0.113.10"); + } + + #[tokio::test] + async fn rate_limit_key_prefers_proxy_headers_when_trusted() { + let mut config = ServerConfig::default(); + config.rate_limit.trust_proxy = true; + let state = AppState::new(config).await.unwrap(); + let mut request = Request::builder() + .header("X-Forwarded-For", "198.51.100.7, 203.0.113.8") + .body(Body::empty()) + .unwrap(); + let addr: SocketAddr = "203.0.113.10:4242".parse().unwrap(); + request.extensions_mut().insert(ConnectInfo(addr)); + + assert_eq!(get_rate_limit_key(&request, &state), "ip:198.51.100.7"); + } }