Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion indico_toolkit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,3 @@

from .errors import *
from .client import create_client
from .retry import retry
2 changes: 1 addition & 1 deletion indico_toolkit/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from indico_toolkit.retry import retry


@retry((IndicoRequestError, ConnectionError))
@retry(IndicoRequestError, ConnectionError)
def create_client(
host: str,
api_token_path: str = None,
Expand Down
2 changes: 1 addition & 1 deletion indico_toolkit/indico_wrapper/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .indico_wrapper import IndicoWrapper, retry
from .indico_wrapper import IndicoWrapper
from .workflow import Workflow
from .dataset import Datasets
from .reviewer import Reviewer
Expand Down
9 changes: 5 additions & 4 deletions indico_toolkit/indico_wrapper/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import pandas as pd
from indico.types.export import Export
from indico import IndicoClient, IndicoRequestError
from indico_toolkit import retry, ToolkitInputError
from indico_toolkit import ToolkitInputError
from indico_toolkit.retry import retry
from indico.queries import (
RetrieveStorageObject,
DownloadExport,
Expand Down Expand Up @@ -103,14 +104,14 @@ def _download_pdfs_from_export(
return max_files_to_download
return export_df.shape[0]

@retry((IndicoRequestError, ConnectionError))
@retry(IndicoRequestError, ConnectionError)
def _download_export(self, export_id: int) -> pd.DataFrame:
"""
Download a dataframe representation of your dataset export
"""
return self.client.call(DownloadExport(export_id=export_id))

@retry((IndicoRequestError, ConnectionError))
@retry(IndicoRequestError, ConnectionError)
def _create_export(
self,
dataset_id: int,
Expand Down Expand Up @@ -142,7 +143,7 @@ def _create_export(
)
)

@retry((IndicoRequestError, ConnectionError))
@retry(IndicoRequestError, ConnectionError)
def _retrieve_storage_object(self, url: str):
return self.client.call(RetrieveStorageObject(url))

Expand Down
7 changes: 4 additions & 3 deletions indico_toolkit/indico_wrapper/indico_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
from indico import IndicoClient
from indico.errors import IndicoRequestError

from indico_toolkit import ToolkitInputError
from indico_toolkit.retry import retry
from indico_toolkit.types import Predictions
from indico_toolkit import ToolkitStatusError, retry


class IndicoWrapper:
Expand Down Expand Up @@ -68,7 +69,7 @@ def train_model(
)
)

@retry((IndicoRequestError, ConnectionError))
@retry(IndicoRequestError, ConnectionError)
def get_storage_object(self, storage_url: str):
return self.client.call(RetrieveStorageObject(storage_url))

Expand All @@ -78,7 +79,7 @@ def create_storage_urls(self, file_paths: List[str]) -> List[str]:
def get_job_status(self, job_id: int, wait: bool = True, timeout: float = None):
return self.client.call(JobStatus(id=job_id, wait=wait, timeout=timeout))

@retry((IndicoRequestError, ConnectionError))
@retry(IndicoRequestError, ConnectionError)
def graphQL_request(self, graphql_query: str, variables: dict = None):
return self.client.call(
GraphQLRequest(query=graphql_query, variables=variables)
Expand Down
122 changes: 100 additions & 22 deletions indico_toolkit/retry.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,105 @@
from functools import wraps
import asyncio
import time
from functools import wraps
from inspect import iscoroutinefunction
from random import random
from typing import TYPE_CHECKING, overload

if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
from typing import ParamSpec, TypeVar

ArgumentsType = ParamSpec("ArgumentsType")
OuterReturnType = TypeVar("OuterReturnType")
InnerReturnType = TypeVar("InnerReturnType")

def retry(exceptions, num_retries=5, wait=0.5):

class MaxRetriesExceeded(Exception):
"""
Decorator for retrying functions after specified exceptions are raised
Args:
exceptions (Exception or Tuple[Exception]): exceptions that should be retried on
wait (float): time in seconds to wait before retrying
num_retries (int): the number of times to retry the wrapped function
Raised when a function has retried more than `count` number of times.
"""
def retry_decorator(fn):
@wraps(fn)
def retry_func(*args, **kwargs):
retries = 0
while True:
try:
return fn(*args, **kwargs)
except exceptions as e:
if retries >= num_retries:
raise e
else:
retries += 1
time.sleep(wait)
return retry_func
return retry_decorator


def retry(
*errors: "type[Exception]",
count: int = 4,
wait: float = 1,
backoff: float = 4,
jitter: float = 0.5,
) -> "Callable[[Callable[ArgumentsType, OuterReturnType]], Callable[ArgumentsType, OuterReturnType]]": # noqa: E501
"""
Decorate a function or coroutine to retry when it raises specified errors,
apply exponential backoff and jitter to the wait time,
and raise `MaxRetriesExceeded` after it retries too many times.

By default, the decorated function or coroutine will be retried up to 4 times over
the course of ~2 minutes (waiting 1, 4, 16, and 64 seconds; plus up to 50% jitter)
before raising `MaxRetriesExceeded` from the last error.

Arguments:
errors: Retry the function when it raises one of these errors.
count: Retry the function this many times before raising `MaxRetriesExceeded`.
wait: Wait this many seconds after the first error before retrying.
backoff: Multiply the wait time by this amount for each additional error.
jitter: Add a random amount of time (up to this percent as a decimal)
to the wait time to prevent simultaneous retries.
"""

def wait_time(times_retried: int) -> float:
"""
Calculate the sleep time based on number of times retried.
"""
return wait * backoff**times_retried * (1 + jitter * random())

@overload
def retry_decorator(
decorated: "Callable[ArgumentsType, Awaitable[InnerReturnType]]",
) -> "Callable[ArgumentsType, Awaitable[InnerReturnType]]": ...
@overload
def retry_decorator(
decorated: "Callable[ArgumentsType, InnerReturnType]",
) -> "Callable[ArgumentsType, InnerReturnType]": ...
def retry_decorator(
decorated: "Callable[ArgumentsType, InnerReturnType]",
) -> "Callable[ArgumentsType, Awaitable[InnerReturnType]] | Callable[ArgumentsType, InnerReturnType]": # noqa: E501
"""
Decorate either a function or coroutine as appropriate.
"""
if iscoroutinefunction(decorated):

@wraps(decorated)
async def retrying_coroutine( # type: ignore[return]
*args: "ArgumentsType.args", **kwargs: "ArgumentsType.kwargs"
) -> "InnerReturnType":
for times_retried in range(count + 1):
try:
return await decorated(*args, **kwargs) # type: ignore[no-any-return]
except errors as error:
last_error = error

if times_retried >= count:
raise MaxRetriesExceeded() from last_error

await asyncio.sleep(wait_time(times_retried))

return retrying_coroutine
else:

@wraps(decorated)
def retrying_function( # type: ignore[return]
*args: "ArgumentsType.args", **kwargs: "ArgumentsType.kwargs"
) -> "InnerReturnType":
for times_retried in range(count + 1):
try:
return decorated(*args, **kwargs)
except errors as error:
last_error = error

if times_retried >= count:
raise MaxRetriesExceeded() from last_error

time.sleep(wait_time(times_retried))

return retrying_function

return retry_decorator
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ requires = [

[tool.flit.metadata.requires-extra]
test = [
"pytest>=6.2.5",
"requests-mock>=1.7.0-7",
"pytest-dependency==0.5.1"
"pytest==8.3.4",
"pytest-asyncio==0.25.2",
"pytest-dependency==0.6.0",
"requests-mock>=1.7.0-7"
]
full = [
"spacy>=3.1.4,<4"
Expand Down
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
indico-client>=5.1.4
python-dateutil==2.8.1
pytz==2021.1
pytest==6.2.5
pytest-dependency==0.5.1
pytest==8.3.4
pytest-asyncio==0.25.2
pytest-dependency==0.6.0
pytest-mock==3.11.1
coverage==5.5
black==22.3
Expand Down
84 changes: 67 additions & 17 deletions tests/test_retry.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,73 @@
import pytest

from indico_toolkit import retry
from indico_toolkit.retry import retry, MaxRetriesExceeded


counter = 0


def test_retry_decor():
def test_no_errors() -> None:
@retry(Exception)
def no_exceptions():
def no_errors() -> bool:
return True

@retry((RuntimeError, ConnectionError), num_retries=7)
def raises_exceptions():
global counter
counter +=1
raise RuntimeError("Test runtime fail")

assert no_exceptions()
with pytest.raises(RuntimeError):
raises_exceptions()
assert counter == 8

assert no_errors()


def test_raises_errors() -> None:
calls = 0

@retry(RuntimeError, count=4, wait=0)
def raises_errors() -> None:
nonlocal calls
calls += 1
raise RuntimeError()

with pytest.raises(MaxRetriesExceeded):
raises_errors()

assert calls == 5


def test_raises_other_errors() -> None:
calls = 0

@retry(RuntimeError, count=4, wait=0)
def raises_errors() -> None:
nonlocal calls
calls += 1
raise ValueError()

with pytest.raises(ValueError):
raises_errors()

assert calls == 1


@pytest.mark.asyncio
async def test_raises_errors_async() -> None:
calls = 0

@retry(RuntimeError, count=4, wait=0)
async def raises_errors() -> None:
nonlocal calls
calls += 1
raise RuntimeError()

with pytest.raises(MaxRetriesExceeded):
await raises_errors()

assert calls == 5


@pytest.mark.asyncio
async def test_raises_other_errors_async() -> None:
calls = 0

@retry(RuntimeError, count=4, wait=0)
async def raises_errors() -> None:
nonlocal calls
calls += 1
raise ValueError()

with pytest.raises(ValueError):
await raises_errors()

assert calls == 1
Loading