Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
6a0b408
detect equivalient models between providers
erin2722 Jun 12, 2026
045daf3
fixes
erin2722 Jun 12, 2026
6a15bd0
cleanpu provider alias stuff
erin2722 Jun 12, 2026
b0d9bdd
allow custom models to specify equivalent models
erin2722 Jun 15, 2026
f46669b
merge base and custom catalog, this is getting too complicated to hav…
erin2722 Jun 15, 2026
5c9b327
comments
erin2722 Jun 15, 2026
c8458d1
Revert "merge base and custom catalog, this is getting too complicate…
erin2722 Jun 15, 2026
f3e2c0e
make shared equivalence index
erin2722 Jun 15, 2026
9817f7e
address comments
erin2722 Jun 16, 2026
0248a51
comment
erin2722 Jun 16, 2026
73fc1c0
address comment
erin2722 Jun 16, 2026
7cb2937
comment
erin2722 Jun 16, 2026
ba08768
equivalent => fallback
erin2722 Jun 16, 2026
8c5cd8a
don't recompute full base equivalnce index each time and consolidate …
erin2722 Jun 16, 2026
7bc9338
rename
erin2722 Jun 16, 2026
471718f
rename
erin2722 Jun 16, 2026
9a2743f
one more rename
erin2722 Jun 16, 2026
8674f41
comment fix
erin2722 Jun 16, 2026
0eb3148
stop doing bedrock request transforms in a seperate spot than everywh…
erin2722 Jun 16, 2026
8189e2d
oops
erin2722 Jun 16, 2026
199be5b
more holistic approach to re-write model name on failover in request …
erin2722 Jun 16, 2026
6363aa8
Merge branch 'main' of https://github.com/braintrustdata/lingua into …
erin2722 Jun 16, 2026
ce6b753
only rewrite model name in body if absolutely needed
erin2722 Jun 16, 2026
192cea1
tighten more
erin2722 Jun 16, 2026
41e2a9f
if the model name diff was intentional, don't change it
erin2722 Jun 16, 2026
0985268
move stuff out of mod.rs
erin2722 Jun 16, 2026
53818fe
Update router.rs
erin2722 Jun 18, 2026
9a0eeef
Merge branch 'main' of https://github.com/braintrustdata/lingua into …
erin2722 Jun 18, 2026
4afb9d8
Merge branch 'equivalent-model-endpoints-btwn-providers' of https://g…
erin2722 Jun 18, 2026
5e9689b
pass down params
erin2722 Jun 18, 2026
bbd1950
small fix
erin2722 Jun 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/braintrust-llm-router/examples/custom_auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ async fn main() -> Result<()> {
ProviderFormat::ChatCompletions,
))?;
let (request, _metadata) = router
.create_request(body, ProviderFormat::ChatCompletions, route)
.create_request(body, ProviderFormat::ChatCompletions, route, false)
.await?;
let bytes = router.complete(request, &ClientHeaders::default()).await?;
let response: Value = serde_json::from_slice(&bytes)?;
Expand Down
4 changes: 2 additions & 2 deletions crates/braintrust-llm-router/examples/multi_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ async fn main() -> Result<()> {
match router.resolve_provider_routes(model, ProviderFormat::ChatCompletions, &[]) {
Ok(routes) => match routes.first() {
Some(route) => match router
.create_request(body, ProviderFormat::ChatCompletions, route)
.create_request(body, ProviderFormat::ChatCompletions, route, false)
.await
{
Ok((request, _metadata)) => {
Expand Down Expand Up @@ -122,7 +122,7 @@ async fn main() -> Result<()> {
match router.resolve_provider_routes(model, ProviderFormat::ChatCompletions, &[]) {
Ok(routes) => match routes.first() {
Some(route) => match router
.create_request(body, ProviderFormat::ChatCompletions, route)
.create_request(body, ProviderFormat::ChatCompletions, route, false)
.await
{
Ok((request, _metadata)) => {
Expand Down
2 changes: 1 addition & 1 deletion crates/braintrust-llm-router/examples/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ async fn main() -> Result<()> {
ProviderFormat::ChatCompletions,
))?;
let (request, _metadata) = router
.create_request(body, ProviderFormat::ChatCompletions, route)
.create_request(body, ProviderFormat::ChatCompletions, route, false)
.await?;
let bytes = router.complete(request, &ClientHeaders::default()).await?;
let response: Value = serde_json::from_slice(&bytes)?;
Expand Down
4 changes: 2 additions & 2 deletions crates/braintrust-llm-router/examples/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async fn main() -> Result<()> {
ProviderFormat::ChatCompletions,
))?;
let (request, _metadata) = router
.create_stream_request(body, ProviderFormat::ChatCompletions, route)
.create_stream_request(body, ProviderFormat::ChatCompletions, route, false)
.await?;
let mut stream = router
.complete_stream(request, &ClientHeaders::default(), None)
Expand Down Expand Up @@ -150,7 +150,7 @@ async fn main() -> Result<()> {
ProviderFormat::ChatCompletions,
))?;
let (request, _metadata) = router
.create_stream_request(body, ProviderFormat::ChatCompletions, route)
.create_stream_request(body, ProviderFormat::ChatCompletions, route, false)
.await?;
let stream = router
.complete_stream(request, &ClientHeaders::default(), None)
Expand Down
320 changes: 320 additions & 0 deletions crates/braintrust-llm-router/src/catalog/fallback.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,320 @@
use std::collections::{HashMap, HashSet};
use std::sync::Arc;

use super::{ModelCatalog, ModelSpec};
use crate::error::{Error, Result};

/// A request-local catalog overlay.
///
/// Secret-defined custom models live in `custom` and shadow entries in the
/// shared `base` catalog. This avoids cloning the base catalog when adding
/// per-request model definitions.
#[derive(Debug, Clone)]
pub struct OverlayModelCatalog {
base: Arc<ModelCatalog>,
custom: ModelCatalog,
custom_model_names: HashSet<String>,
overlay_edges: HashMap<String, Vec<String>>,
}

impl OverlayModelCatalog {
pub fn new(base: Arc<ModelCatalog>, custom: ModelCatalog) -> Self {
let custom_model_names: HashSet<String> = custom.models.keys().cloned().collect();
let mut overlay_edges: HashMap<String, Vec<String>> = HashMap::new();
for (name, fallbacks) in &custom.fallback_models {
if !custom.models.contains_key(name) {
continue;
}
for fallback_model in fallbacks {
let fallback_is_visible = custom.models.contains_key(fallback_model)
|| (base.models.contains_key(fallback_model)
&& !custom_model_names.contains(fallback_model));
if !fallback_is_visible {
continue;
}
overlay_edges
.entry(name.clone())
.or_default()
.push(fallback_model.clone());
overlay_edges
.entry(fallback_model.clone())
.or_default()
.push(name.clone());
}
}
Self {
base,
custom,
custom_model_names,
overlay_edges,
}
}

pub fn base_catalog(&self) -> Arc<ModelCatalog> {
Arc::clone(&self.base)
}

pub fn get(&self, name: &str) -> Option<Arc<ModelSpec>> {
self.custom.get(name).or_else(|| self.base.get(name))
}

pub fn find_fallback_models(&self, name: &str) -> Vec<String> {
let Some(_) = self.get(name) else {
return Vec::new();
};

let mut visited = HashSet::new();
let mut stack = vec![name.to_string()];
while let Some(current) = stack.pop() {
if !visited.insert(current.clone()) {
continue;
}

if !self.custom_model_names.contains(&current) {
stack.extend(
self.base
.fallback_models(&current)
.into_iter()
.filter(|model_name| !self.custom_model_names.contains(model_name)),
);
}
if let Some(neighbors) = self.overlay_edges.get(&current) {
stack.extend(neighbors.iter().cloned());
}
}

let mut names = vec![name.to_string()];
visited.remove(name);
let mut equivalent_names: Vec<String> = visited.into_iter().collect();
equivalent_names.sort();
names.extend(equivalent_names);
names
}
}

enum FallbackModelSource<'a> {
Json(&'a str),
Parsed(HashMap<String, Vec<String>>),
}

impl ModelCatalog {
pub fn fallback_models(&self, name: &str) -> Vec<String> {
let Some(_) = self.models.get(name) else {
return Vec::new();
};

let mut names = vec![name.to_string()];
if let Some(equivalent_names) = self.equivalence_index.get(name) {
names.extend(equivalent_names.iter().cloned());
}
names
}

pub fn add_fallback_models<I>(&mut self, name: String, fallback_models: I) -> Result<()>
where
I: IntoIterator<Item = String>,
{
if !self.models.contains_key(&name) {
return Err(Error::InvalidRequest(format!(
"model '{name}' references fallback_models but is missing from catalog"
)));
}

let fallback_models: Vec<String> = fallback_models
.into_iter()
.filter(|fallback_model| !fallback_model.is_empty())
.collect();
for fallback_model in &fallback_models {
if !self.models.contains_key(fallback_model) {
return Err(Error::InvalidRequest(format!(
"model '{name}' references missing fallback model '{fallback_model}'"
)));
}
}

let mut next_fallback_models = self.fallback_models.clone();
let entry = next_fallback_models.entry(name).or_default();
for fallback_model in fallback_models {
if entry.contains(&fallback_model) {
continue;
}
entry.push(fallback_model);
}
self.set_fallback_models(FallbackModelSource::Parsed(next_fallback_models), false)?;
Ok(())
}

pub fn add_external_fallback_models<I>(
&mut self,
name: String,
fallback_models: I,
) -> Result<()>
where
I: IntoIterator<Item = String>,
{
if !self.models.contains_key(&name) {
return Err(Error::InvalidRequest(format!(
"model '{name}' references fallback_models but is missing from catalog"
)));
}

let mut next_fallback_models = self.fallback_models.clone();
let entry = next_fallback_models.entry(name).or_default();
for fallback_model in fallback_models {
if fallback_model.is_empty() || entry.contains(&fallback_model) {
continue;
}
entry.push(fallback_model);
}
self.set_fallback_models(FallbackModelSource::Parsed(next_fallback_models), false)?;
Ok(())
}

pub(super) fn set_fallback_models_from_json(
&mut self,
content: &str,
validate_targets: bool,
) -> Result<()> {
self.set_fallback_models(FallbackModelSource::Json(content), validate_targets)
}

pub(super) fn set_fallback_models_from_parsed(
&mut self,
fallback_models: HashMap<String, Vec<String>>,
validate_targets: bool,
) -> Result<()> {
self.set_fallback_models(
FallbackModelSource::Parsed(fallback_models),
validate_targets,
)
}

fn set_fallback_models(
&mut self,
source: FallbackModelSource<'_>,
validate_targets: bool,
) -> Result<()> {
let fallback_models = match source {
FallbackModelSource::Json(content) => parse_fallback_models(content)?,
FallbackModelSource::Parsed(fallback_models) => fallback_models,
};

if validate_targets {
validate_fallback_models(&self.models, &fallback_models)?;
}
self.equivalence_index =
build_equivalence_index(self.models.keys().cloned().collect(), &fallback_models);
self.fallback_models = fallback_models;
Ok(())
}
}

fn parse_fallback_models(content: &str) -> Result<HashMap<String, Vec<String>>> {
let raw: HashMap<String, serde_json::Value> = serde_json::from_str(content)?;
let mut fallback_models = HashMap::new();
for (name, value) in raw {
let Some(fallbacks) = value.get("fallback_models") else {
continue;
};
let Some(fallbacks) = fallbacks.as_array() else {
return Err(Error::InvalidRequest(format!(
"model '{name}' has invalid fallback_models"
)));
};
let mut parsed = Vec::with_capacity(fallbacks.len());
for fallback_model in fallbacks {
let Some(fallback_model) = fallback_model.as_str() else {
return Err(Error::InvalidRequest(format!(
"model '{name}' has invalid fallback_models"
)));
};
parsed.push(fallback_model.to_string());
}
if !parsed.is_empty() {
fallback_models.insert(name, parsed);
}
}
Ok(fallback_models)
}

fn validate_fallback_models(
models: &HashMap<String, Arc<ModelSpec>>,
fallback_models: &HashMap<String, Vec<String>>,
) -> Result<()> {
for (name, fallback_models) in fallback_models {
for fallback_model in fallback_models {
if !models.contains_key(fallback_model) {
return Err(Error::InvalidRequest(format!(
"model '{name}' references missing fallback model '{fallback_model}'"
)));
}
}
}
Ok(())
}

fn build_equivalence_index(
model_names: HashSet<String>,
fallback_models: &HashMap<String, Vec<String>>,
) -> HashMap<String, Vec<String>> {
let mut adjacency: HashMap<String, Vec<String>> = HashMap::new();
for name in &model_names {
adjacency.entry(name.clone()).or_default();
}

for (name, fallbacks) in fallback_models {
if !model_names.contains(name) {
continue;
}
for fallback_model in fallbacks {
if !model_names.contains(fallback_model) {
continue;
}
adjacency
.entry(name.clone())
.or_default()
.push(fallback_model.clone());
adjacency
.entry(fallback_model.clone())
.or_default()
.push(name.clone());
}
}

let mut visited = HashSet::new();
let mut index = HashMap::new();
for name in model_names {
if visited.contains(&name) {
continue;
}

let mut stack = vec![name.clone()];
let mut component = Vec::new();
while let Some(current) = stack.pop() {
if !visited.insert(current.clone()) {
continue;
}
component.push(current.clone());
if let Some(neighbors) = adjacency.get(&current) {
stack.extend(neighbors.iter().cloned());
}
}

if component.len() <= 1 {
continue;
}
component.sort();
for member in &component {
index.insert(
member.clone(),
component
.iter()
.filter(|other| *other != member)
.cloned()
.collect(),
);
}
}

index
}
Loading
Loading