diff --git a/Cargo.lock b/Cargo.lock index e365998f93..524283c6d9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1497,6 +1497,7 @@ dependencies = [ "hyper-util", "lazy_static", "logfmt", + "mac_address", "metrics-endpoint", "opentelemetry", "opentelemetry-otlp", @@ -1512,6 +1513,7 @@ dependencies = [ "tokio-rustls", "tokio-stream", "tokio-util", + "tonic", "tower", "tower-http", "tracing", diff --git a/crates/api-db/src/machine_interface.rs b/crates/api-db/src/machine_interface.rs index 964ca0ba13..7209567adb 100644 --- a/crates/api-db/src/machine_interface.rs +++ b/crates/api-db/src/machine_interface.rs @@ -279,6 +279,22 @@ pub async fn find_by_mac_address( find_by(txn, ObjectColumnFilter::One(MacAddressColumn, &macaddr)).await } +/// This function returns only an IP for efficiency, we don't need to fetch/deserialize the entire +/// MachineInterfaceSnapshot +pub async fn lookup_bmc_ip_by_mac_address( + db: impl DbReader<'_>, + mac_address: MacAddress, +) -> DatabaseResult> { + let query = r"SELECT mia.address FROM machine_interfaces mi + INNER JOIN machine_interface_addresses mia ON (mia.interface_id = mi.id) + WHERE mi.mac_address = $1"; + sqlx::query_scalar(query) + .bind(mac_address) + .fetch_all(db) + .await + .map_err(|e| DatabaseError::query(query, e)) +} + pub async fn find_by_ip( txn: impl DbReader<'_>, ip: IpAddr, diff --git a/crates/api/src/api.rs b/crates/api/src/api.rs index 15d37ccb81..88c43a8a3f 100644 --- a/crates/api/src/api.rs +++ b/crates/api/src/api.rs @@ -3365,6 +3365,13 @@ impl Forge for Api { templates, })) } + + async fn find_bmc_ips( + &self, + request: Request<::rpc::forge::FindBmcIpsRequest>, + ) -> Result, Status> { + crate::handlers::machine_interface::find_bmc_ips(self, request).await + } } fn ipxe_template_scope_to_proto( diff --git a/crates/api/src/auth/internal_rbac_rules.rs b/crates/api/src/auth/internal_rbac_rules.rs index 7d54319f9b..867c4f14d3 100644 --- a/crates/api/src/auth/internal_rbac_rules.rs +++ b/crates/api/src/auth/internal_rbac_rules.rs @@ -496,6 +496,7 @@ impl InternalRBACRules { vec![SiteAgent, ForgeAdminCLI], ); x.perm("FindMacAddressByBmcIp", vec![SiteAgent, BmcProxy]); + x.perm("FindBmcIps", vec![ForgeAdminCLI, BmcProxy]); x.perm("BmcCredentialStatus", vec![ForgeAdminCLI, SiteAgent]); x.perm( "GetMachineValidationExternalConfigs", diff --git a/crates/api/src/handlers/machine_interface.rs b/crates/api/src/handlers/machine_interface.rs index 436a5a22fd..33c9a234a9 100644 --- a/crates/api/src/handlers/machine_interface.rs +++ b/crates/api/src/handlers/machine_interface.rs @@ -19,6 +19,8 @@ use std::net::IpAddr; use std::str::FromStr; use ::rpc::forge as rpc; +use db::WithTransaction; +use futures_util::FutureExt; use itertools::Itertools; use model::machine_interface::InterfaceType; use tonic::{Request, Response, Status}; @@ -145,3 +147,71 @@ pub(crate) async fn find_mac_address_by_bmc_ip( mac_address: interface.mac_address.to_string(), })) } + +pub(crate) async fn find_bmc_ips( + api: &Api, + request: Request, +) -> Result, Status> { + use rpc::find_bmc_ips_request::LookupBy; + + log_request_data(&request); + + let req = request.into_inner(); + + let bmc_ips = match req.lookup_by { + Some(LookupBy::MacAddress(mac_address)) => { + db::machine_interface::lookup_bmc_ip_by_mac_address( + &api.database_connection, + mac_address.parse().map_err(|e| { + CarbideError::InvalidArgument(format!("Invalid MAC address: {e}")) + })?, + ) + .await? + } + Some(LookupBy::Serial(serial)) => { + // Get the machine ID for this serial + let machine_ids = + db::machine_topology::find_by_serial(&api.database_connection, &serial).await?; + if machine_ids.len() > 1 { + tracing::warn!( + serial, + "Multiple machines match serial number, cannot resolve to BMC IP" + ); + return Ok(Response::new(rpc::BmcIpList::default())); + } + let Some(machine_id) = machine_ids.into_iter().next() else { + return Ok(Response::new(rpc::BmcIpList::default())); + }; + + // Get the machine topology for this machine + let Some(machine_topology) = api + .with_txn(|txn| { + async move { + db::machine_topology::find_latest_by_machine_ids(txn, &[machine_id]).await + } + .boxed() + }) + .await?? + .into_values() + .next() + else { + return Ok(Response::new(rpc::BmcIpList::default())); + }; + + // Get the BMC IP out of the machine topology + let bmc_ip = match machine_topology.topology.bmc_info.ip.map(|ip| ip.parse()) { + Some(Ok(ip)) => ip, + None | Some(Err(_)) => { + return Ok(Response::new(rpc::BmcIpList::default())); + } + }; + + vec![bmc_ip] + } + None => return Err(CarbideError::MissingArgument("lookup_by").into()), + }; + + Ok(Response::new(rpc::BmcIpList { + bmc_ips: bmc_ips.into_iter().map(|ip| ip.to_string()).collect(), + })) +} diff --git a/crates/bmc-proxy/Cargo.toml b/crates/bmc-proxy/Cargo.toml index c08028409e..999f0a06b5 100644 --- a/crates/bmc-proxy/Cargo.toml +++ b/crates/bmc-proxy/Cargo.toml @@ -54,6 +54,7 @@ hyper-timeout = { workspace = true } hyper-util = { workspace = true } hyper = { workspace = true, features = ["full"] } lazy_static = { workspace = true } +mac_address = { workspace = true } opentelemetry = { workspace = true, features = ["logs"] } opentelemetry-otlp = { workspace = true, features = ["grpc-tonic"] } opentelemetry-prometheus.workspace = true @@ -74,6 +75,7 @@ tokio-rustls = { workspace = true } tokio-stream = { workspace = true } tokio-util = { workspace = true } tokio = { workspace = true } +tonic = { workspace = true } tower = { workspace = true } tower-http = { features = [ "add-extension", diff --git a/crates/bmc-proxy/src/bmc_proxy.rs b/crates/bmc-proxy/src/bmc_proxy.rs index 99bbf5d0ff..97e891a090 100644 --- a/crates/bmc-proxy/src/bmc_proxy.rs +++ b/crates/bmc-proxy/src/bmc_proxy.rs @@ -39,9 +39,11 @@ use http::{HeaderMap, Method, Request, Response, StatusCode, Uri}; use hyper_util::rt::{TokioExecutor, TokioIo}; use hyper_util::server::conn::auto; use hyper_util::service::TowerToHyperService; +use mac_address::{MacAddress, MacParseError}; use opentelemetry::KeyValue; use opentelemetry::metrics::Meter; use rpc::forge; +use rpc::forge::find_bmc_ips_request::LookupBy; use rpc::forge_api_client::ForgeApiClient; use rpc::forge_tls_client::{ApiConfig, ForgeClientConfig}; use tokio::net::TcpListener; @@ -60,8 +62,8 @@ const MAX_BODY_SIZE: usize = 8 * 1024 * 1024; // 8MiB body size limit (matches n #[derive(thiserror::Error, Debug)] pub enum BmcProxyError { - #[error("Error resolving BMC credentials through Carbide API: {0}")] - ApiCredentials(String), + #[error("Error resolving BMC information through Carbide API: {0}")] + Api(String), #[error("Invalid configuration: {0}")] InvalidConfiguration(String), #[error("Internal error proxying request: {0}")] @@ -86,10 +88,27 @@ struct BmcProxyState { api_client: ForgeApiClient, credential_cache: CredentialCache, client_cache: HttpClientCache, + ip_cache: LookupToIpCache, } type CredentialCache = Arc>>; type HttpClientCache = Arc>>; +type LookupToIpCache = Arc>>; + +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +enum ForwardedTarget<'a> { + Ip(IpAddr), + Mac(MacAddress), + Serial(&'a str), +} + +#[derive(thiserror::Error, Debug)] +enum ForwardedHeaderParseError { + #[error("Invalid IP in Forwarded host header: {0}")] + Ip(#[from] AddrParseError), + #[error("Invalid MAC address in Forwarded host header: {0}")] + Mac(#[from] MacParseError), +} impl BmcProxyState { fn allows(&self, request: &Request) -> bool { @@ -152,8 +171,9 @@ pub async fn start( let state = BmcProxyState { config, api_client, - credential_cache: Arc::new(Mutex::new(HashMap::new())), - client_cache: Arc::new(Mutex::new(HashMap::new())), + credential_cache: Default::default(), + client_cache: Default::default(), + ip_cache: Default::default(), meter, }; @@ -470,17 +490,39 @@ async fn proxy_request( return Ok(error_response((StatusCode::FORBIDDEN, "Forbidden").into())); } let (parts, body) = request.into_parts(); - let target_ip = forwarded_host_ip(&parts.headers) + let forwarded_target = forwarded_header_value(&parts.headers) + .map_err(|e| error_response((StatusCode::BAD_REQUEST, e.to_string()).into()))? .ok_or_else(|| { error_response( ( StatusCode::BAD_REQUEST, - "missing Forwarded host in request header", + "missing Forwarded host/mac/serial in request header", ) .into(), ) - })? - .map_err(|e| error_response((StatusCode::BAD_REQUEST, e.to_string()).into()))?; + })?; + + let target_ip = match ip_for_forwarded_target(&forwarded_target, &state).await { + Ok(Some(ip)) => ip, + Ok(None) => { + return Err(error_response( + ( + StatusCode::BAD_REQUEST, + "Could not find BMC from forwarded header", + ) + .into(), + )); + } + Err(e) => { + return Err(error_response( + ( + StatusCode::BAD_GATEWAY, + format!("Failure looking up BMC IP from target: {e}"), + ) + .into(), + )); + } + }; let path_and_query = parts .uri @@ -543,6 +585,88 @@ async fn proxy_request( Ok(build_response(status, &headers, body)) } +async fn ip_for_forwarded_target( + forwarded_target: &ForwardedTarget<'_>, + state: &BmcProxyState, +) -> Result, tonic::Status> { + let lookup_by = match forwarded_target { + ForwardedTarget::Ip(ip) => { + // No need to look up + return Ok(Some(*ip)); + } + ForwardedTarget::Mac(mac) => LookupBy::MacAddress(mac.to_string()), + ForwardedTarget::Serial(serial) => LookupBy::Serial(serial.to_string()), + }; + + if let Some(ip) = state.ip_cache.lock().await.get(&lookup_by) { + return Ok(Some(*ip)); + } + + let lookup_by_str = match &lookup_by { + LookupBy::Serial(serial) => format!("Serial number {serial}"), + LookupBy::MacAddress(mac) => format!("MAC address {mac}"), + }; + + let ips = state + .api_client + .find_bmc_ips(forge::FindBmcIpsRequest { + lookup_by: Some(lookup_by.clone()), + }) + .await? + .bmc_ips + .iter() + .filter_map(|s| { + IpAddr::from_str(s) + .inspect_err(|e| tracing::error!("Invalid IP address returned by API: {e}")) + .ok() + }) + .collect::>(); + + if ips.is_empty() { + return Ok(None); + } + + let (v4_ips, v6_ips): (Vec, Vec) = ips.into_iter().partition(|ip| ip.is_ipv4()); + + let ip = match (v4_ips.len(), v6_ips.len()) { + (0, 1..) => { + if v6_ips.len() > 1 { + tracing::warn!( + "Multiple IPv6 BMC IP's found for {} ({}), using first one", + lookup_by_str, + v6_ips + .iter() + .map(|ip| ip.to_string()) + .collect::>() + .join(", "), + ); + } + v6_ips.into_iter().next() + } + _ => { + // TODO: We may want to be smart about when to pick IPv6 vs IPv4, but for now just pick IPv4 + // first, in case of broken dual-stack setups. + if v4_ips.len() > 1 { + tracing::warn!( + "Multiple IPv4 BMC IP's found for {} ({}), using first one", + lookup_by_str, + v4_ips + .iter() + .map(|ip| ip.to_string()) + .collect::>() + .join(", "), + ); + } + v4_ips.into_iter().next() + } + }; + + if let Some(ip) = ip { + state.ip_cache.lock().await.insert(lookup_by, ip); + } + Ok(ip) +} + async fn authorize_proxy_request( State(state): State, request: Request, @@ -634,7 +758,9 @@ fn is_hop_by_hop_header(name: &str) -> bool { ) } -fn forwarded_host_ip(headers: &HeaderMap) -> Option> { +fn forwarded_header_value( + headers: &HeaderMap, +) -> Result>, ForwardedHeaderParseError> { let values = headers.get_all("forwarded"); for raw_value in values { let Ok(raw_value) = raw_value.to_str() else { @@ -645,36 +771,41 @@ fn forwarded_host_ip(headers: &HeaderMap) -> Option Option> { +fn parse_forwarded_host_value(value: &str) -> Result { let value = value.trim_matches('"'); - if let Ok(ip) = IpAddr::from_str(value) { - return Some(Ok(ip)); + let result = IpAddr::from_str(value); + if let Ok(ip) = result { + return Ok(ip); } + // If it failed to parse, maybe it's a bracked ipv6 address, support that if let Some(rest) = value.strip_prefix('[') && let Some((host, _)) = rest.split_once(']') { - return Some(IpAddr::from_str(host)); - } - - if let Some((host, _port)) = value.rsplit_once(':') - && let Ok(ip) = IpAddr::from_str(host) - { - return Some(Ok(ip)); + IpAddr::from_str(host) + } else { + // Nope, just return the failure + result } - - None } fn error_response(error: ProxyError) -> Response { @@ -744,7 +875,7 @@ impl TryFrom for BmcCredentials { Some(forge::bmc_credentials::Type::SessionToken(value)) => { Ok(Self::SessionToken { token: value.token }) } - None => Err(BmcProxyError::ApiCredentials( + None => Err(BmcProxyError::Api( "missing credential type in API response".to_string(), )), } @@ -771,7 +902,7 @@ async fn create_client( let mut header_map = HeaderMap::new(); if add_custom_header { header_map.insert("forwarded", format!("host={ip}").parse().unwrap()); - }; + } let http_client = get_http_client(ip, client_cache).await?; let credentials = get_bmc_credentials(ip, api_client, credential_cache).await?; @@ -814,7 +945,7 @@ async fn get_bmc_credentials( bmc_ip: ip.to_string(), }) .await - .map_err(|e| BmcProxyError::ApiCredentials(e.to_string()))? + .map_err(|e| BmcProxyError::Api(e.to_string()))? .mac_address; let credentials: BmcCredentials = api_client @@ -822,7 +953,7 @@ async fn get_bmc_credentials( mac_addr: bmc_mac_address, }) .await - .map_err(|e| BmcProxyError::ApiCredentials(e.to_string()))? + .map_err(|e| BmcProxyError::Api(e.to_string()))? .credentials .ok_or(BmcProxyError::NoCredentials(ip))? .try_into()?; @@ -878,17 +1009,49 @@ mod tests { use std::sync::Arc; use axum::http::{HeaderMap, HeaderName, HeaderValue}; + use mac_address::MacAddress; + use opentelemetry::global; + use rpc::forge::find_bmc_ips_request::LookupBy; + use rpc::forge_api_client::ForgeApiClient; + use rpc::forge_tls_client::{ApiConfig, ForgeClientConfig}; use tokio::sync::Mutex; use super::{ - BmcCredentials, CredentialCache, evict_cached_credentials, forwarded_host_ip, - parse_forwarded_host_value, + BmcCredentials, BmcProxyState, CredentialCache, ForwardedTarget, evict_cached_credentials, + forwarded_header_value, ip_for_forwarded_target, parse_forwarded_host_value, }; + fn test_state_with_ip_cache(ip_cache: HashMap) -> BmcProxyState { + let client_config = ForgeClientConfig::default(); + let api_config = ApiConfig::new("https://example.com", &client_config); + + BmcProxyState { + config: Arc::new( + crate::Config::parse( + r#" + [tls] + identity_pemfile_path = "" + identity_keyfile_path = "" + root_cafile_path = "" + admin_root_cafile_path = "" + + [auth] + "#, + ) + .expect("test config should parse"), + ), + meter: global::meter("carbide-bmc-proxy-test"), + api_client: ForgeApiClient::new(&api_config), + credential_cache: Default::default(), + client_cache: Default::default(), + ip_cache: Arc::new(Mutex::new(ip_cache)), + } + } + #[test] fn parses_forwarded_ipv4() { assert_eq!( - parse_forwarded_host_value("10.0.0.5").unwrap().unwrap(), + parse_forwarded_host_value("10.0.0.5").unwrap(), IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)) ); } @@ -896,9 +1059,7 @@ mod tests { #[test] fn parses_forwarded_ipv6_with_port() { assert_eq!( - parse_forwarded_host_value("\"[2001:db8::1]:443\"") - .unwrap() - .unwrap(), + parse_forwarded_host_value("\"[2001:db8::1]:443\"").unwrap(), IpAddr::V6(Ipv6Addr::from_str("2001:db8::1").unwrap()) ); } @@ -911,8 +1072,93 @@ mod tests { HeaderValue::from_static("proto=https;host=10.1.2.3;for=10.0.0.1"), ); assert_eq!( - forwarded_host_ip(&headers).unwrap().unwrap(), - IpAddr::V4(Ipv4Addr::new(10, 1, 2, 3)) + forwarded_header_value(&headers).unwrap().unwrap(), + ForwardedTarget::Ip(IpAddr::V4(Ipv4Addr::new(10, 1, 2, 3))), + ); + } + + #[test] + fn finds_forwarded_mac_target() { + let mut headers = HeaderMap::new(); + headers.insert( + HeaderName::from_static("forwarded"), + HeaderValue::from_static("proto=https;mac=00:11:22:33:44:55;for=10.0.0.1"), + ); + + assert_eq!( + forwarded_header_value(&headers).unwrap().unwrap(), + ForwardedTarget::Mac(MacAddress::from_str("00:11:22:33:44:55").unwrap()), + ); + } + + #[test] + fn finds_forwarded_serial_target() { + let mut headers = HeaderMap::new(); + headers.insert( + HeaderName::from_static("forwarded"), + HeaderValue::from_static("proto=https; serial = DGX-A100-0001 ; for=10.0.0.1"), + ); + + assert_eq!( + forwarded_header_value(&headers).unwrap().unwrap(), + ForwardedTarget::Serial("DGX-A100-0001"), + ); + } + + #[test] + fn rejects_invalid_forwarded_mac_target() { + let mut headers = HeaderMap::new(); + headers.insert( + HeaderName::from_static("forwarded"), + HeaderValue::from_static("mac=not-a-mac-address"), + ); + + assert!(matches!( + forwarded_header_value(&headers), + Err(super::ForwardedHeaderParseError::Mac(_)) + )); + } + + #[tokio::test] + async fn forwarded_ip_target_resolves_without_lookup() { + let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)); + let state = test_state_with_ip_cache(HashMap::new()); + + assert_eq!( + ip_for_forwarded_target(&ForwardedTarget::Ip(ip), &state) + .await + .unwrap(), + Some(ip) + ); + } + + #[tokio::test] + async fn forwarded_mac_target_resolves_from_ip_cache() { + let mac = MacAddress::from_str("00:11:22:33:44:55").unwrap(); + let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)); + let state = + test_state_with_ip_cache(HashMap::from([(LookupBy::MacAddress(mac.to_string()), ip)])); + + assert_eq!( + ip_for_forwarded_target(&ForwardedTarget::Mac(mac), &state) + .await + .unwrap(), + Some(ip) + ); + } + + #[tokio::test] + async fn forwarded_serial_target_resolves_from_ip_cache() { + let serial = "DGX-A100-0001"; + let ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)); + let state = + test_state_with_ip_cache(HashMap::from([(LookupBy::Serial(serial.to_string()), ip)])); + + assert_eq!( + ip_for_forwarded_target(&ForwardedTarget::Serial(serial), &state) + .await + .unwrap(), + Some(ip) ); } diff --git a/crates/rpc/proto/forge.proto b/crates/rpc/proto/forge.proto index 135f90eab8..98eae73a72 100644 --- a/crates/rpc/proto/forge.proto +++ b/crates/rpc/proto/forge.proto @@ -234,6 +234,7 @@ service Forge { rpc FindConnectedDevicesByDpuMachineIds(common.MachineIdList) returns (ConnectedDeviceList); rpc FindMachineIdsByBmcIps(BmcIpList) returns (MachineIdBmcIpPairs); rpc FindMacAddressByBmcIp(BmcIp) returns (MacAddressBmcIp); + rpc FindBmcIps(FindBmcIpsRequest) returns (BmcIpList); rpc IdentifyUuid(IdentifyUuidRequest) returns (IdentifyUuidResponse); rpc IdentifyMac(IdentifyMacRequest) returns (IdentifyMacResponse); rpc IdentifySerial(IdentifySerialRequest) returns (IdentifySerialResponse); @@ -5187,6 +5188,13 @@ enum UuidType { UuidTypeComputeAllocationId = 6; } +message FindBmcIpsRequest { + oneof lookup_by { + string mac_address = 1; + string serial = 2; + } +} + message IdentifyMacRequest { string mac_address = 1; } diff --git a/crates/rpc/src/lib.rs b/crates/rpc/src/lib.rs index 83a290c0dd..0d311abd38 100644 --- a/crates/rpc/src/lib.rs +++ b/crates/rpc/src/lib.rs @@ -884,6 +884,12 @@ impl forge_agent_control_response::Action { } } +impl From for forge::find_bmc_ips_request::LookupBy { + fn from(addr: MacAddress) -> Self { + Self::MacAddress(addr.to_string()) + } +} + #[cfg(feature = "cli")] // This impl allows us to use the RPC RouteServerSourceType type // as a first class enum with clap, for the purpose of allowing