diff --git a/src/webserver/oidc.rs b/src/webserver/oidc.rs index a6d83e20..21a40e24 100644 --- a/src/webserver/oidc.rs +++ b/src/webserver/oidc.rs @@ -1,8 +1,9 @@ use std::collections::HashSet; use std::future::ready; use std::rc::Rc; -use std::time::{Duration, Instant}; +use std::time::Duration; use std::{future::Future, pin::Pin, str::FromStr, sync::Arc}; +use tokio::time::Instant; use crate::webserver::http_client::get_http_client_from_appdata; use crate::{app_config::AppConfig, AppState}; @@ -34,7 +35,6 @@ use openidconnect::{ StandardTokenResponse, }; use serde::{Deserialize, Serialize}; -use tokio::sync::{RwLock, RwLockReadGuard}; use super::error::anyhow_err_to_actix_resp; use super::http_client::make_http_client; @@ -189,15 +189,21 @@ fn get_app_host(config: &AppConfig) -> String { host } -pub struct ClientWithTime { +/// A point-in-time snapshot of the OIDC provider's client and metadata. +/// Cheaply cloneable via Arc — callers never hold a lock while using this. +struct OidcSnapshot { client: OidcClient, end_session_endpoint: Option, - last_update: Instant, + created_at: Instant, } pub struct OidcState { pub config: OidcConfig, - client: RwLock, + /// Current snapshot. The lock is only held for the instant + /// needed to clone/swap the Arc — never across await points. + snapshot: std::sync::RwLock>, + /// Prevents concurrent background refreshes. + refresh_in_progress: std::sync::atomic::AtomicBool, } impl OidcState { @@ -207,63 +213,60 @@ impl OidcState { Ok(Self { config: oidc_cfg, - client: RwLock::new(ClientWithTime { + snapshot: std::sync::RwLock::new(Arc::new(OidcSnapshot { client, end_session_endpoint, - last_update: Instant::now(), - }), + created_at: Instant::now(), + })), + refresh_in_progress: std::sync::atomic::AtomicBool::new(false), }) } - async fn refresh(&self, service_request: &ServiceRequest) { - let mut write_guard = self.client.write().await; - match build_oidc_client_from_appdata(&self.config, service_request).await { - Ok((http_client, end_session_endpoint)) => { - *write_guard = ClientWithTime { - client: http_client, - end_session_endpoint, - last_update: Instant::now(), - } - } - Err(e) => log::error!("Failed to refresh OIDC client: {e:#}"), - } + /// Returns the current snapshot. Never blocks in practice. + fn snapshot(&self) -> Arc { + self.snapshot.read().unwrap().clone() } - /// Refreshes the OIDC client from the provider metadata URL if it has expired. - /// Most providers update their signing keys periodically. - pub async fn refresh_if_expired(&self, service_request: &ServiceRequest) { - if self.client.read().await.last_update.elapsed() > OIDC_CLIENT_MAX_REFRESH_INTERVAL { - self.refresh(service_request).await; + /// If the snapshot is older than `max_age` and no refresh is already running, + /// spawns a background task to fetch new provider metadata. + /// Returns immediately — never blocks the caller on I/O. + pub fn maybe_refresh(self: &Arc, http_client: &Client, max_age: Duration) { + use std::sync::atomic::Ordering; + if self.snapshot().created_at.elapsed() <= max_age { + return; } - } - - /// When an authentication error is encountered, refresh the OIDC client info faster - pub async fn refresh_on_error(&self, service_request: &ServiceRequest) { - if self.client.read().await.last_update.elapsed() > OIDC_CLIENT_MIN_REFRESH_INTERVAL { - self.refresh(service_request).await; + if self.refresh_in_progress.swap(true, Ordering::AcqRel) { + return; } + let state = Arc::clone(self); + let http_client = http_client.clone(); + tokio::task::spawn_local(async move { + match build_oidc_client(&state.config, &http_client).await { + Ok((client, end_session_endpoint)) => { + *state.snapshot.write().unwrap() = Arc::new(OidcSnapshot { + client, + end_session_endpoint, + created_at: Instant::now(), + }); + } + Err(e) => log::error!("Failed to refresh OIDC client: {e:#}"), + } + state.refresh_in_progress.store(false, Ordering::Release); + }); } - /// Gets a reference to the oidc client, potentially generating a new one if needed - pub async fn get_client(&self) -> RwLockReadGuard<'_, OidcClient> { - RwLockReadGuard::map( - self.client.read().await, - |ClientWithTime { client, .. }| client, - ) - } - - pub async fn get_end_session_endpoint(&self) -> Option { - self.client.read().await.end_session_endpoint.clone() + pub fn end_session_endpoint(&self) -> Option { + self.snapshot().end_session_endpoint.clone() } - /// Validate and decode the claims of an OIDC token, without refreshing the client. - async fn get_token_claims( + /// Validate and decode the claims of an OIDC token. + fn get_token_claims( &self, id_token: OidcToken, expected_nonce: &Nonce, ) -> anyhow::Result { - let client = &self.get_client().await; - let verifier = self.config.create_id_token_verifier(client); + let snapshot = self.snapshot(); + let verifier = self.config.create_id_token_verifier(&snapshot.client); let nonce_verifier = |nonce: Option<&Nonce>| check_nonce(nonce, expected_nonce); let claims: OidcClaims = id_token .into_claims(&verifier, nonce_verifier) @@ -271,13 +274,14 @@ impl OidcState { Ok(claims) } - /// Builds an absolute redirect URI by joining the relative redirect URI with the client's redirect URL - pub async fn build_absolute_redirect_uri( + /// Builds an absolute redirect URI from the client's configured redirect URL. + pub fn build_absolute_redirect_uri( &self, relative_redirect_uri: &str, ) -> anyhow::Result { - let client_guard = self.get_client().await; - let client_redirect_url = client_guard + let snapshot = self.snapshot(); + let client_redirect_url = snapshot + .client .redirect_uri() .ok_or_else(|| anyhow!("OIDC client has no redirect URL configured"))?; let absolute_redirect_uri = client_redirect_url @@ -309,14 +313,6 @@ pub async fn initialize_oidc_state( ))) } -async fn build_oidc_client_from_appdata( - cfg: &OidcConfig, - req: &ServiceRequest, -) -> anyhow::Result<(OidcClient, Option)> { - let http_client = get_http_client_from_appdata(req)?; - build_oidc_client(cfg, http_client).await -} - async fn build_oidc_client( oidc_cfg: &OidcConfig, http_client: &Client, @@ -405,9 +401,15 @@ enum MiddlewareResponse { Respond(ServiceResponse), } -async fn handle_request(oidc_state: &OidcState, request: ServiceRequest) -> MiddlewareResponse { +async fn handle_request( + oidc_state: &Arc, + request: ServiceRequest, +) -> MiddlewareResponse { log::trace!("Started OIDC middleware request handling"); - oidc_state.refresh_if_expired(&request).await; + let http_client = get_http_client_from_appdata(&request).ok(); + if let Some(c) = http_client { + oidc_state.maybe_refresh(c, OIDC_CLIENT_MAX_REFRESH_INTERVAL); + } if request.path() == oidc_state.config.redirect_uri { let response = handle_oidc_callback(oidc_state, request).await; @@ -415,11 +417,11 @@ async fn handle_request(oidc_state: &OidcState, request: ServiceRequest) -> Midd } if request.path() == oidc_state.config.logout_uri { - let response = handle_oidc_logout(oidc_state, request).await; + let response = handle_oidc_logout(oidc_state, request); return MiddlewareResponse::Respond(response); } - match get_authenticated_user_info(oidc_state, &request).await { + match get_authenticated_user_info(oidc_state, &request) { Ok(Some(claims)) => { log::trace!("Storing authenticated user info in request extensions: {claims:?}"); request.extensions_mut().insert(claims); @@ -427,17 +429,19 @@ async fn handle_request(oidc_state: &OidcState, request: ServiceRequest) -> Midd } Ok(None) => { log::trace!("No authenticated user found"); - handle_unauthenticated_request(oidc_state, request).await + handle_unauthenticated_request(oidc_state, request) } Err(e) => { log::debug!("An auth cookie is present but could not be verified. Redirecting to OIDC provider to re-authenticate. {e:?}"); - oidc_state.refresh_on_error(&request).await; - handle_unauthenticated_request(oidc_state, request).await + if let Some(c) = http_client { + oidc_state.maybe_refresh(c, OIDC_CLIENT_MIN_REFRESH_INTERVAL); + } + handle_unauthenticated_request(oidc_state, request) } } } -async fn handle_unauthenticated_request( +fn handle_unauthenticated_request( oidc_state: &OidcState, request: ServiceRequest, ) -> MiddlewareResponse { @@ -451,35 +455,39 @@ async fn handle_unauthenticated_request( let initial_url = request.uri().to_string(); let redirect_count = get_redirect_count(&request); - let response = - build_auth_provider_redirect_response(oidc_state, &initial_url, redirect_count).await; + let response = build_auth_provider_redirect_response(oidc_state, &initial_url, redirect_count); MiddlewareResponse::Respond(request.into_response(response)) } -async fn handle_oidc_callback(oidc_state: &OidcState, request: ServiceRequest) -> ServiceResponse { +async fn handle_oidc_callback( + oidc_state: &Arc, + request: ServiceRequest, +) -> ServiceResponse { match process_oidc_callback(oidc_state, &request).await { Ok(mut response) => { clear_redirect_count_cookie(&mut response); request.into_response(response) } - Err(e) => handle_oidc_callback_error(oidc_state, request, e).await, + Err(e) => handle_oidc_callback_error(oidc_state, request, &e), } } -async fn handle_oidc_callback_error( - oidc_state: &OidcState, +fn handle_oidc_callback_error( + oidc_state: &Arc, request: ServiceRequest, - e: anyhow::Error, + e: &anyhow::Error, ) -> ServiceResponse { let redirect_count = get_redirect_count(&request); if redirect_count >= MAX_OIDC_REDIRECTS { - return handle_max_redirect_count_reached(request, &e, redirect_count); + return handle_max_redirect_count_reached(request, e, redirect_count); } log::error!( "Failed to process OIDC callback (attempt {redirect_count}). Refreshing oidc provider metadata, then redirecting to home page: {e:#}" ); - oidc_state.refresh_on_error(&request).await; - let resp = build_auth_provider_redirect_response(oidc_state, "/", redirect_count).await; + if let Ok(http_client) = get_http_client_from_appdata(&request) { + oidc_state.maybe_refresh(http_client, OIDC_CLIENT_MIN_REFRESH_INTERVAL); + } + let resp = build_auth_provider_redirect_response(oidc_state, "/", redirect_count); request.into_response(resp) } @@ -496,8 +504,8 @@ fn handle_max_redirect_count_reached( request.into_response(resp) } -async fn handle_oidc_logout(oidc_state: &OidcState, request: ServiceRequest) -> ServiceResponse { - match process_oidc_logout(oidc_state, &request).await { +fn handle_oidc_logout(oidc_state: &OidcState, request: ServiceRequest) -> ServiceResponse { + match process_oidc_logout(oidc_state, &request) { Ok(response) => request.into_response(response), Err(e) => { log::error!("Failed to process OIDC logout: {e:#}"); @@ -525,7 +533,7 @@ fn parse_logout_params(query: &str) -> anyhow::Result { .map(Query::into_inner) } -async fn process_oidc_logout( +fn process_oidc_logout( oidc_state: &OidcState, request: &ServiceRequest, ) -> anyhow::Result { @@ -541,34 +549,31 @@ async fn process_oidc_logout( .ok() .flatten(); - let mut response = - if let Some(end_session_endpoint) = oidc_state.get_end_session_endpoint().await { - let absolute_redirect_uri = oidc_state - .build_absolute_redirect_uri(¶ms.redirect_uri) - .await?; + let mut response = if let Some(end_session_endpoint) = oidc_state.end_session_endpoint() { + let absolute_redirect_uri = oidc_state.build_absolute_redirect_uri(¶ms.redirect_uri)?; - let post_logout_redirect_uri = - PostLogoutRedirectUrl::new(absolute_redirect_uri.clone()).with_context(|| { - format!("Invalid post_logout_redirect_uri: {absolute_redirect_uri}") - })?; + let post_logout_redirect_uri = PostLogoutRedirectUrl::new(absolute_redirect_uri.clone()) + .with_context(|| { + format!("Invalid post_logout_redirect_uri: {absolute_redirect_uri}") + })?; - let mut logout_request = LogoutRequest::from(end_session_endpoint) - .set_post_logout_redirect_uri(post_logout_redirect_uri); + let mut logout_request = LogoutRequest::from(end_session_endpoint) + .set_post_logout_redirect_uri(post_logout_redirect_uri); - if let Some(ref token) = id_token { - logout_request = logout_request.set_id_token_hint(token); - } + if let Some(ref token) = id_token { + logout_request = logout_request.set_id_token_hint(token); + } - let logout_url = logout_request.http_get_url(); - log::info!("Redirecting to OIDC logout URL: {logout_url}"); - build_redirect_response(logout_url.to_string()) - } else { - log::info!( - "No end_session_endpoint, redirecting to {}", - params.redirect_uri - ); - build_redirect_response(params.redirect_uri) - }; + let logout_url = logout_request.http_get_url(); + log::info!("Redirecting to OIDC logout URL: {logout_url}"); + build_redirect_response(logout_url.to_string()) + } else { + log::info!( + "No end_session_endpoint, redirecting to {}", + params.redirect_uri + ); + build_redirect_response(params.redirect_uri) + }; response.add_removal_cookie( &Cookie::build(SQLPAGE_AUTH_COOKIE_NAME, "") @@ -663,9 +668,9 @@ async fn process_oidc_callback( .into_inner(); log::debug!("Processing OIDC callback with params: {params:?}. Requesting token..."); let mut tmp_login_flow_state_cookie = get_tmp_login_flow_state_cookie(request, ¶ms.state)?; - let client = oidc_state.get_client().await; + let snapshot = oidc_state.snapshot(); let http_client = get_http_client_from_appdata(request)?; - let id_token = exchange_code_for_token(&client, http_client, params.clone()).await?; + let id_token = exchange_code_for_token(&snapshot.client, http_client, params.clone()).await?; log::debug!("Received token response: {id_token:?}"); let LoginFlowState { nonce, @@ -679,7 +684,6 @@ async fn process_oidc_callback( set_auth_cookie(&mut response, &id_token); let claims = oidc_state .get_token_claims(id_token, &nonce) - .await .context("The identity provider returned an invalid ID token")?; log::debug!("{} successfully logged in", claims.subject().as_str()); let nonce_cookie = create_final_nonce_cookie(&nonce); @@ -730,12 +734,12 @@ fn set_auth_cookie(response: &mut HttpResponse, id_token: &OidcToken) { response.add_cookie(&cookie).unwrap(); } -async fn build_auth_provider_redirect_response( +fn build_auth_provider_redirect_response( oidc_state: &OidcState, initial_url: &str, redirect_count: u8, ) -> HttpResponse { - let AuthUrl { url, params } = build_auth_url(oidc_state).await; + let AuthUrl { url, params } = build_auth_url(oidc_state); let tmp_login_flow_state_cookie = create_tmp_login_flow_state_cookie(¶ms, initial_url); let redirect_count_cookie = Cookie::build( SQLPAGE_OIDC_REDIRECT_COUNT_COOKIE, @@ -782,7 +786,7 @@ fn build_oidc_error_response(request: &ServiceRequest, e: &anyhow::Error) -> Htt } /// Returns the claims from the ID token in the `SQLPage` auth cookie. -async fn get_authenticated_user_info( +fn get_authenticated_user_info( oidc_state: &OidcState, request: &ServiceRequest, ) -> anyhow::Result> { @@ -795,7 +799,7 @@ async fn get_authenticated_user_info( let nonce = get_final_nonce_from_cookie(request)?; log::debug!("Verifying id token: {id_token:?}"); - let claims = oidc_state.get_token_claims(id_token, &nonce).await?; + let claims = oidc_state.get_token_claims(id_token, &nonce)?; log::debug!("The current user is: {claims:?}"); Ok(Some(claims)) } @@ -963,12 +967,13 @@ struct AuthUrlParams { nonce: Nonce, } -async fn build_auth_url(oidc_state: &OidcState) -> AuthUrl { +fn build_auth_url(oidc_state: &OidcState) -> AuthUrl { let nonce_source = Nonce::new_random(); let hashed_nonce = Nonce::new(hash_nonce(&nonce_source)); let scopes = &oidc_state.config.scopes; - let client_lock = oidc_state.get_client().await; - let (url, csrf_token, _nonce) = client_lock + let snapshot = oidc_state.snapshot(); + let (url, csrf_token, _nonce) = snapshot + .client .authorize_url( CoreAuthenticationFlow::AuthorizationCode, CsrfToken::new_random, diff --git a/tests/oidc/mod.rs b/tests/oidc/mod.rs index 027fe6ac..9375cff4 100644 --- a/tests/oidc/mod.rs +++ b/tests/oidc/mod.rs @@ -52,6 +52,7 @@ struct ProviderState<'a> { auth_codes: HashMap, // code -> nonce jwt_customizer: Option>>, token_endpoint_delay: Duration, + discovery_count: usize, } type ProviderStateWithLifetime<'a> = ProviderState<'a>; @@ -83,7 +84,8 @@ struct TokenResponse { } async fn discovery_endpoint(state: Data) -> impl Responder { - let state = state.lock().unwrap(); + let mut state = state.lock().unwrap(); + state.discovery_count += 1; let discovery = DiscoveryResponse { issuer: state.issuer_url.clone(), authorization_endpoint: format!("{}/auth", state.issuer_url), @@ -94,6 +96,7 @@ async fn discovery_endpoint(state: Data) -> impl Responder id_token_signing_alg_values_supported: vec!["HS256".to_string()], end_session_endpoint: format!("{}/logout", state.issuer_url), }; + drop(state); HttpResponse::Ok() .insert_header((header::CONTENT_TYPE, "application/json")) .json(discovery) @@ -196,6 +199,7 @@ impl FakeOidcProvider { auth_codes: HashMap::new(), jwt_customizer: None, token_endpoint_delay: Duration::ZERO, + discovery_count: 0, })); let state_for_server = Arc::clone(&state); @@ -241,6 +245,10 @@ impl FakeOidcProvider { self.with_state_mut(|s| s.token_endpoint_delay = delay); } + pub fn discovery_count(&self) -> usize { + self.state.lock().unwrap().discovery_count + } + pub fn store_auth_code(&self, code: String, nonce: String) { self.with_state_mut(|s| { s.auth_codes.insert(code, nonce); @@ -556,8 +564,54 @@ async fn test_oidc_logout_uses_correct_scheme() { assert_eq!(post_logout, "https://example.com/logged_out"); } -/// A slow OIDC provider must not freeze the server. -/// See https://github.com/sqlpage/SQLPage/issues/1231 +/// An OIDC provider metadata refresh must not block authenticated requests. +/// The refresh should happen in the background while existing requests are +/// served using the current (possibly stale) OIDC client. +#[actix_web::test] +async fn test_slow_discovery_does_not_block_authenticated_requests() { + let (app, provider) = setup_oidc_test(|_| {}).await; + let mut cookies: Vec> = Vec::new(); + + // Complete a full login to get auth cookies + let resp = request_with_cookies!(app, test::TestRequest::get().uri("/"), cookies); + assert_eq!(resp.status(), StatusCode::SEE_OTHER); + let auth_url = Url::parse(resp.headers().get("location").unwrap().to_str().unwrap()).unwrap(); + let state_param = get_query_param(&auth_url, "state"); + let nonce = get_query_param(&auth_url, "nonce"); + let redirect_uri = get_query_param(&auth_url, "redirect_uri"); + provider.store_auth_code("test_auth_code".to_string(), nonce); + let callback_uri = format!( + "{}?code=test_auth_code&state={}", + Url::parse(&redirect_uri).unwrap().path(), + state_param + ); + let callback_resp = + request_with_cookies!(app, test::TestRequest::get().uri(&callback_uri), cookies); + assert_eq!(callback_resp.status(), StatusCode::SEE_OTHER); + + // Advance time so the OIDC snapshot appears stale. + // The next request triggers a background refresh. + let count_before = provider.discovery_count(); + tokio::time::pause(); + tokio::time::advance(Duration::from_secs(3601)).await; + // Resume real time so the DB pool and background refresh work normally. + tokio::time::resume(); + + // An authenticated request must succeed immediately, even though + // it triggers a background refresh. + let resp = request_with_cookies!(app, test::TestRequest::get().uri("/"), cookies); + assert_eq!(resp.status(), StatusCode::OK); + + // Let the background refresh task complete. + tokio::task::yield_now().await; + assert!( + provider.discovery_count() > count_before, + "OIDC provider metadata was not refreshed" + ); +} + +/// A slow OIDC token endpoint must not freeze the server. +/// The body-read timeout fires and the request completes with a redirect. #[actix_web::test] async fn test_slow_token_endpoint_does_not_freeze_server() { let (app, provider) = setup_oidc_test(|_| {}).await; @@ -587,15 +641,12 @@ async fn test_slow_token_endpoint_does_not_freeze_server() { test::call_service(&app, req.to_request()).await }); - // Let the localhost TCP round-trip complete so awc reads response headers. + // Let the TCP round-trip complete so awc reads HTTP headers, + // then advance past the body-read timeout. tokio::time::sleep(Duration::from_millis(50)).await; - - // Freeze time and advance past the body-read timeout. tokio::time::pause(); tokio::time::advance(Duration::from_secs(60)).await; - // The body timeout should have fired, completing the request with an error - // that SQLPage handles by redirecting to the OIDC provider. let resp = tokio::time::timeout(Duration::from_secs(1), handle) .await .expect("OIDC callback hung on a slow token endpoint")