diff --git a/indico_toolkit/__init__.py b/indico_toolkit/__init__.py index ce9c4313..2797debf 100644 --- a/indico_toolkit/__init__.py +++ b/indico_toolkit/__init__.py @@ -3,4 +3,3 @@ from .errors import * from .client import create_client -from .retry import retry diff --git a/indico_toolkit/client.py b/indico_toolkit/client.py index 000ccb7c..ec449633 100644 --- a/indico_toolkit/client.py +++ b/indico_toolkit/client.py @@ -4,7 +4,7 @@ from indico_toolkit.retry import retry -@retry((IndicoRequestError, ConnectionError)) +@retry(IndicoRequestError, ConnectionError) def create_client( host: str, api_token_path: str = None, diff --git a/indico_toolkit/indico_wrapper/__init__.py b/indico_toolkit/indico_wrapper/__init__.py index b1bf3b33..7c993624 100644 --- a/indico_toolkit/indico_wrapper/__init__.py +++ b/indico_toolkit/indico_wrapper/__init__.py @@ -1,4 +1,4 @@ -from .indico_wrapper import IndicoWrapper, retry +from .indico_wrapper import IndicoWrapper from .workflow import Workflow from .dataset import Datasets from .reviewer import Reviewer diff --git a/indico_toolkit/indico_wrapper/download.py b/indico_toolkit/indico_wrapper/download.py index af44f7d5..2e5d2759 100644 --- a/indico_toolkit/indico_wrapper/download.py +++ b/indico_toolkit/indico_wrapper/download.py @@ -3,7 +3,8 @@ import pandas as pd from indico.types.export import Export from indico import IndicoClient, IndicoRequestError -from indico_toolkit import retry, ToolkitInputError +from indico_toolkit import ToolkitInputError +from indico_toolkit.retry import retry from indico.queries import ( RetrieveStorageObject, DownloadExport, @@ -103,14 +104,14 @@ def _download_pdfs_from_export( return max_files_to_download return export_df.shape[0] - @retry((IndicoRequestError, ConnectionError)) + @retry(IndicoRequestError, ConnectionError) def _download_export(self, export_id: int) -> pd.DataFrame: """ Download a dataframe representation of your dataset export """ return self.client.call(DownloadExport(export_id=export_id)) - @retry((IndicoRequestError, ConnectionError)) + @retry(IndicoRequestError, ConnectionError) def _create_export( self, dataset_id: int, @@ -142,7 +143,7 @@ def _create_export( ) ) - @retry((IndicoRequestError, ConnectionError)) + @retry(IndicoRequestError, ConnectionError) def _retrieve_storage_object(self, url: str): return self.client.call(RetrieveStorageObject(url)) diff --git a/indico_toolkit/indico_wrapper/indico_wrapper.py b/indico_toolkit/indico_wrapper/indico_wrapper.py index 6ad751e5..b6b46206 100644 --- a/indico_toolkit/indico_wrapper/indico_wrapper.py +++ b/indico_toolkit/indico_wrapper/indico_wrapper.py @@ -11,8 +11,9 @@ from indico import IndicoClient from indico.errors import IndicoRequestError +from indico_toolkit import ToolkitInputError +from indico_toolkit.retry import retry from indico_toolkit.types import Predictions -from indico_toolkit import ToolkitStatusError, retry class IndicoWrapper: @@ -68,7 +69,7 @@ def train_model( ) ) - @retry((IndicoRequestError, ConnectionError)) + @retry(IndicoRequestError, ConnectionError) def get_storage_object(self, storage_url: str): return self.client.call(RetrieveStorageObject(storage_url)) @@ -78,7 +79,7 @@ def create_storage_urls(self, file_paths: List[str]) -> List[str]: def get_job_status(self, job_id: int, wait: bool = True, timeout: float = None): return self.client.call(JobStatus(id=job_id, wait=wait, timeout=timeout)) - @retry((IndicoRequestError, ConnectionError)) + @retry(IndicoRequestError, ConnectionError) def graphQL_request(self, graphql_query: str, variables: dict = None): return self.client.call( GraphQLRequest(query=graphql_query, variables=variables) diff --git a/indico_toolkit/retry.py b/indico_toolkit/retry.py index 43111f01..24362344 100644 --- a/indico_toolkit/retry.py +++ b/indico_toolkit/retry.py @@ -1,27 +1,105 @@ -from functools import wraps +import asyncio import time +from functools import wraps +from inspect import iscoroutinefunction +from random import random +from typing import TYPE_CHECKING, overload + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + from typing import ParamSpec, TypeVar + ArgumentsType = ParamSpec("ArgumentsType") + OuterReturnType = TypeVar("OuterReturnType") + InnerReturnType = TypeVar("InnerReturnType") -def retry(exceptions, num_retries=5, wait=0.5): + +class MaxRetriesExceeded(Exception): """ - Decorator for retrying functions after specified exceptions are raised - Args: - exceptions (Exception or Tuple[Exception]): exceptions that should be retried on - wait (float): time in seconds to wait before retrying - num_retries (int): the number of times to retry the wrapped function + Raised when a function has retried more than `count` number of times. """ - def retry_decorator(fn): - @wraps(fn) - def retry_func(*args, **kwargs): - retries = 0 - while True: - try: - return fn(*args, **kwargs) - except exceptions as e: - if retries >= num_retries: - raise e - else: - retries += 1 - time.sleep(wait) - return retry_func - return retry_decorator \ No newline at end of file + + +def retry( + *errors: "type[Exception]", + count: int = 4, + wait: float = 1, + backoff: float = 4, + jitter: float = 0.5, +) -> "Callable[[Callable[ArgumentsType, OuterReturnType]], Callable[ArgumentsType, OuterReturnType]]": # noqa: E501 + """ + Decorate a function or coroutine to retry when it raises specified errors, + apply exponential backoff and jitter to the wait time, + and raise `MaxRetriesExceeded` after it retries too many times. + + By default, the decorated function or coroutine will be retried up to 4 times over + the course of ~2 minutes (waiting 1, 4, 16, and 64 seconds; plus up to 50% jitter) + before raising `MaxRetriesExceeded` from the last error. + + Arguments: + errors: Retry the function when it raises one of these errors. + count: Retry the function this many times before raising `MaxRetriesExceeded`. + wait: Wait this many seconds after the first error before retrying. + backoff: Multiply the wait time by this amount for each additional error. + jitter: Add a random amount of time (up to this percent as a decimal) + to the wait time to prevent simultaneous retries. + """ + + def wait_time(times_retried: int) -> float: + """ + Calculate the sleep time based on number of times retried. + """ + return wait * backoff**times_retried * (1 + jitter * random()) + + @overload + def retry_decorator( + decorated: "Callable[ArgumentsType, Awaitable[InnerReturnType]]", + ) -> "Callable[ArgumentsType, Awaitable[InnerReturnType]]": ... + @overload + def retry_decorator( + decorated: "Callable[ArgumentsType, InnerReturnType]", + ) -> "Callable[ArgumentsType, InnerReturnType]": ... + def retry_decorator( + decorated: "Callable[ArgumentsType, InnerReturnType]", + ) -> "Callable[ArgumentsType, Awaitable[InnerReturnType]] | Callable[ArgumentsType, InnerReturnType]": # noqa: E501 + """ + Decorate either a function or coroutine as appropriate. + """ + if iscoroutinefunction(decorated): + + @wraps(decorated) + async def retrying_coroutine( # type: ignore[return] + *args: "ArgumentsType.args", **kwargs: "ArgumentsType.kwargs" + ) -> "InnerReturnType": + for times_retried in range(count + 1): + try: + return await decorated(*args, **kwargs) # type: ignore[no-any-return] + except errors as error: + last_error = error + + if times_retried >= count: + raise MaxRetriesExceeded() from last_error + + await asyncio.sleep(wait_time(times_retried)) + + return retrying_coroutine + else: + + @wraps(decorated) + def retrying_function( # type: ignore[return] + *args: "ArgumentsType.args", **kwargs: "ArgumentsType.kwargs" + ) -> "InnerReturnType": + for times_retried in range(count + 1): + try: + return decorated(*args, **kwargs) + except errors as error: + last_error = error + + if times_retried >= count: + raise MaxRetriesExceeded() from last_error + + time.sleep(wait_time(times_retried)) + + return retrying_function + + return retry_decorator diff --git a/pyproject.toml b/pyproject.toml index 24c47f59..2919a7a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,9 +22,10 @@ requires = [ [tool.flit.metadata.requires-extra] test = [ - "pytest>=6.2.5", - "requests-mock>=1.7.0-7", - "pytest-dependency==0.5.1" + "pytest==8.3.4", + "pytest-asyncio==0.25.2", + "pytest-dependency==0.6.0", + "requests-mock>=1.7.0-7" ] full = [ "spacy>=3.1.4,<4" diff --git a/requirements.txt b/requirements.txt index ca0cfe38..0e8b39ac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,9 @@ indico-client>=5.1.4 python-dateutil==2.8.1 pytz==2021.1 -pytest==6.2.5 -pytest-dependency==0.5.1 +pytest==8.3.4 +pytest-asyncio==0.25.2 +pytest-dependency==0.6.0 pytest-mock==3.11.1 coverage==5.5 black==22.3 diff --git a/tests/test_retry.py b/tests/test_retry.py index c91fa875..f78c10c7 100644 --- a/tests/test_retry.py +++ b/tests/test_retry.py @@ -1,23 +1,73 @@ import pytest -from indico_toolkit import retry +from indico_toolkit.retry import retry, MaxRetriesExceeded -counter = 0 - - -def test_retry_decor(): +def test_no_errors() -> None: @retry(Exception) - def no_exceptions(): + def no_errors() -> bool: return True - - @retry((RuntimeError, ConnectionError), num_retries=7) - def raises_exceptions(): - global counter - counter +=1 - raise RuntimeError("Test runtime fail") - - assert no_exceptions() - with pytest.raises(RuntimeError): - raises_exceptions() - assert counter == 8 \ No newline at end of file + + assert no_errors() + + +def test_raises_errors() -> None: + calls = 0 + + @retry(RuntimeError, count=4, wait=0) + def raises_errors() -> None: + nonlocal calls + calls += 1 + raise RuntimeError() + + with pytest.raises(MaxRetriesExceeded): + raises_errors() + + assert calls == 5 + + +def test_raises_other_errors() -> None: + calls = 0 + + @retry(RuntimeError, count=4, wait=0) + def raises_errors() -> None: + nonlocal calls + calls += 1 + raise ValueError() + + with pytest.raises(ValueError): + raises_errors() + + assert calls == 1 + + +@pytest.mark.asyncio +async def test_raises_errors_async() -> None: + calls = 0 + + @retry(RuntimeError, count=4, wait=0) + async def raises_errors() -> None: + nonlocal calls + calls += 1 + raise RuntimeError() + + with pytest.raises(MaxRetriesExceeded): + await raises_errors() + + assert calls == 5 + + +@pytest.mark.asyncio +async def test_raises_other_errors_async() -> None: + calls = 0 + + @retry(RuntimeError, count=4, wait=0) + async def raises_errors() -> None: + nonlocal calls + calls += 1 + raise ValueError() + + with pytest.raises(ValueError): + await raises_errors() + + assert calls == 1