From 53e38a960d62c80a74d29e44f66d9b96465a5471 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 9 Feb 2026 20:08:49 +0000 Subject: [PATCH 1/6] Initial plan From ff0ed7a6267485657e7d295e13d3f7a0b1708b08 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 9 Feb 2026 20:12:33 +0000 Subject: [PATCH 2/6] Add async function extensions (then and catch) for coroutines Co-authored-by: neriyaco <42520786+neriyaco@users.noreply.github.com> --- README.md | 49 ++++++-- pyproject.toml | 2 +- .../builtin_extensions/coroutine_ext.py | 90 +++++++++++++++ src/extype/builtin_extensions/extend_all.py | 2 + src/extype/builtin_extensions/meson.build | 3 +- tests/test_builtin_extensions.py | 106 ++++++++++++++++++ 6 files changed, 241 insertions(+), 11 deletions(-) create mode 100644 src/extype/builtin_extensions/coroutine_ext.py diff --git a/README.md b/README.md index 4978476..fb97f34 100644 --- a/README.md +++ b/README.md @@ -47,15 +47,16 @@ list_ext.extend() Currently, we provide the following extensions: -| file | extended types | -|:---------------:|:----------------------------------:| -| dict_ext.py | dict_keys, dict_values, dict_items | -| float_ext.py | float | -| function_ext.py | FunctionType, LambdaType | -| int_ext.py | int | -| list_ext.py | list | -| seq_ext.py | map, filter, range, zip | -| str_ext.py | str | +| file | extended types | +|:----------------:|:----------------------------------:| +| coroutine_ext.py | coroutine (async functions) | +| dict_ext.py | dict_keys, dict_values, dict_items | +| float_ext.py | float | +| function_ext.py | FunctionType, LambdaType | +| int_ext.py | int | +| list_ext.py | list | +| seq_ext.py | map, filter, range, zip | +| str_ext.py | str | @@ -214,6 +215,36 @@ list.last(self: List[T]) -> T, raise IndexError ``` Returns the last element in the list, or raises `IndexError` if the list is empty. +```py +coroutine.then(self: Awaitable[T], fn: Callable[[T], Awaitable[U] | U]) -> Awaitable[U] +``` +Maps the result of the awaitable via an optionally async function. If the function is async, it is awaited in the context of the wrapped awaitable. + +Example: +```py +async def get_value(): + return 10 + +result = await get_value().then(lambda x: x * 2) # result is 20 +``` + +```py +coroutine.catch(self: Awaitable[T], fn: Callable[[E], Awaitable[U] | U], *, exception: type[E] = Exception) -> Awaitable[T | U] +``` +Catches an exception of the given type and calls the passed function with the caught exception. + +If no exception was raised inside the wrapped awaitable, the function will not be called. +The passed function can optionally return a value to be returned in case of an error. +The passed function can be either sync or async. If it's async, it is awaited in the context of the wrapped awaitable. + +Example: +```py +async def might_fail(): + raise ValueError("error") + +result = await might_fail().catch(lambda e: "default", exception=ValueError) # result is "default" +``` + ```py float.round(self: float) -> int ``` diff --git a/pyproject.toml b/pyproject.toml index d806e6f..2b98ab4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,4 +23,4 @@ dev = [ "meson-python>=0.17.1", "ninja>=1.11.1.4", ] -test = ["pytest>=7.4.4", "typing-extensions>=4.7.1"] +test = ["pytest>=7.4.4", "pytest-asyncio>=0.21.0", "typing-extensions>=4.7.1"] diff --git a/src/extype/builtin_extensions/coroutine_ext.py b/src/extype/builtin_extensions/coroutine_ext.py new file mode 100644 index 0000000..293ff1c --- /dev/null +++ b/src/extype/builtin_extensions/coroutine_ext.py @@ -0,0 +1,90 @@ +from inspect import iscoroutinefunction +from typing import Awaitable, Callable, TypeVar + +from ..extension_utils import extend_type_with, extension + + +__all__ = [ + "extend", + "CoroutineExtension" +] + + +_T = TypeVar("_T") +_U = TypeVar("_U") +_E = TypeVar("_E", bound=BaseException) + + +class CoroutineExtension: + """ + A class that contains methods to extend coroutine objects (async functions). + """ + + @extension + def then(self: Awaitable[_T], fn: Callable[[_T], Awaitable[_U] | _U]) -> Awaitable[_U]: + """ + Maps the result of the awaitable via an optionally async function. + + If the function is async, it is awaited in the context of the wrapped awaitable. + + Args: + fn: A function that takes the result of the awaitable and returns a value or awaitable. + + Returns: + An awaitable that resolves to the result of the function. + """ + async def _then(): + result = fn(await self) + if iscoroutinefunction(fn): + return await result + return result + + return _then() + + @extension + def catch( + self: Awaitable[_T], + fn: Callable[[_E], Awaitable[_U] | _U], + *, + exception: type[_E] = Exception + ) -> Awaitable[_T | _U]: + """ + Catches an exception of the given type and calls the passed function with the caught exception. + + If no exception was raised inside the wrapped awaitable, the function will not be called. + The passed function can optionally return a value to be returned in case of an error. + The passed function can be either sync or async. If it's async, it is awaited. + + Args: + fn: A function that takes the exception and returns a value or awaitable. + exception: The type of exception to catch (default: Exception). + + Returns: + An awaitable that resolves to the original result or the result of the error handler. + """ + async def _catch(): + try: + return await self + except exception as e: + result = fn(e) + if iscoroutinefunction(fn): + return await result + return result + + return _catch() + + +def extend(): + """ + Applies the coroutine extensions to coroutine objects. + """ + # Get the coroutine type by creating a coroutine and getting its type + async def _dummy(): + pass + + coro = _dummy() + coroutine_type = type(coro) + extend_type_with(coroutine_type, CoroutineExtension) + + # Close the coroutine to avoid warnings + coro.close() diff --git a/src/extype/builtin_extensions/extend_all.py b/src/extype/builtin_extensions/extend_all.py index 59a4651..95176fe 100644 --- a/src/extype/builtin_extensions/extend_all.py +++ b/src/extype/builtin_extensions/extend_all.py @@ -6,6 +6,7 @@ function_ext, dict_ext, str_ext, + coroutine_ext, ) for ext in [ @@ -16,5 +17,6 @@ function_ext, dict_ext, str_ext, + coroutine_ext, ]: ext.extend() diff --git a/src/extype/builtin_extensions/meson.build b/src/extype/builtin_extensions/meson.build index 5c68d17..ef6032a 100644 --- a/src/extype/builtin_extensions/meson.build +++ b/src/extype/builtin_extensions/meson.build @@ -7,7 +7,8 @@ python_sources = [ 'int_ext.py', 'list_ext.py', 'seq_ext.py', - 'str_ext.py' + 'str_ext.py', + 'coroutine_ext.py' ] diff --git a/tests/test_builtin_extensions.py b/tests/test_builtin_extensions.py index ee7ede0..62c21b2 100644 --- a/tests/test_builtin_extensions.py +++ b/tests/test_builtin_extensions.py @@ -322,3 +322,109 @@ def test_str_to_float(): ################################################### + +# coroutine extensions tests + + +@pytest.mark.asyncio +async def test_coroutine_then_sync(): + async def foo(): + return 10 + + result = await foo().then(lambda x: x + 5) + assert result == 15 + + +@pytest.mark.asyncio +async def test_coroutine_then_async(): + async def foo(): + return 10 + + async def add_five(x): + return x + 5 + + result = await foo().then(add_five) + assert result == 15 + + +@pytest.mark.asyncio +async def test_coroutine_then_chaining(): + async def foo(): + return 10 + + async def add_five(x): + return x + 5 + + result = await foo().then(lambda x: x * 2).then(add_five).then(lambda x: x - 3) + assert result == 22 # (10 * 2) + 5 - 3 = 22 + + +@pytest.mark.asyncio +async def test_coroutine_catch_no_exception(): + async def foo(): + return 42 + + result = await foo().catch(lambda e: 0) + assert result == 42 + + +@pytest.mark.asyncio +async def test_coroutine_catch_with_exception(): + async def foo(): + raise ValueError("test error") + + result = await foo().catch(lambda e: 100, exception=ValueError) + assert result == 100 + + +@pytest.mark.asyncio +async def test_coroutine_catch_async_handler(): + async def foo(): + raise ValueError("test error") + + async def handle_error(e): + return 200 + + result = await foo().catch(handle_error, exception=ValueError) + assert result == 200 + + +@pytest.mark.asyncio +async def test_coroutine_catch_wrong_exception_type(): + async def foo(): + raise ValueError("test error") + + with pytest.raises(ValueError): + await foo().catch(lambda e: 0, exception=TypeError) + + +@pytest.mark.asyncio +async def test_coroutine_catch_default_exception(): + async def foo(): + raise RuntimeError("test error") + + result = await foo().catch(lambda e: 300) + assert result == 300 + + +@pytest.mark.asyncio +async def test_coroutine_then_and_catch_combined(): + async def foo(): + return 10 + + result = await foo().then(lambda x: x * 2).catch(lambda e: 0) + assert result == 20 + + +@pytest.mark.asyncio +async def test_coroutine_catch_and_then_combined(): + async def foo(): + raise ValueError("error") + + result = await foo().catch(lambda e: 50, exception=ValueError).then(lambda x: x + 10) + assert result == 60 + + +################################################### + + From f79f0628f6099a4112a3b2f20ed5fc772dc32a9e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 9 Feb 2026 20:14:07 +0000 Subject: [PATCH 3/6] Fix linting issues (trailing whitespace) Co-authored-by: neriyaco <42520786+neriyaco@users.noreply.github.com> --- .../builtin_extensions/coroutine_ext.py | 20 ++++++------- tests/test_builtin_extensions.py | 30 +++++++++---------- 2 files changed, 24 insertions(+), 26 deletions(-) diff --git a/src/extype/builtin_extensions/coroutine_ext.py b/src/extype/builtin_extensions/coroutine_ext.py index 293ff1c..8bfa7be 100644 --- a/src/extype/builtin_extensions/coroutine_ext.py +++ b/src/extype/builtin_extensions/coroutine_ext.py @@ -24,12 +24,12 @@ class CoroutineExtension: def then(self: Awaitable[_T], fn: Callable[[_T], Awaitable[_U] | _U]) -> Awaitable[_U]: """ Maps the result of the awaitable via an optionally async function. - + If the function is async, it is awaited in the context of the wrapped awaitable. - + Args: fn: A function that takes the result of the awaitable and returns a value or awaitable. - + Returns: An awaitable that resolves to the result of the function. """ @@ -38,7 +38,7 @@ async def _then(): if iscoroutinefunction(fn): return await result return result - + return _then() @extension @@ -50,15 +50,15 @@ def catch( ) -> Awaitable[_T | _U]: """ Catches an exception of the given type and calls the passed function with the caught exception. - + If no exception was raised inside the wrapped awaitable, the function will not be called. The passed function can optionally return a value to be returned in case of an error. The passed function can be either sync or async. If it's async, it is awaited. - + Args: fn: A function that takes the exception and returns a value or awaitable. exception: The type of exception to catch (default: Exception). - + Returns: An awaitable that resolves to the original result or the result of the error handler. """ @@ -70,7 +70,7 @@ async def _catch(): if iscoroutinefunction(fn): return await result return result - + return _catch() @@ -81,10 +81,10 @@ def extend(): # Get the coroutine type by creating a coroutine and getting its type async def _dummy(): pass - + coro = _dummy() coroutine_type = type(coro) extend_type_with(coroutine_type, CoroutineExtension) - + # Close the coroutine to avoid warnings coro.close() diff --git a/tests/test_builtin_extensions.py b/tests/test_builtin_extensions.py index 62c21b2..ed2aea3 100644 --- a/tests/test_builtin_extensions.py +++ b/tests/test_builtin_extensions.py @@ -1,5 +1,5 @@ import pytest -from extype.builtin_extensions import extend_all +from extype.builtin_extensions import extend_all # noqa: F401 # dict keys extension tests @@ -330,7 +330,7 @@ def test_str_to_float(): async def test_coroutine_then_sync(): async def foo(): return 10 - + result = await foo().then(lambda x: x + 5) assert result == 15 @@ -339,10 +339,10 @@ async def foo(): async def test_coroutine_then_async(): async def foo(): return 10 - + async def add_five(x): return x + 5 - + result = await foo().then(add_five) assert result == 15 @@ -351,10 +351,10 @@ async def add_five(x): async def test_coroutine_then_chaining(): async def foo(): return 10 - + async def add_five(x): return x + 5 - + result = await foo().then(lambda x: x * 2).then(add_five).then(lambda x: x - 3) assert result == 22 # (10 * 2) + 5 - 3 = 22 @@ -363,7 +363,7 @@ async def add_five(x): async def test_coroutine_catch_no_exception(): async def foo(): return 42 - + result = await foo().catch(lambda e: 0) assert result == 42 @@ -372,7 +372,7 @@ async def foo(): async def test_coroutine_catch_with_exception(): async def foo(): raise ValueError("test error") - + result = await foo().catch(lambda e: 100, exception=ValueError) assert result == 100 @@ -381,10 +381,10 @@ async def foo(): async def test_coroutine_catch_async_handler(): async def foo(): raise ValueError("test error") - + async def handle_error(e): return 200 - + result = await foo().catch(handle_error, exception=ValueError) assert result == 200 @@ -393,7 +393,7 @@ async def handle_error(e): async def test_coroutine_catch_wrong_exception_type(): async def foo(): raise ValueError("test error") - + with pytest.raises(ValueError): await foo().catch(lambda e: 0, exception=TypeError) @@ -402,7 +402,7 @@ async def foo(): async def test_coroutine_catch_default_exception(): async def foo(): raise RuntimeError("test error") - + result = await foo().catch(lambda e: 300) assert result == 300 @@ -411,7 +411,7 @@ async def foo(): async def test_coroutine_then_and_catch_combined(): async def foo(): return 10 - + result = await foo().then(lambda x: x * 2).catch(lambda e: 0) assert result == 20 @@ -420,11 +420,9 @@ async def foo(): async def test_coroutine_catch_and_then_combined(): async def foo(): raise ValueError("error") - + result = await foo().catch(lambda e: 50, exception=ValueError).then(lambda x: x + 10) assert result == 60 ################################################### - - From 14145769bb26564961107caceb075273b0264acb Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 9 Feb 2026 20:15:08 +0000 Subject: [PATCH 4/6] Fix code review issue: use iscoroutine instead of iscoroutinefunction Co-authored-by: neriyaco <42520786+neriyaco@users.noreply.github.com> --- src/extype/builtin_extensions/coroutine_ext.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/extype/builtin_extensions/coroutine_ext.py b/src/extype/builtin_extensions/coroutine_ext.py index 8bfa7be..6525def 100644 --- a/src/extype/builtin_extensions/coroutine_ext.py +++ b/src/extype/builtin_extensions/coroutine_ext.py @@ -1,4 +1,4 @@ -from inspect import iscoroutinefunction +from inspect import iscoroutine from typing import Awaitable, Callable, TypeVar from ..extension_utils import extend_type_with, extension @@ -35,7 +35,7 @@ def then(self: Awaitable[_T], fn: Callable[[_T], Awaitable[_U] | _U]) -> Awaitab """ async def _then(): result = fn(await self) - if iscoroutinefunction(fn): + if iscoroutine(result): return await result return result @@ -67,7 +67,7 @@ async def _catch(): return await self except exception as e: result = fn(e) - if iscoroutinefunction(fn): + if iscoroutine(result): return await result return result From b545dd14d90b0ddb703cd92bf59e0e19afde46d5 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 10 Feb 2026 10:31:37 +0000 Subject: [PATCH 5/6] Fix CI: install pytest-asyncio in all workflow test steps Co-authored-by: binyamin555 <42520501+binyamin555@users.noreply.github.com> --- .github/workflows/python-test-linux.yml | 4 ++-- .github/workflows/python-test-macos.yml | 4 ++-- .github/workflows/python-test-windows.yml | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/python-test-linux.yml b/.github/workflows/python-test-linux.yml index bc1c9c5..55f47c8 100644 --- a/.github/workflows/python-test-linux.yml +++ b/.github/workflows/python-test-linux.yml @@ -39,7 +39,7 @@ jobs: pip install dist/*.whl - name: Test with pytest run: | - python -m pip install typing-extensions + python -m pip install typing-extensions pytest-asyncio pytest build: @@ -71,6 +71,6 @@ jobs: pip install dist/*.whl - name: Test with pytest run: | - python -m pip install typing-extensions + python -m pip install typing-extensions pytest-asyncio pytest diff --git a/.github/workflows/python-test-macos.yml b/.github/workflows/python-test-macos.yml index 55d5488..6a0d1c6 100644 --- a/.github/workflows/python-test-macos.yml +++ b/.github/workflows/python-test-macos.yml @@ -39,7 +39,7 @@ jobs: pip install dist/*.whl - name: Test with pytest run: | - python -m pip install typing-extensions + python -m pip install typing-extensions pytest-asyncio pytest build: @@ -71,6 +71,6 @@ jobs: pip install dist/*.whl - name: Test with pytest run: | - python -m pip install typing-extensions + python -m pip install typing-extensions pytest-asyncio pytest diff --git a/.github/workflows/python-test-windows.yml b/.github/workflows/python-test-windows.yml index eb3545b..7674e42 100644 --- a/.github/workflows/python-test-windows.yml +++ b/.github/workflows/python-test-windows.yml @@ -46,6 +46,6 @@ jobs: } - name: Test with pytest run: | - python -m pip install typing-extensions + python -m pip install typing-extensions pytest-asyncio pytest From 2e969bd239e8c44f8917c3d5cf23d3ca6042a330 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 10 Feb 2026 10:41:00 +0000 Subject: [PATCH 6/6] Fix Python 3.7+ compatibility: use Union and Type instead of | and type Co-authored-by: binyamin555 <42520501+binyamin555@users.noreply.github.com> --- src/extype/builtin_extensions/coroutine_ext.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/extype/builtin_extensions/coroutine_ext.py b/src/extype/builtin_extensions/coroutine_ext.py index 6525def..44c0bf8 100644 --- a/src/extype/builtin_extensions/coroutine_ext.py +++ b/src/extype/builtin_extensions/coroutine_ext.py @@ -1,5 +1,5 @@ from inspect import iscoroutine -from typing import Awaitable, Callable, TypeVar +from typing import Awaitable, Callable, Type, TypeVar, Union from ..extension_utils import extend_type_with, extension @@ -21,7 +21,7 @@ class CoroutineExtension: """ @extension - def then(self: Awaitable[_T], fn: Callable[[_T], Awaitable[_U] | _U]) -> Awaitable[_U]: + def then(self: Awaitable[_T], fn: Callable[[_T], Union[Awaitable[_U], _U]]) -> Awaitable[_U]: """ Maps the result of the awaitable via an optionally async function. @@ -44,10 +44,10 @@ async def _then(): @extension def catch( self: Awaitable[_T], - fn: Callable[[_E], Awaitable[_U] | _U], + fn: Callable[[_E], Union[Awaitable[_U], _U]], *, - exception: type[_E] = Exception - ) -> Awaitable[_T | _U]: + exception: Type[_E] = Exception + ) -> Awaitable[Union[_T, _U]]: """ Catches an exception of the given type and calls the passed function with the caught exception.