Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions google/auth/_constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Shared constants."""

_SERVICE_ACCOUNT_TRUST_BOUNDARY_LOOKUP_ENDPOINT = "https://iamcredentials.{universe_domain}/v1/projects/-/serviceAccounts/{service_account_email}/allowedLocations"
_WORKFORCE_POOL_TRUST_BOUNDARY_LOOKUP_ENDPOINT = "https://iamcredentials.{universe_domain}/v1/locations/global/workforcePools/{pool_id}/allowedLocations"
_WORKLOAD_IDENTITY_POOL_TRUST_BOUNDARY_LOOKUP_ENDPOINT = "https://iamcredentials.{universe_domain}/v1/projects/{project_number}/locations/global/workloadIdentityPools/{pool_id}/allowedLocations"
_SERVICE_ACCOUNT_REGIONAL_ACCESS_BOUNDARY_LOOKUP_ENDPOINT = "https://iamcredentials.{universe_domain}/v1/projects/-/serviceAccounts/{service_account_email}/allowedLocations"
_WORKFORCE_POOL_REGIONAL_ACCESS_BOUNDARY_LOOKUP_ENDPOINT = "https://iamcredentials.{universe_domain}/v1/locations/global/workforcePools/{pool_id}/allowedLocations"
_WORKLOAD_IDENTITY_POOL_REGIONAL_ACCESS_BOUNDARY_LOOKUP_ENDPOINT = "https://iamcredentials.{universe_domain}/v1/projects/{project_number}/locations/global/workloadIdentityPools/{pool_id}/allowedLocations"
87 changes: 87 additions & 0 deletions google/auth/_regional_access_boundary_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Utilities for Regional Access Boundary management."""

import threading
import datetime

from google.auth import _helpers
from google.auth import exceptions
from google.auth._default import _LOGGER


# The default lifetime for a cached Regional Access Boundary.
DEFAULT_REGIONAL_ACCESS_BOUNDARY_TTL = datetime.timedelta(hours=6)

# The initial cooldown period for a failed Regional Access Boundary lookup.
DEFAULT_REGIONAL_ACCESS_BOUNDARY_COOLDOWN = datetime.timedelta(minutes=15)


class _RegionalAccessBoundaryRefreshThread(threading.Thread):
"""Thread for background refreshing of the Regional Access Boundary."""

def __init__(self, credentials, request):
super(_RegionalAccessBoundaryRefreshThread, self).__init__()
self._credentials = credentials
self._request = request

def run(self):
"""
Performs the Regional Access Boundary lookup. This method is run in a separate thread.

It includes a short-term retry loop for transient server errors. If the
lookup fails completely, it sets a longer-term cooldown period on the
credential to avoid overwhelming the lookup service.
"""
regional_access_boundary_info = self._credentials._lookup_regional_access_boundary_with_retry(
self._request
)

if regional_access_boundary_info:
# On success, update the boundary and its expiry, and clear any cooldown.
self._credentials._regional_access_boundary = regional_access_boundary_info
self._credentials._regional_access_boundary_expiry = (
_helpers.utcnow() + DEFAULT_REGIONAL_ACCESS_BOUNDARY_TTL
)
self._credentials._regional_access_boundary_cooldown_expiry = None
if _helpers.is_logging_enabled(_LOGGER):
_LOGGER.debug(
"Asynchronous Regional Access Boundary lookup successful."
)
else:
# On complete failure, set a cooldown period. The existing
# _regional_access_boundary and _regional_access_boundary_expiry
# will be kept as they are considered safe to use until explicitly
# invalidated by a "stale Regional Access Boundary" API error.
if _helpers.is_logging_enabled(_LOGGER):
_LOGGER.warning(
"Asynchronous Regional Access Boundary lookup failed. Entering cooldown."
)

self._credentials._regional_access_boundary_cooldown_expiry = (
_helpers.utcnow() + DEFAULT_REGIONAL_ACCESS_BOUNDARY_COOLDOWN
)


class _RegionalAccessBoundaryRefreshManager(object):
"""Manages a thread for background refreshing of the Regional Access Boundary."""

def __init__(self):
self._lock = threading.Lock()
self._worker = None

def start_refresh(self, credentials, request):
"""
Starts a background thread to refresh the Regional Access Boundary if one is not already running.

Args:
credentials (CredentialsWithRegionalAccessBoundary): The credentials
to refresh.
request (google.auth.transport.Request): The object used to make
HTTP requests.
"""
with self._lock:
if self._worker and self._worker.is_alive():
# A refresh is already in progress.
return

self._worker = _RegionalAccessBoundaryRefreshThread(credentials, request)
self._worker.start()
40 changes: 27 additions & 13 deletions google/auth/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,9 @@ def _generate_authentication_header_map(
full_headers[key.lower()] = additional_headers[key]
# Add AWS session token if available.
if aws_security_credentials.session_token is not None:
full_headers[
_AWS_SECURITY_TOKEN_HEADER
] = aws_security_credentials.session_token
full_headers[_AWS_SECURITY_TOKEN_HEADER] = (
aws_security_credentials.session_token
)

# Required headers
full_headers["host"] = host
Expand Down Expand Up @@ -348,10 +348,10 @@ def _generate_authentication_header_map(
class AwsSecurityCredentials:
"""A class that models AWS security credentials with an optional session token.

Attributes:
access_key_id (str): The AWS security credentials access key id.
secret_access_key (str): The AWS security credentials secret access key.
session_token (Optional[str]): The optional AWS security credentials session token. This should be set when using temporary credentials.
Attributes:
access_key_id (str): The AWS security credentials access key id.
secret_access_key (str): The AWS security credentials secret access key.
session_token (Optional[str]): The optional AWS security credentials session token. This should be set when using temporary credentials.
"""

access_key_id: str
Expand Down Expand Up @@ -641,7 +641,7 @@ def __init__(
"regional_cred_verification_url": "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15",
"region_url": "http://169.254.169.254/latest/meta-data/placement/availability-zone",
"url": "http://169.254.169.254/latest/meta-data/iam/security-credentials",
imdsv2_session_token_url": "http://169.254.169.254/latest/api/token"
"imdsv2_session_token_url": "http://169.254.169.254/latest/api/token"
}

aws_security_credentials_supplier (Optional [AwsSecurityCredentialsSupplier]): Optional AWS security credentials supplier.
Expand All @@ -660,6 +660,9 @@ def __init__(
:meth:`from_file` or
:meth:`from_info` are used instead of calling the constructor directly.
"""
# Pop regional_access_boundary from kwargs to avoid passing it to the parent constructor.
kwargs.pop("regional_access_boundary", None)

super(Credentials, self).__init__(
audience=audience,
subject_token_type=subject_token_type,
Expand Down Expand Up @@ -688,8 +691,8 @@ def __init__(
)
else:
environment_id = credential_source.get("environment_id") or ""
self._aws_security_credentials_supplier = _DefaultAwsSecurityCredentialsSupplier(
credential_source
self._aws_security_credentials_supplier = (
_DefaultAwsSecurityCredentialsSupplier(credential_source)
)
self._cred_verification_url = credential_source.get(
"regional_cred_verification_url"
Expand Down Expand Up @@ -759,8 +762,10 @@ def retrieve_subject_token(self, request):

# Retrieve the AWS security credentials needed to generate the signed
# request.
aws_security_credentials = self._aws_security_credentials_supplier.get_aws_security_credentials(
self._supplier_context, request
aws_security_credentials = (
self._aws_security_credentials_supplier.get_aws_security_credentials(
self._supplier_context, request
)
)
# Generate the signed request to AWS STS GetCallerIdentity API.
# Use the required regional endpoint. Otherwise, the request will fail.
Expand Down Expand Up @@ -845,7 +850,16 @@ def from_info(cls, info, **kwargs):
kwargs.update(
{"aws_security_credentials_supplier": aws_security_credentials_supplier}
)
return super(Credentials, cls).from_info(info, **kwargs)
regional_access_boundary = info.pop("regional_access_boundary", None)

credentials = super(Credentials, cls).from_info(info, **kwargs)

if regional_access_boundary:
credentials = credentials.with_regional_access_boundary(
regional_access_boundary
)

return credentials

@classmethod
def from_file(cls, filename, **kwargs):
Expand Down
78 changes: 29 additions & 49 deletions google/auth/compute_engine/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import datetime

import warnings
from google.auth import _helpers
from google.auth import credentials
from google.auth import exceptions
Expand All @@ -30,7 +31,7 @@
from google.auth.compute_engine import _metadata
from google.oauth2 import _client

_TRUST_BOUNDARY_LOOKUP_ENDPOINT = (
_REGIONAL_ACCESS_BOUNDARY_LOOKUP_ENDPOINT = (
"https://iamcredentials.{}/v1/projects/-/serviceAccounts/{}/allowedLocations"
)

Expand All @@ -39,7 +40,7 @@ class Credentials(
credentials.Scoped,
credentials.CredentialsWithQuotaProject,
credentials.CredentialsWithUniverseDomain,
credentials.CredentialsWithTrustBoundary,
credentials.CredentialsWithRegionalAccessBoundary,
):
"""Compute Engine Credentials.

Expand All @@ -66,7 +67,7 @@ def __init__(
scopes=None,
default_scopes=None,
universe_domain=None,
trust_boundary=None,
regional_access_boundary=None,
):
"""
Args:
Expand All @@ -82,7 +83,7 @@ def __init__(
provided or None, credential will attempt to fetch the value
from metadata server. If metadata server doesn't have universe
domain endpoint, then the default googleapis.com will be used.
trust_boundary (Mapping[str,str]): A credential trust boundary.
regional_access_boundary (Mapping[str,str]): A credential Regional Access Boundary.
"""
super(Credentials, self).__init__()
self._service_account_email = service_account_email
Expand All @@ -93,7 +94,6 @@ def __init__(
if universe_domain:
self._universe_domain = universe_domain
self._universe_domain_cached = True
self._trust_boundary = trust_boundary

def _retrieve_info(self, request):
"""Retrieve information about the service account.
Expand Down Expand Up @@ -146,8 +146,8 @@ def _refresh_token(self, request):
new_exc = exceptions.RefreshError(caught_exc)
raise new_exc from caught_exc

def _build_trust_boundary_lookup_url(self):
"""Builds and returns the URL for the trust boundary lookup API for GCE."""
def _build_regional_access_boundary_lookup_url(self):
"""Builds and returns the URL for the Regional Access Boundary lookup API for GCE."""
# If the service account email is 'default', we need to get the
# actual email address from the metadata server.
if self._service_account_email == "default":
Expand All @@ -165,15 +165,15 @@ def _build_trust_boundary_lookup_url(self):

except exceptions.TransportError as e:
# If fetching the service account email fails due to a transport error,
# it means we cannot build the trust boundary lookup URL.
# Wrap this in a RefreshError so it's caught by _refresh_trust_boundary.
# it means we cannot build the Regional Access Boundary lookup URL.
# Wrap this in a RefreshError so it's caught by _refresh_regional_access_boundary.
raise exceptions.RefreshError(
"Failed to get service account email for trust boundary lookup: {}".format(
"Failed to get service account email for Regional Access Boundary lookup: {}".format(
e
)
) from e

return _TRUST_BOUNDARY_LOOKUP_ENDPOINT.format(
return _REGIONAL_ACCESS_BOUNDARY_LOOKUP_ENDPOINT.format(
self.universe_domain, self.service_account_email
)

Expand Down Expand Up @@ -211,57 +211,37 @@ def get_cred_info(self):
"principal": self.service_account_email,
}

@_helpers.copy_docstring(credentials.CredentialsWithQuotaProject)
def with_quota_project(self, quota_project_id):
creds = self.__class__(
def _make_copy(self):
"""Create a copy of the current credentials."""
new_creds = self.__class__(
service_account_email=self._service_account_email,
quota_project_id=quota_project_id,
quota_project_id=self._quota_project_id,
scopes=self._scopes,
default_scopes=self._default_scopes,
universe_domain=self._universe_domain,
trust_boundary=self._trust_boundary,
)
creds._universe_domain_cached = self._universe_domain_cached
new_creds._universe_domain_cached = self._universe_domain_cached
self._copy_regional_access_boundary_state(new_creds)
return new_creds

@_helpers.copy_docstring(credentials.CredentialsWithQuotaProject)
def with_quota_project(self, quota_project_id):
creds = self._make_copy()
creds._quota_project_id = quota_project_id
return creds

@_helpers.copy_docstring(credentials.Scoped)
def with_scopes(self, scopes, default_scopes=None):
# Compute Engine credentials can not be scoped (the metadata service
# ignores the scopes parameter). App Engine, Cloud Run and Flex support
# requesting scopes.
creds = self.__class__(
scopes=scopes,
default_scopes=default_scopes,
service_account_email=self._service_account_email,
quota_project_id=self._quota_project_id,
universe_domain=self._universe_domain,
trust_boundary=self._trust_boundary,
)
creds._universe_domain_cached = self._universe_domain_cached
creds = self._make_copy()
creds._scopes = scopes
creds._default_scopes = default_scopes
return creds

@_helpers.copy_docstring(credentials.CredentialsWithUniverseDomain)
def with_universe_domain(self, universe_domain):
return self.__class__(
scopes=self._scopes,
default_scopes=self._default_scopes,
service_account_email=self._service_account_email,
quota_project_id=self._quota_project_id,
trust_boundary=self._trust_boundary,
universe_domain=universe_domain,
)

@_helpers.copy_docstring(credentials.CredentialsWithTrustBoundary)
def with_trust_boundary(self, trust_boundary):
creds = self.__class__(
service_account_email=self._service_account_email,
quota_project_id=self._quota_project_id,
scopes=self._scopes,
default_scopes=self._default_scopes,
universe_domain=self._universe_domain,
trust_boundary=trust_boundary,
)
creds._universe_domain_cached = self._universe_domain_cached
creds = self._make_copy()
creds._universe_domain = universe_domain
creds._universe_domain_cached = True
return creds


Expand Down
Loading