Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 76 additions & 20 deletions crates/openshell-cli/src/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ use openshell_core::proto::{
use openshell_core::settings::{self, SettingValueKind};
use openshell_core::{ObjectId, ObjectName};
use openshell_providers::{
ProviderRegistry, ProviderTypeProfile, detect_provider_from_command, normalize_provider_type,
parse_profile_json, parse_profile_yaml, profile_to_json, profile_to_yaml, profiles_to_json,
profiles_to_yaml,
ProviderRegistry, ProviderTypeProfile, RealDiscoveryContext, detect_provider_from_command,
discover_from_profile, normalize_provider_type, parse_profile_json, parse_profile_yaml,
profile_to_json, profile_to_yaml, profiles_to_json, profiles_to_yaml,
};
use owo_colors::OwoColorize;
use std::collections::{HashMap, HashSet};
Expand Down Expand Up @@ -1709,7 +1709,12 @@ pub async fn sandbox_create(
};
let requested_gpu = gpu || image.as_deref().is_some_and(image_requests_gpu);

let inferred_types: Vec<String> = inferred_provider_type(command).into_iter().collect();
let providers_v2_enabled = gateway_providers_v2_enabled(&mut client).await?;
let inferred_types: Vec<String> = if providers_v2_enabled {
Vec::new()
} else {
inferred_provider_type(command).into_iter().collect()
};
let configured_providers = ensure_required_providers(
&mut client,
providers,
Expand Down Expand Up @@ -3631,9 +3636,8 @@ async fn auto_create_provider(
return Ok(());
}

let registry = ProviderRegistry::new();
let discovered = registry
.discover_existing(provider_type)
let discovered = discover_existing_provider_data(client, provider_type)
.await
.map_err(|err| miette::miette!("failed to discover provider '{provider_type}': {err}"))?;
let Some(discovered) = discovered else {
eprintln!(
Expand Down Expand Up @@ -4094,6 +4098,68 @@ fn service_url_for_gateway(service_url: &str, gateway_endpoint: &str) -> String
service_url.to_string()
}

async fn gateway_providers_v2_enabled(client: &mut crate::tls::GrpcClient) -> Result<bool> {
let response = client
.get_gateway_config(GetGatewayConfigRequest {})
.await
.into_diagnostic()?
.into_inner();
let Some(setting) = response.settings.get(settings::PROVIDERS_V2_ENABLED_KEY) else {
return Ok(false);
};
match setting.value.as_ref() {
Some(setting_value::Value::BoolValue(enabled)) => Ok(*enabled),
None => Ok(false),
Some(_) => Err(miette::miette!(
"gateway setting '{}' has invalid value type; expected bool",
settings::PROVIDERS_V2_ENABLED_KEY
)),
}
}

async fn fetch_provider_profile(
client: &mut crate::tls::GrpcClient,
provider_type: &str,
) -> Result<ProviderProfile> {
let response = client
.get_provider_profile(GetProviderProfileRequest {
id: provider_type.to_string(),
})
.await
.map_err(|status| {
if status.code() == Code::NotFound {
miette::miette!(
"provider profile '{provider_type}' not found; providers v2 discovery requires a provider profile"
)
} else {
miette::miette!(status.to_string())
}
})?;

response
.into_inner()
.profile
.ok_or_else(|| miette::miette!("provider profile '{provider_type}' missing from response"))
}

async fn discover_existing_provider_data(
client: &mut crate::tls::GrpcClient,
provider_type: &str,
) -> Result<Option<openshell_providers::DiscoveredProvider>> {
if gateway_providers_v2_enabled(client).await? {
let profile = fetch_provider_profile(client, provider_type).await?;
let profile = ProviderTypeProfile::from_proto(&profile);
discover_from_profile(&profile, &RealDiscoveryContext).map_err(|err| {
miette::miette!("failed to discover existing provider data from profile: {err}")
})
} else {
let registry = ProviderRegistry::new();
registry
.discover_existing(provider_type)
.map_err(|err| miette::miette!("failed to discover existing provider data: {err}"))
}
}

pub async fn provider_create(
server: &str,
name: &str,
Expand Down Expand Up @@ -4143,10 +4209,7 @@ pub async fn provider_create(
let mut config_map = parse_key_value_pairs(config, "--config")?;

if from_existing {
let registry = ProviderRegistry::new();
let discovered = registry
.discover_existing(&provider_type)
.map_err(|err| miette::miette!("failed to discover existing provider data: {err}"))?;
let discovered = discover_existing_provider_data(&mut client, &provider_type).await?;
let Some(discovered) = discovered else {
return Err(miette::miette!(
"no existing local credentials/config found for provider type '{provider_type}'"
Expand All @@ -4162,13 +4225,9 @@ pub async fn provider_create(
}

if credential_map.is_empty() {
let allows_refresh_bootstrap = client
.get_provider_profile(GetProviderProfileRequest {
id: provider_type.clone(),
})
let allows_refresh_bootstrap = fetch_provider_profile(&mut client, &provider_type)
.await
.ok()
.and_then(|response| response.into_inner().profile)
.is_some_and(|profile| provider_profile_allows_refresh_bootstrap(&profile));
if !allows_refresh_bootstrap {
return Err(miette::miette!(
Expand Down Expand Up @@ -4911,10 +4970,7 @@ pub async fn provider_update(
.ok_or_else(|| miette::miette!("provider '{name}' not found"))?;

let provider_type = existing.r#type;
let registry = ProviderRegistry::new();
let discovered = registry
.discover_existing(&provider_type)
.map_err(|err| miette::miette!("failed to discover existing provider data: {err}"))?;
let discovered = discover_existing_provider_data(&mut client, &provider_type).await?;
let Some(discovered) = discovered else {
return Err(miette::miette!(
"no existing local credentials/config found for provider type '{provider_type}'"
Expand Down
Loading
Loading