diff --git a/crates/rmcp/src/transport/auth.rs b/crates/rmcp/src/transport/auth.rs index 1ff2ddd7..1ee1ceeb 100644 --- a/crates/rmcp/src/transport/auth.rs +++ b/crates/rmcp/src/transport/auth.rs @@ -1,4 +1,8 @@ -use std::{collections::HashMap, sync::Arc, time::Duration}; +use std::{ + collections::HashMap, + sync::Arc, + time::{Duration, SystemTime, UNIX_EPOCH}, +}; use async_trait::async_trait; use oauth2::{ @@ -61,6 +65,8 @@ pub struct StoredCredentials { pub token_response: Option, #[serde(default)] pub granted_scopes: Vec, + #[serde(default)] + pub token_received_at: Option, } /// Trait for storing and retrieving OAuth2 credentials @@ -943,34 +949,67 @@ impl AuthorizationManager { client_id, token_response: Some(token_result.clone()), granted_scopes, + token_received_at: Some(Self::now_epoch_secs()), }; self.credential_store.save(stored).await?; Ok(token_result) } + fn now_epoch_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs() + } + + /// Proactive refresh buffer: refresh tokens this many seconds before they expire + /// to avoid races between token retrieval and the actual HTTP request. + const REFRESH_BUFFER_SECS: u64 = 30; + /// get access token, if expired, refresh it automatically pub async fn get_access_token(&self) -> Result { - // Load credentials from store let stored = self.credential_store.load().await?; - let credentials = stored.and_then(|s| s.token_response); - - if let Some(creds) = credentials.as_ref() { - // check token expiry if we have a refresh token or an expiry time - if creds.refresh_token().is_some() || creds.expires_in().is_some() { - let expires_in = creds.expires_in().unwrap_or(Duration::from_secs(0)); - if expires_in <= Duration::from_secs(0) { - tracing::info!("Access token expired, refreshing."); - - let new_creds = self.refresh_token().await?; - tracing::info!("Refreshed access token."); - return Ok(new_creds.access_token().secret().to_string()); - } + let Some(stored_creds) = stored else { + return Err(AuthError::AuthorizationRequired); + }; + let Some(creds) = stored_creds.token_response.as_ref() else { + return Err(AuthError::AuthorizationRequired); + }; + + if let (Some(expires_in), Some(received_at)) = + (creds.expires_in(), stored_creds.token_received_at) + { + let elapsed = Self::now_epoch_secs().saturating_sub(received_at); + let remaining = expires_in.as_secs().saturating_sub(elapsed); + + if remaining < Self::REFRESH_BUFFER_SECS { + tracing::info!( + remaining_secs = remaining, + "Access token expired or nearly expired, refreshing." + ); + return self.try_refresh_or_reauth().await; } + } - Ok(creds.access_token().secret().to_string()) - } else { - Err(AuthError::AuthorizationRequired) + Ok(creds.access_token().secret().to_string()) + } + + /// Attempt to refresh the token. If refresh fails because there is no + /// refresh token or the server rejected it, return `AuthorizationRequired` + /// so the caller can re-prompt the user. Infrastructure errors (e.g. store + /// I/O failures, misconfigured client) are propagated as-is. + async fn try_refresh_or_reauth(&self) -> Result { + match self.refresh_token().await { + Ok(new_creds) => { + tracing::info!("Refreshed access token."); + Ok(new_creds.access_token().secret().to_string()) + } + Err(AuthError::AuthorizationRequired | AuthError::TokenRefreshFailed(_)) => { + tracing::warn!("Token refresh not possible, re-authorization required."); + Err(AuthError::AuthorizationRequired) + } + Err(e) => Err(e), } } @@ -999,10 +1038,10 @@ impl AuthorizationManager { .await .map_err(|e| AuthError::TokenRefreshFailed(e.to_string()))?; - let granted_scopes: Vec = token_result - .scopes() - .map(|scopes| scopes.iter().map(|s| s.to_string()).collect()) - .unwrap_or_else(|| self.current_scopes.blocking_read().clone()); + let granted_scopes: Vec = match token_result.scopes() { + Some(scopes) => scopes.iter().map(|s| s.to_string()).collect(), + None => self.current_scopes.read().await.clone(), + }; *self.current_scopes.write().await = granted_scopes.clone(); @@ -1011,6 +1050,7 @@ impl AuthorizationManager { client_id, token_response: Some(token_result.clone()), granted_scopes, + token_received_at: Some(Self::now_epoch_secs()), }; self.credential_store.save(stored).await?; @@ -1618,6 +1658,7 @@ impl OAuthState { client_id: client_id.to_string(), token_response: Some(credentials), granted_scopes, + token_received_at: Some(AuthorizationManager::now_epoch_secs()), }; manager.credential_store.save(stored).await?; @@ -2636,4 +2677,116 @@ mod tests { *manager.scope_upgrade_attempts.write().await = 1; assert!(manager.can_attempt_scope_upgrade().await); } + + // -- get_access_token -- + + fn make_token_response(access_token: &str, expires_in_secs: Option) -> OAuthTokenResponse { + use oauth2::{AccessToken, EmptyExtraTokenFields, basic::BasicTokenType}; + let mut resp = OAuthTokenResponse::new( + AccessToken::new(access_token.to_string()), + BasicTokenType::Bearer, + EmptyExtraTokenFields {}, + ); + if let Some(secs) = expires_in_secs { + resp.set_expires_in(Some(&std::time::Duration::from_secs(secs))); + } + resp + } + + use super::{OAuthTokenResponse, StoredCredentials}; + + #[tokio::test] + async fn get_access_token_returns_error_when_no_credentials() { + let manager = AuthorizationManager::new("http://localhost").await.unwrap(); + let err = manager.get_access_token().await.unwrap_err(); + assert!(matches!(err, AuthError::AuthorizationRequired)); + } + + #[tokio::test] + async fn get_access_token_returns_token_when_not_expired() { + let manager = AuthorizationManager::new("http://localhost").await.unwrap(); + let stored = StoredCredentials { + client_id: "test".to_string(), + token_response: Some(make_token_response("my-access-token", Some(3600))), + granted_scopes: vec![], + token_received_at: Some(AuthorizationManager::now_epoch_secs()), + }; + manager.credential_store.save(stored).await.unwrap(); + + let token = manager.get_access_token().await.unwrap(); + assert_eq!(token, "my-access-token"); + } + + #[tokio::test] + async fn get_access_token_requires_reauth_when_expired_and_no_refresh_token() { + let mut manager = manager_with_metadata(None).await; + manager.configure_client(test_client_config()).unwrap(); + + let stored = StoredCredentials { + client_id: "my-client".to_string(), + token_response: Some(make_token_response("stale-token", Some(3600))), + granted_scopes: vec![], + token_received_at: Some(AuthorizationManager::now_epoch_secs() - 7200), + }; + manager.credential_store.save(stored).await.unwrap(); + + let err = manager.get_access_token().await.unwrap_err(); + assert!( + matches!(err, AuthError::AuthorizationRequired), + "expected AuthorizationRequired when token is expired and refresh is impossible, got: {err:?}" + ); + } + + #[tokio::test] + async fn get_access_token_returns_token_without_expiry_info() { + let manager = AuthorizationManager::new("http://localhost").await.unwrap(); + let stored = StoredCredentials { + client_id: "test".to_string(), + token_response: Some(make_token_response("no-expiry-token", None)), + granted_scopes: vec![], + token_received_at: None, + }; + manager.credential_store.save(stored).await.unwrap(); + + let token = manager.get_access_token().await.unwrap(); + assert_eq!(token, "no-expiry-token"); + } + + #[tokio::test] + async fn get_access_token_requires_reauth_when_within_refresh_buffer() { + let mut manager = manager_with_metadata(None).await; + manager.configure_client(test_client_config()).unwrap(); + + let stored = StoredCredentials { + client_id: "my-client".to_string(), + token_response: Some(make_token_response("almost-expired", Some(3600))), + granted_scopes: vec![], + token_received_at: Some(AuthorizationManager::now_epoch_secs() - 3590), + }; + manager.credential_store.save(stored).await.unwrap(); + + let err = manager.get_access_token().await.unwrap_err(); + assert!( + matches!(err, AuthError::AuthorizationRequired), + "expected AuthorizationRequired when token is within refresh buffer, got: {err:?}" + ); + } + + #[tokio::test] + async fn get_access_token_propagates_internal_errors() { + let manager = AuthorizationManager::new("http://localhost").await.unwrap(); + let stored = StoredCredentials { + client_id: "test".to_string(), + token_response: Some(make_token_response("stale-token", Some(3600))), + granted_scopes: vec![], + token_received_at: Some(AuthorizationManager::now_epoch_secs() - 7200), + }; + manager.credential_store.save(stored).await.unwrap(); + + let err = manager.get_access_token().await.unwrap_err(); + assert!( + matches!(err, AuthError::InternalError(_)), + "expected InternalError when OAuth client is not configured, got: {err:?}" + ); + } }