From 6a0b40854b5637d1b2d3c37072c85c94262c8577 Mon Sep 17 00:00:00 2001 From: Erin McNulty Date: Fri, 12 Jun 2026 16:02:41 -0500 Subject: [PATCH 01/28] detect equivalient models between providers --- .../braintrust-llm-router/src/catalog/mod.rs | 276 +++++++++++++++++- .../src/catalog/resolver.rs | 18 ++ crates/braintrust-llm-router/src/router.rs | 175 +++++++++++ 3 files changed, 468 insertions(+), 1 deletion(-) diff --git a/crates/braintrust-llm-router/src/catalog/mod.rs b/crates/braintrust-llm-router/src/catalog/mod.rs index 68f20fc9..21ad45cf 100644 --- a/crates/braintrust-llm-router/src/catalog/mod.rs +++ b/crates/braintrust-llm-router/src/catalog/mod.rs @@ -13,13 +13,15 @@ use std::io::Read; use std::path::Path; use std::sync::Arc; -use crate::error::Result; +use crate::error::{Error, Result}; #[derive(Debug, Clone, Default)] pub struct ModelCatalog { models: HashMap>, by_format: HashMap>, by_parent: HashMap>, + equivalent_models: HashMap>, + equivalence_index: HashMap>, } /// A request-local catalog overlay. @@ -57,6 +59,19 @@ impl CatalogResolver { Self::Overlay(overlay) => overlay.custom.get(name).or_else(|| overlay.base.get(name)), } } + + pub fn equivalent_model_names(&self, name: &str) -> Vec { + match self { + Self::Base(catalog) => catalog.equivalent_model_names(name), + Self::Overlay(overlay) => { + if overlay.custom.get(name).is_some() { + overlay.custom.equivalent_model_names(name) + } else { + overlay.base.equivalent_model_names(name) + } + } + } + } } impl From> for CatalogResolver { @@ -65,6 +80,34 @@ impl From> for CatalogResolver { } } +fn parse_equivalent_models(content: &str) -> Result>> { + let raw: HashMap = serde_json::from_str(content)?; + let mut equivalent_models = HashMap::new(); + for (name, value) in raw { + let Some(equivalents) = value.get("equivalent_models") else { + continue; + }; + let Some(equivalents) = equivalents.as_array() else { + return Err(Error::InvalidRequest(format!( + "model '{name}' has invalid equivalent_models" + ))); + }; + let mut parsed = Vec::with_capacity(equivalents.len()); + for equivalent in equivalents { + let Some(equivalent) = equivalent.as_str() else { + return Err(Error::InvalidRequest(format!( + "model '{name}' has invalid equivalent_models" + ))); + }; + parsed.push(equivalent.to_string()); + } + if !parsed.is_empty() { + equivalent_models.insert(name, parsed); + } + } + Ok(equivalent_models) +} + impl ModelCatalog { pub fn empty() -> Self { Self::default() @@ -72,10 +115,14 @@ impl ModelCatalog { pub fn from_json_str(content: &str) -> Result { let raw: HashMap = serde_json::from_str(content)?; + let equivalent_models = parse_equivalent_models(content)?; let mut catalog = Self::empty(); for (name, spec) in raw { catalog.insert(name, spec); } + catalog.equivalent_models = equivalent_models; + catalog.validate_equivalent_models()?; + catalog.rebuild_equivalence_index(); Ok(catalog) } @@ -94,6 +141,18 @@ impl ModelCatalog { self.models.get(name).cloned() } + pub fn equivalent_model_names(&self, name: &str) -> Vec { + let Some(_) = self.models.get(name) else { + return Vec::new(); + }; + + let mut names = vec![name.to_string()]; + if let Some(equivalent_names) = self.equivalence_index.get(name) { + names.extend(equivalent_names.iter().cloned()); + } + names + } + pub fn resolve_format(&self, model: &str) -> Option { self.models.get(model).map(|spec| spec.format) } @@ -142,6 +201,21 @@ impl ModelCatalog { self.models.iter() } + pub fn map_specs(&self, mut f: F) -> Self + where + F: FnMut(&str, &ModelSpec) -> ModelSpec, + { + let mut out = Self { + equivalent_models: self.equivalent_models.clone(), + ..Self::empty() + }; + for (name, spec) in &self.models { + out.insert(name.clone(), f(name, spec.as_ref())); + } + out.rebuild_equivalence_index(); + out + } + pub fn len(&self) -> usize { self.models.len() } @@ -163,6 +237,82 @@ impl ModelCatalog { self.by_parent.entry(parent).or_default().push(name); } } + + fn validate_equivalent_models(&self) -> Result<()> { + for (name, equivalents) in &self.equivalent_models { + for equivalent_model in equivalents { + if !self.models.contains_key(equivalent_model) { + return Err(Error::InvalidRequest(format!( + "model '{name}' references missing equivalent model '{equivalent_model}'" + ))); + } + } + } + Ok(()) + } + + fn rebuild_equivalence_index(&mut self) { + let mut adjacency: HashMap> = HashMap::new(); + for name in self.models.keys() { + adjacency.entry(name.clone()).or_default(); + } + + for (name, equivalents) in &self.equivalent_models { + if !self.models.contains_key(name) { + continue; + } + for equivalent_model in equivalents { + if !self.models.contains_key(equivalent_model) { + continue; + } + adjacency + .entry(name.clone()) + .or_default() + .push(equivalent_model.clone()); + adjacency + .entry(equivalent_model.clone()) + .or_default() + .push(name.clone()); + } + } + + let mut visited = std::collections::HashSet::new(); + let mut index = HashMap::new(); + for name in self.models.keys() { + if visited.contains(name) { + continue; + } + + let mut stack = vec![name.clone()]; + let mut component = Vec::new(); + while let Some(current) = stack.pop() { + if !visited.insert(current.clone()) { + continue; + } + component.push(current.clone()); + if let Some(neighbors) = adjacency.get(¤t) { + stack.extend(neighbors.iter().cloned()); + } + } + + if component.len() <= 1 { + continue; + } + component.sort(); + for member in &component { + index.insert( + member.clone(), + component + .iter() + .filter(|other| *other != member) + .cloned() + .collect(), + ); + } + } + + self.equivalence_index = index; + } } pub fn load_catalog_from_disk>(path: P) -> Result> { @@ -174,3 +324,127 @@ pub fn load_catalog_from_disk>(path: P) -> Result Result<(Arc, ProviderFormat, Vec)> { + self.resolve_one(model) + } + + pub fn resolve_for_failover( + &self, + model: &str, + ) -> Result, ProviderFormat, Vec)>> { + let mut resolved = Vec::new(); + for model_name in self.catalog.equivalent_model_names(model) { + resolved.push(self.resolve_one(&model_name)?); + } + if resolved.is_empty() { + return Err(Error::UnknownModel(model.to_string())); + } + Ok(resolved) + } + + fn resolve_one(&self, model: &str) -> Result<(Arc, ProviderFormat, Vec)> { let spec = self .catalog .get(model) diff --git a/crates/braintrust-llm-router/src/router.rs b/crates/braintrust-llm-router/src/router.rs index 597c4af7..c2b27d94 100644 --- a/crates/braintrust-llm-router/src/router.rs +++ b/crates/braintrust-llm-router/src/router.rs @@ -519,6 +519,14 @@ impl Router { output_format: ProviderFormat, fallback_aliases: &[String], ) -> Result> { + if !fallback_aliases.is_empty() { + return self.resolve_provider_routes_for_failover( + model, + output_format, + fallback_aliases, + ); + } + let (spec, catalog_format, aliases) = self.resolver.resolve(model)?; let routes: Vec> = aliases .iter() @@ -618,6 +626,92 @@ impl Router { Ok(successes) } + fn resolve_provider_routes_for_failover( + &self, + model: &str, + output_format: ProviderFormat, + fallback_aliases: &[String], + ) -> Result> { + let resolved_models = self.resolver.resolve_for_failover(model)?; + let (_, first_catalog_format, _) = resolved_models + .first() + .ok_or_else(|| Error::UnknownModel(model.to_string()))?; + let mut first_error = None; + let mut routes = Vec::new(); + let mut seen = HashSet::new(); + + if let Some((spec, catalog_format, aliases)) = resolved_models.first() { + for alias in aliases { + match self.resolve_provider( + output_format, + spec.clone(), + *catalog_format, + alias.to_string(), + ) { + Ok(route) => { + seen.insert(route.provider_alias.clone()); + routes.push(route); + break; + } + Err(err) => { + if first_error.is_none() { + first_error = Some(err); + } + } + } + } + } + + for fallback_alias in fallback_aliases { + if seen.contains(fallback_alias) { + continue; + } + + for (spec, catalog_format, aliases) in &resolved_models { + if !aliases + .iter() + .any(|alias| self.alias_matches_provider(alias, fallback_alias)) + { + continue; + } + + match self.resolve_provider( + output_format, + spec.clone(), + *catalog_format, + fallback_alias.clone(), + ) { + Ok(route) => { + seen.insert(route.provider_alias.clone()); + routes.push(route); + break; + } + Err(err) => { + if first_error.is_none() { + first_error = Some(err); + } + } + } + } + } + + if routes.is_empty() { + return Err(first_error.unwrap_or_else(|| Error::NoProvider(*first_catalog_format))); + } + + Ok(routes) + } + + fn alias_matches_provider(&self, resolver_alias: &str, provider_alias: &str) -> bool { + if resolver_alias == provider_alias { + return true; + } + + self.providers + .get(provider_alias) + .is_some_and(|provider| provider.id() == resolver_alias) + } + #[cfg(test)] fn resolve_providers( &self, @@ -2689,4 +2783,85 @@ mod tests { assert_eq!(aliases, vec!["openai", "azure"]); } + + #[test] + fn failover_routes_use_equivalent_provider_native_vertex_model() { + let model = "claude-sonnet-4-6"; + let vertex_model = "publishers/anthropic/models/claude-sonnet-4-6"; + let catalog = ModelCatalog::from_json_str( + r#"{ + "claude-sonnet-4-6": { + "format": "anthropic", + "flavor": "chat", + "available_providers": ["anthropic"], + "equivalent_models": ["publishers/anthropic/models/claude-sonnet-4-6"] + }, + "publishers/anthropic/models/claude-sonnet-4-6": { + "format": "anthropic", + "flavor": "chat" + } +}"#, + ) + .expect("catalog parses"); + let catalog = catalog.map_specs(|_, spec| { + let mut spec = spec.clone(); + spec.available_providers = spec + .available_providers + .iter() + .map(|provider| { + if provider == "anthropic" { + "ANTHROPIC_API_KEY".to_string() + } else { + provider.clone() + } + }) + .collect(); + spec + }); + let router = Router::builder() + .with_catalog(Arc::new(catalog)) + .add_provider( + "ANTHROPIC_API_KEY", + FakeProvider { + name: "anthropic", + formats: vec![ProviderFormat::Anthropic], + }, + dummy_auth(), + vec![], + ) + .add_provider( + "GOOGLE_DEFAULT_CREDENTIALS", + FakeProvider { + name: "vertex", + formats: vec![ProviderFormat::VertexAnthropic], + }, + dummy_auth(), + vec![], + ) + .build() + .expect("router builds"); + + let routes = router + .resolve_provider_routes( + model, + ProviderFormat::ChatCompletions, + &[ + "ANTHROPIC_API_KEY".to_string(), + "GOOGLE_DEFAULT_CREDENTIALS".to_string(), + ], + ) + .expect("failover routes resolve"); + let route_info: Vec<(&str, &str)> = routes + .iter() + .map(|route| (route.provider_alias(), route.model())) + .collect(); + + assert_eq!( + route_info, + vec![ + ("ANTHROPIC_API_KEY", model), + ("GOOGLE_DEFAULT_CREDENTIALS", vertex_model), + ] + ); + } } From 045daf334b0d6469958dd4531362a68eca83d76d Mon Sep 17 00:00:00 2001 From: Erin McNulty Date: Fri, 12 Jun 2026 16:57:59 -0500 Subject: [PATCH 02/28] fixes --- .../braintrust-llm-router/src/catalog/mod.rs | 2 +- .../src/catalog/resolver.rs | 13 +- crates/braintrust-llm-router/src/router.rs | 122 +++++++++++++++++- 3 files changed, 126 insertions(+), 11 deletions(-) diff --git a/crates/braintrust-llm-router/src/catalog/mod.rs b/crates/braintrust-llm-router/src/catalog/mod.rs index 21ad45cf..730afda8 100644 --- a/crates/braintrust-llm-router/src/catalog/mod.rs +++ b/crates/braintrust-llm-router/src/catalog/mod.rs @@ -42,7 +42,7 @@ pub struct OverlayModelCatalog { #[derive(Debug, Clone)] pub enum CatalogResolver { Base(Arc), - Overlay(OverlayModelCatalog), + Overlay(Box), } impl CatalogResolver { diff --git a/crates/braintrust-llm-router/src/catalog/resolver.rs b/crates/braintrust-llm-router/src/catalog/resolver.rs index 8d115861..b1bcef72 100644 --- a/crates/braintrust-llm-router/src/catalog/resolver.rs +++ b/crates/braintrust-llm-router/src/catalog/resolver.rs @@ -5,6 +5,8 @@ use crate::catalog::{CatalogResolver, ModelCatalog, ModelSpec, OverlayModelCatal use crate::error::{Error, Result}; use lingua::ProviderFormat; +pub type ResolvedModel = (Arc, ProviderFormat, Vec); + #[derive(Debug, Clone)] pub struct ModelResolver { catalog: CatalogResolver, @@ -25,7 +27,7 @@ impl ModelResolver { /// do not observe per-request custom models as global catalog entries. pub fn with_overlay(base: Arc, custom: ModelCatalog) -> Self { Self { - catalog: CatalogResolver::Overlay(OverlayModelCatalog { base, custom }), + catalog: CatalogResolver::Overlay(Box::new(OverlayModelCatalog { base, custom })), aliases: HashMap::new(), } } @@ -39,14 +41,11 @@ impl ModelResolver { self.catalog.base_catalog() } - pub fn resolve(&self, model: &str) -> Result<(Arc, ProviderFormat, Vec)> { + pub fn resolve(&self, model: &str) -> Result { self.resolve_one(model) } - pub fn resolve_for_failover( - &self, - model: &str, - ) -> Result, ProviderFormat, Vec)>> { + pub fn resolve_for_failover(&self, model: &str) -> Result> { let mut resolved = Vec::new(); for model_name in self.catalog.equivalent_model_names(model) { resolved.push(self.resolve_one(&model_name)?); @@ -57,7 +56,7 @@ impl ModelResolver { Ok(resolved) } - fn resolve_one(&self, model: &str) -> Result<(Arc, ProviderFormat, Vec)> { + fn resolve_one(&self, model: &str) -> Result { let spec = self .catalog .get(model) diff --git a/crates/braintrust-llm-router/src/router.rs b/crates/braintrust-llm-router/src/router.rs index c2b27d94..4ae4ff38 100644 --- a/crates/braintrust-llm-router/src/router.rs +++ b/crates/braintrust-llm-router/src/router.rs @@ -707,9 +707,22 @@ impl Router { return true; } - self.providers - .get(provider_alias) - .is_some_and(|provider| provider.id() == resolver_alias) + if resolver_alias != "openai" + && self + .providers + .get(provider_alias) + .is_some_and(|provider| provider.id() == resolver_alias) + { + return true; + } + + matches!( + (resolver_alias, provider_alias), + ("bedrock", "AWS_DEFAULT_CREDENTIALS") + | ("databricks", "DATABRICKS_DEFAULT_CREDENTIALS") + | ("vertex", "GOOGLE_DEFAULT_CREDENTIALS") + | ("azure", "AZURE_DEFAULT_CREDENTIALS") + ) } #[cfg(test)] @@ -2629,6 +2642,109 @@ mod tests { ); } + #[test] + fn fallback_provider_routes_do_not_treat_openai_provider_id_as_allowlist_match() { + let model = "gpt-4o"; + let mut catalog = ModelCatalog::empty(); + let mut spec = openai_spec(model, ModelFlavor::Chat); + spec.available_providers = vec!["openai".into()]; + catalog.insert(model.into(), spec); + let router = Router::builder() + .with_catalog(Arc::new(catalog)) + .add_provider( + "openai", + FakeProvider { + name: "openai", + formats: vec![ProviderFormat::ChatCompletions], + }, + dummy_auth(), + vec![ProviderFormat::ChatCompletions], + ) + .add_provider( + "cerebras", + FakeProvider { + name: "openai", + formats: vec![ProviderFormat::ChatCompletions], + }, + dummy_auth(), + vec![], + ) + .build() + .expect("router builds"); + + assert_eq!( + explicit_route_aliases( + &router, + model, + ProviderFormat::ChatCompletions, + &["cerebras"] + ) + .expect("routes"), + vec!["openai".to_string()] + ); + } + + #[test] + fn failover_routes_match_named_secrets_by_concrete_provider_id() { + let model = "claude-sonnet-4-6"; + let vertex_model = "publishers/anthropic/models/claude-sonnet-4-6"; + let catalog = ModelCatalog::from_json_str( + r#"{ + "claude-sonnet-4-6": { + "format": "anthropic", + "flavor": "chat", + "available_providers": ["ANTHROPIC_API_KEY", "anthropic"], + "equivalent_models": ["publishers/anthropic/models/claude-sonnet-4-6"] + }, + "publishers/anthropic/models/claude-sonnet-4-6": { + "format": "anthropic", + "flavor": "chat", + "available_providers": ["GOOGLE_DEFAULT_CREDENTIALS", "vertex"] + } +}"#, + ) + .expect("catalog parses"); + let router = Router::builder() + .with_catalog(Arc::new(catalog)) + .add_provider( + "my-anthropic", + FakeProvider { + name: "anthropic", + formats: vec![ProviderFormat::Anthropic], + }, + dummy_auth(), + vec![], + ) + .add_provider( + "my-vertex", + FakeProvider { + name: "vertex", + formats: vec![ProviderFormat::VertexAnthropic], + }, + dummy_auth(), + vec![], + ) + .build() + .expect("router builds"); + + let routes = router + .resolve_provider_routes( + model, + ProviderFormat::ChatCompletions, + &["my-anthropic".to_string(), "my-vertex".to_string()], + ) + .expect("routes resolve"); + let route_info: Vec<(&str, &str)> = routes + .iter() + .map(|route| (route.provider_alias(), route.model())) + .collect(); + + assert_eq!( + route_info, + vec![("my-anthropic", model), ("my-vertex", vertex_model),] + ); + } + #[test] fn fallback_provider_routes_do_not_use_format_default_for_ineligible_alias() { let model = "gpt-4o"; From 6a15bd0a412cfd79e27fcc5ffd01040dd71a10b9 Mon Sep 17 00:00:00 2001 From: Erin McNulty Date: Fri, 12 Jun 2026 18:12:24 -0500 Subject: [PATCH 03/28] cleanpu provider alias stuff --- crates/braintrust-llm-router/src/router.rs | 34 +++++++++++++--------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/crates/braintrust-llm-router/src/router.rs b/crates/braintrust-llm-router/src/router.rs index 4ae4ff38..918315f5 100644 --- a/crates/braintrust-llm-router/src/router.rs +++ b/crates/braintrust-llm-router/src/router.rs @@ -707,22 +707,15 @@ impl Router { return true; } - if resolver_alias != "openai" - && self + if let Some(provider_id) = default_alias_provider_id(resolver_alias) { + return self .providers .get(provider_alias) - .is_some_and(|provider| provider.id() == resolver_alias) - { - return true; + .is_some_and(|provider| provider.id() == provider_id); } - matches!( - (resolver_alias, provider_alias), - ("bedrock", "AWS_DEFAULT_CREDENTIALS") - | ("databricks", "DATABRICKS_DEFAULT_CREDENTIALS") - | ("vertex", "GOOGLE_DEFAULT_CREDENTIALS") - | ("azure", "AZURE_DEFAULT_CREDENTIALS") - ) + default_alias_provider_id(provider_alias) + .is_some_and(|provider_id| provider_id == resolver_alias) } #[cfg(test)] @@ -916,6 +909,19 @@ impl Router { } } +fn default_alias_provider_id(alias: &str) -> Option<&'static str> { + match alias { + "ANTHROPIC_API_KEY" => Some("anthropic"), + "GEMINI_API_KEY" => Some("google"), + "MISTRAL_API_KEY" => Some("mistral"), + "AWS_DEFAULT_CREDENTIALS" => Some("bedrock"), + "GOOGLE_DEFAULT_CREDENTIALS" => Some("vertex"), + "AZURE_DEFAULT_CREDENTIALS" => Some("azure"), + "DATABRICKS_DEFAULT_CREDENTIALS" => Some("databricks"), + _ => None, + } +} + /// One provider registration: alias, provider, auth, and default formats. struct ProviderEntry { alias: String, @@ -2693,13 +2699,13 @@ mod tests { "claude-sonnet-4-6": { "format": "anthropic", "flavor": "chat", - "available_providers": ["ANTHROPIC_API_KEY", "anthropic"], + "available_providers": ["ANTHROPIC_API_KEY"], "equivalent_models": ["publishers/anthropic/models/claude-sonnet-4-6"] }, "publishers/anthropic/models/claude-sonnet-4-6": { "format": "anthropic", "flavor": "chat", - "available_providers": ["GOOGLE_DEFAULT_CREDENTIALS", "vertex"] + "available_providers": ["GOOGLE_DEFAULT_CREDENTIALS"] } }"#, ) From b0d9bdddf70eaa487c229397937b83801b6fb8c5 Mon Sep 17 00:00:00 2001 From: Erin McNulty Date: Mon, 15 Jun 2026 14:14:56 -0400 Subject: [PATCH 04/28] allow custom models to specify equivalent models --- .../braintrust-llm-router/src/catalog/mod.rs | 86 +++++++++++++++++++ crates/braintrust-llm-router/src/router.rs | 61 +++++++++++++ 2 files changed, 147 insertions(+) diff --git a/crates/braintrust-llm-router/src/catalog/mod.rs b/crates/braintrust-llm-router/src/catalog/mod.rs index 730afda8..e20e0986 100644 --- a/crates/braintrust-llm-router/src/catalog/mod.rs +++ b/crates/braintrust-llm-router/src/catalog/mod.rs @@ -238,6 +238,39 @@ impl ModelCatalog { } } + pub fn add_equivalent_models(&mut self, name: String, equivalents: I) -> Result<()> + where + I: IntoIterator, + { + if !self.models.contains_key(&name) { + return Err(Error::InvalidRequest(format!( + "model '{name}' references equivalent_models but is missing from catalog" + ))); + } + + let equivalents: Vec = equivalents + .into_iter() + .filter(|equivalent_model| !equivalent_model.is_empty()) + .collect(); + for equivalent_model in &equivalents { + if !self.models.contains_key(equivalent_model) { + return Err(Error::InvalidRequest(format!( + "model '{name}' references missing equivalent model '{equivalent_model}'" + ))); + } + } + + let entry = self.equivalent_models.entry(name).or_default(); + for equivalent_model in equivalents { + if entry.contains(&equivalent_model) { + continue; + } + entry.push(equivalent_model); + } + self.rebuild_equivalence_index(); + Ok(()) + } + fn validate_equivalent_models(&self) -> Result<()> { for (name, equivalents) in &self.equivalent_models { for equivalent_model in equivalents { @@ -419,6 +452,59 @@ mod tests { assert!(matches!(error, Error::InvalidRequest(_))); } + #[test] + fn add_equivalent_models_rebuilds_index() { + let mut catalog = ModelCatalog::from_json_str( + r#"{ + "model-a": { + "format": "openai", + "flavor": "chat" + }, + "model-b": { + "format": "openai", + "flavor": "chat" + } +}"#, + ) + .expect("catalog parses"); + + catalog + .add_equivalent_models("model-a".to_string(), vec!["model-b".to_string()]) + .expect("equivalence is valid"); + + assert_eq!( + catalog.equivalent_model_names("model-a"), + vec!["model-a".to_string(), "model-b".to_string()] + ); + assert_eq!( + catalog.equivalent_model_names("model-b"), + vec!["model-b".to_string(), "model-a".to_string()] + ); + } + + #[test] + fn add_equivalent_models_rejects_missing_reference() { + let mut catalog = ModelCatalog::from_json_str( + r#"{ + "model-a": { + "format": "openai", + "flavor": "chat" + } +}"#, + ) + .expect("catalog parses"); + + let error = catalog + .add_equivalent_models("model-a".to_string(), vec!["missing".to_string()]) + .expect_err("missing equivalent model should fail"); + + assert!(matches!(error, Error::InvalidRequest(_))); + assert_eq!( + catalog.equivalent_model_names("model-a"), + vec!["model-a".to_string()] + ); + } + #[test] fn map_specs_preserves_equivalent_model_index() { let catalog = ModelCatalog::from_json_str( diff --git a/crates/braintrust-llm-router/src/router.rs b/crates/braintrust-llm-router/src/router.rs index 918315f5..081d1ce5 100644 --- a/crates/braintrust-llm-router/src/router.rs +++ b/crates/braintrust-llm-router/src/router.rs @@ -2535,6 +2535,67 @@ mod tests { .is_ok()); } + #[test] + fn overlay_catalog_failover_routes_use_equivalent_custom_model() { + let base = ModelCatalog::empty(); + let mut custom = ModelCatalog::empty(); + let mut primary = openai_spec("custom-primary", ModelFlavor::Chat); + primary.available_providers = vec!["provider-a".to_string()]; + let mut fallback = openai_spec("custom-fallback", ModelFlavor::Chat); + fallback.available_providers = vec!["provider-b".to_string()]; + custom.insert("custom-primary".into(), primary); + custom.insert("custom-fallback".into(), fallback); + custom + .add_equivalent_models( + "custom-primary".to_string(), + vec!["custom-fallback".to_string()], + ) + .expect("equivalence is valid"); + + let router = Router::builder() + .with_overlay_catalog(Arc::new(base), custom) + .add_provider( + "provider-a", + FakeProvider { + name: "openai", + formats: vec![ProviderFormat::ChatCompletions], + }, + dummy_auth(), + vec![], + ) + .add_provider( + "provider-b", + FakeProvider { + name: "openai", + formats: vec![ProviderFormat::ChatCompletions], + }, + dummy_auth(), + vec![], + ) + .build() + .expect("router builds"); + + let routes = router + .resolve_provider_routes( + "custom-primary", + ProviderFormat::ChatCompletions, + &["provider-a".to_string(), "provider-b".to_string()], + ) + .expect("failover routes resolve"); + let route_info: Vec<(&str, &str)> = routes + .iter() + .map(|route| (route.provider_alias(), route.model())) + .collect(); + + assert_eq!( + route_info, + vec![ + ("provider-a", "custom-primary"), + ("provider-b", "custom-fallback"), + ] + ); + } + #[test] fn resolved_aliases_returns_only_registered_available_providers() { let model = "gpt-4o"; From f46669b5fc901abe565121446763824b72392b78 Mon Sep 17 00:00:00 2001 From: Erin McNulty Date: Mon, 15 Jun 2026 16:57:38 -0400 Subject: [PATCH 05/28] merge base and custom catalog, this is getting too complicated to have seperate ones --- .../braintrust-llm-router/src/catalog/mod.rs | 169 ++++++++++++------ .../src/catalog/resolver.rs | 29 ++- crates/braintrust-llm-router/src/router.rs | 31 +--- 3 files changed, 130 insertions(+), 99 deletions(-) diff --git a/crates/braintrust-llm-router/src/catalog/mod.rs b/crates/braintrust-llm-router/src/catalog/mod.rs index e20e0986..047008e5 100644 --- a/crates/braintrust-llm-router/src/catalog/mod.rs +++ b/crates/braintrust-llm-router/src/catalog/mod.rs @@ -7,7 +7,7 @@ pub use spec::{ModelFlavor, ModelSpec}; use lingua::ProviderFormat; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::fs::File; use std::io::Read; use std::path::Path; @@ -24,62 +24,6 @@ pub struct ModelCatalog { equivalence_index: HashMap>, } -/// A request-local catalog overlay. -/// -/// Secret-defined custom models live in `custom` and shadow entries in the -/// shared `base` catalog. This avoids cloning the base catalog when adding -/// per-request model definitions. -#[derive(Debug, Clone)] -pub struct OverlayModelCatalog { - pub base: Arc, - pub custom: ModelCatalog, -} - -/// Catalog view used by the router resolver. -/// -/// `Base` preserves the existing router behavior. `Overlay` checks custom -/// models first and then falls back to the shared base catalog. -#[derive(Debug, Clone)] -pub enum CatalogResolver { - Base(Arc), - Overlay(Box), -} - -impl CatalogResolver { - pub fn base_catalog(&self) -> Arc { - match self { - Self::Base(catalog) => Arc::clone(catalog), - Self::Overlay(overlay) => Arc::clone(&overlay.base), - } - } - - pub fn get(&self, name: &str) -> Option> { - match self { - Self::Base(catalog) => catalog.get(name), - Self::Overlay(overlay) => overlay.custom.get(name).or_else(|| overlay.base.get(name)), - } - } - - pub fn equivalent_model_names(&self, name: &str) -> Vec { - match self { - Self::Base(catalog) => catalog.equivalent_model_names(name), - Self::Overlay(overlay) => { - if overlay.custom.get(name).is_some() { - overlay.custom.equivalent_model_names(name) - } else { - overlay.base.equivalent_model_names(name) - } - } - } - } -} - -impl From> for CatalogResolver { - fn from(catalog: Arc) -> Self { - Self::Base(catalog) - } -} - fn parse_equivalent_models(content: &str) -> Result>> { let raw: HashMap = serde_json::from_str(content)?; let mut equivalent_models = HashMap::new(); @@ -216,6 +160,37 @@ impl ModelCatalog { out } + /// Merge request-local custom models into this catalog. + /// + /// Custom model specs shadow base specs with the same name. Any base + /// equivalence edges involving shadowed names are removed so custom models + /// do not inherit equivalence relationships from the model they replace. + pub fn merge_custom_catalog(&self, custom: ModelCatalog) -> Self { + let mut merged = self.clone(); + let ModelCatalog { + models, + equivalent_models, + .. + } = custom; + let custom_model_names: HashSet = models.keys().cloned().collect(); + + for name in &custom_model_names { + merged.equivalent_models.remove(name); + } + for equivalents in merged.equivalent_models.values_mut() { + equivalents.retain(|name| !custom_model_names.contains(name)); + } + + for (name, spec) in models { + merged.insert(name, spec.as_ref().clone()); + } + for (name, equivalents) in equivalent_models { + merged.equivalent_models.insert(name, equivalents); + } + merged.rebuild_equivalence_index(); + merged + } + pub fn len(&self) -> usize { self.models.len() } @@ -225,6 +200,32 @@ impl ModelCatalog { } pub fn insert(&mut self, name: String, mut spec: ModelSpec) { + if let Some(existing) = self.models.get(&name).cloned() { + let format = existing.format; + let parent = existing.parent.clone(); + if let Some(models) = self.by_format.get_mut(&format) { + models.retain(|model| model != &name); + } + if self + .by_format + .get(&format) + .is_some_and(|models| models.is_empty()) + { + self.by_format.remove(&format); + } + if let Some(parent) = parent { + if let Some(models) = self.by_parent.get_mut(&parent) { + models.retain(|model| model != &name); + } + if self + .by_parent + .get(&parent) + .is_some_and(|models| models.is_empty()) + { + self.by_parent.remove(&parent); + } + } + } if spec.model.is_empty() { spec.model = name.clone(); } @@ -533,4 +534,58 @@ mod tests { vec!["model-a".to_string(), "model-b".to_string()] ); } + + #[test] + fn merge_custom_catalog_shadows_base_specs_and_equivalence() { + let base = ModelCatalog::from_json_str( + r#"{ + "model-a": { + "format": "anthropic", + "flavor": "chat", + "equivalent_models": ["model-b"] + }, + "model-b": { + "format": "anthropic", + "flavor": "chat" + } +}"#, + ) + .expect("base catalog parses"); + let custom = ModelCatalog::from_json_str( + r#"{ + "model-a": { + "format": "openai", + "flavor": "chat", + "available_providers": ["custom-provider"] + } +}"#, + ) + .expect("custom catalog parses"); + + let merged = base.merge_custom_catalog(custom); + + assert_eq!( + merged + .get("model-a") + .expect("custom model") + .available_providers, + vec!["custom-provider".to_string()] + ); + assert_eq!( + merged.equivalent_model_names("model-a"), + vec!["model-a".to_string()] + ); + assert_eq!( + merged.equivalent_model_names("model-b"), + vec!["model-b".to_string()] + ); + assert_eq!( + merged.models_for_format(ProviderFormat::Anthropic), + Some(&["model-b".to_string()][..]) + ); + assert_eq!( + merged.models_for_format(ProviderFormat::ChatCompletions), + Some(&["model-a".to_string()][..]) + ); + } } diff --git a/crates/braintrust-llm-router/src/catalog/resolver.rs b/crates/braintrust-llm-router/src/catalog/resolver.rs index b1bcef72..73a7c3fb 100644 --- a/crates/braintrust-llm-router/src/catalog/resolver.rs +++ b/crates/braintrust-llm-router/src/catalog/resolver.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use std::sync::Arc; -use crate::catalog::{CatalogResolver, ModelCatalog, ModelSpec, OverlayModelCatalog}; +use crate::catalog::{ModelCatalog, ModelSpec}; use crate::error::{Error, Result}; use lingua::ProviderFormat; @@ -9,25 +9,14 @@ pub type ResolvedModel = (Arc, ProviderFormat, Vec); #[derive(Debug, Clone)] pub struct ModelResolver { - catalog: CatalogResolver, + catalog: Arc, aliases: HashMap, } impl ModelResolver { pub fn new(catalog: Arc) -> Self { Self { - catalog: catalog.into(), - aliases: HashMap::new(), - } - } - - /// Resolve models against custom entries first, then the shared base catalog. - /// - /// The base catalog remains the public `catalog()` view so existing callers - /// do not observe per-request custom models as global catalog entries. - pub fn with_overlay(base: Arc, custom: ModelCatalog) -> Self { - Self { - catalog: CatalogResolver::Overlay(Box::new(OverlayModelCatalog { base, custom })), + catalog, aliases: HashMap::new(), } } @@ -38,7 +27,7 @@ impl ModelResolver { } pub fn catalog(&self) -> Arc { - self.catalog.base_catalog() + Arc::clone(&self.catalog) } pub fn resolve(&self, model: &str) -> Result { @@ -182,7 +171,7 @@ mod tests { } #[test] - fn resolve_overlay_custom_model_before_base_catalog() { + fn resolve_merged_catalog_custom_model_before_base_catalog() { let mut base = ModelCatalog::empty(); base.insert( "model".into(), @@ -197,7 +186,8 @@ mod tests { vec!["custom-provider".into()], ), ); - let resolver = ModelResolver::with_overlay(Arc::new(base), custom); + let merged = base.merge_custom_catalog(custom); + let resolver = ModelResolver::new(Arc::new(merged)); let (spec, format, aliases) = resolver.resolve("model").expect("resolves"); assert_eq!(spec.model, "custom-model"); @@ -206,7 +196,7 @@ mod tests { } #[test] - fn resolve_overlay_falls_back_to_base_catalog() { + fn resolve_merged_catalog_keeps_base_catalog_models() { let mut base = ModelCatalog::empty(); base.insert( "base-model".into(), @@ -217,7 +207,8 @@ mod tests { "custom-model".into(), spec("custom-model", ProviderFormat::ChatCompletions), ); - let resolver = ModelResolver::with_overlay(Arc::new(base), custom); + let merged = base.merge_custom_catalog(custom); + let resolver = ModelResolver::new(Arc::new(merged)); let (spec, format, aliases) = resolver.resolve("base-model").expect("resolves"); assert_eq!(spec.model, "base-model"); diff --git a/crates/braintrust-llm-router/src/router.rs b/crates/braintrust-llm-router/src/router.rs index 081d1ce5..d6937904 100644 --- a/crates/braintrust-llm-router/src/router.rs +++ b/crates/braintrust-llm-router/src/router.rs @@ -932,7 +932,6 @@ struct ProviderEntry { pub struct RouterBuilder { catalog: Option>, - custom_catalog: Option, provider_entries: Vec, retry_policy: RetryPolicy, } @@ -947,7 +946,6 @@ impl RouterBuilder { pub fn new() -> Self { Self { catalog: None, - custom_catalog: None, provider_entries: Vec::new(), retry_policy: RetryPolicy::default(), } @@ -956,23 +954,11 @@ impl RouterBuilder { pub fn load_models(mut self, path: impl AsRef) -> Result { let catalog = load_catalog_from_disk(path)?; self.catalog = Some(catalog); - self.custom_catalog = None; Ok(self) } pub fn with_catalog(mut self, catalog: Arc) -> Self { self.catalog = Some(catalog); - self.custom_catalog = None; - self - } - - /// Configure the router with custom models overlaid on a shared base catalog. - /// - /// Custom entries shadow base entries for resolution, while `Router::catalog()` - /// continues to expose the base catalog for compatibility. - pub fn with_overlay_catalog(mut self, base: Arc, custom: ModelCatalog) -> Self { - self.catalog = Some(base); - self.custom_catalog = Some(custom); self } @@ -1021,10 +1007,7 @@ impl RouterBuilder { let catalog = self .catalog .ok_or_else(|| Error::InvalidRequest("model catalog not configured".into()))?; - let resolver = match self.custom_catalog { - Some(custom) => ModelResolver::with_overlay(Arc::clone(&catalog), custom), - None => ModelResolver::new(Arc::clone(&catalog)), - }; + let resolver = ModelResolver::new(Arc::clone(&catalog)); let mut providers = HashMap::new(); let mut auth_configs = HashMap::new(); @@ -2499,7 +2482,7 @@ mod tests { } #[test] - fn overlay_catalog_resolves_custom_and_base_models() { + fn merged_catalog_resolves_custom_and_base_models() { let mut base = ModelCatalog::empty(); base.insert( "base-model".into(), @@ -2510,9 +2493,10 @@ mod tests { "custom-model".into(), openai_spec_with_available_providers("custom-model", ModelFlavor::Chat), ); + let catalog = base.merge_custom_catalog(custom); let router = Router::builder() - .with_overlay_catalog(Arc::new(base), custom) + .with_catalog(Arc::new(catalog)) .add_provider( "openai", FakeProvider { @@ -2526,7 +2510,7 @@ mod tests { .expect("router builds"); assert!(router.catalog().get("base-model").is_some()); - assert!(router.catalog().get("custom-model").is_none()); + assert!(router.catalog().get("custom-model").is_some()); assert!(router .resolve_provider_routes("base-model", ProviderFormat::ChatCompletions, &[]) .is_ok()); @@ -2536,7 +2520,7 @@ mod tests { } #[test] - fn overlay_catalog_failover_routes_use_equivalent_custom_model() { + fn merged_catalog_failover_routes_use_equivalent_custom_model() { let base = ModelCatalog::empty(); let mut custom = ModelCatalog::empty(); let mut primary = openai_spec("custom-primary", ModelFlavor::Chat); @@ -2551,9 +2535,10 @@ mod tests { vec!["custom-fallback".to_string()], ) .expect("equivalence is valid"); + let catalog = base.merge_custom_catalog(custom); let router = Router::builder() - .with_overlay_catalog(Arc::new(base), custom) + .with_catalog(Arc::new(catalog)) .add_provider( "provider-a", FakeProvider { From 5c9b327b96ba66518e6ae8ce9772924339ea79ba Mon Sep 17 00:00:00 2001 From: Erin McNulty Date: Mon, 15 Jun 2026 17:48:42 -0400 Subject: [PATCH 06/28] comments --- .../src/providers/mod.rs | 5 + .../src/providers/openai.rs | 11 + crates/braintrust-llm-router/src/router.rs | 214 +++++++++++++++++- 3 files changed, 222 insertions(+), 8 deletions(-) diff --git a/crates/braintrust-llm-router/src/providers/mod.rs b/crates/braintrust-llm-router/src/providers/mod.rs index f9f84872..051e4ea9 100644 --- a/crates/braintrust-llm-router/src/providers/mod.rs +++ b/crates/braintrust-llm-router/src/providers/mod.rs @@ -224,6 +224,11 @@ pub trait Provider: Send + Sync { /// Provider identifier (e.g., "openai", "anthropic"). fn id(&self) -> &'static str; + /// Whether this provider registration satisfies a catalog provider alias. + fn matches_provider_alias(&self, alias: &str) -> bool { + self.id() == alias + } + /// All formats this provider can handle. fn provider_formats(&self) -> Vec; diff --git a/crates/braintrust-llm-router/src/providers/openai.rs b/crates/braintrust-llm-router/src/providers/openai.rs index 21c3fba2..0b62b6e2 100644 --- a/crates/braintrust-llm-router/src/providers/openai.rs +++ b/crates/braintrust-llm-router/src/providers/openai.rs @@ -43,6 +43,7 @@ pub struct OpenAIProvider { client: ClientWithMiddleware, config: OpenAIConfig, endpoint_template: Option, + provider_alias: String, } impl OpenAIProvider { @@ -63,9 +64,15 @@ impl OpenAIProvider { client, endpoint_template, config, + provider_alias: "openai".to_string(), }) } + pub fn with_provider_alias(mut self, provider_alias: impl Into) -> Self { + self.provider_alias = provider_alias.into(); + self + } + /// Create an OpenAI provider from configuration parameters. /// /// Extracts OpenAI-specific options from metadata: @@ -182,6 +189,10 @@ impl crate::providers::Provider for OpenAIProvider { "openai" } + fn matches_provider_alias(&self, alias: &str) -> bool { + self.provider_alias == alias + } + fn provider_formats(&self) -> Vec { vec![ProviderFormat::ChatCompletions, ProviderFormat::Responses] } diff --git a/crates/braintrust-llm-router/src/router.rs b/crates/braintrust-llm-router/src/router.rs index d6937904..708fcb3a 100644 --- a/crates/braintrust-llm-router/src/router.rs +++ b/crates/braintrust-llm-router/src/router.rs @@ -124,13 +124,16 @@ pub fn create_provider( timeout, client_settings, )?)), - kind if is_openai_compatible(kind) => Ok(Arc::new(OpenAIProvider::from_config( - endpoint, - endpoint_template, - timeout, - metadata, - client_settings, - )?)), + kind if is_openai_compatible(kind) => Ok(Arc::new( + OpenAIProvider::from_config( + endpoint, + endpoint_template, + timeout, + metadata, + client_settings, + )? + .with_provider_alias(kind), + )), other => Err(Error::InvalidRequest(format!( "unsupported provider kind: {other}" ))), @@ -711,7 +714,14 @@ impl Router { return self .providers .get(provider_alias) - .is_some_and(|provider| provider.id() == provider_id); + .is_some_and(|provider| provider.matches_provider_alias(provider_id)); + } + + if concrete_provider_id_alias(resolver_alias) { + return self + .providers + .get(provider_alias) + .is_some_and(|provider| provider.matches_provider_alias(resolver_alias)); } default_alias_provider_id(provider_alias) @@ -911,6 +921,7 @@ impl Router { fn default_alias_provider_id(alias: &str) -> Option<&'static str> { match alias { + "OPENAI_API_KEY" => Some("openai"), "ANTHROPIC_API_KEY" => Some("anthropic"), "GEMINI_API_KEY" => Some("google"), "MISTRAL_API_KEY" => Some("mistral"), @@ -922,6 +933,13 @@ fn default_alias_provider_id(alias: &str) -> Option<&'static str> { } } +fn concrete_provider_id_alias(alias: &str) -> bool { + matches!( + alias, + "anthropic" | "google" | "mistral" | "bedrock" | "vertex" | "azure" | "databricks" + ) +} + /// One provider registration: alias, provider, auth, and default formats. struct ProviderEntry { alias: String, @@ -1167,6 +1185,51 @@ mod tests { } } + struct FakeOpenAICompatibleProvider { + alias: &'static str, + } + + #[async_trait] + impl Provider for FakeOpenAICompatibleProvider { + fn id(&self) -> &'static str { + "openai" + } + + fn matches_provider_alias(&self, alias: &str) -> bool { + self.alias == alias + } + + fn provider_formats(&self) -> Vec { + vec![ProviderFormat::ChatCompletions] + } + + async fn complete( + &self, + _payload: Bytes, + _auth: &AuthConfig, + _spec: &ModelSpec, + _format: ProviderFormat, + _client_headers: &ClientHeaders, + ) -> Result { + Ok(Bytes::from("{}")) + } + + async fn complete_stream( + &self, + _payload: Bytes, + _auth: &AuthConfig, + _spec: &ModelSpec, + _format: ProviderFormat, + _client_headers: &ClientHeaders, + ) -> Result { + unimplemented!() + } + + async fn health_check(&self, _auth: &AuthConfig) -> Result<()> { + Ok(()) + } + } + fn google_spec(model: &str) -> ModelSpec { ModelSpec { model: model.to_string(), @@ -2736,6 +2799,82 @@ mod tests { ); } + #[test] + fn fallback_provider_routes_match_named_openai_secret_for_default_alias() { + let model = "gpt-4o"; + let mut catalog = ModelCatalog::empty(); + let mut spec = openai_spec(model, ModelFlavor::Chat); + spec.available_providers = vec!["OPENAI_API_KEY".into()]; + catalog.insert(model.into(), spec); + let router = Router::builder() + .with_catalog(Arc::new(catalog)) + .add_provider( + "my-openai", + FakeProvider { + name: "openai", + formats: vec![ProviderFormat::ChatCompletions], + }, + dummy_auth(), + vec![], + ) + .build() + .expect("router builds"); + + assert_eq!( + explicit_route_aliases( + &router, + model, + ProviderFormat::ChatCompletions, + &["my-openai"] + ) + .expect("routes"), + vec!["my-openai".to_string()] + ); + } + + #[test] + fn fallback_provider_routes_do_not_match_openai_default_alias_to_compatible_provider() { + let model = "gpt-4o"; + let mut catalog = ModelCatalog::empty(); + let mut spec = openai_spec(model, ModelFlavor::Chat); + spec.available_providers = vec!["OPENAI_API_KEY".into()]; + catalog.insert(model.into(), spec); + let router = Router::builder() + .with_catalog(Arc::new(catalog)) + .add_provider( + "openai", + FakeProvider { + name: "openai", + formats: vec![ProviderFormat::ChatCompletions], + }, + dummy_auth(), + vec![ProviderFormat::ChatCompletions], + ) + .add_provider( + "cerebras", + FakeOpenAICompatibleProvider { alias: "cerebras" }, + dummy_auth(), + vec![], + ) + .build() + .expect("router builds"); + + let err = match explicit_route_aliases( + &router, + model, + ProviderFormat::ChatCompletions, + &["cerebras"], + ) { + Ok(_) => panic!("OpenAI-compatible provider should not satisfy OpenAI default alias"), + Err(err) => err, + }; + + assert!(matches!( + err, + Error::NoProvider(ProviderFormat::ChatCompletions) + )); + } + #[test] fn failover_routes_match_named_secrets_by_concrete_provider_id() { let model = "claude-sonnet-4-6"; @@ -2797,6 +2936,65 @@ mod tests { ); } + #[test] + fn failover_routes_match_named_secrets_when_equivalents_omit_available_providers() { + let model = "claude-sonnet-4-6"; + let vertex_model = "publishers/anthropic/models/claude-sonnet-4-6"; + let catalog = ModelCatalog::from_json_str( + r#"{ + "claude-sonnet-4-6": { + "format": "anthropic", + "flavor": "chat", + "equivalent_models": ["publishers/anthropic/models/claude-sonnet-4-6"] + }, + "publishers/anthropic/models/claude-sonnet-4-6": { + "format": "anthropic", + "flavor": "chat" + } +}"#, + ) + .expect("catalog parses"); + let router = Router::builder() + .with_catalog(Arc::new(catalog)) + .add_provider( + "my-anthropic", + FakeProvider { + name: "anthropic", + formats: vec![ProviderFormat::Anthropic], + }, + dummy_auth(), + vec![], + ) + .add_provider( + "my-vertex", + FakeProvider { + name: "vertex", + formats: vec![ProviderFormat::VertexAnthropic], + }, + dummy_auth(), + vec![], + ) + .build() + .expect("router builds"); + + let routes = router + .resolve_provider_routes( + model, + ProviderFormat::ChatCompletions, + &["my-anthropic".to_string(), "my-vertex".to_string()], + ) + .expect("routes resolve"); + let route_info: Vec<(&str, &str)> = routes + .iter() + .map(|route| (route.provider_alias(), route.model())) + .collect(); + + assert_eq!( + route_info, + vec![("my-anthropic", model), ("my-vertex", vertex_model),] + ); + } + #[test] fn fallback_provider_routes_do_not_use_format_default_for_ineligible_alias() { let model = "gpt-4o"; From c8458d135d11be6a590d7c000b6714b15ecc6f2a Mon Sep 17 00:00:00 2001 From: Erin McNulty Date: Mon, 15 Jun 2026 17:54:08 -0400 Subject: [PATCH 07/28] Revert "merge base and custom catalog, this is getting too complicated to have seperate ones" This reverts commit f46669b5fc901abe565121446763824b72392b78. --- .../braintrust-llm-router/src/catalog/mod.rs | 169 ++++++------------ .../src/catalog/resolver.rs | 29 +-- crates/braintrust-llm-router/src/router.rs | 31 +++- 3 files changed, 99 insertions(+), 130 deletions(-) diff --git a/crates/braintrust-llm-router/src/catalog/mod.rs b/crates/braintrust-llm-router/src/catalog/mod.rs index 047008e5..e20e0986 100644 --- a/crates/braintrust-llm-router/src/catalog/mod.rs +++ b/crates/braintrust-llm-router/src/catalog/mod.rs @@ -7,7 +7,7 @@ pub use spec::{ModelFlavor, ModelSpec}; use lingua::ProviderFormat; -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::fs::File; use std::io::Read; use std::path::Path; @@ -24,6 +24,62 @@ pub struct ModelCatalog { equivalence_index: HashMap>, } +/// A request-local catalog overlay. +/// +/// Secret-defined custom models live in `custom` and shadow entries in the +/// shared `base` catalog. This avoids cloning the base catalog when adding +/// per-request model definitions. +#[derive(Debug, Clone)] +pub struct OverlayModelCatalog { + pub base: Arc, + pub custom: ModelCatalog, +} + +/// Catalog view used by the router resolver. +/// +/// `Base` preserves the existing router behavior. `Overlay` checks custom +/// models first and then falls back to the shared base catalog. +#[derive(Debug, Clone)] +pub enum CatalogResolver { + Base(Arc), + Overlay(Box), +} + +impl CatalogResolver { + pub fn base_catalog(&self) -> Arc { + match self { + Self::Base(catalog) => Arc::clone(catalog), + Self::Overlay(overlay) => Arc::clone(&overlay.base), + } + } + + pub fn get(&self, name: &str) -> Option> { + match self { + Self::Base(catalog) => catalog.get(name), + Self::Overlay(overlay) => overlay.custom.get(name).or_else(|| overlay.base.get(name)), + } + } + + pub fn equivalent_model_names(&self, name: &str) -> Vec { + match self { + Self::Base(catalog) => catalog.equivalent_model_names(name), + Self::Overlay(overlay) => { + if overlay.custom.get(name).is_some() { + overlay.custom.equivalent_model_names(name) + } else { + overlay.base.equivalent_model_names(name) + } + } + } + } +} + +impl From> for CatalogResolver { + fn from(catalog: Arc) -> Self { + Self::Base(catalog) + } +} + fn parse_equivalent_models(content: &str) -> Result>> { let raw: HashMap = serde_json::from_str(content)?; let mut equivalent_models = HashMap::new(); @@ -160,37 +216,6 @@ impl ModelCatalog { out } - /// Merge request-local custom models into this catalog. - /// - /// Custom model specs shadow base specs with the same name. Any base - /// equivalence edges involving shadowed names are removed so custom models - /// do not inherit equivalence relationships from the model they replace. - pub fn merge_custom_catalog(&self, custom: ModelCatalog) -> Self { - let mut merged = self.clone(); - let ModelCatalog { - models, - equivalent_models, - .. - } = custom; - let custom_model_names: HashSet = models.keys().cloned().collect(); - - for name in &custom_model_names { - merged.equivalent_models.remove(name); - } - for equivalents in merged.equivalent_models.values_mut() { - equivalents.retain(|name| !custom_model_names.contains(name)); - } - - for (name, spec) in models { - merged.insert(name, spec.as_ref().clone()); - } - for (name, equivalents) in equivalent_models { - merged.equivalent_models.insert(name, equivalents); - } - merged.rebuild_equivalence_index(); - merged - } - pub fn len(&self) -> usize { self.models.len() } @@ -200,32 +225,6 @@ impl ModelCatalog { } pub fn insert(&mut self, name: String, mut spec: ModelSpec) { - if let Some(existing) = self.models.get(&name).cloned() { - let format = existing.format; - let parent = existing.parent.clone(); - if let Some(models) = self.by_format.get_mut(&format) { - models.retain(|model| model != &name); - } - if self - .by_format - .get(&format) - .is_some_and(|models| models.is_empty()) - { - self.by_format.remove(&format); - } - if let Some(parent) = parent { - if let Some(models) = self.by_parent.get_mut(&parent) { - models.retain(|model| model != &name); - } - if self - .by_parent - .get(&parent) - .is_some_and(|models| models.is_empty()) - { - self.by_parent.remove(&parent); - } - } - } if spec.model.is_empty() { spec.model = name.clone(); } @@ -534,58 +533,4 @@ mod tests { vec!["model-a".to_string(), "model-b".to_string()] ); } - - #[test] - fn merge_custom_catalog_shadows_base_specs_and_equivalence() { - let base = ModelCatalog::from_json_str( - r#"{ - "model-a": { - "format": "anthropic", - "flavor": "chat", - "equivalent_models": ["model-b"] - }, - "model-b": { - "format": "anthropic", - "flavor": "chat" - } -}"#, - ) - .expect("base catalog parses"); - let custom = ModelCatalog::from_json_str( - r#"{ - "model-a": { - "format": "openai", - "flavor": "chat", - "available_providers": ["custom-provider"] - } -}"#, - ) - .expect("custom catalog parses"); - - let merged = base.merge_custom_catalog(custom); - - assert_eq!( - merged - .get("model-a") - .expect("custom model") - .available_providers, - vec!["custom-provider".to_string()] - ); - assert_eq!( - merged.equivalent_model_names("model-a"), - vec!["model-a".to_string()] - ); - assert_eq!( - merged.equivalent_model_names("model-b"), - vec!["model-b".to_string()] - ); - assert_eq!( - merged.models_for_format(ProviderFormat::Anthropic), - Some(&["model-b".to_string()][..]) - ); - assert_eq!( - merged.models_for_format(ProviderFormat::ChatCompletions), - Some(&["model-a".to_string()][..]) - ); - } } diff --git a/crates/braintrust-llm-router/src/catalog/resolver.rs b/crates/braintrust-llm-router/src/catalog/resolver.rs index 73a7c3fb..b1bcef72 100644 --- a/crates/braintrust-llm-router/src/catalog/resolver.rs +++ b/crates/braintrust-llm-router/src/catalog/resolver.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use std::sync::Arc; -use crate::catalog::{ModelCatalog, ModelSpec}; +use crate::catalog::{CatalogResolver, ModelCatalog, ModelSpec, OverlayModelCatalog}; use crate::error::{Error, Result}; use lingua::ProviderFormat; @@ -9,14 +9,25 @@ pub type ResolvedModel = (Arc, ProviderFormat, Vec); #[derive(Debug, Clone)] pub struct ModelResolver { - catalog: Arc, + catalog: CatalogResolver, aliases: HashMap, } impl ModelResolver { pub fn new(catalog: Arc) -> Self { Self { - catalog, + catalog: catalog.into(), + aliases: HashMap::new(), + } + } + + /// Resolve models against custom entries first, then the shared base catalog. + /// + /// The base catalog remains the public `catalog()` view so existing callers + /// do not observe per-request custom models as global catalog entries. + pub fn with_overlay(base: Arc, custom: ModelCatalog) -> Self { + Self { + catalog: CatalogResolver::Overlay(Box::new(OverlayModelCatalog { base, custom })), aliases: HashMap::new(), } } @@ -27,7 +38,7 @@ impl ModelResolver { } pub fn catalog(&self) -> Arc { - Arc::clone(&self.catalog) + self.catalog.base_catalog() } pub fn resolve(&self, model: &str) -> Result { @@ -171,7 +182,7 @@ mod tests { } #[test] - fn resolve_merged_catalog_custom_model_before_base_catalog() { + fn resolve_overlay_custom_model_before_base_catalog() { let mut base = ModelCatalog::empty(); base.insert( "model".into(), @@ -186,8 +197,7 @@ mod tests { vec!["custom-provider".into()], ), ); - let merged = base.merge_custom_catalog(custom); - let resolver = ModelResolver::new(Arc::new(merged)); + let resolver = ModelResolver::with_overlay(Arc::new(base), custom); let (spec, format, aliases) = resolver.resolve("model").expect("resolves"); assert_eq!(spec.model, "custom-model"); @@ -196,7 +206,7 @@ mod tests { } #[test] - fn resolve_merged_catalog_keeps_base_catalog_models() { + fn resolve_overlay_falls_back_to_base_catalog() { let mut base = ModelCatalog::empty(); base.insert( "base-model".into(), @@ -207,8 +217,7 @@ mod tests { "custom-model".into(), spec("custom-model", ProviderFormat::ChatCompletions), ); - let merged = base.merge_custom_catalog(custom); - let resolver = ModelResolver::new(Arc::new(merged)); + let resolver = ModelResolver::with_overlay(Arc::new(base), custom); let (spec, format, aliases) = resolver.resolve("base-model").expect("resolves"); assert_eq!(spec.model, "base-model"); diff --git a/crates/braintrust-llm-router/src/router.rs b/crates/braintrust-llm-router/src/router.rs index 708fcb3a..773bf4c9 100644 --- a/crates/braintrust-llm-router/src/router.rs +++ b/crates/braintrust-llm-router/src/router.rs @@ -950,6 +950,7 @@ struct ProviderEntry { pub struct RouterBuilder { catalog: Option>, + custom_catalog: Option, provider_entries: Vec, retry_policy: RetryPolicy, } @@ -964,6 +965,7 @@ impl RouterBuilder { pub fn new() -> Self { Self { catalog: None, + custom_catalog: None, provider_entries: Vec::new(), retry_policy: RetryPolicy::default(), } @@ -972,11 +974,23 @@ impl RouterBuilder { pub fn load_models(mut self, path: impl AsRef) -> Result { let catalog = load_catalog_from_disk(path)?; self.catalog = Some(catalog); + self.custom_catalog = None; Ok(self) } pub fn with_catalog(mut self, catalog: Arc) -> Self { self.catalog = Some(catalog); + self.custom_catalog = None; + self + } + + /// Configure the router with custom models overlaid on a shared base catalog. + /// + /// Custom entries shadow base entries for resolution, while `Router::catalog()` + /// continues to expose the base catalog for compatibility. + pub fn with_overlay_catalog(mut self, base: Arc, custom: ModelCatalog) -> Self { + self.catalog = Some(base); + self.custom_catalog = Some(custom); self } @@ -1025,7 +1039,10 @@ impl RouterBuilder { let catalog = self .catalog .ok_or_else(|| Error::InvalidRequest("model catalog not configured".into()))?; - let resolver = ModelResolver::new(Arc::clone(&catalog)); + let resolver = match self.custom_catalog { + Some(custom) => ModelResolver::with_overlay(Arc::clone(&catalog), custom), + None => ModelResolver::new(Arc::clone(&catalog)), + }; let mut providers = HashMap::new(); let mut auth_configs = HashMap::new(); @@ -2545,7 +2562,7 @@ mod tests { } #[test] - fn merged_catalog_resolves_custom_and_base_models() { + fn overlay_catalog_resolves_custom_and_base_models() { let mut base = ModelCatalog::empty(); base.insert( "base-model".into(), @@ -2556,10 +2573,9 @@ mod tests { "custom-model".into(), openai_spec_with_available_providers("custom-model", ModelFlavor::Chat), ); - let catalog = base.merge_custom_catalog(custom); let router = Router::builder() - .with_catalog(Arc::new(catalog)) + .with_overlay_catalog(Arc::new(base), custom) .add_provider( "openai", FakeProvider { @@ -2573,7 +2589,7 @@ mod tests { .expect("router builds"); assert!(router.catalog().get("base-model").is_some()); - assert!(router.catalog().get("custom-model").is_some()); + assert!(router.catalog().get("custom-model").is_none()); assert!(router .resolve_provider_routes("base-model", ProviderFormat::ChatCompletions, &[]) .is_ok()); @@ -2583,7 +2599,7 @@ mod tests { } #[test] - fn merged_catalog_failover_routes_use_equivalent_custom_model() { + fn overlay_catalog_failover_routes_use_equivalent_custom_model() { let base = ModelCatalog::empty(); let mut custom = ModelCatalog::empty(); let mut primary = openai_spec("custom-primary", ModelFlavor::Chat); @@ -2598,10 +2614,9 @@ mod tests { vec!["custom-fallback".to_string()], ) .expect("equivalence is valid"); - let catalog = base.merge_custom_catalog(custom); let router = Router::builder() - .with_catalog(Arc::new(catalog)) + .with_overlay_catalog(Arc::new(base), custom) .add_provider( "provider-a", FakeProvider { From f3e2c0ed0c512a74e9711df5530c32e1c07f854c Mon Sep 17 00:00:00 2001 From: Erin McNulty Date: Mon, 15 Jun 2026 18:04:51 -0400 Subject: [PATCH 08/28] make shared equivalence index --- .../braintrust-llm-router/src/catalog/mod.rs | 273 +++++++++++++----- .../src/catalog/resolver.rs | 2 +- crates/braintrust-llm-router/src/router.rs | 64 ++++ 3 files changed, 268 insertions(+), 71 deletions(-) diff --git a/crates/braintrust-llm-router/src/catalog/mod.rs b/crates/braintrust-llm-router/src/catalog/mod.rs index e20e0986..f9c170d5 100644 --- a/crates/braintrust-llm-router/src/catalog/mod.rs +++ b/crates/braintrust-llm-router/src/catalog/mod.rs @@ -7,7 +7,7 @@ pub use spec::{ModelFlavor, ModelSpec}; use lingua::ProviderFormat; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::fs::File; use std::io::Read; use std::path::Path; @@ -33,6 +33,62 @@ pub struct ModelCatalog { pub struct OverlayModelCatalog { pub base: Arc, pub custom: ModelCatalog, + equivalence_index: HashMap>, +} + +impl OverlayModelCatalog { + pub fn new(base: Arc, custom: ModelCatalog) -> Self { + let custom_model_names: HashSet = custom.models.keys().cloned().collect(); + let visible_models = base + .models + .keys() + .filter(|name| !custom_model_names.contains(*name)) + .chain(custom.models.keys()) + .cloned() + .collect(); + let mut equivalence_edges: HashMap> = base + .equivalent_models + .iter() + .filter(|(name, _)| !custom_model_names.contains(*name)) + .map(|(name, equivalents)| { + ( + name.clone(), + equivalents + .iter() + .filter(|equivalent| !custom_model_names.contains(*equivalent)) + .cloned() + .collect(), + ) + }) + .collect(); + equivalence_edges.extend(custom.equivalent_models.clone()); + let equivalence_index = build_equivalence_index(visible_models, &equivalence_edges); + Self { + base, + custom, + equivalence_index, + } + } + + pub fn base_catalog(&self) -> Arc { + Arc::clone(&self.base) + } + + pub fn get(&self, name: &str) -> Option> { + self.custom.get(name).or_else(|| self.base.get(name)) + } + + pub fn equivalent_model_names(&self, name: &str) -> Vec { + let Some(_) = self.get(name) else { + return Vec::new(); + }; + + let mut names = vec![name.to_string()]; + if let Some(equivalent_names) = self.equivalence_index.get(name) { + names.extend(equivalent_names.iter().cloned()); + } + names + } } /// Catalog view used by the router resolver. @@ -49,27 +105,21 @@ impl CatalogResolver { pub fn base_catalog(&self) -> Arc { match self { Self::Base(catalog) => Arc::clone(catalog), - Self::Overlay(overlay) => Arc::clone(&overlay.base), + Self::Overlay(overlay) => overlay.base_catalog(), } } pub fn get(&self, name: &str) -> Option> { match self { Self::Base(catalog) => catalog.get(name), - Self::Overlay(overlay) => overlay.custom.get(name).or_else(|| overlay.base.get(name)), + Self::Overlay(overlay) => overlay.get(name), } } pub fn equivalent_model_names(&self, name: &str) -> Vec { match self { Self::Base(catalog) => catalog.equivalent_model_names(name), - Self::Overlay(overlay) => { - if overlay.custom.get(name).is_some() { - overlay.custom.equivalent_model_names(name) - } else { - overlay.base.equivalent_model_names(name) - } - } + Self::Overlay(overlay) => overlay.equivalent_model_names(name), } } } @@ -108,6 +158,72 @@ fn parse_equivalent_models(content: &str) -> Result> Ok(equivalent_models) } +fn build_equivalence_index( + model_names: HashSet, + equivalent_models: &HashMap>, +) -> HashMap> { + let mut adjacency: HashMap> = HashMap::new(); + for name in &model_names { + adjacency.entry(name.clone()).or_default(); + } + + for (name, equivalents) in equivalent_models { + if !model_names.contains(name) { + continue; + } + for equivalent_model in equivalents { + if !model_names.contains(equivalent_model) { + continue; + } + adjacency + .entry(name.clone()) + .or_default() + .push(equivalent_model.clone()); + adjacency + .entry(equivalent_model.clone()) + .or_default() + .push(name.clone()); + } + } + + let mut visited = HashSet::new(); + let mut index = HashMap::new(); + for name in model_names { + if visited.contains(&name) { + continue; + } + + let mut stack = vec![name.clone()]; + let mut component = Vec::new(); + while let Some(current) = stack.pop() { + if !visited.insert(current.clone()) { + continue; + } + component.push(current.clone()); + if let Some(neighbors) = adjacency.get(¤t) { + stack.extend(neighbors.iter().cloned()); + } + } + + if component.len() <= 1 { + continue; + } + component.sort(); + for member in &component { + index.insert( + member.clone(), + component + .iter() + .filter(|other| *other != member) + .cloned() + .collect(), + ); + } + } + + index +} + impl ModelCatalog { pub fn empty() -> Self { Self::default() @@ -271,6 +387,27 @@ impl ModelCatalog { Ok(()) } + pub fn add_external_equivalent_models(&mut self, name: String, equivalents: I) -> Result<()> + where + I: IntoIterator, + { + if !self.models.contains_key(&name) { + return Err(Error::InvalidRequest(format!( + "model '{name}' references equivalent_models but is missing from catalog" + ))); + } + + let entry = self.equivalent_models.entry(name).or_default(); + for equivalent_model in equivalents { + if equivalent_model.is_empty() || entry.contains(&equivalent_model) { + continue; + } + entry.push(equivalent_model); + } + self.rebuild_equivalence_index(); + Ok(()) + } + fn validate_equivalent_models(&self) -> Result<()> { for (name, equivalents) in &self.equivalent_models { for equivalent_model in equivalents { @@ -285,66 +422,10 @@ impl ModelCatalog { } fn rebuild_equivalence_index(&mut self) { - let mut adjacency: HashMap> = HashMap::new(); - for name in self.models.keys() { - adjacency.entry(name.clone()).or_default(); - } - - for (name, equivalents) in &self.equivalent_models { - if !self.models.contains_key(name) { - continue; - } - for equivalent_model in equivalents { - if !self.models.contains_key(equivalent_model) { - continue; - } - adjacency - .entry(name.clone()) - .or_default() - .push(equivalent_model.clone()); - adjacency - .entry(equivalent_model.clone()) - .or_default() - .push(name.clone()); - } - } - - let mut visited = std::collections::HashSet::new(); - let mut index = HashMap::new(); - for name in self.models.keys() { - if visited.contains(name) { - continue; - } - - let mut stack = vec![name.clone()]; - let mut component = Vec::new(); - while let Some(current) = stack.pop() { - if !visited.insert(current.clone()) { - continue; - } - component.push(current.clone()); - if let Some(neighbors) = adjacency.get(¤t) { - stack.extend(neighbors.iter().cloned()); - } - } - - if component.len() <= 1 { - continue; - } - component.sort(); - for member in &component { - index.insert( - member.clone(), - component - .iter() - .filter(|other| *other != member) - .cloned() - .collect(), - ); - } - } - - self.equivalence_index = index; + self.equivalence_index = build_equivalence_index( + self.models.keys().cloned().collect(), + &self.equivalent_models, + ); } } @@ -533,4 +614,56 @@ mod tests { vec!["model-a".to_string(), "model-b".to_string()] ); } + + #[test] + fn overlay_equivalence_index_does_not_inherit_shadowed_base_edges() { + let base = Arc::new( + ModelCatalog::from_json_str( + r#"{ + "model-a": { + "format": "openai", + "flavor": "chat", + "equivalent_models": ["model-b"] + }, + "model-b": { + "format": "openai", + "flavor": "chat" + } +}"#, + ) + .expect("base catalog parses"), + ); + let mut custom = ModelCatalog::empty(); + custom.insert( + "model-b".to_string(), + ModelSpec { + model: "custom-model-b".to_string(), + format: ProviderFormat::Anthropic, + flavor: ModelFlavor::Chat, + display_name: None, + parent: None, + input_cost_per_mil_tokens: None, + output_cost_per_mil_tokens: None, + input_cache_read_cost_per_mil_tokens: None, + multimodal: None, + reasoning: None, + max_input_tokens: None, + max_output_tokens: None, + supports_streaming: true, + extra: Default::default(), + available_providers: vec!["custom-provider".to_string()], + }, + ); + + let overlay = OverlayModelCatalog::new(base, custom); + + assert_eq!( + overlay.equivalent_model_names("model-a"), + vec!["model-a".to_string()] + ); + assert_eq!( + overlay.equivalent_model_names("model-b"), + vec!["model-b".to_string()] + ); + } } diff --git a/crates/braintrust-llm-router/src/catalog/resolver.rs b/crates/braintrust-llm-router/src/catalog/resolver.rs index b1bcef72..2522e1f5 100644 --- a/crates/braintrust-llm-router/src/catalog/resolver.rs +++ b/crates/braintrust-llm-router/src/catalog/resolver.rs @@ -27,7 +27,7 @@ impl ModelResolver { /// do not observe per-request custom models as global catalog entries. pub fn with_overlay(base: Arc, custom: ModelCatalog) -> Self { Self { - catalog: CatalogResolver::Overlay(Box::new(OverlayModelCatalog { base, custom })), + catalog: CatalogResolver::Overlay(Box::new(OverlayModelCatalog::new(base, custom))), aliases: HashMap::new(), } } diff --git a/crates/braintrust-llm-router/src/router.rs b/crates/braintrust-llm-router/src/router.rs index 773bf4c9..0e278d18 100644 --- a/crates/braintrust-llm-router/src/router.rs +++ b/crates/braintrust-llm-router/src/router.rs @@ -2659,6 +2659,70 @@ mod tests { ); } + #[test] + fn overlay_catalog_failover_routes_use_equivalent_base_model() { + let mut base = ModelCatalog::empty(); + base.insert( + "base-fallback".into(), + openai_spec_with_available_providers("base-fallback", ModelFlavor::Chat), + ); + let mut custom = ModelCatalog::empty(); + let mut primary = openai_spec("custom-primary", ModelFlavor::Chat); + primary.available_providers = vec!["provider-a".to_string()]; + custom.insert("custom-primary".into(), primary); + custom + .add_external_equivalent_models( + "custom-primary".to_string(), + vec!["base-fallback".to_string()], + ) + .expect("equivalence is valid"); + + let router = Router::builder() + .with_overlay_catalog(Arc::new(base), custom) + .add_provider( + "provider-a", + FakeProvider { + name: "openai", + formats: vec![ProviderFormat::ChatCompletions], + }, + dummy_auth(), + vec![], + ) + .add_provider( + "openai", + FakeProvider { + name: "openai", + formats: vec![ProviderFormat::ChatCompletions], + }, + dummy_auth(), + vec![], + ) + .build() + .expect("router builds"); + + let routes = router + .resolve_provider_routes( + "custom-primary", + ProviderFormat::ChatCompletions, + &["provider-a".to_string(), "openai".to_string()], + ) + .expect("failover routes resolve"); + let route_info: Vec<(&str, &str)> = routes + .iter() + .map(|route| (route.provider_alias(), route.model())) + .collect(); + + assert_eq!( + route_info, + vec![ + ("provider-a", "custom-primary"), + ("openai", "base-fallback"), + ] + ); + assert!(router.catalog().get("base-fallback").is_some()); + assert!(router.catalog().get("custom-primary").is_none()); + } + #[test] fn resolved_aliases_returns_only_registered_available_providers() { let model = "gpt-4o"; From 9817f7e122c65ffea6c415dcfd65399734a82a87 Mon Sep 17 00:00:00 2001 From: Erin McNulty Date: Tue, 16 Jun 2026 10:00:19 -0400 Subject: [PATCH 09/28] address comments --- .../src/catalog/resolver.rs | 2 +- crates/braintrust-llm-router/src/router.rs | 157 +++++++++++++++++- 2 files changed, 152 insertions(+), 7 deletions(-) diff --git a/crates/braintrust-llm-router/src/catalog/resolver.rs b/crates/braintrust-llm-router/src/catalog/resolver.rs index 2522e1f5..d274b5a5 100644 --- a/crates/braintrust-llm-router/src/catalog/resolver.rs +++ b/crates/braintrust-llm-router/src/catalog/resolver.rs @@ -45,7 +45,7 @@ impl ModelResolver { self.resolve_one(model) } - pub fn resolve_for_failover(&self, model: &str) -> Result> { + pub fn resolve_all_equivalent_model_routes(&self, model: &str) -> Result> { let mut resolved = Vec::new(); for model_name in self.catalog.equivalent_model_names(model) { resolved.push(self.resolve_one(&model_name)?); diff --git a/crates/braintrust-llm-router/src/router.rs b/crates/braintrust-llm-router/src/router.rs index 0e278d18..3fb28405 100644 --- a/crates/braintrust-llm-router/src/router.rs +++ b/crates/braintrust-llm-router/src/router.rs @@ -208,6 +208,23 @@ struct PreparedRequestInner { strategy: RetryStrategy, } +fn override_payload_model(payload: Bytes, model: &str) -> Bytes { + let Ok(mut value) = serde_json::from_slice::(&payload) else { + return payload; + }; + let Some(object) = value.as_object_mut() else { + return payload; + }; + if object.get("model").and_then(Value::as_str) == Some(model) { + return payload; + } + object.insert("model".to_string(), Value::String(model.to_string())); + match serde_json::to_vec(&value) { + Ok(serialized) => Bytes::from(serialized), + Err(_) => payload, + } +} + async fn prepare_provider_request( body: Bytes, spec: &ModelSpec, @@ -231,6 +248,8 @@ async fn prepare_provider_request( Err(err) => return Err(err.into()), }; + let transformed = override_payload_model(transformed, &spec.model); + if stream { // TODO: Fold streaming intent into `lingua::transform_request` once we // are ready to update its Rust/WASM/Python/TS call sites together. @@ -635,7 +654,7 @@ impl Router { output_format: ProviderFormat, fallback_aliases: &[String], ) -> Result> { - let resolved_models = self.resolver.resolve_for_failover(model)?; + let resolved_models = self.resolver.resolve_all_equivalent_model_routes(model)?; let (_, first_catalog_format, _) = resolved_models .first() .ok_or_else(|| Error::UnknownModel(model.to_string()))?; @@ -936,7 +955,14 @@ fn default_alias_provider_id(alias: &str) -> Option<&'static str> { fn concrete_provider_id_alias(alias: &str) -> bool { matches!( alias, - "anthropic" | "google" | "mistral" | "bedrock" | "vertex" | "azure" | "databricks" + "openai" + | "anthropic" + | "google" + | "mistral" + | "bedrock" + | "vertex" + | "azure" + | "databricks" ) } @@ -2856,10 +2882,7 @@ mod tests { ) .add_provider( "cerebras", - FakeProvider { - name: "openai", - formats: vec![ProviderFormat::ChatCompletions], - }, + FakeOpenAICompatibleProvider { alias: "cerebras" }, dummy_auth(), vec![], ) @@ -2911,6 +2934,128 @@ mod tests { ); } + #[test] + fn failover_routes_match_named_openai_when_equivalents_omit_available_providers() { + let model = "custom-primary"; + let fallback_model = "gpt-4o"; + let mut base = ModelCatalog::empty(); + base.insert( + fallback_model.into(), + openai_spec(fallback_model, ModelFlavor::Chat), + ); + let mut custom = ModelCatalog::empty(); + let mut primary = openai_spec(model, ModelFlavor::Chat); + primary.available_providers = vec!["provider-a".to_string()]; + custom.insert(model.into(), primary); + custom + .add_external_equivalent_models(model.to_string(), vec![fallback_model.to_string()]) + .expect("equivalence is valid"); + + let router = Router::builder() + .with_overlay_catalog(Arc::new(base), custom) + .add_provider( + "provider-a", + FakeProvider { + name: "openai", + formats: vec![ProviderFormat::ChatCompletions], + }, + dummy_auth(), + vec![], + ) + .add_provider( + "my-openai", + FakeOpenAICompatibleProvider { alias: "openai" }, + dummy_auth(), + vec![], + ) + .build() + .expect("router builds"); + + let routes = router + .resolve_provider_routes( + model, + ProviderFormat::ChatCompletions, + &["provider-a".to_string(), "my-openai".to_string()], + ) + .expect("routes resolve"); + let route_info: Vec<(&str, &str)> = routes + .iter() + .map(|route| (route.provider_alias(), route.model())) + .collect(); + + assert_eq!( + route_info, + vec![("provider-a", model), ("my-openai", fallback_model)] + ); + } + + #[tokio::test] + async fn failover_request_payload_uses_equivalent_route_model_for_same_format() { + let model = "gpt-4o"; + let fallback_model = "other-provider/gpt-4o"; + let catalog = ModelCatalog::from_json_str( + r#"{ + "gpt-4o": { + "format": "openai", + "flavor": "chat", + "available_providers": ["provider-a"], + "equivalent_models": ["other-provider/gpt-4o"] + }, + "other-provider/gpt-4o": { + "format": "openai", + "flavor": "chat" + } +}"#, + ) + .expect("catalog parses"); + let router = Router::builder() + .with_catalog(Arc::new(catalog)) + .add_provider( + "provider-a", + FakeProvider { + name: "openai", + formats: vec![ProviderFormat::ChatCompletions], + }, + dummy_auth(), + vec![], + ) + .add_provider( + "my-openai", + FakeOpenAICompatibleProvider { alias: "openai" }, + dummy_auth(), + vec![], + ) + .build() + .expect("router builds"); + + let routes = router + .resolve_provider_routes( + model, + ProviderFormat::ChatCompletions, + &["provider-a".to_string(), "my-openai".to_string()], + ) + .expect("routes resolve"); + let fallback_route = routes + .iter() + .find(|route| route.provider_alias() == "my-openai") + .expect("fallback route exists"); + assert_eq!(fallback_route.model(), fallback_model); + + let body = Bytes::from_static( + br#"{"model":"gpt-4o","messages":[{"role":"user","content":"Ping"}]}"#, + ); + let (request, _) = router + .create_request(body, ProviderFormat::ChatCompletions, fallback_route) + .await + .expect("request prepares"); + let payload: Value = serde_json::from_slice(&request.inner.payload).expect("json"); + + assert_eq!( + payload.get("model").and_then(Value::as_str), + Some(fallback_model) + ); + } + #[test] fn fallback_provider_routes_do_not_match_openai_default_alias_to_compatible_provider() { let model = "gpt-4o"; From 0248a518f1be3114e52c1267f181092c8994cdea Mon Sep 17 00:00:00 2001 From: Erin McNulty Date: Tue, 16 Jun 2026 10:22:46 -0400 Subject: [PATCH 10/28] comment --- .../braintrust-llm-router/src/catalog/mod.rs | 4 +- .../src/catalog/resolver.rs | 4 -- .../src/providers/mod.rs | 10 ++++ crates/braintrust-llm-router/src/router.rs | 46 +++++++++++++++++-- 4 files changed, 54 insertions(+), 10 deletions(-) diff --git a/crates/braintrust-llm-router/src/catalog/mod.rs b/crates/braintrust-llm-router/src/catalog/mod.rs index f9c170d5..a489a00d 100644 --- a/crates/braintrust-llm-router/src/catalog/mod.rs +++ b/crates/braintrust-llm-router/src/catalog/mod.rs @@ -31,8 +31,8 @@ pub struct ModelCatalog { /// per-request model definitions. #[derive(Debug, Clone)] pub struct OverlayModelCatalog { - pub base: Arc, - pub custom: ModelCatalog, + base: Arc, + custom: ModelCatalog, equivalence_index: HashMap>, } diff --git a/crates/braintrust-llm-router/src/catalog/resolver.rs b/crates/braintrust-llm-router/src/catalog/resolver.rs index d274b5a5..53d8edef 100644 --- a/crates/braintrust-llm-router/src/catalog/resolver.rs +++ b/crates/braintrust-llm-router/src/catalog/resolver.rs @@ -21,10 +21,6 @@ impl ModelResolver { } } - /// Resolve models against custom entries first, then the shared base catalog. - /// - /// The base catalog remains the public `catalog()` view so existing callers - /// do not observe per-request custom models as global catalog entries. pub fn with_overlay(base: Arc, custom: ModelCatalog) -> Self { Self { catalog: CatalogResolver::Overlay(Box::new(OverlayModelCatalog::new(base, custom))), diff --git a/crates/braintrust-llm-router/src/providers/mod.rs b/crates/braintrust-llm-router/src/providers/mod.rs index 051e4ea9..539ce873 100644 --- a/crates/braintrust-llm-router/src/providers/mod.rs +++ b/crates/braintrust-llm-router/src/providers/mod.rs @@ -176,6 +176,16 @@ pub(crate) fn disable_streaming_payload(payload: Bytes) -> Bytes { } } +pub(crate) fn format_carries_model_in_body(format: ProviderFormat) -> bool { + matches!( + format, + ProviderFormat::ChatCompletions + | ProviderFormat::Responses + | ProviderFormat::Anthropic + | ProviderFormat::Mistral + ) +} + pub(crate) fn enable_streaming_payload(payload: Bytes, format: ProviderFormat) -> Bytes { let Ok(mut value) = serde_json::from_slice::(&payload) else { return payload; diff --git a/crates/braintrust-llm-router/src/router.rs b/crates/braintrust-llm-router/src/router.rs index 3fb28405..2d2ada91 100644 --- a/crates/braintrust-llm-router/src/router.rs +++ b/crates/braintrust-llm-router/src/router.rs @@ -15,8 +15,8 @@ use crate::catalog::{ use crate::client::ClientSettings; use crate::error::{Error, Result}; use crate::providers::{ - enable_streaming_payload, prepare_bedrock_request, requires_bedrock_request_preparation, - ClientHeaders, Provider, + enable_streaming_payload, format_carries_model_in_body, prepare_bedrock_request, + requires_bedrock_request_preparation, ClientHeaders, Provider, }; use crate::retry::{RetryPolicy, RetryStrategy}; use crate::streaming::{ @@ -208,7 +208,10 @@ struct PreparedRequestInner { strategy: RetryStrategy, } -fn override_payload_model(payload: Bytes, model: &str) -> Bytes { +fn override_payload_model(payload: Bytes, format: ProviderFormat, model: &str) -> Bytes { + if !format_carries_model_in_body(format) { + return payload; + } let Ok(mut value) = serde_json::from_slice::(&payload) else { return payload; }; @@ -248,7 +251,7 @@ async fn prepare_provider_request( Err(err) => return Err(err.into()), }; - let transformed = override_payload_model(transformed, &spec.model); + let transformed = override_payload_model(transformed, actual_format, &spec.model); if stream { // TODO: Fold streaming intent into `lingua::transform_request` once we @@ -1451,6 +1454,41 @@ mod tests { assert_eq!(parsed.get("stream_options"), None); } + #[tokio::test] + async fn prepare_provider_request_does_not_readd_model_for_vertex_anthropic() { + let body = Bytes::from_static( + br#"{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"Ping"}]}"#, + ); + let spec = ModelSpec { + model: "publishers/anthropic/models/claude-sonnet-4-6".to_string(), + format: ProviderFormat::Anthropic, + flavor: ModelFlavor::Chat, + display_name: None, + parent: None, + input_cost_per_mil_tokens: None, + output_cost_per_mil_tokens: None, + input_cache_read_cost_per_mil_tokens: None, + multimodal: None, + reasoning: None, + max_input_tokens: None, + max_output_tokens: None, + supports_streaming: true, + extra: Default::default(), + available_providers: vec!["vertex".to_string()], + }; + + let (payload, _, actual_format) = + prepare_provider_request(body, &spec, ProviderFormat::VertexAnthropic, false) + .await + .expect("request prepares"); + let parsed: Value = serde_json::from_slice(&payload).expect("valid request json"); + + assert_eq!(actual_format, ProviderFormat::VertexAnthropic); + assert_eq!(parsed.get("model"), None); + assert!(parsed.get("anthropic_version").is_some()); + assert!(parsed.get("messages").is_some()); + } + #[tokio::test] async fn prepare_provider_request_upgrades_actual_format_to_responses_for_reasoning_plus_tools() { From 73fc1c099321a591ecc7174231420e8f944489f1 Mon Sep 17 00:00:00 2001 From: Erin McNulty Date: Tue, 16 Jun 2026 10:32:18 -0400 Subject: [PATCH 11/28] address comment --- crates/braintrust-llm-router/src/router.rs | 79 ++++++++++++++++------ 1 file changed, 60 insertions(+), 19 deletions(-) diff --git a/crates/braintrust-llm-router/src/router.rs b/crates/braintrust-llm-router/src/router.rs index 2d2ada91..0d45f46c 100644 --- a/crates/braintrust-llm-router/src/router.rs +++ b/crates/braintrust-llm-router/src/router.rs @@ -739,11 +739,12 @@ impl Router { .is_some_and(|provider| provider.matches_provider_alias(provider_id)); } - if concrete_provider_id_alias(resolver_alias) { - return self - .providers - .get(provider_alias) - .is_some_and(|provider| provider.matches_provider_alias(resolver_alias)); + if self + .providers + .get(provider_alias) + .is_some_and(|provider| provider.matches_provider_alias(resolver_alias)) + { + return true; } default_alias_provider_id(provider_alias) @@ -955,20 +956,6 @@ fn default_alias_provider_id(alias: &str) -> Option<&'static str> { } } -fn concrete_provider_id_alias(alias: &str) -> bool { - matches!( - alias, - "openai" - | "anthropic" - | "google" - | "mistral" - | "bedrock" - | "vertex" - | "azure" - | "databricks" - ) -} - /// One provider registration: alias, provider, auth, and default formats. struct ProviderEntry { alias: String, @@ -3027,6 +3014,60 @@ mod tests { ); } + #[test] + fn failover_routes_match_named_openai_compatible_provider_id() { + let model = "custom-primary"; + let fallback_model = "llama-4-scout"; + let mut base = ModelCatalog::empty(); + let mut fallback = openai_spec(fallback_model, ModelFlavor::Chat); + fallback.available_providers = vec!["cerebras".to_string()]; + base.insert(fallback_model.into(), fallback); + let mut custom = ModelCatalog::empty(); + let mut primary = openai_spec(model, ModelFlavor::Chat); + primary.available_providers = vec!["provider-a".to_string()]; + custom.insert(model.into(), primary); + custom + .add_external_equivalent_models(model.to_string(), vec![fallback_model.to_string()]) + .expect("equivalence is valid"); + + let router = Router::builder() + .with_overlay_catalog(Arc::new(base), custom) + .add_provider( + "provider-a", + FakeProvider { + name: "openai", + formats: vec![ProviderFormat::ChatCompletions], + }, + dummy_auth(), + vec![], + ) + .add_provider( + "my-cerebras", + FakeOpenAICompatibleProvider { alias: "cerebras" }, + dummy_auth(), + vec![], + ) + .build() + .expect("router builds"); + + let routes = router + .resolve_provider_routes( + model, + ProviderFormat::ChatCompletions, + &["provider-a".to_string(), "my-cerebras".to_string()], + ) + .expect("routes resolve"); + let route_info: Vec<(&str, &str)> = routes + .iter() + .map(|route| (route.provider_alias(), route.model())) + .collect(); + + assert_eq!( + route_info, + vec![("provider-a", model), ("my-cerebras", fallback_model)] + ); + } + #[tokio::test] async fn failover_request_payload_uses_equivalent_route_model_for_same_format() { let model = "gpt-4o"; From 7cb2937f0e5fd9a4813eddd0f324eaa55e971bd2 Mon Sep 17 00:00:00 2001 From: Erin McNulty Date: Tue, 16 Jun 2026 10:44:00 -0400 Subject: [PATCH 12/28] comment --- crates/braintrust-llm-router/src/router.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/braintrust-llm-router/src/router.rs b/crates/braintrust-llm-router/src/router.rs index 0d45f46c..a28e5967 100644 --- a/crates/braintrust-llm-router/src/router.rs +++ b/crates/braintrust-llm-router/src/router.rs @@ -132,7 +132,7 @@ pub fn create_provider( metadata, client_settings, )? - .with_provider_alias(kind), + .with_provider_alias(kind.to_ascii_lowercase()), )), other => Err(Error::InvalidRequest(format!( "unsupported provider kind: {other}" From ba087683dcf776757896c74de3dd9b03ba641901 Mon Sep 17 00:00:00 2001 From: Erin McNulty Date: Tue, 16 Jun 2026 10:48:56 -0400 Subject: [PATCH 13/28] equivalent => fallback --- .../braintrust-llm-router/src/catalog/mod.rs | 142 +++++++++--------- crates/braintrust-llm-router/src/router.rs | 16 +- 2 files changed, 80 insertions(+), 78 deletions(-) diff --git a/crates/braintrust-llm-router/src/catalog/mod.rs b/crates/braintrust-llm-router/src/catalog/mod.rs index a489a00d..b50fc84a 100644 --- a/crates/braintrust-llm-router/src/catalog/mod.rs +++ b/crates/braintrust-llm-router/src/catalog/mod.rs @@ -20,7 +20,7 @@ pub struct ModelCatalog { models: HashMap>, by_format: HashMap>, by_parent: HashMap>, - equivalent_models: HashMap>, + fallback_models: HashMap>, equivalence_index: HashMap>, } @@ -46,23 +46,23 @@ impl OverlayModelCatalog { .chain(custom.models.keys()) .cloned() .collect(); - let mut equivalence_edges: HashMap> = base - .equivalent_models + let mut fallback_edges: HashMap> = base + .fallback_models .iter() - .filter(|(name, _)| !custom_model_names.contains(*name)) - .map(|(name, equivalents)| { + .filter(|(name, _fallbacks)| !custom_model_names.contains(*name)) + .map(|(name, fallbacks)| { ( name.clone(), - equivalents + fallbacks .iter() - .filter(|equivalent| !custom_model_names.contains(*equivalent)) + .filter(|fallback_model| !custom_model_names.contains(*fallback_model)) .cloned() .collect(), ) }) .collect(); - equivalence_edges.extend(custom.equivalent_models.clone()); - let equivalence_index = build_equivalence_index(visible_models, &equivalence_edges); + fallback_edges.extend(custom.fallback_models.clone()); + let equivalence_index = build_equivalence_index(visible_models, &fallback_edges); Self { base, custom, @@ -130,57 +130,57 @@ impl From> for CatalogResolver { } } -fn parse_equivalent_models(content: &str) -> Result>> { +fn parse_fallback_models(content: &str) -> Result>> { let raw: HashMap = serde_json::from_str(content)?; - let mut equivalent_models = HashMap::new(); + let mut fallback_models = HashMap::new(); for (name, value) in raw { - let Some(equivalents) = value.get("equivalent_models") else { + let Some(fallbacks) = value.get("fallback_models") else { continue; }; - let Some(equivalents) = equivalents.as_array() else { + let Some(fallbacks) = fallbacks.as_array() else { return Err(Error::InvalidRequest(format!( - "model '{name}' has invalid equivalent_models" + "model '{name}' has invalid fallback_models" ))); }; - let mut parsed = Vec::with_capacity(equivalents.len()); - for equivalent in equivalents { - let Some(equivalent) = equivalent.as_str() else { + let mut parsed = Vec::with_capacity(fallbacks.len()); + for fallback_model in fallbacks { + let Some(fallback_model) = fallback_model.as_str() else { return Err(Error::InvalidRequest(format!( - "model '{name}' has invalid equivalent_models" + "model '{name}' has invalid fallback_models" ))); }; - parsed.push(equivalent.to_string()); + parsed.push(fallback_model.to_string()); } if !parsed.is_empty() { - equivalent_models.insert(name, parsed); + fallback_models.insert(name, parsed); } } - Ok(equivalent_models) + Ok(fallback_models) } fn build_equivalence_index( model_names: HashSet, - equivalent_models: &HashMap>, + fallback_models: &HashMap>, ) -> HashMap> { let mut adjacency: HashMap> = HashMap::new(); for name in &model_names { adjacency.entry(name.clone()).or_default(); } - for (name, equivalents) in equivalent_models { + for (name, fallbacks) in fallback_models { if !model_names.contains(name) { continue; } - for equivalent_model in equivalents { - if !model_names.contains(equivalent_model) { + for fallback_model in fallbacks { + if !model_names.contains(fallback_model) { continue; } adjacency .entry(name.clone()) .or_default() - .push(equivalent_model.clone()); + .push(fallback_model.clone()); adjacency - .entry(equivalent_model.clone()) + .entry(fallback_model.clone()) .or_default() .push(name.clone()); } @@ -231,13 +231,13 @@ impl ModelCatalog { pub fn from_json_str(content: &str) -> Result { let raw: HashMap = serde_json::from_str(content)?; - let equivalent_models = parse_equivalent_models(content)?; + let fallback_models = parse_fallback_models(content)?; let mut catalog = Self::empty(); for (name, spec) in raw { catalog.insert(name, spec); } - catalog.equivalent_models = equivalent_models; - catalog.validate_equivalent_models()?; + catalog.fallback_models = fallback_models; + catalog.validate_fallback_models()?; catalog.rebuild_equivalence_index(); Ok(catalog) } @@ -322,7 +322,7 @@ impl ModelCatalog { F: FnMut(&str, &ModelSpec) -> ModelSpec, { let mut out = Self { - equivalent_models: self.equivalent_models.clone(), + fallback_models: self.fallback_models.clone(), ..Self::empty() }; for (name, spec) in &self.models { @@ -354,66 +354,70 @@ impl ModelCatalog { } } - pub fn add_equivalent_models(&mut self, name: String, equivalents: I) -> Result<()> + pub fn add_fallback_models(&mut self, name: String, fallback_models: I) -> Result<()> where I: IntoIterator, { if !self.models.contains_key(&name) { return Err(Error::InvalidRequest(format!( - "model '{name}' references equivalent_models but is missing from catalog" + "model '{name}' references fallback_models but is missing from catalog" ))); } - let equivalents: Vec = equivalents + let fallback_models: Vec = fallback_models .into_iter() - .filter(|equivalent_model| !equivalent_model.is_empty()) + .filter(|fallback_model| !fallback_model.is_empty()) .collect(); - for equivalent_model in &equivalents { - if !self.models.contains_key(equivalent_model) { + for fallback_model in &fallback_models { + if !self.models.contains_key(fallback_model) { return Err(Error::InvalidRequest(format!( - "model '{name}' references missing equivalent model '{equivalent_model}'" + "model '{name}' references missing fallback model '{fallback_model}'" ))); } } - let entry = self.equivalent_models.entry(name).or_default(); - for equivalent_model in equivalents { - if entry.contains(&equivalent_model) { + let entry = self.fallback_models.entry(name).or_default(); + for fallback_model in fallback_models { + if entry.contains(&fallback_model) { continue; } - entry.push(equivalent_model); + entry.push(fallback_model); } self.rebuild_equivalence_index(); Ok(()) } - pub fn add_external_equivalent_models(&mut self, name: String, equivalents: I) -> Result<()> + pub fn add_external_fallback_models( + &mut self, + name: String, + fallback_models: I, + ) -> Result<()> where I: IntoIterator, { if !self.models.contains_key(&name) { return Err(Error::InvalidRequest(format!( - "model '{name}' references equivalent_models but is missing from catalog" + "model '{name}' references fallback_models but is missing from catalog" ))); } - let entry = self.equivalent_models.entry(name).or_default(); - for equivalent_model in equivalents { - if equivalent_model.is_empty() || entry.contains(&equivalent_model) { + let entry = self.fallback_models.entry(name).or_default(); + for fallback_model in fallback_models { + if fallback_model.is_empty() || entry.contains(&fallback_model) { continue; } - entry.push(equivalent_model); + entry.push(fallback_model); } self.rebuild_equivalence_index(); Ok(()) } - fn validate_equivalent_models(&self) -> Result<()> { - for (name, equivalents) in &self.equivalent_models { - for equivalent_model in equivalents { - if !self.models.contains_key(equivalent_model) { + fn validate_fallback_models(&self) -> Result<()> { + for (name, fallback_models) in &self.fallback_models { + for fallback_model in fallback_models { + if !self.models.contains_key(fallback_model) { return Err(Error::InvalidRequest(format!( - "model '{name}' references missing equivalent model '{equivalent_model}'" + "model '{name}' references missing fallback model '{fallback_model}'" ))); } } @@ -422,10 +426,8 @@ impl ModelCatalog { } fn rebuild_equivalence_index(&mut self) { - self.equivalence_index = build_equivalence_index( - self.models.keys().cloned().collect(), - &self.equivalent_models, - ); + self.equivalence_index = + build_equivalence_index(self.models.keys().cloned().collect(), &self.fallback_models); } } @@ -450,7 +452,7 @@ mod tests { "claude-sonnet-4-6": { "format": "anthropic", "flavor": "chat", - "equivalent_models": [ + "fallback_models": [ "publishers/anthropic/models/claude-sonnet-4-6", "anthropic.claude-sonnet-4-6" ] @@ -492,12 +494,12 @@ mod tests { "model-a": { "format": "openai", "flavor": "chat", - "equivalent_models": ["model-b"] + "fallback_models": ["model-b"] }, "model-b": { "format": "openai", "flavor": "chat", - "equivalent_models": ["model-c"] + "fallback_models": ["model-c"] }, "model-c": { "format": "openai", @@ -518,23 +520,23 @@ mod tests { } #[test] - fn missing_equivalent_model_reference_is_invalid() { + fn missing_fallback_model_reference_is_invalid() { let error = ModelCatalog::from_json_str( r#"{ "model-a": { "format": "openai", "flavor": "chat", - "equivalent_models": ["missing-model"] + "fallback_models": ["missing-model"] } }"#, ) - .expect_err("missing equivalent model should fail"); + .expect_err("missing fallback model should fail"); assert!(matches!(error, Error::InvalidRequest(_))); } #[test] - fn add_equivalent_models_rebuilds_index() { + fn add_fallback_models_rebuilds_index() { let mut catalog = ModelCatalog::from_json_str( r#"{ "model-a": { @@ -550,7 +552,7 @@ mod tests { .expect("catalog parses"); catalog - .add_equivalent_models("model-a".to_string(), vec!["model-b".to_string()]) + .add_fallback_models("model-a".to_string(), vec!["model-b".to_string()]) .expect("equivalence is valid"); assert_eq!( @@ -564,7 +566,7 @@ mod tests { } #[test] - fn add_equivalent_models_rejects_missing_reference() { + fn add_fallback_models_rejects_missing_reference() { let mut catalog = ModelCatalog::from_json_str( r#"{ "model-a": { @@ -576,8 +578,8 @@ mod tests { .expect("catalog parses"); let error = catalog - .add_equivalent_models("model-a".to_string(), vec!["missing".to_string()]) - .expect_err("missing equivalent model should fail"); + .add_fallback_models("model-a".to_string(), vec!["missing".to_string()]) + .expect_err("missing fallback model should fail"); assert!(matches!(error, Error::InvalidRequest(_))); assert_eq!( @@ -593,7 +595,7 @@ mod tests { "model-a": { "format": "openai", "flavor": "chat", - "equivalent_models": ["model-b"] + "fallback_models": ["model-b"] }, "model-b": { "format": "openai", @@ -623,7 +625,7 @@ mod tests { "model-a": { "format": "openai", "flavor": "chat", - "equivalent_models": ["model-b"] + "fallback_models": ["model-b"] }, "model-b": { "format": "openai", diff --git a/crates/braintrust-llm-router/src/router.rs b/crates/braintrust-llm-router/src/router.rs index a28e5967..367fcbe1 100644 --- a/crates/braintrust-llm-router/src/router.rs +++ b/crates/braintrust-llm-router/src/router.rs @@ -2660,7 +2660,7 @@ mod tests { custom.insert("custom-primary".into(), primary); custom.insert("custom-fallback".into(), fallback); custom - .add_equivalent_models( + .add_fallback_models( "custom-primary".to_string(), vec!["custom-fallback".to_string()], ) @@ -2722,7 +2722,7 @@ mod tests { primary.available_providers = vec!["provider-a".to_string()]; custom.insert("custom-primary".into(), primary); custom - .add_external_equivalent_models( + .add_external_fallback_models( "custom-primary".to_string(), vec!["base-fallback".to_string()], ) @@ -2973,7 +2973,7 @@ mod tests { primary.available_providers = vec!["provider-a".to_string()]; custom.insert(model.into(), primary); custom - .add_external_equivalent_models(model.to_string(), vec![fallback_model.to_string()]) + .add_external_fallback_models(model.to_string(), vec![fallback_model.to_string()]) .expect("equivalence is valid"); let router = Router::builder() @@ -3027,7 +3027,7 @@ mod tests { primary.available_providers = vec!["provider-a".to_string()]; custom.insert(model.into(), primary); custom - .add_external_equivalent_models(model.to_string(), vec![fallback_model.to_string()]) + .add_external_fallback_models(model.to_string(), vec![fallback_model.to_string()]) .expect("equivalence is valid"); let router = Router::builder() @@ -3078,7 +3078,7 @@ mod tests { "format": "openai", "flavor": "chat", "available_providers": ["provider-a"], - "equivalent_models": ["other-provider/gpt-4o"] + "fallback_models": ["other-provider/gpt-4o"] }, "other-provider/gpt-4o": { "format": "openai", @@ -3188,7 +3188,7 @@ mod tests { "format": "anthropic", "flavor": "chat", "available_providers": ["ANTHROPIC_API_KEY"], - "equivalent_models": ["publishers/anthropic/models/claude-sonnet-4-6"] + "fallback_models": ["publishers/anthropic/models/claude-sonnet-4-6"] }, "publishers/anthropic/models/claude-sonnet-4-6": { "format": "anthropic", @@ -3248,7 +3248,7 @@ mod tests { "claude-sonnet-4-6": { "format": "anthropic", "flavor": "chat", - "equivalent_models": ["publishers/anthropic/models/claude-sonnet-4-6"] + "fallback_models": ["publishers/anthropic/models/claude-sonnet-4-6"] }, "publishers/anthropic/models/claude-sonnet-4-6": { "format": "anthropic", @@ -3463,7 +3463,7 @@ mod tests { "format": "anthropic", "flavor": "chat", "available_providers": ["anthropic"], - "equivalent_models": ["publishers/anthropic/models/claude-sonnet-4-6"] + "fallback_models": ["publishers/anthropic/models/claude-sonnet-4-6"] }, "publishers/anthropic/models/claude-sonnet-4-6": { "format": "anthropic", From 8c5cd8a6220d718cceeb141f17012dd6caac1d88 Mon Sep 17 00:00:00 2001 From: Erin McNulty Date: Tue, 16 Jun 2026 11:15:36 -0400 Subject: [PATCH 14/28] don't recompute full base equivalnce index each time and consolidate fallback logic --- .../src/catalog/fallback.rs | 67 ++++ .../braintrust-llm-router/src/catalog/mod.rs | 333 ++++++++++-------- 2 files changed, 258 insertions(+), 142 deletions(-) create mode 100644 crates/braintrust-llm-router/src/catalog/fallback.rs diff --git a/crates/braintrust-llm-router/src/catalog/fallback.rs b/crates/braintrust-llm-router/src/catalog/fallback.rs new file mode 100644 index 00000000..8a4f9f7c --- /dev/null +++ b/crates/braintrust-llm-router/src/catalog/fallback.rs @@ -0,0 +1,67 @@ +use std::collections::{HashMap, HashSet}; + +pub(super) fn build_equivalence_index( + model_names: HashSet, + fallback_models: &HashMap>, +) -> HashMap> { + let mut adjacency: HashMap> = HashMap::new(); + for name in &model_names { + adjacency.entry(name.clone()).or_default(); + } + + for (name, fallbacks) in fallback_models { + if !model_names.contains(name) { + continue; + } + for fallback_model in fallbacks { + if !model_names.contains(fallback_model) { + continue; + } + adjacency + .entry(name.clone()) + .or_default() + .push(fallback_model.clone()); + adjacency + .entry(fallback_model.clone()) + .or_default() + .push(name.clone()); + } + } + + let mut visited = HashSet::new(); + let mut index = HashMap::new(); + for name in model_names { + if visited.contains(&name) { + continue; + } + + let mut stack = vec![name.clone()]; + let mut component = Vec::new(); + while let Some(current) = stack.pop() { + if !visited.insert(current.clone()) { + continue; + } + component.push(current.clone()); + if let Some(neighbors) = adjacency.get(¤t) { + stack.extend(neighbors.iter().cloned()); + } + } + + if component.len() <= 1 { + continue; + } + component.sort(); + for member in &component { + index.insert( + member.clone(), + component + .iter() + .filter(|other| *other != member) + .cloned() + .collect(), + ); + } + } + + index +} diff --git a/crates/braintrust-llm-router/src/catalog/mod.rs b/crates/braintrust-llm-router/src/catalog/mod.rs index b50fc84a..8f6a5f36 100644 --- a/crates/braintrust-llm-router/src/catalog/mod.rs +++ b/crates/braintrust-llm-router/src/catalog/mod.rs @@ -1,6 +1,9 @@ +mod fallback; mod resolver; pub mod spec; +use fallback::build_equivalence_index; + pub(crate) use resolver::is_gemini_api_model; pub use resolver::ModelResolver; pub use spec::{ModelFlavor, ModelSpec}; @@ -33,40 +36,40 @@ pub struct ModelCatalog { pub struct OverlayModelCatalog { base: Arc, custom: ModelCatalog, - equivalence_index: HashMap>, + custom_model_names: HashSet, + overlay_edges: HashMap>, } impl OverlayModelCatalog { pub fn new(base: Arc, custom: ModelCatalog) -> Self { let custom_model_names: HashSet = custom.models.keys().cloned().collect(); - let visible_models = base - .models - .keys() - .filter(|name| !custom_model_names.contains(*name)) - .chain(custom.models.keys()) - .cloned() - .collect(); - let mut fallback_edges: HashMap> = base - .fallback_models - .iter() - .filter(|(name, _fallbacks)| !custom_model_names.contains(*name)) - .map(|(name, fallbacks)| { - ( - name.clone(), - fallbacks - .iter() - .filter(|fallback_model| !custom_model_names.contains(*fallback_model)) - .cloned() - .collect(), - ) - }) - .collect(); - fallback_edges.extend(custom.fallback_models.clone()); - let equivalence_index = build_equivalence_index(visible_models, &fallback_edges); + let mut overlay_edges: HashMap> = HashMap::new(); + for (name, fallbacks) in &custom.fallback_models { + if !custom.models.contains_key(name) { + continue; + } + for fallback_model in fallbacks { + let fallback_is_visible = custom.models.contains_key(fallback_model) + || (base.models.contains_key(fallback_model) + && !custom_model_names.contains(fallback_model)); + if !fallback_is_visible { + continue; + } + overlay_edges + .entry(name.clone()) + .or_default() + .push(fallback_model.clone()); + overlay_edges + .entry(fallback_model.clone()) + .or_default() + .push(name.clone()); + } + } Self { base, custom, - equivalence_index, + custom_model_names, + overlay_edges, } } @@ -83,10 +86,31 @@ impl OverlayModelCatalog { return Vec::new(); }; - let mut names = vec![name.to_string()]; - if let Some(equivalent_names) = self.equivalence_index.get(name) { - names.extend(equivalent_names.iter().cloned()); + let mut visited = HashSet::new(); + let mut stack = vec![name.to_string()]; + while let Some(current) = stack.pop() { + if !visited.insert(current.clone()) { + continue; + } + + if !self.custom_model_names.contains(¤t) { + stack.extend( + self.base + .equivalent_model_names(¤t) + .into_iter() + .filter(|model_name| !self.custom_model_names.contains(model_name)), + ); + } + if let Some(neighbors) = self.overlay_edges.get(¤t) { + stack.extend(neighbors.iter().cloned()); + } } + + let mut names = vec![name.to_string()]; + visited.remove(name); + let mut equivalent_names: Vec = visited.into_iter().collect(); + equivalent_names.sort(); + names.extend(equivalent_names); names } } @@ -130,98 +154,9 @@ impl From> for CatalogResolver { } } -fn parse_fallback_models(content: &str) -> Result>> { - let raw: HashMap = serde_json::from_str(content)?; - let mut fallback_models = HashMap::new(); - for (name, value) in raw { - let Some(fallbacks) = value.get("fallback_models") else { - continue; - }; - let Some(fallbacks) = fallbacks.as_array() else { - return Err(Error::InvalidRequest(format!( - "model '{name}' has invalid fallback_models" - ))); - }; - let mut parsed = Vec::with_capacity(fallbacks.len()); - for fallback_model in fallbacks { - let Some(fallback_model) = fallback_model.as_str() else { - return Err(Error::InvalidRequest(format!( - "model '{name}' has invalid fallback_models" - ))); - }; - parsed.push(fallback_model.to_string()); - } - if !parsed.is_empty() { - fallback_models.insert(name, parsed); - } - } - Ok(fallback_models) -} - -fn build_equivalence_index( - model_names: HashSet, - fallback_models: &HashMap>, -) -> HashMap> { - let mut adjacency: HashMap> = HashMap::new(); - for name in &model_names { - adjacency.entry(name.clone()).or_default(); - } - - for (name, fallbacks) in fallback_models { - if !model_names.contains(name) { - continue; - } - for fallback_model in fallbacks { - if !model_names.contains(fallback_model) { - continue; - } - adjacency - .entry(name.clone()) - .or_default() - .push(fallback_model.clone()); - adjacency - .entry(fallback_model.clone()) - .or_default() - .push(name.clone()); - } - } - - let mut visited = HashSet::new(); - let mut index = HashMap::new(); - for name in model_names { - if visited.contains(&name) { - continue; - } - - let mut stack = vec![name.clone()]; - let mut component = Vec::new(); - while let Some(current) = stack.pop() { - if !visited.insert(current.clone()) { - continue; - } - component.push(current.clone()); - if let Some(neighbors) = adjacency.get(¤t) { - stack.extend(neighbors.iter().cloned()); - } - } - - if component.len() <= 1 { - continue; - } - component.sort(); - for member in &component { - index.insert( - member.clone(), - component - .iter() - .filter(|other| *other != member) - .cloned() - .collect(), - ); - } - } - - index +enum FallbackModelSource<'a> { + Json(&'a str), + Parsed(HashMap>), } impl ModelCatalog { @@ -231,14 +166,11 @@ impl ModelCatalog { pub fn from_json_str(content: &str) -> Result { let raw: HashMap = serde_json::from_str(content)?; - let fallback_models = parse_fallback_models(content)?; let mut catalog = Self::empty(); for (name, spec) in raw { catalog.insert(name, spec); } - catalog.fallback_models = fallback_models; - catalog.validate_fallback_models()?; - catalog.rebuild_equivalence_index(); + catalog.set_fallback_models(FallbackModelSource::Json(content), true)?; Ok(catalog) } @@ -321,14 +253,15 @@ impl ModelCatalog { where F: FnMut(&str, &ModelSpec) -> ModelSpec, { - let mut out = Self { - fallback_models: self.fallback_models.clone(), - ..Self::empty() - }; + let mut out = Self::empty(); for (name, spec) in &self.models { out.insert(name.clone(), f(name, spec.as_ref())); } - out.rebuild_equivalence_index(); + out.set_fallback_models( + FallbackModelSource::Parsed(self.fallback_models.clone()), + false, + ) + .expect("existing catalog fallback_models remain valid after mapping specs"); out } @@ -376,14 +309,15 @@ impl ModelCatalog { } } - let entry = self.fallback_models.entry(name).or_default(); + let mut next_fallback_models = self.fallback_models.clone(); + let entry = next_fallback_models.entry(name).or_default(); for fallback_model in fallback_models { if entry.contains(&fallback_model) { continue; } entry.push(fallback_model); } - self.rebuild_equivalence_index(); + self.set_fallback_models(FallbackModelSource::Parsed(next_fallback_models), false)?; Ok(()) } @@ -401,21 +335,70 @@ impl ModelCatalog { ))); } - let entry = self.fallback_models.entry(name).or_default(); + let mut next_fallback_models = self.fallback_models.clone(); + let entry = next_fallback_models.entry(name).or_default(); for fallback_model in fallback_models { if fallback_model.is_empty() || entry.contains(&fallback_model) { continue; } entry.push(fallback_model); } - self.rebuild_equivalence_index(); + self.set_fallback_models(FallbackModelSource::Parsed(next_fallback_models), false)?; + Ok(()) + } + + fn set_fallback_models( + &mut self, + source: FallbackModelSource<'_>, + validate_targets: bool, + ) -> Result<()> { + let fallback_models = match source { + FallbackModelSource::Json(content) => { + let raw: HashMap = serde_json::from_str(content)?; + let mut fallback_models = HashMap::new(); + for (name, value) in raw { + let Some(fallbacks) = value.get("fallback_models") else { + continue; + }; + let Some(fallbacks) = fallbacks.as_array() else { + return Err(Error::InvalidRequest(format!( + "model '{name}' has invalid fallback_models" + ))); + }; + let mut parsed = Vec::with_capacity(fallbacks.len()); + for fallback_model in fallbacks { + let Some(fallback_model) = fallback_model.as_str() else { + return Err(Error::InvalidRequest(format!( + "model '{name}' has invalid fallback_models" + ))); + }; + parsed.push(fallback_model.to_string()); + } + if !parsed.is_empty() { + fallback_models.insert(name, parsed); + } + } + fallback_models + } + FallbackModelSource::Parsed(fallback_models) => fallback_models, + }; + + if validate_targets { + Self::validate_fallback_models(&self.models, &fallback_models)?; + } + self.equivalence_index = + build_equivalence_index(self.models.keys().cloned().collect(), &fallback_models); + self.fallback_models = fallback_models; Ok(()) } - fn validate_fallback_models(&self) -> Result<()> { - for (name, fallback_models) in &self.fallback_models { + fn validate_fallback_models( + models: &HashMap>, + fallback_models: &HashMap>, + ) -> Result<()> { + for (name, fallback_models) in fallback_models { for fallback_model in fallback_models { - if !self.models.contains_key(fallback_model) { + if !models.contains_key(fallback_model) { return Err(Error::InvalidRequest(format!( "model '{name}' references missing fallback model '{fallback_model}'" ))); @@ -424,11 +407,6 @@ impl ModelCatalog { } Ok(()) } - - fn rebuild_equivalence_index(&mut self) { - self.equivalence_index = - build_equivalence_index(self.models.keys().cloned().collect(), &self.fallback_models); - } } pub fn load_catalog_from_disk>(path: P) -> Result> { @@ -617,6 +595,77 @@ mod tests { ); } + #[test] + fn overlay_equivalence_reaches_custom_and_touched_base_models() { + let base = Arc::new( + ModelCatalog::from_json_str( + r#"{ + "base-a": { + "format": "openai", + "flavor": "chat", + "fallback_models": ["base-b"] + }, + "base-b": { + "format": "openai", + "flavor": "chat" + } +}"#, + ) + .expect("base catalog parses"), + ); + let mut custom = ModelCatalog::empty(); + custom.insert( + "custom-a".to_string(), + ModelSpec { + model: "custom-a".to_string(), + format: ProviderFormat::Anthropic, + flavor: ModelFlavor::Chat, + display_name: None, + parent: None, + input_cost_per_mil_tokens: None, + output_cost_per_mil_tokens: None, + input_cache_read_cost_per_mil_tokens: None, + multimodal: None, + reasoning: None, + max_input_tokens: None, + max_output_tokens: None, + supports_streaming: true, + extra: Default::default(), + available_providers: vec!["custom-provider".to_string()], + }, + ); + custom + .add_external_fallback_models("custom-a".to_string(), vec!["base-a".to_string()]) + .expect("fallback is valid"); + + let overlay = OverlayModelCatalog::new(base, custom); + + assert_eq!( + overlay.equivalent_model_names("custom-a"), + vec![ + "custom-a".to_string(), + "base-a".to_string(), + "base-b".to_string() + ] + ); + assert_eq!( + overlay.equivalent_model_names("base-a"), + vec![ + "base-a".to_string(), + "base-b".to_string(), + "custom-a".to_string() + ] + ); + assert_eq!( + overlay.equivalent_model_names("base-b"), + vec![ + "base-b".to_string(), + "base-a".to_string(), + "custom-a".to_string() + ] + ); + } + #[test] fn overlay_equivalence_index_does_not_inherit_shadowed_base_edges() { let base = Arc::new( From 7bc93387a1ba0e9ecc13b36c58cab62d680aebcd Mon Sep 17 00:00:00 2001 From: Erin McNulty Date: Tue, 16 Jun 2026 11:48:41 -0400 Subject: [PATCH 15/28] rename --- .../braintrust-llm-router/src/catalog/mod.rs | 38 +++++++++---------- .../src/catalog/resolver.rs | 2 +- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/crates/braintrust-llm-router/src/catalog/mod.rs b/crates/braintrust-llm-router/src/catalog/mod.rs index 8f6a5f36..4bbeb5b0 100644 --- a/crates/braintrust-llm-router/src/catalog/mod.rs +++ b/crates/braintrust-llm-router/src/catalog/mod.rs @@ -81,7 +81,7 @@ impl OverlayModelCatalog { self.custom.get(name).or_else(|| self.base.get(name)) } - pub fn equivalent_model_names(&self, name: &str) -> Vec { + pub fn find_fallback_models(&self, name: &str) -> Vec { let Some(_) = self.get(name) else { return Vec::new(); }; @@ -96,7 +96,7 @@ impl OverlayModelCatalog { if !self.custom_model_names.contains(¤t) { stack.extend( self.base - .equivalent_model_names(¤t) + .find_fallback_models(¤t) .into_iter() .filter(|model_name| !self.custom_model_names.contains(model_name)), ); @@ -140,10 +140,10 @@ impl CatalogResolver { } } - pub fn equivalent_model_names(&self, name: &str) -> Vec { + pub fn find_fallback_models(&self, name: &str) -> Vec { match self { - Self::Base(catalog) => catalog.equivalent_model_names(name), - Self::Overlay(overlay) => overlay.equivalent_model_names(name), + Self::Base(catalog) => catalog.find_fallback_models(name), + Self::Overlay(overlay) => overlay.find_fallback_models(name), } } } @@ -189,7 +189,7 @@ impl ModelCatalog { self.models.get(name).cloned() } - pub fn equivalent_model_names(&self, name: &str) -> Vec { + pub fn find_fallback_models(&self, name: &str) -> Vec { let Some(_) = self.models.get(name) else { return Vec::new(); }; @@ -424,7 +424,7 @@ mod tests { use super::*; #[test] - fn equivalent_model_names_are_available_from_any_member() { + fn find_fallback_models_are_available_from_any_member() { let catalog = ModelCatalog::from_json_str( r#"{ "claude-sonnet-4-6": { @@ -448,7 +448,7 @@ mod tests { .expect("catalog parses"); assert_eq!( - catalog.equivalent_model_names("claude-sonnet-4-6"), + catalog.find_fallback_models("claude-sonnet-4-6"), vec![ "claude-sonnet-4-6".to_string(), "anthropic.claude-sonnet-4-6".to_string(), @@ -456,7 +456,7 @@ mod tests { ] ); assert_eq!( - catalog.equivalent_model_names("publishers/anthropic/models/claude-sonnet-4-6"), + catalog.find_fallback_models("publishers/anthropic/models/claude-sonnet-4-6"), vec![ "publishers/anthropic/models/claude-sonnet-4-6".to_string(), "anthropic.claude-sonnet-4-6".to_string(), @@ -488,7 +488,7 @@ mod tests { .expect("catalog parses"); assert_eq!( - catalog.equivalent_model_names("model-a"), + catalog.find_fallback_models("model-a"), vec![ "model-a".to_string(), "model-b".to_string(), @@ -534,11 +534,11 @@ mod tests { .expect("equivalence is valid"); assert_eq!( - catalog.equivalent_model_names("model-a"), + catalog.find_fallback_models("model-a"), vec!["model-a".to_string(), "model-b".to_string()] ); assert_eq!( - catalog.equivalent_model_names("model-b"), + catalog.find_fallback_models("model-b"), vec!["model-b".to_string(), "model-a".to_string()] ); } @@ -561,7 +561,7 @@ mod tests { assert!(matches!(error, Error::InvalidRequest(_))); assert_eq!( - catalog.equivalent_model_names("model-a"), + catalog.find_fallback_models("model-a"), vec!["model-a".to_string()] ); } @@ -590,7 +590,7 @@ mod tests { }); assert_eq!( - mapped.equivalent_model_names("model-a"), + mapped.find_fallback_models("model-a"), vec!["model-a".to_string(), "model-b".to_string()] ); } @@ -641,7 +641,7 @@ mod tests { let overlay = OverlayModelCatalog::new(base, custom); assert_eq!( - overlay.equivalent_model_names("custom-a"), + overlay.find_fallback_models("custom-a"), vec![ "custom-a".to_string(), "base-a".to_string(), @@ -649,7 +649,7 @@ mod tests { ] ); assert_eq!( - overlay.equivalent_model_names("base-a"), + overlay.find_fallback_models("base-a"), vec![ "base-a".to_string(), "base-b".to_string(), @@ -657,7 +657,7 @@ mod tests { ] ); assert_eq!( - overlay.equivalent_model_names("base-b"), + overlay.find_fallback_models("base-b"), vec![ "base-b".to_string(), "base-a".to_string(), @@ -709,11 +709,11 @@ mod tests { let overlay = OverlayModelCatalog::new(base, custom); assert_eq!( - overlay.equivalent_model_names("model-a"), + overlay.find_fallback_models("model-a"), vec!["model-a".to_string()] ); assert_eq!( - overlay.equivalent_model_names("model-b"), + overlay.find_fallback_models("model-b"), vec!["model-b".to_string()] ); } diff --git a/crates/braintrust-llm-router/src/catalog/resolver.rs b/crates/braintrust-llm-router/src/catalog/resolver.rs index 53d8edef..9eef9a2f 100644 --- a/crates/braintrust-llm-router/src/catalog/resolver.rs +++ b/crates/braintrust-llm-router/src/catalog/resolver.rs @@ -43,7 +43,7 @@ impl ModelResolver { pub fn resolve_all_equivalent_model_routes(&self, model: &str) -> Result> { let mut resolved = Vec::new(); - for model_name in self.catalog.equivalent_model_names(model) { + for model_name in self.catalog.find_fallback_models(model) { resolved.push(self.resolve_one(&model_name)?); } if resolved.is_empty() { From 471718f74fa97b5730a9c677c3c2366ca28ed0c6 Mon Sep 17 00:00:00 2001 From: Erin McNulty Date: Tue, 16 Jun 2026 11:53:21 -0400 Subject: [PATCH 16/28] rename --- .../braintrust-llm-router/src/catalog/mod.rs | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/crates/braintrust-llm-router/src/catalog/mod.rs b/crates/braintrust-llm-router/src/catalog/mod.rs index 4bbeb5b0..4d720a81 100644 --- a/crates/braintrust-llm-router/src/catalog/mod.rs +++ b/crates/braintrust-llm-router/src/catalog/mod.rs @@ -96,7 +96,7 @@ impl OverlayModelCatalog { if !self.custom_model_names.contains(¤t) { stack.extend( self.base - .find_fallback_models(¤t) + .fallback_models(¤t) .into_iter() .filter(|model_name| !self.custom_model_names.contains(model_name)), ); @@ -142,7 +142,7 @@ impl CatalogResolver { pub fn find_fallback_models(&self, name: &str) -> Vec { match self { - Self::Base(catalog) => catalog.find_fallback_models(name), + Self::Base(catalog) => catalog.fallback_models(name), Self::Overlay(overlay) => overlay.find_fallback_models(name), } } @@ -189,7 +189,7 @@ impl ModelCatalog { self.models.get(name).cloned() } - pub fn find_fallback_models(&self, name: &str) -> Vec { + pub fn fallback_models(&self, name: &str) -> Vec { let Some(_) = self.models.get(name) else { return Vec::new(); }; @@ -424,7 +424,7 @@ mod tests { use super::*; #[test] - fn find_fallback_models_are_available_from_any_member() { + fn fallback_models_are_available_from_any_member() { let catalog = ModelCatalog::from_json_str( r#"{ "claude-sonnet-4-6": { @@ -448,7 +448,7 @@ mod tests { .expect("catalog parses"); assert_eq!( - catalog.find_fallback_models("claude-sonnet-4-6"), + catalog.fallback_models("claude-sonnet-4-6"), vec![ "claude-sonnet-4-6".to_string(), "anthropic.claude-sonnet-4-6".to_string(), @@ -456,7 +456,7 @@ mod tests { ] ); assert_eq!( - catalog.find_fallback_models("publishers/anthropic/models/claude-sonnet-4-6"), + catalog.fallback_models("publishers/anthropic/models/claude-sonnet-4-6"), vec![ "publishers/anthropic/models/claude-sonnet-4-6".to_string(), "anthropic.claude-sonnet-4-6".to_string(), @@ -488,7 +488,7 @@ mod tests { .expect("catalog parses"); assert_eq!( - catalog.find_fallback_models("model-a"), + catalog.fallback_models("model-a"), vec![ "model-a".to_string(), "model-b".to_string(), @@ -534,11 +534,11 @@ mod tests { .expect("equivalence is valid"); assert_eq!( - catalog.find_fallback_models("model-a"), + catalog.fallback_models("model-a"), vec!["model-a".to_string(), "model-b".to_string()] ); assert_eq!( - catalog.find_fallback_models("model-b"), + catalog.fallback_models("model-b"), vec!["model-b".to_string(), "model-a".to_string()] ); } @@ -561,7 +561,7 @@ mod tests { assert!(matches!(error, Error::InvalidRequest(_))); assert_eq!( - catalog.find_fallback_models("model-a"), + catalog.fallback_models("model-a"), vec!["model-a".to_string()] ); } @@ -590,7 +590,7 @@ mod tests { }); assert_eq!( - mapped.find_fallback_models("model-a"), + mapped.fallback_models("model-a"), vec!["model-a".to_string(), "model-b".to_string()] ); } From 9a2743f316630151e1b57d9696960e7a4d7e9d0e Mon Sep 17 00:00:00 2001 From: Erin McNulty Date: Tue, 16 Jun 2026 11:58:16 -0400 Subject: [PATCH 17/28] one more rename --- crates/braintrust-llm-router/src/catalog/mod.rs | 2 +- crates/braintrust-llm-router/src/catalog/resolver.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/braintrust-llm-router/src/catalog/mod.rs b/crates/braintrust-llm-router/src/catalog/mod.rs index 4d720a81..5fff8358 100644 --- a/crates/braintrust-llm-router/src/catalog/mod.rs +++ b/crates/braintrust-llm-router/src/catalog/mod.rs @@ -140,7 +140,7 @@ impl CatalogResolver { } } - pub fn find_fallback_models(&self, name: &str) -> Vec { + pub fn fallback_models(&self, name: &str) -> Vec { match self { Self::Base(catalog) => catalog.fallback_models(name), Self::Overlay(overlay) => overlay.find_fallback_models(name), diff --git a/crates/braintrust-llm-router/src/catalog/resolver.rs b/crates/braintrust-llm-router/src/catalog/resolver.rs index 9eef9a2f..6df57b57 100644 --- a/crates/braintrust-llm-router/src/catalog/resolver.rs +++ b/crates/braintrust-llm-router/src/catalog/resolver.rs @@ -43,7 +43,7 @@ impl ModelResolver { pub fn resolve_all_equivalent_model_routes(&self, model: &str) -> Result> { let mut resolved = Vec::new(); - for model_name in self.catalog.find_fallback_models(model) { + for model_name in self.catalog.fallback_models(model) { resolved.push(self.resolve_one(&model_name)?); } if resolved.is_empty() { From 8674f412ce167a212d8d647e37dd2bcc76aa6f23 Mon Sep 17 00:00:00 2001 From: Erin McNulty Date: Tue, 16 Jun 2026 12:09:36 -0400 Subject: [PATCH 18/28] comment fix --- .../src/providers/mod.rs | 1 + crates/braintrust-llm-router/src/router.rs | 37 +++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/crates/braintrust-llm-router/src/providers/mod.rs b/crates/braintrust-llm-router/src/providers/mod.rs index 539ce873..042cb56f 100644 --- a/crates/braintrust-llm-router/src/providers/mod.rs +++ b/crates/braintrust-llm-router/src/providers/mod.rs @@ -182,6 +182,7 @@ pub(crate) fn format_carries_model_in_body(format: ProviderFormat) -> bool { ProviderFormat::ChatCompletions | ProviderFormat::Responses | ProviderFormat::Anthropic + | ProviderFormat::Google | ProviderFormat::Mistral ) } diff --git a/crates/braintrust-llm-router/src/router.rs b/crates/braintrust-llm-router/src/router.rs index 367fcbe1..5f21a12f 100644 --- a/crates/braintrust-llm-router/src/router.rs +++ b/crates/braintrust-llm-router/src/router.rs @@ -1476,6 +1476,43 @@ mod tests { assert!(parsed.get("messages").is_some()); } + #[tokio::test] + async fn prepare_provider_request_rewrites_model_for_google_pass_through() { + let body = Bytes::from_static( + br#"{"model":"models/gemini-2.5-flash","contents":[{"role":"user","parts":[{"text":"Ping"}]}]}"#, + ); + let spec = ModelSpec { + model: "models/gemini-2.5-pro".to_string(), + format: ProviderFormat::Google, + flavor: ModelFlavor::Chat, + display_name: None, + parent: None, + input_cost_per_mil_tokens: None, + output_cost_per_mil_tokens: None, + input_cache_read_cost_per_mil_tokens: None, + multimodal: None, + reasoning: None, + max_input_tokens: None, + max_output_tokens: None, + supports_streaming: true, + extra: Default::default(), + available_providers: vec!["google".to_string()], + }; + + let (payload, _, actual_format) = + prepare_provider_request(body, &spec, ProviderFormat::Google, false) + .await + .expect("request prepares"); + let parsed: Value = serde_json::from_slice(&payload).expect("valid request json"); + + assert_eq!(actual_format, ProviderFormat::Google); + assert_eq!( + parsed.get("model").and_then(Value::as_str), + Some("models/gemini-2.5-pro") + ); + assert!(parsed.get("contents").is_some()); + } + #[tokio::test] async fn prepare_provider_request_upgrades_actual_format_to_responses_for_reasoning_plus_tools() { From 0eb3148d2766d6e85d1a053f60ddf3f250b185fa Mon Sep 17 00:00:00 2001 From: Erin McNulty Date: Tue, 16 Jun 2026 13:34:05 -0400 Subject: [PATCH 19/28] stop doing bedrock request transforms in a seperate spot than everywhere else --- .../src/providers/bedrock.rs | 47 ++------- crates/lingua/src/lib.rs | 10 +- crates/lingua/src/processing/mod.rs | 7 +- crates/lingua/src/processing/transform.rs | 95 +++++++++++++++++-- 4 files changed, 107 insertions(+), 52 deletions(-) diff --git a/crates/braintrust-llm-router/src/providers/bedrock.rs b/crates/braintrust-llm-router/src/providers/bedrock.rs index ea7c5ae0..bbc87c95 100644 --- a/crates/braintrust-llm-router/src/providers/bedrock.rs +++ b/crates/braintrust-llm-router/src/providers/bedrock.rs @@ -11,10 +11,10 @@ use aws_sigv4::sign::v4; use aws_smithy_runtime_api::client::identity::Identity; use bytes::Bytes; use http::Request as HttpRequest; -use lingua::processing::{adapter_for_format, adapters}; use lingua::serde_json::Value; use lingua::universal::message::{Message, UserContent, UserContentPart}; use lingua::util::media::MediaBlock; +use lingua::{finish_request_transform, prepare_request_transform, RequestTransformPreparation}; use reqwest::header::{HeaderMap, HeaderValue, CONTENT_TYPE}; use reqwest::Url; use reqwest_middleware::ClientWithMiddleware; @@ -25,7 +25,7 @@ use crate::client::{build_middleware_client, ClientSettings}; use crate::error::{Error, Result, UpstreamHttpError}; use crate::providers::ClientHeaders; use crate::streaming::{bedrock_event_stream, sse_stream, RawResponseStream}; -use lingua::{ProviderFormat, TransformError}; +use lingua::ProviderFormat; const BEDROCK_REMOTE_MEDIA_MAX_BYTES: usize = 5 * 1024 * 1024; @@ -77,42 +77,15 @@ where return Ok(body); } - let parsed = lingua::parse_json_body(body)?; - let payload = parsed.value; - let body = parsed.bytes; - - let source_adapter = match adapters() - .iter() - .map(|adapter| adapter.as_ref()) - .find(|adapter| adapter.detect_request(&payload)) - { - Some(adapter) => adapter, - None => return Err(TransformError::UnableToDetectRequestFormat.into()), - }; - - if source_adapter.format() == format { - return Ok(body); - } - - let mut request = match source_adapter.request_to_universal(payload) { - Ok(request) => request, - Err(err) => return Err(err.into()), - }; - - inline_remote_image_urls_with_fetch(&mut request, fetch).await?; - - if request.model.is_none() { - request.model = Some(spec.model.clone()); + match prepare_request_transform(body, format, Some(&spec.model))? { + RequestTransformPreparation::PassThrough(bytes) => Ok(bytes), + RequestTransformPreparation::Prepared(mut prepared) => { + inline_remote_image_urls_with_fetch(&mut prepared.request, fetch).await?; + finish_request_transform(*prepared) + .map(|result| result.into_bytes()) + .map_err(Error::from) + } } - - let target_adapter = - adapter_for_format(format).ok_or(TransformError::UnsupportedTargetFormat(format))?; - target_adapter.apply_defaults(&mut request); - let prepared = target_adapter.request_from_universal(&request)?; - - lingua::serde_json::to_vec(&prepared) - .map(Bytes::from) - .map_err(Error::LinguaJson) } async fn inline_remote_image_urls_with_fetch( diff --git a/crates/lingua/src/lib.rs b/crates/lingua/src/lib.rs index 8bd31d5e..6defd7b2 100644 --- a/crates/lingua/src/lib.rs +++ b/crates/lingua/src/lib.rs @@ -36,10 +36,12 @@ pub use capabilities::ProviderFormat; // Re-export key processing functions (bytes-based API) pub use processing::{ - extract_model, normalize_json_lone_surrogate_escapes, parse_json, parse_json_body, - parse_json_value, parse_stream_event, response_to_universal, sanitize_payload, - transform_request, transform_response, transform_stream_chunk, ParsedJsonBody, - ParsedStreamEvent, StreamOutputChunk, StreamTransformSession, TransformError, TransformResult, + extract_model, finish_request_transform, normalize_json_lone_surrogate_escapes, parse_json, + parse_json_body, parse_json_value, parse_stream_event, prepare_request_transform, + response_to_universal, sanitize_payload, transform_request, transform_response, + transform_stream_chunk, ParsedJsonBody, ParsedStreamEvent, PreparedRequestTransform, + RequestTransformPreparation, StreamOutputChunk, StreamTransformSession, TransformError, + TransformResult, }; // Re-export universal types diff --git a/crates/lingua/src/processing/mod.rs b/crates/lingua/src/processing/mod.rs index 02351a47..c6b564fb 100644 --- a/crates/lingua/src/processing/mod.rs +++ b/crates/lingua/src/processing/mod.rs @@ -16,7 +16,8 @@ pub use stream::{ parse_stream_event, ParsedStreamEvent, StreamOutputChunk, StreamTransformSession, }; pub use transform::{ - extract_model, parse_json, parse_json_body, parse_json_value, response_to_universal, - sanitize_payload, transform_request, transform_response, transform_stream_chunk, - ParsedJsonBody, TransformError, TransformResult, + extract_model, finish_request_transform, parse_json, parse_json_body, parse_json_value, + prepare_request_transform, response_to_universal, sanitize_payload, transform_request, + transform_response, transform_stream_chunk, ParsedJsonBody, PreparedRequestTransform, + RequestTransformPreparation, TransformError, TransformResult, }; diff --git a/crates/lingua/src/processing/transform.rs b/crates/lingua/src/processing/transform.rs index 89ad7079..da13df03 100644 --- a/crates/lingua/src/processing/transform.rs +++ b/crates/lingua/src/processing/transform.rs @@ -22,9 +22,9 @@ use crate::providers::openai::model_needs_transforms; use crate::serde_json; use crate::serde_json::Value; use crate::universal::{ - AssistantContent, AssistantContentPart, Message, UniversalReasoningDelta, UniversalResponse, - UniversalStreamChoice, UniversalStreamChunk, UniversalStreamDelta, UniversalToolCallDelta, - UniversalToolFunctionDelta, + AssistantContent, AssistantContentPart, Message, UniversalReasoningDelta, UniversalRequest, + UniversalResponse, UniversalStreamChoice, UniversalStreamChunk, UniversalStreamDelta, + UniversalToolCallDelta, UniversalToolFunctionDelta, }; use serde::de::DeserializeOwned; use thiserror::Error; @@ -132,6 +132,15 @@ pub enum TransformResult { }, } +/// A request converted into Lingua's universal representation and ready for +/// caller-specific mutation before final target serialization. +#[derive(Debug, Clone)] +pub struct PreparedRequestTransform { + pub request: UniversalRequest, + pub source_format: ProviderFormat, + pub actual_target_format: ProviderFormat, +} + impl TransformResult { /// Check if this is a pass-through result (no transformation occurred). pub fn is_passthrough(&self) -> bool { @@ -303,6 +312,26 @@ pub fn transform_request( target_format: ProviderFormat, model: Option<&str>, ) -> Result { + match prepare_request_transform(input, target_format, model)? { + RequestTransformPreparation::PassThrough(bytes) => Ok(TransformResult::PassThrough(bytes)), + RequestTransformPreparation::Prepared(prepared) => finish_request_transform(*prepared), + } +} + +/// Result of preparing a request transform before final target serialization. +pub enum RequestTransformPreparation { + /// Payload was already valid for target format and no target model override + /// forced a universal conversion. + PassThrough(Bytes), + /// Payload has been converted to universal form with the target model applied. + Prepared(Box), +} + +pub fn prepare_request_transform( + input: Bytes, + target_format: ProviderFormat, + model: Option<&str>, +) -> Result { let parsed = parse_json_body(input)?; let payload = parsed.value; let request_bytes = parsed.bytes; @@ -339,9 +368,10 @@ pub fn transform_request( if source_format == target_format && !request_model_needs_forced_translation(request_model.as_deref(), model, target_format) + && model.is_none() && target_adapter.detect_passthrough_request(&payload) { - return Ok(TransformResult::PassThrough(request_bytes)); + return Ok(RequestTransformPreparation::PassThrough(request_bytes)); } let mut universal = source_adapter.request_to_universal(payload)?; @@ -350,19 +380,35 @@ pub fn transform_request( universal.model = Some(model.to_string()); } + Ok(RequestTransformPreparation::Prepared(Box::new( + PreparedRequestTransform { + request: universal, + source_format, + actual_target_format: target_format, + }, + ))) +} + +pub fn finish_request_transform( + mut prepared: PreparedRequestTransform, +) -> Result { + let target_adapter = adapter_for_format(prepared.actual_target_format).ok_or( + TransformError::UnsupportedTargetFormat(prepared.actual_target_format), + )?; + // Apply target provider defaults (e.g., Anthropic's required max_tokens) - target_adapter.apply_defaults(&mut universal); + target_adapter.apply_defaults(&mut prepared.request); // Convert to target format (validation happens in adapter) - let transformed = target_adapter.request_from_universal(&universal)?; + let transformed = target_adapter.request_from_universal(&prepared.request)?; let bytes = crate::serde_json::to_vec(&transformed) .map_err(|e| TransformError::SerializationFailed(e.to_string()))?; Ok(TransformResult::Transformed { bytes: Bytes::from(bytes), - source_format, - actual_target_format: target_format, + source_format: prepared.source_format, + actual_target_format: prepared.actual_target_format, }) } @@ -2158,9 +2204,42 @@ mod tests { ); let output: Value = crate::serde_json::from_slice(result.as_bytes()).unwrap(); + assert_eq!( + output.get("modelId").and_then(Value::as_str), + Some("amazon.nova-pro-v1:0") + ); assert!( output.get("anthropic_version").is_none(), "Non-anthropic models should not have anthropic_version" ); } + + #[test] + #[cfg(feature = "bedrock")] + fn test_converse_model_override_forces_universal_translation() { + let payload = json!({ + "modelId": "anthropic.claude-3-haiku-20240307-v1:0", + "messages": [{ + "role": "user", + "content": [{"text": "Hello"}] + }] + }); + let input = to_bytes(&payload); + + let result = transform_request( + input, + ProviderFormat::Converse, + Some("anthropic.claude-3-5-sonnet-20241022-v2:0"), + ) + .unwrap(); + + assert!(!result.is_passthrough()); + assert_eq!(result.source_format(), Some(ProviderFormat::Converse)); + + let output: Value = crate::serde_json::from_slice(result.as_bytes()).unwrap(); + assert_eq!( + output.get("modelId").and_then(Value::as_str), + Some("anthropic.claude-3-5-sonnet-20241022-v2:0") + ); + } } From 8189e2d0b4f2e271f9ebc653872aa582f7f1fec9 Mon Sep 17 00:00:00 2001 From: Erin McNulty Date: Tue, 16 Jun 2026 14:57:25 -0400 Subject: [PATCH 20/28] oops --- crates/lingua/src/processing/transform.rs | 36 ++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/crates/lingua/src/processing/transform.rs b/crates/lingua/src/processing/transform.rs index da13df03..d40adee7 100644 --- a/crates/lingua/src/processing/transform.rs +++ b/crates/lingua/src/processing/transform.rs @@ -368,7 +368,7 @@ pub fn prepare_request_transform( if source_format == target_format && !request_model_needs_forced_translation(request_model.as_deref(), model, target_format) - && model.is_none() + && !request_model_override_changes_model(request_model.as_deref(), model) && target_adapter.detect_passthrough_request(&payload) { return Ok(RequestTransformPreparation::PassThrough(request_bytes)); @@ -389,6 +389,16 @@ pub fn prepare_request_transform( ))) } +fn request_model_override_changes_model( + request_model: Option<&str>, + override_model: Option<&str>, +) -> bool { + match override_model { + Some(override_model) => request_model != Some(override_model), + None => false, + } +} + pub fn finish_request_transform( mut prepared: PreparedRequestTransform, ) -> Result { @@ -894,6 +904,30 @@ mod tests { assert_eq!(output.as_ptr(), input_ptr); } + #[test] + #[cfg(feature = "openai")] + fn test_transform_request_passthrough_with_identical_model_override() { + let payload = json!({ + "model": "gpt-4", + "messages": [{"role": "user", "name": "example_user", "content": "Hello"}] + }); + let input = to_bytes(&payload); + let input_ptr = input.as_ptr(); + + let result = + transform_request(input, ProviderFormat::ChatCompletions, Some("gpt-4")).unwrap(); + + assert!(result.is_passthrough()); + let output = result.into_bytes(); + assert_eq!(output.as_ptr(), input_ptr); + + let parsed: Value = crate::serde_json::from_slice(&output).unwrap(); + assert_eq!( + parsed["messages"][0].get("name").and_then(Value::as_str), + Some("example_user") + ); + } + #[test] #[cfg(feature = "openai")] fn test_transform_request_passthrough_repairs_lone_surrogate() { From 199be5b1929f12104b7fcb40af7bbe73966324fa Mon Sep 17 00:00:00 2001 From: Erin McNulty Date: Tue, 16 Jun 2026 15:46:58 -0400 Subject: [PATCH 21/28] more holistic approach to re-write model name on failover in request body --- .../src/providers/bedrock.rs | 102 ++++++++++++++++-- .../src/providers/mod.rs | 39 +++++-- crates/braintrust-llm-router/src/router.rs | 48 +++++---- crates/lingua/src/lib.rs | 10 +- crates/lingua/src/processing/mod.rs | 7 +- crates/lingua/src/processing/transform.rs | 84 +++------------ 6 files changed, 170 insertions(+), 120 deletions(-) diff --git a/crates/braintrust-llm-router/src/providers/bedrock.rs b/crates/braintrust-llm-router/src/providers/bedrock.rs index bbc87c95..cb3435a3 100644 --- a/crates/braintrust-llm-router/src/providers/bedrock.rs +++ b/crates/braintrust-llm-router/src/providers/bedrock.rs @@ -11,10 +11,10 @@ use aws_sigv4::sign::v4; use aws_smithy_runtime_api::client::identity::Identity; use bytes::Bytes; use http::Request as HttpRequest; +use lingua::processing::{adapter_for_format, adapters}; use lingua::serde_json::Value; use lingua::universal::message::{Message, UserContent, UserContentPart}; use lingua::util::media::MediaBlock; -use lingua::{finish_request_transform, prepare_request_transform, RequestTransformPreparation}; use reqwest::header::{HeaderMap, HeaderValue, CONTENT_TYPE}; use reqwest::Url; use reqwest_middleware::ClientWithMiddleware; @@ -23,9 +23,9 @@ use crate::auth::AuthConfig; use crate::catalog::ModelSpec; use crate::client::{build_middleware_client, ClientSettings}; use crate::error::{Error, Result, UpstreamHttpError}; -use crate::providers::ClientHeaders; +use crate::providers::{rewrite_body_model, ClientHeaders}; use crate::streaming::{bedrock_event_stream, sse_stream, RawResponseStream}; -use lingua::ProviderFormat; +use lingua::{ProviderFormat, TransformError}; const BEDROCK_REMOTE_MEDIA_MAX_BYTES: usize = 5 * 1024 * 1024; @@ -77,15 +77,39 @@ where return Ok(body); } - match prepare_request_transform(body, format, Some(&spec.model))? { - RequestTransformPreparation::PassThrough(bytes) => Ok(bytes), - RequestTransformPreparation::Prepared(mut prepared) => { - inline_remote_image_urls_with_fetch(&mut prepared.request, fetch).await?; - finish_request_transform(*prepared) - .map(|result| result.into_bytes()) - .map_err(Error::from) - } + let parsed = lingua::parse_json_body(body)?; + let payload = parsed.value; + let body = parsed.bytes; + + let source_adapter = match adapters() + .iter() + .map(|adapter| adapter.as_ref()) + .find(|adapter| adapter.detect_request(&payload)) + { + Some(adapter) => adapter, + None => return Err(TransformError::UnableToDetectRequestFormat.into()), + }; + + if source_adapter.format() == format { + return Ok(rewrite_body_model(body, format, &spec.model)); } + + let mut request = match source_adapter.request_to_universal(payload) { + Ok(request) => request, + Err(err) => return Err(err.into()), + }; + + inline_remote_image_urls_with_fetch(&mut request, fetch).await?; + request.model = Some(spec.model.clone()); + + let target_adapter = + adapter_for_format(format).ok_or(TransformError::UnsupportedTargetFormat(format))?; + target_adapter.apply_defaults(&mut request); + let prepared = target_adapter.request_from_universal(&request)?; + + lingua::serde_json::to_vec(&prepared) + .map(Bytes::from) + .map_err(Error::LinguaJson) } async fn inline_remote_image_urls_with_fetch( @@ -597,6 +621,11 @@ mod tests { let body = Bytes::from( lingua::serde_json::to_vec(&lingua::serde_json::json!({ "modelId": "anthropic.claude-3-haiku-20240307-v1:0", + "system": [{"text": "You are helpful."}], + "guardrailConfig": { + "guardrailIdentifier": "test", + "guardrailVersion": "1" + }, "messages": [{ "role": "user", "content": [{"text": "Hello"}] @@ -624,6 +653,57 @@ mod tests { assert_eq!(prepared, body); } + #[tokio::test] + async fn prepare_request_rewrites_same_format_converse_model_without_losing_native_fields() { + let body = Bytes::from( + lingua::serde_json::to_vec(&lingua::serde_json::json!({ + "modelId": "anthropic.claude-3-haiku-20240307-v1:0", + "system": [{"text": "You are helpful."}], + "guardrailConfig": { + "guardrailIdentifier": "test", + "guardrailVersion": "1" + }, + "messages": [{ + "role": "user", + "content": [{"text": "Hello"}] + }] + })) + .unwrap(), + ); + + let prepared = prepare_bedrock_request_with_fetch( + body, + &bedrock_spec( + "anthropic.claude-3-5-sonnet-20241022-v2:0", + ProviderFormat::Converse, + ), + ProviderFormat::Converse, + |_url| { + Box::pin(async { + panic!("fetch should not be called for same-format converse requests"); + }) + }, + ) + .await + .unwrap(); + + let value: lingua::serde_json::Value = lingua::serde_json::from_slice(&prepared).unwrap(); + assert_eq!( + value.get("modelId").and_then(|v| v.as_str()), + Some("anthropic.claude-3-5-sonnet-20241022-v2:0") + ); + assert_eq!( + value.pointer("/system/0/text").and_then(|v| v.as_str()), + Some("You are helpful.") + ); + assert_eq!( + value + .pointer("/guardrailConfig/guardrailIdentifier") + .and_then(|v| v.as_str()), + Some("test") + ); + } + #[tokio::test] async fn prepare_request_repairs_lone_surrogate_for_same_format_converse() { let body = Bytes::from_static( diff --git a/crates/braintrust-llm-router/src/providers/mod.rs b/crates/braintrust-llm-router/src/providers/mod.rs index 042cb56f..68c7ce98 100644 --- a/crates/braintrust-llm-router/src/providers/mod.rs +++ b/crates/braintrust-llm-router/src/providers/mod.rs @@ -176,15 +176,38 @@ pub(crate) fn disable_streaming_payload(payload: Bytes) -> Bytes { } } -pub(crate) fn format_carries_model_in_body(format: ProviderFormat) -> bool { - matches!( - format, +pub(crate) fn body_model_field(format: ProviderFormat) -> Option<&'static str> { + match format { ProviderFormat::ChatCompletions - | ProviderFormat::Responses - | ProviderFormat::Anthropic - | ProviderFormat::Google - | ProviderFormat::Mistral - ) + | ProviderFormat::Responses + | ProviderFormat::Anthropic + | ProviderFormat::Google + | ProviderFormat::Mistral => Some("model"), + ProviderFormat::Converse => Some("modelId"), + ProviderFormat::BedrockAnthropic + | ProviderFormat::VertexAnthropic + | ProviderFormat::Unknown => None, + } +} + +pub(crate) fn rewrite_body_model(payload: Bytes, format: ProviderFormat, model: &str) -> Bytes { + let Some(model_field) = body_model_field(format) else { + return payload; + }; + let Ok(mut value) = serde_json::from_slice::(&payload) else { + return payload; + }; + let Some(object) = value.as_object_mut() else { + return payload; + }; + if object.get(model_field).and_then(Value::as_str) == Some(model) { + return payload; + } + object.insert(model_field.to_string(), Value::String(model.to_string())); + match serde_json::to_vec(&value) { + Ok(serialized) => Bytes::from(serialized), + Err(_) => payload, + } } pub(crate) fn enable_streaming_payload(payload: Bytes, format: ProviderFormat) -> Bytes { diff --git a/crates/braintrust-llm-router/src/router.rs b/crates/braintrust-llm-router/src/router.rs index 5f21a12f..a88199d3 100644 --- a/crates/braintrust-llm-router/src/router.rs +++ b/crates/braintrust-llm-router/src/router.rs @@ -15,8 +15,8 @@ use crate::catalog::{ use crate::client::ClientSettings; use crate::error::{Error, Result}; use crate::providers::{ - enable_streaming_payload, format_carries_model_in_body, prepare_bedrock_request, - requires_bedrock_request_preparation, ClientHeaders, Provider, + enable_streaming_payload, prepare_bedrock_request, requires_bedrock_request_preparation, + rewrite_body_model, ClientHeaders, Provider, }; use crate::retry::{RetryPolicy, RetryStrategy}; use crate::streaming::{ @@ -208,26 +208,6 @@ struct PreparedRequestInner { strategy: RetryStrategy, } -fn override_payload_model(payload: Bytes, format: ProviderFormat, model: &str) -> Bytes { - if !format_carries_model_in_body(format) { - return payload; - } - let Ok(mut value) = serde_json::from_slice::(&payload) else { - return payload; - }; - let Some(object) = value.as_object_mut() else { - return payload; - }; - if object.get("model").and_then(Value::as_str) == Some(model) { - return payload; - } - object.insert("model".to_string(), Value::String(model.to_string())); - match serde_json::to_vec(&value) { - Ok(serialized) => Bytes::from(serialized), - Err(_) => payload, - } -} - async fn prepare_provider_request( body: Bytes, spec: &ModelSpec, @@ -251,7 +231,7 @@ async fn prepare_provider_request( Err(err) => return Err(err.into()), }; - let transformed = override_payload_model(transformed, actual_format, &spec.model); + let transformed = rewrite_body_model(transformed, actual_format, &spec.model); if stream { // TODO: Fold streaming intent into `lingua::transform_request` once we @@ -1513,6 +1493,28 @@ mod tests { assert!(parsed.get("contents").is_some()); } + #[tokio::test] + async fn prepare_provider_request_rewrites_same_format_chat_model_without_losing_native_fields() + { + let body = Bytes::from_static( + br#"{"model":"gpt-4","messages":[{"role":"user","name":"example_user","content":"Ping"}]}"#, + ); + let spec = openai_spec("gpt-4o", ModelFlavor::Chat); + + let (payload, _, actual_format) = + prepare_provider_request(body, &spec, ProviderFormat::ChatCompletions, false) + .await + .expect("request prepares"); + let parsed: Value = serde_json::from_slice(&payload).expect("valid request json"); + + assert_eq!(actual_format, ProviderFormat::ChatCompletions); + assert_eq!(parsed.get("model").and_then(Value::as_str), Some("gpt-4o")); + assert_eq!( + parsed.pointer("/messages/0/name").and_then(Value::as_str), + Some("example_user") + ); + } + #[tokio::test] async fn prepare_provider_request_upgrades_actual_format_to_responses_for_reasoning_plus_tools() { diff --git a/crates/lingua/src/lib.rs b/crates/lingua/src/lib.rs index 6defd7b2..8bd31d5e 100644 --- a/crates/lingua/src/lib.rs +++ b/crates/lingua/src/lib.rs @@ -36,12 +36,10 @@ pub use capabilities::ProviderFormat; // Re-export key processing functions (bytes-based API) pub use processing::{ - extract_model, finish_request_transform, normalize_json_lone_surrogate_escapes, parse_json, - parse_json_body, parse_json_value, parse_stream_event, prepare_request_transform, - response_to_universal, sanitize_payload, transform_request, transform_response, - transform_stream_chunk, ParsedJsonBody, ParsedStreamEvent, PreparedRequestTransform, - RequestTransformPreparation, StreamOutputChunk, StreamTransformSession, TransformError, - TransformResult, + extract_model, normalize_json_lone_surrogate_escapes, parse_json, parse_json_body, + parse_json_value, parse_stream_event, response_to_universal, sanitize_payload, + transform_request, transform_response, transform_stream_chunk, ParsedJsonBody, + ParsedStreamEvent, StreamOutputChunk, StreamTransformSession, TransformError, TransformResult, }; // Re-export universal types diff --git a/crates/lingua/src/processing/mod.rs b/crates/lingua/src/processing/mod.rs index c6b564fb..02351a47 100644 --- a/crates/lingua/src/processing/mod.rs +++ b/crates/lingua/src/processing/mod.rs @@ -16,8 +16,7 @@ pub use stream::{ parse_stream_event, ParsedStreamEvent, StreamOutputChunk, StreamTransformSession, }; pub use transform::{ - extract_model, finish_request_transform, parse_json, parse_json_body, parse_json_value, - prepare_request_transform, response_to_universal, sanitize_payload, transform_request, - transform_response, transform_stream_chunk, ParsedJsonBody, PreparedRequestTransform, - RequestTransformPreparation, TransformError, TransformResult, + extract_model, parse_json, parse_json_body, parse_json_value, response_to_universal, + sanitize_payload, transform_request, transform_response, transform_stream_chunk, + ParsedJsonBody, TransformError, TransformResult, }; diff --git a/crates/lingua/src/processing/transform.rs b/crates/lingua/src/processing/transform.rs index d40adee7..b11f96e8 100644 --- a/crates/lingua/src/processing/transform.rs +++ b/crates/lingua/src/processing/transform.rs @@ -22,9 +22,9 @@ use crate::providers::openai::model_needs_transforms; use crate::serde_json; use crate::serde_json::Value; use crate::universal::{ - AssistantContent, AssistantContentPart, Message, UniversalReasoningDelta, UniversalRequest, - UniversalResponse, UniversalStreamChoice, UniversalStreamChunk, UniversalStreamDelta, - UniversalToolCallDelta, UniversalToolFunctionDelta, + AssistantContent, AssistantContentPart, Message, UniversalReasoningDelta, UniversalResponse, + UniversalStreamChoice, UniversalStreamChunk, UniversalStreamDelta, UniversalToolCallDelta, + UniversalToolFunctionDelta, }; use serde::de::DeserializeOwned; use thiserror::Error; @@ -132,15 +132,6 @@ pub enum TransformResult { }, } -/// A request converted into Lingua's universal representation and ready for -/// caller-specific mutation before final target serialization. -#[derive(Debug, Clone)] -pub struct PreparedRequestTransform { - pub request: UniversalRequest, - pub source_format: ProviderFormat, - pub actual_target_format: ProviderFormat, -} - impl TransformResult { /// Check if this is a pass-through result (no transformation occurred). pub fn is_passthrough(&self) -> bool { @@ -312,26 +303,6 @@ pub fn transform_request( target_format: ProviderFormat, model: Option<&str>, ) -> Result { - match prepare_request_transform(input, target_format, model)? { - RequestTransformPreparation::PassThrough(bytes) => Ok(TransformResult::PassThrough(bytes)), - RequestTransformPreparation::Prepared(prepared) => finish_request_transform(*prepared), - } -} - -/// Result of preparing a request transform before final target serialization. -pub enum RequestTransformPreparation { - /// Payload was already valid for target format and no target model override - /// forced a universal conversion. - PassThrough(Bytes), - /// Payload has been converted to universal form with the target model applied. - Prepared(Box), -} - -pub fn prepare_request_transform( - input: Bytes, - target_format: ProviderFormat, - model: Option<&str>, -) -> Result { let parsed = parse_json_body(input)?; let payload = parsed.value; let request_bytes = parsed.bytes; @@ -368,10 +339,9 @@ pub fn prepare_request_transform( if source_format == target_format && !request_model_needs_forced_translation(request_model.as_deref(), model, target_format) - && !request_model_override_changes_model(request_model.as_deref(), model) && target_adapter.detect_passthrough_request(&payload) { - return Ok(RequestTransformPreparation::PassThrough(request_bytes)); + return Ok(TransformResult::PassThrough(request_bytes)); } let mut universal = source_adapter.request_to_universal(payload)?; @@ -380,45 +350,19 @@ pub fn prepare_request_transform( universal.model = Some(model.to_string()); } - Ok(RequestTransformPreparation::Prepared(Box::new( - PreparedRequestTransform { - request: universal, - source_format, - actual_target_format: target_format, - }, - ))) -} - -fn request_model_override_changes_model( - request_model: Option<&str>, - override_model: Option<&str>, -) -> bool { - match override_model { - Some(override_model) => request_model != Some(override_model), - None => false, - } -} - -pub fn finish_request_transform( - mut prepared: PreparedRequestTransform, -) -> Result { - let target_adapter = adapter_for_format(prepared.actual_target_format).ok_or( - TransformError::UnsupportedTargetFormat(prepared.actual_target_format), - )?; - // Apply target provider defaults (e.g., Anthropic's required max_tokens) - target_adapter.apply_defaults(&mut prepared.request); + target_adapter.apply_defaults(&mut universal); // Convert to target format (validation happens in adapter) - let transformed = target_adapter.request_from_universal(&prepared.request)?; + let transformed = target_adapter.request_from_universal(&universal)?; let bytes = crate::serde_json::to_vec(&transformed) .map_err(|e| TransformError::SerializationFailed(e.to_string()))?; Ok(TransformResult::Transformed { bytes: Bytes::from(bytes), - source_format: prepared.source_format, - actual_target_format: prepared.actual_target_format, + source_format, + actual_target_format: target_format, }) } @@ -2250,9 +2194,13 @@ mod tests { #[test] #[cfg(feature = "bedrock")] - fn test_converse_model_override_forces_universal_translation() { + fn test_transform_request_preserves_same_format_converse_passthrough_with_model_override() { let payload = json!({ "modelId": "anthropic.claude-3-haiku-20240307-v1:0", + "guardrailConfig": { + "guardrailIdentifier": "test", + "guardrailVersion": "1" + }, "messages": [{ "role": "user", "content": [{"text": "Hello"}] @@ -2267,13 +2215,13 @@ mod tests { ) .unwrap(); - assert!(!result.is_passthrough()); - assert_eq!(result.source_format(), Some(ProviderFormat::Converse)); + assert!(result.is_passthrough()); let output: Value = crate::serde_json::from_slice(result.as_bytes()).unwrap(); assert_eq!( output.get("modelId").and_then(Value::as_str), - Some("anthropic.claude-3-5-sonnet-20241022-v2:0") + Some("anthropic.claude-3-haiku-20240307-v1:0") ); + assert!(output.get("guardrailConfig").is_some()); } } From ce6b753ecc1309eaf777fc9f74ab18a3130bf29c Mon Sep 17 00:00:00 2001 From: Erin McNulty Date: Tue, 16 Jun 2026 16:31:35 -0400 Subject: [PATCH 22/28] only rewrite model name in body if absolutely needed --- .../src/providers/bedrock.rs | 4 +- .../src/providers/mod.rs | 106 +++++++++++++++++- crates/braintrust-llm-router/src/router.rs | 16 ++- 3 files changed, 117 insertions(+), 9 deletions(-) diff --git a/crates/braintrust-llm-router/src/providers/bedrock.rs b/crates/braintrust-llm-router/src/providers/bedrock.rs index cb3435a3..1dac4f28 100644 --- a/crates/braintrust-llm-router/src/providers/bedrock.rs +++ b/crates/braintrust-llm-router/src/providers/bedrock.rs @@ -23,7 +23,7 @@ use crate::auth::AuthConfig; use crate::catalog::ModelSpec; use crate::client::{build_middleware_client, ClientSettings}; use crate::error::{Error, Result, UpstreamHttpError}; -use crate::providers::{rewrite_body_model, ClientHeaders}; +use crate::providers::{rewrite_body_model_if_required, ClientHeaders}; use crate::streaming::{bedrock_event_stream, sse_stream, RawResponseStream}; use lingua::{ProviderFormat, TransformError}; @@ -91,7 +91,7 @@ where }; if source_adapter.format() == format { - return Ok(rewrite_body_model(body, format, &spec.model)); + return Ok(rewrite_body_model_if_required(body, format, &spec.model)); } let mut request = match source_adapter.request_to_universal(payload) { diff --git a/crates/braintrust-llm-router/src/providers/mod.rs b/crates/braintrust-llm-router/src/providers/mod.rs index 68c7ce98..be93ea0b 100644 --- a/crates/braintrust-llm-router/src/providers/mod.rs +++ b/crates/braintrust-llm-router/src/providers/mod.rs @@ -190,7 +190,63 @@ pub(crate) fn body_model_field(format: ProviderFormat) -> Option<&'static str> { } } -pub(crate) fn rewrite_body_model(payload: Bytes, format: ProviderFormat, model: &str) -> Bytes { +enum BodyModelRewrite { + Required, + NotRequired, + Unknown, +} + +#[derive(serde::Deserialize)] +struct BodyModel { + model: Option, +} + +#[derive(serde::Deserialize)] +struct BodyModelId { + #[serde(rename = "modelId")] + model_id: Option, +} + +fn body_model_rewrite_status( + payload: &[u8], + format: ProviderFormat, + model: &str, +) -> BodyModelRewrite { + match body_model_field(format) { + Some("model") => match serde_json::from_slice::(payload) { + Ok(parsed) => { + if parsed.model.as_deref() == Some(model) { + BodyModelRewrite::NotRequired + } else { + BodyModelRewrite::Required + } + } + Err(_) => BodyModelRewrite::Unknown, + }, + Some("modelId") => match serde_json::from_slice::(payload) { + Ok(parsed) => { + if parsed.model_id.as_deref() == Some(model) { + BodyModelRewrite::NotRequired + } else { + BodyModelRewrite::Required + } + } + Err(_) => BodyModelRewrite::Unknown, + }, + Some(_) | None => BodyModelRewrite::Unknown, + } +} + +pub(crate) fn rewrite_body_model_if_required( + payload: Bytes, + format: ProviderFormat, + model: &str, +) -> Bytes { + match body_model_rewrite_status(&payload, format, model) { + BodyModelRewrite::Required => {} + BodyModelRewrite::NotRequired | BodyModelRewrite::Unknown => return payload, + } + let Some(model_field) = body_model_field(format) else { return payload; }; @@ -343,6 +399,54 @@ impl dyn Provider { mod tests { use super::*; + #[test] + fn rewrite_body_model_if_required_leaves_matching_model_bytes_unchanged() { + let payload = Bytes::from_static(br#"{"model":"gpt-4o","messages":[]}"#); + let original_ptr = payload.as_ptr(); + + let updated = + rewrite_body_model_if_required(payload, ProviderFormat::ChatCompletions, "gpt-4o"); + + assert_eq!(updated.as_ptr(), original_ptr); + } + + #[test] + fn rewrite_body_model_if_required_rewrites_mismatched_model_field() { + let payload = Bytes::from_static(br#"{"model":"gpt-4","messages":[]}"#); + + let updated = + rewrite_body_model_if_required(payload, ProviderFormat::ChatCompletions, "gpt-4o"); + let value: Value = serde_json::from_slice(&updated).unwrap(); + + assert_eq!(value.get("model").and_then(Value::as_str), Some("gpt-4o")); + } + + #[test] + fn rewrite_body_model_if_required_rewrites_mismatched_model_id_field() { + let payload = Bytes::from_static( + br#"{"modelId":"model-a","messages":[{"role":"user","content":[]}]}"#, + ); + + let updated = rewrite_body_model_if_required(payload, ProviderFormat::Converse, "model-b"); + let value: Value = serde_json::from_slice(&updated).unwrap(); + + assert_eq!( + value.get("modelId").and_then(Value::as_str), + Some("model-b") + ); + } + + #[test] + fn rewrite_body_model_if_required_leaves_unknown_payload_unchanged() { + let payload = Bytes::from_static(b"not-json"); + let original_ptr = payload.as_ptr(); + + let updated = + rewrite_body_model_if_required(payload, ProviderFormat::ChatCompletions, "gpt-4o"); + + assert_eq!(updated.as_ptr(), original_ptr); + } + #[test] fn disable_streaming_payload_removes_stream_fields() { let payload = Bytes::from_static( diff --git a/crates/braintrust-llm-router/src/router.rs b/crates/braintrust-llm-router/src/router.rs index a88199d3..510ed78b 100644 --- a/crates/braintrust-llm-router/src/router.rs +++ b/crates/braintrust-llm-router/src/router.rs @@ -16,7 +16,7 @@ use crate::client::ClientSettings; use crate::error::{Error, Result}; use crate::providers::{ enable_streaming_payload, prepare_bedrock_request, requires_bedrock_request_preparation, - rewrite_body_model, ClientHeaders, Provider, + rewrite_body_model_if_required, ClientHeaders, Provider, }; use crate::retry::{RetryPolicy, RetryStrategy}; use crate::streaming::{ @@ -219,19 +219,23 @@ async fn prepare_provider_request( return Ok((bytes, Some(format), format)); } - let (transformed, detected_format, actual_format) = + let (transformed, detected_format, actual_format, maybe_rewrite_model) = match lingua::transform_request(body.clone(), format, Some(&spec.model)) { - Ok(TransformResult::PassThrough(bytes)) => (bytes, None, format), + Ok(TransformResult::PassThrough(bytes)) => (bytes, None, format, true), Ok(TransformResult::Transformed { bytes, source_format, actual_target_format, - }) => (bytes, Some(source_format), actual_target_format), - Err(TransformError::UnsupportedTargetFormat(_)) => (body, None, format), + }) => (bytes, Some(source_format), actual_target_format, false), + Err(TransformError::UnsupportedTargetFormat(_)) => (body, None, format, true), Err(err) => return Err(err.into()), }; - let transformed = rewrite_body_model(transformed, actual_format, &spec.model); + let transformed = if maybe_rewrite_model { + rewrite_body_model_if_required(transformed, actual_format, &spec.model) + } else { + transformed + }; if stream { // TODO: Fold streaming intent into `lingua::transform_request` once we From 192cea1d39a5040c05c9dd1f3a2a697951379ac5 Mon Sep 17 00:00:00 2001 From: Erin McNulty Date: Tue, 16 Jun 2026 16:44:33 -0400 Subject: [PATCH 23/28] tighten more --- .../src/providers/bedrock.rs | 4 +- .../src/providers/mod.rs | 45 +++++++++---------- crates/braintrust-llm-router/src/router.rs | 4 +- 3 files changed, 25 insertions(+), 28 deletions(-) diff --git a/crates/braintrust-llm-router/src/providers/bedrock.rs b/crates/braintrust-llm-router/src/providers/bedrock.rs index 1dac4f28..0aaddd1b 100644 --- a/crates/braintrust-llm-router/src/providers/bedrock.rs +++ b/crates/braintrust-llm-router/src/providers/bedrock.rs @@ -654,7 +654,7 @@ mod tests { } #[tokio::test] - async fn prepare_request_rewrites_same_format_converse_model_without_losing_native_fields() { + async fn prepare_request_preserves_same_format_converse_model_and_native_fields() { let body = Bytes::from( lingua::serde_json::to_vec(&lingua::serde_json::json!({ "modelId": "anthropic.claude-3-haiku-20240307-v1:0", @@ -690,7 +690,7 @@ mod tests { let value: lingua::serde_json::Value = lingua::serde_json::from_slice(&prepared).unwrap(); assert_eq!( value.get("modelId").and_then(|v| v.as_str()), - Some("anthropic.claude-3-5-sonnet-20241022-v2:0") + Some("anthropic.claude-3-haiku-20240307-v1:0") ); assert_eq!( value.pointer("/system/0/text").and_then(|v| v.as_str()), diff --git a/crates/braintrust-llm-router/src/providers/mod.rs b/crates/braintrust-llm-router/src/providers/mod.rs index be93ea0b..92980550 100644 --- a/crates/braintrust-llm-router/src/providers/mod.rs +++ b/crates/braintrust-llm-router/src/providers/mod.rs @@ -181,10 +181,10 @@ pub(crate) fn body_model_field(format: ProviderFormat) -> Option<&'static str> { ProviderFormat::ChatCompletions | ProviderFormat::Responses | ProviderFormat::Anthropic - | ProviderFormat::Google | ProviderFormat::Mistral => Some("model"), - ProviderFormat::Converse => Some("modelId"), - ProviderFormat::BedrockAnthropic + ProviderFormat::Google + | ProviderFormat::Converse + | ProviderFormat::BedrockAnthropic | ProviderFormat::VertexAnthropic | ProviderFormat::Unknown => None, } @@ -201,12 +201,6 @@ struct BodyModel { model: Option, } -#[derive(serde::Deserialize)] -struct BodyModelId { - #[serde(rename = "modelId")] - model_id: Option, -} - fn body_model_rewrite_status( payload: &[u8], format: ProviderFormat, @@ -223,16 +217,6 @@ fn body_model_rewrite_status( } Err(_) => BodyModelRewrite::Unknown, }, - Some("modelId") => match serde_json::from_slice::(payload) { - Ok(parsed) => { - if parsed.model_id.as_deref() == Some(model) { - BodyModelRewrite::NotRequired - } else { - BodyModelRewrite::Required - } - } - Err(_) => BodyModelRewrite::Unknown, - }, Some(_) | None => BodyModelRewrite::Unknown, } } @@ -422,18 +406,31 @@ mod tests { } #[test] - fn rewrite_body_model_if_required_rewrites_mismatched_model_id_field() { + fn rewrite_body_model_if_required_leaves_converse_payload_unchanged() { let payload = Bytes::from_static( br#"{"modelId":"model-a","messages":[{"role":"user","content":[]}]}"#, ); + let original_ptr = payload.as_ptr(); let updated = rewrite_body_model_if_required(payload, ProviderFormat::Converse, "model-b"); - let value: Value = serde_json::from_slice(&updated).unwrap(); - assert_eq!( - value.get("modelId").and_then(Value::as_str), - Some("model-b") + assert_eq!(updated.as_ptr(), original_ptr); + } + + #[test] + fn rewrite_body_model_if_required_leaves_google_payload_unchanged() { + let payload = Bytes::from_static( + br#"{"model":"gemini-2.5-flash","contents":[{"role":"user","parts":[{"text":"hi"}]}]}"#, + ); + let original_ptr = payload.as_ptr(); + + let updated = rewrite_body_model_if_required( + payload, + ProviderFormat::Google, + "models/gemini-2.5-pro", ); + + assert_eq!(updated.as_ptr(), original_ptr); } #[test] diff --git a/crates/braintrust-llm-router/src/router.rs b/crates/braintrust-llm-router/src/router.rs index 510ed78b..25a17833 100644 --- a/crates/braintrust-llm-router/src/router.rs +++ b/crates/braintrust-llm-router/src/router.rs @@ -1461,7 +1461,7 @@ mod tests { } #[tokio::test] - async fn prepare_provider_request_rewrites_model_for_google_pass_through() { + async fn prepare_provider_request_does_not_rewrite_model_for_google_pass_through() { let body = Bytes::from_static( br#"{"model":"models/gemini-2.5-flash","contents":[{"role":"user","parts":[{"text":"Ping"}]}]}"#, ); @@ -1492,7 +1492,7 @@ mod tests { assert_eq!(actual_format, ProviderFormat::Google); assert_eq!( parsed.get("model").and_then(Value::as_str), - Some("models/gemini-2.5-pro") + Some("models/gemini-2.5-flash") ); assert!(parsed.get("contents").is_some()); } From 41e2a9f84f5f56ec257f6b156be2c6933025520e Mon Sep 17 00:00:00 2001 From: Erin McNulty Date: Tue, 16 Jun 2026 16:58:22 -0400 Subject: [PATCH 24/28] if the model name diff was intentional, don't change it --- crates/braintrust-llm-router/src/router.rs | 182 +++++++++++++++++---- 1 file changed, 154 insertions(+), 28 deletions(-) diff --git a/crates/braintrust-llm-router/src/router.rs b/crates/braintrust-llm-router/src/router.rs index 25a17833..b0ca76d3 100644 --- a/crates/braintrust-llm-router/src/router.rs +++ b/crates/braintrust-llm-router/src/router.rs @@ -208,11 +208,25 @@ struct PreparedRequestInner { strategy: RetryStrategy, } +#[derive(Clone, Copy)] +struct RequestPreparationOptions { + rewrite_body_model: bool, +} + +impl Default for RequestPreparationOptions { + fn default() -> Self { + Self { + rewrite_body_model: true, + } + } +} + async fn prepare_provider_request( body: Bytes, spec: &ModelSpec, format: ProviderFormat, stream: bool, + options: RequestPreparationOptions, ) -> Result<(Bytes, Option, ProviderFormat)> { if requires_bedrock_request_preparation(format) { let bytes = prepare_bedrock_request(body, spec, format).await?; @@ -231,7 +245,7 @@ async fn prepare_provider_request( Err(err) => return Err(err.into()), }; - let transformed = if maybe_rewrite_model { + let transformed = if options.rewrite_body_model && maybe_rewrite_model { rewrite_body_model_if_required(transformed, actual_format, &spec.model) } else { transformed @@ -275,9 +289,11 @@ impl Router { output_format: ProviderFormat, route: &ProviderRoute, stream: bool, + options: RequestPreparationOptions, ) -> Result<(PreparedRequestInner, RouterMetadata)> { let (payload, detected_format, actual_format) = - prepare_provider_request(body, route.spec.as_ref(), route.format, stream).await?; + prepare_provider_request(body, route.spec.as_ref(), route.format, stream, options) + .await?; Ok(( PreparedRequestInner { provider: route.provider.clone(), @@ -320,7 +336,33 @@ impl Router { route: &ProviderRoute, ) -> Result<(PreparedRequest, RouterMetadata)> { let (inner, metadata) = self - .create_prepared_request_internal(body, output_format, route, false) + .create_prepared_request_internal( + body, + output_format, + route, + false, + RequestPreparationOptions::default(), + ) + .await?; + Ok((PreparedRequest { inner }, metadata)) + } + + pub async fn create_request_preserving_body_model( + &self, + body: Bytes, + output_format: ProviderFormat, + route: &ProviderRoute, + ) -> Result<(PreparedRequest, RouterMetadata)> { + let (inner, metadata) = self + .create_prepared_request_internal( + body, + output_format, + route, + false, + RequestPreparationOptions { + rewrite_body_model: false, + }, + ) .await?; Ok((PreparedRequest { inner }, metadata)) } @@ -426,7 +468,33 @@ impl Router { route: &ProviderRoute, ) -> Result<(PreparedStreamRequest, RouterMetadata)> { let (inner, metadata) = self - .create_prepared_request_internal(body, output_format, route, true) + .create_prepared_request_internal( + body, + output_format, + route, + true, + RequestPreparationOptions::default(), + ) + .await?; + Ok((PreparedStreamRequest { inner }, metadata)) + } + + pub async fn create_stream_request_preserving_body_model( + &self, + body: Bytes, + output_format: ProviderFormat, + route: &ProviderRoute, + ) -> Result<(PreparedStreamRequest, RouterMetadata)> { + let (inner, metadata) = self + .create_prepared_request_internal( + body, + output_format, + route, + true, + RequestPreparationOptions { + rewrite_body_model: false, + }, + ) .await?; Ok((PreparedStreamRequest { inner }, metadata)) } @@ -1397,10 +1465,15 @@ mod tests { ); let spec = openai_spec("gpt-5-mini", ModelFlavor::Chat); - let (payload, _, _) = - prepare_provider_request(body, &spec, ProviderFormat::ChatCompletions, true) - .await - .expect("request prepares"); + let (payload, _, _) = prepare_provider_request( + body, + &spec, + ProviderFormat::ChatCompletions, + true, + RequestPreparationOptions::default(), + ) + .await + .expect("request prepares"); let parsed: Value = serde_json::from_slice(&payload).expect("valid request json"); assert_eq!(parsed.get("stream"), Some(&Value::Bool(true))); @@ -1415,10 +1488,15 @@ mod tests { ); let spec = openai_spec("gpt-5-mini", ModelFlavor::Chat); - let (payload, _, _) = - prepare_provider_request(body, &spec, ProviderFormat::ChatCompletions, false) - .await - .expect("request prepares"); + let (payload, _, _) = prepare_provider_request( + body, + &spec, + ProviderFormat::ChatCompletions, + false, + RequestPreparationOptions::default(), + ) + .await + .expect("request prepares"); let parsed: Value = serde_json::from_slice(&payload).expect("valid request json"); assert_eq!(parsed.get("stream"), None); @@ -1448,10 +1526,15 @@ mod tests { available_providers: vec!["vertex".to_string()], }; - let (payload, _, actual_format) = - prepare_provider_request(body, &spec, ProviderFormat::VertexAnthropic, false) - .await - .expect("request prepares"); + let (payload, _, actual_format) = prepare_provider_request( + body, + &spec, + ProviderFormat::VertexAnthropic, + false, + RequestPreparationOptions::default(), + ) + .await + .expect("request prepares"); let parsed: Value = serde_json::from_slice(&payload).expect("valid request json"); assert_eq!(actual_format, ProviderFormat::VertexAnthropic); @@ -1483,10 +1566,15 @@ mod tests { available_providers: vec!["google".to_string()], }; - let (payload, _, actual_format) = - prepare_provider_request(body, &spec, ProviderFormat::Google, false) - .await - .expect("request prepares"); + let (payload, _, actual_format) = prepare_provider_request( + body, + &spec, + ProviderFormat::Google, + false, + RequestPreparationOptions::default(), + ) + .await + .expect("request prepares"); let parsed: Value = serde_json::from_slice(&payload).expect("valid request json"); assert_eq!(actual_format, ProviderFormat::Google); @@ -1505,10 +1593,15 @@ mod tests { ); let spec = openai_spec("gpt-4o", ModelFlavor::Chat); - let (payload, _, actual_format) = - prepare_provider_request(body, &spec, ProviderFormat::ChatCompletions, false) - .await - .expect("request prepares"); + let (payload, _, actual_format) = prepare_provider_request( + body, + &spec, + ProviderFormat::ChatCompletions, + false, + RequestPreparationOptions::default(), + ) + .await + .expect("request prepares"); let parsed: Value = serde_json::from_slice(&payload).expect("valid request json"); assert_eq!(actual_format, ProviderFormat::ChatCompletions); @@ -1519,6 +1612,34 @@ mod tests { ); } + #[tokio::test] + async fn prepare_provider_request_can_preserve_same_format_body_model() { + let body = Bytes::from_static( + br#"{"model":"gpt-4","messages":[{"role":"user","name":"example_user","content":"Ping"}]}"#, + ); + let spec = openai_spec("gpt-4o", ModelFlavor::Chat); + + let (payload, _, actual_format) = prepare_provider_request( + body, + &spec, + ProviderFormat::ChatCompletions, + false, + RequestPreparationOptions { + rewrite_body_model: false, + }, + ) + .await + .expect("request prepares"); + let parsed: Value = serde_json::from_slice(&payload).expect("valid request json"); + + assert_eq!(actual_format, ProviderFormat::ChatCompletions); + assert_eq!(parsed.get("model").and_then(Value::as_str), Some("gpt-4")); + assert_eq!( + parsed.pointer("/messages/0/name").and_then(Value::as_str), + Some("example_user") + ); + } + #[tokio::test] async fn prepare_provider_request_upgrades_actual_format_to_responses_for_reasoning_plus_tools() { @@ -1546,10 +1667,15 @@ mod tests { ); let spec = openai_spec("gpt-5.4-mini", ModelFlavor::Chat); - let (_, _, actual_format) = - prepare_provider_request(body, &spec, ProviderFormat::ChatCompletions, false) - .await - .expect("request prepares"); + let (_, _, actual_format) = prepare_provider_request( + body, + &spec, + ProviderFormat::ChatCompletions, + false, + RequestPreparationOptions::default(), + ) + .await + .expect("request prepares"); assert_eq!( actual_format, From 0985268e480921cd643ece047f5570b8770defa9 Mon Sep 17 00:00:00 2001 From: Erin McNulty Date: Tue, 16 Jun 2026 17:30:31 -0400 Subject: [PATCH 25/28] move stuff out of mod.rs --- .../src/catalog/fallback.rs | 255 +++++++++++++++++- .../braintrust-llm-router/src/catalog/mod.rs | 243 +---------------- .../src/providers/body_model.rs | 143 ++++++++++ .../src/providers/mod.rs | 137 +--------- 4 files changed, 406 insertions(+), 372 deletions(-) create mode 100644 crates/braintrust-llm-router/src/providers/body_model.rs diff --git a/crates/braintrust-llm-router/src/catalog/fallback.rs b/crates/braintrust-llm-router/src/catalog/fallback.rs index 8a4f9f7c..16705f6f 100644 --- a/crates/braintrust-llm-router/src/catalog/fallback.rs +++ b/crates/braintrust-llm-router/src/catalog/fallback.rs @@ -1,6 +1,259 @@ use std::collections::{HashMap, HashSet}; +use std::sync::Arc; -pub(super) fn build_equivalence_index( +use super::{ModelCatalog, ModelSpec}; +use crate::error::{Error, Result}; + +/// A request-local catalog overlay. +/// +/// Secret-defined custom models live in `custom` and shadow entries in the +/// shared `base` catalog. This avoids cloning the base catalog when adding +/// per-request model definitions. +#[derive(Debug, Clone)] +pub struct OverlayModelCatalog { + base: Arc, + custom: ModelCatalog, + custom_model_names: HashSet, + overlay_edges: HashMap>, +} + +impl OverlayModelCatalog { + pub fn new(base: Arc, custom: ModelCatalog) -> Self { + let custom_model_names: HashSet = custom.models.keys().cloned().collect(); + let mut overlay_edges: HashMap> = HashMap::new(); + for (name, fallbacks) in &custom.fallback_models { + if !custom.models.contains_key(name) { + continue; + } + for fallback_model in fallbacks { + let fallback_is_visible = custom.models.contains_key(fallback_model) + || (base.models.contains_key(fallback_model) + && !custom_model_names.contains(fallback_model)); + if !fallback_is_visible { + continue; + } + overlay_edges + .entry(name.clone()) + .or_default() + .push(fallback_model.clone()); + overlay_edges + .entry(fallback_model.clone()) + .or_default() + .push(name.clone()); + } + } + Self { + base, + custom, + custom_model_names, + overlay_edges, + } + } + + pub fn base_catalog(&self) -> Arc { + Arc::clone(&self.base) + } + + pub fn get(&self, name: &str) -> Option> { + self.custom.get(name).or_else(|| self.base.get(name)) + } + + pub fn find_fallback_models(&self, name: &str) -> Vec { + let Some(_) = self.get(name) else { + return Vec::new(); + }; + + let mut visited = HashSet::new(); + let mut stack = vec![name.to_string()]; + while let Some(current) = stack.pop() { + if !visited.insert(current.clone()) { + continue; + } + + if !self.custom_model_names.contains(¤t) { + stack.extend( + self.base + .fallback_models(¤t) + .into_iter() + .filter(|model_name| !self.custom_model_names.contains(model_name)), + ); + } + if let Some(neighbors) = self.overlay_edges.get(¤t) { + stack.extend(neighbors.iter().cloned()); + } + } + + let mut names = vec![name.to_string()]; + visited.remove(name); + let mut equivalent_names: Vec = visited.into_iter().collect(); + equivalent_names.sort(); + names.extend(equivalent_names); + names + } +} + +enum FallbackModelSource<'a> { + Json(&'a str), + Parsed(HashMap>), +} + +impl ModelCatalog { + pub fn fallback_models(&self, name: &str) -> Vec { + let Some(_) = self.models.get(name) else { + return Vec::new(); + }; + + let mut names = vec![name.to_string()]; + if let Some(equivalent_names) = self.equivalence_index.get(name) { + names.extend(equivalent_names.iter().cloned()); + } + names + } + + pub fn add_fallback_models(&mut self, name: String, fallback_models: I) -> Result<()> + where + I: IntoIterator, + { + if !self.models.contains_key(&name) { + return Err(Error::InvalidRequest(format!( + "model '{name}' references fallback_models but is missing from catalog" + ))); + } + + let fallback_models: Vec = fallback_models + .into_iter() + .filter(|fallback_model| !fallback_model.is_empty()) + .collect(); + for fallback_model in &fallback_models { + if !self.models.contains_key(fallback_model) { + return Err(Error::InvalidRequest(format!( + "model '{name}' references missing fallback model '{fallback_model}'" + ))); + } + } + + let mut next_fallback_models = self.fallback_models.clone(); + let entry = next_fallback_models.entry(name).or_default(); + for fallback_model in fallback_models { + if entry.contains(&fallback_model) { + continue; + } + entry.push(fallback_model); + } + self.set_fallback_models(FallbackModelSource::Parsed(next_fallback_models), false)?; + Ok(()) + } + + pub fn add_external_fallback_models( + &mut self, + name: String, + fallback_models: I, + ) -> Result<()> + where + I: IntoIterator, + { + if !self.models.contains_key(&name) { + return Err(Error::InvalidRequest(format!( + "model '{name}' references fallback_models but is missing from catalog" + ))); + } + + let mut next_fallback_models = self.fallback_models.clone(); + let entry = next_fallback_models.entry(name).or_default(); + for fallback_model in fallback_models { + if fallback_model.is_empty() || entry.contains(&fallback_model) { + continue; + } + entry.push(fallback_model); + } + self.set_fallback_models(FallbackModelSource::Parsed(next_fallback_models), false)?; + Ok(()) + } + + pub(super) fn set_fallback_models_from_json( + &mut self, + content: &str, + validate_targets: bool, + ) -> Result<()> { + self.set_fallback_models(FallbackModelSource::Json(content), validate_targets) + } + + pub(super) fn set_fallback_models_from_parsed( + &mut self, + fallback_models: HashMap>, + validate_targets: bool, + ) -> Result<()> { + self.set_fallback_models( + FallbackModelSource::Parsed(fallback_models), + validate_targets, + ) + } + + fn set_fallback_models( + &mut self, + source: FallbackModelSource<'_>, + validate_targets: bool, + ) -> Result<()> { + let fallback_models = match source { + FallbackModelSource::Json(content) => parse_fallback_models(content)?, + FallbackModelSource::Parsed(fallback_models) => fallback_models, + }; + + if validate_targets { + validate_fallback_models(&self.models, &fallback_models)?; + } + self.equivalence_index = + build_equivalence_index(self.models.keys().cloned().collect(), &fallback_models); + self.fallback_models = fallback_models; + Ok(()) + } +} + +fn parse_fallback_models(content: &str) -> Result>> { + let raw: HashMap = serde_json::from_str(content)?; + let mut fallback_models = HashMap::new(); + for (name, value) in raw { + let Some(fallbacks) = value.get("fallback_models") else { + continue; + }; + let Some(fallbacks) = fallbacks.as_array() else { + return Err(Error::InvalidRequest(format!( + "model '{name}' has invalid fallback_models" + ))); + }; + let mut parsed = Vec::with_capacity(fallbacks.len()); + for fallback_model in fallbacks { + let Some(fallback_model) = fallback_model.as_str() else { + return Err(Error::InvalidRequest(format!( + "model '{name}' has invalid fallback_models" + ))); + }; + parsed.push(fallback_model.to_string()); + } + if !parsed.is_empty() { + fallback_models.insert(name, parsed); + } + } + Ok(fallback_models) +} + +fn validate_fallback_models( + models: &HashMap>, + fallback_models: &HashMap>, +) -> Result<()> { + for (name, fallback_models) in fallback_models { + for fallback_model in fallback_models { + if !models.contains_key(fallback_model) { + return Err(Error::InvalidRequest(format!( + "model '{name}' references missing fallback model '{fallback_model}'" + ))); + } + } + } + Ok(()) +} + +fn build_equivalence_index( model_names: HashSet, fallback_models: &HashMap>, ) -> HashMap> { diff --git a/crates/braintrust-llm-router/src/catalog/mod.rs b/crates/braintrust-llm-router/src/catalog/mod.rs index 5fff8358..be6be2f2 100644 --- a/crates/braintrust-llm-router/src/catalog/mod.rs +++ b/crates/braintrust-llm-router/src/catalog/mod.rs @@ -2,21 +2,20 @@ mod fallback; mod resolver; pub mod spec; -use fallback::build_equivalence_index; - +pub use fallback::OverlayModelCatalog; pub(crate) use resolver::is_gemini_api_model; pub use resolver::ModelResolver; pub use spec::{ModelFlavor, ModelSpec}; use lingua::ProviderFormat; -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::fs::File; use std::io::Read; use std::path::Path; use std::sync::Arc; -use crate::error::{Error, Result}; +use crate::error::Result; #[derive(Debug, Clone, Default)] pub struct ModelCatalog { @@ -27,94 +26,6 @@ pub struct ModelCatalog { equivalence_index: HashMap>, } -/// A request-local catalog overlay. -/// -/// Secret-defined custom models live in `custom` and shadow entries in the -/// shared `base` catalog. This avoids cloning the base catalog when adding -/// per-request model definitions. -#[derive(Debug, Clone)] -pub struct OverlayModelCatalog { - base: Arc, - custom: ModelCatalog, - custom_model_names: HashSet, - overlay_edges: HashMap>, -} - -impl OverlayModelCatalog { - pub fn new(base: Arc, custom: ModelCatalog) -> Self { - let custom_model_names: HashSet = custom.models.keys().cloned().collect(); - let mut overlay_edges: HashMap> = HashMap::new(); - for (name, fallbacks) in &custom.fallback_models { - if !custom.models.contains_key(name) { - continue; - } - for fallback_model in fallbacks { - let fallback_is_visible = custom.models.contains_key(fallback_model) - || (base.models.contains_key(fallback_model) - && !custom_model_names.contains(fallback_model)); - if !fallback_is_visible { - continue; - } - overlay_edges - .entry(name.clone()) - .or_default() - .push(fallback_model.clone()); - overlay_edges - .entry(fallback_model.clone()) - .or_default() - .push(name.clone()); - } - } - Self { - base, - custom, - custom_model_names, - overlay_edges, - } - } - - pub fn base_catalog(&self) -> Arc { - Arc::clone(&self.base) - } - - pub fn get(&self, name: &str) -> Option> { - self.custom.get(name).or_else(|| self.base.get(name)) - } - - pub fn find_fallback_models(&self, name: &str) -> Vec { - let Some(_) = self.get(name) else { - return Vec::new(); - }; - - let mut visited = HashSet::new(); - let mut stack = vec![name.to_string()]; - while let Some(current) = stack.pop() { - if !visited.insert(current.clone()) { - continue; - } - - if !self.custom_model_names.contains(¤t) { - stack.extend( - self.base - .fallback_models(¤t) - .into_iter() - .filter(|model_name| !self.custom_model_names.contains(model_name)), - ); - } - if let Some(neighbors) = self.overlay_edges.get(¤t) { - stack.extend(neighbors.iter().cloned()); - } - } - - let mut names = vec![name.to_string()]; - visited.remove(name); - let mut equivalent_names: Vec = visited.into_iter().collect(); - equivalent_names.sort(); - names.extend(equivalent_names); - names - } -} - /// Catalog view used by the router resolver. /// /// `Base` preserves the existing router behavior. `Overlay` checks custom @@ -154,11 +65,6 @@ impl From> for CatalogResolver { } } -enum FallbackModelSource<'a> { - Json(&'a str), - Parsed(HashMap>), -} - impl ModelCatalog { pub fn empty() -> Self { Self::default() @@ -170,7 +76,7 @@ impl ModelCatalog { for (name, spec) in raw { catalog.insert(name, spec); } - catalog.set_fallback_models(FallbackModelSource::Json(content), true)?; + catalog.set_fallback_models_from_json(content, true)?; Ok(catalog) } @@ -189,18 +95,6 @@ impl ModelCatalog { self.models.get(name).cloned() } - pub fn fallback_models(&self, name: &str) -> Vec { - let Some(_) = self.models.get(name) else { - return Vec::new(); - }; - - let mut names = vec![name.to_string()]; - if let Some(equivalent_names) = self.equivalence_index.get(name) { - names.extend(equivalent_names.iter().cloned()); - } - names - } - pub fn resolve_format(&self, model: &str) -> Option { self.models.get(model).map(|spec| spec.format) } @@ -257,11 +151,8 @@ impl ModelCatalog { for (name, spec) in &self.models { out.insert(name.clone(), f(name, spec.as_ref())); } - out.set_fallback_models( - FallbackModelSource::Parsed(self.fallback_models.clone()), - false, - ) - .expect("existing catalog fallback_models remain valid after mapping specs"); + out.set_fallback_models_from_parsed(self.fallback_models.clone(), false) + .expect("existing catalog fallback_models remain valid after mapping specs"); out } @@ -286,127 +177,6 @@ impl ModelCatalog { self.by_parent.entry(parent).or_default().push(name); } } - - pub fn add_fallback_models(&mut self, name: String, fallback_models: I) -> Result<()> - where - I: IntoIterator, - { - if !self.models.contains_key(&name) { - return Err(Error::InvalidRequest(format!( - "model '{name}' references fallback_models but is missing from catalog" - ))); - } - - let fallback_models: Vec = fallback_models - .into_iter() - .filter(|fallback_model| !fallback_model.is_empty()) - .collect(); - for fallback_model in &fallback_models { - if !self.models.contains_key(fallback_model) { - return Err(Error::InvalidRequest(format!( - "model '{name}' references missing fallback model '{fallback_model}'" - ))); - } - } - - let mut next_fallback_models = self.fallback_models.clone(); - let entry = next_fallback_models.entry(name).or_default(); - for fallback_model in fallback_models { - if entry.contains(&fallback_model) { - continue; - } - entry.push(fallback_model); - } - self.set_fallback_models(FallbackModelSource::Parsed(next_fallback_models), false)?; - Ok(()) - } - - pub fn add_external_fallback_models( - &mut self, - name: String, - fallback_models: I, - ) -> Result<()> - where - I: IntoIterator, - { - if !self.models.contains_key(&name) { - return Err(Error::InvalidRequest(format!( - "model '{name}' references fallback_models but is missing from catalog" - ))); - } - - let mut next_fallback_models = self.fallback_models.clone(); - let entry = next_fallback_models.entry(name).or_default(); - for fallback_model in fallback_models { - if fallback_model.is_empty() || entry.contains(&fallback_model) { - continue; - } - entry.push(fallback_model); - } - self.set_fallback_models(FallbackModelSource::Parsed(next_fallback_models), false)?; - Ok(()) - } - - fn set_fallback_models( - &mut self, - source: FallbackModelSource<'_>, - validate_targets: bool, - ) -> Result<()> { - let fallback_models = match source { - FallbackModelSource::Json(content) => { - let raw: HashMap = serde_json::from_str(content)?; - let mut fallback_models = HashMap::new(); - for (name, value) in raw { - let Some(fallbacks) = value.get("fallback_models") else { - continue; - }; - let Some(fallbacks) = fallbacks.as_array() else { - return Err(Error::InvalidRequest(format!( - "model '{name}' has invalid fallback_models" - ))); - }; - let mut parsed = Vec::with_capacity(fallbacks.len()); - for fallback_model in fallbacks { - let Some(fallback_model) = fallback_model.as_str() else { - return Err(Error::InvalidRequest(format!( - "model '{name}' has invalid fallback_models" - ))); - }; - parsed.push(fallback_model.to_string()); - } - if !parsed.is_empty() { - fallback_models.insert(name, parsed); - } - } - fallback_models - } - FallbackModelSource::Parsed(fallback_models) => fallback_models, - }; - - if validate_targets { - Self::validate_fallback_models(&self.models, &fallback_models)?; - } - self.equivalence_index = - build_equivalence_index(self.models.keys().cloned().collect(), &fallback_models); - self.fallback_models = fallback_models; - Ok(()) - } - - fn validate_fallback_models( - models: &HashMap>, - fallback_models: &HashMap>, - ) -> Result<()> { - for (name, fallback_models) in fallback_models { - for fallback_model in fallback_models { - if !models.contains_key(fallback_model) { - return Err(Error::InvalidRequest(format!( - "model '{name}' references missing fallback model '{fallback_model}'" - ))); - } - } - } - Ok(()) - } } pub fn load_catalog_from_disk>(path: P) -> Result> { @@ -422,6 +192,7 @@ pub fn load_catalog_from_disk>(path: P) -> Result Option<&'static str> { + match format { + ProviderFormat::ChatCompletions + | ProviderFormat::Responses + | ProviderFormat::Anthropic + | ProviderFormat::Mistral => Some("model"), + ProviderFormat::Google + | ProviderFormat::Converse + | ProviderFormat::BedrockAnthropic + | ProviderFormat::VertexAnthropic + | ProviderFormat::Unknown => None, + } +} + +enum BodyModelRewrite { + Required, + NotRequired, + Unknown, +} + +#[derive(serde::Deserialize)] +struct BodyModel { + model: Option, +} + +fn body_model_rewrite_status( + payload: &[u8], + format: ProviderFormat, + model: &str, +) -> BodyModelRewrite { + match body_model_field(format) { + Some("model") => match serde_json::from_slice::(payload) { + Ok(parsed) => { + if parsed.model.as_deref() == Some(model) { + BodyModelRewrite::NotRequired + } else { + BodyModelRewrite::Required + } + } + Err(_) => BodyModelRewrite::Unknown, + }, + Some(_) | None => BodyModelRewrite::Unknown, + } +} + +pub(crate) fn rewrite_body_model_if_required( + payload: Bytes, + format: ProviderFormat, + model: &str, +) -> Bytes { + match body_model_rewrite_status(&payload, format, model) { + BodyModelRewrite::Required => {} + BodyModelRewrite::NotRequired | BodyModelRewrite::Unknown => return payload, + } + + let Some(model_field) = body_model_field(format) else { + return payload; + }; + let Ok(mut value) = serde_json::from_slice::(&payload) else { + return payload; + }; + let Some(object) = value.as_object_mut() else { + return payload; + }; + if object.get(model_field).and_then(Value::as_str) == Some(model) { + return payload; + } + object.insert(model_field.to_string(), Value::String(model.to_string())); + match serde_json::to_vec(&value) { + Ok(serialized) => Bytes::from(serialized), + Err(_) => payload, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn rewrite_body_model_if_required_leaves_matching_model_bytes_unchanged() { + let payload = Bytes::from_static(br#"{"model":"gpt-4o","messages":[]}"#); + let original_ptr = payload.as_ptr(); + + let updated = + rewrite_body_model_if_required(payload, ProviderFormat::ChatCompletions, "gpt-4o"); + + assert_eq!(updated.as_ptr(), original_ptr); + } + + #[test] + fn rewrite_body_model_if_required_rewrites_mismatched_model_field() { + let payload = Bytes::from_static(br#"{"model":"gpt-4","messages":[]}"#); + + let updated = + rewrite_body_model_if_required(payload, ProviderFormat::ChatCompletions, "gpt-4o"); + let value: Value = serde_json::from_slice(&updated).unwrap(); + + assert_eq!(value.get("model").and_then(Value::as_str), Some("gpt-4o")); + } + + #[test] + fn rewrite_body_model_if_required_leaves_converse_payload_unchanged() { + let payload = Bytes::from_static( + br#"{"modelId":"model-a","messages":[{"role":"user","content":[]}]}"#, + ); + let original_ptr = payload.as_ptr(); + + let updated = rewrite_body_model_if_required(payload, ProviderFormat::Converse, "model-b"); + + assert_eq!(updated.as_ptr(), original_ptr); + } + + #[test] + fn rewrite_body_model_if_required_leaves_google_payload_unchanged() { + let payload = Bytes::from_static( + br#"{"model":"gemini-2.5-flash","contents":[{"role":"user","parts":[{"text":"hi"}]}]}"#, + ); + let original_ptr = payload.as_ptr(); + + let updated = rewrite_body_model_if_required( + payload, + ProviderFormat::Google, + "models/gemini-2.5-pro", + ); + + assert_eq!(updated.as_ptr(), original_ptr); + } + + #[test] + fn rewrite_body_model_if_required_leaves_unknown_payload_unchanged() { + let payload = Bytes::from_static(b"not-json"); + let original_ptr = payload.as_ptr(); + + let updated = + rewrite_body_model_if_required(payload, ProviderFormat::ChatCompletions, "gpt-4o"); + + assert_eq!(updated.as_ptr(), original_ptr); + } +} diff --git a/crates/braintrust-llm-router/src/providers/mod.rs b/crates/braintrust-llm-router/src/providers/mod.rs index 92980550..6f355ecc 100644 --- a/crates/braintrust-llm-router/src/providers/mod.rs +++ b/crates/braintrust-llm-router/src/providers/mod.rs @@ -1,6 +1,7 @@ pub(crate) mod anthropic; mod azure; mod bedrock; +mod body_model; mod databricks; mod google; mod mistral; @@ -11,6 +12,7 @@ pub use anthropic::{AnthropicConfig, AnthropicProvider}; pub use azure::{AzureConfig, AzureProvider}; pub(crate) use bedrock::{prepare_bedrock_request, requires_bedrock_request_preparation}; pub use bedrock::{BedrockConfig, BedrockProvider}; +pub(crate) use body_model::rewrite_body_model_if_required; pub use databricks::{DatabricksConfig, DatabricksProvider}; pub use google::{GoogleConfig, GoogleProvider}; pub use mistral::{MistralConfig, MistralProvider}; @@ -176,80 +178,6 @@ pub(crate) fn disable_streaming_payload(payload: Bytes) -> Bytes { } } -pub(crate) fn body_model_field(format: ProviderFormat) -> Option<&'static str> { - match format { - ProviderFormat::ChatCompletions - | ProviderFormat::Responses - | ProviderFormat::Anthropic - | ProviderFormat::Mistral => Some("model"), - ProviderFormat::Google - | ProviderFormat::Converse - | ProviderFormat::BedrockAnthropic - | ProviderFormat::VertexAnthropic - | ProviderFormat::Unknown => None, - } -} - -enum BodyModelRewrite { - Required, - NotRequired, - Unknown, -} - -#[derive(serde::Deserialize)] -struct BodyModel { - model: Option, -} - -fn body_model_rewrite_status( - payload: &[u8], - format: ProviderFormat, - model: &str, -) -> BodyModelRewrite { - match body_model_field(format) { - Some("model") => match serde_json::from_slice::(payload) { - Ok(parsed) => { - if parsed.model.as_deref() == Some(model) { - BodyModelRewrite::NotRequired - } else { - BodyModelRewrite::Required - } - } - Err(_) => BodyModelRewrite::Unknown, - }, - Some(_) | None => BodyModelRewrite::Unknown, - } -} - -pub(crate) fn rewrite_body_model_if_required( - payload: Bytes, - format: ProviderFormat, - model: &str, -) -> Bytes { - match body_model_rewrite_status(&payload, format, model) { - BodyModelRewrite::Required => {} - BodyModelRewrite::NotRequired | BodyModelRewrite::Unknown => return payload, - } - - let Some(model_field) = body_model_field(format) else { - return payload; - }; - let Ok(mut value) = serde_json::from_slice::(&payload) else { - return payload; - }; - let Some(object) = value.as_object_mut() else { - return payload; - }; - if object.get(model_field).and_then(Value::as_str) == Some(model) { - return payload; - } - object.insert(model_field.to_string(), Value::String(model.to_string())); - match serde_json::to_vec(&value) { - Ok(serialized) => Bytes::from(serialized), - Err(_) => payload, - } -} - pub(crate) fn enable_streaming_payload(payload: Bytes, format: ProviderFormat) -> Bytes { let Ok(mut value) = serde_json::from_slice::(&payload) else { return payload; @@ -383,67 +311,6 @@ impl dyn Provider { mod tests { use super::*; - #[test] - fn rewrite_body_model_if_required_leaves_matching_model_bytes_unchanged() { - let payload = Bytes::from_static(br#"{"model":"gpt-4o","messages":[]}"#); - let original_ptr = payload.as_ptr(); - - let updated = - rewrite_body_model_if_required(payload, ProviderFormat::ChatCompletions, "gpt-4o"); - - assert_eq!(updated.as_ptr(), original_ptr); - } - - #[test] - fn rewrite_body_model_if_required_rewrites_mismatched_model_field() { - let payload = Bytes::from_static(br#"{"model":"gpt-4","messages":[]}"#); - - let updated = - rewrite_body_model_if_required(payload, ProviderFormat::ChatCompletions, "gpt-4o"); - let value: Value = serde_json::from_slice(&updated).unwrap(); - - assert_eq!(value.get("model").and_then(Value::as_str), Some("gpt-4o")); - } - - #[test] - fn rewrite_body_model_if_required_leaves_converse_payload_unchanged() { - let payload = Bytes::from_static( - br#"{"modelId":"model-a","messages":[{"role":"user","content":[]}]}"#, - ); - let original_ptr = payload.as_ptr(); - - let updated = rewrite_body_model_if_required(payload, ProviderFormat::Converse, "model-b"); - - assert_eq!(updated.as_ptr(), original_ptr); - } - - #[test] - fn rewrite_body_model_if_required_leaves_google_payload_unchanged() { - let payload = Bytes::from_static( - br#"{"model":"gemini-2.5-flash","contents":[{"role":"user","parts":[{"text":"hi"}]}]}"#, - ); - let original_ptr = payload.as_ptr(); - - let updated = rewrite_body_model_if_required( - payload, - ProviderFormat::Google, - "models/gemini-2.5-pro", - ); - - assert_eq!(updated.as_ptr(), original_ptr); - } - - #[test] - fn rewrite_body_model_if_required_leaves_unknown_payload_unchanged() { - let payload = Bytes::from_static(b"not-json"); - let original_ptr = payload.as_ptr(); - - let updated = - rewrite_body_model_if_required(payload, ProviderFormat::ChatCompletions, "gpt-4o"); - - assert_eq!(updated.as_ptr(), original_ptr); - } - #[test] fn disable_streaming_payload_removes_stream_fields() { let payload = Bytes::from_static( From 53818fe484749dadf0654b45d6ca26f8e12671bb Mon Sep 17 00:00:00 2001 From: Erin McNulty Date: Thu, 18 Jun 2026 09:45:08 -0400 Subject: [PATCH 26/28] Update router.rs --- crates/braintrust-llm-router/src/router.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/braintrust-llm-router/src/router.rs b/crates/braintrust-llm-router/src/router.rs index b0ca76d3..465b25f7 100644 --- a/crates/braintrust-llm-router/src/router.rs +++ b/crates/braintrust-llm-router/src/router.rs @@ -1504,7 +1504,7 @@ mod tests { } #[tokio::test] - async fn prepare_provider_request_does_not_readd_model_for_vertex_anthropic() { + async fn prepare_provider_request_does_not_read_model_for_vertex_anthropic() { let body = Bytes::from_static( br#"{"model":"claude-sonnet-4-6","messages":[{"role":"user","content":"Ping"}]}"#, ); From 5e9689bf94aa3e0e99b0def57183fc4486dbe7f4 Mon Sep 17 00:00:00 2001 From: Erin McNulty Date: Thu, 18 Jun 2026 12:00:24 -0400 Subject: [PATCH 27/28] pass down params --- .../examples/custom_auth.rs | 2 +- .../examples/multi_provider.rs | 4 +- .../braintrust-llm-router/examples/simple.rs | 2 +- .../examples/streaming.rs | 4 +- crates/braintrust-llm-router/src/router.rs | 52 ++++--------------- crates/braintrust-llm-router/tests/router.rs | 6 ++- 6 files changed, 21 insertions(+), 49 deletions(-) diff --git a/crates/braintrust-llm-router/examples/custom_auth.rs b/crates/braintrust-llm-router/examples/custom_auth.rs index c45466fa..db19c454 100644 --- a/crates/braintrust-llm-router/examples/custom_auth.rs +++ b/crates/braintrust-llm-router/examples/custom_auth.rs @@ -143,7 +143,7 @@ async fn main() -> Result<()> { ProviderFormat::ChatCompletions, ))?; let (request, _metadata) = router - .create_request(body, ProviderFormat::ChatCompletions, route) + .create_request(body, ProviderFormat::ChatCompletions, route, false) .await?; let bytes = router.complete(request, &ClientHeaders::default()).await?; let response: Value = serde_json::from_slice(&bytes)?; diff --git a/crates/braintrust-llm-router/examples/multi_provider.rs b/crates/braintrust-llm-router/examples/multi_provider.rs index 4134ed38..62002ccf 100644 --- a/crates/braintrust-llm-router/examples/multi_provider.rs +++ b/crates/braintrust-llm-router/examples/multi_provider.rs @@ -85,7 +85,7 @@ async fn main() -> Result<()> { match router.resolve_provider_routes(model, ProviderFormat::ChatCompletions, &[]) { Ok(routes) => match routes.first() { Some(route) => match router - .create_request(body, ProviderFormat::ChatCompletions, route) + .create_request(body, ProviderFormat::ChatCompletions, route, false) .await { Ok((request, _metadata)) => { @@ -122,7 +122,7 @@ async fn main() -> Result<()> { match router.resolve_provider_routes(model, ProviderFormat::ChatCompletions, &[]) { Ok(routes) => match routes.first() { Some(route) => match router - .create_request(body, ProviderFormat::ChatCompletions, route) + .create_request(body, ProviderFormat::ChatCompletions, route, false) .await { Ok((request, _metadata)) => { diff --git a/crates/braintrust-llm-router/examples/simple.rs b/crates/braintrust-llm-router/examples/simple.rs index 0a52c9e2..41ef2868 100644 --- a/crates/braintrust-llm-router/examples/simple.rs +++ b/crates/braintrust-llm-router/examples/simple.rs @@ -68,7 +68,7 @@ async fn main() -> Result<()> { ProviderFormat::ChatCompletions, ))?; let (request, _metadata) = router - .create_request(body, ProviderFormat::ChatCompletions, route) + .create_request(body, ProviderFormat::ChatCompletions, route, false) .await?; let bytes = router.complete(request, &ClientHeaders::default()).await?; let response: Value = serde_json::from_slice(&bytes)?; diff --git a/crates/braintrust-llm-router/examples/streaming.rs b/crates/braintrust-llm-router/examples/streaming.rs index b4c02a14..b7eac7a2 100644 --- a/crates/braintrust-llm-router/examples/streaming.rs +++ b/crates/braintrust-llm-router/examples/streaming.rs @@ -67,7 +67,7 @@ async fn main() -> Result<()> { ProviderFormat::ChatCompletions, ))?; let (request, _metadata) = router - .create_stream_request(body, ProviderFormat::ChatCompletions, route) + .create_stream_request(body, ProviderFormat::ChatCompletions, route, false) .await?; let mut stream = router .complete_stream(request, &ClientHeaders::default(), None) @@ -150,7 +150,7 @@ async fn main() -> Result<()> { ProviderFormat::ChatCompletions, ))?; let (request, _metadata) = router - .create_stream_request(body, ProviderFormat::ChatCompletions, route) + .create_stream_request(body, ProviderFormat::ChatCompletions, route, false) .await?; let stream = router .complete_stream(request, &ClientHeaders::default(), None) diff --git a/crates/braintrust-llm-router/src/router.rs b/crates/braintrust-llm-router/src/router.rs index 107f9030..cdb73f3e 100644 --- a/crates/braintrust-llm-router/src/router.rs +++ b/crates/braintrust-llm-router/src/router.rs @@ -319,6 +319,7 @@ impl Router { /// * `body` - Raw request body bytes in any supported format (OpenAI, Anthropic, Google, etc.) /// * `output_format` - The output format, or None to auto-detect from body /// * `route` - The already-resolved provider route to prepare for + /// * `preserve_body_model` - Keep the request body's model instead of rewriting it to the route model /// /// The body will be automatically transformed to the selected provider format if needed. #[cfg_attr( @@ -334,24 +335,7 @@ impl Router { body: Bytes, output_format: ProviderFormat, route: &ProviderRoute, - ) -> Result<(PreparedRequest, RouterMetadata)> { - let (inner, metadata) = self - .create_prepared_request_internal( - body, - output_format, - route, - false, - RequestPreparationOptions::default(), - ) - .await?; - Ok((PreparedRequest { inner }, metadata)) - } - - pub async fn create_request_preserving_body_model( - &self, - body: Bytes, - output_format: ProviderFormat, - route: &ProviderRoute, + preserve_body_model: bool, ) -> Result<(PreparedRequest, RouterMetadata)> { let (inner, metadata) = self .create_prepared_request_internal( @@ -360,7 +344,7 @@ impl Router { route, false, RequestPreparationOptions { - rewrite_body_model: false, + rewrite_body_model: !preserve_body_model, }, ) .await?; @@ -451,6 +435,7 @@ impl Router { /// * `body` - Raw request body bytes in any supported format (OpenAI, Anthropic, Google, etc.) /// * `output_format` - The output format, or None to auto-detect from body /// * `route` - The already-resolved provider route to prepare for + /// * `preserve_body_model` - Keep the request body's model instead of rewriting it to the route model /// /// The body will be automatically transformed to the selected provider format if needed. #[cfg_attr( @@ -466,24 +451,7 @@ impl Router { body: Bytes, output_format: ProviderFormat, route: &ProviderRoute, - ) -> Result<(PreparedStreamRequest, RouterMetadata)> { - let (inner, metadata) = self - .create_prepared_request_internal( - body, - output_format, - route, - true, - RequestPreparationOptions::default(), - ) - .await?; - Ok((PreparedStreamRequest { inner }, metadata)) - } - - pub async fn create_stream_request_preserving_body_model( - &self, - body: Bytes, - output_format: ProviderFormat, - route: &ProviderRoute, + preserve_body_model: bool, ) -> Result<(PreparedStreamRequest, RouterMetadata)> { let (inner, metadata) = self .create_prepared_request_internal( @@ -492,7 +460,7 @@ impl Router { route, true, RequestPreparationOptions { - rewrite_body_model: false, + rewrite_body_model: !preserve_body_model, }, ) .await?; @@ -1441,7 +1409,9 @@ mod tests { let route = routes .first() .ok_or_else(|| Error::NoProvider(output_format))?; - router.create_request(body, output_format, route).await + router + .create_request(body, output_format, route, false) + .await } async fn create_test_stream_request( @@ -1455,7 +1425,7 @@ mod tests { .first() .ok_or_else(|| Error::NoProvider(output_format))?; router - .create_stream_request(body, output_format, route) + .create_stream_request(body, output_format, route, false) .await } @@ -3294,7 +3264,7 @@ mod tests { br#"{"model":"gpt-4o","messages":[{"role":"user","content":"Ping"}]}"#, ); let (request, _) = router - .create_request(body, ProviderFormat::ChatCompletions, fallback_route) + .create_request(body, ProviderFormat::ChatCompletions, fallback_route, false) .await .expect("request prepares"); let payload: Value = serde_json::from_slice(&request.inner.payload).expect("json"); diff --git a/crates/braintrust-llm-router/tests/router.rs b/crates/braintrust-llm-router/tests/router.rs index c285a7ad..9cd46ac3 100644 --- a/crates/braintrust-llm-router/tests/router.rs +++ b/crates/braintrust-llm-router/tests/router.rs @@ -28,7 +28,9 @@ async fn create_request( let route = routes .first() .ok_or_else(|| Error::NoProvider(output_format))?; - router.create_request(body, output_format, route).await + router + .create_request(body, output_format, route, false) + .await } #[derive(Clone)] @@ -918,7 +920,7 @@ async fn responses_required_model_uses_responses_for_anthropic_messages_output() .expect("resolve routes"); let route = routes.first().expect("route"); let (_request, metadata) = router - .create_request(body, ProviderFormat::Anthropic, route) + .create_request(body, ProviderFormat::Anthropic, route, false) .await .expect("create request"); From bbd19500ee63938b75c1b7af01a8ab6a5446ccd4 Mon Sep 17 00:00:00 2001 From: Erin McNulty Date: Thu, 18 Jun 2026 12:12:51 -0400 Subject: [PATCH 28/28] small fix --- crates/braintrust-llm-router/src/router.rs | 31 +++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/crates/braintrust-llm-router/src/router.rs b/crates/braintrust-llm-router/src/router.rs index cdb73f3e..a6dd5d46 100644 --- a/crates/braintrust-llm-router/src/router.rs +++ b/crates/braintrust-llm-router/src/router.rs @@ -233,8 +233,9 @@ async fn prepare_provider_request( return Ok((bytes, Some(format), format)); } + let model_override = options.rewrite_body_model.then_some(spec.model.as_str()); let (transformed, detected_format, actual_format, maybe_rewrite_model) = - match lingua::transform_request(body.clone(), format, Some(&spec.model)) { + match lingua::transform_request(body.clone(), format, model_override) { Ok(TransformResult::PassThrough(bytes)) => (bytes, None, format, true), Ok(TransformResult::Transformed { bytes, @@ -1611,6 +1612,34 @@ mod tests { ); } + #[tokio::test] + async fn prepare_provider_request_can_preserve_body_model_across_format_transform() { + let body = Bytes::from_static( + br#"{"model":"claude-3-5-haiku-20241022","max_tokens":128,"messages":[{"role":"user","content":"Ping"}]}"#, + ); + let spec = openai_spec("gpt-4o", ModelFlavor::Chat); + + let (payload, detected_format, actual_format) = prepare_provider_request( + body, + &spec, + ProviderFormat::ChatCompletions, + false, + RequestPreparationOptions { + rewrite_body_model: false, + }, + ) + .await + .expect("request prepares"); + let parsed: Value = serde_json::from_slice(&payload).expect("valid request json"); + + assert_eq!(detected_format, Some(ProviderFormat::Anthropic)); + assert_eq!(actual_format, ProviderFormat::ChatCompletions); + assert_eq!( + parsed.get("model").and_then(Value::as_str), + Some("claude-3-5-haiku-20241022") + ); + } + #[tokio::test] async fn prepare_provider_request_upgrades_actual_format_to_responses_for_reasoning_plus_tools() {