diff --git a/qlib/utils/paral.py b/qlib/utils/paral.py index a6177833413..320f49c19d2 100644 --- a/qlib/utils/paral.py +++ b/qlib/utils/paral.py @@ -6,7 +6,6 @@ from threading import Thread from typing import Callable, Text, Union -import joblib from joblib import Parallel, delayed from joblib._parallel_backends import MultiprocessingBackend import pandas as pd @@ -22,12 +21,7 @@ def __init__(self, *args, **kwargs): maxtasksperchild = kwargs.pop("maxtasksperchild", None) super(ParallelExt, self).__init__(*args, **kwargs) if isinstance(self._backend, MultiprocessingBackend): - # 2025-05-04 joblib released version 1.5.0, in which _backend_args was removed and replaced by _backend_kwargs. - # Ref: https://github.com/joblib/joblib/pull/1525/files#diff-e4dff8042ce45b443faf49605b75a58df35b8c195978d4a57f4afa695b406bdc - if joblib.__version__ < "1.5.0": - self._backend_args["maxtasksperchild"] = maxtasksperchild # pylint: disable=E1101 - else: - self._backend_kwargs["maxtasksperchild"] = maxtasksperchild # pylint: disable=E1101 + self._backend_kwargs["maxtasksperchild"] = maxtasksperchild # pylint: disable=E1101 def datetime_groupby_apply( diff --git a/tests/test_parallel_ext.py b/tests/test_parallel_ext.py new file mode 100644 index 00000000000..30ecb95e021 --- /dev/null +++ b/tests/test_parallel_ext.py @@ -0,0 +1,22 @@ +"""Test for Issue #1927: ParallelExt _backend_kwargs attribute fix.""" +import pytest +from joblib import delayed +from qlib.utils.paral import ParallelExt + + +def test_parallel_ext_with_maxtasksperchild(): + """ParallelExt should accept maxtasksperchild without AttributeError.""" + p = ParallelExt(n_jobs=1, backend="loky", maxtasksperchild=10) + results = p(delayed(lambda x: x * 2)(i) for i in range(5)) + assert results == [0, 2, 4, 6, 8] + + +def test_parallel_ext_without_maxtasksperchild(): + """ParallelExt should work normally without maxtasksperchild.""" + p = ParallelExt(n_jobs=1) + results = p(delayed(sum)([i, 1]) for i in range(3)) + assert results == [1, 2, 3] + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])