Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions taskgroup/taskgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def __init__(self) -> None:
self._errors: list[BaseException] | None = []
self._base_error: BaseException | None = None
self._on_completed_fut: futures.Future[Literal[True]] | None = None
self._cancel_on_enter = False

def __repr__(self) -> str:
info = [""]
Expand All @@ -106,6 +107,8 @@ async def __aenter__(self) -> Self:
if self._parent_task is None:
raise RuntimeError(f"TaskGroup {self!r} cannot determine the parent task")
self._entered = True
if self._cancel_on_enter:
self.cancel()

return self

Expand Down Expand Up @@ -222,6 +225,10 @@ async def _aexit(
finally:
exc = None

# Suppress any remaining exception (exceptions deserving to be raised
# were raised above).
return True

def create_task(
self,
coro: _TaskCompatibleCoro[_T_co],
Expand Down Expand Up @@ -327,6 +334,33 @@ def _on_task_done(self, task: tasks.Task[object]) -> None:
self._parent_cancel_requested = True
self._parent_task.cancel()

def cancel(self):
"""Cancel the task group

`cancel()` will be called on any tasks in the group that aren't yet
done, as well as the parent (body) of the group. This will cause the
task group context manager to exit *without* `asyncio.CancelledError`
being raised.

If `cancel()` is called before entering the task group, the group will be
cancelled upon entry. This is useful for patterns where one piece of
code passes an unused TaskGroup instance to another in order to have
the ability to cancel anything run within the group.

`cancel()` is idempotent and may be called after the task group has
already exited.
"""
if not self._entered:
self._cancel_on_enter = True
return
if self._exiting and not self._tasks:
return
if not self._aborting:
self._abort()
if self._parent_task and not self._parent_cancel_requested:
self._parent_cancel_requested = True
self._parent_task.cancel()


class TaskGroup(_TaskGroup):
__stack: contextlib.AsyncExitStack
Expand Down
Loading