diff --git a/crates/braintrust-llm-router/examples/custom_auth.rs b/crates/braintrust-llm-router/examples/custom_auth.rs index c45466fa..db19c454 100644 --- a/crates/braintrust-llm-router/examples/custom_auth.rs +++ b/crates/braintrust-llm-router/examples/custom_auth.rs @@ -143,7 +143,7 @@ async fn main() -> Result<()> { ProviderFormat::ChatCompletions, ))?; let (request, _metadata) = router - .create_request(body, ProviderFormat::ChatCompletions, route) + .create_request(body, ProviderFormat::ChatCompletions, route, false) .await?; let bytes = router.complete(request, &ClientHeaders::default()).await?; let response: Value = serde_json::from_slice(&bytes)?; diff --git a/crates/braintrust-llm-router/examples/multi_provider.rs b/crates/braintrust-llm-router/examples/multi_provider.rs index 4134ed38..62002ccf 100644 --- a/crates/braintrust-llm-router/examples/multi_provider.rs +++ b/crates/braintrust-llm-router/examples/multi_provider.rs @@ -85,7 +85,7 @@ async fn main() -> Result<()> { match router.resolve_provider_routes(model, ProviderFormat::ChatCompletions, &[]) { Ok(routes) => match routes.first() { Some(route) => match router - .create_request(body, ProviderFormat::ChatCompletions, route) + .create_request(body, ProviderFormat::ChatCompletions, route, false) .await { Ok((request, _metadata)) => { @@ -122,7 +122,7 @@ async fn main() -> Result<()> { match router.resolve_provider_routes(model, ProviderFormat::ChatCompletions, &[]) { Ok(routes) => match routes.first() { Some(route) => match router - .create_request(body, ProviderFormat::ChatCompletions, route) + .create_request(body, ProviderFormat::ChatCompletions, route, false) .await { Ok((request, _metadata)) => { diff --git a/crates/braintrust-llm-router/examples/simple.rs b/crates/braintrust-llm-router/examples/simple.rs index 0a52c9e2..41ef2868 100644 --- a/crates/braintrust-llm-router/examples/simple.rs +++ b/crates/braintrust-llm-router/examples/simple.rs @@ -68,7 +68,7 @@ async fn main() -> Result<()> { ProviderFormat::ChatCompletions, ))?; let (request, _metadata) = router - .create_request(body, ProviderFormat::ChatCompletions, route) + .create_request(body, ProviderFormat::ChatCompletions, route, false) .await?; let bytes = router.complete(request, &ClientHeaders::default()).await?; let response: Value = serde_json::from_slice(&bytes)?; diff --git a/crates/braintrust-llm-router/examples/streaming.rs b/crates/braintrust-llm-router/examples/streaming.rs index b4c02a14..b7eac7a2 100644 --- a/crates/braintrust-llm-router/examples/streaming.rs +++ b/crates/braintrust-llm-router/examples/streaming.rs @@ -67,7 +67,7 @@ async fn main() -> Result<()> { ProviderFormat::ChatCompletions, ))?; let (request, _metadata) = router - .create_stream_request(body, ProviderFormat::ChatCompletions, route) + .create_stream_request(body, ProviderFormat::ChatCompletions, route, false) .await?; let mut stream = router .complete_stream(request, &ClientHeaders::default(), None) @@ -150,7 +150,7 @@ async fn main() -> Result<()> { ProviderFormat::ChatCompletions, ))?; let (request, _metadata) = router - .create_stream_request(body, ProviderFormat::ChatCompletions, route) + .create_stream_request(body, ProviderFormat::ChatCompletions, route, false) .await?; let stream = router .complete_stream(request, &ClientHeaders::default(), None) diff --git a/crates/braintrust-llm-router/src/catalog/fallback.rs b/crates/braintrust-llm-router/src/catalog/fallback.rs new file mode 100644 index 00000000..16705f6f --- /dev/null +++ b/crates/braintrust-llm-router/src/catalog/fallback.rs @@ -0,0 +1,320 @@ +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; + +use super::{ModelCatalog, ModelSpec}; +use crate::error::{Error, Result}; + +/// A request-local catalog overlay. +/// +/// Secret-defined custom models live in `custom` and shadow entries in the +/// shared `base` catalog. This avoids cloning the base catalog when adding +/// per-request model definitions. +#[derive(Debug, Clone)] +pub struct OverlayModelCatalog { + base: Arc, + custom: ModelCatalog, + custom_model_names: HashSet, + overlay_edges: HashMap>, +} + +impl OverlayModelCatalog { + pub fn new(base: Arc, custom: ModelCatalog) -> Self { + let custom_model_names: HashSet = custom.models.keys().cloned().collect(); + let mut overlay_edges: HashMap> = HashMap::new(); + for (name, fallbacks) in &custom.fallback_models { + if !custom.models.contains_key(name) { + continue; + } + for fallback_model in fallbacks { + let fallback_is_visible = custom.models.contains_key(fallback_model) + || (base.models.contains_key(fallback_model) + && !custom_model_names.contains(fallback_model)); + if !fallback_is_visible { + continue; + } + overlay_edges + .entry(name.clone()) + .or_default() + .push(fallback_model.clone()); + overlay_edges + .entry(fallback_model.clone()) + .or_default() + .push(name.clone()); + } + } + Self { + base, + custom, + custom_model_names, + overlay_edges, + } + } + + pub fn base_catalog(&self) -> Arc { + Arc::clone(&self.base) + } + + pub fn get(&self, name: &str) -> Option> { + self.custom.get(name).or_else(|| self.base.get(name)) + } + + pub fn find_fallback_models(&self, name: &str) -> Vec { + let Some(_) = self.get(name) else { + return Vec::new(); + }; + + let mut visited = HashSet::new(); + let mut stack = vec![name.to_string()]; + while let Some(current) = stack.pop() { + if !visited.insert(current.clone()) { + continue; + } + + if !self.custom_model_names.contains(¤t) { + stack.extend( + self.base + .fallback_models(¤t) + .into_iter() + .filter(|model_name| !self.custom_model_names.contains(model_name)), + ); + } + if let Some(neighbors) = self.overlay_edges.get(¤t) { + stack.extend(neighbors.iter().cloned()); + } + } + + let mut names = vec![name.to_string()]; + visited.remove(name); + let mut equivalent_names: Vec = visited.into_iter().collect(); + equivalent_names.sort(); + names.extend(equivalent_names); + names + } +} + +enum FallbackModelSource<'a> { + Json(&'a str), + Parsed(HashMap>), +} + +impl ModelCatalog { + pub fn fallback_models(&self, name: &str) -> Vec { + let Some(_) = self.models.get(name) else { + return Vec::new(); + }; + + let mut names = vec![name.to_string()]; + if let Some(equivalent_names) = self.equivalence_index.get(name) { + names.extend(equivalent_names.iter().cloned()); + } + names + } + + pub fn add_fallback_models(&mut self, name: String, fallback_models: I) -> Result<()> + where + I: IntoIterator, + { + if !self.models.contains_key(&name) { + return Err(Error::InvalidRequest(format!( + "model '{name}' references fallback_models but is missing from catalog" + ))); + } + + let fallback_models: Vec = fallback_models + .into_iter() + .filter(|fallback_model| !fallback_model.is_empty()) + .collect(); + for fallback_model in &fallback_models { + if !self.models.contains_key(fallback_model) { + return Err(Error::InvalidRequest(format!( + "model '{name}' references missing fallback model '{fallback_model}'" + ))); + } + } + + let mut next_fallback_models = self.fallback_models.clone(); + let entry = next_fallback_models.entry(name).or_default(); + for fallback_model in fallback_models { + if entry.contains(&fallback_model) { + continue; + } + entry.push(fallback_model); + } + self.set_fallback_models(FallbackModelSource::Parsed(next_fallback_models), false)?; + Ok(()) + } + + pub fn add_external_fallback_models( + &mut self, + name: String, + fallback_models: I, + ) -> Result<()> + where + I: IntoIterator, + { + if !self.models.contains_key(&name) { + return Err(Error::InvalidRequest(format!( + "model '{name}' references fallback_models but is missing from catalog" + ))); + } + + let mut next_fallback_models = self.fallback_models.clone(); + let entry = next_fallback_models.entry(name).or_default(); + for fallback_model in fallback_models { + if fallback_model.is_empty() || entry.contains(&fallback_model) { + continue; + } + entry.push(fallback_model); + } + self.set_fallback_models(FallbackModelSource::Parsed(next_fallback_models), false)?; + Ok(()) + } + + pub(super) fn set_fallback_models_from_json( + &mut self, + content: &str, + validate_targets: bool, + ) -> Result<()> { + self.set_fallback_models(FallbackModelSource::Json(content), validate_targets) + } + + pub(super) fn set_fallback_models_from_parsed( + &mut self, + fallback_models: HashMap>, + validate_targets: bool, + ) -> Result<()> { + self.set_fallback_models( + FallbackModelSource::Parsed(fallback_models), + validate_targets, + ) + } + + fn set_fallback_models( + &mut self, + source: FallbackModelSource<'_>, + validate_targets: bool, + ) -> Result<()> { + let fallback_models = match source { + FallbackModelSource::Json(content) => parse_fallback_models(content)?, + FallbackModelSource::Parsed(fallback_models) => fallback_models, + }; + + if validate_targets { + validate_fallback_models(&self.models, &fallback_models)?; + } + self.equivalence_index = + build_equivalence_index(self.models.keys().cloned().collect(), &fallback_models); + self.fallback_models = fallback_models; + Ok(()) + } +} + +fn parse_fallback_models(content: &str) -> Result>> { + let raw: HashMap = serde_json::from_str(content)?; + let mut fallback_models = HashMap::new(); + for (name, value) in raw { + let Some(fallbacks) = value.get("fallback_models") else { + continue; + }; + let Some(fallbacks) = fallbacks.as_array() else { + return Err(Error::InvalidRequest(format!( + "model '{name}' has invalid fallback_models" + ))); + }; + let mut parsed = Vec::with_capacity(fallbacks.len()); + for fallback_model in fallbacks { + let Some(fallback_model) = fallback_model.as_str() else { + return Err(Error::InvalidRequest(format!( + "model '{name}' has invalid fallback_models" + ))); + }; + parsed.push(fallback_model.to_string()); + } + if !parsed.is_empty() { + fallback_models.insert(name, parsed); + } + } + Ok(fallback_models) +} + +fn validate_fallback_models( + models: &HashMap>, + fallback_models: &HashMap>, +) -> Result<()> { + for (name, fallback_models) in fallback_models { + for fallback_model in fallback_models { + if !models.contains_key(fallback_model) { + return Err(Error::InvalidRequest(format!( + "model '{name}' references missing fallback model '{fallback_model}'" + ))); + } + } + } + Ok(()) +} + +fn build_equivalence_index( + model_names: HashSet, + fallback_models: &HashMap>, +) -> HashMap> { + let mut adjacency: HashMap> = HashMap::new(); + for name in &model_names { + adjacency.entry(name.clone()).or_default(); + } + + for (name, fallbacks) in fallback_models { + if !model_names.contains(name) { + continue; + } + for fallback_model in fallbacks { + if !model_names.contains(fallback_model) { + continue; + } + adjacency + .entry(name.clone()) + .or_default() + .push(fallback_model.clone()); + adjacency + .entry(fallback_model.clone()) + .or_default() + .push(name.clone()); + } + } + + let mut visited = HashSet::new(); + let mut index = HashMap::new(); + for name in model_names { + if visited.contains(&name) { + continue; + } + + let mut stack = vec![name.clone()]; + let mut component = Vec::new(); + while let Some(current) = stack.pop() { + if !visited.insert(current.clone()) { + continue; + } + component.push(current.clone()); + if let Some(neighbors) = adjacency.get(¤t) { + stack.extend(neighbors.iter().cloned()); + } + } + + if component.len() <= 1 { + continue; + } + component.sort(); + for member in &component { + index.insert( + member.clone(), + component + .iter() + .filter(|other| *other != member) + .cloned() + .collect(), + ); + } + } + + index +} diff --git a/crates/braintrust-llm-router/src/catalog/mod.rs b/crates/braintrust-llm-router/src/catalog/mod.rs index 68f20fc9..be6be2f2 100644 --- a/crates/braintrust-llm-router/src/catalog/mod.rs +++ b/crates/braintrust-llm-router/src/catalog/mod.rs @@ -1,6 +1,8 @@ +mod fallback; mod resolver; pub mod spec; +pub use fallback::OverlayModelCatalog; pub(crate) use resolver::is_gemini_api_model; pub use resolver::ModelResolver; pub use spec::{ModelFlavor, ModelSpec}; @@ -20,17 +22,8 @@ pub struct ModelCatalog { models: HashMap>, by_format: HashMap>, by_parent: HashMap>, -} - -/// A request-local catalog overlay. -/// -/// Secret-defined custom models live in `custom` and shadow entries in the -/// shared `base` catalog. This avoids cloning the base catalog when adding -/// per-request model definitions. -#[derive(Debug, Clone)] -pub struct OverlayModelCatalog { - pub base: Arc, - pub custom: ModelCatalog, + fallback_models: HashMap>, + equivalence_index: HashMap>, } /// Catalog view used by the router resolver. @@ -40,21 +33,28 @@ pub struct OverlayModelCatalog { #[derive(Debug, Clone)] pub enum CatalogResolver { Base(Arc), - Overlay(OverlayModelCatalog), + Overlay(Box), } impl CatalogResolver { pub fn base_catalog(&self) -> Arc { match self { Self::Base(catalog) => Arc::clone(catalog), - Self::Overlay(overlay) => Arc::clone(&overlay.base), + Self::Overlay(overlay) => overlay.base_catalog(), } } pub fn get(&self, name: &str) -> Option> { match self { Self::Base(catalog) => catalog.get(name), - Self::Overlay(overlay) => overlay.custom.get(name).or_else(|| overlay.base.get(name)), + Self::Overlay(overlay) => overlay.get(name), + } + } + + pub fn fallback_models(&self, name: &str) -> Vec { + match self { + Self::Base(catalog) => catalog.fallback_models(name), + Self::Overlay(overlay) => overlay.find_fallback_models(name), } } } @@ -76,6 +76,7 @@ impl ModelCatalog { for (name, spec) in raw { catalog.insert(name, spec); } + catalog.set_fallback_models_from_json(content, true)?; Ok(catalog) } @@ -142,6 +143,19 @@ impl ModelCatalog { self.models.iter() } + pub fn map_specs(&self, mut f: F) -> Self + where + F: FnMut(&str, &ModelSpec) -> ModelSpec, + { + let mut out = Self::empty(); + for (name, spec) in &self.models { + out.insert(name.clone(), f(name, spec.as_ref())); + } + out.set_fallback_models_from_parsed(self.fallback_models.clone(), false) + .expect("existing catalog fallback_models remain valid after mapping specs"); + out + } + pub fn len(&self) -> usize { self.models.len() } @@ -174,3 +188,304 @@ pub fn load_catalog_from_disk>(path: P) -> Result, ProviderFormat, Vec); + #[derive(Debug, Clone)] pub struct ModelResolver { catalog: CatalogResolver, @@ -19,13 +21,9 @@ impl ModelResolver { } } - /// Resolve models against custom entries first, then the shared base catalog. - /// - /// The base catalog remains the public `catalog()` view so existing callers - /// do not observe per-request custom models as global catalog entries. pub fn with_overlay(base: Arc, custom: ModelCatalog) -> Self { Self { - catalog: CatalogResolver::Overlay(OverlayModelCatalog { base, custom }), + catalog: CatalogResolver::Overlay(Box::new(OverlayModelCatalog::new(base, custom))), aliases: HashMap::new(), } } @@ -39,7 +37,22 @@ impl ModelResolver { self.catalog.base_catalog() } - pub fn resolve(&self, model: &str) -> Result<(Arc, ProviderFormat, Vec)> { + pub fn resolve(&self, model: &str) -> Result { + self.resolve_one(model) + } + + pub fn resolve_all_equivalent_model_routes(&self, model: &str) -> Result> { + let mut resolved = Vec::new(); + for model_name in self.catalog.fallback_models(model) { + resolved.push(self.resolve_one(&model_name)?); + } + if resolved.is_empty() { + return Err(Error::UnknownModel(model.to_string())); + } + Ok(resolved) + } + + fn resolve_one(&self, model: &str) -> Result { let spec = self .catalog .get(model) diff --git a/crates/braintrust-llm-router/src/providers/bedrock.rs b/crates/braintrust-llm-router/src/providers/bedrock.rs index ea7c5ae0..0aaddd1b 100644 --- a/crates/braintrust-llm-router/src/providers/bedrock.rs +++ b/crates/braintrust-llm-router/src/providers/bedrock.rs @@ -23,7 +23,7 @@ use crate::auth::AuthConfig; use crate::catalog::ModelSpec; use crate::client::{build_middleware_client, ClientSettings}; use crate::error::{Error, Result, UpstreamHttpError}; -use crate::providers::ClientHeaders; +use crate::providers::{rewrite_body_model_if_required, ClientHeaders}; use crate::streaming::{bedrock_event_stream, sse_stream, RawResponseStream}; use lingua::{ProviderFormat, TransformError}; @@ -91,7 +91,7 @@ where }; if source_adapter.format() == format { - return Ok(body); + return Ok(rewrite_body_model_if_required(body, format, &spec.model)); } let mut request = match source_adapter.request_to_universal(payload) { @@ -100,10 +100,7 @@ where }; inline_remote_image_urls_with_fetch(&mut request, fetch).await?; - - if request.model.is_none() { - request.model = Some(spec.model.clone()); - } + request.model = Some(spec.model.clone()); let target_adapter = adapter_for_format(format).ok_or(TransformError::UnsupportedTargetFormat(format))?; @@ -624,6 +621,11 @@ mod tests { let body = Bytes::from( lingua::serde_json::to_vec(&lingua::serde_json::json!({ "modelId": "anthropic.claude-3-haiku-20240307-v1:0", + "system": [{"text": "You are helpful."}], + "guardrailConfig": { + "guardrailIdentifier": "test", + "guardrailVersion": "1" + }, "messages": [{ "role": "user", "content": [{"text": "Hello"}] @@ -651,6 +653,57 @@ mod tests { assert_eq!(prepared, body); } + #[tokio::test] + async fn prepare_request_preserves_same_format_converse_model_and_native_fields() { + let body = Bytes::from( + lingua::serde_json::to_vec(&lingua::serde_json::json!({ + "modelId": "anthropic.claude-3-haiku-20240307-v1:0", + "system": [{"text": "You are helpful."}], + "guardrailConfig": { + "guardrailIdentifier": "test", + "guardrailVersion": "1" + }, + "messages": [{ + "role": "user", + "content": [{"text": "Hello"}] + }] + })) + .unwrap(), + ); + + let prepared = prepare_bedrock_request_with_fetch( + body, + &bedrock_spec( + "anthropic.claude-3-5-sonnet-20241022-v2:0", + ProviderFormat::Converse, + ), + ProviderFormat::Converse, + |_url| { + Box::pin(async { + panic!("fetch should not be called for same-format converse requests"); + }) + }, + ) + .await + .unwrap(); + + let value: lingua::serde_json::Value = lingua::serde_json::from_slice(&prepared).unwrap(); + assert_eq!( + value.get("modelId").and_then(|v| v.as_str()), + Some("anthropic.claude-3-haiku-20240307-v1:0") + ); + assert_eq!( + value.pointer("/system/0/text").and_then(|v| v.as_str()), + Some("You are helpful.") + ); + assert_eq!( + value + .pointer("/guardrailConfig/guardrailIdentifier") + .and_then(|v| v.as_str()), + Some("test") + ); + } + #[tokio::test] async fn prepare_request_repairs_lone_surrogate_for_same_format_converse() { let body = Bytes::from_static( diff --git a/crates/braintrust-llm-router/src/providers/body_model.rs b/crates/braintrust-llm-router/src/providers/body_model.rs new file mode 100644 index 00000000..1af9a304 --- /dev/null +++ b/crates/braintrust-llm-router/src/providers/body_model.rs @@ -0,0 +1,143 @@ +use bytes::Bytes; +use lingua::serde_json::{self, Value}; +use lingua::ProviderFormat; + +fn body_model_field(format: ProviderFormat) -> Option<&'static str> { + match format { + ProviderFormat::ChatCompletions + | ProviderFormat::Responses + | ProviderFormat::Anthropic + | ProviderFormat::Mistral => Some("model"), + ProviderFormat::Google + | ProviderFormat::Converse + | ProviderFormat::BedrockAnthropic + | ProviderFormat::VertexAnthropic + | ProviderFormat::Unknown => None, + } +} + +enum BodyModelRewrite { + Required, + NotRequired, + Unknown, +} + +#[derive(serde::Deserialize)] +struct BodyModel { + model: Option, +} + +fn body_model_rewrite_status( + payload: &[u8], + format: ProviderFormat, + model: &str, +) -> BodyModelRewrite { + match body_model_field(format) { + Some("model") => match serde_json::from_slice::(payload) { + Ok(parsed) => { + if parsed.model.as_deref() == Some(model) { + BodyModelRewrite::NotRequired + } else { + BodyModelRewrite::Required + } + } + Err(_) => BodyModelRewrite::Unknown, + }, + Some(_) | None => BodyModelRewrite::Unknown, + } +} + +pub(crate) fn rewrite_body_model_if_required( + payload: Bytes, + format: ProviderFormat, + model: &str, +) -> Bytes { + match body_model_rewrite_status(&payload, format, model) { + BodyModelRewrite::Required => {} + BodyModelRewrite::NotRequired | BodyModelRewrite::Unknown => return payload, + } + + let Some(model_field) = body_model_field(format) else { + return payload; + }; + let Ok(mut value) = serde_json::from_slice::(&payload) else { + return payload; + }; + let Some(object) = value.as_object_mut() else { + return payload; + }; + if object.get(model_field).and_then(Value::as_str) == Some(model) { + return payload; + } + object.insert(model_field.to_string(), Value::String(model.to_string())); + match serde_json::to_vec(&value) { + Ok(serialized) => Bytes::from(serialized), + Err(_) => payload, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn rewrite_body_model_if_required_leaves_matching_model_bytes_unchanged() { + let payload = Bytes::from_static(br#"{"model":"gpt-4o","messages":[]}"#); + let original_ptr = payload.as_ptr(); + + let updated = + rewrite_body_model_if_required(payload, ProviderFormat::ChatCompletions, "gpt-4o"); + + assert_eq!(updated.as_ptr(), original_ptr); + } + + #[test] + fn rewrite_body_model_if_required_rewrites_mismatched_model_field() { + let payload = Bytes::from_static(br#"{"model":"gpt-4","messages":[]}"#); + + let updated = + rewrite_body_model_if_required(payload, ProviderFormat::ChatCompletions, "gpt-4o"); + let value: Value = serde_json::from_slice(&updated).unwrap(); + + assert_eq!(value.get("model").and_then(Value::as_str), Some("gpt-4o")); + } + + #[test] + fn rewrite_body_model_if_required_leaves_converse_payload_unchanged() { + let payload = Bytes::from_static( + br#"{"modelId":"model-a","messages":[{"role":"user","content":[]}]}"#, + ); + let original_ptr = payload.as_ptr(); + + let updated = rewrite_body_model_if_required(payload, ProviderFormat::Converse, "model-b"); + + assert_eq!(updated.as_ptr(), original_ptr); + } + + #[test] + fn rewrite_body_model_if_required_leaves_google_payload_unchanged() { + let payload = Bytes::from_static( + br#"{"model":"gemini-2.5-flash","contents":[{"role":"user","parts":[{"text":"hi"}]}]}"#, + ); + let original_ptr = payload.as_ptr(); + + let updated = rewrite_body_model_if_required( + payload, + ProviderFormat::Google, + "models/gemini-2.5-pro", + ); + + assert_eq!(updated.as_ptr(), original_ptr); + } + + #[test] + fn rewrite_body_model_if_required_leaves_unknown_payload_unchanged() { + let payload = Bytes::from_static(b"not-json"); + let original_ptr = payload.as_ptr(); + + let updated = + rewrite_body_model_if_required(payload, ProviderFormat::ChatCompletions, "gpt-4o"); + + assert_eq!(updated.as_ptr(), original_ptr); + } +} diff --git a/crates/braintrust-llm-router/src/providers/mod.rs b/crates/braintrust-llm-router/src/providers/mod.rs index f9f84872..6f355ecc 100644 --- a/crates/braintrust-llm-router/src/providers/mod.rs +++ b/crates/braintrust-llm-router/src/providers/mod.rs @@ -1,6 +1,7 @@ pub(crate) mod anthropic; mod azure; mod bedrock; +mod body_model; mod databricks; mod google; mod mistral; @@ -11,6 +12,7 @@ pub use anthropic::{AnthropicConfig, AnthropicProvider}; pub use azure::{AzureConfig, AzureProvider}; pub(crate) use bedrock::{prepare_bedrock_request, requires_bedrock_request_preparation}; pub use bedrock::{BedrockConfig, BedrockProvider}; +pub(crate) use body_model::rewrite_body_model_if_required; pub use databricks::{DatabricksConfig, DatabricksProvider}; pub use google::{GoogleConfig, GoogleProvider}; pub use mistral::{MistralConfig, MistralProvider}; @@ -224,6 +226,11 @@ pub trait Provider: Send + Sync { /// Provider identifier (e.g., "openai", "anthropic"). fn id(&self) -> &'static str; + /// Whether this provider registration satisfies a catalog provider alias. + fn matches_provider_alias(&self, alias: &str) -> bool { + self.id() == alias + } + /// All formats this provider can handle. fn provider_formats(&self) -> Vec; diff --git a/crates/braintrust-llm-router/src/providers/openai.rs b/crates/braintrust-llm-router/src/providers/openai.rs index 21c3fba2..0b62b6e2 100644 --- a/crates/braintrust-llm-router/src/providers/openai.rs +++ b/crates/braintrust-llm-router/src/providers/openai.rs @@ -43,6 +43,7 @@ pub struct OpenAIProvider { client: ClientWithMiddleware, config: OpenAIConfig, endpoint_template: Option, + provider_alias: String, } impl OpenAIProvider { @@ -63,9 +64,15 @@ impl OpenAIProvider { client, endpoint_template, config, + provider_alias: "openai".to_string(), }) } + pub fn with_provider_alias(mut self, provider_alias: impl Into) -> Self { + self.provider_alias = provider_alias.into(); + self + } + /// Create an OpenAI provider from configuration parameters. /// /// Extracts OpenAI-specific options from metadata: @@ -182,6 +189,10 @@ impl crate::providers::Provider for OpenAIProvider { "openai" } + fn matches_provider_alias(&self, alias: &str) -> bool { + self.provider_alias == alias + } + fn provider_formats(&self) -> Vec { vec![ProviderFormat::ChatCompletions, ProviderFormat::Responses] } diff --git a/crates/braintrust-llm-router/src/router.rs b/crates/braintrust-llm-router/src/router.rs index 382e9f4f..a6dd5d46 100644 --- a/crates/braintrust-llm-router/src/router.rs +++ b/crates/braintrust-llm-router/src/router.rs @@ -16,7 +16,7 @@ use crate::client::ClientSettings; use crate::error::{Error, Result}; use crate::providers::{ enable_streaming_payload, prepare_bedrock_request, requires_bedrock_request_preparation, - ClientHeaders, Provider, + rewrite_body_model_if_required, ClientHeaders, Provider, }; use crate::retry::{RetryPolicy, RetryStrategy}; use crate::streaming::{ @@ -124,13 +124,16 @@ pub fn create_provider( timeout, client_settings, )?)), - kind if is_openai_compatible(kind) => Ok(Arc::new(OpenAIProvider::from_config( - endpoint, - endpoint_template, - timeout, - metadata, - client_settings, - )?)), + kind if is_openai_compatible(kind) => Ok(Arc::new( + OpenAIProvider::from_config( + endpoint, + endpoint_template, + timeout, + metadata, + client_settings, + )? + .with_provider_alias(kind.to_ascii_lowercase()), + )), other => Err(Error::InvalidRequest(format!( "unsupported provider kind: {other}" ))), @@ -205,29 +208,50 @@ struct PreparedRequestInner { strategy: RetryStrategy, } +#[derive(Clone, Copy)] +struct RequestPreparationOptions { + rewrite_body_model: bool, +} + +impl Default for RequestPreparationOptions { + fn default() -> Self { + Self { + rewrite_body_model: true, + } + } +} + async fn prepare_provider_request( body: Bytes, spec: &ModelSpec, format: ProviderFormat, stream: bool, + options: RequestPreparationOptions, ) -> Result<(Bytes, Option, ProviderFormat)> { if requires_bedrock_request_preparation(format) { let bytes = prepare_bedrock_request(body, spec, format).await?; return Ok((bytes, Some(format), format)); } - let (transformed, detected_format, actual_format) = - match lingua::transform_request(body.clone(), format, Some(&spec.model)) { - Ok(TransformResult::PassThrough(bytes)) => (bytes, None, format), + let model_override = options.rewrite_body_model.then_some(spec.model.as_str()); + let (transformed, detected_format, actual_format, maybe_rewrite_model) = + match lingua::transform_request(body.clone(), format, model_override) { + Ok(TransformResult::PassThrough(bytes)) => (bytes, None, format, true), Ok(TransformResult::Transformed { bytes, source_format, actual_target_format, - }) => (bytes, Some(source_format), actual_target_format), - Err(TransformError::UnsupportedTargetFormat(_)) => (body, None, format), + }) => (bytes, Some(source_format), actual_target_format, false), + Err(TransformError::UnsupportedTargetFormat(_)) => (body, None, format, true), Err(err) => return Err(err.into()), }; + let transformed = if options.rewrite_body_model && maybe_rewrite_model { + rewrite_body_model_if_required(transformed, actual_format, &spec.model) + } else { + transformed + }; + if stream { // TODO: Fold streaming intent into `lingua::transform_request` once we // are ready to update its Rust/WASM/Python/TS call sites together. @@ -266,9 +290,11 @@ impl Router { output_format: ProviderFormat, route: &ProviderRoute, stream: bool, + options: RequestPreparationOptions, ) -> Result<(PreparedRequestInner, RouterMetadata)> { let (payload, detected_format, actual_format) = - prepare_provider_request(body, route.spec.as_ref(), route.format, stream).await?; + prepare_provider_request(body, route.spec.as_ref(), route.format, stream, options) + .await?; Ok(( PreparedRequestInner { provider: route.provider.clone(), @@ -294,6 +320,7 @@ impl Router { /// * `body` - Raw request body bytes in any supported format (OpenAI, Anthropic, Google, etc.) /// * `output_format` - The output format, or None to auto-detect from body /// * `route` - The already-resolved provider route to prepare for + /// * `preserve_body_model` - Keep the request body's model instead of rewriting it to the route model /// /// The body will be automatically transformed to the selected provider format if needed. #[cfg_attr( @@ -309,9 +336,18 @@ impl Router { body: Bytes, output_format: ProviderFormat, route: &ProviderRoute, + preserve_body_model: bool, ) -> Result<(PreparedRequest, RouterMetadata)> { let (inner, metadata) = self - .create_prepared_request_internal(body, output_format, route, false) + .create_prepared_request_internal( + body, + output_format, + route, + false, + RequestPreparationOptions { + rewrite_body_model: !preserve_body_model, + }, + ) .await?; Ok((PreparedRequest { inner }, metadata)) } @@ -400,6 +436,7 @@ impl Router { /// * `body` - Raw request body bytes in any supported format (OpenAI, Anthropic, Google, etc.) /// * `output_format` - The output format, or None to auto-detect from body /// * `route` - The already-resolved provider route to prepare for + /// * `preserve_body_model` - Keep the request body's model instead of rewriting it to the route model /// /// The body will be automatically transformed to the selected provider format if needed. #[cfg_attr( @@ -415,9 +452,18 @@ impl Router { body: Bytes, output_format: ProviderFormat, route: &ProviderRoute, + preserve_body_model: bool, ) -> Result<(PreparedStreamRequest, RouterMetadata)> { let (inner, metadata) = self - .create_prepared_request_internal(body, output_format, route, true) + .create_prepared_request_internal( + body, + output_format, + route, + true, + RequestPreparationOptions { + rewrite_body_model: !preserve_body_model, + }, + ) .await?; Ok((PreparedStreamRequest { inner }, metadata)) } @@ -519,6 +565,14 @@ impl Router { output_format: ProviderFormat, fallback_aliases: &[String], ) -> Result> { + if !fallback_aliases.is_empty() { + return self.resolve_provider_routes_for_failover( + model, + output_format, + fallback_aliases, + ); + } + let (spec, catalog_format, aliases) = self.resolver.resolve(model)?; let routes: Vec> = aliases .iter() @@ -618,6 +672,106 @@ impl Router { Ok(successes) } + fn resolve_provider_routes_for_failover( + &self, + model: &str, + output_format: ProviderFormat, + fallback_aliases: &[String], + ) -> Result> { + let resolved_models = self.resolver.resolve_all_equivalent_model_routes(model)?; + let (_, first_catalog_format, _) = resolved_models + .first() + .ok_or_else(|| Error::UnknownModel(model.to_string()))?; + let mut first_error = None; + let mut routes = Vec::new(); + let mut seen = HashSet::new(); + + if let Some((spec, catalog_format, aliases)) = resolved_models.first() { + for alias in aliases { + match self.resolve_provider( + output_format, + spec.clone(), + *catalog_format, + alias.to_string(), + ) { + Ok(route) => { + seen.insert(route.provider_alias.clone()); + routes.push(route); + break; + } + Err(err) => { + if first_error.is_none() { + first_error = Some(err); + } + } + } + } + } + + for fallback_alias in fallback_aliases { + if seen.contains(fallback_alias) { + continue; + } + + for (spec, catalog_format, aliases) in &resolved_models { + if !aliases + .iter() + .any(|alias| self.alias_matches_provider(alias, fallback_alias)) + { + continue; + } + + match self.resolve_provider( + output_format, + spec.clone(), + *catalog_format, + fallback_alias.clone(), + ) { + Ok(route) => { + seen.insert(route.provider_alias.clone()); + routes.push(route); + break; + } + Err(err) => { + if first_error.is_none() { + first_error = Some(err); + } + } + } + } + } + + if routes.is_empty() { + return Err(first_error.unwrap_or_else(|| Error::NoProvider(*first_catalog_format))); + } + + Ok(routes) + } + + fn alias_matches_provider(&self, resolver_alias: &str, provider_alias: &str) -> bool { + if resolver_alias == provider_alias { + return true; + } + + if let Some(provider_id) = default_alias_provider_id(resolver_alias) { + return self + .providers + .get(provider_alias) + .is_some_and(|provider| provider.matches_provider_alias(provider_id)); + } + + if self + .providers + .get(provider_alias) + .is_some_and(|provider| provider.matches_provider_alias(resolver_alias)) + { + return true; + } + + default_alias_provider_id(provider_alias) + .is_some_and(|provider_id| provider_id == resolver_alias) + } + #[cfg(test)] fn resolve_providers( &self, @@ -810,6 +964,20 @@ impl Router { } } +fn default_alias_provider_id(alias: &str) -> Option<&'static str> { + match alias { + "OPENAI_API_KEY" => Some("openai"), + "ANTHROPIC_API_KEY" => Some("anthropic"), + "GEMINI_API_KEY" => Some("google"), + "MISTRAL_API_KEY" => Some("mistral"), + "AWS_DEFAULT_CREDENTIALS" => Some("bedrock"), + "GOOGLE_DEFAULT_CREDENTIALS" => Some("vertex"), + "AZURE_DEFAULT_CREDENTIALS" => Some("azure"), + "DATABRICKS_DEFAULT_CREDENTIALS" => Some("databricks"), + _ => None, + } +} + /// One provider registration: alias, provider, auth, and default formats. struct ProviderEntry { alias: String, @@ -1072,6 +1240,51 @@ mod tests { } } + struct FakeOpenAICompatibleProvider { + alias: &'static str, + } + + #[async_trait] + impl Provider for FakeOpenAICompatibleProvider { + fn id(&self) -> &'static str { + "openai" + } + + fn matches_provider_alias(&self, alias: &str) -> bool { + self.alias == alias + } + + fn provider_formats(&self) -> Vec { + vec![ProviderFormat::ChatCompletions] + } + + async fn complete( + &self, + _payload: Bytes, + _auth: &AuthConfig, + _spec: &ModelSpec, + _format: ProviderFormat, + _client_headers: &ClientHeaders, + ) -> Result { + Ok(Bytes::from("{}")) + } + + async fn complete_stream( + &self, + _payload: Bytes, + _auth: &AuthConfig, + _spec: &ModelSpec, + _format: ProviderFormat, + _client_headers: &ClientHeaders, + ) -> Result { + unimplemented!() + } + + async fn health_check(&self, _auth: &AuthConfig) -> Result<()> { + Ok(()) + } + } + fn google_spec(model: &str) -> ModelSpec { ModelSpec { model: model.to_string(), @@ -1197,7 +1410,9 @@ mod tests { let route = routes .first() .ok_or_else(|| Error::NoProvider(output_format))?; - router.create_request(body, output_format, route).await + router + .create_request(body, output_format, route, false) + .await } async fn create_test_stream_request( @@ -1211,7 +1426,7 @@ mod tests { .first() .ok_or_else(|| Error::NoProvider(output_format))?; router - .create_stream_request(body, output_format, route) + .create_stream_request(body, output_format, route, false) .await } @@ -1222,10 +1437,15 @@ mod tests { ); let spec = openai_spec("gpt-5-mini", ModelFlavor::Chat); - let (payload, _, _) = - prepare_provider_request(body, &spec, ProviderFormat::ChatCompletions, true) - .await - .expect("request prepares"); + let (payload, _, _) = prepare_provider_request( + body, + &spec, + ProviderFormat::ChatCompletions, + true, + RequestPreparationOptions::default(), + ) + .await + .expect("request prepares"); let parsed: Value = serde_json::from_slice(&payload).expect("valid request json"); assert_eq!(parsed.get("stream"), Some(&Value::Bool(true))); @@ -1240,16 +1460,186 @@ mod tests { ); let spec = openai_spec("gpt-5-mini", ModelFlavor::Chat); - let (payload, _, _) = - prepare_provider_request(body, &spec, ProviderFormat::ChatCompletions, false) - .await - .expect("request prepares"); + let (payload, _, _) = prepare_provider_request( + body, + &spec, + ProviderFormat::ChatCompletions, + false, + RequestPreparationOptions::default(), + ) + .await + .expect("request prepares"); let parsed: Value = serde_json::from_slice(&payload).expect("valid request json"); assert_eq!(parsed.get("stream"), None); assert_eq!(parsed.get("stream_options"), None); } + #[tokio::test] + async fn prepare_provider_request_does_not_read_model_for_vertex_anthropic() { + let body = Bytes::from_static( + br#"{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"Ping"}]}"#, + ); + let spec = ModelSpec { + model: "publishers/anthropic/models/claude-sonnet-4-6".to_string(), + format: ProviderFormat::Anthropic, + flavor: ModelFlavor::Chat, + display_name: None, + parent: None, + input_cost_per_mil_tokens: None, + output_cost_per_mil_tokens: None, + input_cache_read_cost_per_mil_tokens: None, + multimodal: None, + reasoning: None, + max_input_tokens: None, + max_output_tokens: None, + supports_streaming: true, + extra: Default::default(), + available_providers: vec!["vertex".to_string()], + }; + + let (payload, _, actual_format) = prepare_provider_request( + body, + &spec, + ProviderFormat::VertexAnthropic, + false, + RequestPreparationOptions::default(), + ) + .await + .expect("request prepares"); + let parsed: Value = serde_json::from_slice(&payload).expect("valid request json"); + + assert_eq!(actual_format, ProviderFormat::VertexAnthropic); + assert_eq!(parsed.get("model"), None); + assert!(parsed.get("anthropic_version").is_some()); + assert!(parsed.get("messages").is_some()); + } + + #[tokio::test] + async fn prepare_provider_request_does_not_rewrite_model_for_google_pass_through() { + let body = Bytes::from_static( + br#"{"model":"models/gemini-2.5-flash","contents":[{"role":"user","parts":[{"text":"Ping"}]}]}"#, + ); + let spec = ModelSpec { + model: "models/gemini-2.5-pro".to_string(), + format: ProviderFormat::Google, + flavor: ModelFlavor::Chat, + display_name: None, + parent: None, + input_cost_per_mil_tokens: None, + output_cost_per_mil_tokens: None, + input_cache_read_cost_per_mil_tokens: None, + multimodal: None, + reasoning: None, + max_input_tokens: None, + max_output_tokens: None, + supports_streaming: true, + extra: Default::default(), + available_providers: vec!["google".to_string()], + }; + + let (payload, _, actual_format) = prepare_provider_request( + body, + &spec, + ProviderFormat::Google, + false, + RequestPreparationOptions::default(), + ) + .await + .expect("request prepares"); + let parsed: Value = serde_json::from_slice(&payload).expect("valid request json"); + + assert_eq!(actual_format, ProviderFormat::Google); + assert_eq!( + parsed.get("model").and_then(Value::as_str), + Some("models/gemini-2.5-flash") + ); + assert!(parsed.get("contents").is_some()); + } + + #[tokio::test] + async fn prepare_provider_request_rewrites_same_format_chat_model_without_losing_native_fields() + { + let body = Bytes::from_static( + br#"{"model":"gpt-4","messages":[{"role":"user","name":"example_user","content":"Ping"}]}"#, + ); + let spec = openai_spec("gpt-4o", ModelFlavor::Chat); + + let (payload, _, actual_format) = prepare_provider_request( + body, + &spec, + ProviderFormat::ChatCompletions, + false, + RequestPreparationOptions::default(), + ) + .await + .expect("request prepares"); + let parsed: Value = serde_json::from_slice(&payload).expect("valid request json"); + + assert_eq!(actual_format, ProviderFormat::ChatCompletions); + assert_eq!(parsed.get("model").and_then(Value::as_str), Some("gpt-4o")); + assert_eq!( + parsed.pointer("/messages/0/name").and_then(Value::as_str), + Some("example_user") + ); + } + + #[tokio::test] + async fn prepare_provider_request_can_preserve_same_format_body_model() { + let body = Bytes::from_static( + br#"{"model":"gpt-4","messages":[{"role":"user","name":"example_user","content":"Ping"}]}"#, + ); + let spec = openai_spec("gpt-4o", ModelFlavor::Chat); + + let (payload, _, actual_format) = prepare_provider_request( + body, + &spec, + ProviderFormat::ChatCompletions, + false, + RequestPreparationOptions { + rewrite_body_model: false, + }, + ) + .await + .expect("request prepares"); + let parsed: Value = serde_json::from_slice(&payload).expect("valid request json"); + + assert_eq!(actual_format, ProviderFormat::ChatCompletions); + assert_eq!(parsed.get("model").and_then(Value::as_str), Some("gpt-4")); + assert_eq!( + parsed.pointer("/messages/0/name").and_then(Value::as_str), + Some("example_user") + ); + } + + #[tokio::test] + async fn prepare_provider_request_can_preserve_body_model_across_format_transform() { + let body = Bytes::from_static( + br#"{"model":"claude-3-5-haiku-20241022","max_tokens":128,"messages":[{"role":"user","content":"Ping"}]}"#, + ); + let spec = openai_spec("gpt-4o", ModelFlavor::Chat); + + let (payload, detected_format, actual_format) = prepare_provider_request( + body, + &spec, + ProviderFormat::ChatCompletions, + false, + RequestPreparationOptions { + rewrite_body_model: false, + }, + ) + .await + .expect("request prepares"); + let parsed: Value = serde_json::from_slice(&payload).expect("valid request json"); + + assert_eq!(detected_format, Some(ProviderFormat::Anthropic)); + assert_eq!(actual_format, ProviderFormat::ChatCompletions); + assert_eq!( + parsed.get("model").and_then(Value::as_str), + Some("claude-3-5-haiku-20241022") + ); + } + #[tokio::test] async fn prepare_provider_request_upgrades_actual_format_to_responses_for_reasoning_plus_tools() { @@ -1277,10 +1667,15 @@ mod tests { ); let spec = openai_spec("gpt-5.4-mini", ModelFlavor::Chat); - let (_, _, actual_format) = - prepare_provider_request(body, &spec, ProviderFormat::ChatCompletions, false) - .await - .expect("request prepares"); + let (_, _, actual_format) = prepare_provider_request( + body, + &spec, + ProviderFormat::ChatCompletions, + false, + RequestPreparationOptions::default(), + ) + .await + .expect("request prepares"); assert_eq!( actual_format, @@ -2424,28 +2819,37 @@ mod tests { } #[test] - fn resolved_aliases_returns_only_registered_available_providers() { - let model = "gpt-4o"; - let mut catalog = ModelCatalog::empty(); - catalog.insert( - model.into(), - openai_spec_with_available_providers(model, ModelFlavor::Chat), - ); + fn overlay_catalog_failover_routes_use_equivalent_custom_model() { + let base = ModelCatalog::empty(); + let mut custom = ModelCatalog::empty(); + let mut primary = openai_spec("custom-primary", ModelFlavor::Chat); + primary.available_providers = vec!["provider-a".to_string()]; + let mut fallback = openai_spec("custom-fallback", ModelFlavor::Chat); + fallback.available_providers = vec!["provider-b".to_string()]; + custom.insert("custom-primary".into(), primary); + custom.insert("custom-fallback".into(), fallback); + custom + .add_fallback_models( + "custom-primary".to_string(), + vec!["custom-fallback".to_string()], + ) + .expect("equivalence is valid"); + let router = Router::builder() - .with_catalog(Arc::new(catalog)) + .with_overlay_catalog(Arc::new(base), custom) .add_provider( - "openai", + "provider-a", FakeProvider { name: "openai", formats: vec![ProviderFormat::ChatCompletions], }, dummy_auth(), - vec![ProviderFormat::ChatCompletions], + vec![], ) .add_provider( - "azure", + "provider-b", FakeProvider { - name: "azure", + name: "openai", formats: vec![ProviderFormat::ChatCompletions], }, dummy_auth(), @@ -2454,10 +2858,126 @@ mod tests { .build() .expect("router builds"); - let aliases = resolved_aliases(&router, model, ProviderFormat::ChatCompletions) - .expect("resolved aliases"); - assert_eq!(aliases, vec!["openai".to_string(), "azure".to_string()]); - } + let routes = router + .resolve_provider_routes( + "custom-primary", + ProviderFormat::ChatCompletions, + &["provider-a".to_string(), "provider-b".to_string()], + ) + .expect("failover routes resolve"); + let route_info: Vec<(&str, &str)> = routes + .iter() + .map(|route| (route.provider_alias(), route.model())) + .collect(); + + assert_eq!( + route_info, + vec![ + ("provider-a", "custom-primary"), + ("provider-b", "custom-fallback"), + ] + ); + } + + #[test] + fn overlay_catalog_failover_routes_use_equivalent_base_model() { + let mut base = ModelCatalog::empty(); + base.insert( + "base-fallback".into(), + openai_spec_with_available_providers("base-fallback", ModelFlavor::Chat), + ); + let mut custom = ModelCatalog::empty(); + let mut primary = openai_spec("custom-primary", ModelFlavor::Chat); + primary.available_providers = vec!["provider-a".to_string()]; + custom.insert("custom-primary".into(), primary); + custom + .add_external_fallback_models( + "custom-primary".to_string(), + vec!["base-fallback".to_string()], + ) + .expect("equivalence is valid"); + + let router = Router::builder() + .with_overlay_catalog(Arc::new(base), custom) + .add_provider( + "provider-a", + FakeProvider { + name: "openai", + formats: vec![ProviderFormat::ChatCompletions], + }, + dummy_auth(), + vec![], + ) + .add_provider( + "openai", + FakeProvider { + name: "openai", + formats: vec![ProviderFormat::ChatCompletions], + }, + dummy_auth(), + vec![], + ) + .build() + .expect("router builds"); + + let routes = router + .resolve_provider_routes( + "custom-primary", + ProviderFormat::ChatCompletions, + &["provider-a".to_string(), "openai".to_string()], + ) + .expect("failover routes resolve"); + let route_info: Vec<(&str, &str)> = routes + .iter() + .map(|route| (route.provider_alias(), route.model())) + .collect(); + + assert_eq!( + route_info, + vec![ + ("provider-a", "custom-primary"), + ("openai", "base-fallback"), + ] + ); + assert!(router.catalog().get("base-fallback").is_some()); + assert!(router.catalog().get("custom-primary").is_none()); + } + + #[test] + fn resolved_aliases_returns_only_registered_available_providers() { + let model = "gpt-4o"; + let mut catalog = ModelCatalog::empty(); + catalog.insert( + model.into(), + openai_spec_with_available_providers(model, ModelFlavor::Chat), + ); + let router = Router::builder() + .with_catalog(Arc::new(catalog)) + .add_provider( + "openai", + FakeProvider { + name: "openai", + formats: vec![ProviderFormat::ChatCompletions], + }, + dummy_auth(), + vec![ProviderFormat::ChatCompletions], + ) + .add_provider( + "azure", + FakeProvider { + name: "azure", + formats: vec![ProviderFormat::ChatCompletions], + }, + dummy_auth(), + vec![], + ) + .build() + .expect("router builds"); + + let aliases = resolved_aliases(&router, model, ProviderFormat::ChatCompletions) + .expect("resolved aliases"); + assert_eq!(aliases, vec!["openai".to_string(), "azure".to_string()]); + } #[test] fn fallback_provider_routes_append_after_primary() { @@ -2536,6 +3056,417 @@ mod tests { ); } + #[test] + fn fallback_provider_routes_do_not_treat_openai_provider_id_as_allowlist_match() { + let model = "gpt-4o"; + let mut catalog = ModelCatalog::empty(); + let mut spec = openai_spec(model, ModelFlavor::Chat); + spec.available_providers = vec!["openai".into()]; + catalog.insert(model.into(), spec); + let router = Router::builder() + .with_catalog(Arc::new(catalog)) + .add_provider( + "openai", + FakeProvider { + name: "openai", + formats: vec![ProviderFormat::ChatCompletions], + }, + dummy_auth(), + vec![ProviderFormat::ChatCompletions], + ) + .add_provider( + "cerebras", + FakeOpenAICompatibleProvider { alias: "cerebras" }, + dummy_auth(), + vec![], + ) + .build() + .expect("router builds"); + + assert_eq!( + explicit_route_aliases( + &router, + model, + ProviderFormat::ChatCompletions, + &["cerebras"] + ) + .expect("routes"), + vec!["openai".to_string()] + ); + } + + #[test] + fn fallback_provider_routes_match_named_openai_secret_for_default_alias() { + let model = "gpt-4o"; + let mut catalog = ModelCatalog::empty(); + let mut spec = openai_spec(model, ModelFlavor::Chat); + spec.available_providers = vec!["OPENAI_API_KEY".into()]; + catalog.insert(model.into(), spec); + let router = Router::builder() + .with_catalog(Arc::new(catalog)) + .add_provider( + "my-openai", + FakeProvider { + name: "openai", + formats: vec![ProviderFormat::ChatCompletions], + }, + dummy_auth(), + vec![], + ) + .build() + .expect("router builds"); + + assert_eq!( + explicit_route_aliases( + &router, + model, + ProviderFormat::ChatCompletions, + &["my-openai"] + ) + .expect("routes"), + vec!["my-openai".to_string()] + ); + } + + #[test] + fn failover_routes_match_named_openai_when_equivalents_omit_available_providers() { + let model = "custom-primary"; + let fallback_model = "gpt-4o"; + let mut base = ModelCatalog::empty(); + base.insert( + fallback_model.into(), + openai_spec(fallback_model, ModelFlavor::Chat), + ); + let mut custom = ModelCatalog::empty(); + let mut primary = openai_spec(model, ModelFlavor::Chat); + primary.available_providers = vec!["provider-a".to_string()]; + custom.insert(model.into(), primary); + custom + .add_external_fallback_models(model.to_string(), vec![fallback_model.to_string()]) + .expect("equivalence is valid"); + + let router = Router::builder() + .with_overlay_catalog(Arc::new(base), custom) + .add_provider( + "provider-a", + FakeProvider { + name: "openai", + formats: vec![ProviderFormat::ChatCompletions], + }, + dummy_auth(), + vec![], + ) + .add_provider( + "my-openai", + FakeOpenAICompatibleProvider { alias: "openai" }, + dummy_auth(), + vec![], + ) + .build() + .expect("router builds"); + + let routes = router + .resolve_provider_routes( + model, + ProviderFormat::ChatCompletions, + &["provider-a".to_string(), "my-openai".to_string()], + ) + .expect("routes resolve"); + let route_info: Vec<(&str, &str)> = routes + .iter() + .map(|route| (route.provider_alias(), route.model())) + .collect(); + + assert_eq!( + route_info, + vec![("provider-a", model), ("my-openai", fallback_model)] + ); + } + + #[test] + fn failover_routes_match_named_openai_compatible_provider_id() { + let model = "custom-primary"; + let fallback_model = "llama-4-scout"; + let mut base = ModelCatalog::empty(); + let mut fallback = openai_spec(fallback_model, ModelFlavor::Chat); + fallback.available_providers = vec!["cerebras".to_string()]; + base.insert(fallback_model.into(), fallback); + let mut custom = ModelCatalog::empty(); + let mut primary = openai_spec(model, ModelFlavor::Chat); + primary.available_providers = vec!["provider-a".to_string()]; + custom.insert(model.into(), primary); + custom + .add_external_fallback_models(model.to_string(), vec![fallback_model.to_string()]) + .expect("equivalence is valid"); + + let router = Router::builder() + .with_overlay_catalog(Arc::new(base), custom) + .add_provider( + "provider-a", + FakeProvider { + name: "openai", + formats: vec![ProviderFormat::ChatCompletions], + }, + dummy_auth(), + vec![], + ) + .add_provider( + "my-cerebras", + FakeOpenAICompatibleProvider { alias: "cerebras" }, + dummy_auth(), + vec![], + ) + .build() + .expect("router builds"); + + let routes = router + .resolve_provider_routes( + model, + ProviderFormat::ChatCompletions, + &["provider-a".to_string(), "my-cerebras".to_string()], + ) + .expect("routes resolve"); + let route_info: Vec<(&str, &str)> = routes + .iter() + .map(|route| (route.provider_alias(), route.model())) + .collect(); + + assert_eq!( + route_info, + vec![("provider-a", model), ("my-cerebras", fallback_model)] + ); + } + + #[tokio::test] + async fn failover_request_payload_uses_equivalent_route_model_for_same_format() { + let model = "gpt-4o"; + let fallback_model = "other-provider/gpt-4o"; + let catalog = ModelCatalog::from_json_str( + r#"{ + "gpt-4o": { + "format": "openai", + "flavor": "chat", + "available_providers": ["provider-a"], + "fallback_models": ["other-provider/gpt-4o"] + }, + "other-provider/gpt-4o": { + "format": "openai", + "flavor": "chat" + } +}"#, + ) + .expect("catalog parses"); + let router = Router::builder() + .with_catalog(Arc::new(catalog)) + .add_provider( + "provider-a", + FakeProvider { + name: "openai", + formats: vec![ProviderFormat::ChatCompletions], + }, + dummy_auth(), + vec![], + ) + .add_provider( + "my-openai", + FakeOpenAICompatibleProvider { alias: "openai" }, + dummy_auth(), + vec![], + ) + .build() + .expect("router builds"); + + let routes = router + .resolve_provider_routes( + model, + ProviderFormat::ChatCompletions, + &["provider-a".to_string(), "my-openai".to_string()], + ) + .expect("routes resolve"); + let fallback_route = routes + .iter() + .find(|route| route.provider_alias() == "my-openai") + .expect("fallback route exists"); + assert_eq!(fallback_route.model(), fallback_model); + + let body = Bytes::from_static( + br#"{"model":"gpt-4o","messages":[{"role":"user","content":"Ping"}]}"#, + ); + let (request, _) = router + .create_request(body, ProviderFormat::ChatCompletions, fallback_route, false) + .await + .expect("request prepares"); + let payload: Value = serde_json::from_slice(&request.inner.payload).expect("json"); + + assert_eq!( + payload.get("model").and_then(Value::as_str), + Some(fallback_model) + ); + } + + #[test] + fn fallback_provider_routes_do_not_match_openai_default_alias_to_compatible_provider() { + let model = "gpt-4o"; + let mut catalog = ModelCatalog::empty(); + let mut spec = openai_spec(model, ModelFlavor::Chat); + spec.available_providers = vec!["OPENAI_API_KEY".into()]; + catalog.insert(model.into(), spec); + let router = Router::builder() + .with_catalog(Arc::new(catalog)) + .add_provider( + "openai", + FakeProvider { + name: "openai", + formats: vec![ProviderFormat::ChatCompletions], + }, + dummy_auth(), + vec![ProviderFormat::ChatCompletions], + ) + .add_provider( + "cerebras", + FakeOpenAICompatibleProvider { alias: "cerebras" }, + dummy_auth(), + vec![], + ) + .build() + .expect("router builds"); + + let err = match explicit_route_aliases( + &router, + model, + ProviderFormat::ChatCompletions, + &["cerebras"], + ) { + Ok(_) => panic!("OpenAI-compatible provider should not satisfy OpenAI default alias"), + Err(err) => err, + }; + + assert!(matches!( + err, + Error::NoProvider(ProviderFormat::ChatCompletions) + )); + } + + #[test] + fn failover_routes_match_named_secrets_by_concrete_provider_id() { + let model = "claude-sonnet-4-6"; + let vertex_model = "publishers/anthropic/models/claude-sonnet-4-6"; + let catalog = ModelCatalog::from_json_str( + r#"{ + "claude-sonnet-4-6": { + "format": "anthropic", + "flavor": "chat", + "available_providers": ["ANTHROPIC_API_KEY"], + "fallback_models": ["publishers/anthropic/models/claude-sonnet-4-6"] + }, + "publishers/anthropic/models/claude-sonnet-4-6": { + "format": "anthropic", + "flavor": "chat", + "available_providers": ["GOOGLE_DEFAULT_CREDENTIALS"] + } +}"#, + ) + .expect("catalog parses"); + let router = Router::builder() + .with_catalog(Arc::new(catalog)) + .add_provider( + "my-anthropic", + FakeProvider { + name: "anthropic", + formats: vec![ProviderFormat::Anthropic], + }, + dummy_auth(), + vec![], + ) + .add_provider( + "my-vertex", + FakeProvider { + name: "vertex", + formats: vec![ProviderFormat::VertexAnthropic], + }, + dummy_auth(), + vec![], + ) + .build() + .expect("router builds"); + + let routes = router + .resolve_provider_routes( + model, + ProviderFormat::ChatCompletions, + &["my-anthropic".to_string(), "my-vertex".to_string()], + ) + .expect("routes resolve"); + let route_info: Vec<(&str, &str)> = routes + .iter() + .map(|route| (route.provider_alias(), route.model())) + .collect(); + + assert_eq!( + route_info, + vec![("my-anthropic", model), ("my-vertex", vertex_model),] + ); + } + + #[test] + fn failover_routes_match_named_secrets_when_equivalents_omit_available_providers() { + let model = "claude-sonnet-4-6"; + let vertex_model = "publishers/anthropic/models/claude-sonnet-4-6"; + let catalog = ModelCatalog::from_json_str( + r#"{ + "claude-sonnet-4-6": { + "format": "anthropic", + "flavor": "chat", + "fallback_models": ["publishers/anthropic/models/claude-sonnet-4-6"] + }, + "publishers/anthropic/models/claude-sonnet-4-6": { + "format": "anthropic", + "flavor": "chat" + } +}"#, + ) + .expect("catalog parses"); + let router = Router::builder() + .with_catalog(Arc::new(catalog)) + .add_provider( + "my-anthropic", + FakeProvider { + name: "anthropic", + formats: vec![ProviderFormat::Anthropic], + }, + dummy_auth(), + vec![], + ) + .add_provider( + "my-vertex", + FakeProvider { + name: "vertex", + formats: vec![ProviderFormat::VertexAnthropic], + }, + dummy_auth(), + vec![], + ) + .build() + .expect("router builds"); + + let routes = router + .resolve_provider_routes( + model, + ProviderFormat::ChatCompletions, + &["my-anthropic".to_string(), "my-vertex".to_string()], + ) + .expect("routes resolve"); + let route_info: Vec<(&str, &str)> = routes + .iter() + .map(|route| (route.provider_alias(), route.model())) + .collect(); + + assert_eq!( + route_info, + vec![("my-anthropic", model), ("my-vertex", vertex_model),] + ); + } + #[test] fn fallback_provider_routes_do_not_use_format_default_for_ineligible_alias() { let model = "gpt-4o"; @@ -2690,4 +3621,85 @@ mod tests { assert_eq!(aliases, vec!["openai", "azure"]); } + + #[test] + fn failover_routes_use_equivalent_provider_native_vertex_model() { + let model = "claude-sonnet-4-6"; + let vertex_model = "publishers/anthropic/models/claude-sonnet-4-6"; + let catalog = ModelCatalog::from_json_str( + r#"{ + "claude-sonnet-4-6": { + "format": "anthropic", + "flavor": "chat", + "available_providers": ["anthropic"], + "fallback_models": ["publishers/anthropic/models/claude-sonnet-4-6"] + }, + "publishers/anthropic/models/claude-sonnet-4-6": { + "format": "anthropic", + "flavor": "chat" + } +}"#, + ) + .expect("catalog parses"); + let catalog = catalog.map_specs(|_, spec| { + let mut spec = spec.clone(); + spec.available_providers = spec + .available_providers + .iter() + .map(|provider| { + if provider == "anthropic" { + "ANTHROPIC_API_KEY".to_string() + } else { + provider.clone() + } + }) + .collect(); + spec + }); + let router = Router::builder() + .with_catalog(Arc::new(catalog)) + .add_provider( + "ANTHROPIC_API_KEY", + FakeProvider { + name: "anthropic", + formats: vec![ProviderFormat::Anthropic], + }, + dummy_auth(), + vec![], + ) + .add_provider( + "GOOGLE_DEFAULT_CREDENTIALS", + FakeProvider { + name: "vertex", + formats: vec![ProviderFormat::VertexAnthropic], + }, + dummy_auth(), + vec![], + ) + .build() + .expect("router builds"); + + let routes = router + .resolve_provider_routes( + model, + ProviderFormat::ChatCompletions, + &[ + "ANTHROPIC_API_KEY".to_string(), + "GOOGLE_DEFAULT_CREDENTIALS".to_string(), + ], + ) + .expect("failover routes resolve"); + let route_info: Vec<(&str, &str)> = routes + .iter() + .map(|route| (route.provider_alias(), route.model())) + .collect(); + + assert_eq!( + route_info, + vec![ + ("ANTHROPIC_API_KEY", model), + ("GOOGLE_DEFAULT_CREDENTIALS", vertex_model), + ] + ); + } } diff --git a/crates/braintrust-llm-router/tests/router.rs b/crates/braintrust-llm-router/tests/router.rs index c285a7ad..9cd46ac3 100644 --- a/crates/braintrust-llm-router/tests/router.rs +++ b/crates/braintrust-llm-router/tests/router.rs @@ -28,7 +28,9 @@ async fn create_request( let route = routes .first() .ok_or_else(|| Error::NoProvider(output_format))?; - router.create_request(body, output_format, route).await + router + .create_request(body, output_format, route, false) + .await } #[derive(Clone)] @@ -918,7 +920,7 @@ async fn responses_required_model_uses_responses_for_anthropic_messages_output() .expect("resolve routes"); let route = routes.first().expect("route"); let (_request, metadata) = router - .create_request(body, ProviderFormat::Anthropic, route) + .create_request(body, ProviderFormat::Anthropic, route, false) .await .expect("create request"); diff --git a/crates/lingua/src/processing/transform.rs b/crates/lingua/src/processing/transform.rs index 89ad7079..b11f96e8 100644 --- a/crates/lingua/src/processing/transform.rs +++ b/crates/lingua/src/processing/transform.rs @@ -848,6 +848,30 @@ mod tests { assert_eq!(output.as_ptr(), input_ptr); } + #[test] + #[cfg(feature = "openai")] + fn test_transform_request_passthrough_with_identical_model_override() { + let payload = json!({ + "model": "gpt-4", + "messages": [{"role": "user", "name": "example_user", "content": "Hello"}] + }); + let input = to_bytes(&payload); + let input_ptr = input.as_ptr(); + + let result = + transform_request(input, ProviderFormat::ChatCompletions, Some("gpt-4")).unwrap(); + + assert!(result.is_passthrough()); + let output = result.into_bytes(); + assert_eq!(output.as_ptr(), input_ptr); + + let parsed: Value = crate::serde_json::from_slice(&output).unwrap(); + assert_eq!( + parsed["messages"][0].get("name").and_then(Value::as_str), + Some("example_user") + ); + } + #[test] #[cfg(feature = "openai")] fn test_transform_request_passthrough_repairs_lone_surrogate() { @@ -2158,9 +2182,46 @@ mod tests { ); let output: Value = crate::serde_json::from_slice(result.as_bytes()).unwrap(); + assert_eq!( + output.get("modelId").and_then(Value::as_str), + Some("amazon.nova-pro-v1:0") + ); assert!( output.get("anthropic_version").is_none(), "Non-anthropic models should not have anthropic_version" ); } + + #[test] + #[cfg(feature = "bedrock")] + fn test_transform_request_preserves_same_format_converse_passthrough_with_model_override() { + let payload = json!({ + "modelId": "anthropic.claude-3-haiku-20240307-v1:0", + "guardrailConfig": { + "guardrailIdentifier": "test", + "guardrailVersion": "1" + }, + "messages": [{ + "role": "user", + "content": [{"text": "Hello"}] + }] + }); + let input = to_bytes(&payload); + + let result = transform_request( + input, + ProviderFormat::Converse, + Some("anthropic.claude-3-5-sonnet-20241022-v2:0"), + ) + .unwrap(); + + assert!(result.is_passthrough()); + + let output: Value = crate::serde_json::from_slice(result.as_bytes()).unwrap(); + assert_eq!( + output.get("modelId").and_then(Value::as_str), + Some("anthropic.claude-3-haiku-20240307-v1:0") + ); + assert!(output.get("guardrailConfig").is_some()); + } }