Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async def __aenter__(self) -> "AuthorizationCodeCredential":

async def close(self) -> None:
"""Close the credential's transport session."""

self._cancel_background_refresh_tasks()
if self._client:
await self._client.__aexit__()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async def __aenter__(self) -> "CertificateCredential":

async def close(self) -> None:
"""Close the credential's transport session."""

self._cancel_background_refresh_tasks()
await self._client.__aexit__()

async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ async def __aenter__(self) -> "ClientAssertionCredential":

async def close(self) -> None:
"""Close the credential's transport session."""
self._cancel_background_refresh_tasks()
await self._client.close()

async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ async def __aenter__(self) -> "ClientSecretCredential":

async def close(self) -> None:
"""Close the credential's transport session."""

self._cancel_background_refresh_tasks()
await self._client.__aexit__()

async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ async def __aenter__(self) -> "ImdsCredential":
return self

async def close(self) -> None:
self._cancel_background_refresh_tasks()
await self._client.close()

async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ async def __aenter__(self) -> "OnBehalfOfCredential":

async def close(self) -> None:
"""Close the credential's underlying HTTP client."""
self._cancel_background_refresh_tasks()
await self._client.close()

async def _acquire_token_silently(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenInfo]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,25 @@
# Licensed under the MIT License.
# ------------------------------------
import abc
import asyncio # pylint: disable=do-not-import-asyncio
import logging
import time
from typing import Any, Optional
from typing import Any, Dict, Optional, Tuple

from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions
from ..._constants import DEFAULT_REFRESH_OFFSET, DEFAULT_TOKEN_REFRESH_RETRY_DELAY
from ..._internal import within_credential_chain

_LOGGER = logging.getLogger(__name__)
_BACKGROUND_REFRESH_MIN_VALIDITY_SECONDS = 600

_BackgroundRefreshKey = Tuple[Tuple[str, ...], Optional[str], Optional[str], bool]


class GetTokenMixin(abc.ABC):
def __init__(self, *args: Any, **kwargs: Any) -> None:
self._last_request_time = 0
self._background_refresh_tasks: Dict[_BackgroundRefreshKey, "asyncio.Task[None]"] = {}

# https://github.com/python/mypy/issues/5887
super(GetTokenMixin, self).__init__(*args, **kwargs) # type: ignore
Expand Down Expand Up @@ -45,6 +50,50 @@ async def _request_token(self, *scopes: str, **kwargs) -> AccessTokenInfo:
:rtype: ~azure.core.credentials.AccessTokenInfo
"""

@staticmethod
def _uses_asyncio() -> bool:
try:
asyncio.get_running_loop()
return True
except RuntimeError:
return False

def _start_background_refresh(self, key: _BackgroundRefreshKey, *scopes: str, **kwargs: Any) -> None:
existing_task = self._background_refresh_tasks.get(key)
if existing_task is not None and not existing_task.done():
return

task = asyncio.create_task(self._background_refresh(*scopes, **kwargs))
self._background_refresh_tasks[key] = task

def _cleanup(done_task: "asyncio.Task[None]") -> None:
if self._background_refresh_tasks.get(key) is done_task:
self._background_refresh_tasks.pop(key, None)

task.add_done_callback(_cleanup)

async def _background_refresh(self, *scopes: str, **kwargs: Any) -> None:
try:
await self._request_token(*scopes, **kwargs)
except Exception as ex: # pylint:disable=broad-except
_LOGGER.debug("Background token refresh failed: %s", ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG))

def _cancel_background_refresh_tasks(self) -> None:
"""Cancel all pending background refresh tasks.

Credentials should call this from their ``close`` method to avoid tasks
running against a closed HTTP transport.
"""
tasks = list(self._background_refresh_tasks.values())
self._background_refresh_tasks.clear()
for task in tasks:
task.cancel()

def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()
state["_background_refresh_tasks"] = {}
return state

def _should_refresh(self, token: AccessTokenInfo) -> bool:
now = int(time.time())
if token.refresh_on is not None and now >= token.refresh_on:
Expand Down Expand Up @@ -132,19 +181,31 @@ async def _get_token_base(
token = await self._acquire_token_silently(
*scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs
)
now = int(time.time())
if not token:
self._last_request_time = int(time.time())
self._last_request_time = now
token = await self._request_token(
*scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs
)
elif self._should_refresh(token):
try:
self._last_request_time = int(time.time())
token = await self._request_token(
*scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs
self._last_request_time = now
if self._uses_asyncio() and token.expires_on - now >= _BACKGROUND_REFRESH_MIN_VALIDITY_SECONDS:
# Token has a certain remaining validity; refresh in the background and return the cached token.
self._start_background_refresh(
(scopes, claims, tenant_id, enable_cae),
*scopes,
claims=claims,
tenant_id=tenant_id,
enable_cae=enable_cae,
**kwargs,
)
except Exception: # pylint:disable=broad-except
pass
else:
try:
token = await self._request_token(
*scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs
)
except Exception: # pylint:disable=broad-except
pass
_LOGGER.log(
logging.DEBUG if within_credential_chain.get() else logging.INFO,
"%s.%s succeeded",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ async def __aexit__(
await self._client.__aexit__(exc_type, exc_value, traceback)

async def close(self) -> None:
self._cancel_background_refresh_tasks()
await self.__aexit__()

async def get_token(
Expand Down
Loading
Loading