Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ workspace = true

[dependencies]
nemo-flow = { workspace = true, features = ["openinference"] }
nemo-flow-adaptive = { workspace = true, features = ["redis-backend"] }
async-stream = "0.3"
axum = "0.8"
bytes = "1"
Expand Down
106 changes: 80 additions & 26 deletions crates/cli/src/gateway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::sync::{Arc, Mutex};
use async_stream::stream;
use axum::body::{Body, Bytes};
use axum::extract::State;
use axum::http::{HeaderMap, HeaderName, Method, Request, Response, StatusCode};
use axum::http::{HeaderMap, HeaderName, HeaderValue, Method, Request, Response, StatusCode};
use futures_util::StreamExt;
use nemo_flow::api::llm::{
LlmCallExecuteParams, LlmRequest, LlmStreamCallExecuteParams, llm_call_execute,
Expand Down Expand Up @@ -297,7 +297,7 @@ fn build_buffered_func(
let body_bytes = prepared.body_bytes.clone();
let headers = prepared.headers.clone();
let route = prepared.provider;
Arc::new(move |_request| {
Arc::new(move |request| {
let http = http.clone();
let method = method.clone();
let url = url.clone();
Expand All @@ -307,17 +307,24 @@ fn build_buffered_func(
let upstream_error = upstream_error.clone();
let response_bytes = response_bytes.clone();
Box::pin(async move {
let response =
match forward_upstream_request(&http, &method, &url, &body_bytes, &headers, route)
.await
{
Ok(response) => response,
Err(error) => {
let message = error.to_string();
*upstream_error.lock().expect("upstream error lock poisoned") = Some(error);
return Err(FlowError::Internal(message));
}
};
let response = match forward_upstream_request(
&http,
&method,
&url,
&body_bytes,
&headers,
Some(&request),
route,
)
.await
{
Ok(response) => response,
Err(error) => {
let message = error.to_string();
*upstream_error.lock().expect("upstream error lock poisoned") = Some(error);
return Err(FlowError::Internal(message));
}
};
let status = response.status();
let response_headers = response_headers(response.headers());
let bytes = match response.bytes().await {
Expand Down Expand Up @@ -431,7 +438,7 @@ fn build_streaming_func(
let body_bytes = prepared.body_bytes.clone();
let headers = prepared.headers.clone();
let route = prepared.provider;
Arc::new(move |_request| {
Arc::new(move |request| {
let http = http.clone();
let method = method.clone();
let url = url.clone();
Expand All @@ -440,17 +447,24 @@ fn build_streaming_func(
let upstream_info = upstream_info.clone();
let upstream_error = upstream_error.clone();
Box::pin(async move {
let response =
match forward_upstream_request(&http, &method, &url, &body_bytes, &headers, route)
.await
{
Ok(response) => response,
Err(error) => {
let message = error.to_string();
*upstream_error.lock().expect("upstream error lock poisoned") = Some(error);
return Err(FlowError::Internal(message));
}
};
let response = match forward_upstream_request(
&http,
&method,
&url,
&body_bytes,
&headers,
Some(&request),
route,
)
.await
{
Ok(response) => response,
Err(error) => {
let message = error.to_string();
*upstream_error.lock().expect("upstream error lock poisoned") = Some(error);
return Err(FlowError::Internal(message));
}
};
let status = response.status();
let response_headers = response_headers(response.headers());
*upstream_info.lock().expect("upstream info lock poisoned") =
Expand Down Expand Up @@ -554,8 +568,10 @@ async fn forward_upstream_request(
url: &str,
body_bytes: &Bytes,
headers: &HeaderMap,
effective_request: Option<&LlmRequest>,
route: ProviderRoute,
) -> Result<reqwest::Response, reqwest::Error> {
let (body_bytes, headers) = effective_upstream_request(body_bytes, headers, effective_request);
// Only strip the inbound JWT when we actually have a replacement key to inject. Without one
// the upstream just receives no auth and 401s, which is no better than letting it reject the
// JWT itself — and stripping silently can break setups that point the gateway at an upstream
Expand All @@ -566,7 +582,7 @@ async fn forward_upstream_request(
.ok()
.filter(|v| !v.trim().is_empty())
.is_some();
let sanitized = strip_chatgpt_oauth_for_openai_route(headers, route, has_openai_env);
let sanitized = strip_chatgpt_oauth_for_openai_route(&headers, route, has_openai_env);
let mut upstream = http.request(method.clone(), url).body(body_bytes.clone());
for (name, value) in &sanitized {
if should_forward_request_header(name) {
Expand All @@ -577,6 +593,43 @@ async fn forward_upstream_request(
upstream.send().await
}

fn effective_upstream_request(
body_bytes: &Bytes,
headers: &HeaderMap,
effective_request: Option<&LlmRequest>,
) -> (Bytes, HeaderMap) {
let Some(request) = effective_request else {
return (body_bytes.clone(), headers.clone());
};

let body_bytes = if request.content.is_null() {
body_bytes.clone()
} else {
serde_json::to_vec(&request.content)
.map(Bytes::from)
.unwrap_or_else(|_| body_bytes.clone())
};
let mut headers = headers.clone();
Comment thread
coderabbitai[bot] marked this conversation as resolved.
for (name, value) in &request.headers {
let Ok(name) = HeaderName::from_bytes(name.as_bytes()) else {
continue;
};
let Some(value) = json_header_value(value) else {
continue;
};
headers.insert(name, value);
}
(body_bytes, headers)
}

fn json_header_value(value: &Value) -> Option<HeaderValue> {
let rendered = match value {
Value::String(value) => value.clone(),
value => serde_json::to_string(value).ok()?,
};
HeaderValue::from_str(&rendered).ok()
}

// Builds the upstream URL for the ChatGPT backend. OpenAI API bases commonly include `/v1`, while
// the ChatGPT backend base is
// `chatgpt.com/backend-api/codex` (no `/v1`). Both append `/responses` to their base, so the
Expand Down Expand Up @@ -718,6 +771,7 @@ async fn passthrough_streaming(
&prepared.upstream_url,
&prepared.body_bytes,
&prepared.headers,
None,
prepared.provider,
)
.await?;
Expand Down
4 changes: 4 additions & 0 deletions crates/cli/src/plugins/config_io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::path::{Path, PathBuf};

use console::style;
use nemo_flow::plugin::{ConfigPolicy, PluginConfig, validate_plugin_config};
use nemo_flow_adaptive::plugin_component::register_adaptive_component;
use serde_json::{Map, Value};

use crate::config::{
Expand Down Expand Up @@ -115,6 +116,9 @@ pub(super) fn print_preview(config: &PluginConfig) -> Result<(), CliError> {
}

pub(super) fn validate_config(config: &PluginConfig) -> Result<(), CliError> {
register_adaptive_component().map_err(|error| {
CliError::Config(format!("adaptive plugin registration failed: {error}"))
})?;
let report = validate_plugin_config(config);
if report.has_errors() {
let messages = report
Expand Down
4 changes: 4 additions & 0 deletions crates/cli/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use axum::http::HeaderMap;
use axum::routing::{get, post};
use axum::{Json, Router};
use nemo_flow::plugin::{PluginConfig, clear_plugin_configuration, initialize_plugins};
use nemo_flow_adaptive::plugin_component::register_adaptive_component;
use reqwest::Client;
use serde_json::Value;
use tokio::net::TcpListener;
Expand Down Expand Up @@ -152,6 +153,9 @@ impl PluginActivation {
let Some(config) = config else {
return Ok(Self { active: false });
};
register_adaptive_component().map_err(|error| {
CliError::Config(format!("adaptive plugin registration failed: {error}"))
})?;
let plugin_config: PluginConfig = serde_json::from_value(config)
.map_err(|error| CliError::Config(format!("invalid plugin config: {error}")))?;
initialize_plugins(plugin_config)
Expand Down
101 changes: 100 additions & 1 deletion crates/cli/tests/coverage/gateway_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::server::AppState;
use crate::session::SessionManager;
use axum::body::Body;
use axum::extract::State;
use axum::http::{HeaderMap, HeaderValue, Method, Request, StatusCode};
use axum::http::{HeaderMap, HeaderValue, Method, Request, StatusCode, header};
use http_body_util::BodyExt;
use reqwest::Client;

Expand Down Expand Up @@ -140,6 +140,105 @@ fn openai_upstream_url_accepts_origin_or_v1_base() {
);
}

#[test]
fn effective_upstream_request_overlays_runtime_body_and_headers() {
let original_body = Bytes::from_static(br#"{"model":"original"}"#);
let mut original_headers = HeaderMap::new();
original_headers.insert(
header::AUTHORIZATION,
HeaderValue::from_static("Bearer original"),
);
let request = LlmRequest {
headers: Map::from_iter([
("x-runtime".to_string(), json!("enabled")),
("x-runtime-json".to_string(), json!({ "enabled": true })),
]),
content: json!({
"model": "rewritten",
"nvext": { "agent_hints": { "priority": 1 } }
}),
};

let (body, headers) =
effective_upstream_request(&original_body, &original_headers, Some(&request));
let body: Value = serde_json::from_slice(&body).unwrap();

assert_eq!(body["model"], json!("rewritten"));
assert_eq!(body["nvext"]["agent_hints"]["priority"], json!(1));
assert_eq!(
headers.get(header::AUTHORIZATION).unwrap(),
"Bearer original"
);
assert_eq!(headers.get("x-runtime").unwrap(), "enabled");
assert_eq!(
headers.get("x-runtime-json").unwrap(),
r#"{"enabled":true}"#
);
}

#[test]
fn effective_upstream_request_returns_original_without_runtime_request() {
let original_body = Bytes::from_static(br#"{"model":"original"}"#);
let mut original_headers = HeaderMap::new();
original_headers.insert(
header::AUTHORIZATION,
HeaderValue::from_static("Bearer original"),
);
original_headers.insert("x-request-id", HeaderValue::from_static("request-1"));

let (body, headers) = effective_upstream_request(&original_body, &original_headers, None);

assert_eq!(body, original_body);
assert_eq!(
headers.get(header::AUTHORIZATION).unwrap(),
"Bearer original"
);
assert_eq!(headers.get("x-request-id").unwrap(), "request-1");
}

#[test]
fn effective_upstream_request_preserves_original_body_for_null_runtime_content() {
let original_body = Bytes::from_static(b"not-json-but-still-upstream-body");
let mut original_headers = HeaderMap::new();
original_headers.insert("x-original", HeaderValue::from_static("kept"));
let request = LlmRequest {
headers: Map::from_iter([("x-runtime".to_string(), json!("enabled"))]),
content: Value::Null,
};

let (body, headers) =
effective_upstream_request(&original_body, &original_headers, Some(&request));

assert_eq!(body, original_body);
assert_eq!(headers.get("x-original").unwrap(), "kept");
assert_eq!(headers.get("x-runtime").unwrap(), "enabled");
}

#[test]
fn effective_upstream_request_skips_invalid_runtime_headers() {
let original_body = Bytes::from_static(br#"{"model":"original"}"#);
let mut original_headers = HeaderMap::new();
original_headers.insert("x-original", HeaderValue::from_static("kept"));
let request = LlmRequest {
headers: Map::from_iter([
("bad header".to_string(), json!("skip")),
("x-invalid-value".to_string(), json!("line\nbreak")),
("x-good".to_string(), json!("ok")),
]),
content: json!({ "model": "rewritten" }),
};

let (body, headers) =
effective_upstream_request(&original_body, &original_headers, Some(&request));
let body: Value = serde_json::from_slice(&body).unwrap();

assert_eq!(body["model"], json!("rewritten"));
assert_eq!(headers.get("x-original").unwrap(), "kept");
assert_eq!(headers.get("x-good").unwrap(), "ok");
assert!(headers.get("x-invalid-value").is_none());
assert!(headers.keys().all(|name| name.as_str() != "bad header"));
}

#[test]
fn gateway_session_id_prefers_headers_and_has_fallbacks() {
let mut headers = HeaderMap::new();
Expand Down
Loading
Loading