Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 95 additions & 8 deletions crates/forge_infra/src/mcp_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,18 +181,26 @@ impl ForgeMcpClient {
// Do NOT allow interactive auth during normal connection
self.create_oauth_connection(http, oauth_config, false)
.await?
} else if self.has_stored_credentials(http).await {
// Auto-detect with stored credentials: the user has already
// authenticated via `mcp login`, so use those credentials
// directly. This avoids an unauthenticated probe whose auth
// error can be masked by transport fallbacks (e.g. a server
// that returns 405 on the SSE GET), which would otherwise
// prevent the stored token from ever being used.
tracing::debug!(
"Found stored OAuth credentials for: {}, using them directly",
http.url
);
let default_config = forge_domain::McpOAuthConfig::default();
self.create_oauth_connection(http, &default_config, false)
.await?
} else {
// Auto-detect: try standard first, fall back to OAuth on auth errors
match self.create_standard_http_connection(http).await {
Ok(client) => Arc::new(client),
Err(e) => {
let error_str = e.to_string().to_lowercase();
if error_str.contains("401")
|| error_str.contains("unauthorized")
|| error_str.contains("authentication required")
|| error_str.contains("auth required")
|| error_str.contains("oauth")
{
if Self::is_auth_error(&e) {
tracing::info!(
"Standard connection failed with auth error for: {}, trying stored credentials",
http.url
Expand All @@ -214,6 +222,44 @@ impl ForgeMcpClient {
Ok(client)
}

/// Returns true if there are stored OAuth credentials for the given server.
///
/// Used by the auto-detect path to decide whether to authenticate directly
/// with stored credentials (after a prior `mcp login`) instead of probing
/// the server with an unauthenticated request first.
async fn has_stored_credentials(&self, http: &McpHttpServer) -> bool {
use crate::auth::McpTokenStorage;

let storage = McpTokenStorage::new(http.url.clone(), self.environment.clone());
matches!(storage.load_credentials().await, Ok(Some(_)))
}

/// Returns true if the error represents an authentication failure.
///
/// Inspects the rmcp transport error chain for a streamable-HTTP
/// `AuthRequired` error and also falls back to string matching for cases
/// where the auth failure surfaces only as a textual message.
fn is_auth_error(error: &anyhow::Error) -> bool {
// Walk the error source chain looking for an explicit auth-required
// marker emitted by rmcp's streamable HTTP transport.
let mut source: Option<&(dyn std::error::Error + 'static)> = Some(error.as_ref());
while let Some(err) = source {
let message = err.to_string().to_lowercase();
if message.contains("auth required")
|| message.contains("authrequired")
|| message.contains("www-authenticate")
|| message.contains("401")
|| message.contains("unauthorized")
|| message.contains("authentication required")
|| message.contains("oauth")
{
return true;
}
source = err.source();
}
false
}

/// Create a standard HTTP connection without OAuth
async fn create_standard_http_connection(
&self,
Expand All @@ -227,7 +273,15 @@ impl ForgeMcpClient {
);
match self.client_info().serve(transport).await {
Ok(client) => Ok(client),
Err(_e) => {
Err(e) => {
// Do not fall back to SSE on authentication failures. The SSE
// transport uses a GET request which many servers reject with a
// non-auth status (e.g. 405), masking the original 401 and
// breaking OAuth auto-detection. Surface the auth error instead.
let e = anyhow::Error::new(e);
if Self::is_auth_error(&e) {
return Err(e);
}
let transport = SseClientTransport::start_with_client(
client.as_ref().clone(),
SseClientConfig { sse_endpoint: http.url.clone().into(), ..Default::default() },
Expand Down Expand Up @@ -811,6 +865,39 @@ mod tests {

use super::*;

#[test]
fn test_is_auth_error_detects_auth_required() {
let fixture = anyhow::anyhow!("Transport [streamable-http] error: Auth required");
let actual = ForgeMcpClient::is_auth_error(&fixture);
let expected = true;
assert_eq!(actual, expected);
}

#[test]
fn test_is_auth_error_detects_unauthorized() {
let fixture = anyhow::anyhow!("HTTP status 401 Unauthorized");
let actual = ForgeMcpClient::is_auth_error(&fixture);
let expected = true;
assert_eq!(actual, expected);
}

#[test]
fn test_is_auth_error_detects_in_source_chain() {
let source = std::io::Error::other("server responded with www-authenticate: Bearer");
let fixture = anyhow::Error::new(source).context("failed to connect to MCP server");
let actual = ForgeMcpClient::is_auth_error(&fixture);
let expected = true;
assert_eq!(actual, expected);
}

#[test]
fn test_is_auth_error_ignores_non_auth_error() {
let fixture = anyhow::anyhow!("HTTP status 405 Method Not Allowed");
let actual = ForgeMcpClient::is_auth_error(&fixture);
let expected = false;
assert_eq!(actual, expected);
}

#[test]
fn test_resolve_http_templates_with_env() {
let env_vars = BTreeMap::from([
Expand Down
Loading