diff --git a/google/auth/compute_engine/_metadata.py b/google/auth/compute_engine/_metadata.py index e59736585..d3ca4c8d7 100644 --- a/google/auth/compute_engine/_metadata.py +++ b/google/auth/compute_engine/_metadata.py @@ -53,11 +53,7 @@ # Timeout in seconds to wait for the GCE metadata server when detecting the # GCE environment. -try: - _METADATA_DEFAULT_TIMEOUT = int(os.getenv("GCE_METADATA_TIMEOUT", 3)) -except ValueError: # pragma: NO COVER - _METADATA_DEFAULT_TIMEOUT = 3 - +_METADATA_PING_DEFAULT_TIMEOUT = 3 # Detect GCE Residency _GOOGLE = "Google" _GCE_PRODUCT_NAME_FILE = "/sys/class/dmi/id/product_name" @@ -100,7 +96,7 @@ def detect_gce_residency_linux(): return content.startswith(_GOOGLE) -def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3): +def ping(request, timeout=None, retry_count=3): """Checks to see if the metadata server is available. Args: @@ -119,6 +115,14 @@ def ping(request, timeout=_METADATA_DEFAULT_TIMEOUT, retry_count=3): # could lead to false negatives in the event that we are on GCE, but # the metadata resolution was particularly slow. The latter case is # "unlikely". + if timeout is None: + try: + timeout = float(os.getenv( + environment_vars.GCE_METADATA_TIMEOUT, + str(_METADATA_PING_DEFAULT_TIMEOUT))) + except ValueError: + timeout = _METADATA_PING_DEFAULT_TIMEOUT + retries = 0 headers = _METADATA_HEADERS.copy() headers[metrics.API_CLIENT_HEADER] = metrics.mds_ping() diff --git a/google/auth/environment_vars.py b/google/auth/environment_vars.py index 81f31571e..93cc8838a 100644 --- a/google/auth/environment_vars.py +++ b/google/auth/environment_vars.py @@ -72,6 +72,9 @@ Used to distinguish between GAE gen1 and GAE gen2+. """ +GCE_METADATA_TIMEOUT = "GCE_METADATA_TIMEOUT" +"""Environment variable for setting timeouts in seconds for metadata queries.""" + # AWS environment variables used with AWS workload identity pools to retrieve # AWS security credentials and the AWS region needed to create a serialized # signed requests to the AWS STS GetCalledIdentity API that can be exchanged diff --git a/tests/compute_engine/test__metadata.py b/tests/compute_engine/test__metadata.py index 60ae355ac..846e968b9 100644 --- a/tests/compute_engine/test__metadata.py +++ b/tests/compute_engine/test__metadata.py @@ -105,8 +105,45 @@ def test_ping_success(mock_metrics_header_value): request.assert_called_once_with( method="GET", url=_metadata._METADATA_IP_ROOT, - headers=MDS_PING_REQUEST_HEADER, - timeout=_metadata._METADATA_DEFAULT_TIMEOUT, + headers=_metadata._METADATA_HEADER, + timeout=_metadata._METADATA_PING_DEFAULT_TIMEOUT, + ) + +@mock.patch("google.auth.metrics.mds_ping", return_value=MDS_PING_METRICS_HEADER_VALUE) +def test_ping_success_with_gce_metadata_timeout(mock_metrics_header_value): + request = make_request("", headers=_metadata._METADATA_HEADERS) + gce_metadata_timeout = .5 + os.environ[ + environment_vars.GCE_METADATA_TIMEOUT] = str(gce_metadata_timeout) + + try: + assert _metadata.ping(request) + finally: + del os.environ[environment_vars.GCE_METADATA_TIMEOUT] + + request.assert_called_once_with( + method="GET", + url=_metadata._METADATA_IP_ROOT, + headers=_metadata._METADATA_HEADER, + timeout=gce_metadata_timeout, + ) + +@mock.patch("google.auth.metrics.mds_ping", return_value=MDS_PING_METRICS_HEADER_VALUE) +def test_ping_success_with_invalid_gce_metadata_timeout(mock_metrics_header_value): + request = make_request("", headers=_metadata._METADATA_HEADERS) + os.environ[ + environment_vars.GCE_METADATA_TIMEOUT] = "Not a valid float value!" + + try: + assert _metadata.ping(request) + finally: + del os.environ[environment_vars.GCE_METADATA_TIMEOUT] + + request.assert_called_once_with( + method="GET", + url=_metadata._METADATA_IP_ROOT, + headers=_metadata._METADATA_HEADERS, + timeout=_metadata._METADATA_PING_DEFAULT_TIMEOUT, # Fallback value. ) @@ -156,7 +193,7 @@ def test_ping_success_custom_root(mock_metrics_header_value): method="GET", url="http://" + fake_ip, headers=MDS_PING_REQUEST_HEADER, - timeout=_metadata._METADATA_DEFAULT_TIMEOUT, + timeout=_metadata._METADATA_PING_DEFAULT_TIMEOUT, )