diff --git a/docs/source/changes.md b/docs/source/changes.md index f7a82dd..99e69ea 100644 --- a/docs/source/changes.md +++ b/docs/source/changes.md @@ -30,12 +30,13 @@ releases are available on [PyPI](https://pypi.org/project/pytask-parallel) and or processes automatically. - {pull}`96` handles local paths with remote executors. `PathNode`s are not supported as dependencies or products (except for return annotations). -- {pull}`99` changes that all tasks that are ready are being scheduled. It improves - interactions with adaptive scaling. {issue}`98` does handle the resulting issues: no - strong adherence to priorities, no pending status. +- {pull}`99` changes that all ready tasks are being scheduled. It improves interactions + with adaptive scaling. {issue}`98` does handle the resulting issues: no strong + adherence to priorities, no pending status. - {pull}`100` adds project management with rye. - {pull}`101` adds syncing for local paths as dependencies or products in remote environments with the same OS. +- {pull}`102` implements a pending status for scheduled but not started tasks. - {pull}`106` fixes {pull}`99` such that only when there are coiled functions, all ready tasks are submitted. - {pull}`107` removes status from `pytask_execute_task_log_start` hook call. diff --git a/docs/source/coiled.md b/docs/source/coiled.md index 69408b5..0253c97 100644 --- a/docs/source/coiled.md +++ b/docs/source/coiled.md @@ -1,7 +1,7 @@ # coiled ```{caution} -Currently, the coiled backend can only be used if your workflow code is organized in a +Currently, the coiled backend can only be used if your workflow code is organized as a package due to how pytask imports your code and dask serializes task functions ([issue](https://github.com/dask/distributed/issues/8607)). ``` diff --git a/pyproject.toml b/pyproject.toml index 8520904..d010638 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ docs = [ "matplotlib", "myst-parser", "nbsphinx", - "sphinx", + "sphinx<9", "sphinx-autobuild", "sphinx-click", "sphinx-copybutton", diff --git a/src/pytask_parallel/execute.py b/src/pytask_parallel/execute.py index ef6529a..60c6c66 100644 --- a/src/pytask_parallel/execute.py +++ b/src/pytask_parallel/execute.py @@ -2,8 +2,13 @@ from __future__ import annotations +import os +import queue import sys import time +from collections import deque +from contextlib import ExitStack +from multiprocessing import Manager from typing import TYPE_CHECKING from typing import Any from typing import cast @@ -17,6 +22,7 @@ from pytask import PTask from pytask import PythonNode from pytask import Session +from pytask import TaskExecutionStatus from pytask import console from pytask import get_marks from pytask import hookimpl @@ -24,6 +30,7 @@ from pytask.tree_util import tree_map from pytask.tree_util import tree_structure +from pytask_parallel.backends import ParallelBackend from pytask_parallel.backends import WorkerType from pytask_parallel.backends import registry from pytask_parallel.typing import CarryOverPath @@ -33,7 +40,9 @@ from pytask_parallel.utils import parse_future_result if TYPE_CHECKING: + from collections.abc import Callable from concurrent.futures import Future + from multiprocessing.managers import SyncManager from pytask_parallel.wrappers import WrapperResult @@ -53,6 +62,33 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091 __tracebackhide__ = True reports = session.execution_reports running_tasks: dict[str, Future[Any]] = {} + running_try_last: set[str] = set() + queued_try_first_tasks: deque[str] = deque() + queued_tasks: deque[str] = deque() + queued_try_last_tasks: deque[str] = deque() + sleeper = _Sleeper() + debug_status = _is_debug_status_enabled() + + # Create a shared queue to differentiate between running and pending tasks for + # some parallel backends. + if session.config["parallel_backend"] in ( + ParallelBackend.PROCESSES, + ParallelBackend.LOKY, + ): + manager_cls: Callable[[], SyncManager] | type[ExitStack] = Manager + start_execution_state = TaskExecutionStatus.PENDING + status_queue_factory = "manager" + elif session.config["parallel_backend"] == ParallelBackend.THREADS: + manager_cls = ExitStack + start_execution_state = TaskExecutionStatus.PENDING + status_queue_factory = "simple" + else: + manager_cls = ExitStack + start_execution_state = TaskExecutionStatus.RUNNING + status_queue_factory = None + + # Get the live execution manager from the registry if it exists. + live_execution = session.config["pm"].get_plugin("live_execution") any_coiled_task = any(is_coiled_function(task) for task in session.tasks) # The executor can only be created after the collection to give users the @@ -60,26 +96,47 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091 session.config["_parallel_executor"] = registry.get_parallel_backend( session.config["parallel_backend"], n_workers=session.config["n_workers"] ) + with session.config["_parallel_executor"], manager_cls() as manager: + if status_queue_factory == "manager": + session.config["_status_queue"] = manager.Queue() # type: ignore[union-attr] + elif status_queue_factory == "simple": + session.config["_status_queue"] = queue.SimpleQueue() - with session.config["_parallel_executor"]: - sleeper = _Sleeper() + if live_execution: + live_execution.initial_status = start_execution_state i = 0 + prefetch_factor = ( + 2 + if session.config["parallel_backend"] + in ( + ParallelBackend.PROCESSES, + ParallelBackend.LOKY, + ParallelBackend.THREADS, + ) + else 1 + ) + use_prefetch_queue = prefetch_factor > 1 while session.scheduler.is_active(): try: newly_collected_reports = [] + did_enqueue = False + did_submit = False # If there is any coiled function, the user probably wants to exploit # adaptive scaling. Thus, we need to submit all ready tasks. # Unfortunately, all submitted tasks are shown as running although some # are pending. # - # Without coiled functions, we submit as many tasks as there are - # available workers since we cannot reliably detect a pending status. - # - # See #98 for more information. if any_coiled_task: n_new_tasks = 10_000 + elif use_prefetch_queue: + n_new_tasks = (session.config["n_workers"] * prefetch_factor) - ( + len(running_tasks) + + len(queued_try_first_tasks) + + len(queued_tasks) + + len(queued_try_last_tasks) + ) else: n_new_tasks = session.config["n_workers"] - len(running_tasks) @@ -89,31 +146,96 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091 else [] ) - for task_name in ready_tasks: - task = session.dag.nodes[task_name]["task"] - session.hook.pytask_execute_task_log_start( - session=session, task=task - ) - try: - session.hook.pytask_execute_task_setup( - session=session, task=task + if use_prefetch_queue: + for task_signature in ready_tasks: + task = session.dag.nodes[task_signature]["task"] + if debug_status: + _log_status("PENDING", task_signature) + session.hook.pytask_execute_task_log_start( + session=session, + task=task, + status=start_execution_state, ) - running_tasks[task_name] = session.hook.pytask_execute_task( - session=session, task=task + if get_marks(task, "try_first"): + queued_try_first_tasks.append(task_signature) + elif get_marks(task, "try_last"): + queued_try_last_tasks.append(task_signature) + else: + queued_tasks.append(task_signature) + did_enqueue = True + + def _can_run_try_last() -> bool: + return not ( + queued_try_first_tasks + or queued_tasks + or (len(running_tasks) > len(running_try_last)) ) - sleeper.reset() - except Exception: # noqa: BLE001 - report = ExecutionReport.from_task_and_exception( - task, sys.exc_info() + + while len(running_tasks) < session.config["n_workers"]: + if queued_try_first_tasks: + task_signature = queued_try_first_tasks.popleft() + elif queued_tasks: + task_signature = queued_tasks.popleft() + elif queued_try_last_tasks and _can_run_try_last(): + task_signature = queued_try_last_tasks.popleft() + else: + break + task = session.dag.nodes[task_signature]["task"] + try: + session.hook.pytask_execute_task_setup( + session=session, task=task + ) + running_tasks[task_signature] = ( + session.hook.pytask_execute_task( + session=session, task=task + ) + ) + if get_marks(task, "try_last"): + running_try_last.add(task_signature) + sleeper.reset() + did_submit = True + except Exception: # noqa: BLE001 + report = ExecutionReport.from_task_and_exception( + task, sys.exc_info() + ) + newly_collected_reports.append(report) + session.scheduler.done(task_signature) + else: + for task_signature in ready_tasks: + task = session.dag.nodes[task_signature]["task"] + if debug_status: + _log_status( + "PENDING" + if start_execution_state == TaskExecutionStatus.PENDING + else "RUNNING", + task_signature, + ) + session.hook.pytask_execute_task_log_start( + session=session, task=task, status=start_execution_state ) - newly_collected_reports.append(report) - session.scheduler.done(task_name) + try: + session.hook.pytask_execute_task_setup( + session=session, task=task + ) + running_tasks[task_signature] = ( + session.hook.pytask_execute_task( + session=session, task=task + ) + ) + sleeper.reset() + did_submit = True + except Exception: # noqa: BLE001 + report = ExecutionReport.from_task_and_exception( + task, sys.exc_info() + ) + newly_collected_reports.append(report) + session.scheduler.done(task_signature) - if not ready_tasks: + if not ready_tasks and not did_enqueue and not did_submit: sleeper.increment() - for task_name in list(running_tasks): - future = running_tasks[task_name] + for task_signature in list(running_tasks): + future = running_tasks[task_signature] if future.done(): wrapper_result = parse_future_result(future) @@ -129,17 +251,18 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091 ) if wrapper_result.exc_info is not None: - task = session.dag.nodes[task_name]["task"] + task = session.dag.nodes[task_signature]["task"] newly_collected_reports.append( ExecutionReport.from_task_and_exception( task, wrapper_result.exc_info, # type: ignore[arg-type] ) ) - running_tasks.pop(task_name) - session.scheduler.done(task_name) + running_tasks.pop(task_signature) + running_try_last.discard(task_signature) + session.scheduler.done(task_signature) else: - task = session.dag.nodes[task_name]["task"] + task = session.dag.nodes[task_signature]["task"] _update_carry_over_products( task, wrapper_result.carry_over_products ) @@ -155,9 +278,29 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091 else: report = ExecutionReport.from_task(task) - running_tasks.pop(task_name) + running_tasks.pop(task_signature) + running_try_last.discard(task_signature) newly_collected_reports.append(report) - session.scheduler.done(task_name) + session.scheduler.done(task_signature) + + # Check if tasks are not pending but running and update the live + # status. + if ( + live_execution or debug_status + ) and "_status_queue" in session.config: + status_queue = session.config["_status_queue"] + while True: + try: + started_task = status_queue.get(block=False) + except queue.Empty: + break + if started_task in running_tasks: + if live_execution: + live_execution.update_task( + started_task, status=TaskExecutionStatus.RUNNING + ) + if debug_status: + _log_status("RUNNING", started_task) for report in newly_collected_reports: session.hook.pytask_execute_task_process_report( @@ -239,6 +382,7 @@ def pytask_execute_task(session: Session, task: PTask) -> Future[WrapperResult]: kwargs=kwargs, remote=remote, session_filterwarnings=session.config["filterwarnings"], + status_queue=session.config.get("_status_queue"), show_locals=session.config["show_locals"], task_filterwarnings=get_marks(task, "filterwarnings"), ) @@ -248,7 +392,11 @@ def pytask_execute_task(session: Session, task: PTask) -> Future[WrapperResult]: from pytask_parallel.wrappers import wrap_task_in_thread # noqa: PLC0415 return session.config["_parallel_executor"].submit( - wrap_task_in_thread, task=task, remote=False, **kwargs + wrap_task_in_thread, + task=task, + remote=False, + status_queue=session.config.get("_status_queue"), + **kwargs, ) msg = f"Unknown worker type {worker_type}" @@ -261,6 +409,17 @@ def pytask_unconfigure() -> None: registry.reset() +def _is_debug_status_enabled() -> bool: + """Return whether to emit debug status updates.""" + value = os.environ.get("PYTASK_PARALLEL_DEBUG_STATUS", "") + return value.strip().lower() in {"1", "true", "yes", "on"} + + +def _log_status(status: str, task_signature: str) -> None: + """Log a status transition for a task.""" + console.print(f"[pytask-parallel] {status}: {task_signature}") + + def _update_carry_over_products( task: PTask, carry_over_products: PyTree[CarryOverPath | PythonNode | None] | None ) -> None: diff --git a/src/pytask_parallel/wrappers.py b/src/pytask_parallel/wrappers.py index e663e8a..72be158 100644 --- a/src/pytask_parallel/wrappers.py +++ b/src/pytask_parallel/wrappers.py @@ -37,6 +37,8 @@ if TYPE_CHECKING: from collections.abc import Callable + from queue import Queue + from queue import SimpleQueue from types import TracebackType from pytask import Mark @@ -57,7 +59,13 @@ class WrapperResult: stderr: str -def wrap_task_in_thread(task: PTask, *, remote: bool, **kwargs: Any) -> WrapperResult: +def wrap_task_in_thread( + task: PTask, + *, + remote: bool, + status_queue: Queue[str] | SimpleQueue[str] | None = None, + **kwargs: Any, +) -> WrapperResult: """Mock execution function such that it returns the same as for processes. The function for processes returns ``warning_reports`` and an ``exception``. With @@ -66,6 +74,11 @@ def wrap_task_in_thread(task: PTask, *, remote: bool, **kwargs: Any) -> WrapperR """ __tracebackhide__ = True + + # Add task to the status queue to indicate that it is currently being executed. + if status_queue is not None: + status_queue.put(task.signature) + try: out = task.function(**kwargs) except Exception: # noqa: BLE001 @@ -81,6 +94,7 @@ def wrap_task_in_thread(task: PTask, *, remote: bool, **kwargs: Any) -> WrapperR else: _handle_function_products(task, out, remote=remote) exc_info = None + return WrapperResult( carry_over_products=None, warning_reports=[], @@ -97,6 +111,7 @@ def wrap_task_in_process( # noqa: PLR0913 kwargs: dict[str, Any], remote: bool, session_filterwarnings: tuple[str, ...], + status_queue: Queue[str] | SimpleQueue[str] | None = None, show_locals: bool, task_filterwarnings: tuple[Mark, ...], ) -> WrapperResult: @@ -109,6 +124,10 @@ def wrap_task_in_process( # noqa: PLR0913 # Hide this function from tracebacks. __tracebackhide__ = True + # Add task to the status queue to indicate that it is currently being executed. + if status_queue is not None: + status_queue.put(task.signature) + # Patch set_trace and breakpoint to show a better error message. _patch_set_trace_and_breakpoint() @@ -217,7 +236,7 @@ def _render_traceback_to_string( traceback = Traceback(exc_info, show_locals=show_locals) segments = console.render(cast("Any", traceback), options=console_options) text = "".join(segment.text for segment in segments) - return (*exc_info[:2], text) # ty: ignore[invalid-return-type] + return (*exc_info[:2], text) def _handle_function_products(