diff --git a/.github/workflows/python-test-linux.yml b/.github/workflows/python-test-linux.yml index 0c1d8ec..9bd5f81 100644 --- a/.github/workflows/python-test-linux.yml +++ b/.github/workflows/python-test-linux.yml @@ -42,7 +42,7 @@ jobs: pip install dist/*.whl - name: Test with pytest run: | - python -m pip install typing-extensions + python -m pip install typing-extensions pytest-asyncio pytest build: @@ -74,6 +74,6 @@ jobs: pip install dist/*.whl - name: Test with pytest run: | - python -m pip install typing-extensions + python -m pip install typing-extensions pytest-asyncio pytest diff --git a/.github/workflows/python-test-macos.yml b/.github/workflows/python-test-macos.yml index b04f553..96208d6 100644 --- a/.github/workflows/python-test-macos.yml +++ b/.github/workflows/python-test-macos.yml @@ -42,7 +42,7 @@ jobs: pip install dist/*.whl - name: Test with pytest run: | - python -m pip install typing-extensions + python -m pip install typing-extensions pytest-asyncio pytest build: @@ -74,6 +74,6 @@ jobs: pip install dist/*.whl - name: Test with pytest run: | - python -m pip install typing-extensions + python -m pip install typing-extensions pytest-asyncio pytest diff --git a/.github/workflows/python-test-windows.yml b/.github/workflows/python-test-windows.yml index d72904c..d7d7c79 100644 --- a/.github/workflows/python-test-windows.yml +++ b/.github/workflows/python-test-windows.yml @@ -49,6 +49,6 @@ jobs: } - name: Test with pytest run: | - python -m pip install typing-extensions + python -m pip install typing-extensions pytest-asyncio pytest diff --git a/README.md b/README.md index 4978476..fb97f34 100644 --- a/README.md +++ b/README.md @@ -47,15 +47,16 @@ list_ext.extend() Currently, we provide the following extensions: -| file | extended types | -|:---------------:|:----------------------------------:| -| dict_ext.py | dict_keys, dict_values, dict_items | -| float_ext.py | float | -| function_ext.py | FunctionType, LambdaType | -| int_ext.py | int | -| list_ext.py | list | -| seq_ext.py | map, filter, range, zip | -| str_ext.py | str | +| file | extended types | +|:----------------:|:----------------------------------:| +| coroutine_ext.py | coroutine (async functions) | +| dict_ext.py | dict_keys, dict_values, dict_items | +| float_ext.py | float | +| function_ext.py | FunctionType, LambdaType | +| int_ext.py | int | +| list_ext.py | list | +| seq_ext.py | map, filter, range, zip | +| str_ext.py | str | @@ -214,6 +215,36 @@ list.last(self: List[T]) -> T, raise IndexError ``` Returns the last element in the list, or raises `IndexError` if the list is empty. +```py +coroutine.then(self: Awaitable[T], fn: Callable[[T], Awaitable[U] | U]) -> Awaitable[U] +``` +Maps the result of the awaitable via an optionally async function. If the function is async, it is awaited in the context of the wrapped awaitable. + +Example: +```py +async def get_value(): + return 10 + +result = await get_value().then(lambda x: x * 2) # result is 20 +``` + +```py +coroutine.catch(self: Awaitable[T], fn: Callable[[E], Awaitable[U] | U], *, exception: type[E] = Exception) -> Awaitable[T | U] +``` +Catches an exception of the given type and calls the passed function with the caught exception. + +If no exception was raised inside the wrapped awaitable, the function will not be called. +The passed function can optionally return a value to be returned in case of an error. +The passed function can be either sync or async. If it's async, it is awaited in the context of the wrapped awaitable. + +Example: +```py +async def might_fail(): + raise ValueError("error") + +result = await might_fail().catch(lambda e: "default", exception=ValueError) # result is "default" +``` + ```py float.round(self: float) -> int ``` diff --git a/pyproject.toml b/pyproject.toml index d806e6f..2b98ab4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,4 +23,4 @@ dev = [ "meson-python>=0.17.1", "ninja>=1.11.1.4", ] -test = ["pytest>=7.4.4", "typing-extensions>=4.7.1"] +test = ["pytest>=7.4.4", "pytest-asyncio>=0.21.0", "typing-extensions>=4.7.1"] diff --git a/src/extype/builtin_extensions/coroutine_ext.py b/src/extype/builtin_extensions/coroutine_ext.py new file mode 100644 index 0000000..44c0bf8 --- /dev/null +++ b/src/extype/builtin_extensions/coroutine_ext.py @@ -0,0 +1,90 @@ +from inspect import iscoroutine +from typing import Awaitable, Callable, Type, TypeVar, Union + +from ..extension_utils import extend_type_with, extension + + +__all__ = [ + "extend", + "CoroutineExtension" +] + + +_T = TypeVar("_T") +_U = TypeVar("_U") +_E = TypeVar("_E", bound=BaseException) + + +class CoroutineExtension: + """ + A class that contains methods to extend coroutine objects (async functions). + """ + + @extension + def then(self: Awaitable[_T], fn: Callable[[_T], Union[Awaitable[_U], _U]]) -> Awaitable[_U]: + """ + Maps the result of the awaitable via an optionally async function. + + If the function is async, it is awaited in the context of the wrapped awaitable. + + Args: + fn: A function that takes the result of the awaitable and returns a value or awaitable. + + Returns: + An awaitable that resolves to the result of the function. + """ + async def _then(): + result = fn(await self) + if iscoroutine(result): + return await result + return result + + return _then() + + @extension + def catch( + self: Awaitable[_T], + fn: Callable[[_E], Union[Awaitable[_U], _U]], + *, + exception: Type[_E] = Exception + ) -> Awaitable[Union[_T, _U]]: + """ + Catches an exception of the given type and calls the passed function with the caught exception. + + If no exception was raised inside the wrapped awaitable, the function will not be called. + The passed function can optionally return a value to be returned in case of an error. + The passed function can be either sync or async. If it's async, it is awaited. + + Args: + fn: A function that takes the exception and returns a value or awaitable. + exception: The type of exception to catch (default: Exception). + + Returns: + An awaitable that resolves to the original result or the result of the error handler. + """ + async def _catch(): + try: + return await self + except exception as e: + result = fn(e) + if iscoroutine(result): + return await result + return result + + return _catch() + + +def extend(): + """ + Applies the coroutine extensions to coroutine objects. + """ + # Get the coroutine type by creating a coroutine and getting its type + async def _dummy(): + pass + + coro = _dummy() + coroutine_type = type(coro) + extend_type_with(coroutine_type, CoroutineExtension) + + # Close the coroutine to avoid warnings + coro.close() diff --git a/src/extype/builtin_extensions/extend_all.py b/src/extype/builtin_extensions/extend_all.py index 59a4651..95176fe 100644 --- a/src/extype/builtin_extensions/extend_all.py +++ b/src/extype/builtin_extensions/extend_all.py @@ -6,6 +6,7 @@ function_ext, dict_ext, str_ext, + coroutine_ext, ) for ext in [ @@ -16,5 +17,6 @@ function_ext, dict_ext, str_ext, + coroutine_ext, ]: ext.extend() diff --git a/src/extype/builtin_extensions/meson.build b/src/extype/builtin_extensions/meson.build index 5c68d17..ef6032a 100644 --- a/src/extype/builtin_extensions/meson.build +++ b/src/extype/builtin_extensions/meson.build @@ -7,7 +7,8 @@ python_sources = [ 'int_ext.py', 'list_ext.py', 'seq_ext.py', - 'str_ext.py' + 'str_ext.py', + 'coroutine_ext.py' ] diff --git a/tests/test_builtin_extensions.py b/tests/test_builtin_extensions.py index ee7ede0..ed2aea3 100644 --- a/tests/test_builtin_extensions.py +++ b/tests/test_builtin_extensions.py @@ -1,5 +1,5 @@ import pytest -from extype.builtin_extensions import extend_all +from extype.builtin_extensions import extend_all # noqa: F401 # dict keys extension tests @@ -322,3 +322,107 @@ def test_str_to_float(): ################################################### + +# coroutine extensions tests + + +@pytest.mark.asyncio +async def test_coroutine_then_sync(): + async def foo(): + return 10 + + result = await foo().then(lambda x: x + 5) + assert result == 15 + + +@pytest.mark.asyncio +async def test_coroutine_then_async(): + async def foo(): + return 10 + + async def add_five(x): + return x + 5 + + result = await foo().then(add_five) + assert result == 15 + + +@pytest.mark.asyncio +async def test_coroutine_then_chaining(): + async def foo(): + return 10 + + async def add_five(x): + return x + 5 + + result = await foo().then(lambda x: x * 2).then(add_five).then(lambda x: x - 3) + assert result == 22 # (10 * 2) + 5 - 3 = 22 + + +@pytest.mark.asyncio +async def test_coroutine_catch_no_exception(): + async def foo(): + return 42 + + result = await foo().catch(lambda e: 0) + assert result == 42 + + +@pytest.mark.asyncio +async def test_coroutine_catch_with_exception(): + async def foo(): + raise ValueError("test error") + + result = await foo().catch(lambda e: 100, exception=ValueError) + assert result == 100 + + +@pytest.mark.asyncio +async def test_coroutine_catch_async_handler(): + async def foo(): + raise ValueError("test error") + + async def handle_error(e): + return 200 + + result = await foo().catch(handle_error, exception=ValueError) + assert result == 200 + + +@pytest.mark.asyncio +async def test_coroutine_catch_wrong_exception_type(): + async def foo(): + raise ValueError("test error") + + with pytest.raises(ValueError): + await foo().catch(lambda e: 0, exception=TypeError) + + +@pytest.mark.asyncio +async def test_coroutine_catch_default_exception(): + async def foo(): + raise RuntimeError("test error") + + result = await foo().catch(lambda e: 300) + assert result == 300 + + +@pytest.mark.asyncio +async def test_coroutine_then_and_catch_combined(): + async def foo(): + return 10 + + result = await foo().then(lambda x: x * 2).catch(lambda e: 0) + assert result == 20 + + +@pytest.mark.asyncio +async def test_coroutine_catch_and_then_combined(): + async def foo(): + raise ValueError("error") + + result = await foo().catch(lambda e: 50, exception=ValueError).then(lambda x: x + 10) + assert result == 60 + + +###################################################