Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

129 changes: 128 additions & 1 deletion crates/adaptive/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ impl Default for AdaptiveConfig {
}

/// Shared state configuration consumed by adaptive features that need persistence.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct StateConfig {
/// Backend selection for adaptive state.
pub backend: BackendSpec,
Expand All @@ -68,6 +68,12 @@ pub struct BackendSpec {
pub config: Map<String, Json>,
}

impl Default for BackendSpec {
fn default() -> Self {
Self::in_memory()
}
}

impl BackendSpec {
/// Creates an in-memory backend spec.
pub fn in_memory() -> Self {
Expand Down Expand Up @@ -209,6 +215,127 @@ fn default_acg_priority() -> i32 {
50
}

nemo_flow::editor_config! {
impl AdaptiveConfig {
agent_id => { label: "agent_id", kind: String, optional: true },
state => {
label: "state",
kind: Section,
optional: true,
nested: StateConfig,
default: StateConfig,
},
telemetry => {
label: "telemetry",
kind: Section,
optional: true,
nested: TelemetryComponentConfig,
default: TelemetryComponentConfig,
},
adaptive_hints => {
label: "adaptive_hints",
kind: Section,
optional: true,
nested: AdaptiveHintsComponentConfig,
default: AdaptiveHintsComponentConfig,
},
tool_parallelism => {
label: "tool_parallelism",
kind: Section,
optional: true,
nested: ToolParallelismComponentConfig,
default: ToolParallelismComponentConfig,
},
acg => {
label: "acg",
kind: Section,
optional: true,
nested: AcgComponentConfig,
default: AcgComponentConfig,
},
policy => {
label: "policy",
kind: Section,
nested: ConfigPolicy,
default: ConfigPolicy,
},
}
}

nemo_flow::editor_config! {
impl StateConfig {
backend => {
label: "backend",
kind: Section,
nested: BackendSpec,
default: BackendSpec,
},
}
}

nemo_flow::editor_config! {
impl BackendSpec {
kind => { label: "kind", kind: Enum, values: ["in_memory", "redis"] },
config => { label: "config", kind: Json },
}
}

nemo_flow::editor_config! {
impl TelemetryComponentConfig {
subscriber_name => { label: "subscriber_name", kind: String, optional: true },
learners => { label: "learners", kind: Json },
}
}

nemo_flow::editor_config! {
impl AdaptiveHintsComponentConfig {
priority => { label: "priority", kind: Integer },
break_chain => { label: "break_chain", kind: Boolean },
inject_header => { label: "inject_header", kind: Boolean },
inject_body_path => { label: "inject_body_path", kind: String },
}
}

nemo_flow::editor_config! {
impl ToolParallelismComponentConfig {
priority => { label: "priority", kind: Integer },
mode => {
label: "mode",
kind: Enum,
values: ["observe_only", "inject_hints", "schedule"],
},
}
}

nemo_flow::editor_config! {
impl AcgComponentConfig {
provider => {
label: "provider",
kind: Enum,
values: ["passthrough", "anthropic", "openai"],
},
observation_window => { label: "observation_window", kind: Integer },
priority => { label: "priority", kind: Integer },
stability_thresholds => {
label: "stability_thresholds",
kind: Section,
nested: crate::acg::stability::StabilityThresholds,
default: crate::acg::stability::StabilityThresholds,
},
}
}

nemo_flow::editor_config! {
impl crate::acg::stability::StabilityThresholds {
stable_threshold => { label: "stable_threshold", kind: Float },
semi_stable_threshold => { label: "semi_stable_threshold", kind: Float },
min_observations_for_full_confidence => {
label: "min_observations_for_full_confidence",
kind: Integer,
},
}
}

#[cfg(test)]
#[path = "../tests/unit/config_tests.rs"]
mod tests;
48 changes: 48 additions & 0 deletions crates/adaptive/tests/unit/config_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
//! Unit tests for config in the NeMo Flow adaptive crate.

use super::*;
use nemo_flow::config_editor::{EditorConfig, EditorFieldKind};
use serde_json::json;

#[test]
Expand Down Expand Up @@ -71,3 +72,50 @@ fn test_component_configs_deserialize_with_default_helpers() {
assert_eq!(tool_parallelism.priority, 100);
assert_eq!(tool_parallelism.mode, "observe_only");
}

#[test]
fn test_adaptive_editor_schema_covers_canonical_options() {
let schema = AdaptiveConfig::editor_schema();
let fields = schema
.fields
.iter()
.map(|field| field.name)
.collect::<Vec<_>>();
assert_eq!(
fields,
vec![
"agent_id",
"state",
"telemetry",
"adaptive_hints",
"tool_parallelism",
"acg",
"policy",
]
);

let state = schema.field("state").unwrap().schema().unwrap();
let backend = state.field("backend").unwrap().schema().unwrap();
assert_eq!(backend.field("kind").unwrap().kind, EditorFieldKind::Enum);
assert_eq!(backend.field("config").unwrap().kind, EditorFieldKind::Json);

let telemetry = schema.field("telemetry").unwrap().schema().unwrap();
assert_eq!(
telemetry.field("learners").unwrap().kind,
EditorFieldKind::Json
);

let acg = schema.field("acg").unwrap().schema().unwrap();
let thresholds = acg.field("stability_thresholds").unwrap().schema().unwrap();
assert_eq!(
thresholds.field("stable_threshold").unwrap().kind,
EditorFieldKind::Float
);
assert_eq!(
thresholds
.field("min_observations_for_full_confidence")
.unwrap()
.kind,
EditorFieldKind::Integer
);
}
1 change: 1 addition & 0 deletions crates/cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ workspace = true

[dependencies]
nemo-flow = { workspace = true, features = ["openinference"] }
nemo-flow-adaptive = { workspace = true, features = ["redis-backend"] }
async-stream = "0.3"
axum = "0.8"
bytes = "1"
Expand Down
106 changes: 80 additions & 26 deletions crates/cli/src/gateway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::sync::{Arc, Mutex};
use async_stream::stream;
use axum::body::{Body, Bytes};
use axum::extract::State;
use axum::http::{HeaderMap, HeaderName, Method, Request, Response, StatusCode};
use axum::http::{HeaderMap, HeaderName, HeaderValue, Method, Request, Response, StatusCode};
use futures_util::StreamExt;
use nemo_flow::api::llm::{
LlmCallExecuteParams, LlmRequest, LlmStreamCallExecuteParams, llm_call_execute,
Expand Down Expand Up @@ -297,7 +297,7 @@ fn build_buffered_func(
let body_bytes = prepared.body_bytes.clone();
let headers = prepared.headers.clone();
let route = prepared.provider;
Arc::new(move |_request| {
Arc::new(move |request| {
let http = http.clone();
let method = method.clone();
let url = url.clone();
Expand All @@ -307,17 +307,24 @@ fn build_buffered_func(
let upstream_error = upstream_error.clone();
let response_bytes = response_bytes.clone();
Box::pin(async move {
let response =
match forward_upstream_request(&http, &method, &url, &body_bytes, &headers, route)
.await
{
Ok(response) => response,
Err(error) => {
let message = error.to_string();
*upstream_error.lock().expect("upstream error lock poisoned") = Some(error);
return Err(FlowError::Internal(message));
}
};
let response = match forward_upstream_request(
&http,
&method,
&url,
&body_bytes,
&headers,
Some(&request),
route,
)
.await
{
Ok(response) => response,
Err(error) => {
let message = error.to_string();
*upstream_error.lock().expect("upstream error lock poisoned") = Some(error);
return Err(FlowError::Internal(message));
}
};
let status = response.status();
let response_headers = response_headers(response.headers());
let bytes = match response.bytes().await {
Expand Down Expand Up @@ -431,7 +438,7 @@ fn build_streaming_func(
let body_bytes = prepared.body_bytes.clone();
let headers = prepared.headers.clone();
let route = prepared.provider;
Arc::new(move |_request| {
Arc::new(move |request| {
let http = http.clone();
let method = method.clone();
let url = url.clone();
Expand All @@ -440,17 +447,24 @@ fn build_streaming_func(
let upstream_info = upstream_info.clone();
let upstream_error = upstream_error.clone();
Box::pin(async move {
let response =
match forward_upstream_request(&http, &method, &url, &body_bytes, &headers, route)
.await
{
Ok(response) => response,
Err(error) => {
let message = error.to_string();
*upstream_error.lock().expect("upstream error lock poisoned") = Some(error);
return Err(FlowError::Internal(message));
}
};
let response = match forward_upstream_request(
&http,
&method,
&url,
&body_bytes,
&headers,
Some(&request),
route,
)
.await
{
Ok(response) => response,
Err(error) => {
let message = error.to_string();
*upstream_error.lock().expect("upstream error lock poisoned") = Some(error);
return Err(FlowError::Internal(message));
}
};
let status = response.status();
let response_headers = response_headers(response.headers());
*upstream_info.lock().expect("upstream info lock poisoned") =
Expand Down Expand Up @@ -554,8 +568,10 @@ async fn forward_upstream_request(
url: &str,
body_bytes: &Bytes,
headers: &HeaderMap,
effective_request: Option<&LlmRequest>,
route: ProviderRoute,
) -> Result<reqwest::Response, reqwest::Error> {
let (body_bytes, headers) = effective_upstream_request(body_bytes, headers, effective_request);
// Only strip the inbound JWT when we actually have a replacement key to inject. Without one
// the upstream just receives no auth and 401s, which is no better than letting it reject the
// JWT itself — and stripping silently can break setups that point the gateway at an upstream
Expand All @@ -566,7 +582,7 @@ async fn forward_upstream_request(
.ok()
.filter(|v| !v.trim().is_empty())
.is_some();
let sanitized = strip_chatgpt_oauth_for_openai_route(headers, route, has_openai_env);
let sanitized = strip_chatgpt_oauth_for_openai_route(&headers, route, has_openai_env);
let mut upstream = http.request(method.clone(), url).body(body_bytes.clone());
for (name, value) in &sanitized {
if should_forward_request_header(name) {
Expand All @@ -577,6 +593,43 @@ async fn forward_upstream_request(
upstream.send().await
}

fn effective_upstream_request(
body_bytes: &Bytes,
headers: &HeaderMap,
effective_request: Option<&LlmRequest>,
) -> (Bytes, HeaderMap) {
let Some(request) = effective_request else {
return (body_bytes.clone(), headers.clone());
};

let body_bytes = if request.content.is_null() {
body_bytes.clone()
} else {
serde_json::to_vec(&request.content)
.map(Bytes::from)
.unwrap_or_else(|_| body_bytes.clone())
};
let mut headers = headers.clone();
for (name, value) in &request.headers {
let Ok(name) = HeaderName::from_bytes(name.as_bytes()) else {
continue;
};
let Some(value) = json_header_value(value) else {
continue;
};
headers.insert(name, value);
}
(body_bytes, headers)
}

fn json_header_value(value: &Value) -> Option<HeaderValue> {
let rendered = match value {
Value::String(value) => value.clone(),
value => serde_json::to_string(value).ok()?,
};
HeaderValue::from_str(&rendered).ok()
}

// Builds the upstream URL for the ChatGPT backend. OpenAI API bases commonly include `/v1`, while
// the ChatGPT backend base is
// `chatgpt.com/backend-api/codex` (no `/v1`). Both append `/responses` to their base, so the
Expand Down Expand Up @@ -718,6 +771,7 @@ async fn passthrough_streaming(
&prepared.upstream_url,
&prepared.body_bytes,
&prepared.headers,
None,
prepared.provider,
)
.await?;
Expand Down
Loading
Loading