diff --git a/cwms/api.py b/cwms/api.py index e96526e..4f82099 100644 --- a/cwms/api.py +++ b/cwms/api.py @@ -37,6 +37,7 @@ from typing import Any, Optional, cast from requests import Response, adapters +from requests.exceptions import RetryError as RequestsRetryError from requests_toolbelt import sessions # type: ignore from requests_toolbelt.sessions import BaseUrlSession # type: ignore from urllib3.util.retry import Retry @@ -55,12 +56,12 @@ status_forcelist=[ 403, 429, - 500, 502, 503, 504, ], # Example: also retry on these HTTP status codes allowed_methods=["GET", "PUT", "POST", "PATCH", "DELETE"], # Methods to retry + raise_on_status=False, ) SESSION = sessions.BaseUrlSession(base_url=API_ROOT) adapter = adapters.HTTPAdapter( @@ -140,6 +141,27 @@ class PermissionError(ApiError): """Raised when the CDA request is not authorized for the current caller.""" +def _unwrap_retry_error(error: RequestsRetryError) -> Exception: + """Return the original retry cause when requests wraps it in RetryError.""" + + current: Exception = error + cause = error.__cause__ + while isinstance(cause, Exception): + current = cause + cause = cause.__cause__ + + if current is error and error.args: + first_arg = error.args[0] + if isinstance(first_arg, Exception): + current = first_arg + reason = getattr(current, "reason", None) + while isinstance(reason, Exception): + current = reason + reason = getattr(current, "reason", None) + + return current + + def init_session( *, api_root: Optional[str] = None, @@ -308,11 +330,14 @@ def get( """ headers = {"Accept": api_version_text(api_version)} - with SESSION.get(endpoint, params=params, headers=headers) as response: - if not response.ok: - logging.error(f"CDA Error: response={response}") - raise ApiError(response) - return _process_response(response) + try: + with SESSION.get(endpoint, params=params, headers=headers) as response: + if not response.ok: + logging.error(f"CDA Error: response={response}") + raise ApiError(response) + return _process_response(response) + except RequestsRetryError as error: + raise _unwrap_retry_error(error) from None def get_with_paging( @@ -367,11 +392,16 @@ def _post_function( headers = {"accept": "*/*", "Content-Type": api_version_text(api_version)} if isinstance(data, dict) or isinstance(data, list): data = json.dumps(data) - with SESSION.post(endpoint, params=params, headers=headers, data=data) as response: - if not response.ok: - logging.error(f"CDA Error: response={response}") - raise ApiError(response) - return response + try: + with SESSION.post( + endpoint, params=params, headers=headers, data=data + ) as response: + if not response.ok: + logging.error(f"CDA Error: response={response}") + raise ApiError(response) + return response + except RequestsRetryError as error: + raise _unwrap_retry_error(error) from None def post( @@ -461,10 +491,15 @@ def patch( if data and isinstance(data, dict) or isinstance(data, list): data = json.dumps(data) - with SESSION.patch(endpoint, params=params, headers=headers, data=data) as response: - if not response.ok: - logging.error(f"CDA Error: response={response}") - raise ApiError(response) + try: + with SESSION.patch( + endpoint, params=params, headers=headers, data=data + ) as response: + if not response.ok: + logging.error(f"CDA Error: response={response}") + raise ApiError(response) + except RequestsRetryError as error: + raise _unwrap_retry_error(error) from None def delete( @@ -488,7 +523,10 @@ def delete( """ headers = {"Accept": api_version_text(api_version)} - with SESSION.delete(endpoint, params=params, headers=headers) as response: - if not response.ok: - logging.error(f"CDA Error: response={response}") - raise ApiError(response) + try: + with SESSION.delete(endpoint, params=params, headers=headers) as response: + if not response.ok: + logging.error(f"CDA Error: response={response}") + raise ApiError(response) + except RequestsRetryError as error: + raise _unwrap_retry_error(error) from None diff --git a/tests/mock/api_test.py b/tests/mock/api_test.py index 5680c24..1b27d4e 100644 --- a/tests/mock/api_test.py +++ b/tests/mock/api_test.py @@ -1,6 +1,11 @@ import pytest +from requests.exceptions import RetryError as RequestsRetryError +from urllib3.exceptions import MaxRetryError, ResponseError -from cwms.api import SESSION, InvalidVersion, api_version_text, init_session +import cwms.api +from cwms.api import SESSION, ApiError, api_version_text, init_session + +TEST_ENDPOINT = "/test-endpoint" def test_session_default(): @@ -53,3 +58,60 @@ def test_api_headers(): version = api_version_text(api_version=2) assert version == "application/json;version=2" + + +def test_retry_strategy_configuration(): + """Verify retry behavior preserves the original CDA error path.""" + + retries = SESSION.adapters["https://"].max_retries + + assert 500 not in retries.status_forcelist + assert retries.raise_on_status is False + + +def test_post_500_raises_api_error(monkeypatch): + """Verify a 500 response is surfaced directly as ApiError.""" + + class ResponseStub: + url = "https://example.com/cwms-data/test-endpoint" + ok = False + status_code = 500 + reason = "Internal Server Error" + content = b"incident identifier 34566432" + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + class SessionStub: + def post(self, endpoint, params=None, headers=None, data=None): + return ResponseStub() + + monkeypatch.setattr(cwms.api, "SESSION", SessionStub()) + + with pytest.raises(ApiError) as error: + cwms.api._post_function(endpoint=TEST_ENDPOINT, data={}) + + assert error.value.response.status_code == 500 + assert "Internal Server Error" in str(error.value) + assert "incident identifier 34566432" in str(error.value) + + +def test_retry_error_unwraps_original_cause(monkeypatch): + """Verify wrapped retry failures propagate the underlying cause.""" + + original_error = ResponseError("too many 503 error responses") + wrapped_error = RequestsRetryError( + MaxRetryError(pool=None, url=TEST_ENDPOINT, reason=original_error) + ) + + class SessionStub: + def get(self, endpoint, params=None, headers=None): + raise wrapped_error + + monkeypatch.setattr(cwms.api, "SESSION", SessionStub()) + + with pytest.raises(ResponseError, match="too many 503 error responses"): + cwms.api.get(TEST_ENDPOINT)