diff --git a/docs/source/changes.md b/docs/source/changes.md index f7a82dd..85c2e52 100644 --- a/docs/source/changes.md +++ b/docs/source/changes.md @@ -11,6 +11,8 @@ releases are available on [PyPI](https://pypi.org/project/pytask-parallel) and - {pull}`130` switches type checking to ty. - {pull}`131` updates pre-commit hooks. - {pull}`132` removes the tox configuration in favor of uv and just. +- {pull}`137` fixes pickling errors in parallel workers when task modules contain + non-picklable globals. Fixes {issue}`136`. ## 0.5.1 - 2025-03-09 diff --git a/pyproject.toml b/pyproject.toml index 8520904..d010638 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ docs = [ "matplotlib", "myst-parser", "nbsphinx", - "sphinx", + "sphinx<9", "sphinx-autobuild", "sphinx-click", "sphinx-copybutton", diff --git a/requirements-dev.lock b/requirements-dev.lock deleted file mode 100644 index 70e3e4a..0000000 --- a/requirements-dev.lock +++ /dev/null @@ -1,86 +0,0 @@ -# generated by rye -# use `rye lock` or `rye sync` to update this lockfile -# -# last locked with the following flags: -# pre: false -# features: [] -# all-features: false -# with-sources: false - --e file:. -aiobotocore==2.12.3 - # via s3fs -aiohttp==3.9.4 - # via aiobotocore - # via s3fs -aioitertools==0.11.0 - # via aiobotocore -aiosignal==1.3.1 - # via aiohttp -attrs==23.2.0 - # via aiohttp - # via pytask - # via pytask-parallel -botocore==1.34.69 - # via aiobotocore -click==8.1.7 - # via click-default-group - # via pytask - # via pytask-parallel -click-default-group==1.2.4 - # via pytask -cloudpickle==3.0.0 - # via loky - # via pytask-parallel -frozenlist==1.4.1 - # via aiohttp - # via aiosignal -fsspec==2024.3.1 - # via s3fs -greenlet==3.0.3 - # via sqlalchemy -idna==3.7 - # via yarl -jmespath==1.0.1 - # via botocore -loky==3.4.1 - # via pytask-parallel -markdown-it-py==3.0.0 - # via rich -mdurl==0.1.2 - # via markdown-it-py -multidict==6.0.5 - # via aiohttp - # via yarl -networkx==3.3 - # via pytask -optree==0.11.0 - # via pytask -packaging==24.0 - # via pytask -pluggy==1.4.0 - # via pytask - # via pytask-parallel -pygments==2.17.2 - # via rich -pytask==0.4.7 - # via pytask-parallel -python-dateutil==2.9.0.post0 - # via botocore -rich==13.7.1 - # via pytask - # via pytask-parallel -s3fs==2024.3.1 -six==1.16.0 - # via python-dateutil -sqlalchemy==2.0.29 - # via pytask -typing-extensions==4.11.0 - # via optree - # via sqlalchemy -urllib3==2.2.1 - # via botocore -wrapt==1.16.0 - # via aiobotocore -yarl==1.9.4 - # via aiohttp diff --git a/requirements.lock b/requirements.lock deleted file mode 100644 index 4979697..0000000 --- a/requirements.lock +++ /dev/null @@ -1,51 +0,0 @@ -# generated by rye -# use `rye lock` or `rye sync` to update this lockfile -# -# last locked with the following flags: -# pre: false -# features: [] -# all-features: false -# with-sources: false - --e file:. -attrs==23.2.0 - # via pytask - # via pytask-parallel -click==8.1.7 - # via click-default-group - # via pytask - # via pytask-parallel -click-default-group==1.2.4 - # via pytask -cloudpickle==3.0.0 - # via loky - # via pytask-parallel -greenlet==3.0.3 - # via sqlalchemy -loky==3.4.1 - # via pytask-parallel -markdown-it-py==3.0.0 - # via rich -mdurl==0.1.2 - # via markdown-it-py -networkx==3.3 - # via pytask -optree==0.11.0 - # via pytask -packaging==24.0 - # via pytask -pluggy==1.4.0 - # via pytask - # via pytask-parallel -pygments==2.17.2 - # via rich -pytask==0.4.7 - # via pytask-parallel -rich==13.7.1 - # via pytask - # via pytask-parallel -sqlalchemy==2.0.29 - # via pytask -typing-extensions==4.11.0 - # via optree - # via sqlalchemy diff --git a/src/pytask_parallel/backends.py b/src/pytask_parallel/backends.py index f5aa6ee..2021782 100644 --- a/src/pytask_parallel/backends.py +++ b/src/pytask_parallel/backends.py @@ -2,6 +2,8 @@ from __future__ import annotations +import os +import sys import warnings from concurrent.futures import Executor from concurrent.futures import Future @@ -19,7 +21,46 @@ if TYPE_CHECKING: from collections.abc import Callable -__all__ = ["ParallelBackend", "ParallelBackendRegistry", "WorkerType", "registry"] +__all__ = [ + "ParallelBackend", + "ParallelBackendRegistry", + "WorkerType", + "registry", + "set_worker_root", +] + +_WORKER_ROOT: str | None = None + + +def set_worker_root(path: os.PathLike[str] | str) -> None: + """Configure the root path for worker processes. + + Spawned workers (notably on Windows) start with a clean interpreter and may not + inherit the parent's import path. We set both ``sys.path`` and ``PYTHONPATH`` so + task modules are importable by reference, which avoids pickling module globals. + + """ + root = os.fspath(path) + global _WORKER_ROOT # noqa: PLW0603 + _WORKER_ROOT = root + if root not in sys.path: + sys.path.insert(0, root) + # Ensure custom process backends can import task modules by reference. + separator = os.pathsep + current = os.environ.get("PYTHONPATH", "") + parts = [p for p in current.split(separator) if p] if current else [] + if root not in parts: + parts.insert(0, root) + os.environ["PYTHONPATH"] = separator.join(parts) + + +def _configure_worker(root: str | None) -> None: + """Set cwd and sys.path for worker processes.""" + if not root: + return + os.chdir(root) + if root not in sys.path: + sys.path.insert(0, root) def _deserialize_and_run_with_cloudpickle(fn: bytes, kwargs: bytes) -> Any: @@ -75,12 +116,20 @@ def _get_dask_executor(n_workers: int) -> Executor: def _get_loky_executor(n_workers: int) -> Executor: """Get a loky executor.""" - return get_reusable_executor(max_workers=n_workers) + return get_reusable_executor( + max_workers=n_workers, + initializer=_configure_worker, + initargs=(_WORKER_ROOT,), + ) def _get_process_pool_executor(n_workers: int) -> Executor: """Get a process pool executor.""" - return _CloudpickleProcessPoolExecutor(max_workers=n_workers) + return _CloudpickleProcessPoolExecutor( + max_workers=n_workers, + initializer=_configure_worker, + initargs=(_WORKER_ROOT,), + ) def _get_thread_pool_executor(n_workers: int) -> Executor: diff --git a/src/pytask_parallel/execute.py b/src/pytask_parallel/execute.py index ef6529a..ec0bc1a 100644 --- a/src/pytask_parallel/execute.py +++ b/src/pytask_parallel/execute.py @@ -26,11 +26,13 @@ from pytask_parallel.backends import WorkerType from pytask_parallel.backends import registry +from pytask_parallel.backends import set_worker_root from pytask_parallel.typing import CarryOverPath from pytask_parallel.typing import is_coiled_function from pytask_parallel.utils import create_kwargs_for_task from pytask_parallel.utils import get_module from pytask_parallel.utils import parse_future_result +from pytask_parallel.utils import should_pickle_module_by_value if TYPE_CHECKING: from concurrent.futures import Future @@ -57,6 +59,7 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091 # The executor can only be created after the collection to give users the # possibility to inject their own executors. + set_worker_root(session.config["root"]) session.config["_parallel_executor"] = registry.get_parallel_backend( session.config["parallel_backend"], n_workers=session.config["n_workers"] ) @@ -208,7 +211,8 @@ def pytask_execute_task(session: Session, task: PTask) -> Future[WrapperResult]: # cloudpickle will pickle it with the function. See cloudpickle#417, pytask#373 # and pytask#374. task_module = get_module(task.function, getattr(task, "path", None)) - cloudpickle.register_pickle_by_value(task_module) + if should_pickle_module_by_value(task_module): + cloudpickle.register_pickle_by_value(task_module) return cast("Any", wrapper_func).submit( task=task, @@ -230,7 +234,8 @@ def pytask_execute_task(session: Session, task: PTask) -> Future[WrapperResult]: # cloudpickle will pickle it with the function. See cloudpickle#417, pytask#373 # and pytask#374. task_module = get_module(task.function, getattr(task, "path", None)) - cloudpickle.register_pickle_by_value(task_module) + if should_pickle_module_by_value(task_module): + cloudpickle.register_pickle_by_value(task_module) return session.config["_parallel_executor"].submit( wrap_task_in_process, diff --git a/src/pytask_parallel/utils.py b/src/pytask_parallel/utils.py index 269c72d..c0be884 100644 --- a/src/pytask_parallel/utils.py +++ b/src/pytask_parallel/utils.py @@ -2,8 +2,10 @@ from __future__ import annotations +import importlib.util import inspect from functools import partial +from pathlib import Path from typing import TYPE_CHECKING from typing import Any @@ -20,7 +22,6 @@ if TYPE_CHECKING: from collections.abc import Callable from concurrent.futures import Future - from pathlib import Path from types import ModuleType from types import TracebackType @@ -39,6 +40,7 @@ class CoiledFunction: ... "create_kwargs_for_task", "get_module", "parse_future_result", + "should_pickle_module_by_value", ] @@ -150,3 +152,30 @@ def get_module(func: Callable[..., Any], path: Path | None) -> ModuleType: if path: return inspect.getmodule(func, path.as_posix()) # type: ignore[return-value] return inspect.getmodule(func) # type: ignore[return-value] + + +def should_pickle_module_by_value(module: ModuleType) -> bool: + """Return whether a module should be pickled by value. + + We only pickle by value when the module is not importable by name in the worker. + This avoids serializing all module globals, which can fail for non-picklable + objects (e.g., closed file handles or locks stored at module scope). + + """ + module_name = getattr(module, "__name__", None) + module_file = getattr(module, "__file__", None) + if not module_name or module_name == "__main__" or module_file is None: + return True + + try: + spec = importlib.util.find_spec(module_name) + except (ImportError, ValueError, AttributeError): + return True + + if spec is None or spec.origin is None: + return True + + try: + return Path(spec.origin).resolve() != Path(module_file).resolve() + except OSError: + return True diff --git a/src/pytask_parallel/wrappers.py b/src/pytask_parallel/wrappers.py index e663e8a..e26deba 100644 --- a/src/pytask_parallel/wrappers.py +++ b/src/pytask_parallel/wrappers.py @@ -217,7 +217,7 @@ def _render_traceback_to_string( traceback = Traceback(exc_info, show_locals=show_locals) segments = console.render(cast("Any", traceback), options=console_options) text = "".join(segment.text for segment in segments) - return (*exc_info[:2], text) # ty: ignore[invalid-return-type] + return (*exc_info[:2], text) def _handle_function_products( diff --git a/tests/test_execute.py b/tests/test_execute.py index c06919a..6df9ef5 100644 --- a/tests/test_execute.py +++ b/tests/test_execute.py @@ -364,3 +364,80 @@ def task_create_file( ) assert result.exit_code == ExitCode.OK assert tmp_path.joinpath("file.txt").read_text() == "This is the text." + + +@pytest.mark.parametrize( + "parallel_backend", + [ + ParallelBackend.PROCESSES, + pytest.param(ParallelBackend.LOKY, marks=skip_if_deadlock), + ], +) +def test_parallel_execution_with_mark_import(runner, tmp_path, parallel_backend): + source = """ + from pytask import mark, task + + @task + def task_assert_math(): + assert 2 + 2 == 4 + """ + tmp_path.joinpath("task_mark.py").write_text(textwrap.dedent(source)) + result = runner.invoke( + cli, [tmp_path.as_posix(), "-n", "2", "--parallel-backend", parallel_backend] + ) + assert result.exit_code == ExitCode.OK + + +@pytest.mark.parametrize( + "parallel_backend", + [ + ParallelBackend.PROCESSES, + pytest.param(ParallelBackend.LOKY, marks=skip_if_deadlock), + ], +) +def test_parallel_execution_with_mark_import_in_loop( + runner, tmp_path, parallel_backend +): + source = """ + from pytask import mark, task + + for data_name in ("a", "b", "c"): + + @task(id=data_name) + def task_assert_math_loop(): + assert 2 + 2 == 4 + """ + tmp_path.joinpath("task_mark_loop.py").write_text(textwrap.dedent(source)) + result = runner.invoke( + cli, [tmp_path.as_posix(), "-n", "2", "--parallel-backend", parallel_backend] + ) + assert result.exit_code == ExitCode.OK + + +@pytest.mark.parametrize( + "parallel_backend", + [ + ParallelBackend.PROCESSES, + pytest.param(ParallelBackend.LOKY, marks=skip_if_deadlock), + ], +) +def test_parallel_execution_with_closed_file_handle(runner, tmp_path, parallel_backend): + source = """ + from pathlib import Path + from pytask import task + + data_path = Path(__file__).parent / "data.txt" + data_path.write_text("hello", encoding="utf-8") + + with data_path.open(encoding="utf-8") as f: + content = f.read() + + @task + def task_assert_math(): + assert content == "hello" + """ + tmp_path.joinpath("task_file.py").write_text(textwrap.dedent(source)) + result = runner.invoke( + cli, [tmp_path.as_posix(), "-n", "2", "--parallel-backend", parallel_backend] + ) + assert result.exit_code == ExitCode.OK