From 1da251be083f7c1aa20d3ef98e76dce455418b9c Mon Sep 17 00:00:00 2001 From: Amit Singh Date: Fri, 29 May 2026 12:02:11 +0530 Subject: [PATCH] fix(mcp): use stored OAuth credentials for auto-detect HTTP servers Co-Authored-By: ForgeCode --- crates/forge_infra/src/mcp_client.rs | 103 ++++++++++++++++++++++++--- 1 file changed, 95 insertions(+), 8 deletions(-) diff --git a/crates/forge_infra/src/mcp_client.rs b/crates/forge_infra/src/mcp_client.rs index 23e246a1f3..e3f53ac08d 100644 --- a/crates/forge_infra/src/mcp_client.rs +++ b/crates/forge_infra/src/mcp_client.rs @@ -181,18 +181,26 @@ impl ForgeMcpClient { // Do NOT allow interactive auth during normal connection self.create_oauth_connection(http, oauth_config, false) .await? + } else if self.has_stored_credentials(http).await { + // Auto-detect with stored credentials: the user has already + // authenticated via `mcp login`, so use those credentials + // directly. This avoids an unauthenticated probe whose auth + // error can be masked by transport fallbacks (e.g. a server + // that returns 405 on the SSE GET), which would otherwise + // prevent the stored token from ever being used. + tracing::debug!( + "Found stored OAuth credentials for: {}, using them directly", + http.url + ); + let default_config = forge_domain::McpOAuthConfig::default(); + self.create_oauth_connection(http, &default_config, false) + .await? } else { // Auto-detect: try standard first, fall back to OAuth on auth errors match self.create_standard_http_connection(http).await { Ok(client) => Arc::new(client), Err(e) => { - let error_str = e.to_string().to_lowercase(); - if error_str.contains("401") - || error_str.contains("unauthorized") - || error_str.contains("authentication required") - || error_str.contains("auth required") - || error_str.contains("oauth") - { + if Self::is_auth_error(&e) { tracing::info!( "Standard connection failed with auth error for: {}, trying stored credentials", http.url @@ -214,6 +222,44 @@ impl ForgeMcpClient { Ok(client) } + /// Returns true if there are stored OAuth credentials for the given server. + /// + /// Used by the auto-detect path to decide whether to authenticate directly + /// with stored credentials (after a prior `mcp login`) instead of probing + /// the server with an unauthenticated request first. + async fn has_stored_credentials(&self, http: &McpHttpServer) -> bool { + use crate::auth::McpTokenStorage; + + let storage = McpTokenStorage::new(http.url.clone(), self.environment.clone()); + matches!(storage.load_credentials().await, Ok(Some(_))) + } + + /// Returns true if the error represents an authentication failure. + /// + /// Inspects the rmcp transport error chain for a streamable-HTTP + /// `AuthRequired` error and also falls back to string matching for cases + /// where the auth failure surfaces only as a textual message. + fn is_auth_error(error: &anyhow::Error) -> bool { + // Walk the error source chain looking for an explicit auth-required + // marker emitted by rmcp's streamable HTTP transport. + let mut source: Option<&(dyn std::error::Error + 'static)> = Some(error.as_ref()); + while let Some(err) = source { + let message = err.to_string().to_lowercase(); + if message.contains("auth required") + || message.contains("authrequired") + || message.contains("www-authenticate") + || message.contains("401") + || message.contains("unauthorized") + || message.contains("authentication required") + || message.contains("oauth") + { + return true; + } + source = err.source(); + } + false + } + /// Create a standard HTTP connection without OAuth async fn create_standard_http_connection( &self, @@ -227,7 +273,15 @@ impl ForgeMcpClient { ); match self.client_info().serve(transport).await { Ok(client) => Ok(client), - Err(_e) => { + Err(e) => { + // Do not fall back to SSE on authentication failures. The SSE + // transport uses a GET request which many servers reject with a + // non-auth status (e.g. 405), masking the original 401 and + // breaking OAuth auto-detection. Surface the auth error instead. + let e = anyhow::Error::new(e); + if Self::is_auth_error(&e) { + return Err(e); + } let transport = SseClientTransport::start_with_client( client.as_ref().clone(), SseClientConfig { sse_endpoint: http.url.clone().into(), ..Default::default() }, @@ -811,6 +865,39 @@ mod tests { use super::*; + #[test] + fn test_is_auth_error_detects_auth_required() { + let fixture = anyhow::anyhow!("Transport [streamable-http] error: Auth required"); + let actual = ForgeMcpClient::is_auth_error(&fixture); + let expected = true; + assert_eq!(actual, expected); + } + + #[test] + fn test_is_auth_error_detects_unauthorized() { + let fixture = anyhow::anyhow!("HTTP status 401 Unauthorized"); + let actual = ForgeMcpClient::is_auth_error(&fixture); + let expected = true; + assert_eq!(actual, expected); + } + + #[test] + fn test_is_auth_error_detects_in_source_chain() { + let source = std::io::Error::other("server responded with www-authenticate: Bearer"); + let fixture = anyhow::Error::new(source).context("failed to connect to MCP server"); + let actual = ForgeMcpClient::is_auth_error(&fixture); + let expected = true; + assert_eq!(actual, expected); + } + + #[test] + fn test_is_auth_error_ignores_non_auth_error() { + let fixture = anyhow::anyhow!("HTTP status 405 Method Not Allowed"); + let actual = ForgeMcpClient::is_auth_error(&fixture); + let expected = false; + assert_eq!(actual, expected); + } + #[test] fn test_resolve_http_templates_with_env() { let env_vars = BTreeMap::from([