Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/cortex-app-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,4 @@ gethostname = "0.5"

[dev-dependencies]
tokio-test = { workspace = true }
tower = { version = "0.5", default-features = false, features = ["util"] }
1 change: 1 addition & 0 deletions src/cortex-app-server/src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ pub use types::{
/// Create the API routes.
pub fn routes() -> Router<Arc<AppState>> {
Router::new()
.without_v07_checks()
// Health and metrics
.route("/health", get(health::health_check))
.route("/metrics", get(health::get_metrics))
Expand Down
171 changes: 163 additions & 8 deletions src/cortex-app-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,8 @@ pub mod websocket;
use std::net::SocketAddr;
use std::sync::Arc;

use axum::Router;
use axum::{Router, middleware as axum_middleware};
use tokio::net::TcpListener;
use tower_http::cors::CorsLayer;
use tower_http::trace::TraceLayer;
use tracing::{info, warn};

Expand Down Expand Up @@ -131,15 +130,171 @@ pub fn create_router(state: AppState) -> Router {
/// This variant is useful when you need to keep a reference to the state
/// for cleanup purposes (e.g., during graceful shutdown).
pub fn create_router_with_state(state: Arc<AppState>) -> Router {
let api_routes = api::routes()
.merge(websocket::routes())
.merge(streaming::routes())
.merge(share::routes())
.merge(admin::routes());
let cors_layer = middleware::cors_layer(&state.config.cors_origins);

let api_routes = add_api_middleware(
api::routes()
.merge(websocket::routes())
.merge(streaming::routes())
.merge(share::routes())
.merge(admin::routes()),
Arc::clone(&state),
);

Router::new()
.without_v07_checks()
.nest("/api/v1", api_routes)
.layer(TraceLayer::new_for_http())
.layer(CorsLayer::permissive())
.layer(cors_layer)
.with_state(state)
}

fn add_api_middleware(
router: Router<Arc<AppState>>,
state: Arc<AppState>,
) -> Router<Arc<AppState>> {
router
.layer(axum_middleware::from_fn_with_state(
Arc::clone(&state),
middleware::rate_limit_middleware,
))
.layer(axum_middleware::from_fn(
middleware::content_type_middleware,
))
.layer(axum_middleware::from_fn_with_state(
state,
middleware::timeout_middleware,
))
.layer(axum_middleware::from_fn(
middleware::security_headers_middleware,
))
}

#[cfg(test)]
mod tests {
use super::*;
use axum::{
body::Body,
http::{Request, StatusCode, header},
routing::get,
};
use tower::ServiceExt;

async fn test_app(config: ServerConfig) -> Router {
let state = AppState::new(config).await.unwrap();
create_router(state)
}

async fn slow_test_handler() -> &'static str {
tokio::time::sleep(std::time::Duration::from_millis(25)).await;
"done"
}

#[tokio::test]
async fn create_router_applies_security_headers_middleware() {
let app = test_app(ServerConfig::default()).await;

let response = app
.oneshot(
Request::builder()
.uri("/api/v1/models")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();

assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get("X-Content-Type-Options").unwrap(),
"nosniff"
);
assert_eq!(response.headers().get("X-Frame-Options").unwrap(), "DENY");
}

#[tokio::test]
async fn create_router_applies_rate_limit_middleware() {
let mut config = ServerConfig::default();
config.rate_limit.requests_per_minute = 1;
config.rate_limit.burst_size = 1;
config.rate_limit.exempt_paths.clear();

let app = test_app(config).await;

let first = app
.clone()
.oneshot(
Request::builder()
.uri("/api/v1/models")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(first.status(), StatusCode::OK);

let second = app
.oneshot(
Request::builder()
.uri("/api/v1/models")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(second.status(), StatusCode::TOO_MANY_REQUESTS);
assert_eq!(second.headers().get(header::RETRY_AFTER).unwrap(), "60");
assert_eq!(
second.headers().get("X-Content-Type-Options").unwrap(),
"nosniff"
);
}

#[tokio::test]
async fn create_router_applies_content_type_middleware() {
let app = test_app(ServerConfig::default()).await;

let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/api/v1/models")
.header(header::CONTENT_TYPE, "text/plain")
.body(Body::from("not json"))
.unwrap(),
)
.await
.unwrap();

assert_eq!(response.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE);
assert_eq!(
response.headers().get("X-Content-Type-Options").unwrap(),
"nosniff"
);
}

#[tokio::test]
async fn create_router_applies_timeout_middleware() {
let mut config = ServerConfig::default();
config.request_timeout = 0;
config.rate_limit.enabled = false;

let state = Arc::new(AppState::new(config).await.unwrap());
let app = add_api_middleware(
Router::new().route("/slow", get(slow_test_handler)),
Arc::clone(&state),
)
.with_state(state);

let response = app
.oneshot(Request::builder().uri("/slow").body(Body::empty()).unwrap())
.await
.unwrap();

assert_eq!(response.status(), StatusCode::GATEWAY_TIMEOUT);
assert_eq!(
response.headers().get("X-Content-Type-Options").unwrap(),
"nosniff"
);
}
}
1 change: 1 addition & 0 deletions src/cortex-app-server/src/share.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use crate::storage::StoredMessage;
/// Create share routes.
pub fn routes() -> Router<Arc<AppState>> {
Router::new()
.without_v07_checks()
.route("/share", post(create_share))
.route("/share/:token", get(get_shared_session))
.route("/share/:token", delete(revoke_share))
Expand Down
1 change: 1 addition & 0 deletions src/cortex-app-server/src/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use crate::state::AppState;
/// Create streaming API routes.
pub fn routes() -> Router<Arc<AppState>> {
Router::new()
.without_v07_checks()
// CLI Session management
.route("/cli/sessions", post(create_cli_session))
.route("/cli/sessions", get(list_cli_sessions))
Expand Down
1 change: 1 addition & 0 deletions src/cortex-app-server/src/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use crate::state::AppState;
/// Create WebSocket routes.
pub fn routes() -> Router<Arc<AppState>> {
Router::new()
.without_v07_checks()
.route("/ws", get(websocket_handler))
.route("/ws/sessions/:id", get(session_websocket_handler))
}
Expand Down