From 22cfcc21a1b99b4c4c3690ecb2827ca7f90d1225 Mon Sep 17 00:00:00 2001 From: Paul Van Eck Date: Fri, 20 Feb 2026 03:17:27 +0000 Subject: [PATCH] [Identity] Add background token refresh to async creds Signed-off-by: Paul Van Eck --- .../aio/_credentials/authorization_code.py | 2 +- .../identity/aio/_credentials/certificate.py | 2 +- .../aio/_credentials/client_assertion.py | 1 + .../aio/_credentials/client_secret.py | 2 +- .../azure/identity/aio/_credentials/imds.py | 1 + .../identity/aio/_credentials/on_behalf_of.py | 1 + .../identity/aio/_internal/get_token_mixin.py | 77 ++++- .../aio/_internal/managed_identity_base.py | 1 + .../tests/test_get_token_mixin_async.py | 267 +++++++++++++++++- 9 files changed, 341 insertions(+), 13 deletions(-) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py index 623a7b5ba095..6190effcca52 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/authorization_code.py @@ -47,7 +47,7 @@ async def __aenter__(self) -> "AuthorizationCodeCredential": async def close(self) -> None: """Close the credential's transport session.""" - + self._cancel_background_refresh_tasks() if self._client: await self._client.__aexit__() diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/certificate.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/certificate.py index 8b19361c11e0..dd44f28b78d7 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/certificate.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/certificate.py @@ -67,7 +67,7 @@ async def __aenter__(self) -> "CertificateCredential": async def close(self) -> None: """Close the credential's transport session.""" - + self._cancel_background_refresh_tasks() await self._client.__aexit__() async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_assertion.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_assertion.py index a316760455e1..1c5540a45602 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_assertion.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_assertion.py @@ -64,6 +64,7 @@ async def __aenter__(self) -> "ClientAssertionCredential": async def close(self) -> None: """Close the credential's transport session.""" + self._cancel_background_refresh_tasks() await self._client.close() async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_secret.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_secret.py index 50bbb3de9315..e8e8d41e6924 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_secret.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_secret.py @@ -57,7 +57,7 @@ async def __aenter__(self) -> "ClientSecretCredential": async def close(self) -> None: """Close the credential's transport session.""" - + self._cancel_background_refresh_tasks() await self._client.__aexit__() async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py index 0fb758c3026c..38279502c7f2 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/imds.py @@ -62,6 +62,7 @@ async def __aenter__(self) -> "ImdsCredential": return self async def close(self) -> None: + self._cancel_background_refresh_tasks() await self._client.close() async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/on_behalf_of.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/on_behalf_of.py index e3d1c8e47b44..3cf0d0aed0e5 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/on_behalf_of.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/on_behalf_of.py @@ -110,6 +110,7 @@ async def __aenter__(self) -> "OnBehalfOfCredential": async def close(self) -> None: """Close the credential's underlying HTTP client.""" + self._cancel_background_refresh_tasks() await self._client.close() async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]: diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/get_token_mixin.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/get_token_mixin.py index 8b383fef9818..a838711202d6 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_internal/get_token_mixin.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/get_token_mixin.py @@ -3,20 +3,25 @@ # Licensed under the MIT License. # ------------------------------------ import abc +import asyncio # pylint: disable=do-not-import-asyncio import logging import time -from typing import Any, Optional +from typing import Any, Dict, Optional, Tuple from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions from ..._constants import DEFAULT_REFRESH_OFFSET, DEFAULT_TOKEN_REFRESH_RETRY_DELAY from ..._internal import within_credential_chain _LOGGER = logging.getLogger(__name__) +_BACKGROUND_REFRESH_MIN_VALIDITY_SECONDS = 600 + +_BackgroundRefreshKey = Tuple[Tuple[str, ...], Optional[str], Optional[str], bool] class GetTokenMixin(abc.ABC): def __init__(self, *args: Any, **kwargs: Any) -> None: self._last_request_time = 0 + self._background_refresh_tasks: Dict[_BackgroundRefreshKey, "asyncio.Task[None]"] = {} # https://github.com/python/mypy/issues/5887 super(GetTokenMixin, self).__init__(*args, **kwargs) # type: ignore @@ -45,6 +50,50 @@ async def _request_token(self, *scopes: str, **kwargs) -> AccessTokenInfo: :rtype: ~azure.core.credentials.AccessTokenInfo """ + @staticmethod + def _uses_asyncio() -> bool: + try: + asyncio.get_running_loop() + return True + except RuntimeError: + return False + + def _start_background_refresh(self, key: _BackgroundRefreshKey, *scopes: str, **kwargs: Any) -> None: + existing_task = self._background_refresh_tasks.get(key) + if existing_task is not None and not existing_task.done(): + return + + task = asyncio.create_task(self._background_refresh(*scopes, **kwargs)) + self._background_refresh_tasks[key] = task + + def _cleanup(done_task: "asyncio.Task[None]") -> None: + if self._background_refresh_tasks.get(key) is done_task: + self._background_refresh_tasks.pop(key, None) + + task.add_done_callback(_cleanup) + + async def _background_refresh(self, *scopes: str, **kwargs: Any) -> None: + try: + await self._request_token(*scopes, **kwargs) + except Exception as ex: # pylint:disable=broad-except + _LOGGER.debug("Background token refresh failed: %s", ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG)) + + def _cancel_background_refresh_tasks(self) -> None: + """Cancel all pending background refresh tasks. + + Credentials should call this from their ``close`` method to avoid tasks + running against a closed HTTP transport. + """ + tasks = list(self._background_refresh_tasks.values()) + self._background_refresh_tasks.clear() + for task in tasks: + task.cancel() + + def __getstate__(self) -> Dict[str, Any]: + state = self.__dict__.copy() + state["_background_refresh_tasks"] = {} + return state + def _should_refresh(self, token: AccessTokenInfo) -> bool: now = int(time.time()) if token.refresh_on is not None and now >= token.refresh_on: @@ -132,19 +181,31 @@ async def _get_token_base( token = await self._acquire_token_silently( *scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs ) + now = int(time.time()) if not token: - self._last_request_time = int(time.time()) + self._last_request_time = now token = await self._request_token( *scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs ) elif self._should_refresh(token): - try: - self._last_request_time = int(time.time()) - token = await self._request_token( - *scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs + self._last_request_time = now + if self._uses_asyncio() and token.expires_on - now >= _BACKGROUND_REFRESH_MIN_VALIDITY_SECONDS: + # Token has a certain remaining validity; refresh in the background and return the cached token. + self._start_background_refresh( + (scopes, claims, tenant_id, enable_cae), + *scopes, + claims=claims, + tenant_id=tenant_id, + enable_cae=enable_cae, + **kwargs, ) - except Exception: # pylint:disable=broad-except - pass + else: + try: + token = await self._request_token( + *scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs + ) + except Exception: # pylint:disable=broad-except + pass _LOGGER.log( logging.DEBUG if within_credential_chain.get() else logging.INFO, "%s.%s succeeded", diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/managed_identity_base.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/managed_identity_base.py index e07403a1982c..cd29c0b5db20 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_internal/managed_identity_base.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/managed_identity_base.py @@ -45,6 +45,7 @@ async def __aexit__( await self._client.__aexit__(exc_type, exc_value, traceback) async def close(self) -> None: + self._cancel_background_refresh_tasks() await self.__aexit__() async def get_token( diff --git a/sdk/identity/azure-identity/tests/test_get_token_mixin_async.py b/sdk/identity/azure-identity/tests/test_get_token_mixin_async.py index 4258a4684148..694ac6169f32 100644 --- a/sdk/identity/azure-identity/tests/test_get_token_mixin_async.py +++ b/sdk/identity/azure-identity/tests/test_get_token_mixin_async.py @@ -2,6 +2,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +import asyncio +import concurrent.futures +import pickle import time from unittest import mock @@ -9,6 +12,7 @@ import pytest from azure.identity._constants import DEFAULT_REFRESH_OFFSET +from azure.identity.aio._internal import AsyncContextManager from azure.identity.aio._internal.get_token_mixin import GetTokenMixin from helpers import GET_TOKEN_METHODS @@ -16,7 +20,7 @@ pytestmark = pytest.mark.asyncio -class MockCredential(GetTokenMixin): +class MockCredential(AsyncContextManager, GetTokenMixin): NEW_TOKEN = AccessTokenInfo("new token", 42) def __init__(self, cached_token=None): @@ -25,6 +29,9 @@ def __init__(self, cached_token=None): self.request_token = mock.Mock(return_value=MockCredential.NEW_TOKEN) self.acquire_token_silently = mock.Mock(return_value=cached_token) + async def close(self) -> None: + self._cancel_background_refresh_tasks() + async def _acquire_token_silently(self, *scopes, **kwargs): return self.acquire_token_silently(*scopes, **kwargs) @@ -107,7 +114,7 @@ async def test_cached_token_outside_refresh_window(get_token_method): @pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) async def test_cached_token_within_refresh_window(get_token_method): - """A credential should request a new token when its cached one is within the refresh window""" + """A credential should request a new token inline when its cached one is within the refresh window""" credential = MockCredential( cached_token=AccessTokenInfo(CACHED_TOKEN, int(time.time() + DEFAULT_REFRESH_OFFSET - 1)) @@ -133,3 +140,259 @@ async def test_retry_delay(get_token_method): assert token.token == CACHED_TOKEN credential.acquire_token_silently.assert_called_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) credential.request_token.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) + + +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_background_refresh_does_not_block(get_token_method): + """Background refresh should return the cached token immediately while refreshing in the background""" + + now = int(time.time()) + # Token has plenty of time before expiry but refresh_on has passed, triggering background refresh + credential = MockCredential(cached_token=AccessTokenInfo(CACHED_TOKEN, now + 3600, refresh_on=now - 1)) + + token = await getattr(credential, get_token_method)(SCOPE) + + # The cached token is returned immediately, before the background task completes + assert token.token == CACHED_TOKEN + # Wait for background refresh task to complete + key = ((SCOPE,), None, None, False) + assert key in credential._background_refresh_tasks + await asyncio.gather(credential._background_refresh_tasks[key], return_exceptions=True) + credential.request_token.assert_called_once() + + +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_no_background_refresh_in_trio(get_token_method): + """When not in asyncio (e.g., trio), refresh should happen inline""" + + now = int(time.time()) + credential = MockCredential(cached_token=AccessTokenInfo(CACHED_TOKEN, now + 3600, refresh_on=now - 1)) + + # Simulate a non-asyncio environment by making _uses_asyncio return False + with mock.patch.object(type(credential), "_uses_asyncio", return_value=False): + token = await getattr(credential, get_token_method)(SCOPE) + + # Inline refresh returns the new token + assert token.token == MockCredential.NEW_TOKEN.token + credential.request_token.assert_called_once_with(SCOPE, claims=None, enable_cae=False, tenant_id=None) + + +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_no_duplicate_background_refresh(get_token_method): + """If a background refresh is already in progress, a new one should not be started""" + + refresh_started = asyncio.Event() + refresh_continue = asyncio.Event() + + class SlowMockCredential(MockCredential): + async def _request_token(self, *scopes, **kwargs): + refresh_started.set() + await refresh_continue.wait() + return self.request_token(*scopes, **kwargs) + + credential = SlowMockCredential( + cached_token=AccessTokenInfo(CACHED_TOKEN, int(time.time() + 3600), refresh_on=int(time.time()) - 1) + ) + + # First call triggers a background refresh + token1 = await getattr(credential, get_token_method)(SCOPE) + assert token1.token == CACHED_TOKEN + key = ((SCOPE,), None, None, False) + assert key in credential._background_refresh_tasks + first_refresh_task = credential._background_refresh_tasks[key] + + # Wait for the background task to start + await refresh_started.wait() + + # Second call while the first refresh is still in progress should NOT start another task + token2 = await getattr(credential, get_token_method)(SCOPE) + assert token2.token == CACHED_TOKEN + assert credential._background_refresh_tasks[key] is first_refresh_task + + # Let the background refresh complete + refresh_continue.set() + await asyncio.gather(first_refresh_task, return_exceptions=True) + + # _request_token should have been called exactly once (no duplicate) + credential.request_token.assert_called_once() + + +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_different_requests_do_not_share_background_refresh(get_token_method): + """Different request keys should be allowed to refresh concurrently in the background""" + + scope_1 = "scope-1" + scope_2 = "scope-2" + refresh_started_scope_1 = asyncio.Event() + refresh_started_scope_2 = asyncio.Event() + refresh_continue = asyncio.Event() + + class SlowMockCredential(MockCredential): + async def _request_token(self, *scopes, **kwargs): + if scopes[0] == scope_1: + refresh_started_scope_1.set() + elif scopes[0] == scope_2: + refresh_started_scope_2.set() + await refresh_continue.wait() + return self.request_token(*scopes, **kwargs) + + credential = SlowMockCredential( + cached_token=AccessTokenInfo(CACHED_TOKEN, int(time.time() + 3600), refresh_on=int(time.time()) - 1) + ) + + token1 = await getattr(credential, get_token_method)(scope_1) + assert token1.token == CACHED_TOKEN + await refresh_started_scope_1.wait() + + token2 = await getattr(credential, get_token_method)(scope_2) + assert token2.token == CACHED_TOKEN + await refresh_started_scope_2.wait() + + key_1 = ((scope_1,), None, None, False) + key_2 = ((scope_2,), None, None, False) + assert key_1 in credential._background_refresh_tasks + assert key_2 in credential._background_refresh_tasks + assert credential._background_refresh_tasks[key_1] is not credential._background_refresh_tasks[key_2] + + refresh_continue.set() + await asyncio.gather( + credential._background_refresh_tasks[key_1], + credential._background_refresh_tasks[key_2], + return_exceptions=True, + ) + assert credential.request_token.call_count == 2 + + +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_background_refresh_failure_returns_cached_token(get_token_method): + """If the background refresh fails, the caller should still get the cached token""" + + now = int(time.time()) + credential = MockCredential(cached_token=AccessTokenInfo(CACHED_TOKEN, now + 3600, refresh_on=now - 1)) + credential.request_token = mock.Mock(side_effect=Exception("transient error")) + + token = await getattr(credential, get_token_method)(SCOPE) + assert token.token == CACHED_TOKEN + + key = ((SCOPE,), None, None, False) + task = credential._background_refresh_tasks.get(key) + assert task is not None + await asyncio.gather(task, return_exceptions=True) + + # request_token was called and raised, but the caller was not affected + credential.request_token.assert_called_once() + + +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_background_refresh_task_cleanup(get_token_method): + """After a background refresh completes, its task should be removed from the dict""" + + now = int(time.time()) + credential = MockCredential(cached_token=AccessTokenInfo(CACHED_TOKEN, now + 3600, refresh_on=now - 1)) + + await getattr(credential, get_token_method)(SCOPE) + key = ((SCOPE,), None, None, False) + assert key in credential._background_refresh_tasks + + task = credential._background_refresh_tasks[key] + await asyncio.gather(task, return_exceptions=True) + # Allow the done-callback to run + await asyncio.sleep(0) + assert key not in credential._background_refresh_tasks + + +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_cancel_background_refresh_tasks(get_token_method): + """_cancel_background_refresh_tasks should cancel in-flight tasks (called on close)""" + + refresh_continue = asyncio.Event() + + class SlowMockCredential(MockCredential): + async def _request_token(self, *scopes, **kwargs): + await refresh_continue.wait() + return self.request_token(*scopes, **kwargs) + + now = int(time.time()) + credential = SlowMockCredential(cached_token=AccessTokenInfo(CACHED_TOKEN, now + 3600, refresh_on=now - 1)) + + await getattr(credential, get_token_method)(SCOPE) + key = ((SCOPE,), None, None, False) + task = credential._background_refresh_tasks[key] + assert not task.done() + + credential._cancel_background_refresh_tasks() + assert len(credential._background_refresh_tasks) == 0 + # Wait for the cancellation to fully propagate + await asyncio.gather(task, return_exceptions=True) + assert task.cancelled() + + +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_close_cancels_background_refresh(get_token_method): + """Calling close() directly (without async with) should cancel background tasks""" + + refresh_continue = asyncio.Event() + + class SlowMockCredential(MockCredential): + async def _request_token(self, *scopes, **kwargs): + await refresh_continue.wait() + return self.request_token(*scopes, **kwargs) + + now = int(time.time()) + credential = SlowMockCredential(cached_token=AccessTokenInfo(CACHED_TOKEN, now + 3600, refresh_on=now - 1)) + + await getattr(credential, get_token_method)(SCOPE) + key = ((SCOPE,), None, None, False) + task = credential._background_refresh_tasks[key] + assert not task.done() + + await credential.close() + assert len(credential._background_refresh_tasks) == 0 + await asyncio.gather(task, return_exceptions=True) + assert task.cancelled() + + +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_pickle_with_active_background_refresh(get_token_method): + """A credential should be picklable even when a background refresh task is active""" + + now = int(time.time()) + credential = MockCredential(cached_token=AccessTokenInfo(CACHED_TOKEN, now + 3600, refresh_on=now - 1)) + + await getattr(credential, get_token_method)(SCOPE) + key = ((SCOPE,), None, None, False) + task = credential._background_refresh_tasks.get(key) + assert task is not None + + # Mock attributes set to None since Mock objects aren't picklable. + credential.request_token = None + credential.acquire_token_silently = None + pickled = pickle.dumps(credential) + restored = pickle.loads(pickled) + assert restored._background_refresh_tasks == {} + assert restored._last_request_time == credential._last_request_time + + # Clean up + await asyncio.gather(task, return_exceptions=True) + + +@pytest.mark.parametrize("get_token_method", GET_TOKEN_METHODS) +async def test_background_refresh_multithread_event_loops(get_token_method): + """Background refresh should work when multiple threads each run their own event loop.""" + + now = int(time.time()) + credential = MockCredential(cached_token=AccessTokenInfo(CACHED_TOKEN, now + 3600, refresh_on=now - 1)) + + def run_in_thread(): + async def get_token(): + return await getattr(credential, get_token_method)(SCOPE) + + return asyncio.run(get_token()) + + num_threads = 4 + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(run_in_thread) for _ in range(num_threads)] + results = [f.result() for f in concurrent.futures.as_completed(futures)] + + for token in results: + assert token.token == CACHED_TOKEN